diff --git a/scripts/_utils.py b/scripts/_utils.py index 623eaca9f7..f9a4cd27dc 100644 --- a/scripts/_utils.py +++ b/scripts/_utils.py @@ -1,4 +1,5 @@ import os +import re from pathlib import Path ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -11,3 +12,7 @@ def dir_name(origin: str) -> str: def pascal_case(origin: str) -> str: """Convert from snake-case to PascalCase.""" return "".join(word.title() for word in origin.split("-")) + + +def get_indent(line: str) -> str: + return re.match(r"^\s*", line).group() # pyright: ignore[reportOptionalMemberAccess] diff --git a/scripts/add_plugin.py b/scripts/add_plugin.py index 3fac8231e0..80158415b5 100755 --- a/scripts/add_plugin.py +++ b/scripts/add_plugin.py @@ -11,7 +11,7 @@ Example usage: import argparse import os -from _utils import ROOT_DIR, dir_name, pascal_case +from _utils import ROOT_DIR, dir_name, get_indent, pascal_case def main(*, plugin: str, url: str) -> None: @@ -67,23 +67,21 @@ mod tests { with open(ROOT_DIR / "src/registry.rs", "w") as fp: for line in content.splitlines(): + indent = get_indent(line) + if line.strip() == "// Ruff": - indent = line.split("// Ruff")[0] fp.write(f"{indent}// {plugin}") fp.write("\n") elif line.strip() == "Ruff,": - indent = line.split("Ruff,")[0] fp.write(f"{indent}{pascal_case(plugin)},") fp.write("\n") elif line.strip() == 'RuleOrigin::Ruff => "Ruff-specific rules",': - indent = line.split('RuleOrigin::Ruff => "Ruff-specific rules",')[0] fp.write(f'{indent}RuleOrigin::{pascal_case(plugin)} => "{plugin}",') fp.write("\n") elif line.strip() == "RuleOrigin::Ruff => vec![RuleCodePrefix::RUF],": - indent = line.split("RuleOrigin::Ruff => vec![RuleCodePrefix::RUF],")[0] fp.write( f"{indent}RuleOrigin::{pascal_case(plugin)} => vec![\n" f'{indent} todo!("Fill-in prefix after generating codes")\n' @@ -92,7 +90,6 @@ mod tests { fp.write("\n") elif line.strip() == "RuleOrigin::Ruff => None,": - indent = line.split("RuleOrigin::Ruff => None,")[0] fp.write(f"{indent}RuleOrigin::{pascal_case(plugin)} => " f'Some(("{url}", &Platform::PyPI)),') fp.write("\n") @@ -105,7 +102,7 @@ mod tests { with open(ROOT_DIR / "src/violations.rs", "w") as fp: for line in content.splitlines(): if line.strip() == "// Ruff": - indent = line.split("// Ruff")[0] + indent = get_indent(line) fp.write(f"{indent}// {plugin}") fp.write("\n") diff --git a/scripts/add_rule.py b/scripts/add_rule.py index 988d4934a1..d61a249b38 100755 --- a/scripts/add_rule.py +++ b/scripts/add_rule.py @@ -11,7 +11,7 @@ Example usage: import argparse -from _utils import ROOT_DIR, dir_name +from _utils import ROOT_DIR, dir_name, get_indent def snake_case(name: str) -> str: @@ -34,7 +34,7 @@ def main(*, name: str, code: str, origin: str) -> None: with open(mod_rs, "w") as fp: for line in content.splitlines(): if line.strip() == "fn rules(rule_code: Rule, path: &Path) -> Result<()> {": - indent = line.split("fn rules(rule_code: Rule, path: &Path) -> Result<()> {")[0] + indent = get_indent(line) fp.write(f'{indent}#[test_case(Rule::{code}, Path::new("{code}.py"); "{code}")]') fp.write("\n") @@ -99,7 +99,7 @@ impl Violation for %s { continue if line.strip() == f"// {origin}": - indent = line.split("//")[0] + indent = get_indent(line) fp.write(f"{indent}{code} => violations::{name},") fp.write("\n") has_written = True