cleanup of python script

This commit is contained in:
Julgodis
2021-01-29 18:57:04 +01:00
parent 4d2d73d15c
commit b29bb9c1d0
+323 -67
View File
@@ -1,10 +1,24 @@
#!/usr/bin/env python3
# PYTHON_ARGCOMPLETE_OK
"""
This script will extract literal and strings data
from secific section located in the baserom.dol.
Useful when trying to match .rodata and .sdata2
in translation units.
usage:
./tools/section2cpp.py --section .rodata --string --object JKRSolidHeap.o
"""
import argparse
import sys
import os
import struct
import shlex
from decimal import getcontext, Decimal
from pathlib import Path, PurePath, PureWindowsPath
from typing import (
Any,
@@ -21,13 +35,19 @@ from typing import (
Pattern,
)
try:
import numpy
except:
print("error: missing numpy")
sys.exit(1)
try:
import argcomplete # type: ignore
except ModuleNotFoundError:
argcomplete = None
parser = argparse.ArgumentParser(description="Extract section data and generate C++ code (arrays).")
parser = argparse.ArgumentParser(description="Extract section data and generate C++ code.")
parser.add_argument(
"--section",
@@ -51,8 +71,7 @@ parser.add_argument(
dest="object_name",
type=str,
metavar="OBJECT",
help="OBJECT filename to extract data from. (e.g. JKRSolidHeap.o)",
required=True
help="OBJECT filename to extract data from. (e.g. JKRSolidHeap.o)"
)
parser.add_argument(
@@ -71,6 +90,20 @@ parser.add_argument(
help="Print arrays as strings"
)
parser.add_argument(
"--array",
dest="as_array",
action="store_true",
help="Print everything as u8 arrays"
)
parser.add_argument(
"--shift-jis",
dest="shift_jis",
action="store_true",
help="Convert shift-jis to utf-8"
)
#
#
@@ -90,6 +123,9 @@ def magicsplit(l, *splitters):
return [subl for subl in _itersplit(l, splitters) ]
def str_encoding(data):
if data[-1] != 0:
return None
try:
data.decode("utf-8")
return "utf-8"
@@ -100,13 +136,23 @@ def str_encoding(data):
data.decode("shift_jisx0213")
return "shift-jis"
except:
pass
pass
return None, None
return None
def encoding_char_list(encoding, data):
if args.shift_jis and encoding == "shift-jis":
try:
return escape(data.decode("shift_jisx0213"))
except:
pass
return [ str(bytes([x]))[2:-1].replace("\"", "\\\"") for x in data ]
def raw_string(data):
assert data[-1] == 0
return str(data[:-1])[2:-1].replace("\"", "\\\"")
return "".join(data)
def raw_array(data):
return ",".join([hex(x) for x in list(data)])
def escape_char(v):
if v == "\n":
@@ -135,6 +181,75 @@ def escape_char(v):
def escape(v):
return "".join([ escape_char(x) for x in list(v) ])
def bytes2float32(data):
if len(data) < 4:
return None
result = numpy.frombuffer(data[0:4][::-1], dtype='float32')
if result:
return result[0]
else:
return None
def bytes2float64(data):
if len(data) < 8:
return None
result = numpy.frombuffer(data[0:8][::-1], dtype='float64')
if result:
return result[0]
else:
return None
def is_nice_float32(f):
try:
if int(f*1000) == f*1000:
return True
if int(f*100) == f*100:
return True
if int(f*10) == f*10:
return True
if int(f) == f:
return True
except:
return False
return False
def is_nice_float64(f):
try:
if int(f*1000) == f*1000:
return True
if int(f*100) == f*100:
return True
if int(f*10) == f*10:
return True
if int(f) == f:
return True
except:
return False
return False
float32_exact: Dict[numpy.float32, Tuple[int,int]] = {}
float64_exact: Dict[numpy.float64, Tuple[int,int]] = {}
getcontext().prec = 64
for i in range(1,32):
for j in range(1,32):
if i%j == 0:
continue
d = Decimal(i)/Decimal(j)
f = numpy.float32(d)
if str(f) != str(d):
if not f in float32_exact:
float32_exact[f] = (i,j)
for i in range(1,32):
for j in range(1,32):
if i%j == 0:
continue
d = Decimal(i)/Decimal(j)
f = numpy.float64(d)
if str(f) != str(d):
if not f in float64_exact:
float64_exact[f] = (i,j)
class Symbol:
def __init__(self, name, addr, size):
@@ -224,85 +339,226 @@ def find_symbols():
last_addr = last_symbol.addr + last_symbol.size
last_symbol.padding = ((last_addr + 31) & ~31) - last_addr
file.close()
def chunks(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i + n]
def data_as_string(data):
return ", ".join([ "0x" + hex(x)[2:].rjust(2, '0') for x in data ])
class Literal:
def __init__(self, name, type, value, comment=None):
self.name = name
self.type = type
self.value = value
self.comment = comment
def format(self):
return str(self.value)
def lines(self):
line = "static const %s %s = %s;" % (self.type, self.name, self.format())
if self.comment:
line = line.ljust(90, ' ') + " // " + self.comment
return [ line ]
def __str__(self):
return "\n".join(self.lines())
class Label(Literal):
def __init__(self, name):
super().__init__(name, "", None, None)
def lines(self):
return [ "", "", "// " + self.name ]
class Float32Literal(Literal):
def __init__(self, name, value, comment=None):
super().__init__(name, "float", value, comment)
def format(self):
return "%sf" % self.value
class Float64Literal(Literal):
def __init__(self, name, value, comment=None):
super().__init__(name, "double", value, comment)
class FractionFloat32Literal(Literal):
def __init__(self, name, value, comment=None):
super().__init__(name, "float", value, comment)
def format(self):
return "%i.0f / %i.0f" % self.value
class FractionFloat64Literal(Literal):
def __init__(self, name, value, comment=None):
super().__init__(name, "double", value, comment)
def format(self):
return "%i.0 / %i.0" % self.value
class U32Literal(Literal):
def __init__(self, name, value, comment=None):
super().__init__(name, "u32", value, comment)
class S32Literal(Literal):
def __init__(self, name, value, comment=None):
super().__init__(name, "s32", value, comment)
class S64Literal(Literal):
def __init__(self, name, value, comment=None):
super().__init__(name, "s64", value, comment)
class U64Literal(Literal):
def __init__(self, name, value, comment=None):
super().__init__(name, "u64", value, comment)
class ArrayLiteral(Literal):
def __init__(self, name, value, comment=None):
super().__init__(name, "u8", value, comment)
def lines(self):
one_line = "static const %s %s[%i] = { %s };" % (self.type, self.name, len(self.value), data_as_string(self.value))
lines = []
if len(one_line) < 90:
lines += [ one_line ]
else:
lines += [ "static const %s %s[%i] = {" % (self.type, self.name, len(self.value)) ]
data_chunks = chunks(list(self.value), 16)
for chunk in data_chunks:
lines += [ " " + data_as_string(chunk) ]
lines += [ "};" ]
if lines and self.comment:
lines[0] = lines[0].ljust(90, ' ') + " // " + self.comment
return lines
class StringLiteral(Literal):
def __init__(self, name, encoding, value, comment=None):
assert value[-1] == 0
super().__init__(name, "char", value[:-1], comment)
self.encoding = encoding
def lines(self):
char_list = encoding_char_list(self.encoding, self.value)
one_line = "static const %s %s = \"%s\";" % (self.type, self.name, raw_string(char_list))
lines = []
if len(one_line) < 90:
lines += [ one_line ]
else:
lines += [ "static const %s %s = " % (self.type, self.name) ]
data_chunks = chunks(char_list, 16)
for chunk in data_chunks:
lines += [ " \"%s\"" % raw_string(chunk) ]
lines[-1] += ";"
if lines and self.comment:
lines[0] = lines[0].ljust(90, ' ') + " // " + self.comment
return lines
def output_cpp():
if not object_name in object_map:
print("error: %s object file not found!" % object_name)
sys.exit(1)
object_names = []
if object_name:
if not object_name in object_map:
print("error: %s object file not found!" % object_name)
sys.exit(1)
object_names += [ object_name ]
else:
object_names = [*object_map.keys()]
br = baserom.open("rb")
br.seek(0, os.SEEK_END)
br_size = br.tell()
br.seek(0, os.SEEK_SET)
obj = object_map[object_name]
for symbol in obj.symbols:
literals = []
for obj_name in object_names:
literals += [ Label(obj_name) ]
label = "lbl_%s" % (hex(symbol.addr).upper()[2:])
obj = object_map[obj_name]
for symbol in obj.symbols:
label = "lbl_%s" % (hex(symbol.addr).upper()[2:])
symbol_file_offset = symbol.addr - file_offset
symbol_file_size = symbol.size + symbol.padding
symbol_file_offset = symbol.addr - file_offset
symbol_file_size = symbol.size + symbol.padding
if symbol_file_offset + symbol_file_size > br_size:
print("error: reading outside baserom file. (%i, %i)" % (symbol_file_offset + symbol_file_size, br_size))
if symbol_file_offset + symbol_file_size > br_size:
print("error: reading outside baserom file. (%i, %i)" % (symbol_file_offset + symbol_file_size, br_size))
br.seek(symbol_file_offset, os.SEEK_SET)
data = br.read(symbol.size)
padding = br.read(symbol.padding)
br.seek(symbol_file_offset, os.SEEK_SET)
data = br.read(symbol.size)
padding = br.read(symbol.padding)
if args.as_string:
offset = 0
str_segments = [ x for x in magicsplit(data, 0) ]
for segment in str_segments[:-1]:
str_data = bytes(segment + [0])
encoding = str_encoding(str_data)
value = "???"
if len(data) == 4:
u32_data = struct.unpack('>I', data)[0]
s32_data = struct.unpack('>i', data)[0]
float_data = struct.unpack('>f', data)[0]
if s32_data == 0 or (s32_data >= -4096 and s32_data <= 4096):
value = str(s32_data)
elif u32_data == 0 or u32_data <= 4096:
value = str(u32_data)
elif int(float_data) == float_data and float_data >= -4096 and float_data <= 4096:
value = "%sf (%s)" % (str(float_data), hex(u32_data))
elif len(data) == 8:
u64_data = struct.unpack('>Q', data)[0]
s64_data = struct.unpack('>q', data)[0]
double_data = struct.unpack('>d', data)[0]
str_label = "lbl_%s" % (hex(symbol.addr + offset).upper()[2:])
if encoding == "shift-jis":
literals += [ StringLiteral(str_label, "shift-jis", str_data, "TODO: shift-jis strings in Metrowerks") ]
elif encoding == "utf-8":
literals += [ StringLiteral(str_label, "utf-8", str_data) ]
else:
literals += [ ArrayLiteral(str_label, str_data, "undecodable string") ]
offset += len(str_data)
if s64_data == 0 or (s64_data >= -4096 and s64_data <= 4096):
value = str(s64_data)
elif u64_data == 0 or u64_data <= 4096:
value = str(u64_data)
elif int(double_data) == double_data and double_data >= -4096 and double_data <= 4096:
value = "%s (%s)" % (str(double_data), hex(u64_data))
if padding:
padding_label = "lbl_%s" % (hex(symbol.addr + symbol.size).upper()[2:])
literals += [ StringLiteral(padding_label, None, padding, "padding") ]
padding = None
elif args.as_array:
literals += [ ArrayLiteral(label, data) ]
else:
lit = None
if len(data) == 4:
u32_data = struct.unpack('>I', data)[0]
s32_data = struct.unpack('>i', data)[0]
float_data = bytes2float32(data)
if s32_data == 0 or (s32_data >= -4096 and s32_data <= 4096):
lit = S32Literal(label, s32_data)
elif u32_data == 0 or (u32_data < 4096):
lit = U32Literal(label, u32_data)
elif float_data in float32_exact:
lit = FractionFloat32Literal(label, float32_exact[float_data], "%sf %s" % (float_data, hex(u32_data)))
elif is_nice_float32(float_data):
lit = Float32Literal(label, float_data, hex(u32_data))
print("// %s %s %s = %s" % (label, obj.path, symbol.name, value))
if args.as_string:
offset = 0
str_segments = [ x + [0] for x in magicsplit(data, 0) ]
for segment in str_segments[:-1]:
str_data = bytes(segment)
encoding = str_encoding(str_data)
elif len(data) == 8:
u64_data = struct.unpack('>Q', data)[0]
s64_data = struct.unpack('>q', data)[0]
double_data = bytes2float64(data)
str_label = "lbl_%s" % (hex(symbol.addr + offset).upper()[2:])
if encoding == "shift-jis" :
print("const char* %s = \"%s\"; /* shift-jis encoded (TODO) */" % (str_label, raw_string(str_data)))
elif encoding == "utf-8" :
print("const char* %s = \"%s\";" % (str_label, raw_string(str_data)))
else:
print("const char* %s = \"%s\"; /* undecodable string */" % (str_label, raw_string(str_data)))
offset += len(str_data)
if padding:
padding_label = "lbl_%s" % (hex(symbol.addr + symbol.size).upper()[2:])
print("const char* %s = \"%s\"; /* padding */" % (padding_label, raw_string(padding)))
else:
cpp_array = ",".join([hex(x) for x in list(data)])
print("static const u8 %s[%i] = { %s };" % (label, len(data), cpp_array))
if u64_data == 0x4330000000000000:
lit = Float64Literal(label, double_data, "%s | u32 to float (compiler-generated)" % hex(u64_data))
elif u64_data == 0x4330000080000000:
lit = Float64Literal(label, double_data, "%s | s32 to float (compiler-generated)" % hex(u64_data))
elif s64_data == 0 or (s64_data >= -4096 and s64_data <= 4096):
lit = S64Literal(label, s64_data)
elif u64_data == 0 or (u64_data < 4096):
lit = U64Literal(label, u64_data)
elif double_data in float64_exact:
lit = FractionFloat64Literal(label, float64_exact[double_data], "%s %s" % (double_data, hex(u64_data)))
elif is_nice_float64(double_data):
lit = Float64Literal(label, double_data, hex(u64_data))
if not lit:
lit = ArrayLiteral(label, data)
literals += [ lit ]
if padding:
padding_label = "lbl_%s" % (hex(symbol.addr + symbol.size).upper()[2:])
cpp_array = ",".join([hex(x) for x in list(padding)])
print("static const u8 %s[%i] = { %s }; /* padding */" % (padding_label, len(padding), cpp_array))
literals += [ ArrayLiteral(padding_label, data, "padding") ]
for lit in literals:
print(lit)
br.close()