diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 4b3bf4af42..9f805edf6f 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1426,13 +1426,14 @@ impl<'db> TypeInferenceBuilder<'db> { } fn infer_string_literal_expression(&mut self, literal: &ast::ExprStringLiteral) -> Type<'db> { - let value = if literal.value.len() <= Self::MAX_STRING_LITERAL_SIZE { - literal.value.to_str().into() + if literal.value.len() <= Self::MAX_STRING_LITERAL_SIZE { + Type::StringLiteral(StringLiteralType::new( + self.db, + literal.value.to_str().into(), + )) } else { - Box::default() - }; - - Type::StringLiteral(StringLiteralType::new(self.db, value)) + Type::LiteralString + } } fn infer_bytes_literal_expression(&mut self, literal: &ast::ExprBytesLiteral) -> Type<'db> { @@ -2041,13 +2042,23 @@ impl<'db> TypeInferenceBuilder<'db> { } (Type::StringLiteral(lhs), Type::StringLiteral(rhs), ast::Operator::Add) => { - Type::StringLiteral(StringLiteralType::new(self.db, { - let lhs_value = lhs.value(self.db).to_string(); - let rhs_value = rhs.value(self.db).as_ref(); - (lhs_value + rhs_value).into() - })) + let lhs_value = lhs.value(self.db).to_string(); + let rhs_value = rhs.value(self.db).as_ref(); + if lhs_value.len() + rhs_value.len() <= Self::MAX_STRING_LITERAL_SIZE { + Type::StringLiteral(StringLiteralType::new(self.db, { + (lhs_value + rhs_value).into() + })) + } else { + Type::LiteralString + } } + ( + Type::StringLiteral(_) | Type::LiteralString, + Type::StringLiteral(_) | Type::LiteralString, + ast::Operator::Add, + ) => Type::LiteralString, + (Type::StringLiteral(s), Type::IntLiteral(n), ast::Operator::Mult) | (Type::IntLiteral(n), Type::StringLiteral(s), ast::Operator::Mult) => { if n < 1 { @@ -2066,6 +2077,15 @@ impl<'db> TypeInferenceBuilder<'db> { } } + (Type::LiteralString, Type::IntLiteral(n), ast::Operator::Mult) + | (Type::IntLiteral(n), Type::LiteralString, ast::Operator::Mult) => { + if n < 1 { + Type::StringLiteral(StringLiteralType::new(self.db, Box::default())) + } else { + Type::LiteralString + } + } + _ => Type::Unknown, // TODO } } @@ -2892,6 +2912,70 @@ mod tests { Ok(()) } + #[test] + fn multiplied_literal_string() -> anyhow::Result<()> { + let mut db = setup_db(); + let content = format!( + r#" + v = "{y}" + w = 10*"{y}" + x = "{y}"*10 + z = 0*"{y}" + u = (-100)*"{y}" + "#, + y = "a".repeat(TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE + 1), + ); + db.write_dedented("src/a.py", &content)?; + + assert_public_ty(&db, "src/a.py", "v", "LiteralString"); + assert_public_ty(&db, "src/a.py", "w", "LiteralString"); + assert_public_ty(&db, "src/a.py", "x", "LiteralString"); + assert_public_ty(&db, "src/a.py", "z", r#"Literal[""]"#); + assert_public_ty(&db, "src/a.py", "u", r#"Literal[""]"#); + Ok(()) + } + + #[test] + fn truncated_string_literals_become_literal_string() -> anyhow::Result<()> { + let mut db = setup_db(); + let content = format!( + r#" + w = "{y}" + x = "a" + "{z}" + "#, + y = "a".repeat(TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE + 1), + z = "a".repeat(TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE), + ); + db.write_dedented("src/a.py", &content)?; + + assert_public_ty(&db, "src/a.py", "w", "LiteralString"); + assert_public_ty(&db, "src/a.py", "x", "LiteralString"); + + Ok(()) + } + + #[test] + fn adding_string_literals_and_literal_string() -> anyhow::Result<()> { + let mut db = setup_db(); + let content = format!( + r#" + v = "{y}" + w = "{y}" + "a" + x = "a" + "{y}" + z = "{y}" + "{y}" + "#, + y = "a".repeat(TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE + 1), + ); + db.write_dedented("src/a.py", &content)?; + + assert_public_ty(&db, "src/a.py", "v", "LiteralString"); + assert_public_ty(&db, "src/a.py", "w", "LiteralString"); + assert_public_ty(&db, "src/a.py", "x", "LiteralString"); + assert_public_ty(&db, "src/a.py", "z", "LiteralString"); + + Ok(()) + } + #[test] fn bytes_type() -> anyhow::Result<()> { let mut db = setup_db();