262 lines
9.3 KiB
Python
Executable File
262 lines
9.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
from fnmatch import fnmatch
|
|
from glob import glob
|
|
import os
|
|
import sys
|
|
from objlib.obj import get_obj_funcs
|
|
from statistics import quantiles
|
|
from termcolor import colored
|
|
from Levenshtein import ratio
|
|
from capstone import Cs, CS_ARCH_MIPS, CS_MODE_MIPS32
|
|
from capstone.mips import *
|
|
from functools import cache
|
|
|
|
root_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), '..'))
|
|
asm_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), '../asm'))
|
|
|
|
reloc_insts = [
|
|
MIPS_INS_ADDIU,
|
|
MIPS_INS_LB,
|
|
MIPS_INS_LBU,
|
|
MIPS_INS_LH,
|
|
MIPS_INS_LHU,
|
|
MIPS_INS_LW,
|
|
# MIPS_INS_LWU,
|
|
MIPS_INS_SB,
|
|
MIPS_INS_SH,
|
|
MIPS_INS_SW,
|
|
]
|
|
|
|
def load_all_funcs():
|
|
funcs = {}
|
|
|
|
for obj_file in glob(os.path.join(root_dir, 'obj*/**/*.obj'), recursive=True):
|
|
if "_fixme" in obj_file:
|
|
continue
|
|
|
|
for func_name, code in get_obj_funcs(obj_file):
|
|
code = b''.join(code for _, code in code)
|
|
funcs[func_name.decode("utf-8")] = code
|
|
|
|
return funcs
|
|
|
|
def not_matched_functions():
|
|
return set(os.path.basename(f).replace('.s', '') for f in glob(os.path.join(asm_dir, '**/*.s'), recursive=True))
|
|
|
|
def byte_equality_distance(lhs_func, rhs_func):
|
|
# Naive approach: how many bytes are different in equal-sized funcs
|
|
if len(lhs_func) != len(rhs_func):
|
|
return None
|
|
|
|
diff_bytes = sum(1 if lhs_byte != rhs_byte else 0 for lhs_byte, rhs_byte in zip(lhs_func, rhs_func))
|
|
return int(diff_bytes / len(lhs_func) * 100)
|
|
|
|
def levenshtein_distance_on_bytes(lhs_func, rhs_func):
|
|
return int((1.0 - ratio(lhs_func, rhs_func)) * 100)
|
|
|
|
@cache
|
|
def disasm(code):
|
|
md = Cs(CS_ARCH_MIPS, CS_MODE_MIPS32)
|
|
md.detail = True
|
|
|
|
insts = []
|
|
l = len(code)
|
|
processing_addr = 0
|
|
last_processed = 0
|
|
|
|
# Disassembles an instruction to format:
|
|
# opcode | number of operands | operand1 | operand2 | ... | operandN
|
|
while processing_addr < l:
|
|
for inst in md.disasm(code, processing_addr):
|
|
if inst.id == MIPS_INS_INVALID:
|
|
raise Exception('MIPS_INS_INVALID is unhandled')
|
|
|
|
processing_addr += 4
|
|
non_reloc_ops = [str(inst.id), str(len(inst.operands))]
|
|
is_reloc_inst = inst.id in reloc_insts
|
|
|
|
for operand in inst.operands:
|
|
if operand.type == MIPS_OP_REG:
|
|
non_reloc_ops.append(str(operand.value.reg))
|
|
elif operand.type == MIPS_OP_IMM:
|
|
if not is_reloc_inst:
|
|
non_reloc_ops.append(str(operand.value.imm))
|
|
elif operand.type == MIPS_OP_MEM:
|
|
non_reloc_ops.append(str(operand.value.mem.base)) # register
|
|
if not is_reloc_inst:
|
|
non_reloc_ops.append(str(operand.value.mem.disp)) # offset
|
|
|
|
insts.append('|'.join(non_reloc_ops))
|
|
|
|
if processing_addr >= l:
|
|
break
|
|
|
|
code = code[processing_addr - last_processed:]
|
|
insts.append('raw|' + '|'.join([str(x) for x in [code[0], code[1], code[2], code[3]]]))
|
|
|
|
processing_addr += 4
|
|
last_processed = processing_addr
|
|
code = code[4:]
|
|
|
|
assert len(insts) * 4 == l
|
|
|
|
return insts
|
|
|
|
def instruction_equality_distance(lhs_func, rhs_func):
|
|
return byte_equality_distance(disasm(lhs_func), disasm(rhs_func))
|
|
|
|
def levenshtein_distance_on_instructions(lhs_func, rhs_func):
|
|
return levenshtein_distance_on_bytes(disasm(lhs_func), disasm(rhs_func))
|
|
|
|
def quantiles_wrapper(data, n):
|
|
if len(data) == 1:
|
|
return [data[0] for _ in range(n)]
|
|
return quantiles(data, n=n)
|
|
|
|
def main():
|
|
# Legend
|
|
print(colored('Legend', 'white', attrs=['underline']))
|
|
print(' Matches:', colored('PERFECT MATCH', 'green', attrs=['reverse']),
|
|
colored('good match', 'green'), colored('OK match', 'white', attrs=['dark']),
|
|
colored('mediocre match', 'yellow'), colored('bad match', 'red'))
|
|
print(' Functions:', colored('B', 'white', attrs=['bold']), "(big function)", ". (small function)")
|
|
print('')
|
|
|
|
# Load all functions
|
|
print('Loading all functions... ', end='', flush=True)
|
|
funcs = load_all_funcs()
|
|
print('\r', end='')
|
|
|
|
# Decide which functions to actually compare
|
|
lhs_pattern = input('Left-hand side of comparison - glob filter (e.g. *, sub_8034*, *s00a*, s???r_*): ')
|
|
lhs_only_not_matched = input('Left-hand side of comparison - only NOT matched functions? y/n ').strip().lower()
|
|
if lhs_only_not_matched == 'y' or lhs_only_not_matched == 'yes':
|
|
lhs_only_not_matched = True
|
|
elif lhs_only_not_matched == 'n' or lhs_only_not_matched == 'no':
|
|
lhs_only_not_matched = False
|
|
else:
|
|
print("Invalid choice:", lhs_only_not_matched)
|
|
sys.exit(1)
|
|
|
|
rhs_pattern = input('Right-hand side of comparison - glob filter (e.g. *, sub_8034*, *s00a*, s???r_*): ')
|
|
|
|
rhs_only_matched = input('Right-hand side of comparison - only matched functions? y/n ').strip().lower()
|
|
if rhs_only_matched == 'y' or rhs_only_matched == 'yes':
|
|
rhs_only_matched = True
|
|
elif rhs_only_matched == 'n' or rhs_only_matched == 'no':
|
|
rhs_only_matched = False
|
|
else:
|
|
print("Invalid choice:", rhs_only_matched)
|
|
sys.exit(1)
|
|
|
|
print('Diffing algorithm:')
|
|
print(" PURE BYTES")
|
|
print(" ===========================================================")
|
|
print(" 1. Byte equality: fastest, compares how many bytes")
|
|
print(" are different between functions of the same size,")
|
|
print(" so it can detect fully identical functions or functions")
|
|
print(" that differ in relocations.")
|
|
print(" 2. Levensthein distance (on pure bytes): slower")
|
|
print("")
|
|
print(" INSTRUCTIONS")
|
|
print(" ===========================================================")
|
|
print(" 3. Instruction equality: faster, compares how many decoded")
|
|
print(" instructions are different between functions of the same")
|
|
print(" instruction count.")
|
|
print(" 4. Levensthein distance (on instructions): slowest")
|
|
|
|
algorithm = input('Diffing algorithm (1, 2, 3 or 4): ').strip()
|
|
if algorithm == '1':
|
|
calc_distance = byte_equality_distance
|
|
elif algorithm == '2':
|
|
calc_distance = levenshtein_distance_on_bytes
|
|
elif algorithm == '3':
|
|
calc_distance = instruction_equality_distance
|
|
elif algorithm == '4':
|
|
calc_distance = levenshtein_distance_on_instructions
|
|
else:
|
|
print("Invalid choice of algorithm:", algorithm)
|
|
sys.exit(1)
|
|
|
|
print('Comparing... ', end='', flush=True)
|
|
|
|
lhs_funcnames = [name for name in funcs if fnmatch(name, lhs_pattern)]
|
|
rhs_funcnames = [name for name in funcs if fnmatch(name, rhs_pattern)]
|
|
|
|
lhs_funcnames = sorted(lhs_funcnames, key=lambda x: x.upper()[-8:])
|
|
rhs_funcnames = sorted(rhs_funcnames, key=lambda x: x.upper()[-8:])
|
|
|
|
if lhs_only_not_matched:
|
|
allowed_funcs = not_matched_functions()
|
|
lhs_funcnames = [f for f in lhs_funcnames if f in allowed_funcs]
|
|
|
|
if rhs_only_matched:
|
|
disallowed_funcs = not_matched_functions()
|
|
rhs_funcnames = [f for f in rhs_funcnames if f not in disallowed_funcs]
|
|
|
|
# Compare the functions, we love O(N^2):
|
|
results = {}
|
|
for lhs_funcname in lhs_funcnames:
|
|
distances = []
|
|
for rhs_funcname in rhs_funcnames:
|
|
if lhs_funcname == rhs_funcname:
|
|
continue
|
|
distance = calc_distance(funcs[lhs_funcname], funcs[rhs_funcname])
|
|
if distance is None:
|
|
continue
|
|
distances.append((rhs_funcname, distance))
|
|
distances = sorted(distances, key=lambda v: v[1])
|
|
distances = distances[:3] # Top 3
|
|
results[lhs_funcname] = distances
|
|
|
|
print('\r', len('Comparing... ') * ' ')
|
|
|
|
# Prepare for displaying the results
|
|
all_distances = [distance for sublist in results.values() for name, distance in sublist]
|
|
percentiles = quantiles_wrapper(sorted(all_distances), n=100)
|
|
|
|
# green - best matches
|
|
# white - in-between
|
|
# yellow - in-between
|
|
# red - worst matches
|
|
green_percentile = percentiles[25]
|
|
white_percentile = percentiles[50]
|
|
yellow_percentile = percentiles[75]
|
|
|
|
percentiles = quantiles_wrapper(sorted(len(funcs[func]) for func in lhs_funcnames), n=100)
|
|
small_function = percentiles[25]
|
|
big_function = percentiles[75]
|
|
|
|
# Display the results
|
|
for lhs_funcname in lhs_funcnames:
|
|
result = results[lhs_funcname]
|
|
|
|
function_indicator = ' '
|
|
if len(funcs[lhs_funcname]) > big_function:
|
|
function_indicator = colored('B', 'white', attrs=['bold'])
|
|
elif len(funcs[lhs_funcname]) < small_function:
|
|
function_indicator = '.'
|
|
|
|
print(f"{lhs_funcname} ({len(funcs[lhs_funcname])} bytes)".rjust(50, ' '), function_indicator, end=' | ')
|
|
for name, dist in result:
|
|
attrs = []
|
|
if dist <= green_percentile:
|
|
color = 'green'
|
|
if dist == 0:
|
|
# Perfect match!
|
|
attrs.append('reverse')
|
|
elif dist <= white_percentile:
|
|
color = 'white'
|
|
attrs.append('dark')
|
|
elif dist <= yellow_percentile:
|
|
color = 'yellow'
|
|
else:
|
|
color = 'red'
|
|
print(colored(f"{name} ({dist} diff)", color, attrs=attrs), end=' ')
|
|
print('')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|