diff --git a/common/type_system/TypeFieldLookup.cpp b/common/type_system/TypeFieldLookup.cpp index c14abd455b..84541d5fa7 100644 --- a/common/type_system/TypeFieldLookup.cpp +++ b/common/type_system/TypeFieldLookup.cpp @@ -67,6 +67,18 @@ std::string FieldReverseLookupOutput::Token::print() const { } } +std::string FieldReverseLookupOutput::print() const { + std::string result = fmt::format("[{}] {} ", result_type.print(), total_score); + if (addr_of) { + result += '&'; + } + for (const auto& tok : tokens) { + result += tok.print(); + result += ' '; + } + return result; +} + namespace { void try_reverse_lookup(const FieldReverseLookupInput& input, diff --git a/common/type_system/TypeSystem.h b/common/type_system/TypeSystem.h index 641cfa694a..146014053a 100644 --- a/common/type_system/TypeSystem.h +++ b/common/type_system/TypeSystem.h @@ -98,6 +98,7 @@ struct FieldReverseLookupOutput { FieldReverseLookupOutput() = default; FieldReverseLookupOutput(bool addr, TypeSpec type, std::vector tok) : success(true), addr_of(addr), result_type(std::move(type)), tokens(std::move(tok)) {} + std::string print() const; bool success = false; bool addr_of = false; // do we take the address of this result? diff --git a/common/util/CopyOnWrite.h b/common/util/CopyOnWrite.h index e758ce43d1..227c3dba3a 100644 --- a/common/util/CopyOnWrite.h +++ b/common/util/CopyOnWrite.h @@ -1,14 +1,6 @@ #include #include "common/util/assert.h" -/* -template -std::unique_ptr make_unique(Args&&... args) -{ - return std::unique_ptr(new T(std::forward(args)...)); -} - */ - /*! * The CopyOnWrite class acts like a value, but internally uses references to avoid copying * when it is possible to avoid it. diff --git a/decompiler/CMakeLists.txt b/decompiler/CMakeLists.txt index 9fde7dc95a..83d23b7457 100644 --- a/decompiler/CMakeLists.txt +++ b/decompiler/CMakeLists.txt @@ -44,6 +44,7 @@ add_library( IR2/FormExpressionAnalysis.cpp IR2/FormStack.cpp IR2/GenericElementMatcher.cpp + IR2/MultiTypeAnalysis.cpp IR2/OpenGoalMapping.cpp ObjectFile/LinkedObjectFile.cpp diff --git a/decompiler/Disasm/Register.cpp b/decompiler/Disasm/Register.cpp index 1b54d76a7d..dd06d5fc4c 100644 --- a/decompiler/Disasm/Register.cpp +++ b/decompiler/Disasm/Register.cpp @@ -94,45 +94,6 @@ const char* special_to_charp(uint32_t special) { } } // namespace -///////////////////////////// -// Register Class -///////////////////////////// -// A register is stored as a 16-bit integer, with the top 8 bits indicating the "kind" and the lower -// 8 bits representing the register id within that kind. If the integer is -1, it is a special -// "invalid" register used to represent an uninitialized Register. - -// Note: VI / COP2 are separate "kinds" of registers, each with 16 registers. -// It might make sense to make this a single "kind" instead? - -namespace { -constexpr int REG_CATEGORY_SHIFT = 5; -constexpr int REG_IDX_MASK = 0b11111; -} // namespace - -/*! - * Create a register. The kind and num must both be valid. - */ -Register::Register(Reg::RegisterKind kind, uint32_t num) { - // 32 regs/category at most. - id = (kind << REG_CATEGORY_SHIFT) | num; - - // check range: - switch (kind) { - case Reg::GPR: - case Reg::FPR: - case Reg::VF: - case Reg::COP0: - case Reg::VI: - assert(num < 32); - break; - case Reg::SPECIAL: - assert(num < Reg::MAX_SPECIAL); - break; - default: - assert(false); - } -} - Register::Register(const std::string& name) { // first try gprs, for (int i = 0; i < Reg::MAX_GPR; i++) { @@ -182,87 +143,4 @@ std::string Register::to_string() const { return {to_charp()}; } -/*! - * Get the register kind. - */ -Reg::RegisterKind Register::get_kind() const { - uint16_t kind = id >> REG_CATEGORY_SHIFT; - assert(kind < Reg::MAX_KIND); - return (Reg::RegisterKind)kind; -} - -/*! - * Get the GPR number. Must be a GPR. - */ -Reg::Gpr Register::get_gpr() const { - assert(get_kind() == Reg::GPR); - uint16_t kind = id & REG_IDX_MASK; - assert(kind < Reg::MAX_GPR); - return (Reg::Gpr)(kind); -} - -/*! - * Get the FPR number. Must be an FPR. - */ -uint32_t Register::get_fpr() const { - assert(get_kind() == Reg::FPR); - uint16_t kind = id & REG_IDX_MASK; - assert(kind < 32); - return kind; -} - -/*! - * Get the VF number. Must be a VF. - */ -uint32_t Register::get_vf() const { - assert(get_kind() == Reg::VF); - uint16_t kind = id & REG_IDX_MASK; - assert(kind < 32); - return kind; -} - -/*! - * Get the VI number. Must be a VI. - */ -uint32_t Register::get_vi() const { - assert(get_kind() == Reg::VI); - uint16_t kind = id & REG_IDX_MASK; - assert(kind < 32); - return kind; -} - -/*! - * Get the COP0 number. Must be a COP0. - */ -Reg::Cop0 Register::get_cop0() const { - assert(get_kind() == Reg::COP0); - uint16_t kind = id & REG_IDX_MASK; - assert(kind < Reg::MAX_COP0); - return (Reg::Cop0)(kind); -} - -/*! - * Get the PCR number. Must be a PCR. - */ -uint32_t Register::get_special() const { - assert(get_kind() == Reg::SPECIAL); - uint16_t kind = id & REG_IDX_MASK; - assert(kind < Reg::MAX_SPECIAL); - return kind; -} - -bool Register::operator==(const Register& other) const { - return id == other.id; -} - -bool Register::operator!=(const Register& other) const { - return id != other.id; -} - -bool Register::allowed_local_gpr() const { - if (get_kind() != Reg::GPR) { - return false; - } - return Reg::allowed_local_gprs[get_gpr()]; -} } // namespace decompiler \ No newline at end of file diff --git a/decompiler/Disasm/Register.h b/decompiler/Disasm/Register.h index 7cc7b50862..ad9f92694e 100644 --- a/decompiler/Disasm/Register.h +++ b/decompiler/Disasm/Register.h @@ -141,11 +141,42 @@ constexpr int MAX_VAR_REG_ID = 32 * 2; // gprs/fprs. } // namespace Reg -// Representation of a register. Uses a 16-bit integer internally. +///////////////////////////// +// Register Class +///////////////////////////// +// A register is stored as a 16-bit integer, with the top 8 bits indicating the "kind" and the lower +// 8 bits representing the register id within that kind. If the integer is -1, it is a special +// "invalid" register used to represent an uninitialized Register. + +// Note: VI / COP2 are separate "kinds" of registers, each with 16 registers. +// It might make sense to make this a single "kind" instead? class Register { + private: + static constexpr int REG_CATEGORY_SHIFT = 5; + static constexpr int REG_IDX_MASK = 0b11111; + public: Register() = default; - Register(Reg::RegisterKind kind, uint32_t num); + Register(Reg::RegisterKind kind, uint32_t num) { + // 32 regs/category at most. + id = (kind << REG_CATEGORY_SHIFT) | num; + + // check range: + switch (kind) { + case Reg::GPR: + case Reg::FPR: + case Reg::VF: + case Reg::COP0: + case Reg::VI: + assert(num < 32); + break; + case Reg::SPECIAL: + assert(num < Reg::MAX_SPECIAL); + break; + default: + assert(false); + } + } explicit Register(int reg_id) { assert(reg_id < Reg::MAX_REG_ID); id = reg_id; @@ -160,22 +191,66 @@ class Register { uint16_t reg_id() const { return id; } const char* to_charp() const; std::string to_string() const; - Reg::RegisterKind get_kind() const; + /*! + * Get the register kind. + */ + Reg::RegisterKind get_kind() const { + uint16_t kind = id >> REG_CATEGORY_SHIFT; + assert(kind < Reg::MAX_KIND); + return (Reg::RegisterKind)kind; + } bool is_vu_float() const { return get_kind() == Reg::VF || (get_kind() == Reg::SPECIAL && (get_special() == Reg::MACRO_Q || get_special() == Reg::MACRO_ACC)); } - Reg::Gpr get_gpr() const; - uint32_t get_fpr() const; - uint32_t get_vf() const; - uint32_t get_vi() const; - Reg::Cop0 get_cop0() const; - uint32_t get_special() const; - bool allowed_local_gpr() const; + bool is_gpr() const { return get_kind() == Reg::GPR; } + Reg::Gpr get_gpr() const { + assert(get_kind() == Reg::GPR); + uint16_t kind = id & REG_IDX_MASK; + assert(kind < Reg::MAX_GPR); + return (Reg::Gpr)(kind); + } - bool operator==(const Register& other) const; - bool operator!=(const Register& other) const; + uint32_t get_fpr() const { + assert(get_kind() == Reg::FPR); + uint16_t kind = id & REG_IDX_MASK; + assert(kind < 32); + return kind; + } + uint32_t get_vf() const { + assert(get_kind() == Reg::VF); + uint16_t kind = id & REG_IDX_MASK; + assert(kind < 32); + return kind; + } + uint32_t get_vi() const { + assert(get_kind() == Reg::VI); + uint16_t kind = id & REG_IDX_MASK; + assert(kind < 32); + return kind; + } + Reg::Cop0 get_cop0() const { + assert(get_kind() == Reg::COP0); + uint16_t kind = id & REG_IDX_MASK; + assert(kind < Reg::MAX_COP0); + return (Reg::Cop0)(kind); + } + uint32_t get_special() const { + assert(get_kind() == Reg::SPECIAL); + uint16_t kind = id & REG_IDX_MASK; + assert(kind < Reg::MAX_SPECIAL); + return kind; + } + bool allowed_local_gpr() const { + if (get_kind() != Reg::GPR) { + return false; + } + return Reg::allowed_local_gprs[get_gpr()]; + } + + bool operator==(const Register& other) const { return id == other.id; } + bool operator!=(const Register& other) const { return id != other.id; } bool operator<(const Register& other) const { return id < other.id; } struct hash { diff --git a/decompiler/IR2/MultiTypeAnalysis.cpp b/decompiler/IR2/MultiTypeAnalysis.cpp new file mode 100644 index 0000000000..feeb6f7e1d --- /dev/null +++ b/decompiler/IR2/MultiTypeAnalysis.cpp @@ -0,0 +1,217 @@ +/*! + * @file MultiTypeAnalysis.cpp + * The "new" type analysis pass which considers multiple possible types that can be at each + * register, due to overlapping fields in types. When it encounters a function call, set, or + * certain math operation, it will attempt to prune the decision tree to remove incompatible types. + * + * When there are multiple ways to get the same type, or the type is ambiguous, it will use the one + * with the highest score. + * + * Compared to the previous type analysis pass, there is more of an focus on being fast, as this is + * historically the slowest part of decompilation. + * + * It will attempt to propagate these decision trees across basic block boundaries, but any time + * there is a "phi node" where a registers can possible come from two different sources, it will + * prune the tree to a single decision there. + */ + +#include + +#include "common/util/assert.h" +#include "decompiler/Function/Warnings.h" +#include "MultiTypeAnalysis.h" +#include "decompiler/IR2/Env.h" + +namespace decompiler { + +using RegState = CopyOnWrite; + +bool DerefHint::matches(const FieldReverseLookupOutput& value) const { + if (value.tokens.size() != tokens.size()) { + return false; + } + + for (size_t i = 0; i < value.tokens.size(); i++) { + if (!tokens[i].matches(value.tokens[i])) { + return false; + } + } + + return true; +} + +bool DerefHint::Token::matches(const FieldReverseLookupOutput::Token& other) const { + switch (kind) { + case Kind::INTEGER: + return other.kind == FieldReverseLookupOutput::Token::Kind::CONSTANT_IDX && + other.idx == integer; + case Kind::FIELD: + return other.kind == FieldReverseLookupOutput::Token::Kind::FIELD && other.name == name; + case Kind::VAR: + return other.kind == FieldReverseLookupOutput::Token::Kind::VAR_IDX; + default: + assert(false); + } +} + +/*! + * Safely access the decision referenced by this TypeDecisionParent. + * This will work even if the actual RegisterTypeState has been modified since the reference was + * created. + */ +const PossibleType& TypeDecisionParent::get() const { + return instruction->get_const(reg).possible_types.at(type_index); +} + +/*! + * Figure out if this has been eliminated or not. Caches the result to avoid looking it up again and + * again. Elimination cannot be undone. + */ +bool PossibleType::is_valid() const { + if (!m_valid_cache) { + return false; + } + + if (parent.instruction) { + // we have a parent in the tree, check if that parent is eliminated. + if (!parent.get().is_valid()) { + m_valid_cache = false; + return false; + } + } + + return true; +} + +/*! + * If we have multiple types, pick the one with the highest deref path score. + * If warnings is set, and we have to throw away a valid type, prints a warning that we made a + * somewhat arbitrary decision to throw a possible type. + * + * After calling this, you can use get_single_tp_type and get_single_type_decision. + */ +void RegisterTypeState::reduce_to_single_type(DecompWarnings* warnings, + int op_idx, + const DerefHint* hint) { + double best_score = -std::numeric_limits::infinity(); + int best_idx = -1; + bool printed_first_warning = false; + std::string warning_string; + + // find the highest score that's valid. + for (int i = 0; i < (int)possible_types.size(); i++) { + if (possible_types[i].deref_score > best_score && possible_types[i].is_valid()) { + best_idx = i; + best_score = possible_types[i].deref_score; + } + + // if we match the hint, just use that. + if (possible_types[i].deref_path && hint->matches(*possible_types[i].deref_path)) { + best_idx = i; + warnings = nullptr; // never warn if we take the hint + break; + } + } + assert(best_idx != -1); + + // eliminate stuff that isn't the best. + for (int i = 0; i < (int)possible_types.size(); i++) { + if (i != best_idx) { + // warn if we eliminate something that is possibly valid. + if (warnings && possible_types[i].is_valid()) { + if (!printed_first_warning) { + warning_string += fmt::format("Ambiguous type selection at op {}\n", op_idx); + printed_first_warning = true; + } + if (possible_types[best_idx].deref_path) { + warning_string += fmt::format(" {}\n", possible_types[best_idx].deref_path->print()); + } else { + warning_string += fmt::format(" {}\n", possible_types[best_idx].type.print()); + } + } + + possible_types[i].eliminate(); + } + } + + // cache the winner + single_type_cache = best_idx; + + if (warnings && printed_first_warning) { + warnings->general_warning(warning_string); + } +} + +/*! + * After this has been pruned to a single type, gets that type decision. + */ +const PossibleType& RegisterTypeState::get_single_type_decision() const { + assert(single_type_cache.has_value()); + assert(possible_types.at(*single_type_cache).is_valid()); // todo remove. + return possible_types[*single_type_cache]; +} + +/*! + * After this has been pruned to a single type, gets it as a TP_Type. + */ +const TP_Type& RegisterTypeState::get_single_tp_type() const { + return get_single_type_decision().type; +} + +/*! + * If there is at least one possibility to get a desired_type, removes anything that's not a + * desired_type. If it's not possible to get a desired type, does nothing. + */ +void RegisterTypeState::try_elimination(const TypeSpec& desired_types, const TypeSystem& ts) { + std::vector to_eliminate; + int keep_count = 0; + for (int i = 0; i < (int)possible_types.size(); i++) { + const auto& possibility = possible_types[i]; + if (possibility.is_valid()) { + if (ts.tc(desired_types, possibility.type.typespec())) { + keep_count++; + } else { + to_eliminate.push_back(i); + } + } + } + + if (keep_count > 0) { + for (auto idx : to_eliminate) { + possible_types.at(idx).eliminate(); + } + } +} + +namespace { + +/*! + * Create a register type state with no parent and the given typespec. + */ +RegState make_typespec_parent_regstate(const TypeSpec& typespec) { + RegState result = make_cow +} + +/*! + * Create an instruction type state for the first instruction of a function. + */ +InstrTypeState construct_initial_typestate(const TypeSpec& function_type, + const Env& env, + const RegState& uninitialized) { + // start with everything unintialized + InstrTypeState result(uninitialized); + assert(function_type.base_type() == "function"); + assert(function_type.arg_count() >= 1); // must know the function type. + assert(function_type.arg_count() <= 8 + 1); // 8 args + 1 return. + + for (int i = 0; i < int(function_type.arg_count()) - 1; i++) { + auto reg_id = Register::get_arg_reg(i); + const auto& reg_type = function_type.get_arg(i); + result.get(Register(Reg::GPR, reg_id)) = TP_Type::make_from_ts(reg_type); + } + +} + +} // namespace + +} // namespace decompiler \ No newline at end of file diff --git a/decompiler/IR2/MultiTypeAnalysis.h b/decompiler/IR2/MultiTypeAnalysis.h new file mode 100644 index 0000000000..e5bc32d64a --- /dev/null +++ b/decompiler/IR2/MultiTypeAnalysis.h @@ -0,0 +1,89 @@ +#pragma once + +#include +#include +#include "common/util/CopyOnWrite.h" +#include "decompiler/Disasm/Register.h" +#include "decompiler/util/TP_Type.h" +#include "common/type_system/TypeSystem.h" + +namespace decompiler { + +class InstrTypeState; +class DecompWarnings; +struct PossibleType; + +struct DerefHint { + struct Token { + enum class Kind { + INTEGER, FIELD, VAR, INVALID + } kind = Kind::INVALID; + int integer = 0; + std::string name; + bool matches(const FieldReverseLookupOutput::Token& other)const ; + }; + std::vector tokens; + + bool matches(const FieldReverseLookupOutput& value)const ; +}; + +/*! + * Represents a reference to a type decision made on a previous instruction. + */ +struct TypeDecisionParent { + InstrTypeState* instruction = nullptr; + Register reg; + int type_index = -1; + + const PossibleType& get() const; +}; + +/*! + * Represents a possibility for the type in a register. + * Can be "invalid", meaning it is eliminated from the possible types due to a constraint. + * Use is_valid to check that it hasn't been eliminated. + */ +struct PossibleType { + TP_Type type; // the actual type. + std::optional deref_path; // the field accessed to get here + double deref_score = 0.; + TypeDecisionParent parent; // the decision we made to allow this. + void eliminate() { m_valid_cache = false; } + bool is_valid() const; // true, unless we were eliminated. + + private: + mutable bool m_valid_cache = true; +}; + +/*! + * The set of all possible types in a register. + */ +struct RegisterTypeState { + std::optional single_type_cache; + std::vector possible_types; + void reduce_to_single_type(DecompWarnings* warnings, int op_idx, const DerefHint* hint); + const PossibleType& get_single_type_decision() const; + const TP_Type& get_single_tp_type() const; + void try_elimination(const TypeSpec& desired_types, const TypeSystem& ts); +}; + +class InstrTypeState { + public: + explicit InstrTypeState(const CopyOnWrite& default_value) { + m_regs.fill(default_value); + } + + const RegisterTypeState& get_const(const Register& reg) const { + assert(reg.reg_id() < Reg::MAX_VAR_REG_ID); + return *m_regs[reg.reg_id()]; + } + + CopyOnWrite& get(const Register& reg) { + assert(reg.reg_id() < Reg::MAX_VAR_REG_ID); + return m_regs[reg.reg_id()]; + } + + private: + std::array, Reg::MAX_VAR_REG_ID> m_regs; +}; +} // namespace decompiler \ No newline at end of file diff --git a/decompiler/analysis/type_analysis.h b/decompiler/analysis/type_analysis.h index 7957f77fdb..1c1da853b5 100644 --- a/decompiler/analysis/type_analysis.h +++ b/decompiler/analysis/type_analysis.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include "common/type_system/TypeSpec.h" @@ -10,4 +11,5 @@ namespace decompiler { bool run_type_analysis_ir2(const TypeSpec& my_type, DecompilerTypeSystem& dts, Function& func); -} \ No newline at end of file + +} // namespace decompiler \ No newline at end of file