diff --git a/common/type_system/TypeFieldLookup.cpp b/common/type_system/TypeFieldLookup.cpp index f1bf59bba3..c14abd455b 100644 --- a/common/type_system/TypeFieldLookup.cpp +++ b/common/type_system/TypeFieldLookup.cpp @@ -196,7 +196,10 @@ void try_reverse_lookup_array_like(const FieldReverseLookupInput& input, vec.push_back(tok); output->results.emplace_back(false, array_data_type, vec); } else { - output->results.emplace_back(false, input.base_type, parent->to_vector()); + auto parent_vector = parent->to_vector(); + if (!parent_vector.empty()) { + output->results.emplace_back(false, input.base_type, parent_vector); + } } } @@ -277,7 +280,11 @@ void try_reverse_lookup_inline_array(const FieldReverseLookupInput& input, // can we just return the array? if (expected_offset_into_elt == offset_into_elt && !input.deref.has_value() && elt_idx == 0) { - output->results.emplace_back(false, input.base_type, parent->to_vector()); + auto parent_vec = parent->to_vector(); + if (!parent_vec.empty()) { + output->results.emplace_back(false, input.base_type, parent->to_vector()); + } + if ((int)output->results.size() >= max_count) { return; } diff --git a/common/util/CopyOnWrite.h b/common/util/CopyOnWrite.h new file mode 100644 index 0000000000..e758ce43d1 --- /dev/null +++ b/common/util/CopyOnWrite.h @@ -0,0 +1,122 @@ +#include +#include "common/util/assert.h" + +/* +template +std::unique_ptr make_unique(Args&&... args) +{ + return std::unique_ptr(new T(std::forward(args)...)); +} + */ + +/*! + * The CopyOnWrite class acts like a value, but internally uses references to avoid copying + * when it is possible to avoid it. + * + * It is used like a shared pointer. + * But, if you try to modify an existing object with multiple owners, it will make a copy + * so the other owners don't see any changes. In this way, it does not act like a reference. + * + * To construct a new object, use CopyOnWrite(args...). This is different from the usual smart + * pointer pattern. + * + * Like shared pointers, a CopyOnWrite can be null. Doing mut() just gives you a null pointer. + * + * The default .get(), ->, and * operators give you const references. If you need to modify, + * use .mut(). It will create a copy if needed, then give you a mutable reference. + */ +template +class CopyOnWrite { + private: + // we store the object and its reference count in the same heap allocation. + struct ObjectAndCount { + T object; + + // construct the object in-place, or copy construct from an existing. + template + explicit ObjectAndCount(Args&&... args) : object(std::forward(args)...) {} + explicit ObjectAndCount(const T& existing) : object(existing) {} + + // in case we ever want this to have locks. + void add_ref() { m_count++; } + void remove_ref() { m_count--; } + bool unique() { return m_count == 1; } + bool dead() { return m_count == 0; } + + private: + int m_count = 0; + }; + + public: + CopyOnWrite() = default; // allow nulls. + + /*! + * Construct a new object. + */ + template + explicit CopyOnWrite(Args&&... args) { + auto obj = new ObjectAndCount(std::forward(args)...); + acquire_object(obj); + } + + /*! + * Copy an object. + */ + CopyOnWrite(const CopyOnWrite& other) { acquire_object(other.m_data); } + + CopyOnWrite& operator=(const CopyOnWrite& other) { + if (this == &other) { + return *this; + } + + if (m_data != other.m_data) { + clear_my_object(); + acquire_object(other.m_data); + } + return *this; + } + + ~CopyOnWrite() { clear_my_object(); } + + // constant access + const T* get() const { return &m_data->object; } + const T* operator->() const { return &m_data->object; } + const T& operator*() const { return m_data->object; } + explicit operator bool() const { return m_data; } + + T* mut() { + if (!m_data) { + return nullptr; + } + + if (!m_data->unique()) { + assert(!m_data->dead()); + m_data->remove_ref(); // don't need to check for dead here, there's another ref somewhere. + assert(!m_data->dead()); + m_data = new ObjectAndCount(m_data->object); + m_data->add_ref(); + } + return &m_data->object; + } + + private: + void clear_my_object() { + if (m_data) { + m_data->remove_ref(); + if (m_data->dead()) { + delete m_data; + } + } + m_data = nullptr; + } + + void acquire_object(ObjectAndCount* obj) { + assert(!m_data); + m_data = obj; + if (obj) { + m_data->add_ref(); + } + } + + ObjectAndCount* m_data = nullptr; +}; diff --git a/decompiler/Disasm/InstructionDecode.cpp b/decompiler/Disasm/InstructionDecode.cpp index 34f933aff9..c5733abeb8 100644 --- a/decompiler/Disasm/InstructionDecode.cpp +++ b/decompiler/Disasm/InstructionDecode.cpp @@ -1086,7 +1086,7 @@ Instruction decode_instruction(LinkedWord& word, LinkedObjectFile& file, int seg atom.set_reg(Register(Reg::COP0, value)); break; case DecodeType::PCR: - atom.set_reg(Register(Reg::PCR, value)); + atom.set_reg(Register(Reg::SPECIAL, Reg::PCR0 + value)); break; case DecodeType::IMM: atom.set_imm(value); diff --git a/decompiler/Disasm/Register.cpp b/decompiler/Disasm/Register.cpp index 94ad2de418..1b54d76a7d 100644 --- a/decompiler/Disasm/Register.cpp +++ b/decompiler/Disasm/Register.cpp @@ -56,9 +56,7 @@ const static char* vi_names[32] = { "Status", "MAC", "Clipping", "INVALID3", "vi_R", "vi_I", "vi_Q", "INVALID7", "INVALID8", "INVALID9", "TPC", "CMSAR0", "FBRST", "VPU-STAT", "INVALID14", "CMSAR1"}; -const static char* pcr_names[2] = {"pcr0", "pcr1"}; - -const static char* cop2_macro_special[2] = {"Q", "ACC"}; +const static char* special_names[Reg::MAX_SPECIAL] = {"pcr0", "pcr1", "Q", "ACC"}; ///////////////////////////// // Register Names Conversion @@ -90,14 +88,9 @@ const char* vi_to_charp(uint32_t vi) { return vi_names[vi]; } -const char* pcr_to_charp(uint32_t pcr) { - assert(pcr < 2); - return pcr_names[pcr]; -} - -const char* cop2_macro_special_to_charp(uint32_t reg) { - assert(reg < 2); - return cop2_macro_special[reg]; +const char* special_to_charp(uint32_t special) { + assert(special < Reg::MAX_SPECIAL); + return special_names[special]; } } // namespace @@ -111,11 +104,17 @@ const char* cop2_macro_special_to_charp(uint32_t reg) { // Note: VI / COP2 are separate "kinds" of registers, each with 16 registers. // It might make sense to make this a single "kind" instead? +namespace { +constexpr int REG_CATEGORY_SHIFT = 5; +constexpr int REG_IDX_MASK = 0b11111; +} // namespace + /*! * Create a register. The kind and num must both be valid. */ Register::Register(Reg::RegisterKind kind, uint32_t num) { - id = (kind << 8) | num; + // 32 regs/category at most. + id = (kind << REG_CATEGORY_SHIFT) | num; // check range: switch (kind) { @@ -126,9 +125,8 @@ Register::Register(Reg::RegisterKind kind, uint32_t num) { case Reg::VI: assert(num < 32); break; - case Reg::PCR: - case Reg::COP2_MACRO_SPECIAL: - assert(num < 2); + case Reg::SPECIAL: + assert(num < Reg::MAX_SPECIAL); break; default: assert(false); @@ -139,7 +137,7 @@ Register::Register(const std::string& name) { // first try gprs, for (int i = 0; i < Reg::MAX_GPR; i++) { if (name == gpr_names[i]) { - id = (Reg::GPR << 8) | i; + id = (Reg::GPR << REG_CATEGORY_SHIFT) | i; return; } } @@ -147,7 +145,7 @@ Register::Register(const std::string& name) { // next fprs for (int i = 0; i < 32; i++) { if (name == fpr_names[i]) { - id = (Reg::FPR << 8) | i; + id = (Reg::FPR << REG_CATEGORY_SHIFT) | i; return; } } @@ -170,10 +168,8 @@ const char* Register::to_charp() const { return vi_to_charp(get_vi()); case Reg::COP0: return cop0_to_charp(get_cop0()); - case Reg::PCR: - return pcr_to_charp(get_pcr()); - case Reg::COP2_MACRO_SPECIAL: - return cop2_macro_special_to_charp(get_cop2_macro_special()); + case Reg::SPECIAL: + return special_to_charp(get_special()); default: throw std::runtime_error("Unsupported Register"); } @@ -190,7 +186,7 @@ std::string Register::to_string() const { * Get the register kind. */ Reg::RegisterKind Register::get_kind() const { - uint16_t kind = id >> 8; + uint16_t kind = id >> REG_CATEGORY_SHIFT; assert(kind < Reg::MAX_KIND); return (Reg::RegisterKind)kind; } @@ -200,7 +196,7 @@ Reg::RegisterKind Register::get_kind() const { */ Reg::Gpr Register::get_gpr() const { assert(get_kind() == Reg::GPR); - uint16_t kind = id & 0xff; + uint16_t kind = id & REG_IDX_MASK; assert(kind < Reg::MAX_GPR); return (Reg::Gpr)(kind); } @@ -210,7 +206,7 @@ Reg::Gpr Register::get_gpr() const { */ uint32_t Register::get_fpr() const { assert(get_kind() == Reg::FPR); - uint16_t kind = id & 0xff; + uint16_t kind = id & REG_IDX_MASK; assert(kind < 32); return kind; } @@ -220,7 +216,7 @@ uint32_t Register::get_fpr() const { */ uint32_t Register::get_vf() const { assert(get_kind() == Reg::VF); - uint16_t kind = id & 0xff; + uint16_t kind = id & REG_IDX_MASK; assert(kind < 32); return kind; } @@ -230,7 +226,7 @@ uint32_t Register::get_vf() const { */ uint32_t Register::get_vi() const { assert(get_kind() == Reg::VI); - uint16_t kind = id & 0xff; + uint16_t kind = id & REG_IDX_MASK; assert(kind < 32); return kind; } @@ -240,7 +236,7 @@ uint32_t Register::get_vi() const { */ Reg::Cop0 Register::get_cop0() const { assert(get_kind() == Reg::COP0); - uint16_t kind = id & 0xff; + uint16_t kind = id & REG_IDX_MASK; assert(kind < Reg::MAX_COP0); return (Reg::Cop0)(kind); } @@ -248,20 +244,13 @@ Reg::Cop0 Register::get_cop0() const { /*! * Get the PCR number. Must be a PCR. */ -uint32_t Register::get_pcr() const { - assert(get_kind() == Reg::PCR); - uint16_t kind = id & 0xff; - assert(kind < 2); +uint32_t Register::get_special() const { + assert(get_kind() == Reg::SPECIAL); + uint16_t kind = id & REG_IDX_MASK; + assert(kind < Reg::MAX_SPECIAL); return kind; } -Reg::Cop2MacroSpecial Register::get_cop2_macro_special() const { - assert(get_kind() == Reg::COP2_MACRO_SPECIAL); - uint16_t k = id & 0xff; - assert(k < 2); - return (Reg::Cop2MacroSpecial)k; -} - bool Register::operator==(const Register& other) const { return id == other.id; } diff --git a/decompiler/Disasm/Register.h b/decompiler/Disasm/Register.h index d44dd68967..7cc7b50862 100644 --- a/decompiler/Disasm/Register.h +++ b/decompiler/Disasm/Register.h @@ -11,17 +11,22 @@ namespace decompiler { // Namespace for register name constants + +// Note on registers: +// Registers are assigned a unique Register ID as an integer from 0 to 164 (not including 164). +// Don't change these enums without updating the indexing scheme. +// It is important that each register is a unique register ID, and that we don't have gaps. + namespace Reg { enum RegisterKind { - GPR = 0, // EE General purpose registers, these have nicknames. - FPR = 1, // EE Floating point registers, just called f0 - f31 - VF = 2, // VU0 Floating point vector registers from EE, just called vf0 - vf31 - VI = - 3, // VU0 Integer registers from EE, the first 16 are vi00 - vi15, the rest are control regs. - COP0 = 4, // EE COP0 Control Registers: full of fancy names (there are 32 of them) - PCR = 5, // Performance Counter registers (PCR0, PCR1) - COP2_MACRO_SPECIAL = 6, // COP2 Q, ACC accessed from macro mode instructions. - MAX_KIND = 7 + GPR = 0, // EE General purpose registers, these have nicknames (32 regs) + FPR = 1, // EE Floating point registers, just called f0 - f31 (32 regs) + VF = 2, // VU0 Floating point vector registers from EE, just called vf0 - vf31 (32 regs) + VI = 3, // VU0 Integer registers from EE, the first 16 are vi00 - vi15, the rest are control + // regs. (32 regs) + COP0 = 4, // EE COP0 Control Registers: full of fancy names (there are 32 of them) (32 regs) + SPECIAL = 5, // COP2 Q, ACC accessed from macro mode instructions and PCR + MAX_KIND = 6 }; // nicknames for GPRs @@ -121,35 +126,52 @@ enum Vi { MAX_COP2 = 32 }; -enum Cop2MacroSpecial { - MACRO_Q = 0, - MACRO_ACC = 1, +enum SpecialRegisters { + PCR0 = 0, + PCR1 = 1, + MACRO_Q = 2, + MACRO_ACC = 3, + MAX_SPECIAL = 4, }; const extern bool allowed_local_gprs[Reg::MAX_GPR]; +constexpr int MAX_REG_ID = 32 * 5 + MAX_SPECIAL; +constexpr int MAX_VAR_REG_ID = 32 * 2; // gprs/fprs. + } // namespace Reg -// Representation of a register. Uses a 32-bit integer internally. +// Representation of a register. Uses a 16-bit integer internally. class Register { public: Register() = default; Register(Reg::RegisterKind kind, uint32_t num); + explicit Register(int reg_id) { + assert(reg_id < Reg::MAX_REG_ID); + id = reg_id; + } + Register(const std::string& name); static Register get_arg_reg(int idx) { assert(idx >= 0 && idx < 8); return Register(Reg::GPR, Reg::A0 + idx); } + + uint16_t reg_id() const { return id; } const char* to_charp() const; std::string to_string() const; Reg::RegisterKind get_kind() const; + bool is_vu_float() const { + return get_kind() == Reg::VF || + (get_kind() == Reg::SPECIAL && + (get_special() == Reg::MACRO_Q || get_special() == Reg::MACRO_ACC)); + } Reg::Gpr get_gpr() const; uint32_t get_fpr() const; uint32_t get_vf() const; uint32_t get_vi() const; Reg::Cop0 get_cop0() const; - uint32_t get_pcr() const; - Reg::Cop2MacroSpecial get_cop2_macro_special() const; + uint32_t get_special() const; bool allowed_local_gpr() const; bool operator==(const Register& other) const; diff --git a/decompiler/IR2/AtomicOp.cpp b/decompiler/IR2/AtomicOp.cpp index ee19ae4da2..ddcce356b6 100644 --- a/decompiler/IR2/AtomicOp.cpp +++ b/decompiler/IR2/AtomicOp.cpp @@ -573,20 +573,20 @@ void AsmOp::update_register_info() { if (m_instr.kind >= FIRST_COP2_MACRO && m_instr.kind <= LAST_COP2_MACRO) { switch (m_instr.kind) { case InstructionKind::VMSUBQ: - m_read_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_Q)); - m_read_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_ACC)); + m_read_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_Q)); + m_read_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_ACC)); break; case InstructionKind::VMULAQ: - m_read_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_Q)); - m_write_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_ACC)); + m_read_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_Q)); + m_write_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_ACC)); break; // Read Q register case InstructionKind::VADDQ: case InstructionKind::VSUBQ: case InstructionKind::VMULQ: - m_read_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_Q)); + m_read_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_Q)); break; // Write ACC register @@ -595,14 +595,14 @@ void AsmOp::update_register_info() { case InstructionKind::VMULA: case InstructionKind::VMULA_BC: case InstructionKind::VOPMULA: - m_write_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_ACC)); + m_write_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_ACC)); break; // Write Q register case InstructionKind::VDIV: case InstructionKind::VSQRT: case InstructionKind::VRSQRT: - m_write_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_Q)); + m_write_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_Q)); break; // Read acc register @@ -610,18 +610,18 @@ void AsmOp::update_register_info() { case InstructionKind::VMADD_BC: case InstructionKind::VMSUB: case InstructionKind::VMSUB_BC: - m_read_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_ACC)); + m_read_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_ACC)); break; case InstructionKind::VOPMSUB: - m_read_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_ACC)); + m_read_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_ACC)); break; // Read/Write acc register case InstructionKind::VMADDA: case InstructionKind::VMADDA_BC: case InstructionKind::VMSUBA_BC: - m_write_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_ACC)); - m_read_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_ACC)); + m_write_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_ACC)); + m_read_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_ACC)); break; case InstructionKind::VMOVE: diff --git a/decompiler/IR2/Form.cpp b/decompiler/IR2/Form.cpp index 53eca6a596..7acfc350e6 100644 --- a/decompiler/IR2/Form.cpp +++ b/decompiler/IR2/Form.cpp @@ -571,22 +571,19 @@ void OpenGoalAsmOpElement::collect_vars(RegAccessSet& vars, bool) const { void OpenGoalAsmOpElement::collect_vf_regs(RegSet& regs) const { for (auto r : m_op->read_regs()) { - if (r.get_kind() == Reg::RegisterKind::VF || - r.get_kind() == Reg::RegisterKind::COP2_MACRO_SPECIAL) { + if (r.is_vu_float()) { regs.insert(r); } } for (auto r : m_op->write_regs()) { - if (r.get_kind() == Reg::RegisterKind::VF || - r.get_kind() == Reg::RegisterKind::COP2_MACRO_SPECIAL) { + if (r.is_vu_float()) { regs.insert(r); } } for (auto r : m_op->clobber_regs()) { - if (r.get_kind() == Reg::RegisterKind::VF || - r.get_kind() == Reg::RegisterKind::COP2_MACRO_SPECIAL) { + if (r.is_vu_float()) { regs.insert(r); } } @@ -947,8 +944,7 @@ void RLetElement::apply(const std::function& f) { goos::Object RLetElement::reg_list() const { std::vector regs; for (auto& reg : sorted_regs) { - if (reg.get_kind() == Reg::RegisterKind::VF || - reg.get_kind() == Reg::RegisterKind::COP2_MACRO_SPECIAL) { + if (reg.is_vu_float()) { std::string reg_name = reg.to_string() == "ACC" ? "acc" : reg.to_string(); regs.push_back( pretty_print::build_list(pretty_print::to_symbol(fmt::format("{} :class vf", reg_name)))); diff --git a/decompiler/ObjectFile/ObjectFileDB_IR2.cpp b/decompiler/ObjectFile/ObjectFileDB_IR2.cpp index 16ded1d9cf..86878ff43f 100644 --- a/decompiler/ObjectFile/ObjectFileDB_IR2.cpp +++ b/decompiler/ObjectFile/ObjectFileDB_IR2.cpp @@ -416,7 +416,7 @@ void ObjectFileDB::ir2_register_usage_pass() { func.warnings.bad_vf_dependency("{}", x.to_string()); } - if (x.get_kind() == Reg::COP2_MACRO_SPECIAL) { + if (x.get_kind() == Reg::SPECIAL) { lg::error("Bad vf dependency on {} in {}", x.to_charp(), func.guessed_name.to_string()); func.warnings.bad_vf_dependency("{}", x.to_string()); } diff --git a/decompiler/config.cpp b/decompiler/config.cpp index 5e1c055777..71cd1d4dd5 100644 --- a/decompiler/config.cpp +++ b/decompiler/config.cpp @@ -79,7 +79,7 @@ Config read_config_file(const std::string& path_to_config_file) { for (auto idx : idx_range) { RegisterTypeCast type_cast; type_cast.atomic_op_idx = idx; - type_cast.reg = Register(cast.at(1)); + type_cast.reg = Register(cast.at(1).get()); type_cast.type_name = cast.at(2).get(); config.register_type_casts_by_function_by_atomic_op_idx[function_name][idx].push_back( type_cast); diff --git a/decompiler/util/config_parsers.cpp b/decompiler/util/config_parsers.cpp index c5349074d7..44a3836054 100644 --- a/decompiler/util/config_parsers.cpp +++ b/decompiler/util/config_parsers.cpp @@ -40,7 +40,7 @@ std::unordered_map> parse_cast_hi for (auto idx : idx_range) { RegisterTypeCast type_cast; type_cast.atomic_op_idx = idx; - type_cast.reg = Register(cast.at(1)); + type_cast.reg = Register(cast.at(1).get()); type_cast.type_name = cast.at(2).get(); out[idx].push_back(type_cast); } diff --git a/test/test_common_util.cpp b/test/test_common_util.cpp index 7ed3eb7579..b9683a605f 100644 --- a/test/test_common_util.cpp +++ b/test/test_common_util.cpp @@ -11,6 +11,7 @@ #include "common/util/Range.h" #include "third-party/fmt/core.h" #include "common/util/print_float.h" +#include "common/util/CopyOnWrite.h" TEST(CommonUtil, get_file_path) { std::vector test = {"cabbage", "banana", "apple"}; @@ -140,4 +141,45 @@ TEST(CommonUtil, PowerOfTwo) { EXPECT_EQ(get_power_of_two(3), std::nullopt); EXPECT_EQ(get_power_of_two(4), 2); EXPECT_EQ(get_power_of_two(u64(1) << 63), 63); +} + +TEST(CommonUtil, CopyOnWrite) { + CopyOnWrite x(2); + + EXPECT_EQ(*x, 2); + *x.mut() = 3; + EXPECT_EQ(*x, 3); + + CopyOnWrite y = x; + EXPECT_EQ(*x, 3); + EXPECT_EQ(*y, 3); + EXPECT_EQ(x.get(), y.get()); + + *x.mut() = 12; + EXPECT_EQ(*x, 12); + EXPECT_EQ(*y, 3); + + x = y; + EXPECT_EQ(*x, 3); + EXPECT_EQ(*y, 3); + EXPECT_EQ(x.get(), y.get()); + + y = x; + EXPECT_EQ(*x, 3); + EXPECT_EQ(*y, 3); + EXPECT_EQ(x.get(), y.get()); + + EXPECT_TRUE(x); + EXPECT_TRUE(y); + + CopyOnWrite z; + EXPECT_FALSE(z); + + z = x; + EXPECT_TRUE(z); + EXPECT_EQ(x.get(), z.get()); + *z.mut() = 15; + EXPECT_EQ(*x, 3); + EXPECT_EQ(*y, 3); + EXPECT_EQ(*z, 15); } \ No newline at end of file