diff --git a/common/goos/PrettyPrinter.cpp b/common/goos/PrettyPrinter.cpp index 796273ab8c..f0ccb5afd5 100644 --- a/common/goos/PrettyPrinter.cpp +++ b/common/goos/PrettyPrinter.cpp @@ -572,9 +572,8 @@ void breakList(NodePool& pool, PrettyPrinterNode* leftParen, PrettyPrinterNode* } namespace { -const std::unordered_set control_flow_start_forms = { - "while", "dotimes", "until", "if", "when", -}; +const std::unordered_set control_flow_start_forms = {"while", "dotimes", "until", + "if", "when", "countdown"}; } PrettyPrinterNode* seek_to_next_non_whitespace(PrettyPrinterNode* in) { diff --git a/decompiler/IR2/Form.cpp b/decompiler/IR2/Form.cpp index 757b73192f..48f622246f 100644 --- a/decompiler/IR2/Form.cpp +++ b/decompiler/IR2/Form.cpp @@ -2203,45 +2203,58 @@ void LetElement::set_body(Form* new_body) { } ///////////////////////////// -// DoTimesElement +// CounterLoopElement ///////////////////////////// -DoTimesElement::DoTimesElement(RegisterAccess var_init, - RegisterAccess var_check, - RegisterAccess var_inc, - Form* check_value, - Form* body) +CounterLoopElement::CounterLoopElement(Kind kind, + RegisterAccess var_init, + RegisterAccess var_check, + RegisterAccess var_inc, + Form* check_value, + Form* body) : m_var_init(var_init), m_var_check(var_check), m_var_inc(var_inc), m_check_value(check_value), - m_body(body) { + m_body(body), + m_kind(kind) { m_body->parent_element = this; m_check_value->parent_element = this; assert(m_var_inc.reg() == m_var_check.reg()); assert(m_var_init.reg() == m_var_inc.reg()); } -goos::Object DoTimesElement::to_form_internal(const Env& env) const { +goos::Object CounterLoopElement::to_form_internal(const Env& env) const { + std::string loop_name; + switch (m_kind) { + case Kind::DOTIMES: + loop_name = "dotimes"; + break; + case Kind::COUNTDOWN: + loop_name = "countdown"; + break; + default: + assert(false); + } std::vector outer = { - pretty_print::to_symbol("dotimes"), + pretty_print::to_symbol(loop_name), pretty_print::build_list(m_var_init.to_form(env), m_check_value->to_form(env))}; m_body->inline_forms(outer, env); return pretty_print::build_list(outer); } -void DoTimesElement::apply(const std::function& f) { +void CounterLoopElement::apply(const std::function& f) { f(this); m_check_value->apply(f); m_body->apply(f); } -void DoTimesElement::apply_form(const std::function& f) { +void CounterLoopElement::apply_form(const std::function& f) { m_check_value->apply_form(f); m_body->apply_form(f); } -void DoTimesElement::collect_vars(RegAccessSet& vars, bool recursive) const { +void CounterLoopElement::collect_vars(RegAccessSet& vars, bool recursive) const { vars.insert(m_var_init); vars.insert(m_var_check); vars.insert(m_var_inc); @@ -2251,7 +2264,7 @@ void DoTimesElement::collect_vars(RegAccessSet& vars, bool recursive) const { } } -void DoTimesElement::get_modified_regs(RegSet& regs) const { +void CounterLoopElement::get_modified_regs(RegSet& regs) const { regs.insert(m_var_inc.reg()); m_body->get_modified_regs(regs); m_check_value->get_modified_regs(regs); diff --git a/decompiler/IR2/Form.h b/decompiler/IR2/Form.h index 9c7c7a2553..cc9a368354 100644 --- a/decompiler/IR2/Form.h +++ b/decompiler/IR2/Form.h @@ -1288,13 +1288,15 @@ class LetElement : public FormElement { bool m_star = false; }; -class DoTimesElement : public FormElement { +class CounterLoopElement : public FormElement { public: - DoTimesElement(RegisterAccess var_init, - RegisterAccess var_check, - RegisterAccess var_inc, - Form* check_value, - Form* body); + enum class Kind { DOTIMES, COUNTDOWN, INVALID }; + CounterLoopElement(Kind kind, + RegisterAccess var_init, + RegisterAccess var_check, + RegisterAccess var_inc, + Form* check_value, + Form* body); goos::Object to_form_internal(const Env& env) const override; void apply(const std::function& f) override; void apply_form(const std::function& f) override; @@ -1306,6 +1308,7 @@ class DoTimesElement : public FormElement { RegisterAccess m_var_init, m_var_check, m_var_inc; Form* m_check_value = nullptr; Form* m_body = nullptr; + Kind m_kind = Kind::INVALID; }; class LambdaDefinitionElement : public FormElement { diff --git a/decompiler/analysis/insert_lets.cpp b/decompiler/analysis/insert_lets.cpp index 702bd51da3..3e2a1cbb33 100644 --- a/decompiler/analysis/insert_lets.cpp +++ b/decompiler/analysis/insert_lets.cpp @@ -162,8 +162,82 @@ FormElement* rewrite_as_dotimes(LetElement* in, const Env& env, FormPool& pool) // first, remove the increment body->pop_back(); - return pool.alloc_element(in->entries().at(0).dest, *lt_var, *inc_var, - mr.maps.forms.at(1), body); + return pool.alloc_element(CounterLoopElement::Kind::DOTIMES, + in->entries().at(0).dest, *lt_var, *inc_var, + mr.maps.forms.at(1), body); +} + +FormElement* rewrite_as_countdown(LetElement* in, const Env& env, FormPool& pool) { + // dotimes OpenGOAL: + /* + (defmacro countdown (var &rest body) + "Loop like for (int i = end; i-- > 0)" + `(let ((,(first var) ,(second var))) + (while (!= ,(first var) 0) + (set! ,(first var) (- ,(first var) 1)) + ,@body + ) + ) + ) + */ + + // should have this anyway, but double check so we don't throw this away. + if (in->entries().size() != 1) { + return nullptr; + } + + // look for setting a var to the initial value. + auto ra = in->entries().at(0).dest; + auto idx_var = env.get_variable_name(ra); + + // still have to check body for the increment and have to check that the lt operates on the right + // thing. + Matcher while_matcher = Matcher::while_loop( + Matcher::op(GenericOpMatcher::condition(IR2_Condition::Kind::NONZERO), {Matcher::any_reg(0)}), + Matcher::any(2)); + + auto mr = match(while_matcher, in->body()); + if (!mr.matched) { + return nullptr; + } + + // check the zero operation: + auto lt_var = mr.maps.regs.at(0); + assert(lt_var); + if (env.get_variable_name(*lt_var) != idx_var) { + return nullptr; // wrong variable checked + } + + // check the body + auto body = mr.maps.forms.at(2); + auto first_in_body = body->elts().front(); + + // kind hacky + Form fake_form; + fake_form.elts().push_back(first_in_body); + Matcher increment_matcher = + Matcher::op(GenericOpMatcher::fixed(FixedOperatorKind::ADDITION_IN_PLACE), + {Matcher::any_reg(0), Matcher::integer(-1)}); + + auto int_mr = match(increment_matcher, &fake_form); + if (!int_mr.matched) { + return nullptr; + } + + auto inc_var = int_mr.maps.regs.at(0); + assert(inc_var); + if (env.get_variable_name(*inc_var) != idx_var) { + return nullptr; // wrong variable incremented + } + + // success! here we commit to modifying this: + + // first, remove the increment + body->elts().erase(body->elts().begin()); + + return pool.alloc_element(CounterLoopElement::Kind::COUNTDOWN, + in->entries().at(0).dest, *lt_var, *inc_var, + in->entries().at(0).src, body); } FormElement* fix_up_abs(LetElement* in, const Env& env, FormPool& pool) { @@ -365,6 +439,11 @@ FormElement* rewrite_let(LetElement* in, const Env& env, FormPool& pool) { return as_dotimes; } + auto as_countdown = rewrite_as_countdown(in, env, pool); + if (as_countdown) { + return as_countdown; + } + auto as_abs = fix_up_abs(in, env, pool); if (as_abs) { return as_abs; diff --git a/goal_src/engine/collide/collide-touch-h.gc b/goal_src/engine/collide/collide-touch-h.gc index 3694e64bfe..2854368bfc 100644 --- a/goal_src/engine/collide/collide-touch-h.gc +++ b/goal_src/engine/collide/collide-touch-h.gc @@ -52,26 +52,21 @@ (defmethod init-list! touching-prims-entry-pool ((obj touching-prims-entry-pool)) "Initialize all entries to be not allocated and in a linked list." - (local-vars - (prev touching-prims-entry) - (idx int) - (current (inline-array touching-prims-entry)) - (next touching-prims-entry) - ) - (set! prev #f) - (set! current (-> obj nodes)) - (set! (-> obj head) (-> current 0)) - (set! idx 64) - (while (nonzero? idx) - (+! idx -1) - (set! (-> current 0 prev) prev) - (set! next (-> current 1)) - (set! (-> current 0 next) next) - (set! (-> current 0 allocated?) #f) - (set! prev (-> current 0)) - (set! current (the (inline-array touching-prims-entry) next)) + (let ((prev (the-as touching-prims-entry #f))) + (let ((current (the-as touching-prims-entry (-> obj nodes)))) + (set! (-> obj head) current) + (countdown (a0-1 64) + (set! (-> current prev) prev) + (let ((next (&+ current 240))) + (set! (-> current next) (the-as touching-prims-entry next)) + (set! (-> current allocated?) #f) + (set! prev current) + (set! current (the-as touching-prims-entry next)) + ) + ) + ) + (set! (-> prev next) #f) ) - (set! (-> prev next) #f) (none) ) diff --git a/goal_src/engine/math/vector-h.gc b/goal_src/engine/math/vector-h.gc index 761d343543..957ae330a3 100644 --- a/goal_src/engine/math/vector-h.gc +++ b/goal_src/engine/math/vector-h.gc @@ -91,16 +91,12 @@ (defmethod clear bit-array ((obj bit-array)) "Set all bits to zero." - (local-vars (idx int)) - (let ((idx (sar (logand -8 (+ (-> obj allocated-length) 7)) 3))) - (while (nonzero? idx) - (set! idx (+ idx -1)) - (nop!) - (nop!) - (set! (-> obj bytes idx) 0) - ) - obj + (countdown (idx (/ (logand -8 (+ (-> obj allocated-length) 7)) 8)) + (nop!) + (nop!) + (set! (-> obj bytes idx) (the-as uint 0)) ) + obj ) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; diff --git a/goal_src/kernel/gcommon.gc b/goal_src/kernel/gcommon.gc index 0908899342..223ac5d24e 100644 --- a/goal_src/kernel/gcommon.gc +++ b/goal_src/kernel/gcommon.gc @@ -995,16 +995,13 @@ - size in bytes will be rounded up to 16-bytes - Ascending address copy." (let ((result dst)) - (let ((qwc (/ (+ size 15) 16))) - (while (nonzero? qwc) - (+! qwc -1) - (set! - (-> (the-as (pointer uint128) dst)) - (-> (the-as (pointer uint128) src)) - ) - (&+! dst 16) - (&+! src 16) - ) + (countdown (qwc (/ (+ size 15) 16)) + (set! + (-> (the-as (pointer uint128) dst)) + (-> (the-as (pointer uint128) src)) + ) + (&+! dst 16) + (&+! src 16) ) result ) diff --git a/test/decompiler/reference/engine/collide/collide-touch-h_REF.gc b/test/decompiler/reference/engine/collide/collide-touch-h_REF.gc index 1a2593bed1..bcdd28957c 100644 --- a/test/decompiler/reference/engine/collide/collide-touch-h_REF.gc +++ b/test/decompiler/reference/engine/collide/collide-touch-h_REF.gc @@ -84,16 +84,13 @@ (let ((prev (the-as touching-prims-entry #f))) (let ((current (the-as touching-prims-entry (-> obj nodes)))) (set! (-> obj head) current) - (let ((a0-1 64)) - (while (nonzero? a0-1) - (+! a0-1 -1) - (set! (-> current prev) prev) - (let ((next (&+ current 240))) - (set! (-> current next) (the-as touching-prims-entry next)) - (set! (-> current allocated?) #f) - (set! prev current) - (set! current (the-as touching-prims-entry next)) - ) + (countdown (a0-1 64) + (set! (-> current prev) prev) + (let ((next (&+ current 240))) + (set! (-> current next) (the-as touching-prims-entry next)) + (set! (-> current allocated?) #f) + (set! prev current) + (set! current (the-as touching-prims-entry next)) ) ) ) diff --git a/test/decompiler/reference/engine/math/vector-h_REF.gc b/test/decompiler/reference/engine/math/vector-h_REF.gc index 8ea121d595..7e76c35edd 100644 --- a/test/decompiler/reference/engine/math/vector-h_REF.gc +++ b/test/decompiler/reference/engine/math/vector-h_REF.gc @@ -95,13 +95,10 @@ ;; definition for method 12 of type bit-array (defmethod clear bit-array ((obj bit-array)) - (let ((idx (/ (logand -8 (+ (-> obj allocated-length) 7)) 8))) - (while (nonzero? idx) - (+! idx -1) - (nop!) - (nop!) - (set! (-> obj bytes idx) (the-as uint 0)) - ) + (countdown (idx (/ (logand -8 (+ (-> obj allocated-length) 7)) 8)) + (nop!) + (nop!) + (set! (-> obj bytes idx) (the-as uint 0)) ) obj ) diff --git a/test/decompiler/reference/kernel/gcommon_REF.gc b/test/decompiler/reference/kernel/gcommon_REF.gc index 7b528a0a2b..bd10892553 100644 --- a/test/decompiler/reference/kernel/gcommon_REF.gc +++ b/test/decompiler/reference/kernel/gcommon_REF.gc @@ -828,16 +828,13 @@ ;; Used lq/sq (defun qmem-copy<-! ((dst pointer) (src pointer) (size int)) (let ((result dst)) - (let ((qwc (/ (+ size 15) 16))) - (while (nonzero? qwc) - (+! qwc -1) - (set! - (-> (the-as (pointer uint128) dst)) - (-> (the-as (pointer uint128) src)) - ) - (&+! dst 16) - (&+! src 16) + (countdown (qwc (/ (+ size 15) 16)) + (set! + (-> (the-as (pointer uint128) dst)) + (-> (the-as (pointer uint128) src)) ) + (&+! dst 16) + (&+! src 16) ) result ) diff --git a/test/decompiler/reference/kernel/gkernel_REF.gc b/test/decompiler/reference/kernel/gkernel_REF.gc index 0589a5b2c4..660902ee3d 100644 --- a/test/decompiler/reference/kernel/gkernel_REF.gc +++ b/test/decompiler/reference/kernel/gkernel_REF.gc @@ -657,13 +657,10 @@ (set! (-> obj child) (the-as (pointer process-tree) #f)) (set! (-> obj self) obj) (set! (-> obj ppointer) (&-> obj self)) - (let ((v1-4 arg1)) - (while (nonzero? v1-4) - (+! v1-4 -1) - (let ((a0-4 (-> obj process-list v1-4))) - (set! (-> a0-4 process) *null-process*) - (set! (-> a0-4 next) (-> obj process-list (+ v1-4 1))) - ) + (countdown (v1-4 arg1) + (let ((a0-4 (-> obj process-list v1-4))) + (set! (-> a0-4 process) *null-process*) + (set! (-> a0-4 next) (-> obj process-list (+ v1-4 1))) ) ) (set! diff --git a/test/decompiler/test_gkernel_decomp.cpp b/test/decompiler/test_gkernel_decomp.cpp index 0c865e44e5..ff881ff3cb 100644 --- a/test/decompiler/test_gkernel_decomp.cpp +++ b/test/decompiler/test_gkernel_decomp.cpp @@ -1263,19 +1263,16 @@ TEST_F(FormRegressionTest, ExprMethod0DeadPoolHeap) { " (set! (-> obj child) (the-as (pointer process-tree) #f))\n" " (set! (-> obj self) obj)\n" " (set! (-> obj ppointer) (&-> obj self))\n" - " (let\n" - " ((v1-4 arg3))\n" - " (while\n" - " (nonzero? v1-4)\n" - " (+! v1-4 -1)\n" - " (let\n" - " ((a0-4 (-> obj process-list v1-4)))\n" - " (set! (-> a0-4 process) *null-process*)\n" - " (set! (-> a0-4 next) (-> obj process-list (+ v1-4 1)))\n" - " )\n" + " (countdown (v1-4 arg3)\n" + " (let ((a0-4 (-> obj process-list v1-4)))\n" + " (set! (-> a0-4 process) *null-process*)\n" + " (set! (-> a0-4 next) (-> obj process-list (+ v1-4 1)))\n" " )\n" " )\n" - " (set! (-> obj dead-list next) (the-as dead-pool-heap-rec (-> obj process-list)))\n" + " (set!\n" + " (-> obj dead-list next)\n" + " (the-as dead-pool-heap-rec (-> obj process-list))\n" + " )\n" " (set! (-> obj alive-list process) #f)\n" " (set! (-> obj process-list (+ arg3 -1) next) #f)\n" " (set! (-> obj alive-list prev) (-> obj alive-list))\n"