SERVER-111295 Set python as formatter in format_multirun (#41677)

GitOrigin-RevId: fd3c58d1f5a9230a9fb728d2678c8c614c20437f
This commit is contained in:
Zack Winter 2025-09-22 17:19:49 -07:00 committed by MongoDB Bot
parent 471b0c8d7a
commit 5c24a13a7d
84 changed files with 990 additions and 753 deletions

1
.gitattributes vendored
View File

@ -6,6 +6,7 @@
external rules-lint-ignored=true
**/*.tpl.h rules-lint-ignored=true
**/*.tpl.cpp rules-lint-ignored=true
rpm/*.spec rules-lint-ignored=true
src/mongo/bson/column/bson_column_compressed_data.inl rules-lint-ignored=true
*.idl linguist-language=yaml

View File

@ -17,11 +17,15 @@ bazel_cache = os.path.expanduser(args.bazel_cache)
# the cc_library and cc_binaries in our build. There is not a good way from
# within the build to get all those targets, so we will generate the list via query
# https://sig-product-docs.synopsys.com/bundle/coverity-docs/page/coverity-analysis/topics/building_with_bazel.html#build_with_bazel
cmd = [
cmd = (
[
bazel_executable,
bazel_cache,
"aquery",
] + bazel_cmd_args + [args.bazel_query]
]
+ bazel_cmd_args
+ [args.bazel_query]
)
print(f"Running command: {cmd}")
proc = subprocess.run(
cmd,
@ -33,9 +37,7 @@ proc = subprocess.run(
print(proc.stderr)
targets = set()
with open('coverity_targets.list', 'w') as f:
with open("coverity_targets.list", "w") as f:
for line in proc.stdout.splitlines():
if line.startswith(" Target: "):
f.write(line.split()[-1] + "\n")

View File

@ -36,6 +36,7 @@ format_multirun(
html = "//:prettier",
javascript = "//:prettier",
markdown = "//:prettier",
python = "@aspect_rules_lint//format:ruff",
shell = "@shfmt//:shfmt",
sql = "//:prettier",
starlark = "@buildifier_prebuilt//:buildifier",

View File

@ -212,7 +212,8 @@ def main() -> int:
return (
0
if run_prettier(prettier_path, args.check, files_to_format) and run_rules_lint(
if run_prettier(prettier_path, args.check, files_to_format)
and run_rules_lint(
args.rules_lint_format, args.rules_lint_format_check, args.check, files_to_format
)
else 1

View File

@ -86,7 +86,9 @@ def install(src, install_type):
if exc.strerror == "Invalid argument":
print("Encountered OSError: Invalid argument. Retrying...")
time.sleep(1)
os.link(os.path.join(root, name), os.path.join(dest_dir, name))
os.link(
os.path.join(root, name), os.path.join(dest_dir, name)
)
else:
try:
os.link(src, dst)

View File

@ -19,12 +19,13 @@ import yaml
# --------------------------- YAML helpers ---------------------------
def load_yaml(file_path: str | pathlib.Path) -> Dict[str, Any]:
path = pathlib.Path(file_path)
if not path.exists():
raise SystemExit(f"Error: Config file '{file_path}' not found.")
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f) or {}
def load_yaml(file_path: str | pathlib.Path) -> Dict[str, Any]:
path = pathlib.Path(file_path)
if not path.exists():
raise SystemExit(f"Error: Config file '{file_path}' not found.")
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f) or {}
def split_checks_to_list(value: Any) -> List[str]:
@ -40,8 +41,9 @@ def split_checks_to_list(value: Any) -> List[str]:
return [s for s in parts if s]
def merge_checks_into_config(target_config: Dict[str, Any],
incoming_config: Dict[str, Any]) -> None:
def merge_checks_into_config(
target_config: Dict[str, Any], incoming_config: Dict[str, Any]
) -> None:
"""Append incoming Checks onto target Checks (string-concatenated)."""
accumulated = split_checks_to_list(target_config.get("Checks"))
additions = split_checks_to_list(incoming_config.get("Checks"))
@ -59,8 +61,9 @@ def check_options_list_to_map(value: Any) -> Dict[str, Any]:
return out
def merge_check_options_into_config(target_config: Dict[str, Any],
incoming_config: Dict[str, Any]) -> None:
def merge_check_options_into_config(
target_config: Dict[str, Any], incoming_config: Dict[str, Any]
) -> None:
"""
Merge CheckOptions so later configs override earlier by 'key'.
Stores back as list[{key,value}] sorted by key for determinism.
@ -69,9 +72,7 @@ def merge_check_options_into_config(target_config: Dict[str, Any],
override = check_options_list_to_map(incoming_config.get("CheckOptions"))
if override:
base.update(override) # later wins
target_config["CheckOptions"] = [
{"key": k, "value": v} for k, v in sorted(base.items())
]
target_config["CheckOptions"] = [{"key": k, "value": v} for k, v in sorted(base.items())]
def deep_merge_dicts(base: Any, override: Any) -> Any:
@ -91,57 +92,58 @@ def deep_merge_dicts(base: Any, override: Any) -> Any:
return override
# --------------------------- path helpers ---------------------------
def is_ancestor_directory(ancestor: pathlib.Path, descendant: pathlib.Path) -> bool:
"""
True if 'ancestor' is the same as or a parent of 'descendant'.
Resolution ensures symlinks and relative parts are normalized.
"""
try:
ancestor = ancestor.resolve()
descendant = descendant.resolve()
except FileNotFoundError:
# If either path doesn't exist yet, still resolve purely lexicaly
ancestor = ancestor.absolute()
descendant = descendant.absolute()
return ancestor == descendant or ancestor in descendant.parents
def filter_and_sort_config_paths(
config_paths: list[str | pathlib.Path],
scope_directory: str | None
) -> list[pathlib.Path]:
"""
Keep only config files whose parent directory is an ancestor
of the provided scope directory.
Sort shallow deep so deeper configs apply later and override earlier ones.
If scope_directory is None, keep paths in the order given.
"""
config_paths = [pathlib.Path(p) for p in config_paths]
if not scope_directory:
return config_paths
workspace_root = pathlib.Path.cwd().resolve()
scope_abs = (workspace_root / scope_directory).resolve()
selected: list[tuple[int, pathlib.Path]] = []
for cfg in config_paths:
parent_dir = cfg.parent
if is_ancestor_directory(parent_dir, scope_abs):
# Depth is number of path components from root
selected.append((len(parent_dir.parts), cfg.resolve()))
# Sort by depth ascending so root-most files merge first
selected.sort(key=lambda t: t[0])
# --------------------------- path helpers ---------------------------
def is_ancestor_directory(ancestor: pathlib.Path, descendant: pathlib.Path) -> bool:
"""
True if 'ancestor' is the same as or a parent of 'descendant'.
Resolution ensures symlinks and relative parts are normalized.
"""
try:
ancestor = ancestor.resolve()
descendant = descendant.resolve()
except FileNotFoundError:
# If either path doesn't exist yet, still resolve purely lexicaly
ancestor = ancestor.absolute()
descendant = descendant.absolute()
return ancestor == descendant or ancestor in descendant.parents
def filter_and_sort_config_paths(
config_paths: list[str | pathlib.Path], scope_directory: str | None
) -> list[pathlib.Path]:
"""
Keep only config files whose parent directory is an ancestor
of the provided scope directory.
Sort shallow deep so deeper configs apply later and override earlier ones.
If scope_directory is None, keep paths in the order given.
"""
config_paths = [pathlib.Path(p) for p in config_paths]
if not scope_directory:
return config_paths
workspace_root = pathlib.Path.cwd().resolve()
scope_abs = (workspace_root / scope_directory).resolve()
selected: list[tuple[int, pathlib.Path]] = []
for cfg in config_paths:
parent_dir = cfg.parent
if is_ancestor_directory(parent_dir, scope_abs):
# Depth is number of path components from root
selected.append((len(parent_dir.parts), cfg.resolve()))
# Sort by depth ascending so root-most files merge first
selected.sort(key=lambda t: t[0])
return [cfg for _, cfg in selected]
# --------------------------- main ---------------------------
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--baseline", required=True, help="Baseline clang-tidy YAML.")
@ -175,7 +177,7 @@ def main() -> None:
# then generic merge:
merged_config = deep_merge_dicts(merged_config, incoming_config)
merged_config["Checks"] = ",".join(split_checks_to_list(merged_config.get("Checks")))
merged_config["Checks"] = ",".join(split_checks_to_list(merged_config.get("Checks")))
output_path = pathlib.Path(args.out)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:

View File

@ -48,10 +48,11 @@ def add_evergreen_build_info(args):
add_volatile_arg(args, "--versionId=", "version_id")
add_volatile_arg(args, "--requester=", "requester")
class ResmokeShimContext:
def __init__(self):
self.links = []
def __enter__(self):
# Use the Bazel provided TEST_TMPDIR. Note this must occur after uses of acquire_local_resource
# which relies on a shared temporary directory among all test shards.
@ -86,6 +87,7 @@ class ResmokeShimContext:
p = psutil.Process(pid)
signal_python(new_resmoke_logger(), p.name, pid)
if __name__ == "__main__":
sys.argv[0] = (
"buildscripts/resmoke.py" # Ensure resmoke's local invocation is printed using resmoke.py directly
@ -127,7 +129,9 @@ if __name__ == "__main__":
lock, base_port = acquire_local_resource("port_block")
resmoke_args.append(f"--basePort={base_port}")
resmoke_args.append(f"--archiveDirectory={os.path.join(os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'), 'data_archives')}")
resmoke_args.append(
f"--archiveDirectory={os.path.join(os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'), 'data_archives')}"
)
if (
os.path.isfile("bazel/resmoke/test_runtimes.json")

View File

@ -45,7 +45,6 @@ def run_pty_command(cmd):
def generate_compiledb(bazel_bin, persistent_compdb, enterprise):
# compiledb ignores command line args so just make a version rc file in anycase
write_mongo_variables_bazelrc([])
if persistent_compdb:
@ -215,7 +214,6 @@ def generate_compiledb(bazel_bin, persistent_compdb, enterprise):
else:
shutil.copyfile(pathlib.Path("bazel-bin") / ".clang-tidy", clang_tidy_file)
if platform.system() == "Linux":
# TODO: SERVER-110144 optimize this to only generate the extensions source code
# instead of build the extension target entirely.

View File

@ -77,7 +77,7 @@ def write_workstation_bazelrc(args):
filtered_args = args[1:]
if "--" in filtered_args:
filtered_args = filtered_args[:filtered_args.index("--")] + ["--", "(REDACTED)"]
filtered_args = filtered_args[: filtered_args.index("--")] + ["--", "(REDACTED)"]
developer_build = os.environ.get("CI") is None
b64_cmd_line = base64.b64encode(json.dumps(filtered_args).encode()).decode()

View File

@ -88,11 +88,13 @@ def search_for_modules(deps, deps_installed, lockfile_changed=False):
wrapper_debug(f"deps_not_found: {deps_not_found}")
return deps_not_found
def skip_cplusplus_toolchain(args):
if any("no_c++_toolchain" in arg for arg in args):
return True
return False
def install_modules(bazel, args):
need_to_install = False
pwd_hash = hashlib.md5(str(REPO_ROOT).encode()).hexdigest()

View File

@ -10,14 +10,29 @@ from typing import List
REPO_ROOT = pathlib.Path(__file__).parent.parent.parent
sys.path.append(str(REPO_ROOT))
LARGE_FILE_THRESHOLD = 10 * 1024 * 1024 #10MiB
LARGE_FILE_THRESHOLD = 10 * 1024 * 1024 # 10MiB
SUPPORTED_EXTENSIONS = (".cpp", ".c", ".h", ".hpp", ".py", ".js", ".mjs", ".json", ".lock", ".toml", ".defs", ".inl", ".idl")
SUPPORTED_EXTENSIONS = (
".cpp",
".c",
".h",
".hpp",
".py",
".js",
".mjs",
".json",
".lock",
".toml",
".defs",
".inl",
".idl",
)
class LinterFail(Exception):
pass
def create_build_files_in_new_js_dirs() -> None:
base_dirs = ["src/mongo/db/modules/enterprise/jstests", "jstests"]
for base_dir in base_dirs:
@ -56,6 +71,7 @@ def list_files_with_targets(bazel_bin: str) -> List:
).stdout.splitlines()
]
class LintRunner:
def __init__(self, keep_going: bool, bazel_bin: str):
self.keep_going = keep_going
@ -205,6 +221,7 @@ def _get_files_changed_since_fork_point(origin_branch: str = "origin/master") ->
return list(file_set)
def get_parsed_args(args):
parser = argparse.ArgumentParser()
parser.add_argument(
@ -242,11 +259,7 @@ def get_parsed_args(args):
default="origin/master",
help="Base branch to compare changes against",
)
parser.add_argument(
"--large-files",
action="store_true",
default=False
)
parser.add_argument("--large-files", action="store_true", default=False)
parser.add_argument(
"--keep-going",
action="store_true",
@ -255,11 +268,13 @@ def get_parsed_args(args):
)
return parser.parse_known_args(args)
def lint_mod(lint_runner: LintRunner):
lint_runner.run_bazel("//modules_poc:mod_mapping", ["--validate-modules"])
#TODO add support for the following steps
#subprocess.run([bazel_bin, "run", "//modules_poc:merge_decls"], check=True)
#subprocess.run([bazel_bin, "run", "//modules_poc:browse", "--", "merged_decls.json", "--parse-only"], check=True)
# TODO add support for the following steps
# subprocess.run([bazel_bin, "run", "//modules_poc:merge_decls"], check=True)
# subprocess.run([bazel_bin, "run", "//modules_poc:browse", "--", "merged_decls.json", "--parse-only"], check=True)
def run_rules_lint(bazel_bin: str, args: List[str]):
parsed_args, args = get_parsed_args(args)
@ -276,10 +291,16 @@ def run_rules_lint(bazel_bin: str, args: List[str]):
files_with_targets = list_files_with_targets(bazel_bin)
lr.list_files_without_targets(files_with_targets, "C++", "cpp", ["src/mongo"])
lr.list_files_without_targets(
files_with_targets, "javascript", "js", ["src/mongo", "jstests"],
files_with_targets,
"javascript",
"js",
["src/mongo", "jstests"],
)
lr.list_files_without_targets(
files_with_targets, "python", "py", ["src/mongo", "buildscripts", "evergreen"],
files_with_targets,
"python",
"py",
["src/mongo", "buildscripts", "evergreen"],
)
lint_all = parsed_args.all or "..." in args or "//..." in args
files_to_lint = [arg for arg in args if not arg.startswith("-")]
@ -309,8 +330,7 @@ def run_rules_lint(bazel_bin: str, args: List[str]):
lr.run_bazel("//buildscripts:quickmongolint", ["lint"])
if lint_all or any(
file.endswith((".cpp", ".c", ".h", ".py", ".idl"))
for file in files_to_lint
file.endswith((".cpp", ".c", ".h", ".py", ".idl")) for file in files_to_lint
):
lr.run_bazel("//buildscripts:errorcodes", ["--quiet"])
@ -323,15 +343,19 @@ def run_rules_lint(bazel_bin: str, args: List[str]):
lr.run_bazel("//buildscripts:poetry_lock_check")
if lint_all or any(file.endswith(".yml") for file in files_to_lint):
lr.run_bazel("buildscripts:validate_evg_project_config", [f"--evg-project-name={parsed_args.lint_yaml_project}", "--evg-auth-config=.evergreen.yml"])
lr.run_bazel(
"buildscripts:validate_evg_project_config",
[
f"--evg-project-name={parsed_args.lint_yaml_project}",
"--evg-auth-config=.evergreen.yml",
],
)
if lint_all or parsed_args.large_files:
lr.run_bazel("buildscripts:large_file_check", ["--exclude", "src/third_party/*"])
else:
lr.simple_file_size_check(files_to_lint)
if lint_all or any(
file.endswith((".cpp", ".c", ".h", ".hpp", ".idl", ".inl", ".defs"))
for file in files_to_lint

View File

@ -21,6 +21,7 @@ class BinAndSourceIncompatible(Exception):
class DuplicateSourceNames(Exception):
pass
def get_buildozer_output(autocomplete_query):
from buildscripts.install_bazel import install_bazel

View File

@ -13,6 +13,7 @@ ARCH_NORMALIZE_MAP = {
"s390x": "s390x",
}
def get_mongo_arch(args):
arch = platform.machine().lower()
if arch in ARCH_NORMALIZE_MAP:
@ -20,14 +21,16 @@ def get_mongo_arch(args):
else:
return arch
def get_mongo_version(args):
proc = subprocess.run(["git", "describe", "--abbrev=0"], capture_output=True, text=True)
return proc.stdout.strip()[1:]
def write_mongo_variables_bazelrc(args):
mongo_version = get_mongo_version(args)
mongo_arch = get_mongo_arch(args)
repo_root = pathlib.Path(os.path.abspath(__file__)).parent.parent.parent
version_file = os.path.join(repo_root, ".bazelrc.mongo_variables")
existing_hash = ""
@ -42,4 +45,4 @@ common --define=MONGO_VERSION={mongo_version}
current_hash = hashlib.md5(bazelrc_contents.encode()).hexdigest()
if existing_hash != current_hash:
with open(version_file, "w", encoding="utf-8") as f:
f.write(bazelrc_contents)
f.write(bazelrc_contents)

View File

@ -23,7 +23,7 @@ def create_tarball(output_filename, file_patterns, exclude_patterns):
else:
for f in found_files:
if os.path.isfile(f) or os.path.islink(f):
included_files.add(f)
included_files.add(f)
except Exception as e:
print(f"Error processing pattern '{pattern}': {e}", file=sys.stderr)
@ -45,14 +45,19 @@ def create_tarball(output_filename, file_patterns, exclude_patterns):
if shutil.which("pigz"):
with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8") as tmp_file:
for file in sorted(list(files_to_add)):
tmp_file.write(file + '\n')
tmp_file.write(file + "\n")
tmp_file.flush()
tar_command = ["tar", "--dereference", "--use-compress-program", "pigz", "-cf", output_filename, "-T", tmp_file.name]
subprocess.run(
tar_command,
check=True,
text=True
)
tar_command = [
"tar",
"--dereference",
"--use-compress-program",
"pigz",
"-cf",
output_filename,
"-T",
tmp_file.name,
]
subprocess.run(tar_command, check=True, text=True)
else:
print("pigz not found. Using serial compression")
with tarfile.open(output_filename, "w:gz", dereference=True) as tar:
@ -74,28 +79,26 @@ if __name__ == "__main__":
)
parser.add_argument(
"-o", "--output",
"-o",
"--output",
required=True,
help="The name of the output tarball file (e.g., archive.tar.gz)."
help="The name of the output tarball file (e.g., archive.tar.gz).",
)
parser.add_argument("--base_dir", default=".", help="Directory to run in.")
parser.add_argument(
"--base_dir",
default=".",
help="Directory to run in."
)
parser.add_argument(
"-e", "--exclude",
action='append',
"-e",
"--exclude",
action="append",
default=[],
help="A file pattern to exclude (e.g., '**/__pycache__/*'). Can be specified multiple times."
help="A file pattern to exclude (e.g., '**/__pycache__/*'). Can be specified multiple times.",
)
parser.add_argument(
"patterns",
nargs='+',
help="One or more file patterns to include. Use quotes around patterns with wildcards."
nargs="+",
help="One or more file patterns to include. Use quotes around patterns with wildcards.",
)
args = parser.parse_args()

View File

@ -171,6 +171,7 @@ def iter_clang_tidy_files(root: str | Path) -> list[Path]:
continue
return results
def validate_clang_tidy_configs(generate_report, fix):
buildozer = download_buildozer()
@ -193,14 +194,16 @@ def validate_clang_tidy_configs(generate_report, fix):
print(p.stderr)
raise Exception(f"could not parse tidy config targets from '{p.stdout}'")
if tidy_targets == ['']:
if tidy_targets == [""]:
tidy_targets = []
all_targets = []
for tidy_file in tidy_files:
tidy_file_target = "//" + os.path.dirname(os.path.join(mongo_dir, tidy_file)) + ":clang_tidy_config"
tidy_file_target = (
"//" + os.path.dirname(os.path.join(mongo_dir, tidy_file)) + ":clang_tidy_config"
)
all_targets.append(tidy_file_target)
if all_targets != tidy_targets:
msg = f"Incorrect clang tidy config targets: {all_targets} != {tidy_targets}"
print(msg)
@ -210,7 +213,9 @@ def validate_clang_tidy_configs(generate_report, fix):
put_report(report)
if fix:
subprocess.run([buildozer, f"set srcs {' '.join(all_targets)}", "//:clang_tidy_config_files"])
subprocess.run(
[buildozer, f"set srcs {' '.join(all_targets)}", "//:clang_tidy_config_files"]
)
def validate_bazel_groups(generate_report, fix):

View File

@ -18,6 +18,7 @@ from buildscripts.util.read_config import read_config_file
# depends_on is only evaluated on task creation/validation, so all dependencies must exist prior to streams_build_and_publish.
# Streams currently depends on multiple generated test suite tasks, which is why this task must also be generated.
def make_task(compile_variant: str, additional_dependencies: set[str]) -> Task:
commands = [
BuiltInCommand("manifest.load", {}),
@ -26,16 +27,21 @@ def make_task(compile_variant: str, additional_dependencies: set[str]) -> Task:
FunctionCall("set up venv"),
FunctionCall("fetch binaries"),
FunctionCall("extract binaries"),
FunctionCall("set up remote credentials", {
"aws_key_remote": "${repo_aws_key}",
"aws_secret_remote": "${repo_aws_secret}"
}),
BuiltInCommand("ec2.assume_role", {"role_arn": "arn:aws:iam::664315256653:role/mongo-tf-project"}),
BuiltInCommand("subprocess.exec", {
"add_expansions_to_env": True,
"binary": "bash",
"args": ["./src/evergreen/streams_image_push.sh"]
}),
FunctionCall(
"set up remote credentials",
{"aws_key_remote": "${repo_aws_key}", "aws_secret_remote": "${repo_aws_secret}"},
),
BuiltInCommand(
"ec2.assume_role", {"role_arn": "arn:aws:iam::664315256653:role/mongo-tf-project"}
),
BuiltInCommand(
"subprocess.exec",
{
"add_expansions_to_env": True,
"binary": "bash",
"args": ["./src/evergreen/streams_image_push.sh"],
},
),
]
dependencies = {
TaskDependency("archive_dist_test", compile_variant),
@ -46,6 +52,7 @@ def make_task(compile_variant: str, additional_dependencies: set[str]) -> Task:
dependencies.add(TaskDependency(dep))
return Task(f"streams_build_and_publish_{compile_variant}", commands, dependencies)
def main(
expansions_file: Annotated[str, typer.Argument()] = "expansions.yml",
output_file: Annotated[str, typer.Option("--output-file")] = "streams_build_and_publish.json",
@ -69,14 +76,14 @@ def main(
else:
# is not a display task
task_deps.append(task.display_name)
required_tasks.remove(task.display_name)
print(task_deps)
if required_tasks:
print("The following required tasks were not found", required_tasks)
raise RuntimeError("Could not find all required tasks")
distro = expansions.get("distro_id")
compile_variant_name = expansions.get("compile_variant")
current_task_name = expansions.get("task_name", "streams_build_and_publish_gen")

View File

@ -16,22 +16,28 @@ from buildscripts.util.read_config import read_config_file
# This file is for generating the task creates a docker manifest for the distro images produced via streams_build_and_publish.
# The docker manifest is used in order for the different architecture images to be pulled correctly without needing the particular architecture tag.
def make_task(compile_variant: str) -> Task:
commands = [
BuiltInCommand("manifest.load", {}),
FunctionCall("git get project and add git tag"),
FunctionCall("f_expansions_write"),
FunctionCall("set up venv"),
FunctionCall("set up remote credentials", {
"aws_key_remote": "${repo_aws_key}",
"aws_secret_remote": "${repo_aws_secret}"
}),
BuiltInCommand("ec2.assume_role", {"role_arn": "arn:aws:iam::664315256653:role/mongo-tf-project"}),
BuiltInCommand("subprocess.exec", {
"add_expansions_to_env": True,
"binary": "bash",
"args": ["./src/evergreen/streams_docker_manifest.sh"]
}),
FunctionCall(
"set up remote credentials",
{"aws_key_remote": "${repo_aws_key}", "aws_secret_remote": "${repo_aws_secret}"},
),
BuiltInCommand(
"ec2.assume_role", {"role_arn": "arn:aws:iam::664315256653:role/mongo-tf-project"}
),
BuiltInCommand(
"subprocess.exec",
{
"add_expansions_to_env": True,
"binary": "bash",
"args": ["./src/evergreen/streams_docker_manifest.sh"],
},
),
]
dependencies = {
TaskDependency(f"streams_build_and_publish_{compile_variant.replace('-arm64', '')}"),
@ -50,7 +56,7 @@ def main(
current_task_name = expansions.get("task_name", "streams_publish_manifest_gen")
compile_variant_name = expansions.get("compile_variant")
if (not compile_variant_name.endswith("-arm64")):
if not compile_variant_name.endswith("-arm64"):
raise RuntimeError("This task should only run on the arm64 compile variant")
build_variant = BuildVariant(name=build_variant_name)

View File

@ -34,9 +34,7 @@ def _relink_binaries_with_symbols(failed_tests: List[str]):
bazel_build_flags += " --remote_download_outputs=toplevel"
# Remap //src/mongo/testabc to //src/mongo:testabc
failed_test_labels = [
":".join(test.rsplit("/", 1)) for test in failed_tests
]
failed_test_labels = [":".join(test.rsplit("/", 1)) for test in failed_tests]
relink_command = [
arg for arg in ["bazel", "build", *bazel_build_flags.split(" "), *failed_test_labels] if arg
@ -53,12 +51,17 @@ def _relink_binaries_with_symbols(failed_tests: List[str]):
f.write(repro_test_command)
print(f"Repro command written to .failed_unittest_repro.txt: {repro_test_command}")
def _copy_bins_to_upload(failed_tests: List[str], upload_bin_dir: str, upload_lib_dir: str) -> bool:
success = True
bazel_bin_dir = Path("./bazel-bin/src")
# Search both in the top level remote exec shellscript wrapper output directory, and in the
# binary output directory.
failed_tests += [failed_test.replace("_remote_exec", "") for failed_test in failed_tests if "_remote_exec" in failed_test]
failed_tests += [
failed_test.replace("_remote_exec", "")
for failed_test in failed_tests
if "_remote_exec" in failed_test
]
for failed_test in failed_tests:
full_binary_path = bazel_bin_dir / failed_test
binary_name = failed_test.split(os.sep)[-1]

View File

@ -17,7 +17,6 @@ if not gdb:
def detect_toolchain(progspace):
readelf_bin = os.environ.get("MONGO_GDB_READELF", "/opt/mongodbtoolchain/v5/bin/llvm-readelf")
if not os.path.exists(readelf_bin):
readelf_bin = "readelf"
@ -76,10 +75,10 @@ STDERR:
-----------------
Assuming {toolchain_ver} as a default, this could cause issues with the printers.""")
base_toolchain_dir = os.environ.get("MONGO_GDB_PP_DIR", f"/opt/mongodbtoolchain/{toolchain_ver}/share")
pp = glob.glob(
f"{base_toolchain_dir}/gcc-*/python/libstdcxx/v6/printers.py"
base_toolchain_dir = os.environ.get(
"MONGO_GDB_PP_DIR", f"/opt/mongodbtoolchain/{toolchain_ver}/share"
)
pp = glob.glob(f"{base_toolchain_dir}/gcc-*/python/libstdcxx/v6/printers.py")
printers = pp[0]
return os.path.dirname(os.path.dirname(os.path.dirname(printers)))

View File

@ -25,6 +25,7 @@ from buildscripts.util.fileops import read_yaml_file
assert sys.version_info >= (3, 7)
@dataclass
class Output:
name: str
@ -35,6 +36,7 @@ class Output:
iso_date = ts.strftime("%Y-%m-%d %H:%M:%S")
return f"{self.name} {iso_date}"
class AppError(Exception):
"""Application execution error."""
@ -206,7 +208,7 @@ class GoldenTestApp(object):
if not os.path.isdir(self.output_parent_path):
return []
return [
Output(name = name, ctime = os.path.getctime(self.get_output_path(name)))
Output(name=name, ctime=os.path.getctime(self.get_output_path(name)))
for name in os.listdir(self.output_parent_path)
if re.match(self.output_name_regex, name)
and os.path.isdir(os.path.join(self.output_parent_path, name))
@ -216,12 +218,13 @@ class GoldenTestApp(object):
"""Return the output name wit most recent created timestamp."""
self.vprint("Searching for output with latest creation time")
outputs = self.get_outputs()
if (len(outputs) == 0):
if len(outputs) == 0:
raise AppError("No outputs found")
else:
latest = max(outputs, key=lambda output: output.ctime)
self.vprint(
f"Found output with latest creation time: {latest.name} " + f"created at {latest.ctime}"
f"Found output with latest creation time: {latest.name} "
+ f"created at {latest.ctime}"
)
return latest

View File

@ -1269,6 +1269,7 @@ def _bind_field(ctxt, parsed_spec, field):
ctxt.add_must_be_query_shape_component(ast_field, ast_field.type.name, ast_field.name)
return ast_field
def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct, nested_chained_parent=None):
# type: (errors.ParserContext, syntax.IDLSpec, ast.Struct, syntax.ChainedStruct, ast.Field) -> None
"""Bind the specified chained struct."""
@ -1291,7 +1292,6 @@ def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct, nested_c
ast_struct, ast_struct.name, chained_struct.name
)
# Configure a field for the chained struct.
ast_chained_field = ast.Field(ast_struct.file_name, ast_struct.line, ast_struct.column)
ast_chained_field.name = struct.name
@ -1302,7 +1302,9 @@ def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct, nested_c
if struct.chained_structs:
for nested_chained_struct in struct.chained_structs or []:
_bind_chained_struct(ctxt, parsed_spec, ast_struct, nested_chained_struct, ast_chained_field)
_bind_chained_struct(
ctxt, parsed_spec, ast_struct, nested_chained_struct, ast_chained_field
)
if nested_chained_parent:
ast_chained_field.nested_chained_parent = nested_chained_parent
@ -1712,13 +1714,19 @@ def _bind_feature_flags(ctxt, param):
ctxt.add_feature_flag_default_true_missing_version(param)
return None
if (param.enable_on_transitional_fcv_UNSAFE and
"(Enable on transitional FCV):" not in param.description):
if (
param.enable_on_transitional_fcv_UNSAFE
and "(Enable on transitional FCV):" not in param.description
):
ctxt.add_feature_flag_enabled_on_transitional_fcv_missing_safety_explanation(param)
return None
else:
# Feature flags that should not be FCV gated must not have unsupported options.
for option_name in ("version", "enable_on_transitional_fcv_UNSAFE", "fcv_context_unaware"):
for option_name in (
"version",
"enable_on_transitional_fcv_UNSAFE",
"fcv_context_unaware",
):
if getattr(param, option_name):
ctxt.add_feature_flag_fcv_gated_false_has_unsupported_option(param, option_name)
return None

View File

@ -200,7 +200,7 @@ class _EnumTypeInt(EnumTypeInfoBase, metaclass=ABCMeta):
# type: () -> str
cpp_type = self.get_cpp_type_name()
deserializer = self._get_enum_deserializer_name()
return f"{cpp_type} {deserializer}(std::int32_t value, const IDLParserContext& ctxt = IDLParserContext(\"{self.get_cpp_type_name()}\"))"
return f'{cpp_type} {deserializer}(std::int32_t value, const IDLParserContext& ctxt = IDLParserContext("{self.get_cpp_type_name()}"))'
def gen_deserializer_definition(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
@ -264,7 +264,7 @@ class _EnumTypeString(EnumTypeInfoBase, metaclass=ABCMeta):
# type: () -> str
cpp_type = self.get_cpp_type_name()
func = self._get_enum_deserializer_name()
return f"{cpp_type} {func}(StringData value, const IDLParserContext& ctxt = IDLParserContext(\"{cpp_type}\"))"
return f'{cpp_type} {func}(StringData value, const IDLParserContext& ctxt = IDLParserContext("{cpp_type}"))'
def gen_deserializer_definition(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None

View File

@ -1431,7 +1431,11 @@ class _CppHeaderFileWriter(_CppFileWriterBase):
# Write member variables
for field in struct.fields:
if not field.ignore and not field.chained_struct_field and not field.nested_chained_parent:
if (
not field.ignore
and not field.chained_struct_field
and not field.nested_chained_parent
):
if not (field.type and field.type.internal_only):
self.gen_member(field)
@ -2911,7 +2915,7 @@ class _CppSourceFileWriter(_CppFileWriterBase):
if field.chained_struct_field:
continue
if field.nested_chained_parent:
continue
@ -3238,19 +3242,17 @@ class _CppSourceFileWriter(_CppFileWriterBase):
for alias_no, alias in enumerate(param.deprecated_name):
varname = f"scp_{param_no}_deprecated_alias"
with self.get_initializer_lambda(
f"auto {varname}",
return_type="std::unique_ptr<ServerParameter>",
capture_ref=True,
):
f"auto {varname}",
return_type="std::unique_ptr<ServerParameter>",
capture_ref=True,
):
self._writer.write_line(
f"""\
auto {varname} = std::make_unique<IDLServerParameterDeprecatedAlias>({_encaps(alias)}, scp_{param_no}.get());
{varname}->setIsDeprecated(true);
return std::move({varname});"""
)
self._writer.write_line(
f"registerServerParameter(std::move({varname}));"
)
self._writer.write_line(f"registerServerParameter(std::move({varname}));")
def gen_server_parameters(self, params, header_file_name):
# type: (List[ast.ServerParameter], str) -> None

View File

@ -296,7 +296,7 @@ class _StructTypeInfo(StructTypeInfoBase):
"parseSharingOwnership",
[
"const BSONObj& bsonObject",
f"const IDLParserContext& ctxt = IDLParserContext(\"{self._struct.name}\")",
f'const IDLParserContext& ctxt = IDLParserContext("{self._struct.name}")',
"DeserializationContext* dctx = nullptr",
],
class_name,
@ -315,7 +315,7 @@ class _StructTypeInfo(StructTypeInfoBase):
"parseOwned",
[
"BSONObj&& bsonObject",
f"const IDLParserContext& ctxt = IDLParserContext(\"{self._struct.name}\")",
f'const IDLParserContext& ctxt = IDLParserContext("{self._struct.name}")',
"DeserializationContext* dctx = nullptr",
],
class_name,
@ -337,7 +337,7 @@ class _StructTypeInfo(StructTypeInfoBase):
"parse",
[
"const BSONObj& bsonObject",
f"const IDLParserContext& ctxt = IDLParserContext(\"{self._struct.name}\")",
f'const IDLParserContext& ctxt = IDLParserContext("{self._struct.name}")',
"DeserializationContext* dctx = nullptr",
],
class_name,
@ -353,7 +353,7 @@ class _StructTypeInfo(StructTypeInfoBase):
"parseProtected",
[
"const BSONObj& bsonObject",
f"const IDLParserContext& ctxt = IDLParserContext(\"{self._struct.name}\")",
f'const IDLParserContext& ctxt = IDLParserContext("{self._struct.name}")',
"DeserializationContext* dctx = nullptr",
],
"void",
@ -430,7 +430,7 @@ class _CommandBaseTypeInfo(_StructTypeInfo):
"parse",
[
"const OpMsgRequest& request",
f"const IDLParserContext& ctxt = IDLParserContext(\"{self._struct.name}\")",
f'const IDLParserContext& ctxt = IDLParserContext("{self._struct.name}")',
"DeserializationContext* dctx = nullptr",
],
class_name,
@ -445,7 +445,7 @@ class _CommandBaseTypeInfo(_StructTypeInfo):
"parseProtected",
[
"const OpMsgRequest& request",
f"const IDLParserContext& ctxt = IDLParserContext(\"{self._struct.name}\")",
f'const IDLParserContext& ctxt = IDLParserContext("{self._struct.name}")',
"DeserializationContext* dctx = nullptr",
],
"void",
@ -551,7 +551,7 @@ class _CommandFromType(_CommandBaseTypeInfo):
"parseProtected",
[
"const BSONObj& bsonObject",
f"const IDLParserContext& ctxt = IDLParserContext(\"{self._struct.name}\")",
f'const IDLParserContext& ctxt = IDLParserContext("{self._struct.name}")',
"DeserializationContext* dctx = nullptr",
],
"void",
@ -633,7 +633,7 @@ class _CommandWithNamespaceTypeInfo(_CommandBaseTypeInfo):
"parseProtected",
[
"const BSONObj& bsonObject",
f"const IDLParserContext& ctxt = IDLParserContext(\"{self._struct.name}\")",
f'const IDLParserContext& ctxt = IDLParserContext("{self._struct.name}")',
"DeserializationContext* dctx = nullptr",
],
"void",
@ -641,7 +641,9 @@ class _CommandWithNamespaceTypeInfo(_CommandBaseTypeInfo):
def gen_methods(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
indented_writer.write_line("void setNamespace(NamespaceString nss) { _nss = std::move(nss); }")
indented_writer.write_line(
"void setNamespace(NamespaceString nss) { _nss = std::move(nss); }"
)
indented_writer.write_line("const NamespaceString& getNamespace() const { return _nss; }")
if self._struct.non_const_getter:
indented_writer.write_line("NamespaceString& getNamespace() { return _nss; }")
@ -739,7 +741,7 @@ class _CommandWithUUIDNamespaceTypeInfo(_CommandBaseTypeInfo):
"parseProtected",
[
"const BSONObj& bsonObject",
f"const IDLParserContext& ctxt = IDLParserContext(\"{self._struct.name}\")",
f'const IDLParserContext& ctxt = IDLParserContext("{self._struct.name}")',
"DeserializationContext* dctx = nullptr",
],
"void",

View File

@ -855,7 +855,7 @@ class ServerParameter(common.SourceLocation):
self.validator = None # type: Validator
self.on_update = None # type: str
self.is_deprecated = False # type: bool
self.is_deprecated = False # type: bool
super(ServerParameter, self).__init__(file_name, line, column)

View File

@ -2896,7 +2896,7 @@ class TestBinder(testcase.IDLTestcase):
fcv_gated: true
enable_on_transitional_fcv_UNSAFE: true
"""),
idl.errors.ERROR_ID_FEATURE_FLAG_ENABLED_ON_TRANSITIONAL_FCV_MISSING_SAFETY_EXPLANATION
idl.errors.ERROR_ID_FEATURE_FLAG_ENABLED_ON_TRANSITIONAL_FCV_MISSING_SAFETY_EXPLANATION,
)
# if fcv_gated is false, fcv_context_unaware is not allowed

View File

@ -470,7 +470,9 @@ class TestGenerator(testcase.IDLTestcase):
""")
)
expected = dedent("constexpr inline auto kTestServerParameterName = \"testServerParameter\"_sd;")
expected = dedent(
'constexpr inline auto kTestServerParameterName = "testServerParameter"_sd;'
)
self.assertIn(expected, header)
def test_command_view_type_generates_anchor(self) -> None:
@ -1047,18 +1049,25 @@ class TestGenerator(testcase.IDLTestcase):
"""
)
)
self.assertStringsInFile(header, [
"mongo::NestedChainedBase& getNestedChainedBase() { return getNestedChainedBottom().getNestedChainedBase();",
"void setNestedChainedBase(mongo::NestedChainedBase value) {\n getNestedChainedBottom().setNestedChainedBase(std::move(value));",
"void setBase_field(std::int32_t value) {\n getNestedChainedBase().setBase_field(std::move(value));",
"mongo::NestedChainedBottom& getNestedChainedBottom() { return getNestedChainedMiddle().getNestedChainedBottom();",
"void setNestedChainedBottom(mongo::NestedChainedBottom value) {\n getNestedChainedMiddle().setNestedChainedBottom(std::move(value));",
])
self.assertStringsInFile(source, ["getNestedChainedBase().setBase_field(element._numberInt());",
"getNestedChainedBottom().setBottom_field(element._numberInt());",
"getNestedChainedMiddle().setMiddle_field(element.str());",
"_top_field = element.boolean();",
])
self.assertStringsInFile(
header,
[
"mongo::NestedChainedBase& getNestedChainedBase() { return getNestedChainedBottom().getNestedChainedBase();",
"void setNestedChainedBase(mongo::NestedChainedBase value) {\n getNestedChainedBottom().setNestedChainedBase(std::move(value));",
"void setBase_field(std::int32_t value) {\n getNestedChainedBase().setBase_field(std::move(value));",
"mongo::NestedChainedBottom& getNestedChainedBottom() { return getNestedChainedMiddle().getNestedChainedBottom();",
"void setNestedChainedBottom(mongo::NestedChainedBottom value) {\n getNestedChainedMiddle().setNestedChainedBottom(std::move(value));",
],
)
self.assertStringsInFile(
source,
[
"getNestedChainedBase().setBase_field(element._numberInt());",
"getNestedChainedBottom().setBottom_field(element._numberInt());",
"getNestedChainedMiddle().setMiddle_field(element.str());",
"_top_field = element.boolean();",
],
)
header, source = self.assert_generate_with_basic_types(
dedent(

View File

@ -105,6 +105,7 @@ def validate_help(exe_path):
print(f"Error while calling help for {exe_path}: {e}")
sys.exit(1)
# Make sure we have a proper git version in the windows release
def validate_version(exe_path):
try:
@ -124,6 +125,7 @@ def validate_version(exe_path):
print(f"Error while calling version for {exe_path}: {e}")
sys.exit(1)
def main():
if len(sys.argv) != 2:
print("Usage: python msi_validation.py <path_to_msi>")

View File

@ -533,12 +533,16 @@ def get_edition_alias(edition_name: str) -> str:
return "org"
return edition_name
def validate_top_level_directory(tar_name: str):
command = f"tar -tf {tar_name} | head -n 1 | awk -F/ '{{print $1}}'"
proc = subprocess.run(command, capture_output=True, shell=True, text=True)
top_level_directory = proc.stdout.strip()
if all(os_arch not in top_level_directory for os_arch in VALID_TAR_DIRECTORY_ARCHITECTURES):
raise Exception(f"Found an unexpected os-arch pairing as the top level directory. Top level directory: {top_level_directory}")
raise Exception(
f"Found an unexpected os-arch pairing as the top level directory. Top level directory: {top_level_directory}"
)
def validate_enterprise(sources_text, edition, binfile):
if edition != "enterprise" and edition != "atlas":
@ -548,6 +552,7 @@ def validate_enterprise(sources_text, edition, binfile):
if "src/mongo/db/modules/enterprise" not in sources_text:
raise Exception(f"Failed to find enterprise code in {edition} binary {binfile}.")
def validate_atlas(sources_text, edition, binfile):
if edition != "atlas":
if "/modules/atlas/" in sources_text:
@ -556,6 +561,7 @@ def validate_atlas(sources_text, edition, binfile):
if "/modules/enterprise/" not in sources_text:
raise Exception(f"Failed to find atlas code in {edition} binary {binfile}.")
arches: Set[str] = set()
oses: Set[str] = set()
editions: Set[str] = set()
@ -743,7 +749,7 @@ if args.command == "branch":
)
output_text = p.stdout + p.stderr
logging.info(output_text)
validate_enterprise(output_text, args.edition, binfile)
validate_atlas(output_text, args.edition, binfile)

View File

@ -50,6 +50,7 @@ DISTROS = ["suse", "debian", "redhat", "ubuntu", "amazon", "amazon2", "amazon202
unexpected_lts_release_series = ("8.2",)
def get_suffix(version, stable_name: str, unstable_name: str) -> str:
parts = version.split(".")
@ -59,11 +60,12 @@ def get_suffix(version, stable_name: str, unstable_name: str) -> str:
series = f"{major}.{minor}"
if major >= 5:
is_stable_version = (minor == 0 or series in unexpected_lts_release_series)
is_stable_version = minor == 0 or series in unexpected_lts_release_series
return stable_name if is_stable_version else unstable_name
else:
return stable_name if minor % 2 == 0 else unstable_name
class Spec(object):
"""Spec class."""

View File

@ -22,6 +22,7 @@ OWNER_NAME = "10gen"
REPO_NAME = "mongo"
PROFILE_DATA_FILE_PATH = "bazel/repository_rules/profiling_data.bzl"
def get_mongo_repository(app_id, private_key):
"""
Gets the mongo github repository
@ -31,6 +32,7 @@ def get_mongo_repository(app_id, private_key):
g = installation.get_github_for_installation()
return g.get_repo(f"{OWNER_NAME}/{REPO_NAME}")
def compute_sha256(file_path: str) -> str:
"""
Compute the sha256 hash of a file
@ -41,11 +43,12 @@ def compute_sha256(file_path: str) -> str:
sha256.update(block)
return sha256.hexdigest()
def download_file(url: str, output_location: str) -> bool:
"""
Download a file to a specific output_location and return if the file existed remotely
"""
try:
try:
response = requests.get(url)
response.raise_for_status()
with open(output_location, "wb") as file:
@ -54,6 +57,7 @@ def download_file(url: str, output_location: str) -> bool:
except requests.exceptions.RequestException:
return False
def replace_quoted_text_in_tagged_line(text: str, tag: str, new_text: str) -> str:
"""
Replace the text between quotes in a line that starts with a specific tag
@ -65,6 +69,7 @@ def replace_quoted_text_in_tagged_line(text: str, tag: str, new_text: str) -> st
pattern = rf'({tag}.*?"(.*?)")'
return re.sub(pattern, lambda match: match.group(0).replace(match.group(2), new_text), text)
def update_bolt_info(file_content: str, new_url: str, new_checksum: str) -> str:
"""
Updates the bolt url and checksum lines in a file
@ -74,6 +79,7 @@ def update_bolt_info(file_content: str, new_url: str, new_checksum: str) -> str:
updated_text = replace_quoted_text_in_tagged_line(file_content, bolt_url_tag, new_url)
return replace_quoted_text_in_tagged_line(updated_text, bolt_checksum_tag, new_checksum)
def update_clang_pgo_info(file_content: str, new_url: str, new_checksum: str) -> str:
"""
Updates the clang pgo url and checksum lines in a file
@ -83,6 +89,7 @@ def update_clang_pgo_info(file_content: str, new_url: str, new_checksum: str) ->
updated_text = replace_quoted_text_in_tagged_line(file_content, clang_pgo_url_tag, new_url)
return replace_quoted_text_in_tagged_line(updated_text, clang_pgo_checksum_tag, new_checksum)
def update_gcc_pgo_info(file_content: str, new_url: str, new_checksum: str) -> str:
"""
Updates the gcc pgo url and checksum lines in a file
@ -92,6 +99,7 @@ def update_gcc_pgo_info(file_content: str, new_url: str, new_checksum: str) -> s
updated_text = replace_quoted_text_in_tagged_line(file_content, gcc_pgo_url_tag, new_url)
return replace_quoted_text_in_tagged_line(updated_text, gcc_pgo_checksum_tag, new_checksum)
def create_pr(target_branch: str, new_branch: str, original_file, new_content: str):
"""
Opens up a pr for a single file with new contents
@ -105,10 +113,22 @@ def create_pr(target_branch: str, new_branch: str, original_file, new_content: s
print(f"Branch doesn't exist, creating branch {new_branch}.")
repo.create_git_ref(ref=ref, sha=target_repo_branch.commit.sha)
else:
raise
raise
repo.update_file(
path=PROFILE_DATA_FILE_PATH,
content=new_content,
branch=new_branch,
message="Updating profile files.",
sha=original_file.sha,
)
repo.create_pull(
base=target_branch,
head=new_branch,
title="SERVER-110427 Update profiling data",
body="Automated PR updating the profiling data.",
)
repo.update_file(path=PROFILE_DATA_FILE_PATH, content=new_content, branch=new_branch, message="Updating profile files.", sha=original_file.sha)
repo.create_pull(base=target_branch, head=new_branch, title="SERVER-110427 Update profiling data", body="Automated PR updating the profiling data.")
def create_profile_data_pr(repo, args, target_branch, new_branch):
"""
@ -129,27 +149,42 @@ def create_profile_data_pr(repo, args, target_branch, new_branch):
sys.exit(0)
if clang_pgo_file_exists and gcc_pgo_file_exists:
print(f"Both clang and gcc had pgo files that existed. Clang: {args.clang_pgo_url} GCC: {args.gcc_pgo_url}. Only one should be updated at a time. Not creating PR.")
print(
f"Both clang and gcc had pgo files that existed. Clang: {args.clang_pgo_url} GCC: {args.gcc_pgo_url}. Only one should be updated at a time. Not creating PR."
)
sys.exit(1)
if not clang_pgo_file_exists and not gcc_pgo_file_exists:
print(f"Neither clang nor gcc had pgo files that existed at either {args.clang_pgo_url} or {args.gcc_pgo_url}. Not creating PR.")
print(
f"Neither clang nor gcc had pgo files that existed at either {args.clang_pgo_url} or {args.gcc_pgo_url}. Not creating PR."
)
sys.exit(0)
profiling_data_file = repo.get_contents(PROFILE_DATA_FILE_PATH, ref=f"refs/heads/{target_branch}")
profiling_data_file = repo.get_contents(
PROFILE_DATA_FILE_PATH, ref=f"refs/heads/{target_branch}"
)
profiling_data_file_content = profiling_data_file.decoded_content.decode()
profiling_file_updated_text = update_bolt_info(profiling_data_file_content, args.bolt_url, compute_sha256(bolt_file))
profiling_file_updated_text = update_bolt_info(
profiling_data_file_content, args.bolt_url, compute_sha256(bolt_file)
)
if clang_pgo_file_exists:
profiling_file_updated_text = update_clang_pgo_info(profiling_file_updated_text, args.clang_pgo_url, compute_sha256(clang_pgo_file))
profiling_file_updated_text = update_clang_pgo_info(
profiling_file_updated_text, args.clang_pgo_url, compute_sha256(clang_pgo_file)
)
else:
profiling_file_updated_text = update_gcc_pgo_info(profiling_file_updated_text, args.gcc_pgo_url, compute_sha256(gcc_pgo_file))
profiling_file_updated_text = update_gcc_pgo_info(
profiling_file_updated_text, args.gcc_pgo_url, compute_sha256(gcc_pgo_file)
)
create_pr(target_branch, new_branch, profiling_data_file, profiling_file_updated_text)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="This script uses bolt file url, clang pgo file url and gcc pgo file url to create a PR updating the links to these files.")
parser = argparse.ArgumentParser(
description="This script uses bolt file url, clang pgo file url and gcc pgo file url to create a PR updating the links to these files."
)
parser.add_argument("bolt_url", help="URL that BOLT data was uploaded to.")
parser.add_argument("clang_pgo_url", help="URL that clang pgo data was uploaded to.")
parser.add_argument("gcc_pgo_url", help="URL that gcc pgo data was uploaded to.")

View File

@ -31,13 +31,7 @@ def lint(paths: List[str]):
"""Lint specified paths (files or directories) using Pyright."""
if "BUILD_WORKSPACE_DIRECTORY" in os.environ:
subprocess.run(
[
"python",
"-m",
"pyright",
"-p",
"pyproject.toml"
] + paths,
["python", "-m", "pyright", "-p", "pyproject.toml"] + paths,
env=os.environ,
check=True,
cwd=REPO_ROOT,

View File

@ -1,6 +1,5 @@
"""Minimum and maximum dictionary declarations for the different randomized parameters (mongod and mongos)."""
"""
For context and maintenance, see:
https://github.com/10gen/mongo/blob/master/buildscripts/resmokelib/generate_fuzz_config/README.md#adding-new-mongo-parameters

View File

@ -53,9 +53,9 @@ def should_activate_core_analysis_task(task: Task) -> bool:
# Expected format is like dump_mongod.429814.core or dump_mongod-8.2.429814.core, where 429814 is the PID.
assert len(core_file_parts) >= 3, "Unknown core dump file name format"
assert str.isdigit(core_file_parts[-2]), (
"PID not in expected location of core dump file name"
)
assert str.isdigit(
core_file_parts[-2]
), "PID not in expected location of core dump file name"
pid = core_file_parts[-2]
core_dump_pids.add(pid)

View File

@ -188,7 +188,6 @@ class HangAnalyzer(Subcommand):
str(pinfo.pidv),
)
# Dump info of all processes, except python & java.
for pinfo in [pinfo for pinfo in processes if not re.match("^(java|python)", pinfo.name)]:
try:
@ -291,9 +290,7 @@ class HangAnalyzer(Subcommand):
)
except (KeyError, OSError):
# The error from getpass.getuser() when there is no username for a UID.
self.root_logger.warning(
"No username set for the current UID."
)
self.root_logger.warning("No username set for the current UID.")
def _check_enough_free_space(self):
usage_percent = psutil.disk_usage(".").percent

View File

@ -20,7 +20,8 @@ if _IS_WINDOWS:
import win32event
PROCS_TIMEOUT_SECS = 60
TYPICAL_MONGOD_DUMP_SECS = 5 # How long a mongod usually takes to core dump.
TYPICAL_MONGOD_DUMP_SECS = 5 # How long a mongod usually takes to core dump.
def call(args, logger, timeout_seconds=None, pinfo=None, check=True) -> int:
"""Call subprocess on args list."""
@ -159,7 +160,9 @@ def teardown_processes(logger, processes, dump_pids):
else:
logger.info("Killing process %s with pid %d", pinfo.name, pid)
proc.kill()
proc.wait(timeout=TYPICAL_MONGOD_DUMP_SECS) # A zombie or defunct process won't end until it is reaped by its parent.
proc.wait(
timeout=TYPICAL_MONGOD_DUMP_SECS
) # A zombie or defunct process won't end until it is reaped by its parent.
except (psutil.NoSuchProcess, psutil.TimeoutExpired):
# Process has already terminated or will need to be reaped by its parent.
pass
@ -170,7 +173,7 @@ def _await_cores(dump_pids, logger):
start_time = datetime.now()
for pid in dump_pids:
while not os.path.exists(dump_pids[pid]):
time.sleep(TYPICAL_MONGOD_DUMP_SECS)
time.sleep(TYPICAL_MONGOD_DUMP_SECS)
if (datetime.now() - start_time).total_seconds() > PROCS_TIMEOUT_SECS:
logger.error("Timed out while awaiting process.")
return

View File

@ -306,7 +306,7 @@ class _TestList(object):
break
if not in_disabled_module:
new_filtered.append(test)
self._filtered = new_filtered
def match_tag_expression(self, tag_expression, get_tags):

View File

@ -37,7 +37,6 @@ def is_s3_presigned_url(url: str) -> bool:
return "X-Amz-Signature" in qs
def download_from_s3(url):
"""Download file from S3 bucket by a given URL."""

View File

@ -273,7 +273,9 @@ class ExplicitSuiteConfig(SuiteConfigInterface):
"""Populate the named suites by scanning config_dir/suites."""
with cls._name_suites_lock:
if not cls._named_suites:
suites_dirs = [os.path.join(_config.CONFIG_DIR, "suites")] + _config.MODULE_SUITE_DIRS
suites_dirs = [
os.path.join(_config.CONFIG_DIR, "suites")
] + _config.MODULE_SUITE_DIRS
for suites_dir in suites_dirs:
root = os.path.abspath(suites_dir)
if not os.path.exists(root):

View File

@ -227,11 +227,7 @@ class ReplSetBuilder(FixtureBuilder):
if replset.disagg_base_config:
members = []
for idx, node in enumerate(replset.nodes):
member = {
"_id": idx,
"host": node.get_internal_connection_string(),
"priority": 1
}
member = {"_id": idx, "host": node.get_internal_connection_string(), "priority": 1}
members.append(member)
disagg_base_config = {
**replset.disagg_base_config,
@ -240,15 +236,17 @@ class ReplSetBuilder(FixtureBuilder):
"version": 1,
"term": 1,
"members": members,
}
},
}
for node in replset.nodes:
opts = node.get_mongod_options()
opts["set_parameters"]["disaggregatedStorageConfig"] = json.dumps(
disagg_base_config)
disagg_base_config
)
opts["set_parameters"]["disaggregatedStorageEnabled"] = True
opts["set_parameters"]["logComponentVerbosity"] = json.dumps(
{"disaggregatedStorage": 5})
{"disaggregatedStorage": 5}
)
node.set_mongod_options(opts)
if replset.start_initial_sync_node:

View File

@ -124,7 +124,9 @@ class HookTestArchival(object):
test_name, config.EVERGREEN_EXECUTION, self._tests_repeat[test_name]
)
logger.info("Archiving data files for test %s from %s", test_name, input_files)
status, message = self.archive_instance.archive_files(input_files, archive_name, display_name)
status, message = self.archive_instance.archive_files(
input_files, archive_name, display_name
)
if status:
logger.warning("Archive failed for %s: %s", test_name, message)
else:

View File

@ -20,7 +20,9 @@ from buildscripts.util.cedar_report import CedarMetric, CedarTestReport
THRESHOLD_LOCATION = "etc/performance_thresholds.yml"
SEP_BENCHMARKS_PROJECT = "mongodb-mongo-master"
SEP_BENCHMARKS_TASK_NAME = "benchmarks_sep"
GET_TIMESERIES_URL = "https://performance-monitoring-api.corp.mongodb.com/time_series/?summarized_executions=false"
GET_TIMESERIES_URL = (
"https://performance-monitoring-api.corp.mongodb.com/time_series/?summarized_executions=false"
)
MAINLINE_REQUESTERS = frozenset(["git_tag_request", "gitter_request"])
@ -115,19 +117,17 @@ class GenerateAndCheckPerfResults(interface.Hook):
"No variant information was given to resmoke. Please set the --variantName flag to let resmoke know what thresholds to use when checking."
)
return
# For mainline builds, Evergreen does not make the base commit available in the expansions
# we retrieve it by looking for the previous commit in the Git log
if _config.EVERGREEN_REQUESTER in MAINLINE_REQUESTERS:
base_commit_hash = subprocess.check_output(
["git", "log", "-1", "--pretty=format:%H", "HEAD~1"],
cwd=".",
text=True
["git", "log", "-1", "--pretty=format:%H", "HEAD~1"], cwd=".", text=True
).strip()
# For patch builds the evergreen revision is set to the base commit
else:
base_commit_hash = _config.EVERGREEN_REVISION
for test_name in benchmark_reports.keys():
variant_thresholds = self.performance_thresholds.get(test_name, None)
if variant_thresholds is None:
@ -274,10 +274,7 @@ class GenerateAndCheckPerfResults(interface.Hook):
project: str,
) -> int:
"""Retrieve the base commit value for a given timeseries for a specific commit hash."""
headers = {
"accept": "application/json",
"Content-Type": "application/json"
}
headers = {"accept": "application/json", "Content-Type": "application/json"}
payload = {
"infos": [
{
@ -286,7 +283,7 @@ class GenerateAndCheckPerfResults(interface.Hook):
"task": task_name,
"test": test_name,
"measurement": measurement,
"args": args
"args": args,
}
]
}
@ -333,6 +330,7 @@ class GenerateAndCheckPerfResults(interface.Hook):
f"No value found for test {test_name}, measurement {measurement} on variant {variant} in project {project}"
)
class CheckPerfResultTestCase(interface.DynamicTestCase):
"""CheckPerfResultTestCase class."""

View File

@ -17,7 +17,8 @@ class ContinuousMaintenance(interface.Hook):
"""Regularly connect to replica sets and send a replSetMaintenance command."""
DESCRIPTION = (
"Continuous maintenance (causes a secondary node to enter maintenance mode at regular" " intervals)"
"Continuous maintenance (causes a secondary node to enter maintenance mode at regular"
" intervals)"
)
IS_BACKGROUND = True
@ -152,8 +153,8 @@ class _MaintenanceThread(threading.Thread):
self.logger.warning("No replica set on which to run maintenances.")
return
try:
while True:
try:
while True:
permitted = self.__lifecycle.wait_for_action_permitted()
if not permitted:
break
@ -177,16 +178,16 @@ class _MaintenanceThread(threading.Thread):
)
client = fixture_interface.build_client(chosen, self._auth_options)
self.logger.info("Putting secondary into maintenance mode...")
self._toggle_maintenance_mode(client, enable=True)
self.logger.info(f"Sleeping for {self._maintenance_interval_secs} seconds...")
self.__lifecycle.wait_for_action_interval(self._maintenance_interval_secs)
self.logger.info("Disabling maintenance mode...")
self._toggle_maintenance_mode(client, enable=False)
self.logger.info(f"Sleeping for {self._maintenance_interval_secs} seconds...")
self.logger.info("Putting secondary into maintenance mode...")
self._toggle_maintenance_mode(client, enable=True)
self.logger.info(f"Sleeping for {self._maintenance_interval_secs} seconds...")
self.__lifecycle.wait_for_action_interval(self._maintenance_interval_secs)
self.logger.info("Disabling maintenance mode...")
self._toggle_maintenance_mode(client, enable=False)
self.logger.info(f"Sleeping for {self._maintenance_interval_secs} seconds...")
self.__lifecycle.wait_for_action_interval(self._maintenance_interval_secs)
except Exception as e:
# Proactively log the exception when it happens so it will be
@ -214,7 +215,6 @@ class _MaintenanceThread(threading.Thread):
self._check_thread()
self._none_maintenance_mode()
# Check that fixtures are still running
for rs_fixture in self._rs_fixtures:
if not rs_fixture.is_running():
@ -239,19 +239,19 @@ class _MaintenanceThread(threading.Thread):
secondaries = rs_fixture.get_secondaries()
for secondary in secondaries:
client = fixture_interface.build_client(secondary, self._auth_options)
self._toggle_maintenance_mode(client, enable=False)
self._toggle_maintenance_mode(client, enable=False)
def _toggle_maintenance_mode(self, client, enable):
"""
Toggles a secondary node into and out of maintenance mode.
Args:
client (MongoClient): A PyMongo client connected to the secondary node.
enable (bool): True to enable maintenance mode, False to disable it.
"""
try:
result = client.admin.command('replSetMaintenance', enable)
self.logger.info(f"Maintenance mode {'enabled' if enable else 'disabled'}: {result}")
def _toggle_maintenance_mode(self, client, enable):
"""
Toggles a secondary node into and out of maintenance mode.
Args:
client (MongoClient): A PyMongo client connected to the secondary node.
enable (bool): True to enable maintenance mode, False to disable it.
"""
try:
result = client.admin.command("replSetMaintenance", enable)
self.logger.info(f"Maintenance mode {'enabled' if enable else 'disabled'}: {result}")
except pymongo.errors.OperationFailure as e:
# Note it is expected to see this log if we are trying to set maintenance mode disabled when we are not in maintenance mode.
self.logger.info(f"Failed to toggle maintenance mode: {e}")

View File

@ -31,8 +31,12 @@ def validate(mdb, logger, acceptable_err_codes):
if "code" in res and res["code"] in acceptable_err_codes:
# Command not supported on view.
pass
elif ("code" in res and "errmsg" in res and res["code"] == 26 and
"Timeseries buckets collection does not exist" in res["errmsg"]):
elif (
"code" in res
and "errmsg" in res
and res["code"] == 26
and "Timeseries buckets collection does not exist" in res["errmsg"]
):
# TODO(SERVER-109819): Remove this workaround once v9.0 is last LTS
# Validating a timeseries view without a matching buckets collection fails with
# NamespaceNotFound. This can happen with this create+drop interleaving:

View File

@ -270,7 +270,7 @@ class Job(object):
)
self.fixture.setup()
self.fixture.await_ready()
# We are intentionally only checking the individual 'test' status and not calling
# report.wasSuccessful() here. It is possible that a thread running in the background as
# part of a hook has added a failed test case to 'self.report'. Checking the individual

View File

@ -181,7 +181,9 @@ class TestReport(unittest.TestResult):
if test.timed_out.is_set():
test_info.status = "timeout"
test_info.evergreen_status = "timeout"
test_status = "no failures detected" if test_info.status == "pass" else test_info.status
test_status = (
"no failures detected" if test_info.status == "pass" else test_info.status
)
time_taken = test_info.end_time - test_info.start_time
self.job_logger.info(

View File

@ -114,7 +114,7 @@ class Suite(object):
@tests.setter
def tests(self, tests):
self._tests = tests
self._tests = tests
@property
def excluded(self):
@ -157,7 +157,9 @@ class Suite(object):
try:
evg_api = evergreen_conn.get_evergreen_api()
except RuntimeError:
loggers.ROOT_EXECUTOR_LOGGER.warning("Failed to create Evergreen API client. Evergreen test selection will be skipped even if it was enabled.")
loggers.ROOT_EXECUTOR_LOGGER.warning(
"Failed to create Evergreen API client. Evergreen test selection will be skipped even if it was enabled."
)
else:
test_selection_strategy = (
_config.EVERGREEN_TEST_SELECTION_STRATEGY
@ -207,9 +209,9 @@ class Suite(object):
# ensures that select_tests results is only used if no exceptions or type errors are thrown from it
if select_tests_succeeds_flag and use_select_tests:
evergreen_filtered_tests = result["tests"]
evergreen_excluded_tests = set(evergreen_filtered_tests).symmetric_difference(
set(tests)
)
evergreen_excluded_tests = set(
evergreen_filtered_tests
).symmetric_difference(set(tests))
loggers.ROOT_EXECUTOR_LOGGER.info(
f"Evergreen applied the following test selection strategies: {test_selection_strategy}"
)

View File

@ -170,7 +170,7 @@ class ResmokeSymbolizer:
unsymbolized_content_dict = {}
try:
with open(f, "r") as file:
unsymbolized_content = ','.join([line.rstrip('\n') for line in file])
unsymbolized_content = ",".join([line.rstrip("\n") for line in file])
unsymbolized_content_dict = ast.literal_eval(unsymbolized_content)
except Exception:
test.logger.error(f"Failed to parse stacktrace file {f}", exc_info=1)

View File

@ -54,7 +54,6 @@ class DBTestCase(interface.ProcessTestCase):
interface.append_process_tracking_options(process_kwargs, self._id)
self.dbtest_options["process_kwargs"] = process_kwargs
def _execute(self, process):
interface.ProcessTestCase._execute(self, process)
self._clear_dbpath()

View File

@ -44,7 +44,7 @@ class JSRunnerFileTestCase(interface.ProcessTestCase):
global_vars["TestData"] = test_data
self.shell_options["global_vars"] = global_vars
process_kwargs = copy.deepcopy(self.shell_options.get("process_kwargs", {}))
interface.append_process_tracking_options(process_kwargs, self._id)
self.shell_options["process_kwargs"] = process_kwargs

View File

@ -318,7 +318,7 @@ class MultiClientsTestCase(interface.TestCase):
def _raise_if_unsafe_exit(self, return_code: int):
"""Determine if a return code represents and unsafe exit."""
if self.timed_out.is_set():
# If the test timed out, it is assumed a non-zero exit code is
# If the test timed out, it is assumed a non-zero exit code is
# from the hang-analyzer intentionally killing the process.
return
# 252 and 253 may be returned in failed test executions.

View File

@ -27,7 +27,6 @@ class PrettyPrinterTestCase(interface.ProcessTestCase):
self.program_executable = program_executables[0]
self.program_options = utils.default_if_none(program_options, {}).copy()
interface.append_process_tracking_options(self.program_options, self._id)
def _make_process(self):

View File

@ -22,7 +22,7 @@ class QueryTesterSelfTestCase(interface.ProcessTestCase):
interface.ProcessTestCase.__init__(self, logger, "QueryTesterSelfTest", test_filenames[0])
self.test_file = test_filenames[0]
def _make_process(self):
def _make_process(self):
program_options = {}
interface.append_process_tracking_options(program_options, self._id)
return core.programs.generic_program(

View File

@ -72,8 +72,8 @@ class QueryTesterServerTestCase(interface.ProcessTestCase):
]
if self.override:
command = command + ["--override", self.override]
program_options = {}
interface.append_process_tracking_options(program_options, self._id)
return core.programs.generic_program(self.logger, command, program_options)

View File

@ -34,7 +34,6 @@ class SDAMJsonTestCase(interface.ProcessTestCase):
interface.append_process_tracking_options(self.program_options, self._id)
def _find_executable(self):
binary = os.path.join(config.INSTALL_DIR, "sdam_json_test")
if os.name == "nt":

View File

@ -7,7 +7,7 @@ from buildscripts.resmokelib import config
def test_analysis(logger, pids):
"""
Write the pids out to a file and kill them instead of running analysis.
Write the pids out to a file and kill them instead of running analysis.
This option will only be specified in resmoke selftests.
"""
with open(os.path.join(config.DBPATH_PREFIX, "test_analysis.txt"), "w") as analysis_file:
@ -17,7 +17,9 @@ def test_analysis(logger, pids):
proc = psutil.Process(pid)
logger.info("Killing process pid %d", pid)
proc.kill()
proc.wait(timeout=5) # A zombie or defunct process won't end until it is reaped by its parent.
proc.wait(
timeout=5
) # A zombie or defunct process won't end until it is reaped by its parent.
except (psutil.NoSuchProcess, psutil.TimeoutExpired):
# Process has already terminated or will need to be reaped by its parent.
pass

View File

@ -22,6 +22,7 @@ def read_sha_file(filename):
content = f.read()
return content.strip().split()[0]
def _fetch_remote_sha256_hash(s3_path: str):
downloaded = False
result = None
@ -40,7 +41,7 @@ def _fetch_remote_sha256_hash(s3_path: str):
if downloaded:
result = read_sha_file(tempfile_name)
if tempfile_name and os.path.exists(tempfile_name):
os.unlink(tempfile_name)
@ -62,13 +63,14 @@ def _verify_s3_hash(s3_path: str, local_path: str, expected_hash: str) -> None:
f"Hash mismatch for {s3_path}, expected {expected_hash} but got {hash_string}"
)
def validate_file(s3_path, output_path, remote_sha_allowed):
hexdigest = S3_SHA256_HASHES.get(s3_path)
if hexdigest:
print(f"Validating against hard coded sha256: {hexdigest}")
_verify_s3_hash(s3_path, output_path, hexdigest)
return True
if not remote_sha_allowed:
raise ValueError(f"No SHA256 hash available for {s3_path}")
@ -81,13 +83,13 @@ def validate_file(s3_path, output_path, remote_sha_allowed):
print(f"Validating against remote sha256 {hexdigest}\n({s3_path}.sha256)")
else:
print(f"Failed to download remote sha256 at {s3_path}.sha256)")
if hexdigest:
_verify_s3_hash(s3_path, output_path, hexdigest)
return True
else:
raise ValueError(f"No SHA256 hash available for {s3_path}")
def _download_and_verify(s3_path, output_path, remote_sha_allowed):
for i in range(5):
@ -97,7 +99,7 @@ def _download_and_verify(s3_path, output_path, remote_sha_allowed):
download_from_s3_with_boto(s3_path, output_path)
except Exception:
download_from_s3_with_requests(s3_path, output_path)
validate_file(s3_path, output_path, remote_sha_allowed)
break
@ -153,7 +155,6 @@ def download_s3_binary(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download and verify S3 binary.")
parser.add_argument("s3_path", help="S3 URL to download from")
parser.add_argument("local_path", nargs="?", help="Optional output file path")

View File

@ -12,6 +12,7 @@ def compute_sha256(file_path: str) -> str:
sha256.update(block)
return sha256.hexdigest()
def write_sha256_file(file_path: str, hash_value: str):
sha256_path = file_path + ".sha256"
file_name = os.path.basename(file_path)
@ -19,6 +20,7 @@ def write_sha256_file(file_path: str, hash_value: str):
f.write(f"{hash_value} {file_name}\n")
print(f"Wrote SHA-256 to {sha256_path}")
def main():
if len(sys.argv) != 2:
print("Usage: sha256sum.py <file>")
@ -32,5 +34,6 @@ def main():
hash_value = compute_sha256(file_path)
write_sha256_file(file_path, hash_value)
if __name__ == "__main__":
main()

View File

@ -116,11 +116,12 @@ def get_script_version(
def strip_extra_prefixes(string_with_prefix: str) -> str:
return string_with_prefix.removeprefix("mongo/").removeprefix("v")
def validate_license(component: dict, error_manager: ErrorManager) -> None:
if "licenses" not in component:
error_manager.append_full_error_message(MISSING_LICENSE_IN_SBOM_COMPONENT_ERROR)
return
valid_license = False
for license in component["licenses"]:
if "expression" in license:
@ -132,15 +133,15 @@ def validate_license(component: dict, error_manager: ErrorManager) -> None:
elif "name" in license["license"]:
# If SPDX does not define the license used, the name field may be used to provide the license name
valid_license = True
if not valid_license:
licensing_validate = get_spdx_licensing().validate( expression, validate=True )
licensing_validate = get_spdx_licensing().validate(expression, validate=True)
# ExpressionInfo(
# original_expression='',
# normalized_expression='',
# errors=[],
# invalid_symbols=[]
#)
# )
valid_license = not licensing_validate.errors or not licensing_validate.invalid_symbols
if not valid_license:
error_manager.append_full_error_message(licensing_validate)
@ -179,18 +180,22 @@ def validate_properties(component: dict, error_manager: ErrorManager) -> None:
return
# Include the .pedigree.descendants[0] version for version matching
if "pedigree" in component and "descendants" in component["pedigree"] and "version" in component["pedigree"]["descendants"][0]:
if (
"pedigree" in component
and "descendants" in component["pedigree"]
and "version" in component["pedigree"]["descendants"][0]
):
comp_pedigree_version = component["pedigree"]["descendants"][0]["version"]
else:
comp_pedigree_version = ""
# At this point a version is attempted to be read from the import script file
script_version = get_script_version(script_path, "VERSION", error_manager)
if script_version == "":
error_manager.append_full_error_message(MISSING_VERSION_IN_IMPORT_FILE_ERROR + script_path)
elif strip_extra_prefixes(script_version) != strip_extra_prefixes(comp_version) and \
strip_extra_prefixes(script_version) != strip_extra_prefixes(comp_pedigree_version):
elif strip_extra_prefixes(script_version) != strip_extra_prefixes(
comp_version
) and strip_extra_prefixes(script_version) != strip_extra_prefixes(comp_pedigree_version):
error_manager.append_full_error_message(
VERSION_MISMATCH_ERROR
+ f"\nscript version:{script_version}\nsbom component version:{comp_version}\nsbom component pedigree version:{comp_pedigree_version}"
@ -217,7 +222,7 @@ def validate_location(component: dict, third_party_libs: set, error_manager: Err
error_manager.append_full_error_message(
"'evidence.occurrences' field must include at least one location."
)
occurrences = component["evidence"]["occurrences"]
for occurrence in occurrences:
if "location" in occurrence:
@ -230,7 +235,7 @@ def validate_location(component: dict, third_party_libs: set, error_manager: Err
lib = location.removeprefix(THIRD_PARTY_LOCATION_PREFIX)
if lib in third_party_libs:
third_party_libs.remove(lib)
def lint_sbom(
input_file: str, output_file: str, third_party_libs: set, should_format: bool
@ -315,4 +320,4 @@ def main() -> int:
if __name__ == "__main__":
sys.exit(main())
sys.exit(main())

View File

@ -21,13 +21,14 @@ from evergreen.api import RetryingEvergreenApi
# this will be populated by the github jwt tokens (1 hour lifetimes)
REDACTED_STRINGS = []
# This is the list of file globs to check for
# This is the list of file globs to check for
# after the dryrun has created the destination output tree
EXCLUDED_PATTERNS = [
EXCLUDED_PATTERNS = [
"src/mongo/db/modules/",
"buildscripts/modules/",
]
class CopybaraRepoConfig(NamedTuple):
"""Copybara source and destination repo sync configuration."""
@ -98,39 +99,38 @@ class CopybaraConfig(NamedTuple):
return self.source is not None and self.destination is not None
def run_command(command):
print(command)
try:
process = subprocess.Popen(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout
text=True,
bufsize=1
)
output_lines = []
for line in process.stdout:
for redact in filter(None, REDACTED_STRINGS): # avoid None replacements
line = line.replace(redact, "<REDACTED>")
print(line, end="")
output_lines.append(line)
full_output = ''.join(output_lines)
process.wait()
if process.returncode != 0:
# Attach output so except block can read it
raise subprocess.CalledProcessError(
process.returncode, command, output=full_output
)
return full_output
except subprocess.CalledProcessError:
# Let main handle it
raise
def run_command(command):
print(command)
try:
process = subprocess.Popen(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout
text=True,
bufsize=1,
)
output_lines = []
for line in process.stdout:
for redact in filter(None, REDACTED_STRINGS): # avoid None replacements
line = line.replace(redact, "<REDACTED>")
print(line, end="")
output_lines.append(line)
full_output = "".join(output_lines)
process.wait()
if process.returncode != 0:
# Attach output so except block can read it
raise subprocess.CalledProcessError(process.returncode, command, output=full_output)
return full_output
except subprocess.CalledProcessError:
# Let main handle it
raise
def create_mongodb_bot_gitconfig():
"""Create the mongodb-bot.gitconfig file with the desired content."""
@ -216,6 +216,7 @@ def check_destination_branch_exists(copybara_config: CopybaraConfig) -> bool:
output = run_command(command)
return copybara_config.destination.branch in output
def find_matching_commit(dir_source_repo: str, dir_destination_repo: str) -> Optional[str]:
"""
Finds a matching commit in the destination repository based on the commit hash from the source repository.
@ -323,6 +324,7 @@ def push_branch_to_destination_repo(
f"git push {copybara_config.destination.git_url} {copybara_config.destination.branch}"
)
def handle_failure(expansions, error_message, output_logs):
acceptable_error_messages = [
# Indicates the two repositories are identical
@ -338,6 +340,7 @@ def handle_failure(expansions, error_message, output_logs):
):
send_failure_message_to_slack(expansions, error_message)
def create_branch_from_matching_commit(copybara_config: CopybaraConfig) -> None:
"""
Create a new branch in the copybara destination repository based on a matching commit found in
@ -397,61 +400,69 @@ def create_branch_from_matching_commit(copybara_config: CopybaraConfig) -> None:
# Change back to the original directory
os.chdir(original_dir)
def is_current_repo_origin(expected_repo: str) -> bool:
"""Check if the current repo's origin matches 'owner/repo'."""
try:
url = run_command("git config --get remote.origin.url").strip()
except subprocess.CalledProcessError:
return False
m = re.search(r"([^/:]+/[^/:]+)\.git$", url)
return bool(m and m.group(1) == expected_repo)
def sky_file_has_version_id(config_file: str, version_id: str) -> bool:
contents = Path(config_file).read_text()
return str(version_id) in contents
def is_current_repo_origin(expected_repo: str) -> bool:
"""Check if the current repo's origin matches 'owner/repo'."""
try:
url = run_command("git config --get remote.origin.url").strip()
except subprocess.CalledProcessError:
return False
m = re.search(r"([^/:]+/[^/:]+)\.git$", url)
return bool(m and m.group(1) == expected_repo)
def branch_exists_remote(remote_url: str, branch_name: str) -> bool:
"""Return True if branch exists on the remote."""
try:
output = run_command(f"git ls-remote --heads {remote_url} {branch_name}")
return bool(output.strip())
except subprocess.CalledProcessError:
return False
def delete_remote_branch(remote_url: str, branch_name: str):
"""Delete branch from remote if it exists."""
if branch_exists_remote(remote_url, branch_name):
print(f"Deleting remote branch {branch_name} from {remote_url}")
run_command(f"git push {remote_url} --delete {branch_name}")
def push_test_branches(copybara_config, expansions):
"""Push test branch with Evergreen patch changes to source, and clean revision to destination."""
# Safety checks
if copybara_config.source.branch != copybara_config.destination.branch:
print(f"ERROR: test branches must match: source={copybara_config.source.branch} dest={copybara_config.destination.branch}")
sys.exit(1)
if not copybara_config.source.branch.startswith("copybara_test_branch") \
or not copybara_config.destination.branch.startswith("copybara_test_branch"):
print(f"ERROR: can not push non copybara test branch: {copybara_config.source.branch}")
sys.exit(1)
if not is_current_repo_origin("10gen/mongo"):
print("Refusing to push copybara_test_branch to non 10gen/mongo repo")
sys.exit(1)
# First, delete stale remote branches if present
delete_remote_branch(copybara_config.source.git_url, copybara_config.source.branch)
delete_remote_branch(copybara_config.destination.git_url, copybara_config.destination.branch)
# --- Push patched branch to DEST repo (local base Evergreen state) ---
run_command(f"git remote add dest_repo {copybara_config.destination.git_url}")
def sky_file_has_version_id(config_file: str, version_id: str) -> bool:
contents = Path(config_file).read_text()
return str(version_id) in contents
def branch_exists_remote(remote_url: str, branch_name: str) -> bool:
"""Return True if branch exists on the remote."""
try:
output = run_command(f"git ls-remote --heads {remote_url} {branch_name}")
return bool(output.strip())
except subprocess.CalledProcessError:
return False
def delete_remote_branch(remote_url: str, branch_name: str):
"""Delete branch from remote if it exists."""
if branch_exists_remote(remote_url, branch_name):
print(f"Deleting remote branch {branch_name} from {remote_url}")
run_command(f"git push {remote_url} --delete {branch_name}")
def push_test_branches(copybara_config, expansions):
"""Push test branch with Evergreen patch changes to source, and clean revision to destination."""
# Safety checks
if copybara_config.source.branch != copybara_config.destination.branch:
print(
f"ERROR: test branches must match: source={copybara_config.source.branch} dest={copybara_config.destination.branch}"
)
sys.exit(1)
if not copybara_config.source.branch.startswith(
"copybara_test_branch"
) or not copybara_config.destination.branch.startswith("copybara_test_branch"):
print(f"ERROR: can not push non copybara test branch: {copybara_config.source.branch}")
sys.exit(1)
if not is_current_repo_origin("10gen/mongo"):
print("Refusing to push copybara_test_branch to non 10gen/mongo repo")
sys.exit(1)
# First, delete stale remote branches if present
delete_remote_branch(copybara_config.source.git_url, copybara_config.source.branch)
delete_remote_branch(copybara_config.destination.git_url, copybara_config.destination.branch)
# --- Push patched branch to DEST repo (local base Evergreen state) ---
run_command(f"git remote add dest_repo {copybara_config.destination.git_url}")
run_command(f"git checkout -B {copybara_config.destination.branch}")
run_command(f"git push dest_repo {copybara_config.destination.branch}")
# --- Push patched branch to SOURCE repo (local patched Evergreen state) ---
run_command(f'git commit -am "Evergreen patch for version_id {expansions["version_id"]}"')
run_command(f"git remote add source_repo {copybara_config.source.git_url}")
run_command(f"git push source_repo {copybara_config.source.branch}")
# --- Push patched branch to SOURCE repo (local patched Evergreen state) ---
run_command(f'git commit -am "Evergreen patch for version_id {expansions["version_id"]}"')
run_command(f"git remote add source_repo {copybara_config.source.git_url}")
run_command(f"git push source_repo {copybara_config.source.branch}")
def main():
global REDACTED_STRINGS
@ -468,7 +479,7 @@ def main():
parser.add_argument(
"--workflow",
default="test",
choices = ["prod", "test"],
choices=["prod", "test"],
help="The copybara workflow to use (test is a dryrun)",
)
@ -516,7 +527,7 @@ def main():
branch = f"copybara_test_branch_{expansions['version_id']}"
test_branch_str = 'testBranch = "copybara_test_branch"'
elif args.workflow == "prod":
if expansions['is_patch'] == "true":
if expansions["is_patch"] == "true":
print("ERROR: prod workflow should not be run in patch builds!")
sys.exit(1)
test_args = []
@ -525,43 +536,43 @@ def main():
raise Exception(f"invalid workflow {args.workflow}")
# Overwrite repo urls in copybara config in-place
with fileinput.FileInput(config_file, inplace=True) as file:
for line in file:
token = None
# Replace GitHub URL with token-authenticated URL
for repo, value in tokens_map.items():
if repo in line:
token = value
break # no need to check other repos
if token:
print(
line.replace(
"https://github.com",
f"https://x-access-token:{token}@github.com",
),
end="",
)
# Update testBranch in .sky file if running test workflow
elif args.workflow == "test" and test_branch_str in line:
print(
line.replace(
test_branch_str,
test_branch_str[:-1] + f"_{expansions['version_id']}\"\n",
),
end="",
)
else:
print(line, end="")
with fileinput.FileInput(config_file, inplace=True) as file:
for line in file:
token = None
if args.workflow == "test":
if not sky_file_has_version_id(config_file, expansions["version_id"]):
print(
f"Copybara test branch in {config_file} does not contain version_id {expansions['version_id']}"
)
# Replace GitHub URL with token-authenticated URL
for repo, value in tokens_map.items():
if repo in line:
token = value
break # no need to check other repos
if token:
print(
line.replace(
"https://github.com",
f"https://x-access-token:{token}@github.com",
),
end="",
)
# Update testBranch in .sky file if running test workflow
elif args.workflow == "test" and test_branch_str in line:
print(
line.replace(
test_branch_str,
test_branch_str[:-1] + f"_{expansions['version_id']}\"\n",
),
end="",
)
else:
print(line, end="")
if args.workflow == "test":
if not sky_file_has_version_id(config_file, expansions["version_id"]):
print(
f"Copybara test branch in {config_file} does not contain version_id {expansions['version_id']}"
)
sys.exit(1)
copybara_config = CopybaraConfig.from_copybara_sky_file(args.workflow, branch, config_file)
@ -593,45 +604,52 @@ def main():
os.makedirs("tmp_copybara")
docker_cmd = [
"docker", "run", "--rm",
"-v", f"{os.path.expanduser('~/.ssh')}:/root/.ssh",
"-v", f"{os.path.expanduser('~/mongodb-bot.gitconfig')}:/root/.gitconfig",
"-v", f"{config_file}:/usr/src/app/copy.bara.sky",
"-v", f"{os.getcwd()}/tmp_copybara:/tmp/copybara-preview",
docker_cmd = [
"docker",
"run",
"--rm",
"-v",
f"{os.path.expanduser('~/.ssh')}:/root/.ssh",
"-v",
f"{os.path.expanduser('~/mongodb-bot.gitconfig')}:/root/.gitconfig",
"-v",
f"{config_file}:/usr/src/app/copy.bara.sky",
"-v",
f"{os.getcwd()}/tmp_copybara:/tmp/copybara-preview",
"copybara_container",
"migrate", "/usr/src/app/copy.bara.sky", args.workflow,
"-v", "--output-root=/tmp/copybara-preview",
"migrate",
"/usr/src/app/copy.bara.sky",
args.workflow,
"-v",
"--output-root=/tmp/copybara-preview",
]
try:
run_command(" ".join(docker_cmd + ["--dry-run"] + test_args))
found_forbidden = False
preview_dir = Path("tmp_copybara")
for file_path in preview_dir.rglob("*"):
if file_path.is_file():
for pattern in EXCLUDED_PATTERNS:
if pattern in str(file_path):
print(f"ERROR: Found excluded path: {file_path}")
found_forbidden = True
if found_forbidden:
sys.exit(1)
found_forbidden = False
preview_dir = Path("tmp_copybara")
for file_path in preview_dir.rglob("*"):
if file_path.is_file():
for pattern in EXCLUDED_PATTERNS:
if pattern in str(file_path):
print(f"ERROR: Found excluded path: {file_path}")
found_forbidden = True
if found_forbidden:
sys.exit(1)
except subprocess.CalledProcessError as err:
if args.workflow == "prod":
error_message = f"Copybara failed with error: {err.returncode}"
error_message = f"Copybara failed with error: {err.returncode}"
handle_failure(expansions, error_message, err.output)
# dry run successful, time to push
try:
run_command(" ".join(docker_cmd + test_args))
except subprocess.CalledProcessError as err:
if args.workflow == "prod":
error_message = f"Copybara failed with error: {err.returncode}"
try:
run_command(" ".join(docker_cmd + test_args))
except subprocess.CalledProcessError as err:
if args.workflow == "prod":
error_message = f"Copybara failed with error: {err.returncode}"
handle_failure(expansions, error_message, err.output)

View File

@ -518,9 +518,7 @@ class TestVariant(unittest.TestCase):
def test_test_flags(self):
variant_ubuntu = self.conf.get_variant("ubuntu")
self.assertEqual(
"--param=value --ubuntu --modules=none", variant_ubuntu.test_flags
)
self.assertEqual("--param=value --ubuntu --modules=none", variant_ubuntu.test_flags)
variant_osx = self.conf.get_variant("osx-108")
self.assertIsNone(variant_osx.test_flags)
@ -559,9 +557,7 @@ class TestVariant(unittest.TestCase):
# Check combined_suite_to_resmoke_args_map when the task doesn't have resmoke_args.
passing_task = variant_ubuntu.get_task("passing_test")
self.assertEqual(
{
"passing_test": "--suites=passing_test --param=value --ubuntu --modules=none"
},
{"passing_test": "--suites=passing_test --param=value --ubuntu --modules=none"},
passing_task.combined_suite_to_resmoke_args_map,
)

View File

@ -14,10 +14,7 @@ class TestSummarize(unittest.TestCase):
summary = under_test.MonitorBuildStatusOrchestrator._summarize("scope", scope_percentages)
expected_summary = (
f"`SUMMARY [scope]` "
f"{under_test.SummaryMsg.BELOW_THRESHOLDS.value}"
)
expected_summary = f"`SUMMARY [scope]` " f"{under_test.SummaryMsg.BELOW_THRESHOLDS.value}"
self.assertEqual(summary, expected_summary)
@ -31,10 +28,7 @@ class TestSummarize(unittest.TestCase):
summary = under_test.MonitorBuildStatusOrchestrator._summarize("scope", scope_percentages)
expected_summary = (
f"`SUMMARY [scope]` "
f"{under_test.SummaryMsg.BELOW_THRESHOLDS.value}"
)
expected_summary = f"`SUMMARY [scope]` " f"{under_test.SummaryMsg.BELOW_THRESHOLDS.value}"
self.assertEqual(summary, expected_summary)

View File

@ -280,11 +280,14 @@ class TestTestTimeout(_ResmokeSelftest):
report = json.load(f)
timeout = [test for test in report["results"] if test["status"] == "timeout"]
passed = [test for test in report["results"] if test["status"] == "pass"]
self.assertEqual(len(timeout), 1, f"Expected one timed out test. Got {timeout}") # one jstest
self.assertEqual(
len(timeout), 1, f"Expected one timed out test. Got {timeout}"
) # one jstest
self.assertEqual(
len(passed), 3, f"Expected 3 passing tests. Got {passed}"
) # one jstest, one fixture setup, one fixture teardown
class TestTestSelection(_ResmokeSelftest):
def parse_reports_json(self):
with open(self.report_file) as fd:
@ -645,7 +648,7 @@ class TestDiscovery(_ResmokeSelftest):
)
def execute_resmoke(resmoke_args: List[str], subcommand: str="run"):
def execute_resmoke(resmoke_args: List[str], subcommand: str = "run"):
return subprocess.run(
[sys.executable, "buildscripts/resmoke.py", subcommand] + resmoke_args,
text=True,
@ -928,10 +931,11 @@ class TestValidateCollections(unittest.TestCase):
self.assertIn(expected, result.stdout)
self.assertNotEqual(result.returncode, 0)
class TestModules(unittest.TestCase):
def test_files_included(self):
# this suite uses a fixture and hook from the module so it will fail if they are not loaded
# it also uses a
# it also uses a
resmoke_args = [
"--resmokeModulesPath=buildscripts/tests/resmoke_end2end/test_resmoke_modules.yml",
"--suite=resmoke_test_module_worked",
@ -939,7 +943,7 @@ class TestModules(unittest.TestCase):
result = execute_resmoke(resmoke_args)
self.assertEqual(result.returncode, 0)
def test_jstests_excluded(self):
# this first command should not include any of the tests from the module
resmoke_args = [
@ -948,10 +952,10 @@ class TestModules(unittest.TestCase):
"--suite=buildscripts/tests/resmoke_end2end/suites/resmoke_test_module_jstests.yml",
"--dryRun=included-tests",
]
result_without_module = execute_resmoke(resmoke_args)
self.assertEqual(result_without_module.returncode, 0)
# this second invocartion should include all of the base jstests and all of the module jstests.
resmoke_args = [
"--resmokeModulesPath=buildscripts/tests/resmoke_end2end/test_resmoke_modules.yml",
@ -959,11 +963,15 @@ class TestModules(unittest.TestCase):
"--suite=buildscripts/tests/resmoke_end2end/suites/resmoke_test_module_jstests.yml",
"--dryRun=included-tests",
]
result_with_module = execute_resmoke(resmoke_args)
self.assertEqual(result_with_module.returncode, 0)
# assert the test is in the list of tests when the module is included
self.assertIn("buildscripts/tests/resmoke_end2end/testfiles/one.js", result_with_module.stdout)
self.assertIn(
"buildscripts/tests/resmoke_end2end/testfiles/one.js", result_with_module.stdout
)
# assert the test is not in the list of tests when the module is excluded
self.assertNotIn("buildscripts/tests/resmoke_end2end/testfiles/one.js", result_without_module.stdout)
self.assertNotIn(
"buildscripts/tests/resmoke_end2end/testfiles/one.js", result_without_module.stdout
)

View File

@ -33,7 +33,6 @@ class TestMochaRunner(unittest.TestCase):
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
def test_mocha_runner(self):
resmoke_args = [

View File

@ -102,7 +102,7 @@ class TestSbom(unittest.TestCase):
third_party_libs = {"librdkafka"}
error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False)
self.assert_message_in_errors(error_manager, sbom_linter.VERSION_MISMATCH_ERROR)
def test_pedigree_version_match(self):
test_file = os.path.join(self.input_dir, "sbom_pedigree_version_match.json")
third_party_libs = {"kafka"}
@ -124,7 +124,7 @@ class TestSbom(unittest.TestCase):
self.assert_message_in_errors(
error_manager, sbom_linter.MISSING_VERSION_IN_SBOM_COMPONENT_ERROR
)
def test_missing_license(self):
test_file = os.path.join(self.input_dir, "sbom_missing_license.json")
third_party_libs = {"librdkafka"}
@ -137,10 +137,8 @@ class TestSbom(unittest.TestCase):
test_file = os.path.join(self.input_dir, "sbom_invalid_license_expression.json")
third_party_libs = {"librdkafka"}
error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False)
#print(error_manager.errors)
self.assert_message_in_errors(
error_manager, "ExpressionInfo"
)
# print(error_manager.errors)
self.assert_message_in_errors(error_manager, "ExpressionInfo")
def test_named_license(self):
test_file = os.path.join(self.input_dir, "sbom_named_license.json")
@ -148,4 +146,4 @@ class TestSbom(unittest.TestCase):
error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False)
if not error_manager.zero_error():
error_manager.print_errors()
self.assertTrue(error_manager.zero_error())
self.assertTrue(error_manager.zero_error())

View File

@ -15,94 +15,83 @@ from bazel.merge_tidy_configs import (
)
class TestClangTidyMergeHelpers(unittest.TestCase):
def test_split_checks_to_list_from_str(self):
self.assertEqual(
split_checks_to_list("foo, bar ,baz"),
["foo", "bar", "baz"]
)
def test_split_checks_to_list_from_list(self):
self.assertEqual(
split_checks_to_list(["a, b", "c"]),
["a", "b", "c"]
)
def test_merge_checks_into_config(self):
base = {"Checks": "a,b"}
incoming = {"Checks": "c,d"}
merge_checks_into_config(base, incoming)
self.assertEqual(base["Checks"], "a,b,c,d")
def test_merge_check_options_into_config(self):
base = {"CheckOptions": [{"key": "A", "value": "1"}]}
incoming = {"CheckOptions": [{"key": "B", "value": "2"}]}
merge_check_options_into_config(base, incoming)
self.assertEqual(
base["CheckOptions"],
[{"key": "A", "value": "1"}, {"key": "B", "value": "2"}]
)
def test_merge_check_options_override(self):
base = {"CheckOptions": [{"key": "A", "value": "1"}]}
incoming = {"CheckOptions": [{"key": "A", "value": "2"}]}
merge_check_options_into_config(base, incoming)
self.assertEqual(base["CheckOptions"], [{"key": "A", "value": "2"}])
def test_deep_merge_dicts(self):
base = {"Outer": {"Inner": 1}, "Keep": True}
override = {"Outer": {"Added": 2}, "New": False}
merged = deep_merge_dicts(base, override)
self.assertEqual(
merged,
{"Outer": {"Inner": 1, "Added": 2}, "Keep": True, "New": False}
)
def test_is_ancestor_directory_true(self):
tmpdir = pathlib.Path(tempfile.mkdtemp())
child = tmpdir / "subdir"
child.mkdir()
self.assertTrue(is_ancestor_directory(tmpdir, child))
def test_is_ancestor_directory_false(self):
tmp1 = pathlib.Path(tempfile.mkdtemp())
tmp2 = pathlib.Path(tempfile.mkdtemp())
self.assertFalse(is_ancestor_directory(tmp1, tmp2))
def test_filter_and_sort_config_paths_no_scope(self):
files = ["/tmp/file1", "/tmp/file2"]
res = filter_and_sort_config_paths(files, None)
self.assertEqual([pathlib.Path("/tmp/file1"), pathlib.Path("/tmp/file2")], res)
def test_filter_and_sort_config_paths_with_scope(self):
tmpdir = pathlib.Path(tempfile.mkdtemp())
(tmpdir / "a").mkdir()
cfg_root = tmpdir / "root.yaml"
cfg_child = tmpdir / "a" / "child.yaml"
cfg_root.write_text("Checks: a")
cfg_child.write_text("Checks: b")
old_cwd = pathlib.Path.cwd()
try:
# Simulate repo root being tmpdir
os.chdir(tmpdir)
res = filter_and_sort_config_paths([cfg_root, cfg_child], "a")
finally:
os.chdir(old_cwd)
self.assertEqual([p.name for p in res], ["root.yaml", "child.yaml"])
def test_load_yaml_empty_file(self):
tmpfile = pathlib.Path(tempfile.mktemp())
tmpfile.write_text("")
self.assertEqual(load_yaml(tmpfile), {})
def test_load_yaml_valid_yaml(self):
tmpfile = pathlib.Path(tempfile.mktemp())
yaml.safe_dump({"a": 1}, open(tmpfile, "w"))
self.assertEqual(load_yaml(tmpfile), {"a": 1})
if __name__ == "__main__":
unittest.main()
class TestClangTidyMergeHelpers(unittest.TestCase):
def test_split_checks_to_list_from_str(self):
self.assertEqual(split_checks_to_list("foo, bar ,baz"), ["foo", "bar", "baz"])
def test_split_checks_to_list_from_list(self):
self.assertEqual(split_checks_to_list(["a, b", "c"]), ["a", "b", "c"])
def test_merge_checks_into_config(self):
base = {"Checks": "a,b"}
incoming = {"Checks": "c,d"}
merge_checks_into_config(base, incoming)
self.assertEqual(base["Checks"], "a,b,c,d")
def test_merge_check_options_into_config(self):
base = {"CheckOptions": [{"key": "A", "value": "1"}]}
incoming = {"CheckOptions": [{"key": "B", "value": "2"}]}
merge_check_options_into_config(base, incoming)
self.assertEqual(
base["CheckOptions"], [{"key": "A", "value": "1"}, {"key": "B", "value": "2"}]
)
def test_merge_check_options_override(self):
base = {"CheckOptions": [{"key": "A", "value": "1"}]}
incoming = {"CheckOptions": [{"key": "A", "value": "2"}]}
merge_check_options_into_config(base, incoming)
self.assertEqual(base["CheckOptions"], [{"key": "A", "value": "2"}])
def test_deep_merge_dicts(self):
base = {"Outer": {"Inner": 1}, "Keep": True}
override = {"Outer": {"Added": 2}, "New": False}
merged = deep_merge_dicts(base, override)
self.assertEqual(merged, {"Outer": {"Inner": 1, "Added": 2}, "Keep": True, "New": False})
def test_is_ancestor_directory_true(self):
tmpdir = pathlib.Path(tempfile.mkdtemp())
child = tmpdir / "subdir"
child.mkdir()
self.assertTrue(is_ancestor_directory(tmpdir, child))
def test_is_ancestor_directory_false(self):
tmp1 = pathlib.Path(tempfile.mkdtemp())
tmp2 = pathlib.Path(tempfile.mkdtemp())
self.assertFalse(is_ancestor_directory(tmp1, tmp2))
def test_filter_and_sort_config_paths_no_scope(self):
files = ["/tmp/file1", "/tmp/file2"]
res = filter_and_sort_config_paths(files, None)
self.assertEqual([pathlib.Path("/tmp/file1"), pathlib.Path("/tmp/file2")], res)
def test_filter_and_sort_config_paths_with_scope(self):
tmpdir = pathlib.Path(tempfile.mkdtemp())
(tmpdir / "a").mkdir()
cfg_root = tmpdir / "root.yaml"
cfg_child = tmpdir / "a" / "child.yaml"
cfg_root.write_text("Checks: a")
cfg_child.write_text("Checks: b")
old_cwd = pathlib.Path.cwd()
try:
# Simulate repo root being tmpdir
os.chdir(tmpdir)
res = filter_and_sort_config_paths([cfg_root, cfg_child], "a")
finally:
os.chdir(old_cwd)
self.assertEqual([p.name for p in res], ["root.yaml", "child.yaml"])
def test_load_yaml_empty_file(self):
tmpfile = pathlib.Path(tempfile.mktemp())
tmpfile.write_text("")
self.assertEqual(load_yaml(tmpfile), {})
def test_load_yaml_valid_yaml(self):
tmpfile = pathlib.Path(tempfile.mktemp())
yaml.safe_dump({"a": 1}, open(tmpfile, "w"))
self.assertEqual(load_yaml(tmpfile), {"a": 1})
if __name__ == "__main__":
unittest.main()

View File

@ -60,21 +60,9 @@ class TestPackager(TestCase):
want: str
cases = [
Case(
name="Old unstable",
version="4.3.0",
want="-org-unstable"
),
Case(
name="Old stable 4.2",
version="4.2.0",
want="-org"
),
Case(
name="Old stable 4.4",
version="4.4.0",
want="-org"
),
Case(name="Old unstable", version="4.3.0", want="-org-unstable"),
Case(name="Old stable 4.2", version="4.2.0", want="-org"),
Case(name="Old stable 4.4", version="4.4.0", want="-org"),
Case(
name="New stable standard",
version="8.0.0",

View File

@ -34,6 +34,7 @@ def get_s3_client():
boto3.setup_default_session(botocore_session=botocore_session)
return boto3.client("s3")
def extract_s3_bucket_key(url: str) -> tuple[str, str]:
"""
Extracts the S3 bucket name and object key from an HTTP(s) S3 URL.

View File

@ -138,7 +138,6 @@ def get_non_merge_queue_squashed_commits(
# required fields, but faked out - these aren't helpful in user-facing logs
repo=fake_repo,
binsha=b"00000000000000000000",
)
]

View File

@ -74,7 +74,7 @@ def main(
evg_project_config_map = {
DEFAULT_EVG_NIGHTLY_PROJECT_NAME: DEFAULT_EVG_NIGHTLY_PROJECT_CONFIG,
}
if RELEASE_BRANCH:
for _, project_config in evg_project_config_map.items():
cmd = [

View File

@ -11,6 +11,7 @@ app = typer.Typer(
add_completion=False,
)
def get_changed_files_from_latest_commit(local_repo_path: str, branch_name: str = "master") -> dict:
try:
repo = Repo(local_repo_path)
@ -28,12 +29,8 @@ def get_changed_files_from_latest_commit(local_repo_path: str, branch_name: str
else:
# Comparing the last commit with its parent to find changed files
files = [file.a_path for file in last_commit.diff(last_commit.parents[0])]
return {
"title": title,
"hash": commit_hash,
"files": files
}
return {"title": title, "hash": commit_hash, "files": files}
except Exception as e:
print(f"Error retrieving changed files: {e}")
raise e
@ -48,38 +45,60 @@ def upload_sbom_via_silkbomb(
creds_file_path: pathlib.Path,
container_command: str,
container_image: str,
timeout_seconds: int = 60 * 5
timeout_seconds: int = 60 * 5,
):
container_options = ["--pull=always", "--platform=linux/amd64", "--rm"]
container_env_files = ["--env-file", str(creds_file_path.resolve())]
container_volumes = ["-v", f"{workdir}:/workdir"]
silkbomb_command = "augment" # it augment first and uses upload command
silkbomb_command = "augment" # it augment first and uses upload command
silkbomb_args = [
"--sbom-in", f"/workdir/{local_repo_path}/{sbom_repo_path}",
"--project", "tarun_test", #kept for tests
"--branch", branch_name,
"--repo", repo_name,
"--sbom-in",
f"/workdir/{local_repo_path}/{sbom_repo_path}",
"--project",
"tarun_test", # kept for tests
"--branch",
branch_name,
"--repo",
repo_name,
]
command = [
container_command, "run", *container_options, *container_env_files,
*container_volumes, container_image,
silkbomb_command, *silkbomb_args
container_command,
"run",
*container_options,
*container_env_files,
*container_volumes,
container_image,
silkbomb_command,
*silkbomb_args,
]
aws_region = "us-east-1"
ecr_registry_url = "901841024863.dkr.ecr.us-east-1.amazonaws.com/release-infrastructure/silkbomb"
ecr_registry_url = (
"901841024863.dkr.ecr.us-east-1.amazonaws.com/release-infrastructure/silkbomb"
)
print(f"Attempting to authenticate to AWS ECR registry '{ecr_registry_url}'...")
try:
login_cmd = f"aws ecr get-login-password --region {aws_region} | {container_command} login --username AWS --password-stdin {ecr_registry_url}"
subprocess.run(login_cmd, shell=True, check=True, text=True, capture_output=True, timeout=timeout_seconds)
subprocess.run(
login_cmd,
shell=True,
check=True,
text=True,
capture_output=True,
timeout=timeout_seconds,
)
print("ECR authentication successful.")
except FileNotFoundError:
print(f"Error: A required command was not found. Please ensure AWS CLI and '{container_command}' are installed and in your PATH.")
print(
f"Error: A required command was not found. Please ensure AWS CLI and '{container_command}' are installed and in your PATH."
)
raise
except subprocess.TimeoutExpired as e:
print(f"Error: Command timed out after {timeout_seconds} seconds. Please check Evergreen network state and try again.")
print(
f"Error: Command timed out after {timeout_seconds} seconds. Please check Evergreen network state and try again."
)
raise e
except subprocess.CalledProcessError as e:
print(f"Error during ECR authentication:\n--- STDERR ---\n{e.stderr}")
@ -93,33 +112,82 @@ def upload_sbom_via_silkbomb(
print(f"Error: '{container_command}' command not found.")
raise e
except subprocess.TimeoutExpired as e:
print(f"Error: Command timed out after {timeout_seconds} seconds. Please check Evergreen network state and try again.")
print(
f"Error: Command timed out after {timeout_seconds} seconds. Please check Evergreen network state and try again."
)
raise e
except subprocess.CalledProcessError as e:
print(f"Error during container execution:\n--- STDOUT ---\n{e.stdout}\n--- STDERR ---\n{e.stderr}")
print(
f"Error during container execution:\n--- STDOUT ---\n{e.stdout}\n--- STDERR ---\n{e.stderr}"
)
raise e
# TODO (SERVER-109205): Add Slack Alerts for failures
@app.command()
def run(
github_org: Annotated[str, typer.Option(...,envvar="GITHUB_ORG", help="Name of the github organization (e.g. 10gen)")],
github_repo: Annotated[str, typer.Option(..., envvar="GITHUB_REPO", help="Repo name in 'owner/repo' format.")],
local_repo_path: Annotated[str, typer.Option(..., envvar="LOCAL_REPO_PATH", help="Path to the local git repository."), ],
branch_name: Annotated[str, typer.Option(..., envvar="BRANCH_NAME", help="The head branch (e.g., the PR branch name).")],
sbom_repo_path: Annotated[str, typer.Option(..., "--sbom-in", envvar="SBOM_REPO_PATH", help="Path to the SBOM file to check and upload.")] = "sbom.json",
requester: Annotated[str, typer.Option(..., envvar="REQUESTER", help="The entity requesting the run (e.g., 'github_merge_queue').")] = "",
container_command: Annotated[str, typer.Option(..., envvar="CONTAINER_COMMAND", help="Container engine to use ('podman' or 'docker').")] = "podman",
container_image: Annotated[str, typer.Option(..., envvar="CONTAINER_IMAGE", help="Silkbomb container image.")] = "901841024863.dkr.ecr.us-east-1.amazonaws.com/release-infrastructure/silkbomb:2.0",
creds_file: Annotated[pathlib.Path, typer.Option(..., envvar="CONTAINER_ENV_FILES", help="Path for the temporary credentials file.")] = pathlib.Path("kondukto_credentials.env"),
workdir: Annotated[str, typer.Option(..., envvar="WORKING_DIR", help="Path for the container volumes.")]= "/workdir",
dry_run: Annotated[bool, typer.Option("--dry-run/--run", help="Check for changes without uploading.")] = True,
check_sbom_file_change: Annotated[bool, typer.Option("--check-sbom-file-change", help="Check for changes to the SBOM file.")] = False,
github_org: Annotated[
str,
typer.Option(..., envvar="GITHUB_ORG", help="Name of the github organization (e.g. 10gen)"),
],
github_repo: Annotated[
str, typer.Option(..., envvar="GITHUB_REPO", help="Repo name in 'owner/repo' format.")
],
local_repo_path: Annotated[
str,
typer.Option(..., envvar="LOCAL_REPO_PATH", help="Path to the local git repository."),
],
branch_name: Annotated[
str,
typer.Option(..., envvar="BRANCH_NAME", help="The head branch (e.g., the PR branch name)."),
],
sbom_repo_path: Annotated[
str,
typer.Option(
...,
"--sbom-in",
envvar="SBOM_REPO_PATH",
help="Path to the SBOM file to check and upload.",
),
] = "sbom.json",
requester: Annotated[
str,
typer.Option(
...,
envvar="REQUESTER",
help="The entity requesting the run (e.g., 'github_merge_queue').",
),
] = "",
container_command: Annotated[
str,
typer.Option(
..., envvar="CONTAINER_COMMAND", help="Container engine to use ('podman' or 'docker')."
),
] = "podman",
container_image: Annotated[
str, typer.Option(..., envvar="CONTAINER_IMAGE", help="Silkbomb container image.")
] = "901841024863.dkr.ecr.us-east-1.amazonaws.com/release-infrastructure/silkbomb:2.0",
creds_file: Annotated[
pathlib.Path,
typer.Option(
..., envvar="CONTAINER_ENV_FILES", help="Path for the temporary credentials file."
),
] = pathlib.Path("kondukto_credentials.env"),
workdir: Annotated[
str, typer.Option(..., envvar="WORKING_DIR", help="Path for the container volumes.")
] = "/workdir",
dry_run: Annotated[
bool, typer.Option("--dry-run/--run", help="Check for changes without uploading.")
] = True,
check_sbom_file_change: Annotated[
bool, typer.Option("--check-sbom-file-change", help="Check for changes to the SBOM file.")
] = False,
):
if requester != "commit" and not dry_run:
print(f"Skipping: Run can only be triggered for 'commit', but requester was '{requester}'.")
sys.exit(0)
major_branches = ["v7.0", "v8.0", "v8.1", "master"] # Only major branches that MongoDB supports
major_branches = ["v7.0", "v8.0", "v8.1", "master"] # Only major branches that MongoDB supports
if False and branch_name not in major_branches:
print(f"Skipping: Branch '{branch_name}' is not a major branch. Exiting.")
sys.exit(0)
@ -129,18 +197,22 @@ def run(
if not sbom_path.resolve().exists():
print(f"Error: SBOM file not found at path: {str(sbom_path.resolve())}")
sys.exit(1)
try:
sbom_file_changed = True
if check_sbom_file_change:
commit_changed_files = get_changed_files_from_latest_commit(repo_path,branch_name)
commit_changed_files = get_changed_files_from_latest_commit(repo_path, branch_name)
if commit_changed_files:
print(f"Latest commit '{commit_changed_files['title']}' ({commit_changed_files['hash']}) in branch '{branch_name}' has the following changed files:")
print(
f"Latest commit '{commit_changed_files['title']}' ({commit_changed_files['hash']}) in branch '{branch_name}' has the following changed files:"
)
print(f"{commit_changed_files['files']}")
else:
print(f"No changed files found in the commit '{commit_changed_files['title']}' ({commit_changed_files['hash']}) in branch '{branch_name}'. Exiting without upload.")
print(
f"No changed files found in the commit '{commit_changed_files['title']}' ({commit_changed_files['hash']}) in branch '{branch_name}'. Exiting without upload."
)
sys.exit(0)
print(f"Checking for changes to file: {sbom_path} ({sbom_repo_path})")
sbom_file_changed = sbom_repo_path in commit_changed_files["files"]
@ -163,7 +235,10 @@ def run(
)
else:
print("--dry-run enabled, skipping upload.")
print(f"File '{sbom_repo_path}'" + (" was modified." if sbom_file_changed else " was not modified."))
print(
f"File '{sbom_repo_path}'"
+ (" was modified." if sbom_file_changed else " was not modified.")
)
if dry_run:
print("Upload metadata:")
@ -173,12 +248,15 @@ def run(
print(f" Container Command: {container_command}")
print(f" Container Image: {container_image}")
print(f" Workdir: {workdir}")
if check_sbom_file_change:
print(f"Latest commit '{commit_changed_files['title']}' ({commit_changed_files['hash']})")
if check_sbom_file_change:
print(
f"Latest commit '{commit_changed_files['title']}' ({commit_changed_files['hash']})"
)
except Exception as e:
print(f"Exception during script execution: {e}")
sys.exit(1)
if __name__ == "__main__":
app()
app()

View File

@ -13,8 +13,8 @@ def url_exists(url, timeout=5):
except requests.RequestException:
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download and verify S3 binary.")
parser.add_argument("s3_path", help="S3 URL to download from")
parser.add_argument("local_path", nargs="?", help="Optional output file path")
@ -23,4 +23,4 @@ if __name__ == "__main__":
if url_exists(args.s3_path):
if not download_s3_binary(args.s3_path, args.local_path, True):
sys.exit(1)
sys.exit(1)

View File

@ -26,11 +26,13 @@ from proxyprotocol.server.main import main
# [1]: https://github.com/python/cpython/blob/5c19c5bac6abf3da97d1d9b80cfa16e003897096/Lib/asyncio/base_events.py#L1429
original_create_server = BaseEventLoop.create_server
async def monkeypatched_create_server(self, protocol_factory, host, port, *args, **kwargs):
result = await original_create_server(self, protocol_factory, host, port, *args, **kwargs)
print(f'Now listening on {host}:{port}')
print(f"Now listening on {host}:{port}")
return result
if __name__ == "__main__":
BaseEventLoop.create_server = monkeypatched_create_server
sys.exit(main())

View File

@ -125,12 +125,16 @@ class DatabaseInstance:
check=True,
)
async def analyze_field(self, collection_name: str, field: str, number_buckets: int = 1000) -> None:
async def analyze_field(
self, collection_name: str, field: str, number_buckets: int = 1000
) -> None:
"""
Run 'analyze' on a given field.
Analyze is currently not persisted across restarts, or when dumping or restoring.
"""
await self.database.command({"analyze": collection_name, "key": field, "numberBuckets": number_buckets})
await self.database.command(
{"analyze": collection_name, "key": field, "numberBuckets": number_buckets}
)
async def set_parameter(self, name: str, value: any) -> None:
"""Set MongoDB Parameter."""
@ -182,7 +186,12 @@ class DatabaseInstance:
"""Drop collection."""
await self.database[collection_name].drop()
async def insert_many(self, collection_name: str, docs: Sequence[Mapping[str, any]], context_manager: ContextManager) -> None:
async def insert_many(
self,
collection_name: str,
docs: Sequence[Mapping[str, any]],
context_manager: ContextManager,
) -> None:
"""Insert documents into the collection with the given name.
The context_manager can be:
- `contextlib.nullcontext` to enable maximum concurrency

View File

@ -90,7 +90,6 @@ class FieldStatistic:
class FieldStatisticByScalarType:
def __init__(self):
self.min = None
self.max = None
@ -152,8 +151,7 @@ def serialize_supported(v):
return v.isoformat()
elif isinstance(v, (bson.decimal128.Decimal128, decimal.Decimal)):
return str(v)
elif isinstance(v,
(bson.datetime_ms.DatetimeMS, bson.timestamp.Timestamp)):
elif isinstance(v, (bson.datetime_ms.DatetimeMS, bson.timestamp.Timestamp)):
return v.as_datetime().replace(tzinfo=datetime.timezone.utc).timestamp()
elif issubclass(type(v), Enum):
# We expect that the Enum will have a __repr__ method that returns

View File

@ -188,7 +188,9 @@ class Specification:
)
return f"{self.type.__name__}({', '.join(specs)})"
async def analyze(self, database_instance: DatabaseInstance, collection_name: str, field_name: str) -> None:
async def analyze(
self, database_instance: DatabaseInstance, collection_name: str, field_name: str
) -> None:
"""
Runs 'analyze' on all fields in the Specification.

View File

@ -49,9 +49,9 @@ class RangeGenerator(typing.Generic[TVar]):
ndv: int = -1
def __post_init__(self):
assert type(self.interval_begin) == type(self.interval_end), (
"Interval ends must of the same type."
)
assert type(self.interval_begin) == type(
self.interval_end
), "Interval ends must of the same type."
if type(self.interval_begin) == int or type(self.interval_begin) == float:
self.ndv = round((self.interval_end - self.interval_begin) / self.step)
elif type(self.interval_begin) == datetime.datetime:

View File

@ -131,7 +131,7 @@ async def upstream(
collection_name: str,
source: typing.Generator,
count: int,
context_manager: ContextManager
context_manager: ContextManager,
):
"""Bulk insert generated objects into a collection."""
import dataclasses
@ -150,7 +150,7 @@ async def upstream(
datagen.serialize.serialize_doc(dataclasses.asdict(next(source)))
for _ in range(num)
],
context_manager
context_manager,
)
)
)
@ -184,10 +184,14 @@ async def main():
)
parser.add_argument("--drop", action="store_true", help="Drop the collection before inserting.")
parser.add_argument("--dump", nargs="?", const="", help="Dump the collection after inserting.")
parser.add_argument("--analyze", action="store_true", help="""
parser.add_argument(
"--analyze",
action="store_true",
help="""
Run the 'analyze' command against each field of the collection.
Analyze is not preserved across restarts, or when dumping or restoring.
""")
""",
)
parser.add_argument("--indexes", action="append", help="An index set to load.")
parser.add_argument("--restore-args", type=str, help="Parameters to pass to mongorestore.")
parser.add_argument(
@ -215,7 +219,12 @@ async def main():
help="Number of objects to generate. Set to 0 to skip data generation.",
)
parser.add_argument("--seed", type=str, help="The seed to use.")
parser.add_argument("--serial-inserts", action='store_true', default=False, help="Force single-threaded insertion")
parser.add_argument(
"--serial-inserts",
action="store_true",
default=False,
help="Force single-threaded insertion",
)
args = parser.parse_args()
module = importlib.import_module(args.module)
@ -249,7 +258,9 @@ async def main():
generator_factory = CorrelatedGeneratorFactory(spec, seed)
generator = generator_factory.make_generator()
context_manager = asyncio.Semaphore(1) if args.serial_inserts else nullcontext()
await upstream(database_instance, collection_name, generator, args.size, context_manager)
await upstream(
database_instance, collection_name, generator, args.size, context_manager
)
generator_factory.dump_metadata(collection_name, args.size, seed, metadata_path)
# 3. Create indexes after documents.
@ -284,6 +295,7 @@ async def main():
for field in fields(spec):
await field.type.analyze(database_instance, collection_name, field.name)
if __name__ == "__main__":
import asyncio

View File

@ -12,13 +12,15 @@ class PartiallyCorrelated:
@staticmethod
def make_field1(fkr: faker.proxy.Faker) -> int:
return fkr.random.choice(['a','b'])
return fkr.random.choice(["a", "b"])
@staticmethod
def make_field2(fkr: faker.proxy.Faker) -> str:
return fkr.random_element(
collections.OrderedDict([
(uncorrelated_faker().random.choice(['c','d']), 0.1),
(fkr.random.choice(['a','b']), 0.9)
])
collections.OrderedDict(
[
(uncorrelated_faker().random.choice(["c", "d"]), 0.1),
(fkr.random.choice(["a", "b"]), 0.9),
]
)
)

View File

@ -15,26 +15,17 @@ from faker import Faker
NUM_FIELDS = 48
# 50% chance of no correlation
CORRELATIONS = ['a', 'b', 'c', None, None, None]
CORRELATIONS = ["a", "b", "c", None, None, None]
class mixed:
"""Used to designate mixed-type fields"""
AVAILABLE_TYPES = [
str,
int,
bool,
datetime,
Timestamp,
Decimal128,
list,
dict,
mixed
]
AVAILABLE_TYPES = [str, int, bool, datetime, Timestamp, Decimal128, list, dict, mixed]
START_DATE = datetime(2024, 1, 1, tzinfo=timezone.utc)
END_DATE = datetime(2025, 12, 31, tzinfo=timezone.utc)
END_DATE = datetime(2025, 12, 31, tzinfo=timezone.utc)
# Ideally we would want to seed our uncorrelated Faker based on the --seed argument to driver.py
# but it is not available here.
@ -42,14 +33,16 @@ ufkr = Faker()
ufkr.seed_instance(1)
universal_generators = {
'missing' : lambda fkr: MISSING,
'null' : lambda fkr: None,
"missing": lambda fkr: MISSING,
"null": lambda fkr: None,
}
def pareto(fkr) -> int:
"""In the absence of a Zipfian implementation to generate skewed datasets, we use pareto"""
return int(fkr.random.paretovariate(2))
def lambda_sources(l: Specification) -> str:
"""Returns the code of the lambdas that participate in generating the values of a Specification."""
signature = inspect.signature(l.source)
@ -59,48 +52,53 @@ def lambda_sources(l: Specification) -> str:
for generator, probability in params[1].default.items()
)
type_generators: dict[type, dict[str, Callable]] = {
str: {
'p1' : lambda fkr: ascii_lowercase[min(25, pareto(fkr) % 26)],
's1' : lambda fkr: fkr.pystr(min_chars=1, max_chars=1),
's2' : lambda fkr: fkr.pystr(min_chars=1, max_chars=2),
's4' : lambda fkr: fkr.pystr(min_chars=1, max_chars=4),
},
"p1": lambda fkr: ascii_lowercase[min(25, pareto(fkr) % 26)],
"s1": lambda fkr: fkr.pystr(min_chars=1, max_chars=1),
"s2": lambda fkr: fkr.pystr(min_chars=1, max_chars=2),
"s4": lambda fkr: fkr.pystr(min_chars=1, max_chars=4),
},
int: {
'const1': lambda fkr: 1,
'i10': lambda fkr: fkr.random_int(min=1, max=10),
'i100': lambda fkr: fkr.random_int(min=1, max=100),
'i1000': lambda fkr: fkr.random_int(min=1, max=1000),
'i10000': lambda fkr: fkr.random_int(min=1, max=10000),
'i100000': lambda fkr: fkr.random_int(min=1, max=100000),
'pareto': pareto
},
"const1": lambda fkr: 1,
"i10": lambda fkr: fkr.random_int(min=1, max=10),
"i100": lambda fkr: fkr.random_int(min=1, max=100),
"i1000": lambda fkr: fkr.random_int(min=1, max=1000),
"i10000": lambda fkr: fkr.random_int(min=1, max=10000),
"i100000": lambda fkr: fkr.random_int(min=1, max=100000),
"pareto": pareto,
},
bool: {
'br': lambda fkr: fkr.boolean(),
'b10': lambda fkr: fkr.boolean(10),
'b100': lambda fkr: fkr.boolean(1),
'b1000': lambda fkr: fkr.boolean(0.1),
'b10000': lambda fkr: fkr.boolean(0.01),
'b100000': lambda fkr: fkr.boolean(0.001),
"br": lambda fkr: fkr.boolean(),
"b10": lambda fkr: fkr.boolean(10),
"b100": lambda fkr: fkr.boolean(1),
"b1000": lambda fkr: fkr.boolean(0.1),
"b10000": lambda fkr: fkr.boolean(0.01),
"b100000": lambda fkr: fkr.boolean(0.001),
},
datetime: {
'dt_pareto': lambda fkr: START_DATE + timedelta(days=pareto(fkr)),
"dt_pareto": lambda fkr: START_DATE + timedelta(days=pareto(fkr)),
},
Timestamp: {
# Note that we can not generate timestamps with i > 0 as the i is not preserved in the .schema file
'ts_const': lambda fkr: Timestamp(fkr.random_element([START_DATE, END_DATE]), 0),
'ts_triangular': lambda fkr: Timestamp(fkr.random.triangular(START_DATE, END_DATE, END_DATE), 0)
},
Decimal128: {
'decimal_pareto': lambda fkr: Decimal128(f"{pareto(fkr)}.{pareto(fkr)}")
"ts_const": lambda fkr: Timestamp(fkr.random_element([START_DATE, END_DATE]), 0),
"ts_triangular": lambda fkr: Timestamp(
fkr.random.triangular(START_DATE, END_DATE, END_DATE), 0
),
},
Decimal128: {"decimal_pareto": lambda fkr: Decimal128(f"{pareto(fkr)}.{pareto(fkr)}")},
list: {
'list_int_pareto': lambda fkr: [pareto(fkr) for _ in range(pareto(fkr) % 10)],
'list_str_pareto': lambda fkr: [ascii_lowercase[min(25, pareto(fkr) % 26)] for _ in range(pareto(fkr) % 10)],
"list_int_pareto": lambda fkr: [pareto(fkr) for _ in range(pareto(fkr) % 10)],
"list_str_pareto": lambda fkr: [
ascii_lowercase[min(25, pareto(fkr) % 26)] for _ in range(pareto(fkr) % 10)
],
},
dict: {
'dict_str_pareto': lambda fkr: {ascii_lowercase[min(25, pareto(fkr) % 26)]: pareto(fkr) for _ in range(pareto(fkr) % 10)}
}
"dict_str_pareto": lambda fkr: {
ascii_lowercase[min(25, pareto(fkr) % 26)]: pareto(fkr) for _ in range(pareto(fkr) % 10)
}
},
}
specifications = {}
@ -114,8 +112,7 @@ for f in range(NUM_FIELDS):
if chosen_type is mixed:
available_generators = [
generator for type in type_generators.values()
for generator in type.values()
generator for type in type_generators.values() for generator in type.values()
]
else:
available_generators = list(type_generators[chosen_type].values())
@ -146,9 +143,7 @@ for f in range(NUM_FIELDS):
chosen_generator = fkr.random_element(generators)
return chosen_generator(fkr)
specification = Specification(chosen_type,
correlation=chosen_correlation,
source=source)
specification = Specification(chosen_type, correlation=chosen_correlation, source=source)
# pylint: disable=invalid-name
field_name = f"field{f}_{chosen_type.__name__}"
@ -168,16 +163,14 @@ for field_name, specification in specifications.items():
# Convert the dictionary into a dataclass that driver.py can then use.
plan_stability2 = dataclasses.make_dataclass(
"plan_stability2", # Name of the dataclass
specifications.items()
specifications.items(),
)
def indexes() -> list[pymongo.IndexModel]:
"""Return a set of pymongo.IndexModel objects that the data generator will create."""
indexed_fields = [
field_name for field_name in specifications if "idx" in field_name
]
indexed_fields = [field_name for field_name in specifications if "idx" in field_name]
assert len(indexed_fields) > 0
chosen_indexes: dict[str, pymongo.IndexModel] = {}
@ -186,28 +179,23 @@ def indexes() -> list[pymongo.IndexModel]:
# The first field of each index is one of the fields we definitely
# want to be indexed ...
chosen_fields: dict[str, int] = {
indexed_field:
ufkr.random_element([pymongo.ASCENDING, pymongo.DESCENDING])
indexed_field: ufkr.random_element([pymongo.ASCENDING, pymongo.DESCENDING])
}
# ... and we will make some indexes multi-field by tacking on more fields.
secondary_field_count = round(ufkr.random.triangular(low=0, high=2,
mode=0))
secondary_field_count = round(ufkr.random.triangular(low=0, high=2, mode=0))
for _ in range(secondary_field_count):
secondary_field = ufkr.random_element(indexed_fields)
if secondary_field in chosen_fields:
continue
has_array_field = any("mixed" in f or "list" in f
for f in chosen_fields)
has_array_field = any("mixed" in f or "list" in f for f in chosen_fields)
if ("mixed" in secondary_field
or "list" in secondary_field) and has_array_field >= 1:
# We can not have two array fields in a compound index
continue
if ("mixed" in secondary_field or "list" in secondary_field) and has_array_field >= 1:
# We can not have two array fields in a compound index
continue
secondary_dir = ufkr.random_element(
[pymongo.ASCENDING, pymongo.DESCENDING])
secondary_dir = ufkr.random_element([pymongo.ASCENDING, pymongo.DESCENDING])
chosen_fields[secondary_field] = secondary_dir

View File

@ -40,10 +40,12 @@ from datagen.util import Specification
class TestGrandchild:
g: Specification(int)
@dataclasses.dataclass
class TestChild:
c: Specification(TestGrandchild)
@dataclasses.dataclass
class Test:
i: Specification(int)
@ -53,6 +55,7 @@ class Test:
def compute_il(fkr: faker.proxy.Faker) -> float:
return 1
def test_index() -> list[pymongo.IndexModel]:
return [
pymongo.IndexModel(keys="i", name="i_idx"),

View File

@ -48,37 +48,39 @@ class TestEnum(enum.Enum):
def __repr__(self) -> str:
return self.name
class TestIntEnum(enum.IntEnum):
A = enum.auto()
def __repr__(self) -> str:
return str(self.value)
@dataclasses.dataclass
class NestedObject:
str_field: Specification(str, source=lambda fkr: "A")
@dataclasses.dataclass
class TypesTest:
float_field: Specification(float, source=lambda fkr: float(1.1))
int_field: Specification(int, source=lambda fkr: 1)
str_field: Specification(str, source=lambda fkr: "A")
bool_field: Specification(bool, source=lambda fkr: True)
datetime_datetime_field: Specification(datetime.datetime,
source=lambda fkr: DATETIME)
datetime_datetime_field: Specification(datetime.datetime, source=lambda fkr: DATETIME)
bson_datetime_ms_field: Specification(
datetime, source=lambda fkr: bson.datetime_ms.DatetimeMS(DATETIME))
datetime, source=lambda fkr: bson.datetime_ms.DatetimeMS(DATETIME)
)
bson_timestamp_field: Specification(
bson.timestamp.Timestamp,
source=lambda fkr: bson.timestamp.Timestamp(DATETIME, 123))
bson.timestamp.Timestamp, source=lambda fkr: bson.timestamp.Timestamp(DATETIME, 123)
)
bson_decimal128_field: Specification(
bson.decimal128.Decimal128,
source=lambda fkr: bson.decimal128.Decimal128("1.1"))
bson.decimal128.Decimal128, source=lambda fkr: bson.decimal128.Decimal128("1.1")
)
array_field: Specification(list, source=lambda fkr: [1, 2])
obj_field: Specification(NestedObject)
dict_field: Specification(dict, source=lambda fkr: {'a': 1})
dict_field: Specification(dict, source=lambda fkr: {"a": 1})
enum_field: Specification(TestEnum, source=lambda fkr: TestEnum.A)
int_enum_field: Specification(TestIntEnum,
source=lambda fkr: TestIntEnum.A)
int_enum_field: Specification(TestIntEnum, source=lambda fkr: TestIntEnum.A)
null_field: Specification(type(None), source=lambda fkr: None)
missing_field: Specification(type(None), source=lambda fkr: MISSING)

View File

@ -11,8 +11,8 @@ class Uncorrelated:
@staticmethod
def make_field1(fkr: faker.proxy.Faker) -> int:
return fkr.random.choice(['a','b'])
return fkr.random.choice(["a", "b"])
@staticmethod
def make_field2(fkr: faker.proxy.Faker) -> int:
return fkr.random.choice(['a','b'])
return fkr.random.choice(["a", "b"])

View File

@ -208,16 +208,24 @@ class MongoTidyTests(unittest.TestCase):
self.run_clang_tidy()
def test_MongoBannedNamesCheck(self):
stdx_replacement_str = "Consider using alternatives such as the polyfills from the mongo::stdx:: namespace."
stdx_replacement_str = (
"Consider using alternatives such as the polyfills from the mongo::stdx:: namespace."
)
test_names = [
("std::get_terminate()", stdx_replacement_str),
("std::future<int> myFuture", "Consider using mongo::Future instead."),
("std::recursive_mutex recursiveMut", "Do not use. A recursive mutex is often an indication of a design problem and is prone to deadlocks because you don't know what code you are calling while holding the lock."),
(
"std::recursive_mutex recursiveMut",
"Do not use. A recursive mutex is often an indication of a design problem and is prone to deadlocks because you don't know what code you are calling while holding the lock.",
),
("const std::condition_variable cv", stdx_replacement_str),
("static std::unordered_map<int, int> myMap", stdx_replacement_str),
("boost::unordered_map<int, int> boostMap", stdx_replacement_str),
("std::regex_search(std::string(\"\"), std::regex(\"\"))", "Consider using mongo::pcre::Regex instead."),
(
'std::regex_search(std::string(""), std::regex(""))',
"Consider using mongo::pcre::Regex instead.",
),
("std::atomic<int> atomicVar", "Consider using mongo::Atomic<T> instead."),
("std::optional<std::string> strOpt", "Consider using boost::optional instead."),
("std::atomic<int> fieldDecl", "Consider using mongo::Atomic<T> instead."),
@ -227,7 +235,9 @@ class MongoTidyTests(unittest.TestCase):
self.expected_output = [
"error: Forbidden use of banned name in "
+ name + ". " + msg
+ name
+ ". "
+ msg
+ " Use '// NOLINT' if usage is absolutely necessary. Be especially careful doing so outside of test code."
for (name, msg) in test_names
]
@ -308,7 +318,7 @@ class MongoTidyTests(unittest.TestCase):
self.run_clang_tidy()
def test_MongoBannedAutoGetUsageCheck(self):
self.expected_output = ("AutoGetCollection is not allowed to be used from the query modules. Use ShardRole CollectionAcquisitions instead.")
self.expected_output = "AutoGetCollection is not allowed to be used from the query modules. Use ShardRole CollectionAcquisitions instead."
self.run_clang_tidy()