From 7248c00448eda9e97fa9be5a4601b814a62b800c Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 22 Feb 2025 10:52:14 +0900 Subject: [PATCH] feat: attrs can be registered in methods other than `__init__` --- Cargo.lock | 56 +++++------ Cargo.toml | 6 +- crates/py2erg/convert.rs | 204 +++++++++++++++++++++++++++------------ tests/class.py | 17 ++++ tests/test.rs | 2 +- 5 files changed, 191 insertions(+), 94 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5a21fbe..d5eee2a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,9 +31,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.95" +version = "1.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" +checksum = "6b964d184e89d9b6b67dd2715bc8e74cf3107fb2b529990c90cf517326150bf4" [[package]] name = "autocfg" @@ -93,9 +93,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cc" -version = "1.2.14" +version = "1.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" +checksum = "c736e259eea577f443d5c86c304f9f4ae0295c43f3ba05c21f1d66b5f06001af" dependencies = [ "shlex", ] @@ -152,9 +152,9 @@ checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "els" -version = "0.1.65-nightly.1" +version = "0.1.65-nightly.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93c98062a1591263c7bdb7444ffae5eb7151fb3fc8f09eddeba13163e86706cf" +checksum = "1dd708f21ce87a184ff5c9010afe61169b021a986fa1f0480ce93f7069308b34" dependencies = [ "erg_common", "erg_compiler", @@ -168,9 +168,9 @@ dependencies = [ [[package]] name = "erg_common" -version = "0.6.53-nightly.1" +version = "0.6.53-nightly.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6d0fdaad7b1913c304d070a5e4737e7495e2eadca1a67fddca77649995c4d52" +checksum = "fc22c2d3966dfd49dc4e3d142a4c066acfbc2debbf807bd80239640263f82906" dependencies = [ "backtrace-on-stack-overflow", "erg_proc_macros", @@ -181,9 +181,9 @@ dependencies = [ [[package]] name = "erg_compiler" -version = "0.6.53-nightly.1" +version = "0.6.53-nightly.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26ebe561fa98ff75a6a99167e86826f01baca96e0a696f3fcb4af86101f9c309" +checksum = "33a71fd95255b255147476dba4b043babe82350d1b32c3bd2b716b5c801ac455" dependencies = [ "erg_common", "erg_parser", @@ -191,9 +191,9 @@ dependencies = [ [[package]] name = "erg_parser" -version = "0.6.53-nightly.1" +version = "0.6.53-nightly.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c06268f345b640bcb6b3f0bf2390b9da072b05b72c697f51502e47e11db667" +checksum = "283375be7368b88ab745eeb2530c4a276325d36d41f664935c6521cd2afe4043" dependencies = [ "erg_common", "erg_proc_macros", @@ -202,9 +202,9 @@ dependencies = [ [[package]] name = "erg_proc_macros" -version = "0.6.53-nightly.1" +version = "0.6.53-nightly.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b6f03ae04101a0079ac52948edd33a08b3349d78a5cf632bddc22f020bd52e1" +checksum = "222a96d889f1f49149a1d20e634096a82fc575dfe49661d8e713ebcfb152ef71" dependencies = [ "quote", "syn 1.0.109", @@ -468,9 +468,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.25" +version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" [[package]] name = "lsp-types" @@ -560,9 +560,9 @@ dependencies = [ [[package]] name = "miniz_oxide" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3b1c9bd4fe1f0f8b387f6eb9eb3b4a1aa26185e5750efb9140301703f62cd1b" +checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" dependencies = [ "adler2", ] @@ -798,9 +798,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.8" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" +checksum = "82b568323e98e49e2a0899dcee453dd679fae22d69adf9b11dd508d1549b7e2f" dependencies = [ "bitflags 2.8.0", ] @@ -890,18 +890,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" +checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" +checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" dependencies = [ "proc-macro2", "quote", @@ -910,9 +910,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.138" +version = "1.0.139" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" +checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6" dependencies = [ "itoa", "memchr", @@ -1077,9 +1077,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" +checksum = "00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe" [[package]] name = "unicode-width" diff --git a/Cargo.toml b/Cargo.toml index 6b48f61..a404772 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,9 +24,9 @@ edition = "2021" repository = "https://github.com/mtshiba/pylyzer" [workspace.dependencies] -erg_common = { version = "0.6.53-nightly.1", features = ["py_compat", "els"] } -erg_compiler = { version = "0.6.53-nightly.1", features = ["py_compat", "els"] } -els = { version = "0.1.65-nightly.1", features = ["py_compat"] } +erg_common = { version = "0.6.53-nightly.2", features = ["py_compat", "els"] } +erg_compiler = { version = "0.6.53-nightly.2", features = ["py_compat", "els"] } +els = { version = "0.1.65-nightly.2", features = ["py_compat"] } # rustpython-parser = { version = "0.3.0", features = ["all-nodes-with-ranges", "location"] } # rustpython-ast = { version = "0.3.0", features = ["all-nodes-with-ranges", "location"] } rustpython-parser = { git = "https://github.com/RustPython/Parser", version = "0.4.0", features = ["all-nodes-with-ranges", "location"] } diff --git a/crates/py2erg/convert.rs b/crates/py2erg/convert.rs index 4923ba3..6e0d9c4 100644 --- a/crates/py2erg/convert.rs +++ b/crates/py2erg/convert.rs @@ -6,7 +6,7 @@ use erg_common::dict::Dict as HashMap; use erg_common::error::Location as ErgLocation; use erg_common::fresh::FRESH_GEN; use erg_common::set::Set as HashSet; -use erg_common::traits::{Locational, Stream}; +use erg_common::traits::{Locational, Stream, Traversable}; use erg_common::{fmt_vec, log, set}; use erg_compiler::artifact::IncompleteArtifact; use erg_compiler::erg_parser::ast::{ @@ -2165,32 +2165,59 @@ impl ASTConverter { Block::new(new_block) } - fn check_init_sig(&mut self, sig: &Signature) -> Option<()> { + #[allow(clippy::result_large_err)] + fn check_init_sig(&self, sig: &Signature) -> Result<(), (bool, CompileError)> { match sig { Signature::Subr(subr) => { if let Some(first) = subr.params.non_defaults.first() { if first.inspect().map(|s| &s[..]) == Some("self") { - return Some(()); + return Ok(()); } } - self.errs.push(self_not_found_error( - self.cfg.input.clone(), - subr.loc(), - self.cur_namespace(), - )); - Some(()) - } - Signature::Var(var) => { - self.errs.push(init_var_error( - self.cfg.input.clone(), - var.loc(), - self.cur_namespace(), - )); - None + Err(( + true, + self_not_found_error(self.cfg.input.clone(), subr.loc(), self.cur_namespace()), + )) } + Signature::Var(var) => Err(( + false, + init_var_error(self.cfg.input.clone(), var.loc(), self.cur_namespace()), + )), } } + // ```python + // def __init__(self, x: Int, y: Int, z): + // self.x = x + // self.y = y + // if True: + // self.set_z(z) + // + // def set_z(self, z): + // self.z = z + // ``` + // ↓ + // methods: {"set_z"} + #[allow(clippy::only_used_in_recursion)] + fn collect_called_methods(&self, expr: &Expr, methods: &mut HashSet) { + expr.traverse(&mut |ex| { + if let Expr::Call(call) = ex { + match call.obj.as_ref() { + Expr::Accessor(Accessor::Ident(ident)) if ident.inspect() == "self" => { + if let Some(method_ident) = &call.attr_name { + methods.insert(method_ident.inspect().to_string()); + } else { + self.collect_called_methods(&call.obj, methods); + } + } + _ => self.collect_called_methods(ex, methods), + } + } else { + self.collect_called_methods(ex, methods); + } + }); + } + // def __init__(self, x: Int, y: Int, z): // self.x = x // self.y = y @@ -2198,8 +2225,21 @@ impl ASTConverter { // ↓ // requirement : {x: Int, y: Int, z: Any} // returns : .__call__(x: Int, y: Int, z: Obj): Self = .unreachable() - fn extract_init(&mut self, base_type: &mut Option, init_def: Def) -> Option { - self.check_init_sig(&init_def.sig)?; + fn extract_init( + &mut self, + base_type: &mut Option, + init_def: Def, + attrs: &[ClassAttr], + pre_check: bool, + ) -> Option { + if let Err((continuable, err)) = self.check_init_sig(&init_def.sig) { + if pre_check { + self.errs.push(err); + } + if !continuable { + return None; + } + } let l_brace = Token::new( TokenKind::LBrace, "{", @@ -2212,50 +2252,21 @@ impl ASTConverter { init_def.ln_end().unwrap_or(0), init_def.col_end().unwrap_or(0), ); + let expr = Expr::Def(init_def.clone()); let Signature::Subr(sig) = init_def.sig else { unreachable!() }; let mut fields = vec![]; - for chunk in init_def.body.block { - #[allow(clippy::single_match)] - match chunk { - Expr::ReDef(redef) => { - let Accessor::Attr(attr) = redef.attr else { - continue; - }; - // if `self.foo == ...` - if attr.obj.get_name().map(|s| &s[..]) == Some("self") { - // get attribute types - let typ = if let Some(t_spec_op) = sig - .params - .non_defaults - .iter() - .find(|¶m| param.inspect() == Some(attr.ident.inspect())) - .and_then(|param| param.t_spec.as_ref()) - .or_else(|| { - sig.params - .defaults - .iter() - .find(|¶m| param.inspect() == Some(attr.ident.inspect())) - .and_then(|param| param.sig.t_spec.as_ref()) - }) { - *t_spec_op.t_spec_as_expr.clone() - } else if let Some(typ) = redef.t_spec.map(|t_spec| t_spec.t_spec_as_expr) { - *typ - } else { - Expr::from(Accessor::Ident(Identifier::private_with_line( - "Any".into(), - attr.obj.ln_begin().unwrap_or(0), - ))) - }; - let sig = - Signature::Var(VarSignature::new(VarPattern::Ident(attr.ident), None)); - let body = DefBody::new(EQUAL, Block::new(vec![typ]), DefId(0)); - let field_type_def = Def::new(sig, body); - fields.push(field_type_def); + self.extract_instance_attrs(&sig, &init_def.body.block, &mut fields); + let mut method_names = HashSet::new(); + self.collect_called_methods(&expr, &mut method_names); + for class_attr in attrs { + if let ClassAttr::Def(def) = class_attr { + if let Signature::Subr(sig) = &def.sig { + if method_names.contains(&sig.ident.inspect()[..]) { + self.extract_instance_attrs(sig, &def.body.block, &mut fields); } } - _ => {} } } if let Some(Expr::Record(Record::Normal(rec))) = base_type.as_mut() { @@ -2312,6 +2323,55 @@ impl ASTConverter { Some(def) } + fn extract_instance_attrs(&self, sig: &SubrSignature, block: &Block, fields: &mut Vec) { + for chunk in block.iter() { + if let Expr::ReDef(redef) = chunk { + let Accessor::Attr(attr) = &redef.attr else { + continue; + }; + // if `self.foo == ...` + if attr.obj.get_name().map(|s| &s[..]) == Some("self") + && fields + .iter() + .all(|field| field.sig.ident().unwrap() != &attr.ident) + { + // get attribute types + let typ = if let Some(t_spec_op) = sig + .params + .non_defaults + .iter() + .find(|¶m| param.inspect() == Some(attr.ident.inspect())) + .and_then(|param| param.t_spec.as_ref()) + .or_else(|| { + sig.params + .defaults + .iter() + .find(|¶m| param.inspect() == Some(attr.ident.inspect())) + .and_then(|param| param.sig.t_spec.as_ref()) + }) { + *t_spec_op.t_spec_as_expr.clone() + } else if let Some(typ) = + redef.t_spec.clone().map(|t_spec| t_spec.t_spec_as_expr) + { + *typ + } else { + Expr::from(Accessor::Ident(Identifier::private_with_line( + "Any".into(), + attr.obj.ln_begin().unwrap_or(0), + ))) + }; + let sig = Signature::Var(VarSignature::new( + VarPattern::Ident(attr.ident.clone()), + None, + )); + let body = DefBody::new(EQUAL, Block::new(vec![typ]), DefId(0)); + let field_type_def = Def::new(sig, body); + fields.push(field_type_def); + } + } + } + } + fn gen_default_init(&self, line: usize) -> Def { let call_ident = Identifier::new( VisModifierSpec::Public(ErgLocation::Unknown), @@ -2346,7 +2406,6 @@ impl ASTConverter { ) -> (Option, ClassAttrs) { let mut base_type = None; let mut attrs = vec![]; - let mut init_is_defined = false; let mut call_params_len = None; for stmt in body { match self.convert_statement(stmt, true) { @@ -2402,12 +2461,15 @@ impl ASTConverter { .ident() .is_some_and(|id| &id.inspect()[..] == "__init__") { - if let Some(call_def) = self.extract_init(&mut base_type, def) { + // We will generate `__init__` and determine the shape of the class at the end + // Here we just extract the signature + if let Some(call_def) = + self.extract_init(&mut base_type, def.clone(), &attrs, true) + { if let Some(params) = call_def.sig.params() { call_params_len = Some(params.len()); } - attrs.insert(0, ClassAttr::Def(call_def)); - init_is_defined = true; + attrs.push(ClassAttr::Def(def)); } } else { attrs.push(ClassAttr::Def(def)); @@ -2457,7 +2519,25 @@ impl ASTConverter { _other => {} // TODO: } } - if !init_is_defined && !inherit { + let mut init = None; + for (i, attr) in attrs.iter().enumerate() { + if let ClassAttr::Def(def) = attr { + if def + .sig + .ident() + .is_some_and(|id| &id.inspect()[..] == "__init__") + { + if let Some(def) = self.extract_init(&mut base_type, def.clone(), &attrs, false) + { + init = Some((i, def)); + } + } + } + } + if let Some((i, def)) = init { + attrs.remove(i); + attrs.insert(0, ClassAttr::Def(def)); + } else if !inherit { attrs.insert(0, ClassAttr::Def(self.gen_default_init(0))); } (base_type, ClassAttrs::new(attrs)) diff --git a/tests/class.py b/tests/class.py index 86b867a..0ee61b1 100644 --- a/tests/class.py +++ b/tests/class.py @@ -120,3 +120,20 @@ class Cs: self.cs.append(c) self.cs2.append(c) self.cs_list.append([c]) + +class I: + def __init__(self): + self.ix: int = 1 + if True: + self.init_y() + + def init_y(self): + self.iy: int = 2 + + def foo(self): + self.iz: int = 1 # ERR + +i = I() +_ = i.ix +_ = i.iy # OK +_ = i.iz # ERR diff --git a/tests/test.rs b/tests/test.rs index b609be3..bea60a3 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -100,7 +100,7 @@ fn exec_func() -> Result<(), String> { #[test] fn exec_class() -> Result<(), String> { - expect("tests/class.py", 0, 6) + expect("tests/class.py", 0, 8) } #[test]