fix: type inference bugs

This commit is contained in:
Shunsuke Shibayama 2024-12-29 14:06:12 +09:00
parent faec281fa9
commit 9cd1846216
8 changed files with 83 additions and 24 deletions

20
Cargo.lock generated
View File

@ -150,9 +150,9 @@ checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
[[package]] [[package]]
name = "els" name = "els"
version = "0.1.62" version = "0.1.63-nightly.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98d4bb4cbce0f519100ba0d6aa5541aa69aa00b0f61bdb862c81124f9cf38cea" checksum = "99d2ae54e13d256ec9a09554249ccc498c31022836c76bc5dc1b11b072f0a753"
dependencies = [ dependencies = [
"erg_common", "erg_common",
"erg_compiler", "erg_compiler",
@ -166,9 +166,9 @@ dependencies = [
[[package]] [[package]]
name = "erg_common" name = "erg_common"
version = "0.6.50" version = "0.6.51-nightly.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15fcd8b1d8d47238d1488f7a05a8131b77b89adb54c867327b83db272a919344" checksum = "abc654d256215d0d1b2ccb5d1e80c9188148dd59714d0a642f5738bf0271197a"
dependencies = [ dependencies = [
"backtrace-on-stack-overflow", "backtrace-on-stack-overflow",
"erg_proc_macros", "erg_proc_macros",
@ -179,9 +179,9 @@ dependencies = [
[[package]] [[package]]
name = "erg_compiler" name = "erg_compiler"
version = "0.6.50" version = "0.6.51-nightly.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5b708d63c430435aac418e0822ef435b6b26c5338e5795130cce308f319505" checksum = "223b817462901cfef987f38c21a18f9637f7f796dbe58a3b51a1e09489fd25be"
dependencies = [ dependencies = [
"erg_common", "erg_common",
"erg_parser", "erg_parser",
@ -189,9 +189,9 @@ dependencies = [
[[package]] [[package]]
name = "erg_parser" name = "erg_parser"
version = "0.6.50" version = "0.6.51-nightly.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b79c7b5789c93deeeb21cbe9d4e0c62db60712a187a1c260210079d3750b4be" checksum = "13d9eaf0b3076b05cb7290be98de5c1cbd73ab7a72b090f67d5911e03d062501"
dependencies = [ dependencies = [
"erg_common", "erg_common",
"erg_proc_macros", "erg_proc_macros",
@ -200,9 +200,9 @@ dependencies = [
[[package]] [[package]]
name = "erg_proc_macros" name = "erg_proc_macros"
version = "0.6.50" version = "0.6.51-nightly.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99659bb992c4e9da4af751d63fa126034025ffe79e07bfdbf013d9e165e768fc" checksum = "e3d2e883cddea276b76add108da8d16276d9d6962a0079561515e701121a3e1a"
dependencies = [ dependencies = [
"quote", "quote",
"syn 1.0.109", "syn 1.0.109",

View File

@ -24,9 +24,9 @@ edition = "2021"
repository = "https://github.com/mtshiba/pylyzer" repository = "https://github.com/mtshiba/pylyzer"
[workspace.dependencies] [workspace.dependencies]
erg_common = { version = "0.6.50", features = ["py_compat", "els"] } erg_common = { version = "0.6.51-nightly.0", features = ["py_compat", "els"] }
erg_compiler = { version = "0.6.50", features = ["py_compat", "els"] } erg_compiler = { version = "0.6.51-nightly.0", features = ["py_compat", "els"] }
els = { version = "0.1.50", features = ["py_compat"] } els = { version = "0.1.63-nightly.0", features = ["py_compat"] }
# rustpython-parser = { version = "0.3.0", features = ["all-nodes-with-ranges", "location"] } # rustpython-parser = { version = "0.3.0", features = ["all-nodes-with-ranges", "location"] }
# rustpython-ast = { version = "0.3.0", features = ["all-nodes-with-ranges", "location"] } # rustpython-ast = { version = "0.3.0", features = ["all-nodes-with-ranges", "location"] }
rustpython-parser = { git = "https://github.com/RustPython/Parser", version = "0.4.0", features = ["all-nodes-with-ranges", "location"] } rustpython-parser = { git = "https://github.com/RustPython/Parser", version = "0.4.0", features = ["all-nodes-with-ranges", "location"] }

View File

@ -320,6 +320,19 @@ pub struct BlockInfo {
pub kind: BlockKind, pub kind: BlockKind,
} }
#[derive(Debug)]
pub enum ReturnKind {
None,
Return,
Yield,
}
impl ReturnKind {
pub const fn is_none(&self) -> bool {
matches!(self, Self::None)
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct LocalContext { pub struct LocalContext {
pub name: String, pub name: String,
@ -329,6 +342,7 @@ pub struct LocalContext {
type_vars: HashMap<String, TypeVarInfo>, type_vars: HashMap<String, TypeVarInfo>,
// e.g. def id(x: T) -> T: ... => appeared_types = {T} // e.g. def id(x: T) -> T: ... => appeared_types = {T}
appeared_type_names: HashSet<String>, appeared_type_names: HashSet<String>,
return_kind: ReturnKind,
} }
impl LocalContext { impl LocalContext {
@ -339,6 +353,7 @@ impl LocalContext {
names: HashMap::new(), names: HashMap::new(),
type_vars: HashMap::new(), type_vars: HashMap::new(),
appeared_type_names: HashSet::new(), appeared_type_names: HashSet::new(),
return_kind: ReturnKind::None,
} }
} }
} }
@ -667,6 +682,14 @@ impl ASTConverter {
&self.contexts.last().unwrap().name &self.contexts.last().unwrap().name
} }
fn cur_context(&self) -> &LocalContext {
self.contexts.last().unwrap()
}
fn cur_context_mut(&mut self) -> &mut LocalContext {
self.contexts.last_mut().unwrap()
}
fn parent_name(&self) -> &str { fn parent_name(&self) -> &str {
&self.contexts[self.contexts.len().saturating_sub(2)].name &self.contexts[self.contexts.len().saturating_sub(2)].name
} }
@ -2033,6 +2056,11 @@ impl ASTConverter {
); );
Expr::Compound(Compound::new(vec![Expr::Def(def), target])) Expr::Compound(Compound::new(vec![Expr::Def(def), target]))
} }
py_ast::Expr::Yield(_) => {
self.cur_context_mut().return_kind = ReturnKind::Yield;
log!(err "unimplemented: {:?}", expr);
Expr::Dummy(Dummy::new(None, vec![]))
}
_other => { _other => {
log!(err "unimplemented: {:?}", _other); log!(err "unimplemented: {:?}", _other);
Expr::Dummy(Dummy::new(None, vec![])) Expr::Dummy(Dummy::new(None, vec![]))
@ -2087,7 +2115,7 @@ impl ASTConverter {
// self.y = y // self.y = y
// self.z = z // self.z = z
// ↓ // ↓
// requirement : {x: Int, y: Int, z: Never} // requirement : {x: Int, y: Int, z: Any}
// returns : .__call__(x: Int, y: Int, z: Obj): Self = .unreachable() // returns : .__call__(x: Int, y: Int, z: Obj): Self = .unreachable()
fn extract_init(&mut self, base_type: &mut Option<Expr>, init_def: Def) -> Option<Def> { fn extract_init(&mut self, base_type: &mut Option<Expr>, init_def: Def) -> Option<Def> {
self.check_init_sig(&init_def.sig)?; self.check_init_sig(&init_def.sig)?;
@ -2134,9 +2162,8 @@ impl ASTConverter {
} else if let Some(typ) = redef.t_spec.map(|t_spec| t_spec.t_spec_as_expr) { } else if let Some(typ) = redef.t_spec.map(|t_spec| t_spec.t_spec_as_expr) {
*typ *typ
} else { } else {
Expr::from(Accessor::Ident(Identifier::public_with_line( Expr::from(Accessor::Ident(Identifier::private_with_line(
DOT, "Any".into(),
"Never".into(),
attr.obj.ln_begin().unwrap_or(0), attr.obj.ln_begin().unwrap_or(0),
))) )))
}; };
@ -2473,8 +2500,22 @@ impl ASTConverter {
.unwrap_or(func_def.type_params) .unwrap_or(func_def.type_params)
}; };
let bounds = self.get_type_bounds(type_params); let bounds = self.get_type_bounds(type_params);
let sig = Signature::Subr(SubrSignature::new(decos, ident, bounds, params, return_t)); let mut sig =
Signature::Subr(SubrSignature::new(decos, ident, bounds, params, return_t));
let block = self.convert_block(func_def.body, BlockKind::Function); let block = self.convert_block(func_def.body, BlockKind::Function);
if self.cur_context().return_kind.is_none() {
let Signature::Subr(subr) = &mut sig else {
unreachable!()
};
if subr.return_t_spec.is_none() {
let none = TypeSpecWithOp::new(
Token::dummy(TokenKind::Colon, ":"),
TypeSpec::mono(Identifier::private("NoneType".into())),
Expr::static_local("NoneType"),
);
subr.return_t_spec = Some(Box::new(none));
}
}
let body = DefBody::new(EQUAL, block, DefId(0)); let body = DefBody::new(EQUAL, block, DefId(0));
let def = Def::new(sig, body); let def = Def::new(sig, body);
self.pop(); self.pop();
@ -3030,6 +3071,7 @@ impl ASTConverter {
} }
} }
py_ast::Stmt::Return(return_) => { py_ast::Stmt::Return(return_) => {
self.cur_context_mut().return_kind = ReturnKind::Return;
let loc = return_.location(); let loc = return_.location();
let value = return_ let value = return_
.value .value

View File

@ -37,10 +37,11 @@ fn handle_name_error(error: CompileError) -> Option<CompileError> {
|| { || {
main.contains(" is not defined") && { main.contains(" is not defined") && {
let name = StyledStr::destyle(main.trim_end_matches(" is not defined")); let name = StyledStr::destyle(main.trim_end_matches(" is not defined"));
error name == "Any"
.core || error
.get_hint() .core
.is_some_and(|hint| hint.contains(name)) .get_hint()
.is_some_and(|hint| hint.contains(name))
} }
} }
{ {

View File

@ -98,3 +98,10 @@ class MyList(list):
return MyList(lis) return MyList(lis)
else: else:
return None return None
class Implicit:
def __init__(self):
self.foo = False
def set_foo(self):
self.foo = True

4
tests/err/class.py Normal file
View File

@ -0,0 +1,4 @@
class Foo:
def invalid_append(self):
paths: list[str] = []
paths.append(self) # ERR

View File

@ -1,5 +1,5 @@
def imaginary(x): def imaginary(x):
x.imag return x.imag
assert imaginary(1) == 0 assert imaginary(1) == 0
assert imaginary(1.0) <= 0.0 assert imaginary(1.0) <= 0.0
@ -8,7 +8,7 @@ print(imaginary("a")) # ERR
class C: class C:
def method(self, x): return x def method(self, x): return x
def call_method(obj, x): def call_method(obj, x):
obj.method(x) return obj.method(x)
c = C() c = C()
assert call_method(c, 1) == 1 assert call_method(c, 1) == 1

View File

@ -103,6 +103,11 @@ fn exec_class() -> Result<(), String> {
expect("tests/class.py", 0, 6) expect("tests/class.py", 0, 6)
} }
#[test]
fn exec_class_err() -> Result<(), String> {
expect("tests/err/class.py", 0, 1)
}
#[test] #[test]
fn exec_errors() -> Result<(), String> { fn exec_errors() -> Result<(), String> {
expect("tests/errors.py", 0, 3) expect("tests/errors.py", 0, 3)