diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index d916de29e1..5deca032ed 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -3268,14 +3268,14 @@ impl<'db> StaticClassLiteral<'db> { } .with_annotated_type(field_ty); - if matches!(name, "__replace__" | "_replace") { + parameter = if matches!(name, "__replace__" | "_replace") { // When replacing, we know there is a default value for the field // (the value that is currently assigned to the field) // assume this to be the declared type of the field - parameter = parameter.with_default_type(field_ty); - } else if let Some(default_ty) = default_ty { - parameter = parameter.with_default_type(default_ty); - } + parameter.with_default_type(field_ty) + } else { + parameter.with_optional_default_type(default_ty) + }; parameters.push(parameter); } @@ -5321,19 +5321,15 @@ fn synthesize_namedtuple_class_member<'db>( let generic_context = GenericContext::from_typevar_instances(db, variables); - let mut parameters = vec![ - Parameter::positional_or_keyword(Name::new_static("cls")) - .with_annotated_type(SubclassOfType::from(db, self_typevar)), - ]; + let first_parameter = Parameter::positional_or_keyword(Name::new_static("cls")) + .with_annotated_type(SubclassOfType::from(db, self_typevar)); - for (field_name, field_ty, default_ty) in fields { - let mut param = - Parameter::positional_or_keyword(field_name).with_annotated_type(field_ty); - if let Some(default) = default_ty { - param = param.with_default_type(default); - } - parameters.push(param); - } + let parameters = + std::iter::once(first_parameter).chain(fields.map(|(name, ty, default)| { + Parameter::positional_or_keyword(name) + .with_annotated_type(ty) + .with_optional_default_type(default) + })); let signature = Signature::new_generic( Some(generic_context), @@ -5364,18 +5360,14 @@ fn synthesize_namedtuple_class_member<'db>( BindingContext::Synthetic, )); - let mut parameters = vec![ - Parameter::positional_or_keyword(Name::new_static("self")) - .with_annotated_type(self_ty), - ]; + let first_parameter = Parameter::positional_or_keyword(Name::new_static("self")) + .with_annotated_type(self_ty); - for (field_name, field_ty, _) in fields { - parameters.push( - Parameter::keyword_only(field_name) - .with_annotated_type(field_ty) - .with_default_type(field_ty), - ); - } + let parameters = std::iter::once(first_parameter).chain(fields.map(|(name, ty, _)| { + Parameter::keyword_only(name) + .with_annotated_type(ty) + .with_default_type(ty) + })); let signature = Signature::new(Parameters::new(db, parameters), self_ty); Some(Type::function_like_callable(db, signature)) diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index af24987081..965466f6a1 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -6162,7 +6162,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let known_class = func_ty .as_class_literal() .and_then(|cls| cls.known(self.db())); - if let Some(KnownClass::NewType) = known_class { + if known_class == Some(KnownClass::NewType) { self.infer_newtype_assignment_deferred(arguments); return; } @@ -10536,28 +10536,22 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .posonlyargs .iter() .map(|param| { - let mut parameter = Parameter::positional_only(Some(param.name().id.clone())); - if let Some(default) = param.default() { - parameter = parameter.with_default_type( - self.infer_expression(default, TypeContext::default()) - .replace_parameter_defaults(self.db()), - ); - } - parameter + Parameter::positional_only(Some(param.name().id.clone())) + .with_optional_default_type(param.default().map(|default_expr| { + self.infer_expression(default_expr, TypeContext::default()) + .replace_parameter_defaults(self.db()) + })) }) .collect::>(); let positional_or_keyword = parameters .args .iter() .map(|param| { - let mut parameter = Parameter::positional_or_keyword(param.name().id.clone()); - if let Some(default) = param.default() { - parameter = parameter.with_default_type( - self.infer_expression(default, TypeContext::default()) - .replace_parameter_defaults(self.db()), - ); - } - parameter + Parameter::positional_or_keyword(param.name().id.clone()) + .with_optional_default_type(param.default().map(|default_expr| { + self.infer_expression(default_expr, TypeContext::default()) + .replace_parameter_defaults(self.db()) + })) }) .collect::>(); let variadic = parameters @@ -10568,14 +10562,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .kwonlyargs .iter() .map(|param| { - let mut parameter = Parameter::keyword_only(param.name().id.clone()); - if let Some(default) = param.default() { - parameter = parameter.with_default_type( - self.infer_expression(default, TypeContext::default()) - .replace_parameter_defaults(self.db()), - ); - } - parameter + Parameter::keyword_only(param.name().id.clone()).with_optional_default_type( + param.default().map(|default_expr| { + self.infer_expression(default_expr, TypeContext::default()) + .replace_parameter_defaults(self.db()) + }), + ) }) .collect::>(); let keyword_variadic = parameters diff --git a/crates/ty_python_semantic/src/types/property_tests/type_generation.rs b/crates/ty_python_semantic/src/types/property_tests/type_generation.rs index 4e789bb94c..cfd5250846 100644 --- a/crates/ty_python_semantic/src/types/property_tests/type_generation.rs +++ b/crates/ty_python_semantic/src/types/property_tests/type_generation.rs @@ -80,7 +80,7 @@ impl CallableParams { CallableParams::List(params) => Parameters::new( db, params.into_iter().map(|param| { - let mut parameter = match param.kind { + let parameter = match param.kind { ParamKind::PositionalOnly => Parameter::positional_only(param.name), ParamKind::PositionalOrKeyword => { Parameter::positional_or_keyword(param.name.unwrap()) @@ -91,11 +91,9 @@ impl CallableParams { Parameter::keyword_variadic(param.name.unwrap()) } }; - parameter = parameter.with_annotated_type(param.annotated_ty.into_type(db)); - if let Some(default_ty) = param.default_ty { - parameter = parameter.with_default_type(default_ty.into_type(db)); - } parameter + .with_annotated_type(param.annotated_ty.into_type(db)) + .with_optional_default_type(param.default_ty.map(|t| t.into_type(db))) }), ), } diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 5747249c16..75232d3314 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -2297,6 +2297,14 @@ impl<'db> Parameter<'db> { self } + pub(crate) fn with_optional_default_type(self, default: Option>) -> Self { + if let Some(default) = default { + self.with_default_type(default) + } else { + self + } + } + pub(crate) fn type_form(mut self) -> Self { self.form = ParameterForm::Type; self