mirror of https://github.com/mongodb/mongo
300 lines
10 KiB
Python
Executable File
300 lines
10 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
import json
|
|
import multiprocessing
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed, wait
|
|
from datetime import datetime
|
|
from typing import Any, TypedDict
|
|
|
|
import pyzstd
|
|
import typer # nicer error dump on exceptions
|
|
from progressbar import ProgressBar, progressbar
|
|
|
|
REPO_ROOT = os.environ.get(
|
|
"BUILD_WORKSPACE_DIRECTORY", os.path.dirname(os.path.abspath(sys.argv[0])) + "/.."
|
|
)
|
|
|
|
|
|
class Decl(TypedDict):
|
|
display_name: str
|
|
kind: str
|
|
loc: str
|
|
mod: str
|
|
other_mods: dict[str, set[str]] # merged
|
|
used_from: dict[str, set[str]] # merged
|
|
usr: str
|
|
visibility: str
|
|
defined: bool
|
|
|
|
|
|
all_decls: dict[str, Decl] = {}
|
|
|
|
|
|
def merge_decls(decls: list[Decl]):
|
|
for decl in decls:
|
|
decls = [] # hide from traceback
|
|
merge_decl(decl)
|
|
|
|
|
|
def merge_decl(decl: Decl):
|
|
other_mods = decl.get("other_mods", {})
|
|
used_from = decl["used_from"]
|
|
usr = decl["usr"]
|
|
if usr not in all_decls:
|
|
# First time seeing this decl - no merging needed
|
|
for mod in used_from:
|
|
if type(used_from[mod]) != set:
|
|
used_from[mod] = set(used_from[mod])
|
|
all_decls[usr] = decl
|
|
return
|
|
|
|
old = all_decls[usr]
|
|
|
|
# Merge used_from into old_used_from
|
|
old_used_from = old["used_from"]
|
|
for mod, locs in used_from.items():
|
|
if not mod:
|
|
mod = "__NONE__"
|
|
old_used_from.setdefault(mod, set()).update(locs)
|
|
|
|
old_other_mods = old.get("other_mods", {})
|
|
|
|
# Merge other_mods into old_other_mods
|
|
for other, val in other_mods.items():
|
|
if isinstance(val, set):
|
|
old_other_mods.setdefault(other, set()).update(val)
|
|
else:
|
|
old_other_mods.setdefault(other, set()).add(val)
|
|
|
|
mod = decl["mod"]
|
|
replace = decl["defined"] and not old["defined"]
|
|
if replace:
|
|
# Make this the primary decl, even if from same mod
|
|
all_decls[usr] = decl
|
|
decl["used_from"] = old_used_from
|
|
if decl["loc"] != old["loc"]:
|
|
old_other_mods.setdefault(old["mod"], set()).add(old["loc"])
|
|
if mod in old_other_mods and decl["loc"] in old_other_mods[mod]:
|
|
old_other_mods[mod].remove(decl["loc"])
|
|
if not old_other_mods[mod]:
|
|
del old_other_mods[mod]
|
|
if old_other_mods:
|
|
decl["other_mods"] = old_other_mods
|
|
else:
|
|
if decl["loc"] != old["loc"]:
|
|
old_other_mods.setdefault(mod, set()).add(decl["loc"])
|
|
if old_other_mods:
|
|
old["other_mods"] = old_other_mods
|
|
|
|
# assert decl["loc"] == old["loc"]
|
|
assert (
|
|
decl["kind"] == old["kind"]
|
|
or (decl["kind"] == "CLASS_DECL" and old["kind"] == "CLASS_TEMPLATE")
|
|
or (decl["kind"] == "CLASS_TEMPLATE" and old["kind"] == "CLASS_DECL")
|
|
)
|
|
# assert decl["display_name"] == old["display_name"] # TODO ugh sometimes mongo:: screws it up
|
|
|
|
|
|
class Timer:
|
|
def __init__(self):
|
|
self.start = datetime.now()
|
|
|
|
def mark(self, label: str):
|
|
if 1:
|
|
elapsed = datetime.now() - self.start
|
|
print(f"{label}: {elapsed}")
|
|
|
|
|
|
def worker(paths: list[bytes]):
|
|
# for path in paths:
|
|
while True:
|
|
try:
|
|
path = paths.pop()
|
|
except IndexError:
|
|
return list(all_decls.values())
|
|
|
|
with pyzstd.ZstdFile(path, read_size=2 * 1024 * 1024) as f:
|
|
merge_decls(json.loads(f.read()))
|
|
|
|
|
|
def parent_mod(mod: str):
|
|
return mod[: mod.rindex(".")]
|
|
|
|
|
|
def is_submodule_usage(decl_mod: str, *, usage_mod: str) -> bool:
|
|
"""usage_mod is keyword-only to avoid mixing up str arguments"""
|
|
return decl_mod == usage_mod or usage_mod.startswith(decl_mod + ".")
|
|
|
|
|
|
def get_paths(timer: Timer):
|
|
# Retry the bazel build up to 3 times to handle transient failures
|
|
max_retries = 3
|
|
|
|
for attempt in range(1, max_retries + 1):
|
|
print(f"Bazel build attempt {attempt}/{max_retries}...")
|
|
proc = subprocess.run(
|
|
["bazel", "build", "--config=mod-scanner", "//src/mongo/..."],
|
|
text=True, # unnecessary since we don't use stdout, but makes the types match
|
|
cwd=REPO_ROOT,
|
|
check=False,
|
|
)
|
|
|
|
if proc.returncode == 0:
|
|
break
|
|
|
|
if attempt < max_retries:
|
|
print(f"Bazel build failed with exit code {proc.returncode}, " f"retrying...")
|
|
else:
|
|
print(f"Bazel build failed after {max_retries} attempts")
|
|
sys.exit(proc.returncode)
|
|
|
|
timer.mark("scanned sources")
|
|
|
|
proc = subprocess.run(
|
|
[
|
|
"bazel",
|
|
"aquery",
|
|
"--config=mod-scanner",
|
|
'outputs(".*mod_scanner_decls.json.*", mnemonic(ModScanner, //src/mongo/...))',
|
|
"--noinclude_commandline",
|
|
"--noinclude_artifacts",
|
|
],
|
|
capture_output=True,
|
|
text=True,
|
|
cwd=REPO_ROOT,
|
|
check=True,
|
|
)
|
|
|
|
outputs = []
|
|
for line in proc.stdout.split("\n"):
|
|
if line.startswith(" Environment:") and "MOD_SCANNER_OUTPUT=" in line:
|
|
m = re.search("MOD_SCANNER_OUTPUT=([^,]+),", line)
|
|
if m:
|
|
outputs.append(REPO_ROOT + "/" + m.group(1))
|
|
timer.mark("queried bazel for mod_scanner outputs")
|
|
return outputs
|
|
|
|
|
|
def get_file_family_regex(path: str) -> re.Pattern:
|
|
# file_base is the portion of the file name that defines the family
|
|
# e.g. bazel-out/blah/src/mongo/db/foo_details.h -> src/mongo/db/foo
|
|
file_base = path.split(".")[0]
|
|
if index := file_base.index("src/mongo/"):
|
|
file_base = file_base[index:]
|
|
file_base = re.sub(r"_(internal|detail)s?$", "", file_base)
|
|
assert file_base.startswith("src/mongo/")
|
|
|
|
file_family_regex = re.compile(
|
|
rf"(?:.*/)?{file_base}(?:_(?:internals?|details?|test|bm|mock)(_.*)?)?\."
|
|
)
|
|
assert file_family_regex.match(path) # sanity check
|
|
|
|
return file_family_regex
|
|
|
|
|
|
def main(
|
|
jobs: int = typer.Option(os.cpu_count(), "--jobs", "-j"),
|
|
intra_module: bool = typer.Option(True, help="Include intra-module accesses"),
|
|
):
|
|
timer = Timer()
|
|
paths = get_paths(timer)
|
|
num_paths = len(paths)
|
|
|
|
if jobs > 1:
|
|
with multiprocessing.Manager() as manager:
|
|
with ProcessPoolExecutor(jobs) as pool:
|
|
workers = set()
|
|
shared_paths = manager.list(paths)
|
|
for _ in range(jobs):
|
|
workers.add(pool.submit(worker, shared_paths)) # type:ignore
|
|
|
|
with ProgressBar(max_value=num_paths, prefix="processing inputs: ") as bar:
|
|
while True:
|
|
done, _ = wait(workers, timeout=0.1, return_when="FIRST_EXCEPTION")
|
|
for d in done:
|
|
if d.exception():
|
|
raise d.exception()
|
|
|
|
remaining_files = len(shared_paths) + jobs - len(done)
|
|
bar.update(num_paths - remaining_files)
|
|
if remaining_files == 0:
|
|
break
|
|
|
|
timer.mark("all paths consumed")
|
|
for result in as_completed(workers):
|
|
merge_decls(result.result())
|
|
timer.mark("merged results")
|
|
|
|
else:
|
|
for path in progressbar(paths):
|
|
with pyzstd.ZstdFile(path, read_size=2 * 1024 * 1024) as f:
|
|
merge_decls(json.loads(f.read()))
|
|
timer.mark("processed input")
|
|
|
|
out: Any = [dict(d) for d in all_decls.values()] # shallow copy each decl
|
|
if not intra_module:
|
|
for decl in out:
|
|
decl["used_from"] = {
|
|
mod: locs
|
|
for mod, locs in decl["used_from"].items()
|
|
if not is_submodule_usage(decl["mod"], usage_mod=mod)
|
|
}
|
|
out = [d for d in out if d["used_from"]]
|
|
|
|
for decl in out:
|
|
# go from {$MOD: $LOCS} map to [{mod: $MOD, locs: $LOCS}] list of
|
|
# objects which is easier to work with in mongo aggregations
|
|
decl["used_from"] = [{"mod": k, "locs": sorted(v)} for k, v in decl["used_from"].items()] # type: ignore
|
|
if "other_mods" in decl:
|
|
decl["other_mods"] = {k: sorted(v) for k, v in decl["other_mods"].items()} # type: ignore
|
|
timer.mark("massaged output for json")
|
|
|
|
with open(f"{REPO_ROOT}/merged_decls.json", "w") as f:
|
|
json.dump(out, f)
|
|
timer.mark("dumped json")
|
|
|
|
found_violations = False
|
|
for decl in sorted(all_decls.values(), key=lambda d: d["display_name"]):
|
|
violations = []
|
|
match decl["visibility"]:
|
|
case "private" | "parent_private":
|
|
decl_mod = decl["mod"]
|
|
if decl["visibility"] == "parent_private":
|
|
decl_mod = parent_mod(decl_mod)
|
|
err = f"Illegal use of {decl['display_name']} outside of module {decl_mod}:"
|
|
for mod, locs in decl["used_from"].items():
|
|
if not is_submodule_usage(decl_mod, usage_mod=mod):
|
|
for loc in locs:
|
|
violations.append(f" {loc} ({mod})")
|
|
|
|
case "file_private":
|
|
err = f"Illegal use of {decl['display_name']} outside of its file family:"
|
|
|
|
file_family_regex = get_file_family_regex(decl["loc"])
|
|
|
|
for mod, locs in decl["used_from"].items():
|
|
for loc in locs:
|
|
# Must be in the same module even if file family matches.
|
|
# This helps prevent accidental matches.
|
|
if mod != decl["mod"] or not file_family_regex.match(loc):
|
|
violations.append(f" {loc} ({mod})")
|
|
case _: # ignore other visibility types
|
|
continue
|
|
|
|
if violations:
|
|
found_violations = True
|
|
print(err)
|
|
print(f" loc: {decl['loc']}")
|
|
print(" usages:")
|
|
print("\n".join(violations))
|
|
timer.mark("checked for privacy violations")
|
|
|
|
sys.exit(found_violations) # bools are ints, so False(0) is success and True(1) is failure
|
|
|
|
|
|
if __name__ == "__main__":
|
|
typer.run(main)
|