fix: type specification parsing bugs

This commit is contained in:
Shunsuke Shibayama 2024-12-28 18:43:07 +09:00
parent 0da66dceca
commit faec281fa9
3 changed files with 94 additions and 28 deletions

View File

@ -1052,8 +1052,7 @@ impl ASTConverter {
)) ))
} }
#[allow(clippy::only_used_in_recursion)] fn convert_const_expr(&mut self, expr: ConstExpr) -> ConstExpr {
fn convert_const_expr(&self, expr: ConstExpr) -> ConstExpr {
match expr { match expr {
ConstExpr::UnaryOp(un) if un.op.is(TokenKind::Mutate) => *un.expr, ConstExpr::UnaryOp(un) if un.op.is(TokenKind::Mutate) => *un.expr,
ConstExpr::App(app) ConstExpr::App(app)
@ -1150,12 +1149,24 @@ impl ASTConverter {
ConstExpr::Set(ConstSet::Normal(set)) ConstExpr::Set(ConstSet::Normal(set))
} }
Some("Callable") => { Some("Callable") => {
let params = match args.pos_args.remove(0).expr { let params = if args.pos_args.is_empty() {
self.errs.push(CompileError::syntax_error(
self.cfg.input.clone(),
line!() as usize,
args.loc(),
self.cur_namespace(),
"`Callable` takes an input type list and a return type".into(),
None,
));
ConstArgs::empty()
} else {
match args.pos_args.remove(0).expr {
ConstExpr::List(ConstList::Normal(list)) => list.elems, ConstExpr::List(ConstList::Normal(list)) => list.elems,
other => { other => {
args.pos_args.insert(0, ConstPosArg::new(other)); args.pos_args.insert(0, ConstPosArg::new(other));
args.clone() args.clone()
} }
}
}; };
let non_defaults = params let non_defaults = params
.pos_args .pos_args
@ -1184,13 +1195,25 @@ impl ASTConverter {
}) })
.collect(); .collect();
let params = Params::new(non_defaults, None, vec![], None, None); let params = Params::new(non_defaults, None, vec![], None, None);
let ret = match args.pos_args.remove(0).expr { let ret = if args.pos_args.is_empty() {
self.errs.push(CompileError::syntax_error(
self.cfg.input.clone(),
line!() as usize,
args.loc(),
self.cur_namespace(),
"Expected a return type".into(),
None,
));
ConstExpr::Accessor(ConstAccessor::Local(Identifier::private("Any".into())))
} else {
match args.pos_args.remove(0).expr {
ConstExpr::Lit(lit) if lit.is(TokenKind::NoneLit) => { ConstExpr::Lit(lit) if lit.is(TokenKind::NoneLit) => {
ConstExpr::Accessor(ConstAccessor::Local(Identifier::private( ConstExpr::Accessor(ConstAccessor::Local(Identifier::private(
"NoneType".into(), "NoneType".into(),
))) )))
} }
other => other, other => other,
}
}; };
let op = Token::dummy(TokenKind::ProcArrow, "=>"); let op = Token::dummy(TokenKind::ProcArrow, "=>");
let body = ConstBlock::new(vec![ret]); let body = ConstBlock::new(vec![ret]);
@ -1234,6 +1257,9 @@ impl ASTConverter {
return Self::gen_dummy_type_spec(args.location()); return Self::gen_dummy_type_spec(args.location());
}; };
let lhs = self.convert_type_spec(tuple.elts.remove(0)); let lhs = self.convert_type_spec(tuple.elts.remove(0));
if tuple.elts.is_empty() {
return lhs;
}
let rhs = self.convert_type_spec(tuple.elts.remove(0)); let rhs = self.convert_type_spec(tuple.elts.remove(0));
let mut union = TypeSpec::or(lhs, rhs); let mut union = TypeSpec::or(lhs, rhs);
for elem in tuple.elts { for elem in tuple.elts {
@ -1264,8 +1290,20 @@ impl ASTConverter {
} }
// TODO: distinguish from collections.abc.Callable // TODO: distinguish from collections.abc.Callable
"Callable" => { "Callable" => {
let py_ast::Expr::Tuple(mut tuple) = args else { let mut tuple = match args {
py_ast::Expr::Tuple(tuple) if tuple.elts.len() == 2 => tuple,
_ => {
let err = CompileError::syntax_error(
self.cfg.input.clone(),
line!() as usize,
pyloc_to_ergloc(args.range()),
self.cur_namespace(),
"`Callable` takes an input type list and a return type".into(),
None,
);
self.errs.push(err);
return Self::gen_dummy_type_spec(args.location()); return Self::gen_dummy_type_spec(args.location());
}
}; };
let params = tuple.elts.remove(0); let params = tuple.elts.remove(0);
let mut non_defaults = vec![]; let mut non_defaults = vec![];
@ -1318,17 +1356,20 @@ impl ASTConverter {
TypeSpec::poly(acc, ConstArgs::pos_only(vec![elem_t], None)) TypeSpec::poly(acc, ConstArgs::pos_only(vec![elem_t], None))
} }
"Mapping" | "MutableMapping" => { "Mapping" | "MutableMapping" => {
let py_ast::Expr::Tuple(mut tuple) = args else { let mut tuple = match args {
py_ast::Expr::Tuple(tuple) if tuple.elts.len() == 2 => tuple,
_ => {
let err = CompileError::syntax_error( let err = CompileError::syntax_error(
self.cfg.input.clone(), self.cfg.input.clone(),
line!() as usize, line!() as usize,
pyloc_to_ergloc(args.range()), pyloc_to_ergloc(args.range()),
self.cur_namespace(), self.cur_namespace(),
format!("`{name}` takes 2 types"), "`Mapping` takes 2 types".into(),
None, None,
); );
self.errs.push(err); self.errs.push(err);
return Self::gen_dummy_type_spec(args.location()); return Self::gen_dummy_type_spec(args.location());
}
}; };
let key_t = match self.convert_expr_to_const(tuple.elts.remove(0)) { let key_t = match self.convert_expr_to_const(tuple.elts.remove(0)) {
Some(key_t) => key_t, Some(key_t) => key_t,
@ -1377,8 +1418,20 @@ impl ASTConverter {
) )
} }
"Dict" | "dict" => { "Dict" | "dict" => {
let py_ast::Expr::Tuple(mut tuple) = args else { let mut tuple = match args {
py_ast::Expr::Tuple(tuple) if tuple.elts.len() == 2 => tuple,
_ => {
let err = CompileError::syntax_error(
self.cfg.input.clone(),
line!() as usize,
pyloc_to_ergloc(args.range()),
self.cur_namespace(),
"`dict` takes 2 types".into(),
None,
);
self.errs.push(err);
return Self::gen_dummy_type_spec(args.location()); return Self::gen_dummy_type_spec(args.location());
}
}; };
let (l_brace, r_brace) = Self::gen_enclosure_tokens(TokenKind::LBrace, tuple.range); let (l_brace, r_brace) = Self::gen_enclosure_tokens(TokenKind::LBrace, tuple.range);
let key_t = match self.convert_expr_to_const(tuple.elts.remove(0)) { let key_t = match self.convert_expr_to_const(tuple.elts.remove(0)) {

8
tests/err/type_spec.py Normal file
View File

@ -0,0 +1,8 @@
from typing import Callable, Mapping
_: Mapping[int, str, str] = ... # ERR
_: Mapping[int] = ... # ERR
_: Callable[[int, str]] = ... # ERR
_: Callable[int] = ... # ERR
_: dict[int] = ... # ERR
_: dict[int, int, int] = ... # ERR

View File

@ -173,6 +173,11 @@ fn exec_typevar() -> Result<(), String> {
expect("tests/typevar.py", 0, 3) expect("tests/typevar.py", 0, 3)
} }
#[test]
fn exec_type_spec() -> Result<(), String> {
expect("tests/err/type_spec.py", 0, 6)
}
#[test] #[test]
fn exec_union() -> Result<(), String> { fn exec_union() -> Result<(), String> {
expect("tests/union.py", 0, 0) expect("tests/union.py", 0, 0)