WIP: Add support for `TypeAlias` and `TypeParam`

This commit is contained in:
Zanie 2023-07-17 17:52:59 -05:00
parent bfaa1f9530
commit e34cfeb475
8 changed files with 202 additions and 2 deletions

View File

@ -915,6 +915,7 @@ pub struct StmtFunctionDef<'a> {
args: ComparableArguments<'a>, args: ComparableArguments<'a>,
body: Vec<ComparableStmt<'a>>, body: Vec<ComparableStmt<'a>>,
decorator_list: Vec<ComparableDecorator<'a>>, decorator_list: Vec<ComparableDecorator<'a>>,
type_params: Vec<ComparableTypeParam<'a>>,
returns: Option<ComparableExpr<'a>>, returns: Option<ComparableExpr<'a>>,
type_comment: Option<&'a str>, type_comment: Option<&'a str>,
} }
@ -925,6 +926,7 @@ pub struct StmtAsyncFunctionDef<'a> {
args: ComparableArguments<'a>, args: ComparableArguments<'a>,
body: Vec<ComparableStmt<'a>>, body: Vec<ComparableStmt<'a>>,
decorator_list: Vec<ComparableDecorator<'a>>, decorator_list: Vec<ComparableDecorator<'a>>,
type_params: Vec<ComparableTypeParam<'a>>,
returns: Option<ComparableExpr<'a>>, returns: Option<ComparableExpr<'a>>,
type_comment: Option<&'a str>, type_comment: Option<&'a str>,
} }
@ -936,6 +938,7 @@ pub struct StmtClassDef<'a> {
keywords: Vec<ComparableKeyword<'a>>, keywords: Vec<ComparableKeyword<'a>>,
body: Vec<ComparableStmt<'a>>, body: Vec<ComparableStmt<'a>>,
decorator_list: Vec<ComparableDecorator<'a>>, decorator_list: Vec<ComparableDecorator<'a>>,
type_params: Vec<ComparableTypeParam<'a>>,
} }
#[derive(Debug, PartialEq, Eq, Hash)] #[derive(Debug, PartialEq, Eq, Hash)]
@ -948,6 +951,59 @@ pub struct StmtDelete<'a> {
targets: Vec<ComparableExpr<'a>>, targets: Vec<ComparableExpr<'a>>,
} }
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct StmtTypeAlias<'a> {
pub name: Box<ComparableExpr<'a>>,
pub type_params: Vec<ComparableTypeParam<'a>>,
pub value: Box<ComparableExpr<'a>>,
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub enum ComparableTypeParam<'a> {
TypeVar(TypeParamTypeVar<'a>),
ParamSpec(TypeParamParamSpec<'a>),
TypeVarTuple(TypeParamTypeVarTuple<'a>),
}
impl<'a> From<&'a ast::TypeParam> for ComparableTypeParam<'a> {
fn from(type_param: &'a ast::TypeParam) -> Self {
match type_param {
ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, bound, .. }) => {
Self::TypeVar(TypeParamTypeVar {
name: name.as_str(),
bound: bound.as_ref().map(Into::into),
})
}
ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, .. }) => {
Self::TypeVarTuple(TypeParamTypeVarTuple {
name: name.as_str(),
})
}
ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, .. }) => {
Self::ParamSpec(TypeParamParamSpec {
name: name.as_str(),
})
}
}
}
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct TypeParamTypeVar<'a> {
pub name: &'a str,
pub bound: Option<Box<ComparableExpr<'a>>>,
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct TypeParamParamSpec<'a> {
pub name: &'a str,
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct TypeParamTypeVarTuple<'a> {
pub name: &'a str,
}
#[derive(Debug, PartialEq, Eq, Hash)] #[derive(Debug, PartialEq, Eq, Hash)]
pub struct StmtAssign<'a> { pub struct StmtAssign<'a> {
targets: Vec<ComparableExpr<'a>>, targets: Vec<ComparableExpr<'a>>,
@ -1097,6 +1153,7 @@ pub enum ComparableStmt<'a> {
Raise(StmtRaise<'a>), Raise(StmtRaise<'a>),
Try(StmtTry<'a>), Try(StmtTry<'a>),
TryStar(StmtTryStar<'a>), TryStar(StmtTryStar<'a>),
TypeAlias(StmtTypeAlias<'a>),
Assert(StmtAssert<'a>), Assert(StmtAssert<'a>),
Import(StmtImport<'a>), Import(StmtImport<'a>),
ImportFrom(StmtImportFrom<'a>), ImportFrom(StmtImportFrom<'a>),
@ -1118,6 +1175,7 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
decorator_list, decorator_list,
returns, returns,
type_comment, type_comment,
type_params,
range: _range, range: _range,
}) => Self::FunctionDef(StmtFunctionDef { }) => Self::FunctionDef(StmtFunctionDef {
name: name.as_str(), name: name.as_str(),
@ -1126,6 +1184,7 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
decorator_list: decorator_list.iter().map(Into::into).collect(), decorator_list: decorator_list.iter().map(Into::into).collect(),
returns: returns.as_ref().map(Into::into), returns: returns.as_ref().map(Into::into),
type_comment: type_comment.as_ref().map(String::as_str), type_comment: type_comment.as_ref().map(String::as_str),
type_params: type_params.iter().map(Into::into).collect(),
}), }),
ast::Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { ast::Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef {
name, name,
@ -1134,6 +1193,7 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
decorator_list, decorator_list,
returns, returns,
type_comment, type_comment,
type_params,
range: _range, range: _range,
}) => Self::AsyncFunctionDef(StmtAsyncFunctionDef { }) => Self::AsyncFunctionDef(StmtAsyncFunctionDef {
name: name.as_str(), name: name.as_str(),
@ -1142,6 +1202,7 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
decorator_list: decorator_list.iter().map(Into::into).collect(), decorator_list: decorator_list.iter().map(Into::into).collect(),
returns: returns.as_ref().map(Into::into), returns: returns.as_ref().map(Into::into),
type_comment: type_comment.as_ref().map(String::as_str), type_comment: type_comment.as_ref().map(String::as_str),
type_params: type_params.iter().map(Into::into).collect(),
}), }),
ast::Stmt::ClassDef(ast::StmtClassDef { ast::Stmt::ClassDef(ast::StmtClassDef {
name, name,
@ -1149,6 +1210,7 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
keywords, keywords,
body, body,
decorator_list, decorator_list,
type_params,
range: _range, range: _range,
}) => Self::ClassDef(StmtClassDef { }) => Self::ClassDef(StmtClassDef {
name: name.as_str(), name: name.as_str(),
@ -1156,6 +1218,7 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
keywords: keywords.iter().map(Into::into).collect(), keywords: keywords.iter().map(Into::into).collect(),
body: body.iter().map(Into::into).collect(), body: body.iter().map(Into::into).collect(),
decorator_list: decorator_list.iter().map(Into::into).collect(), decorator_list: decorator_list.iter().map(Into::into).collect(),
type_params: type_params.iter().map(Into::into).collect(),
}), }),
ast::Stmt::Return(ast::StmtReturn { ast::Stmt::Return(ast::StmtReturn {
value, value,
@ -1169,6 +1232,16 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
}) => Self::Delete(StmtDelete { }) => Self::Delete(StmtDelete {
targets: targets.iter().map(Into::into).collect(), targets: targets.iter().map(Into::into).collect(),
}), }),
ast::Stmt::TypeAlias(ast::StmtTypeAlias {
range: _range,
name,
type_params,
value,
}) => Self::TypeAlias(StmtTypeAlias {
name: name.into(),
type_params: type_params.iter().map(Into::into).collect(),
value: value.into(),
}),
ast::Stmt::Assign(ast::StmtAssign { ast::Stmt::Assign(ast::StmtAssign {
targets, targets,
value, value,

View File

@ -7,7 +7,7 @@ use ruff_text_size::{TextRange, TextSize};
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use rustpython_ast::CmpOp; use rustpython_ast::CmpOp;
use rustpython_parser::ast::{ use rustpython_parser::ast::{
self, Arguments, Constant, ExceptHandler, Expr, Keyword, MatchCase, Pattern, Ranged, Stmt, self, Arguments, Constant, ExceptHandler, Expr, Keyword, MatchCase, Pattern, Ranged, Stmt, TypeParam
}; };
use rustpython_parser::{lexer, Mode, Tok}; use rustpython_parser::{lexer, Mode, Tok};
use smallvec::SmallVec; use smallvec::SmallVec;
@ -265,6 +265,24 @@ where
} }
} }
pub fn any_over_type_param<F>(type_param: &TypeParam, func: &F) -> bool
where
F: Fn(&Expr) -> bool,
{
match type_param {
TypeParam::TypeVar(ast::TypeParamTypeVar { bound, .. }) => {
bound.as_ref().map_or(false, |value| any_over_expr(value, func))
}
TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { .. }) => {
false
}
TypeParam::ParamSpec(ast::TypeParamParamSpec { .. }) => {
false
}
}
}
pub fn any_over_pattern<F>(pattern: &Pattern, func: &F) -> bool pub fn any_over_pattern<F>(pattern: &Pattern, func: &F) -> bool
where where
F: Fn(&Expr) -> bool, F: Fn(&Expr) -> bool,
@ -391,6 +409,11 @@ where
targets, targets,
range: _range, range: _range,
}) => targets.iter().any(|expr| any_over_expr(expr, func)), }) => targets.iter().any(|expr| any_over_expr(expr, func)),
Stmt::TypeAlias(ast::StmtTypeAlias { name, type_params, value, .. }) => {
any_over_expr(name, func)
|| type_params.iter().any(|type_param| any_over_type_param(type_param, func))
|| any_over_expr(value, func)
}
Stmt::Assign(ast::StmtAssign { targets, value, .. }) => { Stmt::Assign(ast::StmtAssign { targets, value, .. }) => {
targets.iter().any(|expr| any_over_expr(expr, func)) || any_over_expr(value, func) targets.iter().any(|expr| any_over_expr(expr, func)) || any_over_expr(value, func)
} }

View File

@ -30,6 +30,7 @@ pub enum AnyNode {
StmtClassDef(ast::StmtClassDef), StmtClassDef(ast::StmtClassDef),
StmtReturn(ast::StmtReturn), StmtReturn(ast::StmtReturn),
StmtDelete(ast::StmtDelete), StmtDelete(ast::StmtDelete),
StmtTypeAlias(ast::StmtTypeAlias),
StmtAssign(ast::StmtAssign), StmtAssign(ast::StmtAssign),
StmtAugAssign(ast::StmtAugAssign), StmtAugAssign(ast::StmtAugAssign),
StmtAnnAssign(ast::StmtAnnAssign), StmtAnnAssign(ast::StmtAnnAssign),
@ -108,6 +109,7 @@ impl AnyNode {
AnyNode::StmtClassDef(node) => Some(Stmt::ClassDef(node)), AnyNode::StmtClassDef(node) => Some(Stmt::ClassDef(node)),
AnyNode::StmtReturn(node) => Some(Stmt::Return(node)), AnyNode::StmtReturn(node) => Some(Stmt::Return(node)),
AnyNode::StmtDelete(node) => Some(Stmt::Delete(node)), AnyNode::StmtDelete(node) => Some(Stmt::Delete(node)),
AnyNode::StmtTypeAlias(node) => Some(Stmt::TypeAlias(node)),
AnyNode::StmtAssign(node) => Some(Stmt::Assign(node)), AnyNode::StmtAssign(node) => Some(Stmt::Assign(node)),
AnyNode::StmtAugAssign(node) => Some(Stmt::AugAssign(node)), AnyNode::StmtAugAssign(node) => Some(Stmt::AugAssign(node)),
AnyNode::StmtAnnAssign(node) => Some(Stmt::AnnAssign(node)), AnyNode::StmtAnnAssign(node) => Some(Stmt::AnnAssign(node)),
@ -223,6 +225,7 @@ impl AnyNode {
| AnyNode::StmtClassDef(_) | AnyNode::StmtClassDef(_)
| AnyNode::StmtReturn(_) | AnyNode::StmtReturn(_)
| AnyNode::StmtDelete(_) | AnyNode::StmtDelete(_)
| AnyNode::StmtTypeAlias(_)
| AnyNode::StmtAssign(_) | AnyNode::StmtAssign(_)
| AnyNode::StmtAugAssign(_) | AnyNode::StmtAugAssign(_)
| AnyNode::StmtAnnAssign(_) | AnyNode::StmtAnnAssign(_)
@ -279,6 +282,7 @@ impl AnyNode {
| AnyNode::StmtClassDef(_) | AnyNode::StmtClassDef(_)
| AnyNode::StmtReturn(_) | AnyNode::StmtReturn(_)
| AnyNode::StmtDelete(_) | AnyNode::StmtDelete(_)
| AnyNode::StmtTypeAlias(_)
| AnyNode::StmtAssign(_) | AnyNode::StmtAssign(_)
| AnyNode::StmtAugAssign(_) | AnyNode::StmtAugAssign(_)
| AnyNode::StmtAnnAssign(_) | AnyNode::StmtAnnAssign(_)
@ -370,6 +374,7 @@ impl AnyNode {
| AnyNode::StmtClassDef(_) | AnyNode::StmtClassDef(_)
| AnyNode::StmtReturn(_) | AnyNode::StmtReturn(_)
| AnyNode::StmtDelete(_) | AnyNode::StmtDelete(_)
| AnyNode::StmtTypeAlias(_)
| AnyNode::StmtAssign(_) | AnyNode::StmtAssign(_)
| AnyNode::StmtAugAssign(_) | AnyNode::StmtAugAssign(_)
| AnyNode::StmtAnnAssign(_) | AnyNode::StmtAnnAssign(_)
@ -446,6 +451,7 @@ impl AnyNode {
| AnyNode::StmtClassDef(_) | AnyNode::StmtClassDef(_)
| AnyNode::StmtReturn(_) | AnyNode::StmtReturn(_)
| AnyNode::StmtDelete(_) | AnyNode::StmtDelete(_)
| AnyNode::StmtTypeAlias(_)
| AnyNode::StmtAssign(_) | AnyNode::StmtAssign(_)
| AnyNode::StmtAugAssign(_) | AnyNode::StmtAugAssign(_)
| AnyNode::StmtAnnAssign(_) | AnyNode::StmtAnnAssign(_)
@ -529,6 +535,7 @@ impl AnyNode {
| AnyNode::StmtClassDef(_) | AnyNode::StmtClassDef(_)
| AnyNode::StmtReturn(_) | AnyNode::StmtReturn(_)
| AnyNode::StmtDelete(_) | AnyNode::StmtDelete(_)
| AnyNode::StmtTypeAlias(_)
| AnyNode::StmtAssign(_) | AnyNode::StmtAssign(_)
| AnyNode::StmtAugAssign(_) | AnyNode::StmtAugAssign(_)
| AnyNode::StmtAnnAssign(_) | AnyNode::StmtAnnAssign(_)
@ -634,6 +641,7 @@ impl AnyNode {
Self::StmtClassDef(node) => AnyNodeRef::StmtClassDef(node), Self::StmtClassDef(node) => AnyNodeRef::StmtClassDef(node),
Self::StmtReturn(node) => AnyNodeRef::StmtReturn(node), Self::StmtReturn(node) => AnyNodeRef::StmtReturn(node),
Self::StmtDelete(node) => AnyNodeRef::StmtDelete(node), Self::StmtDelete(node) => AnyNodeRef::StmtDelete(node),
Self::StmtTypeAlias(node) => AnyNodeRef::StmtTypeAlias(node),
Self::StmtAssign(node) => AnyNodeRef::StmtAssign(node), Self::StmtAssign(node) => AnyNodeRef::StmtAssign(node),
Self::StmtAugAssign(node) => AnyNodeRef::StmtAugAssign(node), Self::StmtAugAssign(node) => AnyNodeRef::StmtAugAssign(node),
Self::StmtAnnAssign(node) => AnyNodeRef::StmtAnnAssign(node), Self::StmtAnnAssign(node) => AnyNodeRef::StmtAnnAssign(node),
@ -963,6 +971,34 @@ impl AstNode for ast::StmtDelete {
AnyNode::from(self) AnyNode::from(self)
} }
} }
impl AstNode for ast::StmtTypeAlias {
fn cast(kind: AnyNode) -> Option<Self>
where
Self: Sized,
{
if let AnyNode::StmtTypeAlias(node) = kind {
Some(node)
} else {
None
}
}
fn cast_ref(kind: AnyNodeRef) -> Option<&Self> {
if let AnyNodeRef::StmtTypeAlias(node) = kind {
Some(node)
} else {
None
}
}
fn as_any_node_ref(&self) -> AnyNodeRef {
AnyNodeRef::from(self)
}
fn into_any_node(self) -> AnyNode {
AnyNode::from(self)
}
}
impl AstNode for ast::StmtAssign { impl AstNode for ast::StmtAssign {
fn cast(kind: AnyNode) -> Option<Self> fn cast(kind: AnyNode) -> Option<Self>
where where
@ -2878,6 +2914,7 @@ impl From<Stmt> for AnyNode {
Stmt::ClassDef(node) => AnyNode::StmtClassDef(node), Stmt::ClassDef(node) => AnyNode::StmtClassDef(node),
Stmt::Return(node) => AnyNode::StmtReturn(node), Stmt::Return(node) => AnyNode::StmtReturn(node),
Stmt::Delete(node) => AnyNode::StmtDelete(node), Stmt::Delete(node) => AnyNode::StmtDelete(node),
Stmt::TypeAlias(node) => AnyNode::StmtTypeAlias(node),
Stmt::Assign(node) => AnyNode::StmtAssign(node), Stmt::Assign(node) => AnyNode::StmtAssign(node),
Stmt::AugAssign(node) => AnyNode::StmtAugAssign(node), Stmt::AugAssign(node) => AnyNode::StmtAugAssign(node),
Stmt::AnnAssign(node) => AnyNode::StmtAnnAssign(node), Stmt::AnnAssign(node) => AnyNode::StmtAnnAssign(node),
@ -3034,6 +3071,12 @@ impl From<ast::StmtDelete> for AnyNode {
} }
} }
impl From<ast::StmtTypeAlias> for AnyNode {
fn from(node: ast::StmtTypeAlias) -> Self {
AnyNode::StmtTypeAlias(node)
}
}
impl From<ast::StmtAssign> for AnyNode { impl From<ast::StmtAssign> for AnyNode {
fn from(node: ast::StmtAssign) -> Self { fn from(node: ast::StmtAssign) -> Self {
AnyNode::StmtAssign(node) AnyNode::StmtAssign(node)
@ -3446,6 +3489,7 @@ impl Ranged for AnyNode {
AnyNode::StmtClassDef(node) => node.range(), AnyNode::StmtClassDef(node) => node.range(),
AnyNode::StmtReturn(node) => node.range(), AnyNode::StmtReturn(node) => node.range(),
AnyNode::StmtDelete(node) => node.range(), AnyNode::StmtDelete(node) => node.range(),
AnyNode::StmtTypeAlias(node) => node.range(),
AnyNode::StmtAssign(node) => node.range(), AnyNode::StmtAssign(node) => node.range(),
AnyNode::StmtAugAssign(node) => node.range(), AnyNode::StmtAugAssign(node) => node.range(),
AnyNode::StmtAnnAssign(node) => node.range(), AnyNode::StmtAnnAssign(node) => node.range(),
@ -3529,6 +3573,7 @@ pub enum AnyNodeRef<'a> {
StmtClassDef(&'a ast::StmtClassDef), StmtClassDef(&'a ast::StmtClassDef),
StmtReturn(&'a ast::StmtReturn), StmtReturn(&'a ast::StmtReturn),
StmtDelete(&'a ast::StmtDelete), StmtDelete(&'a ast::StmtDelete),
StmtTypeAlias(&'a ast::StmtTypeAlias),
StmtAssign(&'a ast::StmtAssign), StmtAssign(&'a ast::StmtAssign),
StmtAugAssign(&'a ast::StmtAugAssign), StmtAugAssign(&'a ast::StmtAugAssign),
StmtAnnAssign(&'a ast::StmtAnnAssign), StmtAnnAssign(&'a ast::StmtAnnAssign),
@ -3611,6 +3656,7 @@ impl AnyNodeRef<'_> {
AnyNodeRef::StmtClassDef(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtClassDef(node) => NonNull::from(*node).cast(),
AnyNodeRef::StmtReturn(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtReturn(node) => NonNull::from(*node).cast(),
AnyNodeRef::StmtDelete(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtDelete(node) => NonNull::from(*node).cast(),
AnyNodeRef::StmtTypeAlias(node) => NonNull::from(*node).cast(),
AnyNodeRef::StmtAssign(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtAssign(node) => NonNull::from(*node).cast(),
AnyNodeRef::StmtAugAssign(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtAugAssign(node) => NonNull::from(*node).cast(),
AnyNodeRef::StmtAnnAssign(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtAnnAssign(node) => NonNull::from(*node).cast(),
@ -3699,6 +3745,7 @@ impl AnyNodeRef<'_> {
AnyNodeRef::StmtClassDef(_) => NodeKind::StmtClassDef, AnyNodeRef::StmtClassDef(_) => NodeKind::StmtClassDef,
AnyNodeRef::StmtReturn(_) => NodeKind::StmtReturn, AnyNodeRef::StmtReturn(_) => NodeKind::StmtReturn,
AnyNodeRef::StmtDelete(_) => NodeKind::StmtDelete, AnyNodeRef::StmtDelete(_) => NodeKind::StmtDelete,
AnyNodeRef::StmtTypeAlias(_) => NodeKind::StmtTypeAlias,
AnyNodeRef::StmtAssign(_) => NodeKind::StmtAssign, AnyNodeRef::StmtAssign(_) => NodeKind::StmtAssign,
AnyNodeRef::StmtAugAssign(_) => NodeKind::StmtAugAssign, AnyNodeRef::StmtAugAssign(_) => NodeKind::StmtAugAssign,
AnyNodeRef::StmtAnnAssign(_) => NodeKind::StmtAnnAssign, AnyNodeRef::StmtAnnAssign(_) => NodeKind::StmtAnnAssign,
@ -3777,6 +3824,7 @@ impl AnyNodeRef<'_> {
| AnyNodeRef::StmtClassDef(_) | AnyNodeRef::StmtClassDef(_)
| AnyNodeRef::StmtReturn(_) | AnyNodeRef::StmtReturn(_)
| AnyNodeRef::StmtDelete(_) | AnyNodeRef::StmtDelete(_)
| AnyNodeRef::StmtTypeAlias(_)
| AnyNodeRef::StmtAssign(_) | AnyNodeRef::StmtAssign(_)
| AnyNodeRef::StmtAugAssign(_) | AnyNodeRef::StmtAugAssign(_)
| AnyNodeRef::StmtAnnAssign(_) | AnyNodeRef::StmtAnnAssign(_)
@ -3892,6 +3940,7 @@ impl AnyNodeRef<'_> {
| AnyNodeRef::StmtClassDef(_) | AnyNodeRef::StmtClassDef(_)
| AnyNodeRef::StmtReturn(_) | AnyNodeRef::StmtReturn(_)
| AnyNodeRef::StmtDelete(_) | AnyNodeRef::StmtDelete(_)
| AnyNodeRef::StmtTypeAlias(_)
| AnyNodeRef::StmtAssign(_) | AnyNodeRef::StmtAssign(_)
| AnyNodeRef::StmtAugAssign(_) | AnyNodeRef::StmtAugAssign(_)
| AnyNodeRef::StmtAnnAssign(_) | AnyNodeRef::StmtAnnAssign(_)
@ -3948,6 +3997,7 @@ impl AnyNodeRef<'_> {
| AnyNodeRef::StmtClassDef(_) | AnyNodeRef::StmtClassDef(_)
| AnyNodeRef::StmtReturn(_) | AnyNodeRef::StmtReturn(_)
| AnyNodeRef::StmtDelete(_) | AnyNodeRef::StmtDelete(_)
| AnyNodeRef::StmtTypeAlias(_)
| AnyNodeRef::StmtAssign(_) | AnyNodeRef::StmtAssign(_)
| AnyNodeRef::StmtAugAssign(_) | AnyNodeRef::StmtAugAssign(_)
| AnyNodeRef::StmtAnnAssign(_) | AnyNodeRef::StmtAnnAssign(_)
@ -4039,6 +4089,7 @@ impl AnyNodeRef<'_> {
| AnyNodeRef::StmtClassDef(_) | AnyNodeRef::StmtClassDef(_)
| AnyNodeRef::StmtReturn(_) | AnyNodeRef::StmtReturn(_)
| AnyNodeRef::StmtDelete(_) | AnyNodeRef::StmtDelete(_)
| AnyNodeRef::StmtTypeAlias(_)
| AnyNodeRef::StmtAssign(_) | AnyNodeRef::StmtAssign(_)
| AnyNodeRef::StmtAugAssign(_) | AnyNodeRef::StmtAugAssign(_)
| AnyNodeRef::StmtAnnAssign(_) | AnyNodeRef::StmtAnnAssign(_)
@ -4115,6 +4166,7 @@ impl AnyNodeRef<'_> {
| AnyNodeRef::StmtClassDef(_) | AnyNodeRef::StmtClassDef(_)
| AnyNodeRef::StmtReturn(_) | AnyNodeRef::StmtReturn(_)
| AnyNodeRef::StmtDelete(_) | AnyNodeRef::StmtDelete(_)
| AnyNodeRef::StmtTypeAlias(_)
| AnyNodeRef::StmtAssign(_) | AnyNodeRef::StmtAssign(_)
| AnyNodeRef::StmtAugAssign(_) | AnyNodeRef::StmtAugAssign(_)
| AnyNodeRef::StmtAnnAssign(_) | AnyNodeRef::StmtAnnAssign(_)
@ -4198,6 +4250,7 @@ impl AnyNodeRef<'_> {
| AnyNodeRef::StmtClassDef(_) | AnyNodeRef::StmtClassDef(_)
| AnyNodeRef::StmtReturn(_) | AnyNodeRef::StmtReturn(_)
| AnyNodeRef::StmtDelete(_) | AnyNodeRef::StmtDelete(_)
| AnyNodeRef::StmtTypeAlias(_)
| AnyNodeRef::StmtAssign(_) | AnyNodeRef::StmtAssign(_)
| AnyNodeRef::StmtAugAssign(_) | AnyNodeRef::StmtAugAssign(_)
| AnyNodeRef::StmtAnnAssign(_) | AnyNodeRef::StmtAnnAssign(_)
@ -4341,6 +4394,12 @@ impl<'a> From<&'a ast::StmtDelete> for AnyNodeRef<'a> {
} }
} }
impl<'a> From<&'a ast::StmtTypeAlias> for AnyNodeRef<'a> {
fn from(node: &'a ast::StmtTypeAlias) -> Self {
AnyNodeRef::StmtTypeAlias(node)
}
}
impl<'a> From<&'a ast::StmtAssign> for AnyNodeRef<'a> { impl<'a> From<&'a ast::StmtAssign> for AnyNodeRef<'a> {
fn from(node: &'a ast::StmtAssign) -> Self { fn from(node: &'a ast::StmtAssign) -> Self {
AnyNodeRef::StmtAssign(node) AnyNodeRef::StmtAssign(node)
@ -4709,6 +4768,7 @@ impl<'a> From<&'a Stmt> for AnyNodeRef<'a> {
Stmt::ClassDef(node) => AnyNodeRef::StmtClassDef(node), Stmt::ClassDef(node) => AnyNodeRef::StmtClassDef(node),
Stmt::Return(node) => AnyNodeRef::StmtReturn(node), Stmt::Return(node) => AnyNodeRef::StmtReturn(node),
Stmt::Delete(node) => AnyNodeRef::StmtDelete(node), Stmt::Delete(node) => AnyNodeRef::StmtDelete(node),
Stmt::TypeAlias(node) => AnyNodeRef::StmtTypeAlias(node),
Stmt::Assign(node) => AnyNodeRef::StmtAssign(node), Stmt::Assign(node) => AnyNodeRef::StmtAssign(node),
Stmt::AugAssign(node) => AnyNodeRef::StmtAugAssign(node), Stmt::AugAssign(node) => AnyNodeRef::StmtAugAssign(node),
Stmt::AnnAssign(node) => AnyNodeRef::StmtAnnAssign(node), Stmt::AnnAssign(node) => AnyNodeRef::StmtAnnAssign(node),
@ -4866,6 +4926,7 @@ impl Ranged for AnyNodeRef<'_> {
AnyNodeRef::StmtClassDef(node) => node.range(), AnyNodeRef::StmtClassDef(node) => node.range(),
AnyNodeRef::StmtReturn(node) => node.range(), AnyNodeRef::StmtReturn(node) => node.range(),
AnyNodeRef::StmtDelete(node) => node.range(), AnyNodeRef::StmtDelete(node) => node.range(),
AnyNodeRef::StmtTypeAlias(node) => node.range(),
AnyNodeRef::StmtAssign(node) => node.range(), AnyNodeRef::StmtAssign(node) => node.range(),
AnyNodeRef::StmtAugAssign(node) => node.range(), AnyNodeRef::StmtAugAssign(node) => node.range(),
AnyNodeRef::StmtAnnAssign(node) => node.range(), AnyNodeRef::StmtAnnAssign(node) => node.range(),
@ -4949,6 +5010,7 @@ pub enum NodeKind {
StmtClassDef, StmtClassDef,
StmtReturn, StmtReturn,
StmtDelete, StmtDelete,
StmtTypeAlias,
StmtAssign, StmtAssign,
StmtAugAssign, StmtAugAssign,
StmtAnnAssign, StmtAnnAssign,

View File

@ -6,7 +6,7 @@ use std::ops::Deref;
use rustpython_literal::escape::{AsciiEscape, Escape, UnicodeEscape}; use rustpython_literal::escape::{AsciiEscape, Escape, UnicodeEscape};
use rustpython_parser::ast::{ use rustpython_parser::ast::{
self, Alias, Arg, Arguments, BoolOp, CmpOp, Comprehension, Constant, ConversionFlag, self, Alias, Arg, Arguments, BoolOp, CmpOp, Comprehension, Constant, ConversionFlag,
ExceptHandler, Expr, Identifier, MatchCase, Operator, Pattern, Stmt, Suite, WithItem, ExceptHandler, Expr, Identifier, MatchCase, Operator, Pattern, Stmt, Suite, WithItem, TypeParam, TypeParamParamSpec, TypeParamTypeVar, TypeParamTypeVarTuple
}; };
use ruff_python_whitespace::LineEnding; use ruff_python_whitespace::LineEnding;
@ -271,6 +271,7 @@ impl<'a> Generator<'a> {
keywords, keywords,
body, body,
decorator_list, decorator_list,
type_params: _,
range: _range, range: _range,
}) => { }) => {
self.newlines(if self.indent_depth == 0 { 2 } else { 1 }); self.newlines(if self.indent_depth == 0 { 2 } else { 1 });
@ -538,6 +539,16 @@ impl<'a> Generator<'a> {
self.indent_depth = self.indent_depth.saturating_sub(1); self.indent_depth = self.indent_depth.saturating_sub(1);
} }
} }
Stmt::TypeAlias(ast::StmtTypeAlias { name, range: _range, type_params, value}) => {
self.p("type ");
self.unparse_expr(name, precedence::MAX);
for type_param in type_params {
self.unparse_type_param(type_param);
self.p(", ")
}
self.p(" = ");
self.unparse_expr(value, precedence::MAX);
}
Stmt::Raise(ast::StmtRaise { Stmt::Raise(ast::StmtRaise {
exc, exc,
cause, cause,
@ -842,6 +853,26 @@ impl<'a> Generator<'a> {
self.body(&ast.body); self.body(&ast.body);
} }
pub(crate) fn unparse_type_param(&mut self, ast: &TypeParam) {
match ast {
TypeParam::TypeVar(TypeParamTypeVar { name, bound, .. }) => {
self.p_id(name);
if let Some(expr) = bound {
self.p(": ");
self.unparse_expr(expr, precedence::MAX);
}
}
TypeParam::TypeVarTuple(TypeParamTypeVarTuple { name, .. }) => {
self.p("*");
self.p_id(name);
}
TypeParam::ParamSpec(TypeParamParamSpec { name, .. }) => {
self.p("**");
self.p_id(name);
}
}
}
pub(crate) fn unparse_expr(&mut self, ast: &Expr, level: u8) { pub(crate) fn unparse_expr(&mut self, ast: &Expr, level: u8) {
macro_rules! opprec { macro_rules! opprec {
($opty:ident, $x:expr, $enu:path, $($var:ident($op:literal, $prec:ident)),*$(,)?) => { ($opty:ident, $x:expr, $enu:path, $($var:ident($op:literal, $prec:ident)),*$(,)?) => {

View File

@ -156,6 +156,7 @@ pub enum TokenKind {
Try, Try,
While, While,
Match, Match,
Type,
Case, Case,
With, With,
Yield, Yield,
@ -426,6 +427,7 @@ impl TokenKind {
Tok::While => TokenKind::While, Tok::While => TokenKind::While,
Tok::Match => TokenKind::Match, Tok::Match => TokenKind::Match,
Tok::Case => TokenKind::Case, Tok::Case => TokenKind::Case,
Tok::Type => TokenKind::Type,
Tok::With => TokenKind::With, Tok::With => TokenKind::With,
Tok::Yield => TokenKind::Yield, Tok::Yield => TokenKind::Yield,
Tok::StartModule => TokenKind::StartModule, Tok::StartModule => TokenKind::StartModule,

View File

@ -151,6 +151,9 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a Stmt) {
visitor.visit_expr(expr); visitor.visit_expr(expr);
} }
} }
Stmt::TypeAlias(ast::StmtTypeAlias { value, .. }) => {
visitor.visit_expr(value)
}
Stmt::Assign(ast::StmtAssign { targets, value, .. }) => { Stmt::Assign(ast::StmtAssign { targets, value, .. }) => {
visitor.visit_expr(value); visitor.visit_expr(value);
for expr in targets { for expr in targets {

View File

@ -214,6 +214,11 @@ where
} }
} }
Stmt::TypeAlias(ast::StmtTypeAlias { value, .. }) => {
visitor.visit_expr(value)
}
Stmt::Assign(ast::StmtAssign { Stmt::Assign(ast::StmtAssign {
targets, targets,
value, value,

View File

@ -18,6 +18,7 @@ impl FormatNodeRule<StmtClassDef> for FormatStmtClassDef {
bases, bases,
keywords, keywords,
body, body,
type_params: _,
decorator_list, decorator_list,
} = item; } = item;