feat: support arguments expansion

This commit is contained in:
Shunsuke Shibayama 2024-08-20 02:53:19 +09:00
parent d448aaf974
commit f5503d6f9e
3 changed files with 67 additions and 24 deletions

View File

@ -6,7 +6,7 @@ use erg_common::dict::Dict as HashMap;
use erg_common::fresh::FRESH_GEN;
use erg_common::set::Set as HashSet;
use erg_common::traits::{Locational, Stream};
use erg_common::{log, set};
use erg_common::{fmt_vec, log, set};
use erg_compiler::artifact::IncompleteArtifact;
use erg_compiler::erg_parser::ast::{
Accessor, Args, BinOp, Block, ClassAttr, ClassAttrs, ClassDef, ConstAccessor, ConstArgs,
@ -228,9 +228,20 @@ pub struct TypeVarInfo {
impl fmt::Display for TypeVarInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(bound) = &self.bound {
write!(f, "TypeVarInfo({} bound={})", self.name, bound)
write!(
f,
"TypeVarInfo({}, [{}], bound={})",
self.name,
fmt_vec(&self.constraints),
bound
)
} else {
write!(f, "TypeVarInfo({})", self.name)
write!(
f,
"TypeVarInfo({}, [{}])",
self.name,
fmt_vec(&self.constraints)
)
}
}
}
@ -1156,21 +1167,40 @@ impl ASTConverter {
py_ast::Expr::Call(call) => {
let loc = call.location();
let function = self.convert_expr(*call.func);
let pos_args = call
let (pos_args, var_args): (Vec<_>, _) = call
.args
.into_iter()
.partition(|arg| !arg.is_starred_expr());
let pos_args = pos_args
.into_iter()
.map(|ex| PosArg::new(self.convert_expr(ex)))
.collect::<Vec<_>>();
let kw_args = call
.keywords
let var_args = var_args
.into_iter()
.map(|ex| {
let py_ast::Expr::Starred(star) = ex else {
unreachable!()
};
PosArg::new(self.convert_expr(*star.value))
})
.next();
let (kw_args, kw_var): (Vec<_>, _) =
call.keywords.into_iter().partition(|kw| kw.arg.is_some());
let kw_args = kw_args
.into_iter()
.map(|Keyword { arg, value, range }| {
let name = arg.unwrap_or(rustpython_ast::Identifier::new("_"));
let name = Token::symbol_with_loc(name.to_string(), pyloc_to_ergloc(range));
let name = Token::symbol_with_loc(
arg.unwrap().to_string(),
pyloc_to_ergloc(range),
);
let ex = self.convert_expr(value);
KwArg::new(name, None, ex)
})
.collect::<Vec<_>>();
.collect();
let kw_var = kw_var
.into_iter()
.map(|Keyword { value, .. }| PosArg::new(self.convert_expr(value)))
.next();
let last_col = pos_args
.last()
.and_then(|last| last.col_end())
@ -1185,7 +1215,7 @@ impl ASTConverter {
let rp = Token::new(TokenKind::RParen, ")", loc.row.get(), last_col);
(lp, rp)
};
let args = Args::new(pos_args, None, kw_args, None, Some(paren));
let args = Args::new(pos_args, var_args, kw_args, kw_var, Some(paren));
function.call_expr(args)
}
py_ast::Expr::BinOp(bin) => {
@ -1667,7 +1697,6 @@ impl ASTConverter {
fn get_type_bounds(&mut self, type_params: Vec<TypeParam>) -> TypeBoundSpecs {
let mut bounds = TypeBoundSpecs::empty();
let mut errs = vec![];
if type_params.is_empty() {
for ty in self.cur_appeared_type_names() {
let name = VarName::from_str(ty.clone().into());
@ -1679,17 +1708,6 @@ impl ASTConverter {
let spec = TypeSpecWithOp::new(op, t_spec, bound.clone());
TypeBoundSpec::non_default(name, spec)
} else if !tv_info.constraints.is_empty() {
if tv_info.constraints.len() == 1 {
let err = CompileError::syntax_error(
self.cfg.input.clone(),
line!() as usize,
pyloc_to_ergloc(type_params[0].range()),
self.cur_namespace(),
"TypeVar must have at least two constrained types".into(),
None,
);
errs.push(err);
}
let op = Token::dummy(TokenKind::Colon, ":");
let mut elems = vec![];
for constraint in tv_info.constraints.iter() {
@ -1738,7 +1756,6 @@ impl ASTConverter {
};
bounds.push(spec);
}
self.errs.extend(errs);
bounds
}
@ -1942,6 +1959,18 @@ impl ASTConverter {
constraints.push(constr.clone());
nth += 1;
}
if constraints.len() == 1 {
let err = CompileError::syntax_error(
self.cfg.input.clone(),
line!() as usize,
call.args.get_nth(1).unwrap().loc(),
self.cur_namespace(),
"TypeVar must have at least two constrained types"
.into(),
None,
);
self.errs.push(err);
}
let bound = call.args.get_with_key("bound").cloned();
let info = TypeVarInfo::new(arg, constraints, bound);
self.define_type_var(name.id.to_string(), info);

View File

@ -8,3 +8,17 @@ def f(x, y=1):
print(f(1, 2)) # OK
print(f(1)) # OK
print(f(1, y="a")) # ERR
def g(first, second):
pass
g(**{"first": "bar", "second": 1}) # OK
g(**[1, 2]) # ERR
g(1, *[2]) # OK
g(*[1, 2]) # OK
g(1, 2, *[3, 4]) # ERR
g(*1) # ERR
g(*[1], **{"second": 1}) # OK
_ = f(1, *[2]) # OK
_ = f(**{"x": 1, "y": 2}) # OK

View File

@ -129,7 +129,7 @@ fn exec_collection() -> Result<(), String> {
#[test]
fn exec_call() -> Result<(), String> {
expect("tests/call.py", 0, 3)
expect("tests/call.py", 0, 6)
}
#[test]