mgs_reversing/build/find_duplicates.py

262 lines
9.3 KiB
Python
Executable File

#!/usr/bin/env python3
from fnmatch import fnmatch
from glob import glob
import os
import sys
from objlib.obj import get_obj_funcs
from statistics import quantiles
from termcolor import colored
from Levenshtein import ratio
from capstone import Cs, CS_ARCH_MIPS, CS_MODE_MIPS32
from capstone.mips import *
from functools import cache
root_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), '..'))
asm_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), '../asm'))
reloc_insts = [
MIPS_INS_ADDIU,
MIPS_INS_LB,
MIPS_INS_LBU,
MIPS_INS_LH,
MIPS_INS_LHU,
MIPS_INS_LW,
# MIPS_INS_LWU,
MIPS_INS_SB,
MIPS_INS_SH,
MIPS_INS_SW,
]
def load_all_funcs():
funcs = {}
for obj_file in glob(os.path.join(root_dir, 'obj*/**/*.obj'), recursive=True):
if "_fixme" in obj_file:
continue
for func_name, code in get_obj_funcs(obj_file):
code = b''.join(code for _, code in code)
funcs[func_name.decode("utf-8")] = code
return funcs
def not_matched_functions():
return set(os.path.basename(f).replace('.s', '') for f in glob(os.path.join(asm_dir, '**/*.s'), recursive=True))
def byte_equality_distance(lhs_func, rhs_func):
# Naive approach: how many bytes are different in equal-sized funcs
if len(lhs_func) != len(rhs_func):
return None
diff_bytes = sum(1 if lhs_byte != rhs_byte else 0 for lhs_byte, rhs_byte in zip(lhs_func, rhs_func))
return int(diff_bytes / len(lhs_func) * 100)
def levenshtein_distance_on_bytes(lhs_func, rhs_func):
return int((1.0 - ratio(lhs_func, rhs_func)) * 100)
@cache
def disasm(code):
md = Cs(CS_ARCH_MIPS, CS_MODE_MIPS32)
md.detail = True
insts = []
l = len(code)
processing_addr = 0
last_processed = 0
# Disassembles an instruction to format:
# opcode | number of operands | operand1 | operand2 | ... | operandN
while processing_addr < l:
for inst in md.disasm(code, processing_addr):
if inst.id == MIPS_INS_INVALID:
raise Exception('MIPS_INS_INVALID is unhandled')
processing_addr += 4
non_reloc_ops = [str(inst.id), str(len(inst.operands))]
is_reloc_inst = inst.id in reloc_insts
for operand in inst.operands:
if operand.type == MIPS_OP_REG:
non_reloc_ops.append(str(operand.value.reg))
elif operand.type == MIPS_OP_IMM:
if not is_reloc_inst:
non_reloc_ops.append(str(operand.value.imm))
elif operand.type == MIPS_OP_MEM:
non_reloc_ops.append(str(operand.value.mem.base)) # register
if not is_reloc_inst:
non_reloc_ops.append(str(operand.value.mem.disp)) # offset
insts.append('|'.join(non_reloc_ops))
if processing_addr >= l:
break
code = code[processing_addr - last_processed:]
insts.append('raw|' + '|'.join([str(x) for x in [code[0], code[1], code[2], code[3]]]))
processing_addr += 4
last_processed = processing_addr
code = code[4:]
assert len(insts) * 4 == l
return insts
def instruction_equality_distance(lhs_func, rhs_func):
return byte_equality_distance(disasm(lhs_func), disasm(rhs_func))
def levenshtein_distance_on_instructions(lhs_func, rhs_func):
return levenshtein_distance_on_bytes(disasm(lhs_func), disasm(rhs_func))
def quantiles_wrapper(data, n):
if len(data) == 1:
return [data[0] for _ in range(n)]
return quantiles(data, n=n)
def main():
# Legend
print(colored('Legend', 'white', attrs=['underline']))
print(' Matches:', colored('PERFECT MATCH', 'green', attrs=['reverse']),
colored('good match', 'green'), colored('OK match', 'white', attrs=['dark']),
colored('mediocre match', 'yellow'), colored('bad match', 'red'))
print(' Functions:', colored('B', 'white', attrs=['bold']), "(big function)", ". (small function)")
print('')
# Load all functions
print('Loading all functions... ', end='', flush=True)
funcs = load_all_funcs()
print('\r', end='')
# Decide which functions to actually compare
lhs_pattern = input('Left-hand side of comparison - glob filter (e.g. *, sub_8034*, *s00a*, s???r_*): ')
lhs_only_not_matched = input('Left-hand side of comparison - only NOT matched functions? y/n ').strip().lower()
if lhs_only_not_matched == 'y' or lhs_only_not_matched == 'yes':
lhs_only_not_matched = True
elif lhs_only_not_matched == 'n' or lhs_only_not_matched == 'no':
lhs_only_not_matched = False
else:
print("Invalid choice:", lhs_only_not_matched)
sys.exit(1)
rhs_pattern = input('Right-hand side of comparison - glob filter (e.g. *, sub_8034*, *s00a*, s???r_*): ')
rhs_only_matched = input('Right-hand side of comparison - only matched functions? y/n ').strip().lower()
if rhs_only_matched == 'y' or rhs_only_matched == 'yes':
rhs_only_matched = True
elif rhs_only_matched == 'n' or rhs_only_matched == 'no':
rhs_only_matched = False
else:
print("Invalid choice:", rhs_only_matched)
sys.exit(1)
print('Diffing algorithm:')
print(" PURE BYTES")
print(" ===========================================================")
print(" 1. Byte equality: fastest, compares how many bytes")
print(" are different between functions of the same size,")
print(" so it can detect fully identical functions or functions")
print(" that differ in relocations.")
print(" 2. Levensthein distance (on pure bytes): slower")
print("")
print(" INSTRUCTIONS")
print(" ===========================================================")
print(" 3. Instruction equality: faster, compares how many decoded")
print(" instructions are different between functions of the same")
print(" instruction count.")
print(" 4. Levensthein distance (on instructions): slowest")
algorithm = input('Diffing algorithm (1, 2, 3 or 4): ').strip()
if algorithm == '1':
calc_distance = byte_equality_distance
elif algorithm == '2':
calc_distance = levenshtein_distance_on_bytes
elif algorithm == '3':
calc_distance = instruction_equality_distance
elif algorithm == '4':
calc_distance = levenshtein_distance_on_instructions
else:
print("Invalid choice of algorithm:", algorithm)
sys.exit(1)
print('Comparing... ', end='', flush=True)
lhs_funcnames = [name for name in funcs if fnmatch(name, lhs_pattern)]
rhs_funcnames = [name for name in funcs if fnmatch(name, rhs_pattern)]
lhs_funcnames = sorted(lhs_funcnames, key=lambda x: x.upper()[-8:])
rhs_funcnames = sorted(rhs_funcnames, key=lambda x: x.upper()[-8:])
if lhs_only_not_matched:
allowed_funcs = not_matched_functions()
lhs_funcnames = [f for f in lhs_funcnames if f in allowed_funcs]
if rhs_only_matched:
disallowed_funcs = not_matched_functions()
rhs_funcnames = [f for f in rhs_funcnames if f not in disallowed_funcs]
# Compare the functions, we love O(N^2):
results = {}
for lhs_funcname in lhs_funcnames:
distances = []
for rhs_funcname in rhs_funcnames:
if lhs_funcname == rhs_funcname:
continue
distance = calc_distance(funcs[lhs_funcname], funcs[rhs_funcname])
if distance is None:
continue
distances.append((rhs_funcname, distance))
distances = sorted(distances, key=lambda v: v[1])
distances = distances[:3] # Top 3
results[lhs_funcname] = distances
print('\r', len('Comparing... ') * ' ')
# Prepare for displaying the results
all_distances = [distance for sublist in results.values() for name, distance in sublist]
percentiles = quantiles_wrapper(sorted(all_distances), n=100)
# green - best matches
# white - in-between
# yellow - in-between
# red - worst matches
green_percentile = percentiles[25]
white_percentile = percentiles[50]
yellow_percentile = percentiles[75]
percentiles = quantiles_wrapper(sorted(len(funcs[func]) for func in lhs_funcnames), n=100)
small_function = percentiles[25]
big_function = percentiles[75]
# Display the results
for lhs_funcname in lhs_funcnames:
result = results[lhs_funcname]
function_indicator = ' '
if len(funcs[lhs_funcname]) > big_function:
function_indicator = colored('B', 'white', attrs=['bold'])
elif len(funcs[lhs_funcname]) < small_function:
function_indicator = '.'
print(f"{lhs_funcname} ({len(funcs[lhs_funcname])} bytes)".rjust(50, ' '), function_indicator, end=' | ')
for name, dist in result:
attrs = []
if dist <= green_percentile:
color = 'green'
if dist == 0:
# Perfect match!
attrs.append('reverse')
elif dist <= white_percentile:
color = 'white'
attrs.append('dark')
elif dist <= yellow_percentile:
color = 'yellow'
else:
color = 'red'
print(colored(f"{name} ({dist} diff)", color, attrs=attrs), end=' ')
print('')
if __name__ == '__main__':
main()