From 770b7f3439c590b7e0410b8f18da1207f8462c32 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Sun, 2 Feb 2025 18:16:07 +0000 Subject: [PATCH] Vendor benchmark test files (#15878) Co-authored-by: Alex Waygood --- .pre-commit-config.yaml | 1 + Cargo.lock | 108 +- Cargo.toml | 14 +- crates/ruff_benchmark/Cargo.toml | 4 - crates/ruff_benchmark/benches/formatter.rs | 30 +- crates/ruff_benchmark/benches/lexer.rs | 28 +- crates/ruff_benchmark/benches/linter.rs | 28 +- crates/ruff_benchmark/benches/parser.rs | 28 +- crates/ruff_benchmark/benches/red_knot.rs | 61 +- crates/ruff_benchmark/resources/README.md | 16 + .../ruff_benchmark/resources/large/dataset.py | 1617 +++++++++++++++++ .../resources/numpy/ctypeslib.py | 547 ++++++ .../ruff_benchmark/resources/numpy/globals.py | 95 + .../resources/pydantic/types.py | 834 +++++++++ crates/ruff_benchmark/resources/pypinyin.py | 161 ++ .../resources/tomllib/__init__.py | 10 + .../resources/tomllib/_parser.py | 691 +++++++ .../ruff_benchmark/resources/tomllib/_re.py | 107 ++ .../resources/tomllib/_types.py | 10 + crates/ruff_benchmark/src/lib.rs | 159 +- 20 files changed, 4235 insertions(+), 314 deletions(-) create mode 100644 crates/ruff_benchmark/resources/README.md create mode 100644 crates/ruff_benchmark/resources/large/dataset.py create mode 100644 crates/ruff_benchmark/resources/numpy/ctypeslib.py create mode 100644 crates/ruff_benchmark/resources/numpy/globals.py create mode 100644 crates/ruff_benchmark/resources/pydantic/types.py create mode 100644 crates/ruff_benchmark/resources/pypinyin.py create mode 100644 crates/ruff_benchmark/resources/tomllib/__init__.py create mode 100644 crates/ruff_benchmark/resources/tomllib/_parser.py create mode 100644 crates/ruff_benchmark/resources/tomllib/_re.py create mode 100644 crates/ruff_benchmark/resources/tomllib/_types.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 42d6180a56..b336b6cc11 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,6 +5,7 @@ exclude: | .github/workflows/release.yml| crates/red_knot_vendored/vendor/.*| crates/red_knot_project/resources/.*| + crates/ruff_benchmark/resources/.*| crates/ruff_linter/resources/.*| crates/ruff_linter/src/rules/.*/snapshots/.*| crates/ruff_notebook/resources/.*| diff --git a/Cargo.lock b/Cargo.lock index 9b34e20e78..a2dd59109e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -190,12 +190,6 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" -[[package]] -name = "base64" -version = "0.22.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" - [[package]] name = "bincode" version = "1.3.3" @@ -2577,28 +2571,13 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" -[[package]] -name = "ring" -version = "0.17.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" -dependencies = [ - "cc", - "cfg-if", - "getrandom", - "libc", - "spin", - "untrusted", - "windows-sys 0.52.0", -] - [[package]] name = "ron" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "88073939a61e5b7680558e6be56b419e208420c2adb92be54921fa6b72283f1a" dependencies = [ - "base64 0.13.1", + "base64", "bitflags 1.3.2", "serde", ] @@ -2692,11 +2671,7 @@ dependencies = [ "ruff_python_parser", "ruff_python_trivia", "rustc-hash 2.1.0", - "serde", - "serde_json", "tikv-jemallocator", - "ureq", - "url", ] [[package]] @@ -3253,38 +3228,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "rustls" -version = "0.23.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8" -dependencies = [ - "log", - "once_cell", - "ring", - "rustls-pki-types", - "rustls-webpki", - "subtle", - "zeroize", -] - -[[package]] -name = "rustls-pki-types" -version = "1.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" - -[[package]] -name = "rustls-webpki" -version = "0.102.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" -dependencies = [ - "ring", - "rustls-pki-types", - "untrusted", -] - [[package]] name = "rustversion" version = "1.0.19" @@ -3567,12 +3510,6 @@ dependencies = [ "anstream", ] -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" - [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -3622,12 +3559,6 @@ dependencies = [ "syn 2.0.96", ] -[[package]] -name = "subtle" -version = "2.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" - [[package]] name = "syn" version = "1.0.109" @@ -4116,28 +4047,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e9df2af067a7953e9c3831320f35c1cc0600c30d44d9f7a12b01db1cd88d6b47" -[[package]] -name = "untrusted" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" - -[[package]] -name = "ureq" -version = "2.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" -dependencies = [ - "base64 0.22.1", - "flate2", - "log", - "once_cell", - "rustls", - "rustls-pki-types", - "url", - "webpki-roots", -] - [[package]] name = "url" version = "2.5.4" @@ -4406,15 +4315,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "webpki-roots" -version = "0.26.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "which" version = "7.0.1" @@ -4723,12 +4623,6 @@ dependencies = [ "synstructure", ] -[[package]] -name = "zeroize" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" - [[package]] name = "zerovec" version = "0.10.4" diff --git a/Cargo.toml b/Cargo.toml index c74e1b8bbc..f638160ea2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -134,7 +134,12 @@ serde_with = { version = "3.6.0", default-features = false, features = [ shellexpand = { version = "3.0.0" } similar = { version = "2.4.0", features = ["inline"] } smallvec = { version = "1.13.2" } -snapbox = { version = "0.6.0", features = ["diff", "term-svg", "cmd", "examples"] } +snapbox = { version = "0.6.0", features = [ + "diff", + "term-svg", + "cmd", + "examples", +] } static_assertions = "1.1.0" strum = { version = "0.26.0", features = ["strum_macros"] } strum_macros = { version = "0.26.0" } @@ -159,7 +164,6 @@ unicode-ident = { version = "1.0.12" } unicode-width = { version = "0.2.0" } unicode_names2 = { version = "1.2.2" } unicode-normalization = { version = "0.1.23" } -ureq = { version = "2.9.6" } url = { version = "2.5.0" } uuid = { version = "1.6.1", features = [ "v4", @@ -305,7 +309,11 @@ local-artifacts-jobs = ["./build-binaries", "./build-docker"] # Publish jobs to run in CI publish-jobs = ["./publish-pypi", "./publish-wasm"] # Post-announce jobs to run in CI -post-announce-jobs = ["./notify-dependents", "./publish-docs", "./publish-playground"] +post-announce-jobs = [ + "./notify-dependents", + "./publish-docs", + "./publish-playground", +] # Custom permissions for GitHub Jobs github-custom-job-permissions = { "build-docker" = { packages = "write", contents = "read" }, "publish-wasm" = { contents = "read", id-token = "write", packages = "write" } } # Whether to install an updater program diff --git a/crates/ruff_benchmark/Cargo.toml b/crates/ruff_benchmark/Cargo.toml index 8ea87a478a..2c56cd3194 100644 --- a/crates/ruff_benchmark/Cargo.toml +++ b/crates/ruff_benchmark/Cargo.toml @@ -41,10 +41,6 @@ codspeed-criterion-compat = { workspace = true, default-features = false, option criterion = { workspace = true, default-features = false } rayon = { workspace = true } rustc-hash = { workspace = true } -serde = { workspace = true } -serde_json = { workspace = true } -url = { workspace = true } -ureq = { workspace = true } [dev-dependencies] ruff_db = { workspace = true } diff --git a/crates/ruff_benchmark/benches/formatter.rs b/crates/ruff_benchmark/benches/formatter.rs index af2b1caa76..1320cd7228 100644 --- a/crates/ruff_benchmark/benches/formatter.rs +++ b/crates/ruff_benchmark/benches/formatter.rs @@ -3,7 +3,10 @@ use std::path::Path; use ruff_benchmark::criterion::{ criterion_group, criterion_main, BenchmarkId, Criterion, Throughput, }; -use ruff_benchmark::{TestCase, TestFile, TestFileDownloadError}; + +use ruff_benchmark::{ + TestCase, LARGE_DATASET, NUMPY_CTYPESLIB, NUMPY_GLOBALS, PYDANTIC_TYPES, UNICODE_PYPINYIN, +}; use ruff_python_formatter::{format_module_ast, PreviewMode, PyFormatOptions}; use ruff_python_parser::{parse, Mode}; use ruff_python_trivia::CommentRanges; @@ -24,27 +27,20 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; #[global_allocator] static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; -fn create_test_cases() -> Result, TestFileDownloadError> { - Ok(vec![ - TestCase::fast(TestFile::try_download("numpy/globals.py", "https://raw.githubusercontent.com/numpy/numpy/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py")?), - TestCase::fast(TestFile::try_download("unicode/pypinyin.py", "https://raw.githubusercontent.com/mozillazg/python-pinyin/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py")?), - TestCase::normal(TestFile::try_download( - "pydantic/types.py", - "https://raw.githubusercontent.com/pydantic/pydantic/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py", - )?), - TestCase::normal(TestFile::try_download("numpy/ctypeslib.py", "https://raw.githubusercontent.com/numpy/numpy/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py")?), - TestCase::slow(TestFile::try_download( - "large/dataset.py", - "https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py", - )?), - ]) +fn create_test_cases() -> Vec { + vec![ + TestCase::fast(NUMPY_GLOBALS.clone()), + TestCase::fast(UNICODE_PYPINYIN.clone()), + TestCase::normal(PYDANTIC_TYPES.clone()), + TestCase::normal(NUMPY_CTYPESLIB.clone()), + TestCase::slow(LARGE_DATASET.clone()), + ] } fn benchmark_formatter(criterion: &mut Criterion) { let mut group = criterion.benchmark_group("formatter"); - let test_cases = create_test_cases().unwrap(); - for case in test_cases { + for case in create_test_cases() { group.throughput(Throughput::Bytes(case.code().len() as u64)); group.bench_with_input( diff --git a/crates/ruff_benchmark/benches/lexer.rs b/crates/ruff_benchmark/benches/lexer.rs index 64b68a7a35..0fa89cdf5c 100644 --- a/crates/ruff_benchmark/benches/lexer.rs +++ b/crates/ruff_benchmark/benches/lexer.rs @@ -1,7 +1,9 @@ use ruff_benchmark::criterion::{ criterion_group, criterion_main, measurement::WallTime, BenchmarkId, Criterion, Throughput, }; -use ruff_benchmark::{TestCase, TestFile, TestFileDownloadError}; +use ruff_benchmark::{ + TestCase, LARGE_DATASET, NUMPY_CTYPESLIB, NUMPY_GLOBALS, PYDANTIC_TYPES, UNICODE_PYPINYIN, +}; use ruff_python_parser::{lexer, Mode, TokenKind}; #[cfg(target_os = "windows")] @@ -20,24 +22,18 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; #[global_allocator] static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; -fn create_test_cases() -> Result, TestFileDownloadError> { - Ok(vec![ - TestCase::fast(TestFile::try_download("numpy/globals.py", "https://raw.githubusercontent.com/numpy/numpy/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py")?), - TestCase::fast(TestFile::try_download("unicode/pypinyin.py", "https://raw.githubusercontent.com/mozillazg/python-pinyin/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py")?), - TestCase::normal(TestFile::try_download( - "pydantic/types.py", - "https://raw.githubusercontent.com/pydantic/pydantic/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py", - )?), - TestCase::normal(TestFile::try_download("numpy/ctypeslib.py", "https://raw.githubusercontent.com/numpy/numpy/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py")?), - TestCase::slow(TestFile::try_download( - "large/dataset.py", - "https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py", - )?), - ]) +fn create_test_cases() -> Vec { + vec![ + TestCase::fast(NUMPY_GLOBALS.clone()), + TestCase::fast(UNICODE_PYPINYIN.clone()), + TestCase::normal(PYDANTIC_TYPES.clone()), + TestCase::normal(NUMPY_CTYPESLIB.clone()), + TestCase::slow(LARGE_DATASET.clone()), + ] } fn benchmark_lexer(criterion: &mut Criterion) { - let test_cases = create_test_cases().unwrap(); + let test_cases = create_test_cases(); let mut group = criterion.benchmark_group("lexer"); for case in test_cases { diff --git a/crates/ruff_benchmark/benches/linter.rs b/crates/ruff_benchmark/benches/linter.rs index eb8667b9e6..d2ce1fb9ed 100644 --- a/crates/ruff_benchmark/benches/linter.rs +++ b/crates/ruff_benchmark/benches/linter.rs @@ -1,7 +1,9 @@ use ruff_benchmark::criterion::{ criterion_group, criterion_main, BenchmarkGroup, BenchmarkId, Criterion, Throughput, }; -use ruff_benchmark::{TestCase, TestFile, TestFileDownloadError}; +use ruff_benchmark::{ + TestCase, LARGE_DATASET, NUMPY_CTYPESLIB, NUMPY_GLOBALS, PYDANTIC_TYPES, UNICODE_PYPINYIN, +}; use ruff_linter::linter::{lint_only, ParseSource}; use ruff_linter::rule_selector::PreviewOptions; use ruff_linter::settings::rule_table::RuleTable; @@ -46,24 +48,18 @@ static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; #[allow(unsafe_code)] pub static _rjem_malloc_conf: &[u8] = b"dirty_decay_ms:-1,muzzy_decay_ms:-1\0"; -fn create_test_cases() -> Result, TestFileDownloadError> { - Ok(vec![ - TestCase::fast(TestFile::try_download("numpy/globals.py", "https://raw.githubusercontent.com/numpy/numpy/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py")?), - TestCase::fast(TestFile::try_download("unicode/pypinyin.py", "https://raw.githubusercontent.com/mozillazg/python-pinyin/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py")?), - TestCase::normal(TestFile::try_download( - "pydantic/types.py", - "https://raw.githubusercontent.com/pydantic/pydantic/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py", - )?), - TestCase::normal(TestFile::try_download("numpy/ctypeslib.py", "https://raw.githubusercontent.com/numpy/numpy/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py")?), - TestCase::slow(TestFile::try_download( - "large/dataset.py", - "https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py", - )?), - ]) +fn create_test_cases() -> Vec { + vec![ + TestCase::fast(NUMPY_GLOBALS.clone()), + TestCase::fast(UNICODE_PYPINYIN.clone()), + TestCase::normal(PYDANTIC_TYPES.clone()), + TestCase::normal(NUMPY_CTYPESLIB.clone()), + TestCase::slow(LARGE_DATASET.clone()), + ] } fn benchmark_linter(mut group: BenchmarkGroup, settings: &LinterSettings) { - let test_cases = create_test_cases().unwrap(); + let test_cases = create_test_cases(); for case in test_cases { group.throughput(Throughput::Bytes(case.code().len() as u64)); diff --git a/crates/ruff_benchmark/benches/parser.rs b/crates/ruff_benchmark/benches/parser.rs index ec2fa671c1..0509bc5b86 100644 --- a/crates/ruff_benchmark/benches/parser.rs +++ b/crates/ruff_benchmark/benches/parser.rs @@ -1,7 +1,9 @@ use ruff_benchmark::criterion::{ criterion_group, criterion_main, measurement::WallTime, BenchmarkId, Criterion, Throughput, }; -use ruff_benchmark::{TestCase, TestFile, TestFileDownloadError}; +use ruff_benchmark::{ + TestCase, LARGE_DATASET, NUMPY_CTYPESLIB, NUMPY_GLOBALS, PYDANTIC_TYPES, UNICODE_PYPINYIN, +}; use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor}; use ruff_python_ast::Stmt; use ruff_python_parser::parse_module; @@ -22,20 +24,14 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; #[global_allocator] static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; -fn create_test_cases() -> Result, TestFileDownloadError> { - Ok(vec![ - TestCase::fast(TestFile::try_download("numpy/globals.py", "https://raw.githubusercontent.com/numpy/numpy/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py")?), - TestCase::fast(TestFile::try_download("unicode/pypinyin.py", "https://raw.githubusercontent.com/mozillazg/python-pinyin/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py")?), - TestCase::normal(TestFile::try_download( - "pydantic/types.py", - "https://raw.githubusercontent.com/pydantic/pydantic/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py", - )?), - TestCase::normal(TestFile::try_download("numpy/ctypeslib.py", "https://raw.githubusercontent.com/numpy/numpy/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py")?), - TestCase::slow(TestFile::try_download( - "large/dataset.py", - "https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py", - )?), - ]) +fn create_test_cases() -> Vec { + vec![ + TestCase::fast(NUMPY_GLOBALS.clone()), + TestCase::fast(UNICODE_PYPINYIN.clone()), + TestCase::normal(PYDANTIC_TYPES.clone()), + TestCase::normal(NUMPY_CTYPESLIB.clone()), + TestCase::slow(LARGE_DATASET.clone()), + ] } struct CountVisitor { @@ -50,7 +46,7 @@ impl<'a> StatementVisitor<'a> for CountVisitor { } fn benchmark_parser(criterion: &mut Criterion) { - let test_cases = create_test_cases().unwrap(); + let test_cases = create_test_cases(); let mut group = criterion.benchmark_group("parser"); for case in test_cases { diff --git a/crates/ruff_benchmark/benches/red_knot.rs b/crates/ruff_benchmark/benches/red_knot.rs index a47a2c0388..5c6362dc0a 100644 --- a/crates/ruff_benchmark/benches/red_knot.rs +++ b/crates/ruff_benchmark/benches/red_knot.rs @@ -24,7 +24,25 @@ struct Case { re_path: SystemPathBuf, } -const TOMLLIB_312_URL: &str = "https://raw.githubusercontent.com/python/cpython/8e8a4baf652f6e1cee7acde9d78c4b6154539748/Lib/tomllib"; +// "https://raw.githubusercontent.com/python/cpython/8e8a4baf652f6e1cee7acde9d78c4b6154539748/Lib/tomllib"; +static TOMLLIB_FILES: [TestFile; 4] = [ + TestFile::new( + "tomllib/__init__.py", + include_str!("../resources/tomllib/__init__.py"), + ), + TestFile::new( + "tomllib/_parser.py", + include_str!("../resources/tomllib/_parser.py"), + ), + TestFile::new( + "tomllib/_re.py", + include_str!("../resources/tomllib/_re.py"), + ), + TestFile::new( + "tomllib/_types.py", + include_str!("../resources/tomllib/_types.py"), + ), +]; /// A structured set of fields we use to do diagnostic comparisons. /// @@ -80,27 +98,19 @@ static EXPECTED_DIAGNOSTICS: &[KeyDiagnosticFields] = &[ ), ]; -fn get_test_file(name: &str) -> TestFile { - let path = format!("tomllib/{name}"); - let url = format!("{TOMLLIB_312_URL}/{name}"); - TestFile::try_download(&path, &url).unwrap() -} - -fn tomllib_path(filename: &str) -> SystemPathBuf { - SystemPathBuf::from(format!("/src/tomllib/{filename}").as_str()) +fn tomllib_path(file: &TestFile) -> SystemPathBuf { + SystemPathBuf::from("src").join(file.name()) } fn setup_case() -> Case { let system = TestSystem::default(); let fs = system.memory_file_system().clone(); - let tomllib_filenames = ["__init__.py", "_parser.py", "_re.py", "_types.py"]; - fs.write_files(tomllib_filenames.iter().map(|filename| { - ( - tomllib_path(filename), - get_test_file(filename).code().to_string(), - ) - })) + fs.write_files( + TOMLLIB_FILES + .iter() + .map(|file| (tomllib_path(file), file.code().to_string())), + ) .unwrap(); let src_root = SystemPath::new("/src"); @@ -114,15 +124,22 @@ fn setup_case() -> Case { }); let mut db = ProjectDatabase::new(metadata, system).unwrap(); + let mut tomllib_files = FxHashSet::default(); + let mut re: Option = None; + + for test_file in &TOMLLIB_FILES { + let file = system_path_to_file(&db, tomllib_path(test_file)).unwrap(); + if test_file.name().ends_with("_re.py") { + re = Some(file); + } + tomllib_files.insert(file); + } + + let re = re.unwrap(); - let tomllib_files: FxHashSet = tomllib_filenames - .iter() - .map(|filename| system_path_to_file(&db, tomllib_path(filename)).unwrap()) - .collect(); db.project().set_open_files(&mut db, tomllib_files); - let re_path = tomllib_path("_re.py"); - let re = system_path_to_file(&db, &re_path).unwrap(); + let re_path = re.path(&db).as_system_path().unwrap().to_owned(); Case { db, fs, diff --git a/crates/ruff_benchmark/resources/README.md b/crates/ruff_benchmark/resources/README.md new file mode 100644 index 0000000000..4a40a70ee5 --- /dev/null +++ b/crates/ruff_benchmark/resources/README.md @@ -0,0 +1,16 @@ +This directory vendors some files from actual projects. +This is to benchmark Ruff's performance against real-world +code instead of synthetic benchmarks. + +The following files are included: + +* [`numpy/globals`](https://github.com/numpy/numpy/blob/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py) +* [`numpy/ctypeslib.py`](https://github.com/numpy/numpy/blob/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py) +* [`pypinyin.py`](https://github.com/mozillazg/python-pinyin/blob/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py) +* [`pydantic/types.py`](https://github.com/pydantic/pydantic/blob/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py) +* [`large/dataset.py`](https://github.com/DHI/mikeio/blob/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py) +* [`tomllib`](https://github.com/python/cpython/tree/8e8a4baf652f6e1cee7acde9d78c4b6154539748/Lib/tomllib) (3.12) + +The files are included in the `resources` directory to allow +running benchmarks offline and for simplicity. They're licensed +according to their original licenses (see link). diff --git a/crates/ruff_benchmark/resources/large/dataset.py b/crates/ruff_benchmark/resources/large/dataset.py new file mode 100644 index 0000000000..e53287e619 --- /dev/null +++ b/crates/ruff_benchmark/resources/large/dataset.py @@ -0,0 +1,1617 @@ +import os +from datetime import datetime +import numpy as np +import pandas as pd +import pytest + +import mikeio +from mikeio.eum import EUMType, ItemInfo, EUMUnit +from mikeio.exceptions import OutsideModelDomainError + + +@pytest.fixture +def ds1(): + nt = 10 + ne = 7 + + d1 = np.zeros([nt, ne]) + 0.1 + d2 = np.zeros([nt, ne]) + 0.2 + data = [d1, d2] + + time = pd.date_range(start=datetime(2000, 1, 1), freq="S", periods=nt) + items = [ItemInfo("Foo"), ItemInfo("Bar")] + return mikeio.Dataset( + data=data, time=time, items=items, geometry=mikeio.Grid1D(nx=7, dx=1) + ) + + +@pytest.fixture +def ds2(): + nt = 10 + ne = 7 + + d1 = np.zeros([nt, ne]) + 1.0 + d2 = np.zeros([nt, ne]) + 2.0 + data = [d1, d2] + + time = pd.date_range(start=datetime(2000, 1, 1), freq="S", periods=nt) + items = [ItemInfo("Foo"), ItemInfo("Bar")] + return mikeio.Dataset(data, time, items) + + +@pytest.fixture +def ds3(): + + nt = 100 + d1 = np.zeros([nt, 100, 30]) + 1.5 + d2 = np.zeros([nt, 100, 30]) + 2.0 + d3 = np.zeros([nt, 100, 30]) + 3.0 + + data = [d1, d2, d3] + + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo(x) for x in ["Foo", "Bar", "Baz"]] + return mikeio.Dataset(data, time, items) + + +def test_create_wrong_data_type_error(): + + data = ["item 1", "item 2"] + + nt = 2 + time = pd.date_range(start=datetime(2000, 1, 1), freq="S", periods=nt) + + with pytest.raises(TypeError, match="numpy"): + mikeio.Dataset(data=data, time=time) + + +def test_get_names(): + + nt = 100 + d = np.zeros([nt, 100, 30]) + 1.0 + time = pd.date_range(start=datetime(2000, 1, 1), freq="S", periods=nt) + data_vars = {"Foo": mikeio.DataArray(data=d, time=time, item=ItemInfo(name="Foo"))} + + ds = mikeio.Dataset(data_vars) + + assert ds["Foo"].name == "Foo" + assert ds["Foo"].type == EUMType.Undefined + assert repr(ds["Foo"].unit) == "undefined" + assert ds.names == ["Foo"] + + +def test_properties(ds1): + nt = 10 + ne = 7 + time = pd.date_range(start=datetime(2000, 1, 1), freq="S", periods=nt) + + assert ds1.names == ["Foo", "Bar"] + assert ds1.n_items == 2 + + assert np.all(ds1.time == time) + assert ds1.n_timesteps == nt + assert ds1.timestep == 1 + assert ds1.start_time == time[0] + assert ds1.end_time == time[-1] + + assert ds1.shape == (nt, ne) + assert ds1.dims == ("time", "x") + assert ds1.geometry.nx == 7 + assert ds1._zn is None + + # assert not hasattr(ds1, "keys") # TODO: remove this + # assert not hasattr(ds1, "values") # TODO: remove this + assert isinstance(ds1.items[0], ItemInfo) + + +def test_pop(ds1): + da = ds1.pop("Foo") + assert len(ds1) == 1 + assert ds1.names == ["Bar"] + assert isinstance(da, mikeio.DataArray) + assert da.name == "Foo" + + ds1["Foo2"] = da # re-insert + assert len(ds1) == 2 + + da = ds1.pop(-1) + assert len(ds1) == 1 + assert ds1.names == ["Bar"] + assert isinstance(da, mikeio.DataArray) + assert da.name == "Foo2" + + +def test_popitem(ds1): + da = ds1.popitem() + assert len(ds1) == 1 + assert ds1.names == ["Bar"] + assert isinstance(da, mikeio.DataArray) + assert da.name == "Foo" + + +def test_insert(ds1): + da = ds1[0].copy() + da.name = "Baz" + + ds1.insert(2, da) + assert len(ds1) == 3 + assert ds1.names == ["Foo", "Bar", "Baz"] + assert ds1[-1] == da + + +def test_insert_wrong_type(ds1): + + with pytest.raises(ValueError): + ds1["Foo"] = "Bar" + + +def test_insert_fail(ds1): + da = ds1[0] + with pytest.raises(ValueError, match="Cannot add the same object"): + ds1.insert(2, da) + + vals = ds1[0].values + da = ds1[0].copy() + da.values = vals + with pytest.raises(ValueError, match="refer to the same data"): + ds1.insert(2, da) + + +def test_remove(ds1): + ds1.remove(-1) + assert len(ds1) == 1 + assert ds1.names == ["Foo"] + + ds1.remove("Foo") + assert len(ds1) == 0 + + +def test_index_with_attribute(): + + nt = 10000 + d = np.zeros([nt, 100, 30]) + 1.0 + time = pd.date_range(start=datetime(2000, 1, 1), freq="S", periods=nt) + + # We cannot create a mikeio.Dataset with multiple references to the same DataArray + da = mikeio.DataArray(data=d, time=time) + data = [da, da] + with pytest.raises(ValueError): + mikeio.Dataset(data) + + # We cannot create a mikeio.Dataset with multiple references to the same data + da1 = mikeio.DataArray(item="Foo", data=d, time=time) + da2 = mikeio.DataArray(item="Bar", data=d, time=time) + data = [da1, da2] + with pytest.raises(ValueError): + mikeio.Dataset(data) + + da1 = mikeio.DataArray(item="Foo", data=d, time=time) + da2 = mikeio.DataArray(item="Bar", data=d.copy(), time=time) + data = [da1, da2] + ds = mikeio.Dataset(data) + assert ds["Foo"].name == "Foo" + assert ds.Bar.name == "Bar" + + assert ds["Foo"] is ds.Foo # This is the same object + + ds["Foo"] = ds.Foo + 2.0 + assert ( + ds["Foo"] is ds.Foo + ) # This is now modfied, but both methods points to the same object + + +def test_getitem_time(ds3): + # time = pd.date_range("2000-1-2", freq="H", periods=100) + ds_sel = ds3["2000-1-2"] + assert ds_sel.n_timesteps == 24 + assert ds_sel.is_equidistant + + ds_sel = ds3["2000-1-2":"2000-1-3 00:00"] + assert ds_sel.n_timesteps == 25 + assert ds_sel.is_equidistant + + time = ["2000-1-2 04:00:00", "2000-1-2 08:00:00", "2000-1-2 12:00:00"] + ds_sel = ds3[time] + assert ds_sel.n_timesteps == 3 + assert ds_sel.is_equidistant + + time = [ds3.time[0], ds3.time[1], ds3.time[7], ds3.time[23]] + ds_sel = ds3[time] + assert ds_sel.n_timesteps == 4 + assert not ds_sel.is_equidistant + + ds_sel = ds3[ds3.time[:10]] + assert ds_sel.n_timesteps == 10 + assert ds_sel.is_equidistant + + +def test_getitem_multi_indexing_attempted(ds3): + with pytest.raises(TypeError, match="not allow multi-index"): + ds3[0, 0] + with pytest.warns(Warning, match="ambiguity"): + ds3[0, 1] # indistinguishable from ds3[(0,1)] + with pytest.raises(TypeError, match="not allow multi-index"): + ds3[:, 1] + with pytest.raises(TypeError, match="not allow multi-index"): + ds3[-1, [0, 1], 1] + + +def test_select_subset_isel(): + + nt = 100 + d1 = np.zeros([nt, 100, 30]) + 1.5 + d2 = np.zeros([nt, 100, 30]) + 2.0 + + d1[0, 10, :] = 2.0 + d2[0, 10, :] = 3.0 + + time = pd.date_range(start=datetime(2000, 1, 1), freq="S", periods=nt) + + geometry = mikeio.Grid2D(nx=30, ny=100, bbox=[0, 0, 1, 1]) + + data = { + "Foo": mikeio.DataArray( + data=d1, time=time, geometry=geometry, item=ItemInfo("Foo") + ), + "Bar": mikeio.DataArray( + data=d2, time=time, geometry=geometry, item=ItemInfo("Bar") + ), + } + + ds = mikeio.Dataset(data) + + selds = ds.isel(10, axis=1) + + assert len(selds.items) == 2 + assert len(selds.to_numpy()) == 2 + assert selds["Foo"].shape == (100, 30) + assert selds["Foo"].to_numpy()[0, 0] == 2.0 + assert selds["Bar"].to_numpy()[0, 0] == 3.0 + + selds_named_axis = ds.isel(10, axis="y") + + assert len(selds_named_axis.items) == 2 + assert selds_named_axis["Foo"].shape == (100, 30) + + +def test_select_subset_isel_axis_out_of_range_error(ds2): + + assert len(ds2.shape) == 2 + dss = ds2.isel(idx=0) + + # After subsetting there is only one dimension + assert len(dss.shape) == 1 + + with pytest.raises(IndexError): + dss.isel(idx=0, axis=1) + + +def test_isel_named_axis(ds2: mikeio.Dataset): + dss = ds2.isel(time=0) + assert len(dss.shape) == 1 + + +def test_select_temporal_subset_by_idx(): + + nt = 100 + d1 = np.zeros([nt, 100, 30]) + 1.5 + d2 = np.zeros([nt, 100, 30]) + 2.0 + + d1[0, 10, :] = 2.0 + d2[0, 10, :] = 3.0 + data = [d1, d2] + + time = pd.date_range(start=datetime(2000, 1, 1), freq="S", periods=nt) + items = [ItemInfo("Foo"), ItemInfo("Bar")] + ds = mikeio.Dataset(data, time, items) + + selds = ds.isel([0, 1, 2], axis=0) + + assert len(selds) == 2 + assert selds["Foo"].shape == (3, 100, 30) + + +def test_temporal_subset_fancy(): + + nt = (24 * 31) + 1 + d1 = np.zeros([nt, 100, 30]) + 1.5 + d2 = np.zeros([nt, 100, 30]) + 2.0 + data = [d1, d2] + + time = pd.date_range("2000-1-1", freq="H", periods=nt) + items = [ItemInfo("Foo"), ItemInfo("Bar")] + ds = mikeio.Dataset(data, time, items) + + assert ds.time[0].hour == 0 + assert ds.time[-1].hour == 0 + + selds = ds["2000-01-01 00:00":"2000-01-02 00:00"] + + assert len(selds) == 2 + assert selds["Foo"].shape == (25, 100, 30) + + selds = ds[:"2000-01-02 00:00"] + assert selds["Foo"].shape == (25, 100, 30) + + selds = ds["2000-01-31 00:00":] + assert selds["Foo"].shape == (25, 100, 30) + + selds = ds["2000-01-30":] + assert selds["Foo"].shape == (49, 100, 30) + + +def test_subset_with_datetime(): + nt = (24 * 31) + 1 + d1 = np.zeros([nt, 100, 30]) + 1.5 + d2 = np.zeros([nt, 100, 30]) + 2.0 + data = [d1, d2] + + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo"), ItemInfo("Bar")] + ds = mikeio.Dataset(data, time, items) + + dssub = ds[datetime(2000, 1, 2)] + assert dssub.n_timesteps == 1 + + dssub = ds[pd.Timestamp(datetime(2000, 1, 2))] + assert dssub.n_timesteps == 1 + + dssub = ds["2000-1-2"] + assert dssub.n_timesteps == 24 + + +def test_select_item_by_name(): + nt = 100 + d1 = np.zeros([nt, 100, 30]) + 1.5 + d2 = np.zeros([nt, 100, 30]) + 2.0 + + d1[0, 10, :] = 2.0 + d2[0, 10, :] = 3.0 + data = [d1, d2] + + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo"), ItemInfo("Bar")] + ds = mikeio.Dataset(data, time, items) + + foo_data = ds["Foo"] + assert foo_data.to_numpy()[0, 10, 0] == 2.0 + + +def test_missing_item_error(): + nt = 100 + + da1 = mikeio.DataArray( + data=np.zeros(nt), + time=pd.date_range("2000-1-2", freq="H", periods=nt), + item="Foo", + ) + + da2 = mikeio.DataArray( + data=np.ones(nt), + time=pd.date_range("2000-1-2", freq="H", periods=nt), + item="Bar", + ) + + ds = mikeio.Dataset([da1, da2]) + + with pytest.raises(KeyError, match="Baz"): + ds["Baz"] # there is no Bar item + + +def test_select_multiple_items_by_name(): + nt = 100 + d1 = np.zeros([nt, 100, 30]) + 1.5 + d2 = np.zeros([nt, 100, 30]) + 2.0 + d3 = np.zeros([nt, 100, 30]) + 3.0 + + data = [d1, d2, d3] + + time = pd.date_range("2000-1-2", freq="H", periods=nt) + # items = [ItemInfo("Foo"), ItemInfo("Bar"), ItemInfo("Baz")] + items = [ItemInfo(x) for x in ["Foo", "Bar", "Baz"]] + ds = mikeio.Dataset(data, time, items) + + assert len(ds) == 3 # Length of a dataset is the number of items + + newds = ds[["Baz", "Foo"]] + assert newds.items[0].name == "Baz" + assert newds.items[1].name == "Foo" + assert newds["Foo"].to_numpy()[0, 10, 0] == 1.5 + + assert len(newds) == 2 + + +def test_select_multiple_items_by_index(ds3): + assert len(ds3) == 3 # Length of a dataset is the number of items + + newds = ds3[[2, 0]] + assert len(newds) == 2 + assert newds.items[0].name == "Baz" + assert newds.items[1].name == "Foo" + assert newds["Foo"].to_numpy()[0, 10, 0] == 1.5 + + +def test_select_multiple_items_by_slice(ds3): + assert len(ds3) == 3 # Length of a dataset is the number of items + + newds = ds3[:2] + assert len(newds) == 2 + assert newds.items[0].name == "Foo" + assert newds.items[1].name == "Bar" + assert newds["Foo"].to_numpy()[0, 10, 0] == 1.5 + + +def test_select_item_by_iteminfo(): + nt = 100 + d1 = np.zeros([nt, 100, 30]) + 1.5 + d2 = np.zeros([nt, 100, 30]) + 2.0 + + d1[0, 10, :] = 2.0 + d2[0, 10, :] = 3.0 + data = [d1, d2] + + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo"), ItemInfo("Bar")] + ds = mikeio.Dataset(data, time, items) + + foo_item = items[0] + + foo_data = ds[foo_item] + assert foo_data.to_numpy()[0, 10, 0] == 2.0 + + +def test_select_subset_isel_multiple_idxs(): + + nt = 100 + d1 = np.zeros([nt, 100, 30]) + 1.5 + d2 = np.zeros([nt, 100, 30]) + 2.0 + + data = [d1, d2] + + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo"), ItemInfo("Bar")] + ds = mikeio.Dataset(data, time, items) + + selds = ds.isel([10, 15], axis=1) + + assert len(selds.items) == 2 + assert len(selds.to_numpy()) == 2 + assert selds["Foo"].shape == (100, 2, 30) + + +def test_decribe(ds1): + df = ds1.describe() + assert df.columns[0] == "Foo" + assert df.loc["mean"][1] == pytest.approx(0.2) + assert df.loc["max"][0] == pytest.approx(0.1) + + +def test_create_undefined(): + + nt = 100 + d1 = np.zeros([nt]) + d2 = np.zeros([nt]) + + time = pd.date_range("2000-1-2", freq="H", periods=nt) + data = { + "Item 1": mikeio.DataArray( + data=d1, time=time, item=ItemInfo("Item 1") + ), # TODO redundant name + "Item 2": mikeio.DataArray(data=d2, time=time, item=ItemInfo("Item 2")), + } + + ds = mikeio.Dataset(data) + + assert len(ds.items) == 2 + assert len(ds.to_numpy()) == 2 + assert ds[0].name == "Item 1" + assert ds[0].type == EUMType.Undefined + + +def test_create_named_undefined(): + + nt = 100 + d1 = np.zeros([nt]) + d2 = np.zeros([nt]) + + data = [d1, d2] + + time = pd.date_range("2000-1-2", freq="H", periods=nt) + ds = mikeio.Dataset(data=data, time=time, items=["Foo", "Bar"]) + + assert len(ds.items) == 2 + assert len(ds.to_numpy()) == 2 + assert ds[1].name == "Bar" + assert ds[1].type == EUMType.Undefined + + +def test_to_dataframe_single_timestep(): + + nt = 1 + d1 = np.zeros([nt]) + d2 = np.zeros([nt]) + + data = [d1, d2] + + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo"), ItemInfo("Bar")] + ds = mikeio.Dataset(data, time, items) + df = ds.to_dataframe() + + assert "Bar" in df.columns + # assert isinstance(df.index, pd.DatetimeIndex) + + +def test_to_dataframe(): + + nt = 100 + d1 = np.zeros([nt]) + d2 = np.zeros([nt]) + + data = [d1, d2] + + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo"), ItemInfo("Bar")] + ds = mikeio.Dataset(data, time, items) + df = ds.to_dataframe() + + assert list(df.columns) == ["Foo", "Bar"] + assert isinstance(df.index, pd.DatetimeIndex) + + +def test_multidimensional_to_dataframe_no_supported(): + + nt = 100 + d1 = np.zeros([nt, 2]) + + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo")] + ds = mikeio.Dataset([d1], time, items) + + with pytest.raises(ValueError): + ds.to_dataframe() + + +def test_get_data(): + + data = [] + nt = 100 + d = np.zeros([nt, 100, 30]) + 1.0 + data.append(d) + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo")] + ds = mikeio.Dataset(data, time, items) + + assert ds.shape == (100, 100, 30) + + +def test_interp_time(): + + nt = 4 + d = np.zeros([nt, 10, 3]) + d[1] = 2.0 + d[3] = 4.0 + data = [d] + time = pd.date_range("2000-1-1", freq="D", periods=nt) + items = [ItemInfo("Foo")] + ds = mikeio.Dataset(data, time, items) + + assert ds[0].shape == (nt, 10, 3) + + dsi = ds.interp_time(dt=3600) + + assert ds.time[0] == dsi.time[0] + assert dsi[0].shape == (73, 10, 3) + + dsi2 = ds.interp_time(freq="2H") + assert dsi2.timestep == 2 * 3600 + + +def test_interp_time_to_other_dataset(): + + # Arrange + ## mikeio.Dataset 1 + nt = 4 + data = [np.zeros([nt, 10, 3])] + time = pd.date_range("2000-1-1", freq="D", periods=nt) + items = [ItemInfo("Foo")] + ds1 = mikeio.Dataset(data, time, items) + assert ds1.shape == (nt, 10, 3) + + ## mikeio.Dataset 2 + nt = 12 + data = [np.ones([nt, 10, 3])] + time = pd.date_range("2000-1-1", freq="H", periods=nt) + items = [ItemInfo("Foo")] + ds2 = mikeio.Dataset(data, time, items) + + # Act + ## interp + dsi = ds1.interp_time(dt=ds2.time) + + # Assert + assert dsi.time[0] == ds2.time[0] + assert dsi.time[-1] == ds2.time[-1] + assert len(dsi.time) == len(ds2.time) + assert dsi[0].shape[0] == ds2[0].shape[0] + + # Accept dataset as argument + dsi2 = ds1.interp_time(ds2) + assert dsi2.time[0] == ds2.time[0] + + +def test_extrapolate(): + # Arrange + ## mikeio.Dataset 1 + nt = 2 + data = [np.zeros([nt, 10, 3])] + time = pd.date_range("2000-1-1", freq="D", periods=nt) + items = [ItemInfo("Foo")] + ds1 = mikeio.Dataset(data, time, items) + assert ds1.shape == (nt, 10, 3) + + ## mikeio.Dataset 2 partly overlapping with mikeio.Dataset 1 + nt = 3 + data = [np.ones([nt, 10, 3])] + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo")] + ds2 = mikeio.Dataset(data, time, items) + + # Act + ## interp + dsi = ds1.interp_time(dt=ds2.time, fill_value=1.0) + + # Assert + assert dsi.time[0] == ds2.time[0] + assert dsi.time[-1] == ds2.time[-1] + assert len(dsi.time) == len(ds2.time) + assert dsi[0].values[0] == pytest.approx(0.0) + assert dsi[0].values[1] == pytest.approx(1.0) # filled + assert dsi[0].values[2] == pytest.approx(1.0) # filled + + +def test_extrapolate_not_allowed(): + ## mikeio.Dataset 1 + nt = 2 + data = [np.zeros([nt, 10, 3])] + time = pd.date_range("2000-1-1", freq="D", periods=nt) + items = [ItemInfo("Foo")] + ds1 = mikeio.Dataset(data, time, items) + assert ds1.shape == (nt, 10, 3) + + ## mikeio.Dataset 2 partly overlapping with mikeio.Dataset 1 + nt = 3 + data = [np.ones([nt, 10, 3])] + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo")] + ds2 = mikeio.Dataset(data, time, items) + + with pytest.raises(ValueError): + dsi = ds1.interp_time(dt=ds2.time, fill_value=1.0, extrapolate=False) + + +def test_get_data_2(): + + nt = 100 + data = [] + d = np.zeros([nt, 100, 30]) + 1.0 + data.append(d) + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo")] + ds = mikeio.Dataset(data, time, items) + + assert data[0].shape == (100, 100, 30) + + +def test_get_data_name(): + + nt = 100 + data = [] + d = np.zeros([nt, 100, 30]) + 1.0 + data.append(d) + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo")] + ds = mikeio.Dataset(data, time, items) + + assert ds["Foo"].shape == (100, 100, 30) + + +def test_modify_selected_variable(): + + nt = 100 + + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo")] + ds = mikeio.Dataset([np.zeros((nt, 10))], time, items) + + assert ds.Foo.to_numpy()[0, 0] == 0.0 + + foo = ds.Foo + foo_mod = foo + 1.0 + + ds["Foo"] = foo_mod + assert ds.Foo.to_numpy()[0, 0] == 1.0 + + +def test_get_bad_name(): + nt = 100 + data = [] + d = np.zeros([100, 100, 30]) + 1.0 + data.append(d) + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo")] + ds = mikeio.Dataset(data, time, items) + + with pytest.raises(Exception): + ds["BAR"] + + +def test_flipud(): + + nt = 2 + d = np.random.random([nt, 100, 30]) + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo")] + ds = mikeio.Dataset([d], time, items) + + dsud = ds.copy() + dsud.flipud() + + assert dsud.shape == ds.shape + assert dsud["Foo"].to_numpy()[0, 0, 0] == ds["Foo"].to_numpy()[0, -1, 0] + + +def test_aggregation_workflows(tmpdir): + filename = "tests/testdata/HD2D.dfsu" + dfs = mikeio.open(filename) + + ds = dfs.read(items=["Surface elevation", "Current speed"]) + ds2 = ds.max(axis=1) + + outfilename = os.path.join(tmpdir.dirname, "max.dfs0") + ds2.to_dfs(outfilename) + assert os.path.isfile(outfilename) + + ds3 = ds.min(axis=1) + + outfilename = os.path.join(tmpdir.dirname, "min.dfs0") + ds3.to_dfs(outfilename) + assert os.path.isfile(outfilename) + + +def test_aggregation_dataset_no_time(): + filename = "tests/testdata/HD2D.dfsu" + dfs = mikeio.open(filename) + ds = dfs.read(time=-1, items=["Surface elevation", "Current speed"]) + + ds2 = ds.max() + assert ds2["Current speed"].values == pytest.approx(1.6463733) + + +def test_aggregations(): + filename = "tests/testdata/gebco_sound.dfs2" + ds = mikeio.read(filename) + + for axis in [0, 1, 2]: + ds.mean(axis=axis) + ds.nanmean(axis=axis) + ds.nanmin(axis=axis) + ds.nanmax(axis=axis) + + assert ds.mean().shape == (264, 216) + assert ds.mean(axis="time").shape == (264, 216) + assert ds.mean(axis="spatial").shape == (1,) + assert ds.mean(axis="space").shape == (1,) + + with pytest.raises(ValueError, match="space"): + ds.mean(axis="spaghetti") + + dsm = ds.mean(axis="time") + assert dsm.geometry is not None + + +def test_to_dfs_extension_validation(tmpdir): + + outfilename = os.path.join(tmpdir, "not_gonna_happen.dfs2") + + ds = mikeio.read( + "tests/testdata/HD2D.dfsu", items=["Surface elevation", "Current speed"] + ) + with pytest.raises(ValueError) as excinfo: + ds.to_dfs(outfilename) + + assert "dfsu" in str(excinfo.value) + + +def test_weighted_average(tmpdir): + filename = "tests/testdata/HD2D.dfsu" + dfs = mikeio.open(filename) + + ds = dfs.read(items=["Surface elevation", "Current speed"]) + + area = dfs.get_element_area() + ds2 = ds.average(weights=area, axis=1) + + outfilename = os.path.join(tmpdir.dirname, "average.dfs0") + ds2.to_dfs(outfilename) + assert os.path.isfile(outfilename) + + +def test_quantile_axis1(ds1): + dsq = ds1.quantile(q=0.345, axis=1) + assert dsq[0].to_numpy()[0] == 0.1 + assert dsq[1].to_numpy()[0] == 0.2 + + assert dsq.n_items == ds1.n_items + assert dsq.n_timesteps == ds1.n_timesteps + + # q as list + dsq = ds1.quantile(q=[0.25, 0.75], axis=1) + assert dsq.n_items == 2 * ds1.n_items + assert "Quantile 0.75, " in dsq.items[1].name + assert "Quantile 0.25, " in dsq.items[2].name + + +def test_quantile_axis0(ds1): + dsq = ds1.quantile(q=0.345) # axis=0 is default + assert dsq[0].to_numpy()[0] == 0.1 + assert dsq[1].to_numpy()[0] == 0.2 + + assert dsq.n_items == ds1.n_items + assert dsq.n_timesteps == 1 + assert dsq.shape[-1] == ds1.shape[-1] + + # q as list + dsq = ds1.quantile(q=[0.25, 0.75], axis=0) + assert dsq.n_items == 2 * ds1.n_items + assert dsq[0].to_numpy()[0] == 0.1 + assert dsq[1].to_numpy()[0] == 0.1 + assert dsq[2].to_numpy()[0] == 0.2 + assert dsq[3].to_numpy()[0] == 0.2 + + assert "Quantile 0.75, " in dsq.items[1].name + assert "Quantile 0.25, " in dsq.items[2].name + assert "Quantile 0.75, " in dsq.items[3].name + + +def test_nanquantile(): + q = 0.99 + fn = "tests/testdata/random.dfs0" # has delete value + ds = mikeio.read(fn) + + dsq1 = ds.quantile(q=q) + dsq2 = ds.nanquantile(q=q) + + assert np.isnan(dsq1[0].to_numpy()) + assert not np.isnan(dsq2[0].to_numpy()) + + qnt = np.quantile(ds[0].to_numpy(), q=q) + nqnt = np.nanquantile(ds[0].to_numpy(), q=q) + + assert np.isnan(qnt) + assert dsq2[0].to_numpy() == nqnt + + +def test_aggregate_across_items(): + + ds = mikeio.read("tests/testdata/State_wlbc_north_err.dfs1") + + with pytest.warns(FutureWarning): # TODO: remove in 1.5.0 + dam = ds.max(axis="items") + + assert isinstance(dam, mikeio.DataArray) + assert dam.geometry == ds.geometry + assert dam.dims == ds.dims + assert dam.type == ds[-1].type + + dsm = ds.mean(axis="items", keepdims=True) + + assert isinstance(dsm, mikeio.Dataset) + assert dsm.geometry == ds.geometry + assert dsm.dims == ds.dims + + +def test_aggregate_selected_items_dfsu_save_to_new_file(tmpdir): + + ds = mikeio.read("tests/testdata/State_Area.dfsu", items="*Level*") + + assert ds.n_items == 5 + + with pytest.warns(FutureWarning): # TODO: remove keepdims in 1.5.0 + dsm = ds.max( + axis="items", keepdims=True, name="Max Water Level" + ) # add a nice name + assert len(dsm) == 1 + assert dsm[0].name == "Max Water Level" + assert dsm.geometry == ds.geometry + assert dsm.dims == ds.dims + assert dsm[0].type == ds[-1].type + + outfilename = os.path.join(tmpdir, "maxwl.dfsu") + dsm.to_dfs(outfilename) + + +def test_copy(): + nt = 100 + d1 = np.zeros([nt, 100, 30]) + 1.5 + d2 = np.zeros([nt, 100, 30]) + 2.0 + + data = [d1, d2] + + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo"), ItemInfo("Bar")] + ds = mikeio.Dataset(data, time, items) + + assert len(ds.items) == 2 + assert len(ds.to_numpy()) == 2 + assert ds[0].name == "Foo" + + ds2 = ds.copy() + + ds2[0].name = "New name" + + assert ds2[0].name == "New name" + assert ds[0].name == "Foo" + + +def test_dropna(): + nt = 10 + d1 = np.zeros([nt, 100, 30]) + d2 = np.zeros([nt, 100, 30]) + + d1[9:] = np.nan + d2[8:] = np.nan + + data = [d1, d2] + + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo"), ItemInfo("Bar")] + ds = mikeio.Dataset(data, time, items) + + assert len(ds.items) == 2 + assert len(ds.to_numpy()) == 2 + + ds2 = ds.dropna() + + assert ds2.n_timesteps == 8 + + +def test_default_type(): + + item = ItemInfo("Foo") + assert item.type == EUMType.Undefined + assert repr(item.unit) == "undefined" + + +def test_int_is_valid_type_info(): + + item = ItemInfo("Foo", 100123) + assert item.type == EUMType.Viscosity + + item = ItemInfo("U", 100002) + assert item.type == EUMType.Wind_Velocity + + +def test_int_is_valid_unit_info(): + + item = ItemInfo("U", 100002, 2000) + assert item.type == EUMType.Wind_Velocity + assert item.unit == EUMUnit.meter_per_sec + assert repr(item.unit) == "meter per sec" # TODO replace _per_ with / + + +def test_default_unit_from_type(): + + item = ItemInfo("Foo", EUMType.Water_Level) + assert item.type == EUMType.Water_Level + assert item.unit == EUMUnit.meter + assert repr(item.unit) == "meter" + + item = ItemInfo("Tp", EUMType.Wave_period) + assert item.type == EUMType.Wave_period + assert item.unit == EUMUnit.second + assert repr(item.unit) == "second" + + item = ItemInfo("Temperature", EUMType.Temperature) + assert item.type == EUMType.Temperature + assert item.unit == EUMUnit.degree_Celsius + assert repr(item.unit) == "degree Celsius" + + +def test_default_name_from_type(): + + item = ItemInfo(EUMType.Current_Speed) + assert item.name == "Current Speed" + assert item.unit == EUMUnit.meter_per_sec + + item2 = ItemInfo(EUMType.Current_Direction, EUMUnit.degree) + assert item2.unit == EUMUnit.degree + item3 = ItemInfo( + "Current direction (going to)", EUMType.Current_Direction, EUMUnit.degree + ) + assert item3.type == EUMType.Current_Direction + assert item3.unit == EUMUnit.degree + + +def test_iteminfo_string_type_should_fail_with_helpful_message(): + + with pytest.raises(ValueError): + + item = ItemInfo("Water level", "Water level") + + +def test_item_search(): + + res = EUMType.search("level") + + assert len(res) > 0 + assert isinstance(res[0], EUMType) + + +def test_dfsu3d_dataset(): + + filename = "tests/testdata/oresund_sigma_z.dfsu" + + dfsu = mikeio.open(filename) + + ds = dfsu.read() + + text = repr(ds) + + assert len(ds) == 2 # Salinity, Temperature + + dsagg = ds.nanmean(axis=0) # Time averaged + + assert len(dsagg) == 2 # Salinity, Temperature + + assert dsagg[0].shape[0] == 17118 + + assert dsagg.time[0] == ds.time[0] # Time-averaged data index by start time + + ds_elm = dfsu.read(elements=[0]) + + assert len(ds_elm) == 2 # Salinity, Temperature + + dss = ds_elm.squeeze() + + assert len(dss) == 2 # Salinity, Temperature + + +def test_items_data_mismatch(): + + nt = 100 + d = np.zeros([nt, 100, 30]) + 1.0 + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo"), ItemInfo("Bar")] # Two items is not correct! + + with pytest.raises(ValueError): + mikeio.Dataset([d], time, items) + + +def test_time_data_mismatch(): + + nt = 100 + d = np.zeros([nt, 100, 30]) + 1.0 + time = pd.date_range( + "2000-1-2", freq="H", periods=nt + 1 + ) # 101 timesteps is not correct! + items = [ItemInfo("Foo")] + + with pytest.raises(ValueError): + mikeio.Dataset([d], time, items) + + +def test_properties_dfs2(): + filename = "tests/testdata/gebco_sound.dfs2" + ds = mikeio.read(filename) + + assert ds.n_timesteps == 1 + assert ds.n_items == 1 + assert np.all(ds.shape == (1, 264, 216)) + assert ds.n_elements == (264 * 216) + assert ds.is_equidistant + + +def test_properties_dfsu(): + filename = "tests/testdata/oresund_vertical_slice.dfsu" + ds = mikeio.read(filename) + + assert ds.n_timesteps == 3 + assert ds.start_time == datetime(1997, 9, 15, 21, 0, 0) + assert ds.end_time == datetime(1997, 9, 16, 3, 0, 0) + assert ds.timestep == (3 * 3600) + assert ds.n_items == 2 + assert np.all(ds.shape == (3, 441)) + assert ds.n_elements == 441 + assert ds.is_equidistant + + +def test_create_empty_data(): + ne = 34 + d = mikeio.Dataset.create_empty_data(n_elements=ne) + assert len(d) == 1 + assert d[0].shape == (1, ne) + + nt = 100 + d = mikeio.Dataset.create_empty_data(n_timesteps=nt, shape=(3, 4, 6)) + assert len(d) == 1 + assert d[0].shape == (nt, 3, 4, 6) + + ni = 4 + d = mikeio.Dataset.create_empty_data(n_items=ni, n_elements=ne) + assert len(d) == ni + assert d[-1].shape == (1, ne) + + with pytest.raises(Exception): + d = mikeio.Dataset.create_empty_data() + + with pytest.raises(Exception): + d = mikeio.Dataset.create_empty_data(n_elements=None, shape=None) + + +def test_create_infer_name_from_eum(): + + nt = 100 + d = np.random.uniform(size=nt) + + ds = mikeio.Dataset( + data=[d], + time=pd.date_range("2000-01-01", freq="H", periods=nt), + items=[EUMType.Wind_speed], + ) + + assert isinstance(ds.items[0], ItemInfo) + assert ds.items[0].type == EUMType.Wind_speed + assert ds.items[0].name == "Wind speed" + + +def test_add_scalar(ds1): + ds2 = ds1 + 10.0 + assert np.all(ds2[0].to_numpy() - ds1[0].to_numpy() == 10.0) + + ds3 = 10.0 + ds1 + assert np.all(ds3[0].to_numpy() == ds2[0].to_numpy()) + assert np.all(ds3[1].to_numpy() == ds2[1].to_numpy()) + + +def test_add_inconsistent_dataset(ds1): + + ds2 = ds1[[0]] + + assert len(ds1) != len(ds2) + + with pytest.raises(ValueError): + ds1 + ds2 + + with pytest.raises(ValueError): + ds1 * ds2 + + +def test_add_bad_value(ds1): + + with pytest.raises(ValueError): + ds1 + ["one"] + + +def test_multiple_bad_value(ds1): + + with pytest.raises(ValueError): + ds1 * ["pi"] + + +def test_sub_scalar(ds1): + ds2 = ds1 - 10.0 + assert isinstance(ds2, mikeio.Dataset) + assert np.all(ds1[0].to_numpy() - ds2[0].to_numpy() == 10.0) + + ds3 = 10.0 - ds1 + assert isinstance(ds3, mikeio.Dataset) + assert np.all(ds3[0].to_numpy() == 9.9) + assert np.all(ds3[1].to_numpy() == 9.8) + + +def test_mul_scalar(ds1): + ds2 = ds1 * 2.0 + assert np.all(ds2[0].to_numpy() * 0.5 == ds1[0].to_numpy()) + + ds3 = 2.0 * ds1 + assert np.all(ds3[0].to_numpy() == ds2[0].to_numpy()) + assert np.all(ds3[1].to_numpy() == ds2[1].to_numpy()) + + +def test_add_dataset(ds1, ds2): + ds3 = ds1 + ds2 + assert np.all(ds3[0].to_numpy() == 1.1) + assert np.all(ds3[1].to_numpy() == 2.2) + + ds4 = ds2 + ds1 + assert np.all(ds3[0].to_numpy() == ds4[0].to_numpy()) + assert np.all(ds3[1].to_numpy() == ds4[1].to_numpy()) + + ds2b = ds2.copy() + ds2b[0].item = ItemInfo(EUMType.Wind_Velocity) + with pytest.raises(ValueError): + # item type does not match + ds1 + ds2b + + ds2c = ds2.copy() + tt = ds2c.time.to_numpy() + tt[-1] = tt[-1] + np.timedelta64(1, "s") + ds2c.time = pd.DatetimeIndex(tt) + with pytest.raises(ValueError): + # time does not match + ds1 + ds2c + + +def test_sub_dataset(ds1, ds2): + ds3 = ds2 - ds1 + assert np.all(ds3[0].to_numpy() == 0.9) + assert np.all(ds3[1].to_numpy() == 1.8) + + +def test_non_equidistant(): + nt = 4 + d = np.random.uniform(size=nt) + + ds = mikeio.Dataset( + data=[d], + time=[ + datetime(2000, 1, 1), + datetime(2001, 1, 1), + datetime(2002, 1, 1), + datetime(2003, 1, 1), + ], + ) + + assert not ds.is_equidistant + + +def test_concat_dataarray_by_time(): + da1 = mikeio.read("tests/testdata/tide1.dfs1")[0] + da2 = mikeio.read("tests/testdata/tide2.dfs1")[0] + da3 = mikeio.DataArray.concat([da1, da2]) + + assert da3.start_time == da1.start_time + assert da3.start_time < da2.start_time + assert da3.end_time == da2.end_time + assert da3.end_time > da1.end_time + assert da3.n_timesteps == 145 + assert da3.is_equidistant + + +def test_concat_by_time(): + ds1 = mikeio.read("tests/testdata/tide1.dfs1") + ds2 = mikeio.read("tests/testdata/tide2.dfs1") + 0.5 # add offset + ds3 = mikeio.Dataset.concat([ds1, ds2]) + + assert isinstance(ds3, mikeio.Dataset) + assert len(ds1) == len(ds2) == len(ds3) + assert ds3.start_time == ds1.start_time + assert ds3.start_time < ds2.start_time + assert ds3.end_time == ds2.end_time + assert ds3.end_time > ds1.end_time + assert ds3.n_timesteps == 145 + assert ds3.is_equidistant + + +def test_concat_by_time_ndim1(): + ds1 = mikeio.read("tests/testdata/tide1.dfs1").isel(x=0) + ds2 = mikeio.read("tests/testdata/tide2.dfs1").isel(x=0) + ds3 = mikeio.Dataset.concat([ds1, ds2]) + + assert isinstance(ds3, mikeio.Dataset) + assert len(ds1) == len(ds2) == len(ds3) + assert ds3.start_time == ds1.start_time + assert ds3.start_time < ds2.start_time + assert ds3.end_time == ds2.end_time + assert ds3.end_time > ds1.end_time + assert ds3.n_timesteps == 145 + assert ds3.is_equidistant + + +def test_concat_by_time_inconsistent_shape_not_possible(): + ds1 = mikeio.read("tests/testdata/tide1.dfs1").isel(x=[0, 1]) + ds2 = mikeio.read("tests/testdata/tide2.dfs1").isel(x=[0, 1, 2]) + with pytest.raises(ValueError, match="Shape"): + mikeio.Dataset.concat([ds1, ds2]) + + +# TODO: implement this +def test_concat_by_time_no_time(): + ds1 = mikeio.read("tests/testdata/tide1.dfs1", time=0) + ds2 = mikeio.read("tests/testdata/tide2.dfs1", time=1) + ds3 = mikeio.Dataset.concat([ds1, ds2]) + + assert ds3.n_timesteps == 2 + + +def test_concat_by_time_2(): + ds1 = mikeio.read("tests/testdata/tide1.dfs1", time=range(0, 12)) + ds2 = mikeio.read("tests/testdata/tide2.dfs1") + ds3 = mikeio.Dataset.concat([ds1, ds2]) + + assert ds3.n_timesteps == 109 + assert not ds3.is_equidistant + + # create concatd datasets in 8 chunks of 6 hours + dsall = [] + for j in range(8): + dsall.append( + mikeio.read( + "tests/testdata/tide1.dfs1", time=range(j * 12, 1 + (j + 1) * 12) + ) + ) + ds4 = mikeio.Dataset.concat(dsall) + assert len(dsall) == 8 + assert ds4.n_timesteps == 97 + assert ds4.is_equidistant + + +def test_renamed_dataset_has_updated_attributes(ds1: mikeio.Dataset): + assert hasattr(ds1, "Foo") + assert isinstance(ds1.Foo, mikeio.DataArray) + ds2 = ds1.rename(dict(Foo="Baz")) + assert not hasattr(ds2, "Foo") + assert hasattr(ds2, "Baz") + assert isinstance(ds2.Baz, mikeio.DataArray) + + # inplace version + ds1.rename(dict(Foo="Baz"), inplace=True) + assert not hasattr(ds1, "Foo") + assert hasattr(ds1, "Baz") + assert isinstance(ds1.Baz, mikeio.DataArray) + + +def test_merge_by_item(): + ds1 = mikeio.read("tests/testdata/tide1.dfs1") + ds2 = mikeio.read("tests/testdata/tide1.dfs1") + old_name = ds2[0].name + new_name = old_name + " v2" + # ds2[0].name = ds2[0].name + " v2" + ds2.rename({old_name: new_name}, inplace=True) + ds3 = mikeio.Dataset.merge([ds1, ds2]) + + assert isinstance(ds3, mikeio.Dataset) + assert ds3.n_items == 2 + assert ds3[1].name == ds1[0].name + " v2" + + +def test_merge_by_item_dfsu_3d(): + ds1 = mikeio.read("tests/testdata/oresund_sigma_z.dfsu", items=[0]) + assert ds1.n_items == 1 + ds2 = mikeio.read("tests/testdata/oresund_sigma_z.dfsu", items=[1]) + assert ds2.n_items == 1 + + ds3 = mikeio.Dataset.merge([ds1, ds2]) + + assert isinstance(ds3, mikeio.Dataset) + itemnames = [x.name for x in ds3.items] + assert "Salinity" in itemnames + assert "Temperature" in itemnames + assert ds3.n_items == 2 + + +def test_to_numpy(ds2): + + X = ds2.to_numpy() + + assert X.shape == (ds2.n_items,) + ds2.shape + assert isinstance(X, np.ndarray) + + +def test_concat(): + filename = "tests/testdata/HD2D.dfsu" + ds1 = mikeio.read(filename, time=[0, 1]) + ds2 = mikeio.read(filename, time=[2, 3]) + ds3 = mikeio.Dataset.concat([ds1, ds2]) + ds3.n_timesteps + + assert ds1.n_items == ds2.n_items == ds3.n_items + assert ds3.n_timesteps == (ds1.n_timesteps + ds2.n_timesteps) + assert ds3.start_time == ds1.start_time + assert ds3.end_time == ds2.end_time + assert type(ds3.geometry) == type(ds1.geometry) + assert ds3.geometry.n_elements == ds1.geometry.n_elements + + +def test_concat_dfsu3d(): + filename = "tests/testdata/basin_3d.dfsu" + ds = mikeio.read(filename) + ds1 = mikeio.read(filename, time=[0, 1]) + ds2 = mikeio.read(filename, time=[1, 2]) + ds3 = mikeio.Dataset.concat([ds1, ds2]) + + assert ds1.n_items == ds2.n_items == ds3.n_items + assert ds3.start_time == ds.start_time + assert ds3.end_time == ds.end_time + assert type(ds3.geometry) == type(ds.geometry) + assert ds3.geometry.n_elements == ds1.geometry.n_elements + assert ds3._zn.shape == ds._zn.shape + assert np.all(ds3._zn == ds._zn) + + +def test_concat_dfsu3d_single_timesteps(): + filename = "tests/testdata/basin_3d.dfsu" + ds = mikeio.read(filename) + ds1 = mikeio.read(filename, time=0) + ds2 = mikeio.read(filename, time=2) + ds3 = mikeio.Dataset.concat([ds1, ds2]) + + assert ds1.n_items == ds2.n_items == ds3.n_items + assert ds3.start_time == ds1.start_time + assert ds3.end_time == ds2.end_time + + +def test_concat_dfs2_single_timesteps(): + filename = "tests/testdata/single_row.dfs2" + ds = mikeio.read(filename) + ds1 = mikeio.read(filename, time=0) + ds2 = mikeio.read(filename, time=2) + ds3 = mikeio.Dataset.concat([ds1, ds2]) + + assert ds1.n_items == ds2.n_items == ds3.n_items + assert ds3.start_time == ds1.start_time + assert ds3.end_time == ds2.end_time + assert ds3.n_timesteps == 2 + + +def test_merge_same_name_error(): + filename = "tests/testdata/HD2D.dfsu" + ds1 = mikeio.read(filename, items=0) + ds2 = mikeio.read(filename, items=0) + + assert ds1.items[0].name == ds2.items[0].name + + with pytest.raises(ValueError): + mikeio.Dataset.merge([ds1, ds2]) + + +def test_incompatible_data_not_allowed(): + + da1 = mikeio.read("tests/testdata/HD2D.dfsu")[0] + da2 = mikeio.read("tests/testdata/oresundHD_run1.dfsu")[1] + + with pytest.raises(ValueError) as excinfo: + mikeio.Dataset([da1, da2]) + + assert "shape" in str(excinfo.value).lower() + + da1 = mikeio.read("tests/testdata/tide1.dfs1")[0] + da2 = mikeio.read("tests/testdata/tide2.dfs1")[0] + + with pytest.raises(ValueError) as excinfo: + mikeio.Dataset([da1, da2]) + + assert "name" in str(excinfo.value).lower() + + da1 = mikeio.read("tests/testdata/tide1.dfs1")[0] + da2 = mikeio.read("tests/testdata/tide2.dfs1")[0] + da2.name = "Foo" + + with pytest.raises(ValueError) as excinfo: + mikeio.Dataset([da1, da2]) + + assert "time" in str(excinfo.value).lower() + + +def test_xzy_selection(): + # select in space via x,y,z coordinates test + filename = "tests/testdata/oresund_sigma_z.dfsu" + ds = mikeio.read(filename) + + with pytest.raises(OutsideModelDomainError): + ds.sel(x=340000, y=15.75, z=0) + + +def test_layer_selection(): + # select layer test + filename = "tests/testdata/oresund_sigma_z.dfsu" + ds = mikeio.read(filename) + + dss_layer = ds.sel(layers=0) + # should not be layered after selection + assert type(dss_layer.geometry) == mikeio.spatial.FM_geometry.GeometryFM + + +def test_time_selection(): + # select time test + nt = 100 + data = [] + d = np.random.rand(nt) + data.append(d) + time = pd.date_range("2000-1-2", freq="H", periods=nt) + items = [ItemInfo("Foo")] + ds = mikeio.Dataset(data, time, items) + + # check for string input + dss_t = ds.sel(time="2000-01-05") + # and index based + dss_tix = ds.sel(time=80) + + assert dss_t.shape == (24,) + assert len(dss_tix) == 1 + + +def test_create_dataset_with_many_items(): + n_items = 800 + nt = 2 + time = pd.date_range("2000", freq="H", periods=nt) + + das = [] + + for i in range(n_items): + x = np.random.random(nt) + da = mikeio.DataArray(data=x, time=time, item=mikeio.ItemInfo(f"Item {i+1}")) + das.append(da) + + ds = mikeio.Dataset(das) + + assert ds.n_items == n_items + + +def test_create_array_with_defaults_from_dataset(): + + filename = "tests/testdata/oresund_sigma_z.dfsu" + ds: mikeio.Dataset = mikeio.read(filename) + + values = np.zeros(ds.Temperature.shape) + + da = ds.create_data_array(values) + + assert isinstance(da, mikeio.DataArray) + assert da.geometry == ds.geometry + assert all(da.time == ds.time) + assert da.item.type == mikeio.EUMType.Undefined + + da_name = ds.create_data_array(values, "Foo") + + assert isinstance(da, mikeio.DataArray) + assert da_name.geometry == ds.geometry + assert da_name.item.type == mikeio.EUMType.Undefined + assert da_name.name == "Foo" + + da_eum = ds.create_data_array( + values, item=mikeio.ItemInfo("TS", mikeio.EUMType.Temperature) + ) + + assert isinstance(da_eum, mikeio.DataArray) + assert da_eum.geometry == ds.geometry + assert da_eum.item.type == mikeio.EUMType.Temperature + + +def test_dataset_plot(ds1): + ax = ds1.isel(x=0).plot() + assert len(ax.lines) == 2 + ds2 = ds1 + 0.01 + ax = ds2.isel(x=-1).plot(ax=ax) + assert len(ax.lines) == 4 + + +def test_interp_na(): + time = pd.date_range("2000", periods=5, freq="D") + da = mikeio.DataArray( + data=np.array([np.nan, 1.0, np.nan, np.nan, 4.0]), + time=time, + item=ItemInfo(name="Foo"), + ) + da2 = mikeio.DataArray( + data=np.array([np.nan, np.nan, np.nan, 4.0, 4.0]), + time=time, + item=ItemInfo(name="Bar"), + ) + + ds = mikeio.Dataset([da, da2]) + + dsi = ds.interp_na() + assert np.isnan(dsi[0].to_numpy()[0]) + assert dsi.Foo.to_numpy()[2] == pytest.approx(2.0) + assert np.isnan(dsi.Foo.to_numpy()[0]) + + dsi = ds.interp_na(fill_value="extrapolate") + assert dsi.Foo.to_numpy()[0] == pytest.approx(0.0) + assert dsi.Bar.to_numpy()[2] == pytest.approx(4.0) + + +def test_plot_scatter(): + ds = mikeio.read("tests/testdata/oresund_sigma_z.dfsu", time=0) + ds.plot.scatter(x="Salinity", y="Temperature", title="S-vs-T") diff --git a/crates/ruff_benchmark/resources/numpy/ctypeslib.py b/crates/ruff_benchmark/resources/numpy/ctypeslib.py new file mode 100644 index 0000000000..c4bafca1bd --- /dev/null +++ b/crates/ruff_benchmark/resources/numpy/ctypeslib.py @@ -0,0 +1,547 @@ +""" +============================ +``ctypes`` Utility Functions +============================ + +See Also +-------- +load_library : Load a C library. +ndpointer : Array restype/argtype with verification. +as_ctypes : Create a ctypes array from an ndarray. +as_array : Create an ndarray from a ctypes array. + +References +---------- +.. [1] "SciPy Cookbook: ctypes", https://scipy-cookbook.readthedocs.io/items/Ctypes.html + +Examples +-------- +Load the C library: + +>>> _lib = np.ctypeslib.load_library('libmystuff', '.') #doctest: +SKIP + +Our result type, an ndarray that must be of type double, be 1-dimensional +and is C-contiguous in memory: + +>>> array_1d_double = np.ctypeslib.ndpointer( +... dtype=np.double, +... ndim=1, flags='CONTIGUOUS') #doctest: +SKIP + +Our C-function typically takes an array and updates its values +in-place. For example:: + + void foo_func(double* x, int length) + { + int i; + for (i = 0; i < length; i++) { + x[i] = i*i; + } + } + +We wrap it using: + +>>> _lib.foo_func.restype = None #doctest: +SKIP +>>> _lib.foo_func.argtypes = [array_1d_double, c_int] #doctest: +SKIP + +Then, we're ready to call ``foo_func``: + +>>> out = np.empty(15, dtype=np.double) +>>> _lib.foo_func(out, len(out)) #doctest: +SKIP + +""" +__all__ = ['load_library', 'ndpointer', 'c_intp', 'as_ctypes', 'as_array', + 'as_ctypes_type'] + +import os +from numpy import ( + integer, ndarray, dtype as _dtype, asarray, frombuffer +) +from numpy.core.multiarray import _flagdict, flagsobj + +try: + import ctypes +except ImportError: + ctypes = None + +if ctypes is None: + def _dummy(*args, **kwds): + """ + Dummy object that raises an ImportError if ctypes is not available. + + Raises + ------ + ImportError + If ctypes is not available. + + """ + raise ImportError("ctypes is not available.") + load_library = _dummy + as_ctypes = _dummy + as_array = _dummy + from numpy import intp as c_intp + _ndptr_base = object +else: + import numpy.core._internal as nic + c_intp = nic._getintp_ctype() + del nic + _ndptr_base = ctypes.c_void_p + + # Adapted from Albert Strasheim + def load_library(libname, loader_path): + """ + It is possible to load a library using + + >>> lib = ctypes.cdll[] # doctest: +SKIP + + But there are cross-platform considerations, such as library file extensions, + plus the fact Windows will just load the first library it finds with that name. + NumPy supplies the load_library function as a convenience. + + .. versionchanged:: 1.20.0 + Allow libname and loader_path to take any + :term:`python:path-like object`. + + Parameters + ---------- + libname : path-like + Name of the library, which can have 'lib' as a prefix, + but without an extension. + loader_path : path-like + Where the library can be found. + + Returns + ------- + ctypes.cdll[libpath] : library object + A ctypes library object + + Raises + ------ + OSError + If there is no library with the expected extension, or the + library is defective and cannot be loaded. + """ + if ctypes.__version__ < '1.0.1': + import warnings + warnings.warn("All features of ctypes interface may not work " + "with ctypes < 1.0.1", stacklevel=2) + + # Convert path-like objects into strings + libname = os.fsdecode(libname) + loader_path = os.fsdecode(loader_path) + + ext = os.path.splitext(libname)[1] + if not ext: + # Try to load library with platform-specific name, otherwise + # default to libname.[so|pyd]. Sometimes, these files are built + # erroneously on non-linux platforms. + from numpy.distutils.misc_util import get_shared_lib_extension + so_ext = get_shared_lib_extension() + libname_ext = [libname + so_ext] + # mac, windows and linux >= py3.2 shared library and loadable + # module have different extensions so try both + so_ext2 = get_shared_lib_extension(is_python_ext=True) + if not so_ext2 == so_ext: + libname_ext.insert(0, libname + so_ext2) + else: + libname_ext = [libname] + + loader_path = os.path.abspath(loader_path) + if not os.path.isdir(loader_path): + libdir = os.path.dirname(loader_path) + else: + libdir = loader_path + + for ln in libname_ext: + libpath = os.path.join(libdir, ln) + if os.path.exists(libpath): + try: + return ctypes.cdll[libpath] + except OSError: + ## defective lib file + raise + ## if no successful return in the libname_ext loop: + raise OSError("no file with expected extension") + + +def _num_fromflags(flaglist): + num = 0 + for val in flaglist: + num += _flagdict[val] + return num + +_flagnames = ['C_CONTIGUOUS', 'F_CONTIGUOUS', 'ALIGNED', 'WRITEABLE', + 'OWNDATA', 'WRITEBACKIFCOPY'] +def _flags_fromnum(num): + res = [] + for key in _flagnames: + value = _flagdict[key] + if (num & value): + res.append(key) + return res + + +class _ndptr(_ndptr_base): + @classmethod + def from_param(cls, obj): + if not isinstance(obj, ndarray): + raise TypeError("argument must be an ndarray") + if cls._dtype_ is not None \ + and obj.dtype != cls._dtype_: + raise TypeError("array must have data type %s" % cls._dtype_) + if cls._ndim_ is not None \ + and obj.ndim != cls._ndim_: + raise TypeError("array must have %d dimension(s)" % cls._ndim_) + if cls._shape_ is not None \ + and obj.shape != cls._shape_: + raise TypeError("array must have shape %s" % str(cls._shape_)) + if cls._flags_ is not None \ + and ((obj.flags.num & cls._flags_) != cls._flags_): + raise TypeError("array must have flags %s" % + _flags_fromnum(cls._flags_)) + return obj.ctypes + + +class _concrete_ndptr(_ndptr): + """ + Like _ndptr, but with `_shape_` and `_dtype_` specified. + + Notably, this means the pointer has enough information to reconstruct + the array, which is not generally true. + """ + def _check_retval_(self): + """ + This method is called when this class is used as the .restype + attribute for a shared-library function, to automatically wrap the + pointer into an array. + """ + return self.contents + + @property + def contents(self): + """ + Get an ndarray viewing the data pointed to by this pointer. + + This mirrors the `contents` attribute of a normal ctypes pointer + """ + full_dtype = _dtype((self._dtype_, self._shape_)) + full_ctype = ctypes.c_char * full_dtype.itemsize + buffer = ctypes.cast(self, ctypes.POINTER(full_ctype)).contents + return frombuffer(buffer, dtype=full_dtype).squeeze(axis=0) + + +# Factory for an array-checking class with from_param defined for +# use with ctypes argtypes mechanism +_pointer_type_cache = {} +def ndpointer(dtype=None, ndim=None, shape=None, flags=None): + """ + Array-checking restype/argtypes. + + An ndpointer instance is used to describe an ndarray in restypes + and argtypes specifications. This approach is more flexible than + using, for example, ``POINTER(c_double)``, since several restrictions + can be specified, which are verified upon calling the ctypes function. + These include data type, number of dimensions, shape and flags. If a + given array does not satisfy the specified restrictions, + a ``TypeError`` is raised. + + Parameters + ---------- + dtype : data-type, optional + Array data-type. + ndim : int, optional + Number of array dimensions. + shape : tuple of ints, optional + Array shape. + flags : str or tuple of str + Array flags; may be one or more of: + + - C_CONTIGUOUS / C / CONTIGUOUS + - F_CONTIGUOUS / F / FORTRAN + - OWNDATA / O + - WRITEABLE / W + - ALIGNED / A + - WRITEBACKIFCOPY / X + + Returns + ------- + klass : ndpointer type object + A type object, which is an ``_ndtpr`` instance containing + dtype, ndim, shape and flags information. + + Raises + ------ + TypeError + If a given array does not satisfy the specified restrictions. + + Examples + -------- + >>> clib.somefunc.argtypes = [np.ctypeslib.ndpointer(dtype=np.float64, + ... ndim=1, + ... flags='C_CONTIGUOUS')] + ... #doctest: +SKIP + >>> clib.somefunc(np.array([1, 2, 3], dtype=np.float64)) + ... #doctest: +SKIP + + """ + + # normalize dtype to an Optional[dtype] + if dtype is not None: + dtype = _dtype(dtype) + + # normalize flags to an Optional[int] + num = None + if flags is not None: + if isinstance(flags, str): + flags = flags.split(',') + elif isinstance(flags, (int, integer)): + num = flags + flags = _flags_fromnum(num) + elif isinstance(flags, flagsobj): + num = flags.num + flags = _flags_fromnum(num) + if num is None: + try: + flags = [x.strip().upper() for x in flags] + except Exception as e: + raise TypeError("invalid flags specification") from e + num = _num_fromflags(flags) + + # normalize shape to an Optional[tuple] + if shape is not None: + try: + shape = tuple(shape) + except TypeError: + # single integer -> 1-tuple + shape = (shape,) + + cache_key = (dtype, ndim, shape, num) + + try: + return _pointer_type_cache[cache_key] + except KeyError: + pass + + # produce a name for the new type + if dtype is None: + name = 'any' + elif dtype.names is not None: + name = str(id(dtype)) + else: + name = dtype.str + if ndim is not None: + name += "_%dd" % ndim + if shape is not None: + name += "_"+"x".join(str(x) for x in shape) + if flags is not None: + name += "_"+"_".join(flags) + + if dtype is not None and shape is not None: + base = _concrete_ndptr + else: + base = _ndptr + + klass = type("ndpointer_%s"%name, (base,), + {"_dtype_": dtype, + "_shape_" : shape, + "_ndim_" : ndim, + "_flags_" : num}) + _pointer_type_cache[cache_key] = klass + return klass + + +if ctypes is not None: + def _ctype_ndarray(element_type, shape): + """ Create an ndarray of the given element type and shape """ + for dim in shape[::-1]: + element_type = dim * element_type + # prevent the type name include np.ctypeslib + element_type.__module__ = None + return element_type + + + def _get_scalar_type_map(): + """ + Return a dictionary mapping native endian scalar dtype to ctypes types + """ + ct = ctypes + simple_types = [ + ct.c_byte, ct.c_short, ct.c_int, ct.c_long, ct.c_longlong, + ct.c_ubyte, ct.c_ushort, ct.c_uint, ct.c_ulong, ct.c_ulonglong, + ct.c_float, ct.c_double, + ct.c_bool, + ] + return {_dtype(ctype): ctype for ctype in simple_types} + + + _scalar_type_map = _get_scalar_type_map() + + + def _ctype_from_dtype_scalar(dtype): + # swapping twice ensure that `=` is promoted to <, >, or | + dtype_with_endian = dtype.newbyteorder('S').newbyteorder('S') + dtype_native = dtype.newbyteorder('=') + try: + ctype = _scalar_type_map[dtype_native] + except KeyError as e: + raise NotImplementedError( + "Converting {!r} to a ctypes type".format(dtype) + ) from None + + if dtype_with_endian.byteorder == '>': + ctype = ctype.__ctype_be__ + elif dtype_with_endian.byteorder == '<': + ctype = ctype.__ctype_le__ + + return ctype + + + def _ctype_from_dtype_subarray(dtype): + element_dtype, shape = dtype.subdtype + ctype = _ctype_from_dtype(element_dtype) + return _ctype_ndarray(ctype, shape) + + + def _ctype_from_dtype_structured(dtype): + # extract offsets of each field + field_data = [] + for name in dtype.names: + field_dtype, offset = dtype.fields[name][:2] + field_data.append((offset, name, _ctype_from_dtype(field_dtype))) + + # ctypes doesn't care about field order + field_data = sorted(field_data, key=lambda f: f[0]) + + if len(field_data) > 1 and all(offset == 0 for offset, name, ctype in field_data): + # union, if multiple fields all at address 0 + size = 0 + _fields_ = [] + for offset, name, ctype in field_data: + _fields_.append((name, ctype)) + size = max(size, ctypes.sizeof(ctype)) + + # pad to the right size + if dtype.itemsize != size: + _fields_.append(('', ctypes.c_char * dtype.itemsize)) + + # we inserted manual padding, so always `_pack_` + return type('union', (ctypes.Union,), dict( + _fields_=_fields_, + _pack_=1, + __module__=None, + )) + else: + last_offset = 0 + _fields_ = [] + for offset, name, ctype in field_data: + padding = offset - last_offset + if padding < 0: + raise NotImplementedError("Overlapping fields") + if padding > 0: + _fields_.append(('', ctypes.c_char * padding)) + + _fields_.append((name, ctype)) + last_offset = offset + ctypes.sizeof(ctype) + + + padding = dtype.itemsize - last_offset + if padding > 0: + _fields_.append(('', ctypes.c_char * padding)) + + # we inserted manual padding, so always `_pack_` + return type('struct', (ctypes.Structure,), dict( + _fields_=_fields_, + _pack_=1, + __module__=None, + )) + + + def _ctype_from_dtype(dtype): + if dtype.fields is not None: + return _ctype_from_dtype_structured(dtype) + elif dtype.subdtype is not None: + return _ctype_from_dtype_subarray(dtype) + else: + return _ctype_from_dtype_scalar(dtype) + + + def as_ctypes_type(dtype): + r""" + Convert a dtype into a ctypes type. + + Parameters + ---------- + dtype : dtype + The dtype to convert + + Returns + ------- + ctype + A ctype scalar, union, array, or struct + + Raises + ------ + NotImplementedError + If the conversion is not possible + + Notes + ----- + This function does not losslessly round-trip in either direction. + + ``np.dtype(as_ctypes_type(dt))`` will: + + - insert padding fields + - reorder fields to be sorted by offset + - discard field titles + + ``as_ctypes_type(np.dtype(ctype))`` will: + + - discard the class names of `ctypes.Structure`\ s and + `ctypes.Union`\ s + - convert single-element `ctypes.Union`\ s into single-element + `ctypes.Structure`\ s + - insert padding fields + + """ + return _ctype_from_dtype(_dtype(dtype)) + + + def as_array(obj, shape=None): + """ + Create a numpy array from a ctypes array or POINTER. + + The numpy array shares the memory with the ctypes object. + + The shape parameter must be given if converting from a ctypes POINTER. + The shape parameter is ignored if converting from a ctypes array + """ + if isinstance(obj, ctypes._Pointer): + # convert pointers to an array of the desired shape + if shape is None: + raise TypeError( + 'as_array() requires a shape argument when called on a ' + 'pointer') + p_arr_type = ctypes.POINTER(_ctype_ndarray(obj._type_, shape)) + obj = ctypes.cast(obj, p_arr_type).contents + + return asarray(obj) + + + def as_ctypes(obj): + """Create and return a ctypes object from a numpy array. Actually + anything that exposes the __array_interface__ is accepted.""" + ai = obj.__array_interface__ + if ai["strides"]: + raise TypeError("strided arrays not supported") + if ai["version"] != 3: + raise TypeError("only __array_interface__ version 3 supported") + addr, readonly = ai["data"] + if readonly: + raise TypeError("readonly arrays unsupported") + + # can't use `_dtype((ai["typestr"], ai["shape"]))` here, as it overflows + # dtype.itemsize (gh-14214) + ctype_scalar = as_ctypes_type(ai["typestr"]) + result_type = _ctype_ndarray(ctype_scalar, ai["shape"]) + result = result_type.from_address(addr) + result.__keep = obj + return result diff --git a/crates/ruff_benchmark/resources/numpy/globals.py b/crates/ruff_benchmark/resources/numpy/globals.py new file mode 100644 index 0000000000..416a20f5e1 --- /dev/null +++ b/crates/ruff_benchmark/resources/numpy/globals.py @@ -0,0 +1,95 @@ +""" +Module defining global singleton classes. + +This module raises a RuntimeError if an attempt to reload it is made. In that +way the identities of the classes defined here are fixed and will remain so +even if numpy itself is reloaded. In particular, a function like the following +will still work correctly after numpy is reloaded:: + + def foo(arg=np._NoValue): + if arg is np._NoValue: + ... + +That was not the case when the singleton classes were defined in the numpy +``__init__.py`` file. See gh-7844 for a discussion of the reload problem that +motivated this module. + +""" +import enum + +from ._utils import set_module as _set_module + +__all__ = ['_NoValue', '_CopyMode'] + + +# Disallow reloading this module so as to preserve the identities of the +# classes defined here. +if '_is_loaded' in globals(): + raise RuntimeError('Reloading numpy._globals is not allowed') +_is_loaded = True + + +class _NoValueType: + """Special keyword value. + + The instance of this class may be used as the default value assigned to a + keyword if no other obvious default (e.g., `None`) is suitable, + + Common reasons for using this keyword are: + + - A new keyword is added to a function, and that function forwards its + inputs to another function or method which can be defined outside of + NumPy. For example, ``np.std(x)`` calls ``x.std``, so when a ``keepdims`` + keyword was added that could only be forwarded if the user explicitly + specified ``keepdims``; downstream array libraries may not have added + the same keyword, so adding ``x.std(..., keepdims=keepdims)`` + unconditionally could have broken previously working code. + - A keyword is being deprecated, and a deprecation warning must only be + emitted when the keyword is used. + + """ + __instance = None + def __new__(cls): + # ensure that only one instance exists + if not cls.__instance: + cls.__instance = super().__new__(cls) + return cls.__instance + + def __repr__(self): + return "" + + +_NoValue = _NoValueType() + + +@_set_module("numpy") +class _CopyMode(enum.Enum): + """ + An enumeration for the copy modes supported + by numpy.copy() and numpy.array(). The following three modes are supported, + + - ALWAYS: This means that a deep copy of the input + array will always be taken. + - IF_NEEDED: This means that a deep copy of the input + array will be taken only if necessary. + - NEVER: This means that the deep copy will never be taken. + If a copy cannot be avoided then a `ValueError` will be + raised. + + Note that the buffer-protocol could in theory do copies. NumPy currently + assumes an object exporting the buffer protocol will never do this. + """ + + ALWAYS = True + IF_NEEDED = False + NEVER = 2 + + def __bool__(self): + # For backwards compatibility + if self == _CopyMode.ALWAYS: + return True + + if self == _CopyMode.IF_NEEDED: + return False + + raise ValueError(f"{self} is neither True nor False.") diff --git a/crates/ruff_benchmark/resources/pydantic/types.py b/crates/ruff_benchmark/resources/pydantic/types.py new file mode 100644 index 0000000000..df455020d0 --- /dev/null +++ b/crates/ruff_benchmark/resources/pydantic/types.py @@ -0,0 +1,834 @@ +from __future__ import annotations as _annotations + +import abc +import dataclasses as _dataclasses +import re +from datetime import date, datetime +from decimal import Decimal +from enum import Enum +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + FrozenSet, + Generic, + Hashable, + List, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) +from uuid import UUID + +import annotated_types +from pydantic_core import PydanticCustomError, PydanticKnownError, core_schema +from typing_extensions import Annotated, Literal + +from ._internal import _fields, _validators + +__all__ = [ + 'Strict', + 'StrictStr', + 'conbytes', + 'conlist', + 'conset', + 'confrozenset', + 'constr', + 'ImportString', + 'conint', + 'PositiveInt', + 'NegativeInt', + 'NonNegativeInt', + 'NonPositiveInt', + 'confloat', + 'PositiveFloat', + 'NegativeFloat', + 'NonNegativeFloat', + 'NonPositiveFloat', + 'FiniteFloat', + 'condecimal', + 'UUID1', + 'UUID3', + 'UUID4', + 'UUID5', + 'FilePath', + 'DirectoryPath', + 'Json', + 'SecretField', + 'SecretStr', + 'SecretBytes', + 'StrictBool', + 'StrictBytes', + 'StrictInt', + 'StrictFloat', + 'PaymentCardNumber', + 'ByteSize', + 'PastDate', + 'FutureDate', + 'condate', + 'AwareDatetime', + 'NaiveDatetime', +] + +from ._internal._core_metadata import build_metadata_dict +from ._internal._utils import update_not_none +from .json_schema import JsonSchemaMetadata + + +@_dataclasses.dataclass +class Strict(_fields.PydanticMetadata): + strict: bool = True + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BOOLEAN TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +StrictBool = Annotated[bool, Strict()] + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTEGER TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +def conint( + *, + strict: bool | None = None, + gt: int | None = None, + ge: int | None = None, + lt: int | None = None, + le: int | None = None, + multiple_of: int | None = None, +) -> type[int]: + return Annotated[ # type: ignore[return-value] + int, + Strict(strict) if strict is not None else None, + annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le), + annotated_types.MultipleOf(multiple_of) if multiple_of is not None else None, + ] + + +PositiveInt = Annotated[int, annotated_types.Gt(0)] +NegativeInt = Annotated[int, annotated_types.Lt(0)] +NonPositiveInt = Annotated[int, annotated_types.Le(0)] +NonNegativeInt = Annotated[int, annotated_types.Ge(0)] +StrictInt = Annotated[int, Strict()] + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLOAT TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +@_dataclasses.dataclass +class AllowInfNan(_fields.PydanticMetadata): + allow_inf_nan: bool = True + + +def confloat( + *, + strict: bool | None = None, + gt: float | None = None, + ge: float | None = None, + lt: float | None = None, + le: float | None = None, + multiple_of: float | None = None, + allow_inf_nan: bool | None = None, +) -> type[float]: + return Annotated[ # type: ignore[return-value] + float, + Strict(strict) if strict is not None else None, + annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le), + annotated_types.MultipleOf(multiple_of) if multiple_of is not None else None, + AllowInfNan(allow_inf_nan) if allow_inf_nan is not None else None, + ] + + +PositiveFloat = Annotated[float, annotated_types.Gt(0)] +NegativeFloat = Annotated[float, annotated_types.Lt(0)] +NonPositiveFloat = Annotated[float, annotated_types.Le(0)] +NonNegativeFloat = Annotated[float, annotated_types.Ge(0)] +StrictFloat = Annotated[float, Strict(True)] +FiniteFloat = Annotated[float, AllowInfNan(False)] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTES TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +def conbytes( + *, + min_length: int | None = None, + max_length: int | None = None, + strict: bool | None = None, +) -> type[bytes]: + return Annotated[ # type: ignore[return-value] + bytes, + Strict(strict) if strict is not None else None, + annotated_types.Len(min_length or 0, max_length), + ] + + +StrictBytes = Annotated[bytes, Strict()] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ STRING TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +def constr( + *, + strip_whitespace: bool | None = None, + to_upper: bool | None = None, + to_lower: bool | None = None, + strict: bool | None = None, + min_length: int | None = None, + max_length: int | None = None, + pattern: str | None = None, +) -> type[str]: + return Annotated[ # type: ignore[return-value] + str, + Strict(strict) if strict is not None else None, + annotated_types.Len(min_length or 0, max_length), + _fields.PydanticGeneralMetadata( + strip_whitespace=strip_whitespace, + to_upper=to_upper, + to_lower=to_lower, + pattern=pattern, + ), + ] + + +StrictStr = Annotated[str, Strict()] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ COLLECTION TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +HashableItemType = TypeVar('HashableItemType', bound=Hashable) + + +def conset( + item_type: Type[HashableItemType], *, min_length: int = None, max_length: int = None +) -> Type[Set[HashableItemType]]: + return Annotated[ # type: ignore[return-value] + Set[item_type], annotated_types.Len(min_length or 0, max_length) # type: ignore[valid-type] + ] + + +def confrozenset( + item_type: Type[HashableItemType], *, min_length: int | None = None, max_length: int | None = None +) -> Type[FrozenSet[HashableItemType]]: + return Annotated[ # type: ignore[return-value] + FrozenSet[item_type], # type: ignore[valid-type] + annotated_types.Len(min_length or 0, max_length), + ] + + +AnyItemType = TypeVar('AnyItemType') + + +def conlist( + item_type: Type[AnyItemType], *, min_length: int | None = None, max_length: int | None = None +) -> Type[List[AnyItemType]]: + return Annotated[ # type: ignore[return-value] + List[item_type], # type: ignore[valid-type] + annotated_types.Len(min_length or 0, max_length), + ] + + +def contuple( + item_type: Type[AnyItemType], *, min_length: int | None = None, max_length: int | None = None +) -> Type[Tuple[AnyItemType]]: + return Annotated[ # type: ignore[return-value] + Tuple[item_type], + annotated_types.Len(min_length or 0, max_length), + ] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ IMPORT STRING TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +AnyType = TypeVar('AnyType') +if TYPE_CHECKING: + ImportString = Annotated[AnyType, ...] +else: + + class ImportString: + @classmethod + def __class_getitem__(cls, item: AnyType) -> AnyType: + return Annotated[item, cls()] + + @classmethod + def __get_pydantic_core_schema__( + cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any + ) -> core_schema.CoreSchema: + if schema is None or schema == {'type': 'any'}: + # Treat bare usage of ImportString (`schema is None`) as the same as ImportString[Any] + return core_schema.function_plain_schema(lambda v, _: _validators.import_string(v)) + else: + return core_schema.function_before_schema(lambda v, _: _validators.import_string(v), schema) + + def __repr__(self) -> str: + return 'ImportString' + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DECIMAL TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +def condecimal( + *, + strict: bool | None = None, + gt: int | Decimal | None = None, + ge: int | Decimal | None = None, + lt: int | Decimal | None = None, + le: int | Decimal | None = None, + multiple_of: int | Decimal | None = None, + max_digits: int | None = None, + decimal_places: int | None = None, + allow_inf_nan: bool | None = None, +) -> Type[Decimal]: + return Annotated[ # type: ignore[return-value] + Decimal, + Strict(strict) if strict is not None else None, + annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le), + annotated_types.MultipleOf(multiple_of) if multiple_of is not None else None, + _fields.PydanticGeneralMetadata(max_digits=max_digits, decimal_places=decimal_places), + AllowInfNan(allow_inf_nan) if allow_inf_nan is not None else None, + ] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ UUID TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +@_dataclasses.dataclass(frozen=True) # Add frozen=True to make it hashable +class UuidVersion: + uuid_version: Literal[1, 3, 4, 5] + + def __pydantic_modify_json_schema__(self, field_schema: dict[str, Any]) -> None: + field_schema.pop('anyOf', None) # remove the bytes/str union + field_schema.update(type='string', format=f'uuid{self.uuid_version}') + + def __get_pydantic_core_schema__( + self, schema: core_schema.CoreSchema, **_kwargs: Any + ) -> core_schema.FunctionSchema: + return core_schema.function_after_schema(schema, cast(core_schema.ValidatorFunction, self.validate)) + + def validate(self, value: UUID, _: core_schema.ValidationInfo) -> UUID: + if value.version != self.uuid_version: + raise PydanticCustomError( + 'uuid_version', 'uuid version {required_version} expected', {'required_version': self.uuid_version} + ) + return value + + +UUID1 = Annotated[UUID, UuidVersion(1)] +UUID3 = Annotated[UUID, UuidVersion(3)] +UUID4 = Annotated[UUID, UuidVersion(4)] +UUID5 = Annotated[UUID, UuidVersion(5)] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PATH TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +@_dataclasses.dataclass +class PathType: + path_type: Literal['file', 'dir', 'new'] + + def __pydantic_modify_json_schema__(self, field_schema: dict[str, Any]) -> None: + format_conversion = {'file': 'file-path', 'dir': 'directory-path'} + field_schema.update(format=format_conversion.get(self.path_type, 'path'), type='string') + + def __get_pydantic_core_schema__( + self, schema: core_schema.CoreSchema, **_kwargs: Any + ) -> core_schema.FunctionSchema: + function_lookup = { + 'file': cast(core_schema.ValidatorFunction, self.validate_file), + 'dir': cast(core_schema.ValidatorFunction, self.validate_directory), + 'new': cast(core_schema.ValidatorFunction, self.validate_new), + } + + return core_schema.function_after_schema( + schema, + function_lookup[self.path_type], + ) + + @staticmethod + def validate_file(path: Path, _: core_schema.ValidationInfo) -> Path: + if path.is_file(): + return path + else: + raise PydanticCustomError('path_not_file', 'Path does not point to a file') + + @staticmethod + def validate_directory(path: Path, _: core_schema.ValidationInfo) -> Path: + if path.is_dir(): + return path + else: + raise PydanticCustomError('path_not_directory', 'Path does not point to a directory') + + @staticmethod + def validate_new(path: Path, _: core_schema.ValidationInfo) -> Path: + if path.exists(): + raise PydanticCustomError('path_exists', 'path already exists') + elif not path.parent.exists(): + raise PydanticCustomError('parent_does_not_exist', 'Parent directory does not exist') + else: + return path + + +FilePath = Annotated[Path, PathType('file')] +DirectoryPath = Annotated[Path, PathType('dir')] +NewPath = Annotated[Path, PathType('new')] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ JSON TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +if TYPE_CHECKING: + Json = Annotated[AnyType, ...] # Json[list[str]] will be recognized by type checkers as list[str] + +else: + + class Json: + @classmethod + def __class_getitem__(cls, item: AnyType) -> AnyType: + return Annotated[item, cls()] + + @classmethod + def __get_pydantic_core_schema__( + cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any + ) -> core_schema.JsonSchema: + return core_schema.json_schema(schema) + + @classmethod + def __pydantic_modify_json_schema__(cls, field_schema: dict[str, Any]) -> None: + field_schema.update(type='string', format='json-string') + + def __repr__(self) -> str: + return 'Json' + + def __hash__(self) -> int: + return hash(type(self)) + + def __eq__(self, other: Any) -> bool: + return type(other) == type(self) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SECRET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +SecretType = TypeVar('SecretType', str, bytes) + + +class SecretField(abc.ABC, Generic[SecretType]): + _error_kind: str + + def __init__(self, secret_value: SecretType) -> None: + self._secret_value: SecretType = secret_value + + def get_secret_value(self) -> SecretType: + return self._secret_value + + @classmethod + def __get_pydantic_core_schema__(cls, **_kwargs: Any) -> core_schema.FunctionSchema: + validator = SecretFieldValidator(cls) + if issubclass(cls, SecretStr): + # Use a lambda here so that `apply_metadata` can be called on the validator before the override is generated + override = lambda: core_schema.str_schema( # noqa E731 + min_length=validator.min_length, + max_length=validator.max_length, + ) + elif issubclass(cls, SecretBytes): + override = lambda: core_schema.bytes_schema( # noqa E731 + min_length=validator.min_length, + max_length=validator.max_length, + ) + else: + override = None + metadata = build_metadata_dict( + update_cs_function=validator.__pydantic_update_schema__, + js_metadata=JsonSchemaMetadata(core_schema_override=override), + ) + return core_schema.function_after_schema( + core_schema.union_schema( + core_schema.is_instance_schema(cls), + cls._pre_core_schema(), + strict=True, + custom_error_type=cls._error_kind, + ), + validator, + metadata=metadata, + serialization=core_schema.function_plain_ser_schema(cls._serialize, json_return_type='str'), + ) + + @classmethod + def _serialize( + cls, value: SecretField[SecretType], info: core_schema.SerializationInfo + ) -> str | SecretField[SecretType]: + if info.mode == 'json': + # we want the output to always be string without the `b'` prefix for byties, + # hence we just use `secret_display` + return secret_display(value) + else: + return value + + @classmethod + @abc.abstractmethod + def _pre_core_schema(cls) -> core_schema.CoreSchema: + ... + + @classmethod + def __pydantic_modify_json_schema__(cls, field_schema: dict[str, Any]) -> None: + update_not_none( + field_schema, + type='string', + writeOnly=True, + format='password', + ) + + def __eq__(self, other: Any) -> bool: + return isinstance(other, self.__class__) and self.get_secret_value() == other.get_secret_value() + + def __hash__(self) -> int: + return hash(self.get_secret_value()) + + def __len__(self) -> int: + return len(self._secret_value) + + @abc.abstractmethod + def _display(self) -> SecretType: + ... + + def __str__(self) -> str: + return str(self._display()) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self._display()!r})' + + +def secret_display(secret_field: SecretField[Any]) -> str: + return '**********' if secret_field.get_secret_value() else '' + + +class SecretFieldValidator(_fields.CustomValidator, Generic[SecretType]): + __slots__ = 'field_type', 'min_length', 'max_length', 'error_prefix' + + def __init__( + self, field_type: Type[SecretField[SecretType]], min_length: int | None = None, max_length: int | None = None + ) -> None: + self.field_type: Type[SecretField[SecretType]] = field_type + self.min_length = min_length + self.max_length = max_length + self.error_prefix: Literal['string', 'bytes'] = 'string' if field_type is SecretStr else 'bytes' + + def __call__(self, __value: SecretField[SecretType] | SecretType, _: core_schema.ValidationInfo) -> Any: + if self.min_length is not None and len(__value) < self.min_length: + short_kind: core_schema.ErrorType = f'{self.error_prefix}_too_short' # type: ignore[assignment] + raise PydanticKnownError(short_kind, {'min_length': self.min_length}) + if self.max_length is not None and len(__value) > self.max_length: + long_kind: core_schema.ErrorType = f'{self.error_prefix}_too_long' # type: ignore[assignment] + raise PydanticKnownError(long_kind, {'max_length': self.max_length}) + + if isinstance(__value, self.field_type): + return __value + else: + return self.field_type(__value) # type: ignore[arg-type] + + def __pydantic_update_schema__(self, schema: core_schema.CoreSchema, **constraints: Any) -> None: + self._update_attrs(constraints, {'min_length', 'max_length'}) + + +class SecretStr(SecretField[str]): + _error_kind = 'string_type' + + @classmethod + def _pre_core_schema(cls) -> core_schema.CoreSchema: + return core_schema.str_schema() + + def _display(self) -> str: + return secret_display(self) + + +class SecretBytes(SecretField[bytes]): + _error_kind = 'bytes_type' + + @classmethod + def _pre_core_schema(cls) -> core_schema.CoreSchema: + return core_schema.bytes_schema() + + def _display(self) -> bytes: + return secret_display(self).encode() + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PAYMENT CARD TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class PaymentCardBrand(str, Enum): + # If you add another card type, please also add it to the + # Hypothesis strategy in `pydantic._hypothesis_plugin`. + amex = 'American Express' + mastercard = 'Mastercard' + visa = 'Visa' + other = 'other' + + def __str__(self) -> str: + return self.value + + +class PaymentCardNumber(str): + """ + Based on: https://en.wikipedia.org/wiki/Payment_card_number + """ + + strip_whitespace: ClassVar[bool] = True + min_length: ClassVar[int] = 12 + max_length: ClassVar[int] = 19 + bin: str + last4: str + brand: PaymentCardBrand + + def __init__(self, card_number: str): + self.validate_digits(card_number) + + card_number = self.validate_luhn_check_digit(card_number) + + self.bin = card_number[:6] + self.last4 = card_number[-4:] + self.brand = self.validate_brand(card_number) + + @classmethod + def __get_pydantic_core_schema__(cls, **_kwargs: Any) -> core_schema.FunctionSchema: + return core_schema.function_after_schema( + core_schema.str_schema( + min_length=cls.min_length, max_length=cls.max_length, strip_whitespace=cls.strip_whitespace + ), + cls.validate, + ) + + @classmethod + def validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> 'PaymentCardNumber': + return cls(__input_value) + + @property + def masked(self) -> str: + num_masked = len(self) - 10 # len(bin) + len(last4) == 10 + return f'{self.bin}{"*" * num_masked}{self.last4}' + + @classmethod + def validate_digits(cls, card_number: str) -> None: + if not card_number.isdigit(): + raise PydanticCustomError('payment_card_number_digits', 'Card number is not all digits') + + @classmethod + def validate_luhn_check_digit(cls, card_number: str) -> str: + """ + Based on: https://en.wikipedia.org/wiki/Luhn_algorithm + """ + sum_ = int(card_number[-1]) + length = len(card_number) + parity = length % 2 + for i in range(length - 1): + digit = int(card_number[i]) + if i % 2 == parity: + digit *= 2 + if digit > 9: + digit -= 9 + sum_ += digit + valid = sum_ % 10 == 0 + if not valid: + raise PydanticCustomError('payment_card_number_luhn', 'Card number is not luhn valid') + return card_number + + @staticmethod + def validate_brand(card_number: str) -> PaymentCardBrand: + """ + Validate length based on BIN for major brands: + https://en.wikipedia.org/wiki/Payment_card_number#Issuer_identification_number_(IIN) + """ + if card_number[0] == '4': + brand = PaymentCardBrand.visa + elif 51 <= int(card_number[:2]) <= 55: + brand = PaymentCardBrand.mastercard + elif card_number[:2] in {'34', '37'}: + brand = PaymentCardBrand.amex + else: + brand = PaymentCardBrand.other + + required_length: Union[None, int, str] = None + if brand in PaymentCardBrand.mastercard: + required_length = 16 + valid = len(card_number) == required_length + elif brand == PaymentCardBrand.visa: + required_length = '13, 16 or 19' + valid = len(card_number) in {13, 16, 19} + elif brand == PaymentCardBrand.amex: + required_length = 15 + valid = len(card_number) == required_length + else: + valid = True + + if not valid: + raise PydanticCustomError( + 'payment_card_number_brand', + 'Length for a {brand} card must be {required_length}', + {'brand': brand, 'required_length': required_length}, + ) + return brand + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTE SIZE TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +BYTE_SIZES = { + 'b': 1, + 'kb': 10**3, + 'mb': 10**6, + 'gb': 10**9, + 'tb': 10**12, + 'pb': 10**15, + 'eb': 10**18, + 'kib': 2**10, + 'mib': 2**20, + 'gib': 2**30, + 'tib': 2**40, + 'pib': 2**50, + 'eib': 2**60, +} +BYTE_SIZES.update({k.lower()[0]: v for k, v in BYTE_SIZES.items() if 'i' not in k}) +byte_string_re = re.compile(r'^\s*(\d*\.?\d+)\s*(\w+)?', re.IGNORECASE) + + +class ByteSize(int): + @classmethod + def __get_pydantic_core_schema__(cls, **_kwargs: Any) -> core_schema.FunctionPlainSchema: + # TODO better schema + return core_schema.function_plain_schema(cls.validate) + + @classmethod + def validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> 'ByteSize': + try: + return cls(int(__input_value)) + except ValueError: + pass + + str_match = byte_string_re.match(str(__input_value)) + if str_match is None: + raise PydanticCustomError('byte_size', 'could not parse value and unit from byte string') + + scalar, unit = str_match.groups() + if unit is None: + unit = 'b' + + try: + unit_mult = BYTE_SIZES[unit.lower()] + except KeyError: + raise PydanticCustomError('byte_size_unit', 'could not interpret byte unit: {unit}', {'unit': unit}) + + return cls(int(float(scalar) * unit_mult)) + + def human_readable(self, decimal: bool = False) -> str: + if decimal: + divisor = 1000 + units = 'B', 'KB', 'MB', 'GB', 'TB', 'PB' + final_unit = 'EB' + else: + divisor = 1024 + units = 'B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB' + final_unit = 'EiB' + + num = float(self) + for unit in units: + if abs(num) < divisor: + if unit == 'B': + return f'{num:0.0f}{unit}' + else: + return f'{num:0.1f}{unit}' + num /= divisor + + return f'{num:0.1f}{final_unit}' + + def to(self, unit: str) -> float: + try: + unit_div = BYTE_SIZES[unit.lower()] + except KeyError: + raise PydanticCustomError('byte_size_unit', 'Could not interpret byte unit: {unit}', {'unit': unit}) + + return self / unit_div + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DATE TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +if TYPE_CHECKING: + PastDate = Annotated[date, ...] + FutureDate = Annotated[date, ...] +else: + + class PastDate: + @classmethod + def __get_pydantic_core_schema__( + cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any + ) -> core_schema.CoreSchema: + if schema is None: + # used directly as a type + return core_schema.date_schema(now_op='past') + else: + assert schema['type'] == 'date' + schema['now_op'] = 'past' + return schema + + def __repr__(self) -> str: + return 'PastDate' + + class FutureDate: + @classmethod + def __get_pydantic_core_schema__( + cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any + ) -> core_schema.CoreSchema: + if schema is None: + # used directly as a type + return core_schema.date_schema(now_op='future') + else: + assert schema['type'] == 'date' + schema['now_op'] = 'future' + return schema + + def __repr__(self) -> str: + return 'FutureDate' + + +def condate(*, strict: bool = None, gt: date = None, ge: date = None, lt: date = None, le: date = None) -> type[date]: + return Annotated[ # type: ignore[return-value] + date, + Strict(strict) if strict is not None else None, + annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le), + ] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DATETIME TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +if TYPE_CHECKING: + AwareDatetime = Annotated[datetime, ...] + NaiveDatetime = Annotated[datetime, ...] +else: + + class AwareDatetime: + @classmethod + def __get_pydantic_core_schema__( + cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any + ) -> core_schema.CoreSchema: + if schema is None: + # used directly as a type + return core_schema.datetime_schema(tz_constraint='aware') + else: + assert schema['type'] == 'datetime' + schema['tz_constraint'] = 'aware' + return schema + + def __repr__(self) -> str: + return 'AwareDatetime' + + class NaiveDatetime: + @classmethod + def __get_pydantic_core_schema__( + cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any + ) -> core_schema.CoreSchema: + if schema is None: + # used directly as a type + return core_schema.datetime_schema(tz_constraint='naive') + else: + assert schema['type'] == 'datetime' + schema['tz_constraint'] = 'naive' + return schema + + def __repr__(self) -> str: + return 'NaiveDatetime' diff --git a/crates/ruff_benchmark/resources/pypinyin.py b/crates/ruff_benchmark/resources/pypinyin.py new file mode 100644 index 0000000000..4dfe41f9d7 --- /dev/null +++ b/crates/ruff_benchmark/resources/pypinyin.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +处理汉语拼音方案中的一些特殊情况 + + +汉语拼音方案: + +* https://zh.wiktionary.org/wiki/%E9%99%84%E5%BD%95:%E6%B1%89%E8%AF%AD%E6%8B%BC%E9%9F%B3%E6%96%B9%E6%A1%88 +* http://www.moe.edu.cn/s78/A19/yxs_left/moe_810/s230/195802/t19580201_186000.html +""" # noqa +from __future__ import unicode_literals + +import re + +from pypinyin.style._constants import _FINALS + +# u -> ü +UV_MAP = { + 'u': 'ü', + 'ū': 'ǖ', + 'ú': 'ǘ', + 'ǔ': 'ǚ', + 'ù': 'ǜ', +} +U_TONES = set(UV_MAP.keys()) +# ü行的韵跟声母j,q,x拼的时候,写成ju(居),qu(区),xu(虚) +UV_RE = re.compile( + r'^(j|q|x)({tones})(.*)$'.format( + tones='|'.join(UV_MAP.keys()))) +I_TONES = set(['i', 'ī', 'í', 'ǐ', 'ì']) + +# iu -> iou +IU_MAP = { + 'iu': 'iou', + 'iū': 'ioū', + 'iú': 'ioú', + 'iǔ': 'ioǔ', + 'iù': 'ioù', +} +IU_TONES = set(IU_MAP.keys()) +IU_RE = re.compile(r'^([a-z]+)({tones})$'.format(tones='|'.join(IU_TONES))) + +# ui -> uei +UI_MAP = { + 'ui': 'uei', + 'uī': 'ueī', + 'uí': 'ueí', + 'uǐ': 'ueǐ', + 'uì': 'ueì', +} +UI_TONES = set(UI_MAP.keys()) +UI_RE = re.compile(r'([a-z]+)({tones})$'.format(tones='|'.join(UI_TONES))) + +# un -> uen +UN_MAP = { + 'un': 'uen', + 'ūn': 'ūen', + 'ún': 'úen', + 'ǔn': 'ǔen', + 'ùn': 'ùen', +} +UN_TONES = set(UN_MAP.keys()) +UN_RE = re.compile(r'([a-z]+)({tones})$'.format(tones='|'.join(UN_TONES))) + + +def convert_zero_consonant(pinyin): + """零声母转换,还原原始的韵母 + + i行的韵母,前面没有声母的时候,写成yi(衣),ya(呀),ye(耶),yao(腰), + you(忧),yan(烟),yin(因),yang(央),ying(英),yong(雍)。 + + u行的韵母,前面没有声母的时候,写成wu(乌),wa(蛙),wo(窝),wai(歪), + wei(威),wan(弯),wen(温),wang(汪),weng(翁)。 + + ü行的韵母,前面没有声母的时候,写成yu(迂),yue(约),yuan(冤), + yun(晕);ü上两点省略。 + """ + raw_pinyin = pinyin + # y: yu -> v, yi -> i, y -> i + if raw_pinyin.startswith('y'): + # 去除 y 后的拼音 + no_y_py = pinyin[1:] + first_char = no_y_py[0] if len(no_y_py) > 0 else None + + # yu -> ü: yue -> üe + if first_char in U_TONES: + pinyin = UV_MAP[first_char] + pinyin[2:] + # yi -> i: yi -> i + elif first_char in I_TONES: + pinyin = no_y_py + # y -> i: ya -> ia + else: + pinyin = 'i' + no_y_py + + # w: wu -> u, w -> u + if raw_pinyin.startswith('w'): + # 去除 w 后的拼音 + no_w_py = pinyin[1:] + first_char = no_w_py[0] if len(no_w_py) > 0 else None + + # wu -> u: wu -> u + if first_char in U_TONES: + pinyin = pinyin[1:] + # w -> u: wa -> ua + else: + pinyin = 'u' + pinyin[1:] + + # 确保不会出现韵母表中不存在的韵母 + if pinyin not in _FINALS: + return raw_pinyin + + return pinyin + + +def convert_uv(pinyin): + """ü 转换,还原原始的韵母 + + ü行的韵跟声母j,q,x拼的时候,写成ju(居),qu(区),xu(虚), + ü上两点也省略;但是跟声母n,l拼的时候,仍然写成nü(女),lü(吕)。 + """ + return UV_RE.sub( + lambda m: ''.join((m.group(1), UV_MAP[m.group(2)], m.group(3))), + pinyin) + + +def convert_iou(pinyin): + """iou 转换,还原原始的韵母 + + iou,uei,uen前面加声母的时候,写成iu,ui,un。 + 例如niu(牛),gui(归),lun(论)。 + """ + return IU_RE.sub(lambda m: m.group(1) + IU_MAP[m.group(2)], pinyin) + + +def convert_uei(pinyin): + """uei 转换,还原原始的韵母 + + iou,uei,uen前面加声母的时候,写成iu,ui,un。 + 例如niu(牛),gui(归),lun(论)。 + """ + return UI_RE.sub(lambda m: m.group(1) + UI_MAP[m.group(2)], pinyin) + + +def convert_uen(pinyin): + """uen 转换,还原原始的韵母 + + iou,uei,uen前面加声母的时候,写成iu,ui,un。 + 例如niu(牛),gui(归),lun(论)。 + """ + return UN_RE.sub(lambda m: m.group(1) + UN_MAP[m.group(2)], pinyin) + + +def convert_finals(pinyin): + """还原原始的韵母""" + pinyin = convert_zero_consonant(pinyin) + pinyin = convert_uv(pinyin) + pinyin = convert_iou(pinyin) + pinyin = convert_uei(pinyin) + pinyin = convert_uen(pinyin) + return pinyin diff --git a/crates/ruff_benchmark/resources/tomllib/__init__.py b/crates/ruff_benchmark/resources/tomllib/__init__.py new file mode 100644 index 0000000000..ef91cb9d25 --- /dev/null +++ b/crates/ruff_benchmark/resources/tomllib/__init__.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: MIT +# SPDX-FileCopyrightText: 2021 Taneli Hukkinen +# Licensed to PSF under a Contributor Agreement. + +__all__ = ("loads", "load", "TOMLDecodeError") + +from ._parser import TOMLDecodeError, load, loads + +# Pretend this exception was created here. +TOMLDecodeError.__module__ = __name__ diff --git a/crates/ruff_benchmark/resources/tomllib/_parser.py b/crates/ruff_benchmark/resources/tomllib/_parser.py new file mode 100644 index 0000000000..45ca7a8963 --- /dev/null +++ b/crates/ruff_benchmark/resources/tomllib/_parser.py @@ -0,0 +1,691 @@ +# SPDX-License-Identifier: MIT +# SPDX-FileCopyrightText: 2021 Taneli Hukkinen +# Licensed to PSF under a Contributor Agreement. + +from __future__ import annotations + +from collections.abc import Iterable +import string +from types import MappingProxyType +from typing import Any, BinaryIO, NamedTuple + +from ._re import ( + RE_DATETIME, + RE_LOCALTIME, + RE_NUMBER, + match_to_datetime, + match_to_localtime, + match_to_number, +) +from ._types import Key, ParseFloat, Pos + +ASCII_CTRL = frozenset(chr(i) for i in range(32)) | frozenset(chr(127)) + +# Neither of these sets include quotation mark or backslash. They are +# currently handled as separate cases in the parser functions. +ILLEGAL_BASIC_STR_CHARS = ASCII_CTRL - frozenset("\t") +ILLEGAL_MULTILINE_BASIC_STR_CHARS = ASCII_CTRL - frozenset("\t\n") + +ILLEGAL_LITERAL_STR_CHARS = ILLEGAL_BASIC_STR_CHARS +ILLEGAL_MULTILINE_LITERAL_STR_CHARS = ILLEGAL_MULTILINE_BASIC_STR_CHARS + +ILLEGAL_COMMENT_CHARS = ILLEGAL_BASIC_STR_CHARS + +TOML_WS = frozenset(" \t") +TOML_WS_AND_NEWLINE = TOML_WS | frozenset("\n") +BARE_KEY_CHARS = frozenset(string.ascii_letters + string.digits + "-_") +KEY_INITIAL_CHARS = BARE_KEY_CHARS | frozenset("\"'") +HEXDIGIT_CHARS = frozenset(string.hexdigits) + +BASIC_STR_ESCAPE_REPLACEMENTS = MappingProxyType( + { + "\\b": "\u0008", # backspace + "\\t": "\u0009", # tab + "\\n": "\u000A", # linefeed + "\\f": "\u000C", # form feed + "\\r": "\u000D", # carriage return + '\\"': "\u0022", # quote + "\\\\": "\u005C", # backslash + } +) + + +class TOMLDecodeError(ValueError): + """An error raised if a document is not valid TOML.""" + + +def load(fp: BinaryIO, /, *, parse_float: ParseFloat = float) -> dict[str, Any]: + """Parse TOML from a binary file object.""" + b = fp.read() + try: + s = b.decode() + except AttributeError: + raise TypeError( + "File must be opened in binary mode, e.g. use `open('foo.toml', 'rb')`" + ) from None + return loads(s, parse_float=parse_float) + + +def loads(s: str, /, *, parse_float: ParseFloat = float) -> dict[str, Any]: # noqa: C901 + """Parse TOML from a string.""" + + # The spec allows converting "\r\n" to "\n", even in string + # literals. Let's do so to simplify parsing. + src = s.replace("\r\n", "\n") + pos = 0 + out = Output(NestedDict(), Flags()) + header: Key = () + parse_float = make_safe_parse_float(parse_float) + + # Parse one statement at a time + # (typically means one line in TOML source) + while True: + # 1. Skip line leading whitespace + pos = skip_chars(src, pos, TOML_WS) + + # 2. Parse rules. Expect one of the following: + # - end of file + # - end of line + # - comment + # - key/value pair + # - append dict to list (and move to its namespace) + # - create dict (and move to its namespace) + # Skip trailing whitespace when applicable. + try: + char = src[pos] + except IndexError: + break + if char == "\n": + pos += 1 + continue + if char in KEY_INITIAL_CHARS: + pos = key_value_rule(src, pos, out, header, parse_float) + pos = skip_chars(src, pos, TOML_WS) + elif char == "[": + try: + second_char: str | None = src[pos + 1] + except IndexError: + second_char = None + out.flags.finalize_pending() + if second_char == "[": + pos, header = create_list_rule(src, pos, out) + else: + pos, header = create_dict_rule(src, pos, out) + pos = skip_chars(src, pos, TOML_WS) + elif char != "#": + raise suffixed_err(src, pos, "Invalid statement") + + # 3. Skip comment + pos = skip_comment(src, pos) + + # 4. Expect end of line or end of file + try: + char = src[pos] + except IndexError: + break + if char != "\n": + raise suffixed_err( + src, pos, "Expected newline or end of document after a statement" + ) + pos += 1 + + return out.data.dict + + +class Flags: + """Flags that map to parsed keys/namespaces.""" + + # Marks an immutable namespace (inline array or inline table). + FROZEN = 0 + # Marks a nest that has been explicitly created and can no longer + # be opened using the "[table]" syntax. + EXPLICIT_NEST = 1 + + def __init__(self) -> None: + self._flags: dict[str, dict] = {} + self._pending_flags: set[tuple[Key, int]] = set() + + def add_pending(self, key: Key, flag: int) -> None: + self._pending_flags.add((key, flag)) + + def finalize_pending(self) -> None: + for key, flag in self._pending_flags: + self.set(key, flag, recursive=False) + self._pending_flags.clear() + + def unset_all(self, key: Key) -> None: + cont = self._flags + for k in key[:-1]: + if k not in cont: + return + cont = cont[k]["nested"] + cont.pop(key[-1], None) + + def set(self, key: Key, flag: int, *, recursive: bool) -> None: # noqa: A003 + cont = self._flags + key_parent, key_stem = key[:-1], key[-1] + for k in key_parent: + if k not in cont: + cont[k] = {"flags": set(), "recursive_flags": set(), "nested": {}} + cont = cont[k]["nested"] + if key_stem not in cont: + cont[key_stem] = {"flags": set(), "recursive_flags": set(), "nested": {}} + cont[key_stem]["recursive_flags" if recursive else "flags"].add(flag) + + def is_(self, key: Key, flag: int) -> bool: + if not key: + return False # document root has no flags + cont = self._flags + for k in key[:-1]: + if k not in cont: + return False + inner_cont = cont[k] + if flag in inner_cont["recursive_flags"]: + return True + cont = inner_cont["nested"] + key_stem = key[-1] + if key_stem in cont: + cont = cont[key_stem] + return flag in cont["flags"] or flag in cont["recursive_flags"] + return False + + +class NestedDict: + def __init__(self) -> None: + # The parsed content of the TOML document + self.dict: dict[str, Any] = {} + + def get_or_create_nest( + self, + key: Key, + *, + access_lists: bool = True, + ) -> dict: + cont: Any = self.dict + for k in key: + if k not in cont: + cont[k] = {} + cont = cont[k] + if access_lists and isinstance(cont, list): + cont = cont[-1] + if not isinstance(cont, dict): + raise KeyError("There is no nest behind this key") + return cont + + def append_nest_to_list(self, key: Key) -> None: + cont = self.get_or_create_nest(key[:-1]) + last_key = key[-1] + if last_key in cont: + list_ = cont[last_key] + if not isinstance(list_, list): + raise KeyError("An object other than list found behind this key") + list_.append({}) + else: + cont[last_key] = [{}] + + +class Output(NamedTuple): + data: NestedDict + flags: Flags + + +def skip_chars(src: str, pos: Pos, chars: Iterable[str]) -> Pos: + try: + while src[pos] in chars: + pos += 1 + except IndexError: + pass + return pos + + +def skip_until( + src: str, + pos: Pos, + expect: str, + *, + error_on: frozenset[str], + error_on_eof: bool, +) -> Pos: + try: + new_pos = src.index(expect, pos) + except ValueError: + new_pos = len(src) + if error_on_eof: + raise suffixed_err(src, new_pos, f"Expected {expect!r}") from None + + if not error_on.isdisjoint(src[pos:new_pos]): + while src[pos] not in error_on: + pos += 1 + raise suffixed_err(src, pos, f"Found invalid character {src[pos]!r}") + return new_pos + + +def skip_comment(src: str, pos: Pos) -> Pos: + try: + char: str | None = src[pos] + except IndexError: + char = None + if char == "#": + return skip_until( + src, pos + 1, "\n", error_on=ILLEGAL_COMMENT_CHARS, error_on_eof=False + ) + return pos + + +def skip_comments_and_array_ws(src: str, pos: Pos) -> Pos: + while True: + pos_before_skip = pos + pos = skip_chars(src, pos, TOML_WS_AND_NEWLINE) + pos = skip_comment(src, pos) + if pos == pos_before_skip: + return pos + + +def create_dict_rule(src: str, pos: Pos, out: Output) -> tuple[Pos, Key]: + pos += 1 # Skip "[" + pos = skip_chars(src, pos, TOML_WS) + pos, key = parse_key(src, pos) + + if out.flags.is_(key, Flags.EXPLICIT_NEST) or out.flags.is_(key, Flags.FROZEN): + raise suffixed_err(src, pos, f"Cannot declare {key} twice") + out.flags.set(key, Flags.EXPLICIT_NEST, recursive=False) + try: + out.data.get_or_create_nest(key) + except KeyError: + raise suffixed_err(src, pos, "Cannot overwrite a value") from None + + if not src.startswith("]", pos): + raise suffixed_err(src, pos, "Expected ']' at the end of a table declaration") + return pos + 1, key + + +def create_list_rule(src: str, pos: Pos, out: Output) -> tuple[Pos, Key]: + pos += 2 # Skip "[[" + pos = skip_chars(src, pos, TOML_WS) + pos, key = parse_key(src, pos) + + if out.flags.is_(key, Flags.FROZEN): + raise suffixed_err(src, pos, f"Cannot mutate immutable namespace {key}") + # Free the namespace now that it points to another empty list item... + out.flags.unset_all(key) + # ...but this key precisely is still prohibited from table declaration + out.flags.set(key, Flags.EXPLICIT_NEST, recursive=False) + try: + out.data.append_nest_to_list(key) + except KeyError: + raise suffixed_err(src, pos, "Cannot overwrite a value") from None + + if not src.startswith("]]", pos): + raise suffixed_err(src, pos, "Expected ']]' at the end of an array declaration") + return pos + 2, key + + +def key_value_rule( + src: str, pos: Pos, out: Output, header: Key, parse_float: ParseFloat +) -> Pos: + pos, key, value = parse_key_value_pair(src, pos, parse_float) + key_parent, key_stem = key[:-1], key[-1] + abs_key_parent = header + key_parent + + relative_path_cont_keys = (header + key[:i] for i in range(1, len(key))) + for cont_key in relative_path_cont_keys: + # Check that dotted key syntax does not redefine an existing table + if out.flags.is_(cont_key, Flags.EXPLICIT_NEST): + raise suffixed_err(src, pos, f"Cannot redefine namespace {cont_key}") + # Containers in the relative path can't be opened with the table syntax or + # dotted key/value syntax in following table sections. + out.flags.add_pending(cont_key, Flags.EXPLICIT_NEST) + + if out.flags.is_(abs_key_parent, Flags.FROZEN): + raise suffixed_err( + src, pos, f"Cannot mutate immutable namespace {abs_key_parent}" + ) + + try: + nest = out.data.get_or_create_nest(abs_key_parent) + except KeyError: + raise suffixed_err(src, pos, "Cannot overwrite a value") from None + if key_stem in nest: + raise suffixed_err(src, pos, "Cannot overwrite a value") + # Mark inline table and array namespaces recursively immutable + if isinstance(value, (dict, list)): + out.flags.set(header + key, Flags.FROZEN, recursive=True) + nest[key_stem] = value + return pos + + +def parse_key_value_pair( + src: str, pos: Pos, parse_float: ParseFloat +) -> tuple[Pos, Key, Any]: + pos, key = parse_key(src, pos) + try: + char: str | None = src[pos] + except IndexError: + char = None + if char != "=": + raise suffixed_err(src, pos, "Expected '=' after a key in a key/value pair") + pos += 1 + pos = skip_chars(src, pos, TOML_WS) + pos, value = parse_value(src, pos, parse_float) + return pos, key, value + + +def parse_key(src: str, pos: Pos) -> tuple[Pos, Key]: + pos, key_part = parse_key_part(src, pos) + key: Key = (key_part,) + pos = skip_chars(src, pos, TOML_WS) + while True: + try: + char: str | None = src[pos] + except IndexError: + char = None + if char != ".": + return pos, key + pos += 1 + pos = skip_chars(src, pos, TOML_WS) + pos, key_part = parse_key_part(src, pos) + key += (key_part,) + pos = skip_chars(src, pos, TOML_WS) + + +def parse_key_part(src: str, pos: Pos) -> tuple[Pos, str]: + try: + char: str | None = src[pos] + except IndexError: + char = None + if char in BARE_KEY_CHARS: + start_pos = pos + pos = skip_chars(src, pos, BARE_KEY_CHARS) + return pos, src[start_pos:pos] + if char == "'": + return parse_literal_str(src, pos) + if char == '"': + return parse_one_line_basic_str(src, pos) + raise suffixed_err(src, pos, "Invalid initial character for a key part") + + +def parse_one_line_basic_str(src: str, pos: Pos) -> tuple[Pos, str]: + pos += 1 + return parse_basic_str(src, pos, multiline=False) + + +def parse_array(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, list]: + pos += 1 + array: list = [] + + pos = skip_comments_and_array_ws(src, pos) + if src.startswith("]", pos): + return pos + 1, array + while True: + pos, val = parse_value(src, pos, parse_float) + array.append(val) + pos = skip_comments_and_array_ws(src, pos) + + c = src[pos : pos + 1] + if c == "]": + return pos + 1, array + if c != ",": + raise suffixed_err(src, pos, "Unclosed array") + pos += 1 + + pos = skip_comments_and_array_ws(src, pos) + if src.startswith("]", pos): + return pos + 1, array + + +def parse_inline_table(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, dict]: + pos += 1 + nested_dict = NestedDict() + flags = Flags() + + pos = skip_chars(src, pos, TOML_WS) + if src.startswith("}", pos): + return pos + 1, nested_dict.dict + while True: + pos, key, value = parse_key_value_pair(src, pos, parse_float) + key_parent, key_stem = key[:-1], key[-1] + if flags.is_(key, Flags.FROZEN): + raise suffixed_err(src, pos, f"Cannot mutate immutable namespace {key}") + try: + nest = nested_dict.get_or_create_nest(key_parent, access_lists=False) + except KeyError: + raise suffixed_err(src, pos, "Cannot overwrite a value") from None + if key_stem in nest: + raise suffixed_err(src, pos, f"Duplicate inline table key {key_stem!r}") + nest[key_stem] = value + pos = skip_chars(src, pos, TOML_WS) + c = src[pos : pos + 1] + if c == "}": + return pos + 1, nested_dict.dict + if c != ",": + raise suffixed_err(src, pos, "Unclosed inline table") + if isinstance(value, (dict, list)): + flags.set(key, Flags.FROZEN, recursive=True) + pos += 1 + pos = skip_chars(src, pos, TOML_WS) + + +def parse_basic_str_escape( + src: str, pos: Pos, *, multiline: bool = False +) -> tuple[Pos, str]: + escape_id = src[pos : pos + 2] + pos += 2 + if multiline and escape_id in {"\\ ", "\\\t", "\\\n"}: + # Skip whitespace until next non-whitespace character or end of + # the doc. Error if non-whitespace is found before newline. + if escape_id != "\\\n": + pos = skip_chars(src, pos, TOML_WS) + try: + char = src[pos] + except IndexError: + return pos, "" + if char != "\n": + raise suffixed_err(src, pos, "Unescaped '\\' in a string") + pos += 1 + pos = skip_chars(src, pos, TOML_WS_AND_NEWLINE) + return pos, "" + if escape_id == "\\u": + return parse_hex_char(src, pos, 4) + if escape_id == "\\U": + return parse_hex_char(src, pos, 8) + try: + return pos, BASIC_STR_ESCAPE_REPLACEMENTS[escape_id] + except KeyError: + raise suffixed_err(src, pos, "Unescaped '\\' in a string") from None + + +def parse_basic_str_escape_multiline(src: str, pos: Pos) -> tuple[Pos, str]: + return parse_basic_str_escape(src, pos, multiline=True) + + +def parse_hex_char(src: str, pos: Pos, hex_len: int) -> tuple[Pos, str]: + hex_str = src[pos : pos + hex_len] + if len(hex_str) != hex_len or not HEXDIGIT_CHARS.issuperset(hex_str): + raise suffixed_err(src, pos, "Invalid hex value") + pos += hex_len + hex_int = int(hex_str, 16) + if not is_unicode_scalar_value(hex_int): + raise suffixed_err(src, pos, "Escaped character is not a Unicode scalar value") + return pos, chr(hex_int) + + +def parse_literal_str(src: str, pos: Pos) -> tuple[Pos, str]: + pos += 1 # Skip starting apostrophe + start_pos = pos + pos = skip_until( + src, pos, "'", error_on=ILLEGAL_LITERAL_STR_CHARS, error_on_eof=True + ) + return pos + 1, src[start_pos:pos] # Skip ending apostrophe + + +def parse_multiline_str(src: str, pos: Pos, *, literal: bool) -> tuple[Pos, str]: + pos += 3 + if src.startswith("\n", pos): + pos += 1 + + if literal: + delim = "'" + end_pos = skip_until( + src, + pos, + "'''", + error_on=ILLEGAL_MULTILINE_LITERAL_STR_CHARS, + error_on_eof=True, + ) + result = src[pos:end_pos] + pos = end_pos + 3 + else: + delim = '"' + pos, result = parse_basic_str(src, pos, multiline=True) + + # Add at maximum two extra apostrophes/quotes if the end sequence + # is 4 or 5 chars long instead of just 3. + if not src.startswith(delim, pos): + return pos, result + pos += 1 + if not src.startswith(delim, pos): + return pos, result + delim + pos += 1 + return pos, result + (delim * 2) + + +def parse_basic_str(src: str, pos: Pos, *, multiline: bool) -> tuple[Pos, str]: + if multiline: + error_on = ILLEGAL_MULTILINE_BASIC_STR_CHARS + parse_escapes = parse_basic_str_escape_multiline + else: + error_on = ILLEGAL_BASIC_STR_CHARS + parse_escapes = parse_basic_str_escape + result = "" + start_pos = pos + while True: + try: + char = src[pos] + except IndexError: + raise suffixed_err(src, pos, "Unterminated string") from None + if char == '"': + if not multiline: + return pos + 1, result + src[start_pos:pos] + if src.startswith('"""', pos): + return pos + 3, result + src[start_pos:pos] + pos += 1 + continue + if char == "\\": + result += src[start_pos:pos] + pos, parsed_escape = parse_escapes(src, pos) + result += parsed_escape + start_pos = pos + continue + if char in error_on: + raise suffixed_err(src, pos, f"Illegal character {char!r}") + pos += 1 + + +def parse_value( # noqa: C901 + src: str, pos: Pos, parse_float: ParseFloat +) -> tuple[Pos, Any]: + try: + char: str | None = src[pos] + except IndexError: + char = None + + # IMPORTANT: order conditions based on speed of checking and likelihood + + # Basic strings + if char == '"': + if src.startswith('"""', pos): + return parse_multiline_str(src, pos, literal=False) + return parse_one_line_basic_str(src, pos) + + # Literal strings + if char == "'": + if src.startswith("'''", pos): + return parse_multiline_str(src, pos, literal=True) + return parse_literal_str(src, pos) + + # Booleans + if char == "t": + if src.startswith("true", pos): + return pos + 4, True + if char == "f": + if src.startswith("false", pos): + return pos + 5, False + + # Arrays + if char == "[": + return parse_array(src, pos, parse_float) + + # Inline tables + if char == "{": + return parse_inline_table(src, pos, parse_float) + + # Dates and times + datetime_match = RE_DATETIME.match(src, pos) + if datetime_match: + try: + datetime_obj = match_to_datetime(datetime_match) + except ValueError as e: + raise suffixed_err(src, pos, "Invalid date or datetime") from e + return datetime_match.end(), datetime_obj + localtime_match = RE_LOCALTIME.match(src, pos) + if localtime_match: + return localtime_match.end(), match_to_localtime(localtime_match) + + # Integers and "normal" floats. + # The regex will greedily match any type starting with a decimal + # char, so needs to be located after handling of dates and times. + number_match = RE_NUMBER.match(src, pos) + if number_match: + return number_match.end(), match_to_number(number_match, parse_float) + + # Special floats + first_three = src[pos : pos + 3] + if first_three in {"inf", "nan"}: + return pos + 3, parse_float(first_three) + first_four = src[pos : pos + 4] + if first_four in {"-inf", "+inf", "-nan", "+nan"}: + return pos + 4, parse_float(first_four) + + raise suffixed_err(src, pos, "Invalid value") + + +def suffixed_err(src: str, pos: Pos, msg: str) -> TOMLDecodeError: + """Return a `TOMLDecodeError` where error message is suffixed with + coordinates in source.""" + + def coord_repr(src: str, pos: Pos) -> str: + if pos >= len(src): + return "end of document" + line = src.count("\n", 0, pos) + 1 + if line == 1: + column = pos + 1 + else: + column = pos - src.rindex("\n", 0, pos) + return f"line {line}, column {column}" + + return TOMLDecodeError(f"{msg} (at {coord_repr(src, pos)})") + + +def is_unicode_scalar_value(codepoint: int) -> bool: + return (0 <= codepoint <= 55295) or (57344 <= codepoint <= 1114111) + + +def make_safe_parse_float(parse_float: ParseFloat) -> ParseFloat: + """A decorator to make `parse_float` safe. + + `parse_float` must not return dicts or lists, because these types + would be mixed with parsed TOML tables and arrays, thus confusing + the parser. The returned decorated callable raises `ValueError` + instead of returning illegal types. + """ + # The default `float` callable never returns illegal types. Optimize it. + if parse_float is float: # type: ignore[comparison-overlap] + return float + + def safe_parse_float(float_str: str) -> Any: + float_value = parse_float(float_str) + if isinstance(float_value, (dict, list)): + raise ValueError("parse_float must not return dicts or lists") + return float_value + + return safe_parse_float diff --git a/crates/ruff_benchmark/resources/tomllib/_re.py b/crates/ruff_benchmark/resources/tomllib/_re.py new file mode 100644 index 0000000000..994bb7493f --- /dev/null +++ b/crates/ruff_benchmark/resources/tomllib/_re.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: MIT +# SPDX-FileCopyrightText: 2021 Taneli Hukkinen +# Licensed to PSF under a Contributor Agreement. + +from __future__ import annotations + +from datetime import date, datetime, time, timedelta, timezone, tzinfo +from functools import lru_cache +import re +from typing import Any + +from ._types import ParseFloat + +# E.g. +# - 00:32:00.999999 +# - 00:32:00 +_TIME_RE_STR = r"([01][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])(?:\.([0-9]{1,6})[0-9]*)?" + +RE_NUMBER = re.compile( + r""" +0 +(?: + x[0-9A-Fa-f](?:_?[0-9A-Fa-f])* # hex + | + b[01](?:_?[01])* # bin + | + o[0-7](?:_?[0-7])* # oct +) +| +[+-]?(?:0|[1-9](?:_?[0-9])*) # dec, integer part +(?P + (?:\.[0-9](?:_?[0-9])*)? # optional fractional part + (?:[eE][+-]?[0-9](?:_?[0-9])*)? # optional exponent part +) +""", + flags=re.VERBOSE, +) +RE_LOCALTIME = re.compile(_TIME_RE_STR) +RE_DATETIME = re.compile( + rf""" +([0-9]{{4}})-(0[1-9]|1[0-2])-(0[1-9]|[12][0-9]|3[01]) # date, e.g. 1988-10-27 +(?: + [Tt ] + {_TIME_RE_STR} + (?:([Zz])|([+-])([01][0-9]|2[0-3]):([0-5][0-9]))? # optional time offset +)? +""", + flags=re.VERBOSE, +) + + +def match_to_datetime(match: re.Match) -> datetime | date: + """Convert a `RE_DATETIME` match to `datetime.datetime` or `datetime.date`. + + Raises ValueError if the match does not correspond to a valid date + or datetime. + """ + ( + year_str, + month_str, + day_str, + hour_str, + minute_str, + sec_str, + micros_str, + zulu_time, + offset_sign_str, + offset_hour_str, + offset_minute_str, + ) = match.groups() + year, month, day = int(year_str), int(month_str), int(day_str) + if hour_str is None: + return date(year, month, day) + hour, minute, sec = int(hour_str), int(minute_str), int(sec_str) + micros = int(micros_str.ljust(6, "0")) if micros_str else 0 + if offset_sign_str: + tz: tzinfo | None = cached_tz( + offset_hour_str, offset_minute_str, offset_sign_str + ) + elif zulu_time: + tz = timezone.utc + else: # local date-time + tz = None + return datetime(year, month, day, hour, minute, sec, micros, tzinfo=tz) + + +@lru_cache(maxsize=None) +def cached_tz(hour_str: str, minute_str: str, sign_str: str) -> timezone: + sign = 1 if sign_str == "+" else -1 + return timezone( + timedelta( + hours=sign * int(hour_str), + minutes=sign * int(minute_str), + ) + ) + + +def match_to_localtime(match: re.Match) -> time: + hour_str, minute_str, sec_str, micros_str = match.groups() + micros = int(micros_str.ljust(6, "0")) if micros_str else 0 + return time(int(hour_str), int(minute_str), int(sec_str), micros) + + +def match_to_number(match: re.Match, parse_float: ParseFloat) -> Any: + if match.group("floatpart"): + return parse_float(match.group()) + return int(match.group(), 0) diff --git a/crates/ruff_benchmark/resources/tomllib/_types.py b/crates/ruff_benchmark/resources/tomllib/_types.py new file mode 100644 index 0000000000..d949412e03 --- /dev/null +++ b/crates/ruff_benchmark/resources/tomllib/_types.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: MIT +# SPDX-FileCopyrightText: 2021 Taneli Hukkinen +# Licensed to PSF under a Contributor Agreement. + +from typing import Any, Callable, Tuple + +# Type annotations +ParseFloat = Callable[[str], Any] +Key = Tuple[str, ...] +Pos = int diff --git a/crates/ruff_benchmark/src/lib.rs b/crates/ruff_benchmark/src/lib.rs index cb6236c29a..3ecde5e8f8 100644 --- a/crates/ruff_benchmark/src/lib.rs +++ b/crates/ruff_benchmark/src/lib.rs @@ -1,10 +1,32 @@ +use std::path::PathBuf; + pub mod criterion; -use std::fmt::{Display, Formatter}; -use std::path::PathBuf; -use std::process::Command; +pub static NUMPY_GLOBALS: TestFile = TestFile::new( + "numpy/globals.py", + include_str!("../resources/numpy/globals.py"), +); -use url::Url; +pub static UNICODE_PYPINYIN: TestFile = TestFile::new( + "unicode/pypinyin.py", + include_str!("../resources/pypinyin.py"), +); + +pub static PYDANTIC_TYPES: TestFile = TestFile::new( + "pydantic/types.py", + include_str!("../resources/pydantic/types.py"), +); + +pub static NUMPY_CTYPESLIB: TestFile = TestFile::new( + "numpy/ctypeslib.py", + include_str!("../resources/numpy/ctypeslib.py"), +); + +// "https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py" +pub static LARGE_DATASET: TestFile = TestFile::new( + "large/dataset.py", + include_str!("../resources/large/dataset.py"), +); /// Relative size of a test case. Benchmarks can use it to configure the time for how long a benchmark should run to get stable results. #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] @@ -26,35 +48,33 @@ pub struct TestCase { } impl TestCase { - pub fn fast(file: TestFile) -> Self { + pub const fn fast(file: TestFile) -> Self { Self { file, speed: TestCaseSpeed::Fast, } } - pub fn normal(file: TestFile) -> Self { + pub const fn normal(file: TestFile) -> Self { Self { file, speed: TestCaseSpeed::Normal, } } - pub fn slow(file: TestFile) -> Self { + pub const fn slow(file: TestFile) -> Self { Self { file, speed: TestCaseSpeed::Slow, } } -} -impl TestCase { pub fn code(&self) -> &str { - &self.file.code + self.file.code } pub fn name(&self) -> &str { - &self.file.name + self.file.name } pub fn speed(&self) -> TestCaseSpeed { @@ -62,119 +82,32 @@ impl TestCase { } pub fn path(&self) -> PathBuf { - TARGET_DIR.join(self.name()) + PathBuf::from(file!()) + .parent() + .unwrap() + .parent() + .unwrap() + .join("resources") + .join(self.name()) } } #[derive(Debug, Clone)] pub struct TestFile { - name: String, - code: String, + name: &'static str, + code: &'static str, } impl TestFile { - pub fn code(&self) -> &str { - &self.code - } - - pub fn name(&self) -> &str { - &self.name - } -} - -static TARGET_DIR: std::sync::LazyLock = std::sync::LazyLock::new(|| { - cargo_target_directory().unwrap_or_else(|| PathBuf::from("target")) -}); - -fn cargo_target_directory() -> Option { - #[derive(serde::Deserialize)] - struct Metadata { - target_directory: PathBuf, - } - - std::env::var_os("CARGO_TARGET_DIR") - .map(PathBuf::from) - .or_else(|| { - let output = Command::new(std::env::var_os("CARGO")?) - .args(["metadata", "--format-version", "1"]) - .output() - .ok()?; - let metadata: Metadata = serde_json::from_slice(&output.stdout).ok()?; - Some(metadata.target_directory) - }) -} - -impl TestFile { - pub fn new(name: String, code: String) -> Self { + pub const fn new(name: &'static str, code: &'static str) -> Self { Self { name, code } } - #[allow(clippy::print_stderr)] - pub fn try_download(name: &str, url: &str) -> Result { - let url = Url::parse(url)?; + pub fn code(&self) -> &str { + self.code + } - let cached_filename = TARGET_DIR.join(name); - - if let Ok(content) = std::fs::read_to_string(&cached_filename) { - Ok(TestFile::new(name.to_string(), content)) - } else { - // File not yet cached, download and cache it in the target directory - let response = ureq::get(url.as_str()).call()?; - - let content = response.into_string()?; - - // SAFETY: There's always the `target` directory - let parent = cached_filename.parent().unwrap(); - if let Err(error) = std::fs::create_dir_all(parent) { - eprintln!("Failed to create the directory for the test case {name}: {error}"); - } else if let Err(error) = std::fs::write(cached_filename, &content) { - eprintln!("Failed to cache test case file downloaded from {url}: {error}"); - } - - Ok(TestFile::new(name.to_string(), content)) - } + pub fn name(&self) -> &str { + self.name } } - -#[derive(Debug)] -pub enum TestFileDownloadError { - UrlParse(url::ParseError), - Request(Box), - Download(std::io::Error), -} - -impl From for TestFileDownloadError { - fn from(value: url::ParseError) -> Self { - Self::UrlParse(value) - } -} - -impl From for TestFileDownloadError { - fn from(value: ureq::Error) -> Self { - Self::Request(Box::new(value)) - } -} - -impl From for TestFileDownloadError { - fn from(value: std::io::Error) -> Self { - Self::Download(value) - } -} - -impl Display for TestFileDownloadError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - TestFileDownloadError::UrlParse(inner) => { - write!(f, "Failed to parse url: {inner}") - } - TestFileDownloadError::Request(inner) => { - write!(f, "Failed to download file: {inner}") - } - TestFileDownloadError::Download(inner) => { - write!(f, "Failed to download file: {inner}") - } - } - } -} - -impl std::error::Error for TestFileDownloadError {}