diff --git a/.gitignore b/.gitignore index 9ecdc1e833..ce16eaafbf 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,8 @@ tools/ido_recomp/* binary ctx.c graphs/ *.c.m2c +tools/decomp-permuter/ +tools/mips_to_c/ # Assets *.png diff --git a/tools/decomp-permuter/.github/workflows/systray.yml b/tools/decomp-permuter/.github/workflows/systray.yml deleted file mode 100644 index b815d64cc3..0000000000 --- a/tools/decomp-permuter/.github/workflows/systray.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: Systray - -on: - push: - branches: [ main ] - paths: - - 'src/net/cmd/systray/*' - pull_request: - branches: [ main ] - paths: - - 'src/net/cmd/systray/*' - -jobs: - - build: - name: Build on ${{ matrix.os }} - runs-on: ${{ matrix.os }} - strategy: - matrix: - include: - - os: ubuntu-16.04 - binary: permuter-systray-linux - - os: windows-latest - binary: permuter-systray.exe - - os: macos-latest - binary: permuter-systray-macos - steps: - - uses: actions/checkout@main - - - name: Install gtk3 - if: ${{ matrix.os == 'ubuntu-16.04' }} - run: sudo apt-get install libgtk-3-dev libappindicator3-dev - - - name: Setup Go environment - uses: actions/setup-go@v2.1.3 - - - name: Build - run: go build -o ${{ matrix.binary }} -ldflags "-s -w" tray.go - working-directory: src/net/cmd/systray/ - - - name: Upload artifact - uses: actions/upload-artifact@v2 - with: - name: ${{ matrix.binary }} - path: src/net/cmd/systray/${{ matrix.binary }} diff --git a/tools/decomp-permuter/.gitignore b/tools/decomp-permuter/.gitignore deleted file mode 100644 index 2dcae2143e..0000000000 --- a/tools/decomp-permuter/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -*.o -*.s -*.c -*.py[cod] -.mypy_cache/ -.cache/ -__pycache__/ -!test/*.c -/nonmatchings -.vscode/ -pah.conf diff --git a/tools/decomp-permuter/.gitrepo b/tools/decomp-permuter/.gitrepo deleted file mode 100644 index 270e3df564..0000000000 --- a/tools/decomp-permuter/.gitrepo +++ /dev/null @@ -1,12 +0,0 @@ -; DO NOT EDIT (unless you know what you are doing) -; -; This subdirectory is a git "subrepo", and this file is maintained by the -; git-subrepo command. See https://github.com/git-commands/git-subrepo#readme -; -[subrepo] - remote = https://github.com/simonlindholm/decomp-permuter.git - branch = main - commit = a20bac9422b6d8adbf7c06473c2ae3c3fee16be5 - parent = 2668eec556c01fa2f4c16a203c93c208dc03e639 - method = merge - cmdver = 0.4.3 diff --git a/tools/decomp-permuter/.pre-commit-config.yaml b/tools/decomp-permuter/.pre-commit-config.yaml deleted file mode 100644 index 6695f71ac9..0000000000 --- a/tools/decomp-permuter/.pre-commit-config.yaml +++ /dev/null @@ -1,6 +0,0 @@ -repos: -- repo: https://github.com/psf/black - rev: 20.8b1 - hooks: - - id: black - language_version: python3.6 diff --git a/tools/decomp-permuter/LICENSE b/tools/decomp-permuter/LICENSE deleted file mode 100644 index ce6aab3a31..0000000000 --- a/tools/decomp-permuter/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2019 Simon Lindholm and contributors - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/tools/decomp-permuter/README.md b/tools/decomp-permuter/README.md deleted file mode 100644 index eb31c9d652..0000000000 --- a/tools/decomp-permuter/README.md +++ /dev/null @@ -1,120 +0,0 @@ -# Decomp permuter - -Automatically permutes C files to better match a target binary. The permuter has two modes of operation: -- Random: purely at random, introduce temporary variables for values, change types, put statements on the same line... -- Manual: test all combinations of user-specified variations, using macros like `PERM_GENERAL(a = b ? c : d;, if (b) a = c; else a = d;)` to try both specified alternatives. - -The modes can also be combined, by using the `PERM_RANDOMIZE` macro. - -[](https://asciinema.org/a/232846) - -The main target for the tool is MIPS code compiled by old compilers (IDO, possibly GCC). -Getting it to work on other architectures shouldn't be too hard, however. -https://github.com/laqieer/decomp-permuter-arm has an ARM port. - -## Usage - -`./permuter.py directory/` runs the permuter; see below for the meaning of the directory. -Pass `-h` to see possible flags. `-j` is suggested (enables multi-threaded mode). - -You'll first need to install a couple of prerequisites: `python3 -m pip install pycparser pynacl toml` (also `dataclasses` if on Python 3.6 or below) - -The permuter expects as input one or more directory containing: - - a .c file with a single function, - - a .o file to match, - - a .sh file that compiles the .c file. - -For projects with a properly configured makefile, you should be able to set these up by running -``` -./import.py -``` -where file.c contains the function to be permuted, and file.s is its assembly in a self-contained file. -Otherwise, see USAGE.md for more details. - -For projects using Ninja instead of Make, add a `permuter_settings.toml` in the root or `tools/` directory of the project: -```toml -build_system = "ninja" -``` -Then `import.py` should work as expected if `build.ninja` is at the root of the project. - -The .c file may be modified with any of the following macros which affect manual permutation: - -- `PERM_GENERAL(a, b, ...)` expands to any of `a`, `b`, ... -- `PERM_VAR(a, b)` sets the meta-variable `a` to `b`, `PERM_VAR(a)` expands to the meta-variable `a`. -- `PERM_RANDOMIZE(code)` expands to `code`, but allows randomization within that region. Multiple regions may be specified. -- `PERM_LINESWAP(lines)` expands to a permutation of the ordered set of non-whitespace lines (split by `\n`). Each line must contain zero or more complete C statements. (For incomplete statements use `PERM_LINESWAP_TEXT`, which is slower because it has to repeatedly parse C code.) -- `PERM_INT(lo, hi)` expands to an integer between `lo` and `hi` (which must be constants). -- `PERM_IGNORE(code)` expands to `code`, without passing it through the C parser library (pycparser)/randomizer. This can be used to avoid parse errors for non-standard C, e.g. `asm` blocks. -- `PERM_PRETEND(code)` expands to `code` for the purpose of the C parser/randomizer, but gets removed afterwards. This can be used together with `PERM_IGNORE` to enable the permuter to deal with input it isn't designed for (e.g. inline functions, C++, non-code). -- `PERM_ONCE([key,] code)` expands to either `code` or to nothing, such that each unique key gets expanded exactly once. `key` defaults to `code`. For example, `PERM_ONCE(a;) b; PERM_ONCE(a;)` expands to either `a; b;` or `b; a;`. - -Arguments are split by a commas, exluding commas inside parenthesis. `(,)` is a special escape sequence that resolves to `,`. - -Nested macros are allowed, so e.g. -``` -PERM_VAR(delayed, ) -PERM_GENERAL(stmt;, PERM_VAR(delayed, stmt;)) -... -PERM_VAR(delayed) -``` -is an alternative way of writing `PERM_ONCE`. - -## permuter@home - -The permuter supports a distributed mode, where people can donate processor power to your permuter runs to speed them up. -To use this, pass `-J` when running `permuter.py` and follow the instructions. -You will need to be granted access by someone who is already connected to a permuter network. - -To allow others to use your computer for permuter runs, do the following: - -- install Docker (used for sandboxing and to ensure a consistent environment) -- if on Linux, add yourself to the Docker group: `sudo usermod -aG docker $USER` -- install required packages: `python3 -m pip install docker` -- open a terminal, and run `./pah.py run-server` to start the server. - There are a few required arguments (e.g. how many cores to use), see `--help` for more details. - -Please be aware that being in the Docker group implies (password-less) sudo rights. -You can avoid that for your personal account by running the permuter under a separate user. -Unfortunately, there is currently no way to run a sandboxed permuter server without sudo rights. 😢 - -Anyone who is granted access to permuter@home can run a server. - -To set up a new permuter network, see [src/net/controller/README.md](./src/net/controller/README.md). - -## FAQ - -**What do the scores mean?** The scores are computed by taking diffs of objdump'd .o -files, and giving different penalties for lines that are the same/use the same -instruction/are reordered/don't match at all. 0 means the function matches fully. -Stack positions are ignored unless --stack-diffs is passed (but beware that the -permuter is currently quite bad at resolving stack differences). For more details, -see scorer.py. It's far from a perfect system, and should probably be tweaked to -look at e.g. the register diff graph. - -**What sort of non-matchings are the permuter good at?** It's generally best towards -the end, when mostly regalloc changes remain. If there are reorderings or functional -changes, it's often easy to resolve those by hand, and neither the scorer nor the -randomizer tends to play well with them. - -**Should I use this instead of trying to match code by hand?** No, but it can be a good -complement. PERM macros can be used to quickly test lots of variations of a function at -once, in cases where there are interactions between several parts of a function. -The randomization mode often finds lots of nonsensical changes that improve regalloc -"by accident"; it's up to you to pick out the ones that look sensible. If none do, -it can still be useful to know which parts of the function need to be changed to get the -code nearer to matching. Having made one of the improvements, and the function can then be -permuted again, to find further possible improvements. - -## Helping out - -There's tons of room for helping out with the permuter! -Many more randomization passes could be added, the scoring function is far from optimal, -the permuter could be made easier to use, etc. etc. The GitHub Issues list has some ideas. - -Ideally, `mypy permuter.py` and `./run-tests.sh` should succeed with no errors, and files -formatted with `black`. To setup a pre-commit hook for black, run: -``` -pip install pre-commit black -pre-commit install -``` -PRs that skip this are still welcome, however. diff --git a/tools/decomp-permuter/USAGE.md b/tools/decomp-permuter/USAGE.md deleted file mode 100644 index d73b3a8980..0000000000 --- a/tools/decomp-permuter/USAGE.md +++ /dev/null @@ -1,25 +0,0 @@ -This file describes how to manually set up a directory for use with the permuter. -**You probably don't need to do this!** In normal circumstances, `./import.py` -does all this for you. See README.md for more details. - -* create a directory that will contain all of the input files for the invokation -* put a compile command into `/compile.sh` (see e.g. `compile_example.sh`; it will be invoked as `./compile.sh input.c -o output.o`) -* `gcc -E -P -I header_dir -D'__attribute__(x)=' orig_c_file.c > /base.c` -* `python3 strip_other_fns.py /base.c func_name` -* put asm for `func_name` into `/target.s`, with the following header: - -```asm -.set noat -.set noreorder -.set gp=64 -.macro glabel label - .global \label - .type \label, @function - \label: -.endm -``` -* `mips-linux-gnu-as -march=vr4300 -mabi=32 /target.s -o /target.o` -* optional sanity checks: - - `/compile.sh /base.c -o /base.o` - - `./diff.sh /target.o /base.o` -* `./permuter.py ` diff --git a/tools/decomp-permuter/compile_example.sh b/tools/decomp-permuter/compile_example.sh deleted file mode 100755 index f87bdfa18d..0000000000 --- a/tools/decomp-permuter/compile_example.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -mips-linux-gnu-gcc -O2 "$@" diff --git a/tools/decomp-permuter/diff.sh b/tools/decomp-permuter/diff.sh deleted file mode 100755 index 20ae8bba08..0000000000 --- a/tools/decomp-permuter/diff.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash - -if [[ $# < 2 ]]; then - echo "Usage: $0 orig.o new.o [flags]" - exit 1 -fi - -if [ ! -f $1 -o ! -f $2 ]; then - echo Source files not readable - exit 1 -fi - -INPUT1="$1" -INPUT2="$2" -shift -shift -wdiff -n <(python3 ./src/objdump.py "$INPUT1" "$@") <(python3 ./src/objdump.py "$INPUT2" "$@") | colordiff | less -Ric diff --git a/tools/decomp-permuter/import.py b/tools/decomp-permuter/import.py deleted file mode 100755 index 9a314621cf..0000000000 --- a/tools/decomp-permuter/import.py +++ /dev/null @@ -1,808 +0,0 @@ -#!/usr/bin/env python3 -# usage: ./import.py path/to/file.c path/to/asm.s [make flags] -import argparse -from collections import defaultdict -import json -import os -import platform -import re -import shlex -import shutil -import subprocess -import sys -import toml -from typing import Callable, Dict, List, Match, Mapping, Optional, Pattern, Set, Tuple -import urllib.request -import urllib.parse - -from src import ast_util -from src.compiler import Compiler -from src.error import CandidateConstructionFailure - -is_macos = platform.system() == "Darwin" - - -def homebrew_gcc_cpp() -> str: - lookup_paths = ["/usr/local/bin", "/opt/homebrew/bin"] - - for lookup_path in lookup_paths: - try: - return max(f for f in os.listdir(lookup_path) if f.startswith("cpp-")) - except ValueError: - pass - - print( - "Error while looking up in " + ":".join(lookup_paths) + " for cpp- executable" - ) - sys.exit(1) - - -cpp_cmd = homebrew_gcc_cpp() if is_macos else "cpp" -make_cmd = "gmake" if is_macos else "make" - -ASM_PRELUDE: str = """ -.set noat -.set noreorder -.set gp=64 -.macro glabel label - .global \label - .type \label, @function - \label: -.endm -""" - -DEFAULT_AS_CMDLINE: List[str] = ["mips-linux-gnu-as", "-march=vr4300", "-mabi=32"] - -CPP: List[str] = [cpp_cmd, "-P", "-undef"] - -STUB_FN_MACROS: List[str] = [ - "-D_Static_assert(x, y)=", - "-D__attribute__(x)=", - "-DGLOBAL_ASM(...)=", -] - -SETTINGS_FILES = ["permuter_settings.toml", "tools/permuter_settings.toml"] - - -def formatcmd(cmdline: List[str]) -> str: - return " ".join(shlex.quote(arg) for arg in cmdline) - - -def parse_asm(asm_file: str) -> Tuple[str, str]: - func_name = None - asm_lines = [] - try: - with open(asm_file, encoding="utf-8") as f: - cur_section = ".text" - for line in f: - if line.strip().startswith(".section"): - cur_section = line.split()[1] - elif line.strip() in [ - ".text", - ".rdata", - ".rodata", - ".late_rodata", - ".bss", - ".data", - ]: - cur_section = line.strip() - if cur_section == ".text": - if func_name is None and line.strip().startswith("glabel "): - func_name = line.split()[1] - asm_lines.append(line) - except OSError as e: - print("Could not open assembly file:", e, file=sys.stderr) - sys.exit(1) - - if func_name is None: - print( - "Missing function name in assembly file! The file should start with 'glabel function_name'.", - file=sys.stderr, - ) - sys.exit(1) - - if not re.fullmatch(r"[a-zA-Z0-9_$]+", func_name): - print(f"Bad function name: {func_name}", file=sys.stderr) - sys.exit(1) - - return func_name, "".join(asm_lines) - - -def create_directory(func_name: str) -> str: - os.makedirs(f"nonmatchings/", exist_ok=True) - ctr = 0 - while True: - ctr += 1 - dirname = f"{func_name}-{ctr}" if ctr > 1 else func_name - dirname = f"nonmatchings/{dirname}" - try: - os.mkdir(dirname) - return dirname - except FileExistsError: - pass - - -def find_root_dir(filename: str, pattern: List[str]) -> Optional[str]: - old_dirname = None - dirname = os.path.abspath(os.path.dirname(filename)) - - while dirname and (not old_dirname or len(dirname) < len(old_dirname)): - for fname in pattern: - if os.path.isfile(os.path.join(dirname, fname)): - return dirname - old_dirname = dirname - dirname = os.path.dirname(dirname) - - return None - - -def fixup_build_command( - parts: List[str], ignore_part: str -) -> Tuple[List[str], Optional[List[str]]]: - res = [] - skip_count = 0 - assembler = None - for part in parts: - if skip_count > 0: - skip_count -= 1 - continue - if part in ["-MF", "-o"]: - skip_count = 1 - continue - if part == ignore_part: - continue - res.append(part) - - try: - ind0 = min( - i - for i, arg in enumerate(res) - if any( - cmd in arg - for cmd in ["asm_processor", "asm-processor", "preprocess.py"] - ) - ) - ind1 = res.index("--", ind0 + 1) - ind2 = res.index("--", ind1 + 1) - assembler = res[ind1 + 1 : ind2] - res = res[ind0 + 1 : ind1] + res[ind2 + 1 :] - except ValueError: - pass - - return res, assembler - - -def find_build_command_line( - root_dir: str, c_file: str, make_flags: List[str], build_system: str -) -> Tuple[List[str], List[str]]: - if build_system == "make": - build_invocation = [ - make_cmd, - "--always-make", - "--dry-run", - "--debug=j", - "PERMUTER=1", - ] + make_flags - elif build_system == "ninja": - build_invocation = ["ninja", "-t", "commands"] + make_flags - else: - print("Unknown build system '" + build_system + "'.") - sys.exit(1) - - rel_c_file = os.path.relpath(c_file, root_dir) - debug_output = ( - subprocess.check_output(build_invocation, cwd=root_dir) - .decode("utf-8") - .split("\n") - ) - - output = [] - close_match = False - - assembler = DEFAULT_AS_CMDLINE - for line in debug_output: - while "//" in line: - line = line.replace("//", "/") - while "/./" in line: - line = line.replace("/./", "/") - if rel_c_file not in line: - continue - - close_match = True - parts = shlex.split(line) - - # extract actual command from 'bash -c "..."' - if parts[0] == "bash" and "-c" in parts: - for part in parts: - if rel_c_file in part: - parts = shlex.split(part) - break - - if rel_c_file not in parts: - continue - if "-o" not in parts: - continue - if "-fsyntax-only" in parts: - continue - cmdline, asmproc_assembler = fixup_build_command(parts, rel_c_file) - if asmproc_assembler: - assembler = asmproc_assembler - output.append(cmdline) - - if not output: - close_extra = ( - "\n(Found one possible candidate, but didn't match due to " - "either spaces in paths, having -fsyntax-only, or missing an -o flag.)" - if close_match - else "" - ) - print( - "Failed to find compile command from build script output. " - f"Please ensure running '{' '.join(build_invocation)}' " - f"contains a line with the string '{rel_c_file}'.{close_extra}", - file=sys.stderr, - ) - sys.exit(1) - - if len(output) > 1: - output_lines = "\n".join(map(formatcmd, output)) - print( - f"Error: found multiple compile commands for {rel_c_file}:\n{output_lines}\n" - f"Please modify the build script such that '{' '.join(build_invocation)}' " - "produces a single compile command.", - file=sys.stderr, - ) - sys.exit(1) - - return output[0], assembler - - -PreserveMacros = Tuple[Pattern[str], Callable[[str], str]] - - -def build_preserve_macros( - cwd: str, preserve_regex: Optional[str], settings: Mapping[str, object] -) -> Optional[PreserveMacros]: - - subdata = settings.get("preserve_macros", {}) - assert isinstance(subdata, dict) - regexes = [] - for regex, value in subdata.items(): - assert isinstance(value, str) - regexes.append((re.compile(f"^(?:{regex})$"), value)) - - if preserve_regex == "" or (preserve_regex is None and not regexes): - return None - - if preserve_regex is None: - global_regex_text = "(?:" + ")|(?:".join(subdata.keys()) + ")" - else: - global_regex_text = preserve_regex - global_regex = re.compile(f"^(?:{global_regex_text})$") - - def type_fn(macro: str) -> str: - for regex, value in regexes: - if regex.match(macro): - return value - return "int" - - return global_regex, type_fn - - -def preprocess_c_with_macros( - cpp_command: List[str], cwd: str, preserve_macros: PreserveMacros -) -> Tuple[str, List[str]]: - """Import C file, preserving function macros. Subroutine of import_c_file. - - Returns the source code and a list of preserved macros.""" - - preserve_regex, preserve_type_fn = preserve_macros - - # Start by running 'cpp' in a mode that just processes ifdefs and includes. - source = subprocess.check_output( - cpp_command + ["-dD", "-fdirectives-only"], cwd=cwd, encoding="utf-8" - ) - - # Modify function macros that match preserved names so the preprocessor - # doesn't touch them, and at the same time normalize their syntax. Some - # of these instances may be in comments, but that's fine. - def repl(match: Match[str]) -> str: - name = match.group(1) - after = "(" if match.group(2) == "(" else " " - if preserve_regex.match(name): - return f"_permuter define {name}{after}" - else: - return f"#define {name}{after}" - - source = re.sub( - r"^\s*#\s*define\s+([a-zA-Z0-9_]+)([ \t\(]|$)", - repl, - source, - flags=re.MULTILINE, - ) - - # Get rid of auto-inserted macros which the second cpp invocation will - # warn about. - source = re.sub(r"^#define __STDC_.*\n", "", source, flags=re.MULTILINE) - - # Now, run the preprocessor again for real. - source = subprocess.check_output( - CPP + STUB_FN_MACROS, cwd=cwd, encoding="utf-8", input=source - ) - - # Finally, find all function-like defines that we hid (some might have - # been comments, so we couldn't do this before), and construct fake - # function declarations for them in a specially demarcated section of - # the file. When the compiler runs, this section will be replaced by - # the real defines and the preprocessor invoked once more. - late_defines = [] - lines = [] - graph = defaultdict(set) - reg_token = re.compile(r"[a-zA-Z0-9_]+") - for line in source.splitlines(): - is_macro = line.startswith("_permuter define ") - params = [] - if is_macro: - ind1 = line.find("(") - ind2 = line.find(" ", len("_permuter define ")) - ind = min(ind1, ind2) - if ind == -1: - ind = len(line) if ind1 == ind2 == -1 else max(ind1, ind2) - before = line[:ind] - after = line[ind:] - name = before.split()[2] - late_defines.append((name, after)) - if after.startswith("("): - params = [w.strip() for w in after[1 : after.find(")")].split(",")] - else: - lines.append(line) - name = "" - for m in reg_token.finditer(line): - name2 = m.group(0) - has_wildcard = False - if is_macro and name2 not in params: - wcbefore = line[: m.start()].rstrip().endswith("##") - wcafter = line[m.end() :].lstrip().startswith("##") - if wcbefore or wcafter: - graph[name].add(name2 + "*") - has_wildcard = True - if not has_wildcard: - graph[name].add(name2) - - # Prune away (recursively) unused macros, for cleanliness. - used_anywhere = set() - used_by_nonmacro = graph[""] - queue = [""] - while queue: - name = queue.pop() - if name not in used_anywhere: - used_anywhere.add(name) - if name.endswith("*"): - wildcard = name[:-1] - for name2 in graph: - if wildcard in name2: - queue.extend(graph[name2]) - else: - queue.extend(graph[name]) - - def get_decl(name: str, after: str) -> str: - typ = preserve_type_fn(name) - if after.startswith("("): - return f"{typ} {name}();" - else: - return f"extern {typ} {name};" - - used_macros = [name for (name, after) in late_defines if name in used_by_nonmacro] - - return ( - "\n".join( - ["#pragma _permuter latedefine start"] - + [ - f"#pragma _permuter define {name}{after}" - for (name, after) in late_defines - if name in used_anywhere - ] - + [ - get_decl(name, after) - for (name, after) in late_defines - if name in used_by_nonmacro - ] - + ["#pragma _permuter latedefine end"] - + lines - + [""] - ), - used_macros, - ) - - -def import_c_file( - compiler: List[str], - cwd: str, - in_file: str, - preserve_macros: Optional[PreserveMacros], -) -> str: - """Preprocess a C file into permuter-usable source. - - Prints preserved macros as a side effect. - - Returns source for base.c and compilable (macro-expanded) source.""" - in_file = os.path.relpath(in_file, cwd) - include_next = 0 - cpp_command = CPP + [in_file, "-D__sgi", "-D_LANGUAGE_C", "-DNON_MATCHING"] - - for arg in compiler: - if include_next > 0: - include_next -= 1 - cpp_command.append(arg) - continue - if arg in ["-D", "-U", "-I"]: - cpp_command.append(arg) - include_next = 1 - continue - if ( - arg.startswith("-D") - or arg.startswith("-U") - or arg.startswith("-I") - or arg in ["-nostdinc"] - ): - cpp_command.append(arg) - - try: - if preserve_macros is None: - # Simple codepath, should work even if the more complex one breaks. - source = subprocess.check_output( - cpp_command + STUB_FN_MACROS, cwd=cwd, encoding="utf-8" - ) - macros: List[str] = [] - else: - source, macros = preprocess_c_with_macros(cpp_command, cwd, preserve_macros) - - except subprocess.CalledProcessError as e: - print( - "Failed to preprocess input file, when running command:\n" - + formatcmd(e.cmd), - file=sys.stderr, - ) - sys.exit(1) - - if macros: - macro_str = "macros: " + ", ".join(macros) - else: - macro_str = "no macros" - print(f"Preserving {macro_str}. Use --preserve-macros='' to override.") - - return source - - -def prune_source( - source: str, should_prune: bool, func_name: str -) -> Tuple[str, Optional[str]]: - """Normalize the source by round-tripping it through pycparser, and - optionally reduce it to a smaller version that includes only the imported - function and functions/struct/variables that it uses. - - Returns (source, compilable_source).""" - try: - ast = ast_util.parse_c(source, from_import=True) - orig_fn, _ = ast_util.extract_fn(ast, func_name) - if should_prune: - try: - ast_util.prune_ast(orig_fn, ast) - source = ast_util.to_c_raw(ast) - except Exception: - print( - "Source minimization failed! " - "You could try --no-prune as a workaround." - ) - raise - return source, ast_util.to_c(ast, from_import=True) - except CandidateConstructionFailure as e: - print(e.message) - if should_prune and "PERM_" in source: - print( - "Please put in PERM macros after import, otherwise source " - "minimization does not work." - ) - else: - print("Proceeding anyway, but expect errors when permuting!") - return source, None - - -def prune_and_separate_context( - source: str, should_prune: bool, func_name: str -) -> Tuple[str, str]: - """Normalize the source by round-tripping it through pycparser, optionally - reduce it to a smaller version that includes only the imported function and - functions/struct/variables that it uses, and split the result into source - for the function itself, and the rest of the file (the "context"). - - Returns (source, context).""" - try: - ast = ast_util.parse_c(source, from_import=True) - orig_fn, ind = ast_util.extract_fn(ast, func_name) - if should_prune: - try: - ind = ast_util.prune_ast(orig_fn, ast) - except Exception: - print( - "Source minimization failed! " - "You could try --no-prune as a workaround." - ) - raise - del ast.ext[ind] - source = ast_util.to_c(orig_fn, from_import=True) - context = ast_util.to_c(ast, from_import=True) - return source, context - except CandidateConstructionFailure as e: - print(e.message) - print("Unable to split context from source.") - print("Proceeding anyway, but expect compile errors!") - return ast_util.process_pragmas(source), "" - - -def get_decompme_compiler_name( - compiler: List[str], settings: Mapping[str, object], api_base: str -) -> str: - decompme_settings = settings.get("decompme", {}) - assert isinstance(decompme_settings, dict) - compiler_mappings = decompme_settings.get("compilers", {}) - assert isinstance(compiler_mappings, dict) - - compiler_path = compiler[0] - - for path, compiler_name in compiler_mappings.items(): - assert isinstance(compiler_name, str) - if path == compiler_path: - return compiler_name - - try: - with urllib.request.urlopen(f"{api_base}/api/compilers") as f: - json_data = json.load(f) - available = json_data["compiler_ids"] - if not isinstance(available, list): - raise Exception("compiler_ids must be a list") - if not all(isinstance(name, str) for name in available): - raise Exception("compiler_ids must be a list of strings") - except Exception as e: - print(f"Failed to request available compilers from decomp.me:\n{e}") - - print() - print( - f'Unable to map compiler path "{compiler_path}" to something ' - "decomp.me understands." - ) - trail = "permuter_settings.toml, where ... is one of: " + ", ".join(available) - if compiler_mappings: - print( - "Please add an entry:\n\n" - f'"{compiler_path}" = "..."\n\n' - f"to the [decompme.compilers] section of {trail}" - ) - else: - print( - "Please add an section:\n\n" - "[decompme.compilers]\n" - f'"{compiler_path}" = "..."\n\n' - f"to {trail}" - ) - sys.exit(1) - - -def finalize_compile_command(cmdline: List[str]) -> str: - quoted = [arg if arg == "|" else shlex.quote(arg) for arg in cmdline] - ind = (quoted + ["|"]).index("|") - return " ".join(quoted[:ind] + ['"$INPUT"'] + quoted[ind:] + ["-o", '"$OUTPUT"']) - - -def get_compiler_flags(cmdline: List[str]) -> str: - flags = [b for a, b in zip(cmdline, cmdline[1:]) if a != "|" and b != "|"] - return " ".join(shlex.quote(flag) for flag in flags) - - -def write_compile_command(compiler: List[str], cwd: str, out_file: str) -> None: - - with open(out_file, "w", encoding="utf-8") as f: - f.write("#!/usr/bin/env bash\n") - f.write('INPUT="$(realpath "$1")"\n') - f.write('OUTPUT="$(realpath "$3")"\n') - f.write(f"cd {shlex.quote(cwd)}\n") - f.write(finalize_compile_command(compiler)) - os.chmod(out_file, 0o755) - - -def write_asm(asm_cont: str, out_file: str) -> None: - with open(out_file, "w", encoding="utf-8") as f: - f.write(ASM_PRELUDE) - f.write(asm_cont) - - -def compile_asm(assembler: List[str], cwd: str, in_file: str, out_file: str) -> None: - in_file = os.path.abspath(in_file) - out_file = os.path.abspath(out_file) - cmdline = assembler + [in_file, "-o", out_file] - try: - subprocess.check_call(cmdline, cwd=cwd) - except subprocess.CalledProcessError: - print( - f"Failed to assemble .s file, command line:\n{formatcmd(cmdline)}", - file=sys.stderr, - ) - sys.exit(1) - - -def compile_base(compile_script: str, source: str, c_file: str, out_file: str) -> None: - if "PERM_" in source: - print( - "Cannot test-compile imported code because it contains PERM macros. " - "It is recommended to put in PERM macros after import." - ) - return - escaped_c_file = json.dumps(c_file) - source = "#line 1 " + escaped_c_file + "\n" + source - compiler = Compiler(compile_script, show_errors=True) - o_file = compiler.compile(source) - if o_file: - shutil.move(o_file, out_file) - else: - print("Warning: failed to compile .c file.") - - -def write_to_file(cont: str, filename: str) -> None: - with open(filename, "w", encoding="utf-8") as f: - f.write(cont) - - -def main() -> None: - parser = argparse.ArgumentParser( - description="""Import a function for use with the permuter. - Will create a new directory nonmatchings/-/.""" - ) - parser.add_argument( - "c_file", - help="""File containing the function. - Assumes that the file can be built with 'make' to create an .o file.""", - ) - parser.add_argument( - "asm_file", - help="""File containing assembly for the function. - Must start with 'glabel ' and contain no other functions.""", - ) - parser.add_argument( - "make_flags", - nargs="*", - help="Arguments to pass to 'make'. PERMUTER=1 will always be passed.", - ) - parser.add_argument( - "--keep", action="store_true", help="Keep the directory on error." - ) - settings_files = ", ".join(SETTINGS_FILES[:-1]) + " or " + SETTINGS_FILES[-1] - parser.add_argument( - "--preserve-macros", - metavar="REGEX", - dest="preserve_macros_regex", - help=f"""Regex for which macros to preserve, or empty string for no macros. - By default, this is read from {settings_files} in a parent directory of - the imported file. Type information is also read from this file.""", - ) - parser.add_argument( - "--no-prune", - dest="prune", - action="store_false", - help="""Don't minimize the source to keep only the imported function and - functions/struct/variables that it uses. Normally this behavior is - useful to make the permuter faster, but in cases where unrelated code - affects the generated assembly asm it can be necessary to turn off. - Note that regardless of this setting the permuter always removes all - other functions by replacing them with declarations.""", - ) - parser.add_argument( - "--decompme", - dest="decompme", - action="store_true", - help="""Upload the function to decomp.me to share with other people, - instead of importing.""", - ) - args = parser.parse_args() - - root_dir = find_root_dir( - args.c_file, SETTINGS_FILES + ["Makefile", "makefile", "build.ninja"] - ) - - if not root_dir: - print(f"Can't find root dir of project!", file=sys.stderr) - sys.exit(1) - - settings: Mapping[str, object] = {} - for filename in SETTINGS_FILES: - filename = os.path.join(root_dir, filename) - if os.path.exists(filename): - with open(filename) as f: - settings = toml.load(f) - break - - build_system = settings.get("build_system", "make") - compiler = settings.get("compiler_command") - assembler = settings.get("assembler_command") - make_flags = args.make_flags - - func_name, asm_cont = parse_asm(args.asm_file) - print(f"Function name: {func_name}") - - if compiler or assembler: - assert isinstance(compiler, str) - assert isinstance(assembler, str) - assert settings.get("build_system") is None - - compiler = shlex.split(compiler) - assembler = shlex.split(assembler) - else: - assert isinstance(build_system, str) - compiler, assembler = find_build_command_line( - root_dir, args.c_file, make_flags, build_system - ) - - print(f"Compiler: {formatcmd(compiler)} {{input}} -o {{output}}") - print(f"Assembler: {formatcmd(assembler)} {{input}} -o {{output}}") - - preserve_macros = build_preserve_macros( - root_dir, args.preserve_macros_regex, settings - ) - source = import_c_file(compiler, root_dir, args.c_file, preserve_macros) - - if args.decompme: - api_base = os.environ.get("DECOMPME_API_BASE", "https://decomp.me") - compiler_name = get_decompme_compiler_name(compiler, settings, api_base) - source, context = prune_and_separate_context(source, args.prune, func_name) - print("Uploading...") - try: - post_data = urllib.parse.urlencode( - { - "target_asm": asm_cont, - "context": context, - "source_code": source, - "compiler": compiler_name, - "compiler_flags": get_compiler_flags(compiler), - } - ).encode("ascii") - with urllib.request.urlopen(f"{api_base}/api/scratch", post_data) as f: - resp = f.read() - json_data: Dict[str, str] = json.loads(resp) - if "slug" in json_data: - slug = json_data["slug"] - print(f"https://decomp.me/scratch/{slug}") - else: - error = json_data.get("error", resp) - print(f"Server error: {error}") - except Exception as e: - print(e) - return - - source, compilable_source = prune_source(source, args.prune, func_name) - - dirname = create_directory(func_name) - base_c_file = f"{dirname}/base.c" - base_o_file = f"{dirname}/base.o" - target_s_file = f"{dirname}/target.s" - target_o_file = f"{dirname}/target.o" - compile_script = f"{dirname}/compile.sh" - func_name_file = f"{dirname}/function.txt" - - try: - write_to_file(source, base_c_file) - write_to_file(func_name, func_name_file) - write_compile_command(compiler, root_dir, compile_script) - write_asm(asm_cont, target_s_file) - compile_asm(assembler, root_dir, target_s_file, target_o_file) - if compilable_source is not None: - compile_base(compile_script, compilable_source, base_c_file, base_o_file) - except: - if not args.keep: - print(f"\nDeleting directory {dirname} (run with --keep to preserve it).") - shutil.rmtree(dirname) - raise - - print(f"\nDone. Imported into {dirname}") - - -if __name__ == "__main__": - main() diff --git a/tools/decomp-permuter/mypy.ini b/tools/decomp-permuter/mypy.ini deleted file mode 100644 index 2d21c54b76..0000000000 --- a/tools/decomp-permuter/mypy.ini +++ /dev/null @@ -1,27 +0,0 @@ -[mypy] -check_untyped_defs = True -disallow_any_generics = False -disallow_incomplete_defs = True -disallow_subclassing_any = True -disallow_untyped_calls = True -disallow_untyped_decorators = True -disallow_untyped_defs = True -no_implicit_optional = True -warn_redundant_casts = True -warn_return_any = True -warn_unused_ignores = True -mypy_path = stubs -python_version = 3.7 -files = import.py, pah.py, permuter.py, src/net/evaluator.py - -[mypy-nacl.*] -ignore_missing_imports = True - -[mypy-pystray.*] -ignore_missing_imports = True - -[mypy-docker.*] -ignore_missing_imports = True - -[mypy-PIL.*] -ignore_missing_imports = True diff --git a/tools/decomp-permuter/pah.py b/tools/decomp-permuter/pah.py deleted file mode 100755 index 0adbb369f4..0000000000 --- a/tools/decomp-permuter/pah.py +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/env python3 -from src.net.cmd.main import main - -main() diff --git a/tools/decomp-permuter/permuter.py b/tools/decomp-permuter/permuter.py deleted file mode 100755 index 8433c82180..0000000000 --- a/tools/decomp-permuter/permuter.py +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env python3 -from src.main import main - -if __name__ == "__main__": - main() diff --git a/tools/decomp-permuter/permuter_settings_example.toml b/tools/decomp-permuter/permuter_settings_example.toml deleted file mode 100644 index 46523f3315..0000000000 --- a/tools/decomp-permuter/permuter_settings_example.toml +++ /dev/null @@ -1,9 +0,0 @@ -# Optional configuration file for import.py. Put it in the root or in tools/ -# of the repo you are importing from. - -build_system = "ninja" - -[preserve_macros] -"g[DS]P.*" = "void" -"gDma.*" = "void" -"_SHIFTL" = "unsigned int" diff --git a/tools/decomp-permuter/run-tests.sh b/tools/decomp-permuter/run-tests.sh deleted file mode 100755 index eff96637e3..0000000000 --- a/tools/decomp-permuter/run-tests.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/sh -python3 -m unittest discover -s test/ -# python3 -m pytest test/ diff --git a/tools/decomp-permuter/sort_cands.sh b/tools/decomp-permuter/sort_cands.sh deleted file mode 100755 index 2f30a84692..0000000000 --- a/tools/decomp-permuter/sort_cands.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -if [[ $1 < 2 ]]; then - echo "Usage: $0 output_dir" - echo "Ex: $0 nonmatchings/func_80000000" - exit 1 -fi -if [[ ! -d $1 ]]; then - echo "Argument must be a directory" - exit 1 -fi - -find $1 -name score.txt -exec echo -n {}\ \; -exec cat {} \; | sort -rnk2 diff --git a/tools/decomp-permuter/src/__init__.py b/tools/decomp-permuter/src/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tools/decomp-permuter/src/ast_types.py b/tools/decomp-permuter/src/ast_types.py deleted file mode 100644 index 9eb21c28bf..0000000000 --- a/tools/decomp-permuter/src/ast_types.py +++ /dev/null @@ -1,303 +0,0 @@ -"""Functions and classes for dealing with types in a C AST. - -They make a number of simplifying assumptions: - - const and volatile doesn't matter. - - arithmetic promotes all int-like types to 'int'. - - no two variables can have the same name, even across functions. - -For the purposes of the randomizer these restrictions are acceptable.""" - -from dataclasses import dataclass, field -from typing import Union, Dict, Set, List - -from pycparser import c_ast -from pycparser.c_ast import ArrayDecl, TypeDecl, PtrDecl, FuncDecl, IdentifierType - -Type = Union[PtrDecl, ArrayDecl, TypeDecl, FuncDecl] -SimpleType = Union[PtrDecl, TypeDecl] - -StructUnion = Union[c_ast.Struct, c_ast.Union] - - -@dataclass -class TypeMap: - typedefs: Dict[str, Type] = field(default_factory=dict) - fn_ret_types: Dict[str, Type] = field(default_factory=dict) - var_types: Dict[str, Type] = field(default_factory=dict) - struct_defs: Dict[str, StructUnion] = field(default_factory=dict) - - -def basic_type(name: Union[str, List[str]]) -> TypeDecl: - names = [name] if isinstance(name, str) else name - idtype = IdentifierType(names=names) - return TypeDecl(declname=None, quals=[], type=idtype) - - -def pointer(type: Type) -> Type: - return PtrDecl(quals=[], type=type) - - -def resolve_typedefs(type: Type, typemap: TypeMap) -> Type: - while ( - isinstance(type, TypeDecl) - and isinstance(type.type, IdentifierType) - and len(type.type.names) == 1 - and type.type.names[0] in typemap.typedefs - ): - type = typemap.typedefs[type.type.names[0]] - return type - - -def pointer_decay(type: Type, typemap: TypeMap) -> SimpleType: - real_type = resolve_typedefs(type, typemap) - if isinstance(real_type, ArrayDecl): - return PtrDecl(quals=[], type=real_type.type) - if isinstance(real_type, FuncDecl): - return PtrDecl(quals=[], type=type) - if isinstance(real_type, TypeDecl) and isinstance(real_type.type, c_ast.Enum): - return basic_type("int") - assert not isinstance( - type, (ArrayDecl, FuncDecl) - ), "resolve_typedefs can't hide arrays/functions" - return type - - -def get_decl_type(decl: c_ast.Decl) -> Type: - """For a Decl that declares a variable (and not just a struct/union/enum), - return its type.""" - assert decl.name is not None - assert isinstance(decl.type, (PtrDecl, ArrayDecl, FuncDecl, TypeDecl)) - return decl.type - - -def deref_type(type: Type, typemap: TypeMap) -> Type: - type = resolve_typedefs(type, typemap) - assert isinstance(type, (ArrayDecl, PtrDecl)), "dereferencing non-pointer" - return type.type - - -def struct_member_type(struct: StructUnion, field_name: str, typemap: TypeMap) -> Type: - if not struct.decls: - assert ( - struct.name in typemap.struct_defs - ), f"Accessing field {field_name} of undefined struct {struct.name}" - struct = typemap.struct_defs[struct.name] - assert struct.decls, "struct_defs never points to an incomplete type" - for decl in struct.decls: - if isinstance(decl, c_ast.Decl): - if decl.name == field_name: - return get_decl_type(decl) - if decl.name == None and isinstance(decl.type, (c_ast.Struct, c_ast.Union)): - try: - return struct_member_type(decl.type, field_name, typemap) - except AssertionError: - pass - - assert False, f"No field {field_name} in struct {struct.name}" - - -def expr_type(node: c_ast.Node, typemap: TypeMap) -> Type: - def rec(sub_expr: c_ast.Node) -> Type: - return expr_type(sub_expr, typemap) - - if isinstance(node, c_ast.Assignment): - return rec(node.lvalue) - if isinstance(node, c_ast.StructRef): - lhs_type = rec(node.name) - if node.type == "->": - lhs_type = deref_type(lhs_type, typemap) - struct_type = resolve_typedefs(lhs_type, typemap) - assert isinstance(struct_type, TypeDecl) - assert isinstance( - struct_type.type, (c_ast.Struct, c_ast.Union) - ), f"struct deref of non-struct {struct_type.declname}" - return struct_member_type(struct_type.type, node.field.name, typemap) - if isinstance(node, c_ast.Cast): - return node.to_type.type - if isinstance(node, c_ast.Constant): - if node.type == "string": - return pointer(basic_type("char")) - if node.type == "char": - return basic_type("int") - return basic_type(node.type.split(" ")) - if isinstance(node, c_ast.ID): - return typemap.var_types[node.name] - if isinstance(node, c_ast.UnaryOp): - if node.op in ["p++", "p--", "++", "--"]: - return rec(node.expr) - if node.op == "&": - return pointer(rec(node.expr)) - if node.op == "*": - subtype = rec(node.expr) - return deref_type(subtype, typemap) - if node.op in ["-", "+"]: - subtype = pointer_decay(rec(node.expr), typemap) - if allowed_basic_type(subtype, typemap, ["double"]): - return basic_type("double") - if allowed_basic_type(subtype, typemap, ["float"]): - return basic_type("float") - if node.op in ["sizeof", "-", "+", "~", "!"]: - return basic_type("int") - assert False, f"unknown unary op {node.op}" - if isinstance(node, c_ast.BinaryOp): - lhs_type = pointer_decay(rec(node.left), typemap) - rhs_type = pointer_decay(rec(node.right), typemap) - if node.op in [">>", "<<"]: - return lhs_type - if node.op in ["<", "<=", ">", ">=", "==", "!=", "&&", "||"]: - return basic_type("int") - if node.op in "&|^%": - return basic_type("int") - real_lhs = resolve_typedefs(lhs_type, typemap) - real_rhs = resolve_typedefs(rhs_type, typemap) - if node.op in "+-": - lptr = isinstance(real_lhs, PtrDecl) - rptr = isinstance(real_rhs, PtrDecl) - if lptr or rptr: - if lptr and rptr: - assert node.op != "+", "pointer + pointer" - return basic_type("int") - if lptr: - return lhs_type - assert node.op == "+", "int - pointer" - return rhs_type - if node.op in "*/+-": - assert isinstance(real_lhs, TypeDecl) - assert isinstance(real_rhs, TypeDecl) - assert isinstance(real_lhs.type, IdentifierType) - assert isinstance(real_rhs.type, IdentifierType) - if "double" in real_lhs.type.names + real_rhs.type.names: - return basic_type("double") - if "float" in real_lhs.type.names + real_rhs.type.names: - return basic_type("float") - return basic_type("int") - if isinstance(node, c_ast.FuncCall): - expr = node.name - if isinstance(expr, c_ast.ID): - if expr.name not in typemap.fn_ret_types: - raise Exception(f"Called function {expr.name} is missing a prototype") - return typemap.fn_ret_types[expr.name] - else: - fptr_type = resolve_typedefs(rec(expr), typemap) - if isinstance(fptr_type, PtrDecl): - fptr_type = fptr_type.type - fptr_type = resolve_typedefs(fptr_type, typemap) - assert isinstance(fptr_type, FuncDecl), "call to non-function" - return fptr_type.type - if isinstance(node, c_ast.ExprList): - return rec(node.exprs[-1]) - if isinstance(node, c_ast.ArrayRef): - subtype = rec(node.name) - return deref_type(subtype, typemap) - if isinstance(node, c_ast.TernaryOp): - return rec(node.iftrue) - assert False, f"Unknown expression node type: {node}" - - -def decayed_expr_type(expr: c_ast.Node, typemap: TypeMap) -> SimpleType: - return pointer_decay(expr_type(expr, typemap), typemap) - - -def same_type( - type1: Type, type2: Type, typemap: TypeMap, allow_similar: bool = False -) -> bool: - while True: - type1 = resolve_typedefs(type1, typemap) - type2 = resolve_typedefs(type2, typemap) - if isinstance(type1, ArrayDecl) and isinstance(type2, ArrayDecl): - type1 = type1.type - type2 = type2.type - continue - if isinstance(type1, PtrDecl) and isinstance(type2, PtrDecl): - type1 = type1.type - type2 = type2.type - continue - if isinstance(type1, TypeDecl) and isinstance(type2, TypeDecl): - sub1 = type1.type - sub2 = type2.type - if isinstance(sub1, c_ast.Struct) and isinstance(sub2, c_ast.Struct): - return sub1.name == sub2.name - if isinstance(sub1, c_ast.Union) and isinstance(sub2, c_ast.Union): - return sub1.name == sub2.name - if ( - allow_similar - and isinstance(sub1, (IdentifierType, c_ast.Enum)) - and isinstance(sub2, (IdentifierType, c_ast.Enum)) - ): - # All int-ish types are similar (except void, but whatever) - return True - if isinstance(sub1, c_ast.Enum) and isinstance(sub2, c_ast.Enum): - return sub1.name == sub2.name - if isinstance(sub1, IdentifierType) and isinstance(sub2, IdentifierType): - return sorted(sub1.names) == sorted(sub2.names) - return False - - -def allowed_basic_type( - type: SimpleType, typemap: TypeMap, allowed_types: List[str] -) -> bool: - """Check if a type resolves to a basic type with one of the allowed_types - keywords in it.""" - base_type = resolve_typedefs(type, typemap) - if not isinstance(base_type, c_ast.TypeDecl): - return False - if not isinstance(base_type.type, c_ast.IdentifierType): - return False - if all(x not in base_type.type.names for x in allowed_types): - return False - return True - - -def build_typemap(ast: c_ast.FileAST) -> TypeMap: - ret = TypeMap() - for item in ast.ext: - if isinstance(item, c_ast.Typedef): - ret.typedefs[item.name] = item.type - if isinstance(item, c_ast.FuncDef): - assert item.decl.name is not None, "cannot define anonymous function" - assert isinstance(item.decl.type, FuncDecl) - ret.fn_ret_types[item.decl.name] = item.decl.type.type - if isinstance(item, c_ast.Decl) and isinstance(item.type, FuncDecl): - assert item.name is not None, "cannot define anonymous function" - ret.fn_ret_types[item.name] = item.type.type - defined_function_decls: Set[c_ast.Decl] = set() - - class Visitor(c_ast.NodeVisitor): - def visit_Struct(self, struct: c_ast.Struct) -> None: - if struct.decls and struct.name is not None: - ret.struct_defs[struct.name] = struct - # Do not visit decls of this struct - - def visit_Union(self, union: c_ast.Union) -> None: - if union.decls and union.name is not None: - ret.struct_defs[union.name] = union - # Do not visit decls of this union - - def visit_Decl(self, decl: c_ast.Decl) -> None: - if decl.name is not None: - ret.var_types[decl.name] = get_decl_type(decl) - if not isinstance(decl.type, FuncDecl) or decl in defined_function_decls: - # Do not visit declarations in parameter lists of functions - # other than our own. - self.visit(decl.type) - - def visit_Enumerator(self, enumerator: c_ast.Enumerator) -> None: - ret.var_types[enumerator.name] = basic_type("int") - - def visit_FuncDef(self, fn: c_ast.FuncDef) -> None: - if fn.decl.name is not None: - ret.var_types[fn.decl.name] = get_decl_type(fn.decl) - defined_function_decls.add(fn.decl) - self.generic_visit(fn) - - Visitor().visit(ast) - return ret - - -def set_decl_name(decl: c_ast.Decl) -> None: - name = decl.name - assert name is not None - type = get_decl_type(decl) - while not isinstance(type, TypeDecl): - type = type.type - type.declname = name diff --git a/tools/decomp-permuter/src/ast_util.py b/tools/decomp-permuter/src/ast_util.py deleted file mode 100644 index 91716fb97c..0000000000 --- a/tools/decomp-permuter/src/ast_util.py +++ /dev/null @@ -1,505 +0,0 @@ -from base64 import b64decode -from collections import defaultdict -import copy -from dataclasses import dataclass -from random import Random -import re -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union - -from pycparser import CParser, c_ast as ca, c_generator -from pycparser.plyparser import ParseError - -from .error import CandidateConstructionFailure -from .ast_types import SimpleType, set_decl_name - - -@dataclass -class Indices: - starts: Dict[ca.Node, int] - ends: Dict[ca.Node, int] - - -Block = Union[ca.Compound, ca.Case, ca.Default] -if TYPE_CHECKING: - # ca.Expression and ca.Statement don't actually exist, they live only in - # the stubs file. - Expression = ca.Expression - Statement = ca.Statement -else: - Expression = Statement = None - - -def to_c_raw(node: ca.Node) -> str: - source: str = c_generator.CGenerator().visit(node) - return source - - -def to_c(node: ca.Node, *, from_import: bool = False) -> str: - source = to_c_raw(node) if from_import else PatchedCGenerator().visit(node) - return process_pragmas(source) - - -def process_pragmas(source: str) -> str: - if "#pragma" not in source: - return source - lines = source.split("\n") - out: List[str] = [] - same_line = 0 - ignore = 0 - for line in lines: - stripped = line.strip() - if stripped.startswith("#pragma _permuter "): - # Expand permuter pragmas to nothing, by default. Still, keep one - # output line per input line to preserve line numbers for import.py - # error messages. - line = "" - - stripped = stripped[len("#pragma _permuter ") :] - if stripped == "sameline start": - same_line += 1 - elif stripped == "sameline end": - same_line -= 1 - elif stripped == "latedefine start": - ignore += 1 - elif stripped == "latedefine end": - assert ignore > 0, "mismatched ignore pragmas" - ignore -= 1 - elif stripped.startswith("define "): - assert ignore > 0, "define pragma must be within latedefine block" - line = "#" + stripped - elif stripped.startswith("b64literal "): - line = b64decode(stripped.split(" ", 1)[1]).decode("utf-8") - elif ignore > 0: - # Ignore non-pragma lines within latedefine section - line = "" - - if not same_line: - line += "\n" - elif line and out and not out[-1].endswith("\n"): - line = " " + line.lstrip() - out.append(line) - assert same_line == 0 - assert ignore == 0, "unbalanced ignore pragmas" - return "".join(out).rstrip() + "\n" - - -class PatchedCGenerator(c_generator.CGenerator): - """Like a CGenerator, except it keeps else if's prettier despite - the terrible things we've done to them in normalize_ast.""" - - def visit_If(self, n: ca.If) -> str: - n2 = n - if ( - n.iffalse - and isinstance(n.iffalse, ca.Compound) - and n.iffalse.block_items - and len(n.iffalse.block_items) == 1 - and isinstance(n.iffalse.block_items[0], ca.If) - ): - n2 = ca.If(cond=n.cond, iftrue=n.iftrue, iffalse=n.iffalse.block_items[0]) - return super().visit_If(n2) # type: ignore - - -def extract_fn(ast: ca.FileAST, fn_name: str) -> Tuple[ca.FuncDef, int]: - ret = [] - for i, node in enumerate(ast.ext): - if isinstance(node, ca.FuncDef): - if node.decl.name == fn_name: - ret.append((node, i)) - else: - node = node.decl - ast.ext[i] = node - if isinstance(node, ca.Decl) and isinstance(node.type, ca.FuncDecl): - node.funcspec = [spec for spec in node.funcspec if spec != "static"] - if len(ret) == 0: - raise CandidateConstructionFailure(f"Function {fn_name} not found in base.c.") - if len(ret) > 1: - raise CandidateConstructionFailure( - f"Found multiple copies of function {fn_name} in base.c." - ) - return ret[0] - - -def parse_c(source: str, *, from_import: bool = False) -> ca.FileAST: - try: - parser = CParser() - return parser.parse(source, "") - except ParseError as e: - msg = str(e) - position, msg = msg.split(": ", 1) - parts = position.split(":") - if len(parts) >= 2: - lineno = int(parts[1]) - posstr = f" at approximately line {lineno}" - if len(parts) >= 3: - posstr += f", column {parts[2]}" - if not from_import: - posstr += " (after PERM expansion)" - try: - line = source.split("\n")[lineno - 1].rstrip() - posstr += "\n\n" + line - except IndexError: - posstr += "(out of bounds?)" - else: - posstr = "" - raise CandidateConstructionFailure( - f"Syntax error in base.c.\n{msg}{posstr}" - ) from None - - -def compute_node_indices(top_node: ca.Node) -> Indices: - starts: Dict[ca.Node, int] = {} - ends: Dict[ca.Node, int] = {} - cur_index = 1 - - class Visitor(ca.NodeVisitor): - def generic_visit(self, node: ca.Node) -> None: - nonlocal cur_index - assert node not in starts, "nodes should only appear once in AST" - starts[node] = cur_index - cur_index += 2 - super().generic_visit(node) - ends[node] = cur_index - cur_index += 2 - - Visitor().visit(top_node) - return Indices(starts, ends) - - -def equal_ast(a: ca.Node, b: ca.Node) -> bool: - def equal(a: Any, b: Any) -> bool: - if type(a) != type(b): - return False - if a is None: - return b is None - if isinstance(a, list): - assert isinstance(b, list) - if len(a) != len(b): - return False - for i in range(len(a)): - if not equal(a[i], b[i]): - return False - return True - if isinstance(a, (int, str)): - return bool(a == b) - assert isinstance(a, ca.Node) - for name in a.__slots__[:-2]: # type: ignore - if not equal(getattr(a, name), getattr(b, name)): - return False - return True - - return equal(a, b) - - -def is_lvalue(expr: Expression) -> bool: - if isinstance(expr, (ca.ID, ca.StructRef, ca.ArrayRef)): - return True - if isinstance(expr, ca.UnaryOp): - return expr.op == "*" - return False - - -def is_effectful(expr: Expression) -> bool: - found = False - - class Visitor(ca.NodeVisitor): - def visit_UnaryOp(self, node: ca.UnaryOp) -> None: - nonlocal found - if node.op in ["p++", "p--", "++", "--"]: - found = True - else: - self.generic_visit(node.expr) - - def visit_FuncCall(self, _: ca.Node) -> None: - nonlocal found - found = True - - def visit_Assignment(self, _: ca.Node) -> None: - nonlocal found - found = True - - Visitor().visit(expr) - return found - - -def get_block_stmts(block: Block, force: bool) -> List[Statement]: - if isinstance(block, ca.Compound): - ret = block.block_items or [] - if force and not block.block_items: - block.block_items = ret - else: - ret = block.stmts or [] - if force and not block.stmts: - block.stmts = ret - return ret - - -def insert_decl( - fn: ca.FuncDef, var: str, type: SimpleType, random: Optional[Random] = None -) -> None: - type = copy.deepcopy(type) - decl = ca.Decl( - name=var, quals=[], storage=[], funcspec=[], type=type, init=None, bitsize=None - ) - set_decl_name(decl) - assert fn.body.block_items, "Non-empty function" - for index, stmt in enumerate(fn.body.block_items): - if not isinstance(stmt, ca.Decl): - break - else: - index = len(fn.body.block_items) - - if random: - index = random.randint(0, index) - fn.body.block_items[index:index] = [decl] - - -def insert_statement(block: Block, index: int, stmt: Statement) -> None: - stmts = get_block_stmts(block, True) - stmts[index:index] = [stmt] - - -def brace_nested_blocks(stmt: Statement) -> None: - def brace(stmt: Statement) -> Block: - if isinstance(stmt, (ca.Compound, ca.Case, ca.Default)): - return stmt - return ca.Compound([stmt]) - - if isinstance(stmt, (ca.For, ca.While, ca.DoWhile)): - stmt.stmt = brace(stmt.stmt) - elif isinstance(stmt, ca.If): - stmt.iftrue = brace(stmt.iftrue) - if stmt.iffalse: - stmt.iffalse = brace(stmt.iffalse) - elif isinstance(stmt, ca.Switch): - stmt.stmt = brace(stmt.stmt) - elif isinstance(stmt, ca.Label): - brace_nested_blocks(stmt.stmt) - - -def has_nested_block(node: ca.Node) -> bool: - return isinstance( - node, - ( - ca.Compound, - ca.For, - ca.While, - ca.DoWhile, - ca.If, - ca.Switch, - ca.Case, - ca.Default, - ), - ) - - -def for_nested_blocks(stmt: Statement, callback: Callable[[Block], None]) -> None: - def invoke(stmt: Statement) -> None: - assert isinstance( - stmt, (ca.Compound, ca.Case, ca.Default) - ), "brace_nested_blocks should have turned nested statements into blocks" - callback(stmt) - - if isinstance(stmt, ca.Compound): - invoke(stmt) - elif isinstance(stmt, (ca.For, ca.While, ca.DoWhile)): - invoke(stmt.stmt) - elif isinstance(stmt, ca.If): - if stmt.iftrue: - invoke(stmt.iftrue) - if stmt.iffalse: - invoke(stmt.iffalse) - elif isinstance(stmt, ca.Switch): - invoke(stmt.stmt) - elif isinstance(stmt, (ca.Case, ca.Default)): - invoke(stmt) - elif isinstance(stmt, ca.Label): - for_nested_blocks(stmt.stmt, callback) - - -def normalize_ast(fn: ca.FuncDef, ast: ca.FileAST) -> None: - """Add braces to all ifs/fors/etc., to make it easier to insert statements.""" - - def rec(block: Block) -> None: - stmts = get_block_stmts(block, False) - for stmt in stmts: - brace_nested_blocks(stmt) - for_nested_blocks(stmt, rec) - - rec(fn.body) - - -def prune_ast(fn: ca.FuncDef, ast: ca.FileAST) -> int: - """Prune away unnecessary parts of the AST, to reduce overhead from serialization - and from the compiler's C parser.""" - - # Create a GC graph that maps names of declarations and enumerators to indices - # in ast.ext, as well an initial list of GC roots, consisting of everything - # that isn't a Decl and or Typedef. - edges: Dict[str, List[int]] = defaultdict(list) - gc_roots: List[int] = [] - can_fwd_declare_typedef: Set[str] = set() - can_fwd_declare_tagged: Set[str] = set() - - def add_type_edges( - tp: Union["ca.Type", ca.Struct, ca.Union, ca.Enum], i: int - ) -> None: - while isinstance(tp, (ca.PtrDecl, ca.ArrayDecl)): - tp = tp.type - if isinstance(tp, ca.FuncDecl): - return - inner_type = tp.type if isinstance(tp, ca.TypeDecl) else tp - if isinstance(inner_type, ca.IdentifierType): - return - if inner_type.name: - edges[inner_type.name].append(i) - if isinstance(inner_type, ca.Enum) and inner_type.values: - for value in inner_type.values.enumerators: - edges[value.name].append(i) - if isinstance(inner_type, (ca.Struct, ca.Union)) and inner_type.decls: - for decl in inner_type.decls: - if isinstance(decl, ca.Decl): - add_type_edges(decl.type, i) - - for i in range(len(ast.ext)): - item = ast.ext[i] - if isinstance(item, ca.Decl) and not item.init: - # (Exclude declarations with initializers, since taking function - # pointers can affect regalloc on IDO.) - if item.name: - edges[item.name].append(i) - if isinstance(item.type, (ca.Struct, ca.Union, ca.Enum)) and item.type.name: - can_fwd_declare_tagged.add(item.type.name) - add_type_edges(item.type, i) - elif isinstance(item, ca.Typedef): - edges[item.name].append(i) - if isinstance(item.type, ca.TypeDecl) and isinstance( - item.type.type, (ca.Struct, ca.Union, ca.Enum) - ): - can_fwd_declare_typedef.add(item.name) - add_type_edges(item.type, i) - elif isinstance(item, ca.Pragma) and "GLOBAL_ASM" in item.string: - pass - else: - gc_roots.append(i) - - mentioned_ids: Set[str] = set() - - class IdVisitor(ca.NodeVisitor): - def visit_Pragma(self, node: ca.Pragma) -> None: - for token in re.findall(r"[a-zA-Z0-9_$]+", node.string): - mentioned_ids.add(token) - - def visit_ID(self, node: ca.ID) -> None: - mentioned_ids.add(node.name) - - IdVisitor().visit(ast) - - # Do the GC as a DFS traversal of the graph. Visiting a node searches its - # AST for all kinds of mentioned IDs, and adds more nodes to the stack - # using the edges we found before. - gc_todo: List[int] = gc_roots - need_fwd_decl_typedef: List[str] = [] - need_fwd_decl_tagged: List[str] = [] - - def add_name(name: str) -> None: - if name in edges: - gc_todo.extend(edges[name]) - del edges[name] - - class Visitor(ca.NodeVisitor): - def visit_Pragma(self, node: ca.Pragma) -> None: - for token in re.findall(r"[a-zA-Z0-9_$]+", node.string): - add_name(token) - - def visit_ID(self, node: ca.ID) -> None: - add_name(node.name) - - def visit_IdentifierType(self, node: ca.IdentifierType) -> None: - for name in node.names: - add_name(name) - - def visit_Enum(self, node: ca.Enum) -> None: - if node.name and not node.values: - add_name(node.name) - self.generic_visit(node) - - def visit_Struct(self, node: ca.Struct) -> None: - if node.name and not node.decls: - add_name(node.name) - self.generic_visit(node) - - def visit_Union(self, node: ca.Union) -> None: - if node.name and not node.decls: - add_name(node.name) - self.generic_visit(node) - - def visit_PtrDecl(self, node: ca.PtrDecl) -> None: - # For pointer declarations which haven't been accessed, forward - # declarations suffice. - if ( - isinstance(node.type, ca.TypeDecl) - and node.type.declname - and node.type.declname not in mentioned_ids - ): - tp = node.type.type - if isinstance(tp, ca.IdentifierType): - if all(name in can_fwd_declare_typedef for name in tp.names): - need_fwd_decl_typedef.extend(tp.names) - return - elif tp.name and tp.name in can_fwd_declare_tagged: - if not (tp.values if isinstance(tp, ca.Enum) else tp.decls): - need_fwd_decl_tagged.append(tp.name) - return - self.generic_visit(node) - - def visit_TypeDecl(self, node: ca.TypeDecl) -> None: - if node.declname: - add_name(node.declname) - self.generic_visit(node) - - keep_exts: Set[int] = set() - while gc_todo: - i = gc_todo.pop() - if i not in keep_exts: - keep_exts.add(i) - Visitor().visit(ast.ext[i]) - - temp_id = 0 - - def fwd_declare(tp: Union[ca.Struct, ca.Union, ca.Enum]) -> None: - nonlocal temp_id - if not tp.name: - temp_id += 1 - tp.name = f"_PermuterTemp{temp_id}" - if isinstance(tp, (ca.Struct, ca.Union)): - tp.decls = None - elif isinstance(tp, ca.Enum): - tp.values = None - else: - assert False - - new_ext = [] - - for i, item in enumerate(ast.ext): - if i in keep_exts: - pass - elif isinstance(item, ca.Typedef) and item.name in need_fwd_decl_typedef: - assert item.name in can_fwd_declare_typedef - assert isinstance(item.type, ca.TypeDecl) - assert isinstance(item.type.type, (ca.Struct, ca.Union, ca.Enum)) - fwd_declare(item.type.type) - elif ( - isinstance(item, ca.Decl) - and isinstance(item.type, (ca.Struct, ca.Union, ca.Enum)) - and item.type.name - and item.type.name in need_fwd_decl_tagged - ): - assert item.type.name in can_fwd_declare_tagged - fwd_declare(item.type) - else: - continue - new_ext.append(item) - - ast.ext = new_ext - return ast.ext.index(fn) diff --git a/tools/decomp-permuter/src/candidate.py b/tools/decomp-permuter/src/candidate.py deleted file mode 100644 index 5504b3d3dd..0000000000 --- a/tools/decomp-permuter/src/candidate.py +++ /dev/null @@ -1,99 +0,0 @@ -import copy -from dataclasses import dataclass, field -import functools -from typing import Optional, Tuple - -from pycparser import c_ast as ca - -from .compiler import Compiler -from .randomizer import Randomizer -from .scorer import Scorer -from .perm.perm import EvalState -from .perm.ast import apply_ast_perms -from .helpers import try_remove -from .profiler import Profiler -from . import ast_util - - -@dataclass -class CandidateResult: - """Represents the result of scoring a candidate, and is sent from child to - parent processes, or server to client with p@h.""" - - score: int - hash: Optional[str] - source: Optional[str] - profiler: Optional[Profiler] = None - - -@dataclass -class Candidate: - """ - Represents a AST candidate created from a source which can be randomized - (possibly multiple times), compiled, and scored. - """ - - ast: ca.FileAST - - fn_index: int - rng_seed: int - randomizer: Randomizer - score_value: Optional[int] = field(init=False, default=None) - score_hash: Optional[str] = field(init=False, default=None) - _cache_source: Optional[str] = field(init=False, default=None) - - @staticmethod - @functools.lru_cache(maxsize=16) - def _cached_shared_ast( - source: str, fn_name: str - ) -> Tuple[ca.FuncDef, int, ca.FileAST]: - ast = ast_util.parse_c(source) - orig_fn, fn_index = ast_util.extract_fn(ast, fn_name) - ast_util.normalize_ast(orig_fn, ast) - return orig_fn, fn_index, ast - - @staticmethod - def from_source( - source: str, eval_state: EvalState, fn_name: str, rng_seed: int - ) -> "Candidate": - # Use the same AST for all instances of the same original source, but - # with the target function deeply copied. Since we never change the - # AST outside of the target function, this is fine, and it saves us - # performance (deepcopy is really slow). - orig_fn, fn_index, ast = Candidate._cached_shared_ast(source, fn_name) - ast = copy.copy(ast) - ast.ext = copy.copy(ast.ext) - fn_copy = copy.deepcopy(orig_fn) - ast.ext[fn_index] = fn_copy - apply_ast_perms(fn_copy, eval_state) - return Candidate( - ast=ast, - fn_index=fn_index, - rng_seed=rng_seed, - randomizer=Randomizer(rng_seed), - ) - - def randomize_ast(self) -> None: - self.randomizer.randomize(self.ast, self.fn_index) - self._cache_source = None - - def get_source(self) -> str: - if self._cache_source is None: - self._cache_source = ast_util.to_c(self.ast) - return self._cache_source - - def compile(self, compiler: Compiler, show_errors: bool = False) -> Optional[str]: - source: str = self.get_source() - return compiler.compile(source, show_errors=show_errors) - - def score(self, scorer: Scorer, o_file: Optional[str]) -> CandidateResult: - self.score_value = None - self.score_hash = None - try: - self.score_value, self.score_hash = scorer.score(o_file) - finally: - if o_file: - try_remove(o_file) - return CandidateResult( - score=self.score_value, hash=self.score_hash, source=self.get_source() - ) diff --git a/tools/decomp-permuter/src/compiler.py b/tools/decomp-permuter/src/compiler.py deleted file mode 100644 index 284ed316ba..0000000000 --- a/tools/decomp-permuter/src/compiler.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Optional -import tempfile -import subprocess - -from .helpers import try_remove - - -class Compiler: - def __init__(self, compile_cmd: str, *, show_errors: bool) -> None: - self.compile_cmd = compile_cmd - self.show_errors = show_errors - - def compile(self, source: str, *, show_errors: bool = False) -> Optional[str]: - """Try to compile a piece of C code. Returns the filename of the resulting .o - temp file if it succeeds.""" - show_errors = show_errors or self.show_errors - with tempfile.NamedTemporaryFile( - prefix="permuter", suffix=".c", mode="w", delete=False - ) as f: - c_name = f.name - f.write(source) - - with tempfile.NamedTemporaryFile( - prefix="permuter", suffix=".o", delete=False - ) as f2: - o_name = f2.name - - try: - stderr = 2 if show_errors else subprocess.DEVNULL - subprocess.check_call( - [self.compile_cmd, c_name, "-o", o_name], - stdout=stderr, - stderr=stderr, - ) - except subprocess.CalledProcessError: - if not show_errors: - try_remove(c_name) - try_remove(o_name) - return None - except KeyboardInterrupt: - # If Ctrl+C happens during this call, make a best effort in - # removing the .c and .o files. This is totally racy, but oh well... - try_remove(c_name) - try_remove(o_name) - raise - - try_remove(c_name) - return o_name diff --git a/tools/decomp-permuter/src/error.py b/tools/decomp-permuter/src/error.py deleted file mode 100644 index a3f53985f1..0000000000 --- a/tools/decomp-permuter/src/error.py +++ /dev/null @@ -1,11 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class ServerError(Exception): - message: str - - -@dataclass -class CandidateConstructionFailure(Exception): - message: str diff --git a/tools/decomp-permuter/src/helpers.py b/tools/decomp-permuter/src/helpers.py deleted file mode 100644 index 5daab76055..0000000000 --- a/tools/decomp-permuter/src/helpers.py +++ /dev/null @@ -1,22 +0,0 @@ -import os -from typing import NoReturn - - -def plural(n: int, noun: str) -> str: - s = "s" if n != 1 else "" - return f"{n} {noun}{s}" - - -def exception_to_string(e: object) -> str: - return str(e) or e.__class__.__name__ - - -def static_assert_unreachable(x: NoReturn) -> NoReturn: - raise Exception("Unreachable! " + repr(x)) - - -def try_remove(path: str) -> None: - try: - os.remove(path) - except FileNotFoundError: - pass diff --git a/tools/decomp-permuter/src/main.py b/tools/decomp-permuter/src/main.py deleted file mode 100644 index 22320bbe2a..0000000000 --- a/tools/decomp-permuter/src/main.py +++ /dev/null @@ -1,654 +0,0 @@ -import argparse -from dataclasses import dataclass, field -import itertools -import multiprocessing -from multiprocessing import Queue -import os -import queue -import sys -import threading -import time -from typing import ( - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, - Tuple, -) - -from .candidate import CandidateResult -from .compiler import Compiler -from .error import CandidateConstructionFailure -from .helpers import plural, static_assert_unreachable -from .net.client import start_client -from .net.core import ServerError, connect, enable_debug_mode, MAX_PRIO, MIN_PRIO -from .permuter import ( - EvalError, - EvalResult, - Feedback, - Finished, - Message, - NeedMoreWork, - Permuter, - Task, - WorkDone, -) -from .preprocess import preprocess -from .printer import Printer -from .profiler import Profiler -from .scorer import Scorer - -# The probability that the randomizer continues transforming the output it -# generated last time. -DEFAULT_RAND_KEEP_PROB = 0.6 - - -@dataclass -class Options: - directories: List[str] - show_errors: bool = False - show_timings: bool = False - print_diffs: bool = False - stack_differences: bool = False - abort_exceptions: bool = False - better_only: bool = False - best_only: bool = False - quiet: bool = False - stop_on_zero: bool = False - keep_prob: float = DEFAULT_RAND_KEEP_PROB - force_seed: Optional[str] = None - threads: int = 1 - use_network: bool = False - network_debug: bool = False - network_priority: float = 1.0 - - -def restricted_float(lo: float, hi: float) -> Callable[[str], float]: - def convert(x: str) -> float: - try: - ret = float(x) - except ValueError: - raise argparse.ArgumentTypeError(f"invalid float value: '{x}'") - - if ret < lo or ret > hi: - raise argparse.ArgumentTypeError( - f"value {x} is out of range (must be between {lo} and {hi})" - ) - return ret - - return convert - - -@dataclass -class EvalContext: - options: Options - printer: Printer = field(default_factory=Printer) - iteration: int = 0 - errors: int = 0 - overall_profiler: Profiler = field(default_factory=Profiler) - permuters: List[Permuter] = field(default_factory=list) - - -def write_candidate(perm: Permuter, result: CandidateResult) -> None: - """Write the candidate's C source and score to the next output directory""" - ctr = 0 - while True: - ctr += 1 - try: - output_dir = os.path.join(perm.dir, f"output-{result.score}-{ctr}") - os.mkdir(output_dir) - break - except FileExistsError: - pass - source = result.source - assert source is not None, "Permuter._need_to_send_source is wrong!" - with open(os.path.join(output_dir, "source.c"), "x", encoding="utf-8") as f: - f.write(source) - with open(os.path.join(output_dir, "score.txt"), "x", encoding="utf-8") as f: - f.write(f"{result.score}\n") - with open(os.path.join(output_dir, "diff.txt"), "x", encoding="utf-8") as f: - f.write(perm.diff(source) + "\n") - print(f"wrote to {output_dir}") - - -def post_score( - context: EvalContext, permuter: Permuter, result: EvalResult, who: Optional[str] -) -> bool: - if isinstance(result, EvalError): - if result.exc_str is not None: - context.printer.print( - "internal permuter failure.", permuter, who, keep_progress=True - ) - print(result.exc_str) - if result.seed is not None: - seed_str = str(result.seed[1]) - if result.seed[0] != 0: - seed_str = f"{result.seed[0]},{seed_str}" - print(f"To reproduce the failure, rerun with: --seed {seed_str}") - if context.options.abort_exceptions: - sys.exit(1) - else: - return False - - if context.options.print_diffs: - assert result.source is not None, "Permuter._need_to_send_source is wrong" - print() - print(permuter.diff(result.source)) - input("Press any key to continue...") - - profiler = result.profiler - score_value = result.score - - if profiler is not None: - for stattype in profiler.time_stats: - context.overall_profiler.add_stat(stattype, profiler.time_stats[stattype]) - - context.iteration += 1 - if score_value == permuter.scorer.PENALTY_INF: - disp_score = "inf" - context.errors += 1 - else: - disp_score = str(score_value) - timings = "" - if context.options.show_timings: - timings = " \t" + context.overall_profiler.get_str_stats() - status_line = f"iteration {context.iteration}, {context.errors} errors, score = {disp_score}{timings}" - - if permuter.should_output(result): - former_best = permuter.best_score - permuter.record_result(result) - if score_value < former_best: - color = "\u001b[32;1m" - msg = f"found new best score! ({score_value} vs {permuter.base_score})" - elif score_value == former_best: - color = "\u001b[32;1m" - msg = f"tied best score! ({score_value} vs {permuter.base_score})" - elif score_value < permuter.base_score: - color = "\u001b[33m" - msg = f"found a better score! ({score_value} vs {permuter.base_score})" - else: - color = "\u001b[33m" - msg = f"found different asm with same score ({score_value})" - context.printer.print(msg, permuter, who, color=color) - - write_candidate(permuter, result) - if not context.options.quiet: - context.printer.progress(status_line) - return score_value == 0 - - -def cycle_seeds(permuters: List[Permuter]) -> Iterable[Tuple[int, int]]: - """ - Return all possible (permuter index, seed) pairs, cycling over permuters. - If a permuter is randomized, it will keep repeating seeds infinitely. - """ - iterators: List[Iterator[Tuple[int, int]]] = [] - for perm_ind, permuter in enumerate(permuters): - it = permuter.seed_iterator() - iterators.append(zip(itertools.repeat(perm_ind), it)) - - i = 0 - while iterators: - i %= len(iterators) - item = next(iterators[i], None) - if item is None: - del iterators[i] - i -= 1 - else: - yield item - i += 1 - - -def multiprocess_worker( - permuters: List[Permuter], - input_queue: "Queue[Task]", - output_queue: "Queue[Feedback]", -) -> None: - try: - while True: - # Read a work item from the queue. If none is immediately available, - # tell the main thread to fill the queues more, and then block on - # the queue. - queue_item: Task - try: - queue_item = input_queue.get(block=False) - except queue.Empty: - output_queue.put((NeedMoreWork(), -1, None)) - queue_item = input_queue.get() - if isinstance(queue_item, Finished): - output_queue.put((queue_item, -1, None)) - output_queue.close() - break - permuter_index, seed = queue_item - permuter = permuters[permuter_index] - result = permuter.try_eval_candidate(seed) - if isinstance(result, CandidateResult) and permuter.should_output(result): - permuter.record_result(result) - output_queue.put((WorkDone(permuter_index, result), -1, None)) - output_queue.put((NeedMoreWork(), -1, None)) - except KeyboardInterrupt: - # Don't clutter the output with stack traces; Ctrl+C is the expected - # way to quit and sends KeyboardInterrupt to all processes. - # A heartbeat thing here would be good but is too complex. - # Don't join the queue background thread -- thread joins in relation - # to KeyboardInterrupt usually result in deadlocks. - input_queue.cancel_join_thread() - output_queue.cancel_join_thread() - - -def run(options: Options) -> List[int]: - last_time = time.time() - try: - - def heartbeat() -> None: - nonlocal last_time - last_time = time.time() - - return run_inner(options, heartbeat) - except KeyboardInterrupt: - if time.time() - last_time > 5: - print() - print("Aborting stuck process.") - raise - print() - print("Exiting.") - sys.exit(0) - - -def run_inner(options: Options, heartbeat: Callable[[], None]) -> List[int]: - print("Loading...") - - context = EvalContext(options) - - force_seed: Optional[int] = None - force_rng_seed: Optional[int] = None - if options.force_seed: - seed_parts = list(map(int, options.force_seed.split(","))) - force_rng_seed = seed_parts[-1] - force_seed = 0 if len(seed_parts) == 1 else seed_parts[0] - - name_counts: Dict[str, int] = {} - for i, d in enumerate(options.directories): - heartbeat() - compile_cmd = os.path.join(d, "compile.sh") - target_o = os.path.join(d, "target.o") - base_c = os.path.join(d, "base.c") - for fname in [compile_cmd, target_o, base_c]: - if not os.path.isfile(fname): - print(f"Missing file {fname}", file=sys.stderr) - sys.exit(1) - if not os.stat(compile_cmd).st_mode & 0o100: - print(f"{compile_cmd} must be marked executable.", file=sys.stderr) - sys.exit(1) - - fn_name: Optional[str] = None - try: - with open(os.path.join(d, "function.txt"), encoding="utf-8") as f: - fn_name = f.read().strip() - except FileNotFoundError: - pass - - if fn_name: - print(f"{base_c} ({fn_name})") - else: - print(base_c) - - compiler = Compiler(compile_cmd, show_errors=options.show_errors) - scorer = Scorer(target_o, stack_differences=options.stack_differences) - c_source = preprocess(base_c) - - try: - permuter = Permuter( - d, - fn_name, - compiler, - scorer, - base_c, - c_source, - force_seed=force_seed, - force_rng_seed=force_rng_seed, - keep_prob=options.keep_prob, - need_profiler=options.show_timings, - need_all_sources=options.print_diffs, - show_errors=options.show_errors, - best_only=options.best_only, - better_only=options.better_only, - ) - except CandidateConstructionFailure as e: - print(e.message, file=sys.stderr) - sys.exit(1) - - context.permuters.append(permuter) - name_counts[permuter.fn_name] = name_counts.get(permuter.fn_name, 0) + 1 - print() - - if not context.permuters: - print("No permuters!") - return [] - - for permuter in context.permuters: - if name_counts[permuter.fn_name] > 1: - permuter.unique_name += f" ({permuter.dir})" - print(f"[{permuter.unique_name}] base score = {permuter.best_score}") - - found_zero = False - if options.threads == 1 and not options.use_network: - # Simple single-threaded mode. This is not technically needed, but - # makes the permuter easier to debug. - for permuter_index, seed in cycle_seeds(context.permuters): - heartbeat() - permuter = context.permuters[permuter_index] - result = permuter.try_eval_candidate(seed) - if post_score(context, permuter, result, None): - found_zero = True - if options.stop_on_zero: - break - else: - seed_iterators: List[Optional[Iterator[int]]] = [ - permuter.seed_iterator() - for perm_ind, permuter in enumerate(context.permuters) - ] - seed_iterators_remaining = len(seed_iterators) - next_iterator_index = 0 - - # Create queues. - worker_task_queue: "Queue[Task]" = Queue() - feedback_queue: "Queue[Feedback]" = Queue() - - # Connect to network and create client threads and queues. - net_conns: "List[Tuple[threading.Thread, Queue[Task]]]" = [] - if options.use_network: - print("Connecting to permuter@home...") - if options.network_debug: - enable_debug_mode() - first_stats: Optional[Tuple[int, int, float]] = None - for perm_index in range(len(context.permuters)): - try: - port = connect() - except (EOFError, ServerError) as e: - print("Error:", e) - sys.exit(1) - thread, queue, stats = start_client( - port, - context.permuters[perm_index], - perm_index, - feedback_queue, - options.network_priority, - ) - net_conns.append((thread, queue)) - if first_stats is None: - first_stats = stats - assert first_stats is not None, "has at least one permuter" - clients_str = plural(first_stats[0], "other client") - servers_str = plural(first_stats[1], "server") - cores_str = plural(int(first_stats[2]), "core") - print(f"Connected! {servers_str} online ({cores_str}, {clients_str})") - - # Start local worker threads - processes: List[multiprocessing.Process] = [] - for i in range(options.threads): - p = multiprocessing.Process( - target=multiprocess_worker, - args=(context.permuters, worker_task_queue, feedback_queue), - ) - p.start() - processes.append(p) - - active_workers = len(processes) - - if not active_workers and not net_conns: - print("No workers available! Exiting.") - sys.exit(1) - - def process_finish(finish: Finished, source: int) -> None: - nonlocal active_workers - - if finish.reason: - permuter: Optional[Permuter] = None - if source != -1 and len(context.permuters) > 1: - permuter = context.permuters[source] - context.printer.print(finish.reason, permuter, None, keep_progress=True) - - if source == -1: - active_workers -= 1 - - def process_result(work: WorkDone, who: Optional[str]) -> bool: - permuter = context.permuters[work.perm_index] - return post_score(context, permuter, work.result, who) - - def get_task(perm_index: int) -> Optional[Tuple[int, int]]: - nonlocal next_iterator_index, seed_iterators_remaining - if perm_index == -1: - while seed_iterators_remaining > 0: - task = get_task(next_iterator_index) - next_iterator_index += 1 - next_iterator_index %= len(seed_iterators) - if task is not None: - return task - else: - it = seed_iterators[perm_index] - if it is not None: - seed = next(it, None) - if seed is None: - seed_iterators[perm_index] = None - seed_iterators_remaining -= 1 - else: - return (perm_index, seed) - return None - - # Feed the task queue with work and read from results queue. - # We generally match these up one-by-one to avoid overfilling queues, - # but workers can ask us to add more tasks into the system if they run - # out of work. (This will happen e.g. at the very beginning, when the - # queues are empty.) - while seed_iterators_remaining > 0: - heartbeat() - feedback, source, who = feedback_queue.get() - if isinstance(feedback, Finished): - process_finish(feedback, source) - elif isinstance(feedback, Message): - context.printer.print(feedback.text, None, who, keep_progress=True) - elif isinstance(feedback, WorkDone): - if process_result(feedback, who): - # Found score 0! - found_zero = True - if options.stop_on_zero: - break - elif isinstance(feedback, NeedMoreWork): - task = get_task(source) - if task is not None: - if source == -1: - worker_task_queue.put(task) - else: - net_conns[source][1].put(task) - else: - static_assert_unreachable(feedback) - - # Signal workers to stop. - for i in range(active_workers): - worker_task_queue.put(Finished()) - - for conn in net_conns: - conn[1].put(Finished()) - - # Await final results. - while active_workers > 0 or net_conns: - heartbeat() - feedback, source, who = feedback_queue.get() - if isinstance(feedback, Finished): - process_finish(feedback, source) - elif isinstance(feedback, Message): - context.printer.print(feedback.text, None, who, keep_progress=True) - elif isinstance(feedback, WorkDone): - if not (options.stop_on_zero and found_zero): - if process_result(feedback, who): - found_zero = True - elif isinstance(feedback, NeedMoreWork): - pass - else: - static_assert_unreachable(feedback) - - # Wait for workers to finish. - for p in processes: - p.join() - - # Wait for network connections to close (currently does not happen). - for conn in net_conns: - conn[0].join() - - if found_zero: - print("\nFound zero score! Exiting.") - return [permuter.best_score for permuter in context.permuters] - - -def main() -> None: - multiprocessing.freeze_support() - sys.setrecursionlimit(10000) - - # Ideally we would do: - # multiprocessing.set_start_method("spawn") - # here, to make multiprocessing behave the same across operating systems. - # However, that means that arguments to Process are passed across using - # pickling, which mysteriously breaks with pycparser... - # (AttributeError: 'CParser' object has no attribute 'p_abstract_declarator_opt') - # So, for now we live with the defaults, which make multiprocessing work on Linux, - # where it uses fork and doesn't pickle arguments, and break on Windows. Sigh. - - parser = argparse.ArgumentParser( - description="Randomly permute C files to better match a target binary." - ) - parser.add_argument( - "directories", - nargs="+", - metavar="directory", - help="Directory containing base.c, target.o and compile.sh. Multiple directories may be given.", - ) - parser.add_argument( - "--show-errors", - dest="show_errors", - action="store_true", - help="Display compiler error/warning messages, and keep .c files for failed compiles.", - ) - parser.add_argument( - "--show-timings", - dest="show_timings", - action="store_true", - help="Display the time taken by permuting vs. compiling vs. scoring.", - ) - parser.add_argument( - "--print-diffs", - dest="print_diffs", - action="store_true", - help="Instead of compiling generated sources, display diffs against a base version.", - ) - parser.add_argument( - "--abort-exceptions", - dest="abort_exceptions", - action="store_true", - help="Stop execution when an internal permuter exception occurs.", - ) - parser.add_argument( - "--better-only", - dest="better_only", - action="store_true", - help="Only report scores better than the base.", - ) - parser.add_argument( - "--best-only", - dest="best_only", - action="store_true", - help="Only report ties or new high scores.", - ) - parser.add_argument( - "--stop-on-zero", - dest="stop_on_zero", - action="store_true", - help="Stop after producing an output with score 0.", - ) - parser.add_argument( - "--quiet", - dest="quiet", - action="store_true", - help="Don't print a status line with the number of iterations.", - ) - parser.add_argument( - "--stack-diffs", - dest="stack_differences", - action="store_true", - help="Take stack differences into account when computing the score.", - ) - parser.add_argument( - "--keep-prob", - dest="keep_prob", - metavar="PROB", - type=restricted_float(0.0, 1.0), - default=DEFAULT_RAND_KEEP_PROB, - help="""Continue randomizing the previous output with the given probability - (float in 0..1, default %(default)s).""", - ) - parser.add_argument("--seed", dest="force_seed", type=str, help=argparse.SUPPRESS) - parser.add_argument( - "-j", - dest="threads", - type=int, - default=0, - help="Number of own threads to use (default: 1 without -J, 0 with -J).", - ) - parser.add_argument( - "-J", - dest="use_network", - action="store_true", - help="Harness extra compute power through cyberspace (permuter@home).", - ) - parser.add_argument( - "--pah-debug", - dest="network_debug", - action="store_true", - help="Enable debug prints for permuter@home.", - ) - parser.add_argument( - "--priority", - dest="network_priority", - metavar="PRIORITY", - type=restricted_float(MIN_PRIO, MAX_PRIO), - default=1.0, - help=f"""Proportion of server resources to use when multiple people - are using -J at the same time. - Defaults to 1.0, meaning resources are split equally, but can be - set to any value within [{MIN_PRIO}, {MAX_PRIO}]. - Each server runs with a priority threshold, which defaults to 0.1, - below which they will not run permuter jobs at all.""", - ) - - args = parser.parse_args() - - threads = args.threads - if not threads and not args.use_network: - threads = 1 - - options = Options( - directories=args.directories, - show_errors=args.show_errors, - show_timings=args.show_timings, - print_diffs=args.print_diffs, - abort_exceptions=args.abort_exceptions, - better_only=args.better_only, - best_only=args.best_only, - quiet=args.quiet, - stack_differences=args.stack_differences, - stop_on_zero=args.stop_on_zero, - keep_prob=args.keep_prob, - force_seed=args.force_seed, - threads=threads, - use_network=args.use_network, - network_debug=args.network_debug, - network_priority=args.network_priority, - ) - - run(options) - - -if __name__ == "__main__": - main() diff --git a/tools/decomp-permuter/src/net/__init__.py b/tools/decomp-permuter/src/net/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tools/decomp-permuter/src/net/client.py b/tools/decomp-permuter/src/net/client.py deleted file mode 100644 index 1a3442c325..0000000000 --- a/tools/decomp-permuter/src/net/client.py +++ /dev/null @@ -1,272 +0,0 @@ -from multiprocessing import Queue -import re -import threading -from typing import Optional, Tuple -import zlib - -from ..candidate import CandidateResult -from ..helpers import exception_to_string -from ..permuter import ( - EvalError, - EvalResult, - Feedback, - FeedbackItem, - Finished, - Message, - NeedMoreWork, - Permuter, - Task, - WorkDone, -) -from ..profiler import Profiler -from .core import ( - PermuterData, - SocketPort, - json_prop, - permuter_data_to_json, -) - - -def _profiler_from_json(obj: dict) -> Profiler: - ret = Profiler() - for key in obj: - assert isinstance(key, str), "json properties are strings" - stat = Profiler.StatType[key] - time = json_prop(obj, key, float) - ret.add_stat(stat, time) - return ret - - -def _result_from_json(obj: dict, source: Optional[str]) -> EvalResult: - if "error" in obj: - return EvalError(exc_str=json_prop(obj, "error", str), seed=None) - - profiler: Optional[Profiler] = None - if "profiler" in obj: - profiler = _profiler_from_json(json_prop(obj, "profiler", dict)) - return CandidateResult( - score=json_prop(obj, "score", int), - hash=json_prop(obj, "hash", str) if "hash" in obj else None, - source=source, - profiler=profiler, - ) - - -def _make_script_portable(source: str) -> str: - """Parse a shell script and get rid of the machine-specific parts that - import.py introduces. The resulting script must be run in an environment - that has the right binaries in its $PATH, and with a current working - directory similar to where import.py found its target's make root.""" - lines = [] - for line in source.split("\n"): - if re.match("cd '?/", line): - # Skip cd's to absolute directory paths. Note that shlex quotes - # its argument with ' if it contains spaces/single quotes. - continue - if re.match("'?/", line): - quote = "'" if line[0] == "'" else "" - ind = line.find(quote + " ") - if ind == -1: - ind = len(line) - else: - ind += len(quote) - lastind = line.rfind("/", 0, ind) - assert lastind != -1 - # Emit a call to "which" as the first part, to ensure the called - # binary still sees an absolute path. qemu-irix requires this, - # for some reason. - line = "$(which " + quote + line[lastind + 1 : ind] + ")" + line[ind:] - lines.append(line) - return "\n".join(lines) - - -def make_portable_permuter(permuter: Permuter) -> PermuterData: - with open(permuter.scorer.target_o, "rb") as f: - target_o_bin = f.read() - - with open(permuter.compiler.compile_cmd, "r") as f2: - compile_script = _make_script_portable(f2.read()) - - return PermuterData( - base_score=permuter.base_score, - base_hash=permuter.base_hash, - fn_name=permuter.fn_name, - filename=permuter.source_file, - keep_prob=permuter.keep_prob, - need_profiler=permuter.need_profiler, - stack_differences=permuter.scorer.stack_differences, - compile_script=compile_script, - source=permuter.source, - target_o_bin=target_o_bin, - ) - - -class Connection: - _port: SocketPort - _permuter_data: PermuterData - _perm_index: int - _task_queue: "Queue[Task]" - _feedback_queue: "Queue[Feedback]" - - def __init__( - self, - port: SocketPort, - permuter_data: PermuterData, - perm_index: int, - task_queue: "Queue[Task]", - feedback_queue: "Queue[Feedback]", - ) -> None: - self._port = port - self._permuter_data = permuter_data - self._perm_index = perm_index - self._task_queue = task_queue - self._feedback_queue = feedback_queue - - def _send_permuter(self) -> None: - data = self._permuter_data - self._port.send_json(permuter_data_to_json(data)) - self._port.send(zlib.compress(data.source.encode("utf-8"))) - self._port.send(zlib.compress(data.target_o_bin)) - - def _feedback(self, feedback: FeedbackItem, server_nick: Optional[str]) -> None: - self._feedback_queue.put((feedback, self._perm_index, server_nick)) - - def _receive_one(self) -> bool: - """Receive a result/progress message and send it on. Returns true if - more work should be requested.""" - msg = self._port.receive_json() - msg_type = json_prop(msg, "type", str) - if msg_type == "need_work": - return True - - server_nick = json_prop(msg, "server", str) - if msg_type == "init_done": - base_hash = json_prop(msg, "hash", str) - my_base_hash = self._permuter_data.base_hash - text = "connected" - if base_hash != my_base_hash: - text += " (note: mismatching hash)" - self._feedback(Message(text), server_nick) - return True - - if msg_type == "init_failed": - text = "failed to initialize: " + json_prop(msg, "reason", str) - self._feedback(Message(text), server_nick) - return False - - if msg_type == "disconnect": - self._feedback(Message("disconnected"), server_nick) - return False - - if msg_type == "result": - source: Optional[str] = None - if msg.get("has_source") == True: - # Source is sent separately, compressed, since it can be - # large (hundreds of kilobytes is not uncommon). - compressed_source = self._port.receive() - try: - source = zlib.decompress(compressed_source).decode("utf-8") - except Exception as e: - text = "failed to decompress: " + exception_to_string(e) - self._feedback(Message(text), server_nick) - return True - try: - result = _result_from_json(msg, source) - self._feedback(WorkDone(self._perm_index, result), server_nick) - except Exception as e: - text = "failed to parse result message: " + exception_to_string(e) - self._feedback(Message(text), server_nick) - return True - - raise ValueError(f"Invalid message type {msg_type}") - - def run(self) -> None: - finish_reason: Optional[str] = None - try: - self._send_permuter() - self._port.receive_json() - - finished = False - - # Main loop: send messages from the queue on to the server, and - # vice versa. Currently we are being lazy and alternate between - # sending and receiving; this is nicely simple and keeps us on a - # single thread, however it could cause deadlocks if the server - # receiver stops reading because we aren't reading fast enough. - while True: - if not self._receive_one(): - continue - self._feedback(NeedMoreWork(), None) - - # Read a task and send it on, unless there are no more tasks. - if not finished: - task = self._task_queue.get() - if isinstance(task, Finished): - # We don't have a way of indicating to the server that - # all is done: the server currently doesn't track - # outstanding work so it doesn't know when to close - # the connection. (Even with this fixed we'll have the - # problem that servers may disconnect, losing work, so - # the task never truly finishes. But it might work well - # enough in practice.) - finished = True - else: - work = { - "type": "work", - "work": { - "seed": task[1], - }, - } - self._port.send_json(work) - - except EOFError: - finish_reason = "disconnected from permuter@home" - - except Exception as e: - errmsg = exception_to_string(e) - finish_reason = f"permuter@home error: {errmsg}" - - finally: - self._feedback(Finished(reason=finish_reason), None) - self._port.shutdown() - self._port.close() - - -def start_client( - port: SocketPort, - permuter: Permuter, - perm_index: int, - feedback_queue: "Queue[Feedback]", - priority: float, -) -> "Tuple[threading.Thread, Queue[Task], Tuple[int, int, float]]": - port.send_json( - { - "method": "connect_client", - "priority": priority, - } - ) - obj = port.receive_json() - if "error" in obj: - err = json_prop(obj, "error", str) - # TODO use another exception type - raise Exception(f"Failed to connect: {err}") - num_servers = json_prop(obj, "servers", int) - num_clients = json_prop(obj, "clients", int) - num_cores = json_prop(obj, "cores", float) - permuter_data = make_portable_permuter(permuter) - task_queue: "Queue[Task]" = Queue() - - conn = Connection( - port, - permuter_data, - perm_index, - task_queue, - feedback_queue, - ) - - thread = threading.Thread(target=conn.run, daemon=True) - thread.start() - - stats = (num_clients, num_servers, num_cores) - - return thread, task_queue, stats diff --git a/tools/decomp-permuter/src/net/cmd/__init__.py b/tools/decomp-permuter/src/net/cmd/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tools/decomp-permuter/src/net/cmd/base.py b/tools/decomp-permuter/src/net/cmd/base.py deleted file mode 100644 index 492044350b..0000000000 --- a/tools/decomp-permuter/src/net/cmd/base.py +++ /dev/null @@ -1,17 +0,0 @@ -import abc -from argparse import ArgumentParser, Namespace - - -class Command(abc.ABC): - command: str - help: str - - @staticmethod - @abc.abstractmethod - def add_arguments(parser: ArgumentParser) -> None: - ... - - @staticmethod - @abc.abstractmethod - def run(args: Namespace) -> None: - ... diff --git a/tools/decomp-permuter/src/net/cmd/icons/notok.ico b/tools/decomp-permuter/src/net/cmd/icons/notok.ico deleted file mode 100644 index 2825d0d9c9..0000000000 Binary files a/tools/decomp-permuter/src/net/cmd/icons/notok.ico and /dev/null differ diff --git a/tools/decomp-permuter/src/net/cmd/icons/ok.ico b/tools/decomp-permuter/src/net/cmd/icons/ok.ico deleted file mode 100644 index bc0b0534b1..0000000000 Binary files a/tools/decomp-permuter/src/net/cmd/icons/ok.ico and /dev/null differ diff --git a/tools/decomp-permuter/src/net/cmd/icons/okthink.ico b/tools/decomp-permuter/src/net/cmd/icons/okthink.ico deleted file mode 100644 index d3a8eacf84..0000000000 Binary files a/tools/decomp-permuter/src/net/cmd/icons/okthink.ico and /dev/null differ diff --git a/tools/decomp-permuter/src/net/cmd/main.py b/tools/decomp-permuter/src/net/cmd/main.py deleted file mode 100644 index 9d1e4db1c7..0000000000 --- a/tools/decomp-permuter/src/net/cmd/main.py +++ /dev/null @@ -1,70 +0,0 @@ -from argparse import ArgumentParser, RawDescriptionHelpFormatter -import sys - -from ..core import ServerError, enable_debug_mode -from .run_server import RunServerCommand -from .setup import SetupCommand -from .ping import PingCommand -from .vouch import VouchCommand - - -def main() -> None: - try: - # We currently sometimes log stuff to stdout, so it's preferable if it's - # line-buffered even when redirected to a non-tty (e.g. when running a - # permuter server as a systemd service). This is supported by Python 3.7 - # and up. - sys.stdout.reconfigure(line_buffering=True) # type: ignore - except Exception: - pass - - parser = ArgumentParser( - description="permuter@home - run the permuter across the Internet!\n\n" - "To use p@h as a client, just pass -J when running the permuter. " - "This script is\nonly necessary for configuration or when running a server.", - formatter_class=RawDescriptionHelpFormatter, - ) - - commands = [ - PingCommand, - RunServerCommand, - SetupCommand, - VouchCommand, - ] - - parser.add_argument( - "--debug", - dest="debug", - action="store_true", - help="Enable debug logging.", - ) - - subparsers = parser.add_subparsers(metavar="") - for command in commands: - subparser = subparsers.add_parser( - command.command, - help=command.help, - description=command.help, - ) - command.add_arguments(subparser) - subparser.set_defaults(subcommand_handler=command.run) - - args = parser.parse_args() - if args.debug: - enable_debug_mode() - - if "subcommand_handler" in args: - try: - args.subcommand_handler(args) - except EOFError as e: - print("Network error:", e) - sys.exit(1) - except ServerError as e: - print("Error:", e.message) - sys.exit(1) - else: - parser.print_help() - - -if __name__ == "__main__": - main() diff --git a/tools/decomp-permuter/src/net/cmd/ping.py b/tools/decomp-permuter/src/net/cmd/ping.py deleted file mode 100644 index e49f1e6825..0000000000 --- a/tools/decomp-permuter/src/net/cmd/ping.py +++ /dev/null @@ -1,32 +0,0 @@ -from argparse import ArgumentParser, Namespace -import time - -from ...helpers import plural -from ..core import connect, json_prop -from .base import Command - - -class PingCommand(Command): - command = "ping" - help = "Check server connectivity." - - @staticmethod - def add_arguments(parser: ArgumentParser) -> None: - pass - - @staticmethod - def run(args: Namespace) -> None: - run_ping() - - -def run_ping() -> None: - port = connect() - t0 = time.time() - port.send_json({"method": "ping"}) - msg = port.receive_json() - rtt = (time.time() - t0) * 1000 - print(f"Connected successfully! Round-trip time: {rtt:.1f} ms") - servers_str = plural(json_prop(msg, "servers", int), "server") - clients_str = plural(json_prop(msg, "clients", int), "client") - cores_str = plural(int(json_prop(msg, "cores", float)), "core") - print(f"{servers_str} online ({cores_str}, {clients_str})") diff --git a/tools/decomp-permuter/src/net/cmd/run_server.py b/tools/decomp-permuter/src/net/cmd/run_server.py deleted file mode 100644 index 429ddb1b6b..0000000000 --- a/tools/decomp-permuter/src/net/cmd/run_server.py +++ /dev/null @@ -1,616 +0,0 @@ -from argparse import ArgumentParser, Namespace -import base64 -from dataclasses import dataclass -from enum import Enum -from functools import partial -import json -import os -import platform -import queue -import random -import shutil -from subprocess import Popen, PIPE -import sys -import time -import threading -import traceback -from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING - -from ...helpers import static_assert_unreachable -from ..core import CancelToken, ServerError, read_config -from ..server import ( - Client, - Config, - IoActivity, - IoConnect, - IoDisconnect, - IoImmediateDisconnect, - IoReconnect, - IoServerFailed, - IoShutdown, - IoUserRemovePermuter, - IoWorkDone, - PermuterHandle, - Server, - ServerOptions, -) -from .base import Command -from .util import ask - - -class RunServerCommand(Command): - command = "run-server" - help = """Run a permuter server, allowing anyone with access to the central - server to run sandboxed permuter jobs on your machine. Requires docker.""" - - @staticmethod - def add_arguments(parser: ArgumentParser) -> None: - parser.add_argument( - "--cores", - dest="num_cores", - metavar="CORES", - type=float, - required=True, - help="Number of cores to use (float).", - ) - parser.add_argument( - "--memory", - dest="max_memory_gb", - metavar="MEMORY_GB", - type=float, - required=True, - help="""Restrict the sandboxed process to the given amount of memory in - gigabytes (float). If this limit is hit, the permuter will crash - horribly, but at least your system won't lock up.""", - ) - parser.add_argument( - "--systray", - dest="systray", - action="store_true", - help="""Make the server controllable through the system tray.""", - ) - parser.add_argument( - "--min-priority", - dest="min_priority", - metavar="PRIORITY", - type=float, - default=0.1, - help="""Only accept jobs from clients who pass --priority with a number - higher or equal to this value. (default: %(default)s)""", - ) - - @staticmethod - def run(args: Namespace) -> None: - options = ServerOptions( - num_cores=args.num_cores, - max_memory_gb=args.max_memory_gb, - min_priority=args.min_priority, - ) - - server_main(options, args.systray) - - -class SystrayState: - def server_reconnecting(self) -> None: - pass - - def server_connected(self) -> None: - pass - - def server_failed(self, graceful: bool, message: Optional[str] = None) -> None: - pass - - def connect(self, handle: PermuterHandle, nickname: str, fn_name: str) -> None: - pass - - def disconnect(self, handle: PermuterHandle) -> None: - pass - - def work_done(self, handle: PermuterHandle, is_improvement: bool) -> None: - pass - - def stop(self) -> None: - pass - - -@dataclass -class Permuter: - nickname: str - fn_name: str - iterations: int = 0 - improvements: int = 0 - last_systray_update: float = 0.0 - slot: "Optional[ClientSlot]" = None - - -@dataclass -class ClientSlot: - menu_id: int - iterations_id: int - improvements_id: int - stop_id: int - permuter: Optional[PermuterHandle] = None - - -class SystrayStatus(Enum): - CONNECTING = 0 - CONNECTED = 1 - FAILED = 2 - RECONNECTING = 3 - - -class RealSystrayState(SystrayState): - _CLIENT_SLOTS = 10 - _UPDATE_INTERVAL = 2.0 - _MENU_TOOLTIP = "permuter@home" - _permuters: Dict[PermuterHandle, Permuter] - _onclick: Dict[int, Callable[[], None]] - _client_slots: List[ClientSlot] - - def __init__( - self, - config: Config, - io_queue: "queue.Queue[IoActivity]", - ) -> None: - self._io_queue = io_queue - self._permuters = {} - self._onclick = {} - self._status = SystrayStatus.CONNECTING - self._fail_message: Optional[str] = None - - def load_icon(fname: str) -> str: - path = os.path.join(os.path.dirname(__file__), "icons", fname) - with open(path, "rb") as f: - data = f.read() - return base64.b64encode(data).decode("ascii") - - self._icons = { - "working": load_icon("okthink.ico"), - "passive": load_icon("ok.ico"), - "fail": load_icon("notok.ico"), - } - self._current_icon = "working" - - next_id = 100 - - def add_item( - menu: List[dict], - title: str, - onclick: Optional[Callable[[], None]] = None, - *, - submenu: Optional[List[dict]] = None, - hidden: bool = False, - ) -> int: - nonlocal next_id - next_id += 1 - obj = { - "title": title, - "enabled": onclick is not None or submenu is not None, - "hidden": hidden, - "__id": next_id, - } - if onclick is not None: - self._onclick[next_id] = onclick - if submenu is not None: - obj["items"] = submenu - menu.append(obj) - return next_id - - menu: List[dict] = [] - self._status_id = add_item(menu, "Connecting...") - self._client_slots = [] - for i in range(self._CLIENT_SLOTS): - submenu: List[dict] = [] - remove_cb = partial(self._remove_permuter, i) - self._client_slots.append( - ClientSlot( - iterations_id=add_item(submenu, ""), - improvements_id=add_item(submenu, ""), - stop_id=add_item(submenu, "Stop", remove_cb), - menu_id=add_item(menu, "", submenu=submenu, hidden=True), - ) - ) - self._more_id = add_item(menu, "", hidden=True) - add_item(menu, "Quit", self._quit) - - try: - path = self._setup_helper() - self._proc = Popen( - [path], - stdout=PIPE, - stdin=PIPE, - universal_newlines=True, - ) - assert self._proc.stdout is not None - self._proc_stdout = self._proc.stdout - assert self._proc.stdin is not None - self._proc_stdin = self._proc.stdin - - self._send( - { - "icon": self._icons[self._current_icon], - "tooltip": self._MENU_TOOLTIP, - "items": menu, - } - ) - - resp_str = self._proc_stdout.readline() - assert resp_str - resp = json.loads(resp_str) - assert isinstance(resp, dict) - assert resp.get("type") == "ready" - except Exception: - print("Failed to initialize systray!") - print() - print("See src/net/cmd/systray/README.md for details on how to set it up.") - traceback.print_exc() - sys.exit(1) - - self._read_thread = threading.Thread(target=self._read_loop, daemon=True) - self._read_thread.start() - - @staticmethod - def _setup_helper() -> str: - fname = "permuter-systray" - suffix = "" - osname = sys.platform.replace("darwin", "macos") - arch = platform.machine().replace("AMD64", "x86_64") - if ( - osname in ("win32", "msys", "cygwin") - or "microsoft" in platform.uname().release.lower() - ): - osname = "win" - suffix = ".exe" - - dir = os.path.join(os.path.dirname(__file__), "systray") - target_binary = os.path.join(dir, fname + suffix) - if os.path.exists(target_binary): - return target_binary - - prebuilt_file = f"{fname}-{osname}-{arch}{suffix}" - prebuilt_file = os.path.join(dir, "prebuilt", prebuilt_file) - - print("An external helper binary is required for systray support.") - print( - "To build it from source (requires Go), see src/net/cmd/systray/README.md." - ) - - if os.path.exists(prebuilt_file): - print("Alternatively, a pre-built binary can be used.") - if ask("Use pre-built binary?", default=False): - shutil.copy(prebuilt_file, target_binary) - os.chmod(target_binary, 0o755) - return target_binary - - print("Aborting.") - sys.exit(1) - - def _send(self, msg: dict) -> None: - data = json.dumps(msg) - self._proc_stdin.write(data + "\n") - self._proc_stdin.flush() - - def _update_item( - self, id: int, title: str, *, hidden: bool = False, enabled: bool = False - ) -> None: - self._send( - { - "type": "update-item", - "item": { - "title": title, - "enabled": enabled, - "hidden": hidden, - "__id": id, - }, - "seq_id": -1, - } - ) - - def _remove_permuter(self, slot_index: int) -> None: - slot = self._client_slots[slot_index] - if not slot.permuter: - return - handle = slot.permuter - self._io_queue.put((None, (handle, IoUserRemovePermuter()))) - - def _quit(self) -> None: - self._io_queue.put((None, IoShutdown())) - - def _read_loop(self) -> None: - while True: - resp_str = self._proc_stdout.readline() - if not resp_str: - break - try: - resp = json.loads(resp_str) - except Exception: - raise Exception(f"Failed to parse systray JSON: {resp_str}") from None - if resp["type"] == "clicked": - id = resp["__id"] - if id in self._onclick: - self._onclick[id]() - - def _permuter_slot(self, perm: Permuter) -> Optional[ClientSlot]: - for slot in self._client_slots: - if slot.permuter is not None and self._permuters[slot.permuter] is perm: - return slot - return None - - def _update_permuter(self, perm: Permuter, slot: ClientSlot) -> None: - self._update_item( - slot.iterations_id, - f"Iterations: {perm.iterations}", - ) - self._update_item( - slot.improvements_id, - f"Improvements found: {perm.improvements}", - ) - - def _update_status(self) -> None: - if self._status == SystrayStatus.CONNECTING: - status = "Reconnecting..." - icon = "working" - elif self._status == SystrayStatus.RECONNECTING: - status = "Disconnected, will reconnect..." - icon = "fail" - elif self._status == SystrayStatus.CONNECTED: - if self._permuters: - status = "Currently permuting:" - icon = "working" - else: - status = "Not running" - icon = "passive" - elif self._status == SystrayStatus.FAILED: - if self._fail_message: - status = f"Error: {self._fail_message}" - else: - status = "Error occurred" - icon = "fail" - else: - assert False, f"bad status {self._status}" - - self._update_item(self._status_id, status) - if self._current_icon != icon: - self._current_icon = icon - self._send( - { - "type": "update-menu", - "menu": { - "tooltip": self._MENU_TOOLTIP, - "icon": self._icons[icon], - }, - } - ) - - def _fill_slots(self) -> None: - has_more = False - while True: - key = next((k for k, p in self._permuters.items() if p.slot is None), None) - if key is None: - break - chosen_slot: Optional[ClientSlot] = None - for i in range(self._CLIENT_SLOTS - 1, -1, -1): - slot = self._client_slots[i] - if slot.permuter is None: - chosen_slot = slot - elif chosen_slot is not None: - break - if chosen_slot is None: - has_more = True - break - perm = self._permuters[key] - perm.slot = chosen_slot - chosen_slot.permuter = key - self._update_permuter(perm, chosen_slot) - self._update_item( - chosen_slot.menu_id, f"{perm.fn_name} ({perm.nickname})", enabled=True - ) - self._update_item(self._more_id, "More...", hidden=not has_more) - - def _hide_slot(self, slot: ClientSlot) -> None: - if slot.permuter is not None: - self._update_item(slot.menu_id, "", hidden=True) - slot.permuter = None - - def server_reconnecting(self) -> None: - self._status = SystrayStatus.CONNECTING - self._update_status() - - def server_connected(self) -> None: - self._status = SystrayStatus.CONNECTED - self._update_status() - - def server_failed(self, graceful: bool, message: Optional[str] = None) -> None: - self._status = SystrayStatus.RECONNECTING if graceful else SystrayStatus.FAILED - self._fail_message = message - self._permuters = {} - self._update_status() - for slot in self._client_slots: - self._hide_slot(slot) - self._fill_slots() - - def connect(self, handle: PermuterHandle, nickname: str, fn_name: str) -> None: - perm = Permuter(nickname, fn_name) - self._permuters[handle] = perm - self._fill_slots() - self._update_status() - - def disconnect(self, handle: PermuterHandle) -> None: - slot = self._permuters[handle].slot - del self._permuters[handle] - self._update_status() - if slot: - self._hide_slot(slot) - self._fill_slots() - - def work_done(self, handle: PermuterHandle, is_improvement: bool) -> None: - perm = self._permuters[handle] - perm.iterations += 1 - if is_improvement: - perm.improvements += 1 - if perm.slot and time.time() > perm.last_systray_update + self._UPDATE_INTERVAL: - perm.last_systray_update = time.time() - self._update_permuter(perm, perm.slot) - - def stop(self) -> None: - try: - self._send({"type": "exit"}) - except BrokenPipeError: - # The systray process may have been killed by Ctrl+C. - pass - self._proc.wait() - self._read_thread.join() - - -class Reconnector: - _RESET_BACKOFF_AFTER_UPTIME: float = 60.0 - _RANDOM_ADDEND_MAX: float = 60.0 - _BACKOFF_MULTIPLIER: float = 2.0 - _INITIAL_DELAY: float = 5.0 - - _io_queue: "queue.Queue[IoActivity]" - _reconnect_token: CancelToken - _reconnect_delay: float - _reconnect_timer: Optional[threading.Timer] - _start_time: float - _stop_time: float - - def __init__(self, io_queue: "queue.Queue[IoActivity]") -> None: - self._io_queue = io_queue - self._reconnect_token = CancelToken() - self._reconnect_delay = self._INITIAL_DELAY - self._reconnect_timer = None - self._start_time = self._stop_time = time.time() - - def mark_start(self) -> None: - self._start_time = time.time() - - def mark_stop(self) -> None: - self._stop_time = time.time() - - def stop(self) -> None: - self._reconnect_token.cancelled = True - if self._reconnect_timer is not None: - self._reconnect_timer.cancel() - self._reconnect_timer.join() - self._reconnect_timer = None - - def reconnect_eventually(self) -> int: - if self._stop_time - self._start_time > self._RESET_BACKOFF_AFTER_UPTIME: - delay = self._reconnect_delay = self._INITIAL_DELAY - else: - delay = self._reconnect_delay - self._reconnect_delay = ( - self._reconnect_delay * self._BACKOFF_MULTIPLIER - + random.uniform(1.0, self._RANDOM_ADDEND_MAX) - ) - token = CancelToken() - self._reconnect_token = token - self._reconnect_timer = threading.Timer( - delay, lambda: self._io_queue.put((token, IoReconnect())) - ) - self._reconnect_timer.daemon = True - self._reconnect_timer.start() - return int(delay) - - -def main_loop( - io_queue: "queue.Queue[IoActivity]", - server: Server, - systray: SystrayState, -) -> None: - reconnector = Reconnector(io_queue) - handle_clients: Dict[PermuterHandle, Client] = {} - while True: - token, activity = io_queue.get() - if token and token.cancelled: - continue - - if not isinstance(activity, tuple): - if isinstance(activity, IoShutdown): - break - - elif isinstance(activity, IoReconnect): - print("reconnecting...") - try: - systray.server_reconnecting() - reconnector.mark_start() - server.start() - systray.server_connected() - except EOFError: - delay = reconnector.reconnect_eventually() - print(f"failed again, reconnecting in {delay} seconds...") - systray.server_failed(True) - except ServerError as e: - print("failed!", e.message) - systray.server_failed(False, e.message) - except Exception: - print("failed!") - traceback.print_exc() - systray.server_failed(False) - - elif isinstance(activity, IoServerFailed): - if activity.message: - print("Server error:", activity.message) - print("disconnected from permuter@home") - server.stop() - reconnector.mark_stop() - systray.server_failed(activity.graceful, activity.message) - - if activity.graceful: - delay = reconnector.reconnect_eventually() - print(f"will reconnect in {delay} seconds...") - - else: - static_assert_unreachable(activity) - - else: - handle, msg = activity - - if isinstance(msg, IoConnect): - client = msg.client - handle_clients[handle] = client - systray.connect(handle, client.nickname, msg.fn_name) - print(f"[{client.nickname}] connected ({msg.fn_name})") - - elif isinstance(msg, IoDisconnect): - systray.disconnect(handle) - nickname = handle_clients[handle].nickname - del handle_clients[handle] - print(f"[{nickname}] {msg.reason}") - - elif isinstance(msg, IoImmediateDisconnect): - print(f"[{msg.client.nickname}] {msg.reason}") - - elif isinstance(msg, IoWorkDone): - # TODO: statistics - systray.work_done(handle, msg.is_improvement) - - elif isinstance(msg, IoUserRemovePermuter): - server.remove_permuter(handle) - - else: - static_assert_unreachable(msg) - - -def server_main(options: ServerOptions, use_systray: bool) -> None: - io_queue: "queue.Queue[IoActivity]" = queue.Queue() - config = read_config() - - systray: SystrayState - if use_systray: - systray = RealSystrayState(config, io_queue) - else: - systray = SystrayState() - - try: - server = Server(options, config, io_queue) - server.start() - - try: - systray.server_connected() - main_loop(io_queue, server, systray) - finally: - server.stop() - finally: - systray.stop() diff --git a/tools/decomp-permuter/src/net/cmd/setup.py b/tools/decomp-permuter/src/net/cmd/setup.py deleted file mode 100644 index 8238ded2b1..0000000000 --- a/tools/decomp-permuter/src/net/cmd/setup.py +++ /dev/null @@ -1,86 +0,0 @@ -from argparse import ArgumentParser, Namespace -import base64 -import os -import random -import string -import sys -from typing import Optional - -from nacl.public import SealedBox -from nacl.signing import SigningKey, VerifyKey - -from .base import Command -from ..core import connect, read_config, sign_with_magic, write_config -from .util import ask - - -class SetupCommand(Command): - command = "setup" - help = """Set up permuter@home. This will require someone else to grant you - access to the central server.""" - - @staticmethod - def add_arguments(parser: ArgumentParser) -> None: - pass - - @staticmethod - def run(args: Namespace) -> None: - _run_initial_setup() - - -def _random_name() -> str: - return "".join(random.choice(string.ascii_lowercase) for _ in range(5)) - - -def _run_initial_setup() -> None: - config = read_config() - signing_key: Optional[SigningKey] = config.signing_key - if not signing_key or not ask("Keep previous secret key", default=True): - signing_key = SigningKey.generate() - config.signing_key = signing_key - write_config(config) - verify_key = signing_key.verify_key - - nickname: Optional[str] = config.initial_setup_nickname - if not nickname or not ask(f"Keep previous nickname [{nickname}]", default=True): - default_nickname = os.environ.get("USER") or _random_name() - nickname = ( - input(f"Nickname [default: {default_nickname}]: ") or default_nickname - ) - config.initial_setup_nickname = nickname - write_config(config) - - signed_nickname = sign_with_magic(b"NAME", signing_key, nickname.encode("utf-8")) - - vouch_data = verify_key.encode() + signed_nickname - vouch_text = base64.b64encode(vouch_data).decode("utf-8") - print("Ask someone to run the following command:") - print(f"./pah.py vouch {vouch_text}") - print() - print("They should give you a token back in return. Paste that here:") - inp = input().strip() - - try: - token = base64.b64decode(inp.encode("utf-8")) - data = SealedBox(signing_key.to_curve25519_private_key()).decrypt(token) - config.server_address = data[32:].decode("utf-8") - config.server_verify_key = VerifyKey(data[:32]) - config.initial_setup_nickname = None - except Exception: - print("Invalid token!") - sys.exit(1) - - print(f"Server: {config.server_address}") - print("Testing connection...") - - port = connect(config) - port.send_json({"method": "ping"}) - port.receive_json() - - try: - write_config(config) - except Exception as e: - print("Failed to write config:", e) - sys.exit(1) - - print("permuter@home successfully set up!") diff --git a/tools/decomp-permuter/src/net/cmd/systray/.gitignore b/tools/decomp-permuter/src/net/cmd/systray/.gitignore deleted file mode 100644 index 49f8d190ce..0000000000 --- a/tools/decomp-permuter/src/net/cmd/systray/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -permuter-systray -permuter-systray.exe diff --git a/tools/decomp-permuter/src/net/cmd/systray/LICENSE b/tools/decomp-permuter/src/net/cmd/systray/LICENSE deleted file mode 100644 index dc9823bd1e..0000000000 --- a/tools/decomp-permuter/src/net/cmd/systray/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2017 Zack Young - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/tools/decomp-permuter/src/net/cmd/systray/README.md b/tools/decomp-permuter/src/net/cmd/systray/README.md deleted file mode 100644 index 5dee2f87ea..0000000000 --- a/tools/decomp-permuter/src/net/cmd/systray/README.md +++ /dev/null @@ -1,13 +0,0 @@ -# systray - -This directory contains a Go application that shows a system tray, which the Python code interacts with. - -It is a fork of https://github.com/felixhao28/systray-portable. - -To build it: - -- install Go -- if on Linux, install dependencies: `libgtk-3-dev`, `libappindicator3-dev` -- run `go build` - -If on Windows, this needs to be done *outside* of WSL. diff --git a/tools/decomp-permuter/src/net/cmd/systray/go.mod b/tools/decomp-permuter/src/net/cmd/systray/go.mod deleted file mode 100644 index 4abd8814ce..0000000000 --- a/tools/decomp-permuter/src/net/cmd/systray/go.mod +++ /dev/null @@ -1,7 +0,0 @@ -module permuter-systray - -go 1.15 - -require github.com/getlantern/systray v1.1.0 - -replace github.com/getlantern/systray v1.1.0 => github.com/simonlindholm/systray v1.1.1-0.20210502122945-b7c77212cd56 diff --git a/tools/decomp-permuter/src/net/cmd/systray/go.sum b/tools/decomp-permuter/src/net/cmd/systray/go.sum deleted file mode 100644 index bd6cf25417..0000000000 --- a/tools/decomp-permuter/src/net/cmd/systray/go.sum +++ /dev/null @@ -1,4 +0,0 @@ -github.com/simonlindholm/systray v1.1.1-0.20210502122945-b7c77212cd56 h1:UZcM1HdV25CQhhJD340jxRLRGl0V11V0wIoUDKTOZMI= -github.com/simonlindholm/systray v1.1.1-0.20210502122945-b7c77212cd56/go.mod h1:N5dpnnWiJhCxh+gXuNgDS2p5MjgcVR/TGwWuaDc4gLk= -golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9 h1:YTzHMGlqJu67/uEo1lBv0n3wBXhXNeUbB1XfN2vmTm0= -golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/tools/decomp-permuter/src/net/cmd/systray/prebuilt/permuter-systray-linux-x86_64 b/tools/decomp-permuter/src/net/cmd/systray/prebuilt/permuter-systray-linux-x86_64 deleted file mode 100755 index d23a6990d3..0000000000 Binary files a/tools/decomp-permuter/src/net/cmd/systray/prebuilt/permuter-systray-linux-x86_64 and /dev/null differ diff --git a/tools/decomp-permuter/src/net/cmd/systray/prebuilt/permuter-systray-macos-x86_64 b/tools/decomp-permuter/src/net/cmd/systray/prebuilt/permuter-systray-macos-x86_64 deleted file mode 100755 index 4eee7ff130..0000000000 Binary files a/tools/decomp-permuter/src/net/cmd/systray/prebuilt/permuter-systray-macos-x86_64 and /dev/null differ diff --git a/tools/decomp-permuter/src/net/cmd/systray/prebuilt/permuter-systray-win-x86_64.exe b/tools/decomp-permuter/src/net/cmd/systray/prebuilt/permuter-systray-win-x86_64.exe deleted file mode 100755 index 8de9fc216b..0000000000 Binary files a/tools/decomp-permuter/src/net/cmd/systray/prebuilt/permuter-systray-win-x86_64.exe and /dev/null differ diff --git a/tools/decomp-permuter/src/net/cmd/systray/tray.go b/tools/decomp-permuter/src/net/cmd/systray/tray.go deleted file mode 100644 index 67c387844a..0000000000 --- a/tools/decomp-permuter/src/net/cmd/systray/tray.go +++ /dev/null @@ -1,287 +0,0 @@ -package main - -import ( - "bufio" - "encoding/base64" - "encoding/json" - "fmt" - "os" - "os/signal" - "reflect" - "strings" - "syscall" - - "github.com/getlantern/systray" -) - -func main() { - systray.Run(onReady, onExit) -} - -func onExit() { - os.Exit(0) -} - -// Item represents an item in the menu -type Item struct { - Icon string `json:"icon"` - Title string `json:"title"` - Tooltip string `json:"tooltip"` - Enabled bool `json:"enabled"` - Checked bool `json:"checked"` - Hidden bool `json:"hidden"` - Items []Item `json:"items"` - InternalID int `json:"__id"` -} - -// Menu has an icon, title and list of items -type Menu struct { - Icon string `json:"icon"` - Title string `json:"title"` - Tooltip string `json:"tooltip"` - Items []Item `json:"items"` -} - -// Action for an item?.. -type Action struct { - Type string `json:"type"` - Item Item `json:"item"` - Menu Menu `json:"menu"` -} - -// ClickEvent for an click event -type ClickEvent struct { - Type string `json:"type"` - InternalID int `json:"__id"` -} - -func readJSON(reader *bufio.Reader, v interface{}) error { - input, err := reader.ReadString('\n') - if err != nil { - return err - } - if len(input) < 1 { - return fmt.Errorf("Empty line") - } - - lineReader := strings.NewReader(input[0 : len(input)-1]) - if err := json.NewDecoder(lineReader).Decode(v); err != nil { - return err - } - - return nil -} - -func addMenuItem(items *[]*systray.MenuItem, seqID2InternalID *[]int, internalID2SeqID *map[int]int, item *Item, parent *systray.MenuItem) { - if item.Title == "" { - systray.AddSeparator() - *items = append(*items, nil) - } else { - var menuItem *systray.MenuItem - if parent == nil { - menuItem = systray.AddMenuItem(item.Title, item.Tooltip) - } else { - menuItem = parent.AddSubMenuItem(item.Title, item.Tooltip) - } - if item.Checked { - menuItem.Check() - } else { - menuItem.Uncheck() - } - if item.Enabled { - menuItem.Enable() - } else { - menuItem.Disable() - } - if len(item.Icon) > 0 { - icon, err := base64.StdEncoding.DecodeString(item.Icon) - if err != nil { - fmt.Fprintln(os.Stderr, err) - } else { - menuItem.SetIcon(icon) - } - } - for i := 0; i < len(item.Items); i++ { - subitem := item.Items[i] - addMenuItem(items, seqID2InternalID, internalID2SeqID, &subitem, menuItem) - } - if item.Hidden { - menuItem.Hide() - } - *items = append(*items, menuItem) - } - seqID := len(*items) - 1 - (*internalID2SeqID)[item.InternalID] = seqID - *seqID2InternalID = append(*seqID2InternalID, item.InternalID) -} - -func onReady() { - signalChannel := make(chan os.Signal, 2) - signal.Notify(signalChannel, os.Interrupt, syscall.SIGTERM) - go func() { - for sig := range signalChannel { - switch sig { - case os.Interrupt, syscall.SIGTERM: - // handle SIGINT, SIGTERM - fmt.Fprintln(os.Stderr, "Quit") - systray.Quit() - default: - fmt.Fprintln(os.Stderr, "Unhandled signal:", sig) - } - } - }() - - items := make([]*systray.MenuItem, 0) - seqID2InternalID := make([]int, 0) - internalID2SeqID := make(map[int]int) - fmt.Println(`{"type": "ready"}`) - reader := bufio.NewReader(os.Stdin) - - var menu Menu - if err := readJSON(reader, &menu); err != nil { - fmt.Fprintln(os.Stderr, err) - systray.Quit() - return - } - - icon, err := base64.StdEncoding.DecodeString(menu.Icon) - if err != nil { - fmt.Fprintln(os.Stderr, err) - systray.Quit() - return - } - - systray.SetIcon(icon) - systray.SetTitle(menu.Title) - systray.SetTooltip(menu.Tooltip) - - updateItem := func(action Action) { - item := action.Item - seqID := internalID2SeqID[action.Item.InternalID] - menuItem := items[seqID] - if menuItem == nil { - return - } - if item.Hidden { - menuItem.Hide() - } else { - if item.Checked { - menuItem.Check() - } else { - menuItem.Uncheck() - } - if item.Enabled { - menuItem.Enable() - } else { - menuItem.Disable() - } - menuItem.SetTitle(item.Title) - menuItem.SetTooltip(item.Tooltip) - if len(item.Icon) > 0 { - icon, err := base64.StdEncoding.DecodeString(item.Icon) - if err != nil { - fmt.Fprintln(os.Stderr, err) - } else { - menuItem.SetIcon(icon) - } - } - menuItem.Show() - for _, child := range item.Items { - seqID = internalID2SeqID[child.InternalID] - items[seqID].Show() - } - } - } - - updateMenu := func(action Action) { - m := action.Menu - if menu.Title != m.Title { - menu.Title = m.Title - systray.SetTitle(menu.Title) - } - if menu.Icon != m.Icon && m.Icon != "" { - menu.Icon = m.Icon - icon, err := base64.StdEncoding.DecodeString(menu.Icon) - if err != nil { - fmt.Fprintln(os.Stderr, err) - } else { - systray.SetIcon(icon) - } - } - if menu.Tooltip != m.Tooltip { - menu.Tooltip = m.Tooltip - systray.SetTooltip(menu.Tooltip) - } - } - - update := func(action Action) { - switch action.Type { - case "update-item": - updateItem(action) - case "update-menu": - updateMenu(action) - case "update-item-and-menu": - updateItem(action) - updateMenu(action) - case "exit": - systray.Quit() - } - } - - for i := 0; i < len(menu.Items); i++ { - item := menu.Items[i] - addMenuItem(&items, &seqID2InternalID, &internalID2SeqID, &item, nil) - } - - go func(reader *bufio.Reader) { - for { - var action Action - if err := readJSON(reader, &action); err != nil { - fmt.Fprintln(os.Stderr, err) - systray.Quit() - break - } - update(action) - } - }(reader) - - stdoutEnc := json.NewEncoder(os.Stdout) - for { - itemsCnt := 0 - for _, ch := range items { - if ch != nil { - itemsCnt++ - } - } - cases := make([]reflect.SelectCase, itemsCnt) - caseCnt2SeqID := make([]int, len(items)) - itemsCnt = 0 - for i, ch := range items { - if ch == nil { - continue - } - cases[itemsCnt] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch.ClickedCh)} - caseCnt2SeqID[itemsCnt] = i - itemsCnt++ - } - - remaining := len(cases) - for remaining > 0 { - chosen, _, ok := reflect.Select(cases) - if !ok { - // The chosen channel has been closed, so zero out the channel to disable the case - cases[chosen].Chan = reflect.ValueOf(nil) - remaining-- - continue - } - seqID := caseCnt2SeqID[chosen] - err := stdoutEnc.Encode(ClickEvent{ - Type: "clicked", - InternalID: seqID2InternalID[seqID], - }) - if err != nil { - fmt.Fprintln(os.Stderr, err) - } - } - } -} diff --git a/tools/decomp-permuter/src/net/cmd/util.py b/tools/decomp-permuter/src/net/cmd/util.py deleted file mode 100644 index dfe9b47e40..0000000000 --- a/tools/decomp-permuter/src/net/cmd/util.py +++ /dev/null @@ -1,15 +0,0 @@ -import sys - - -def ask(msg: str, *, default: bool) -> bool: - if default: - msg += " (Y/n)? " - else: - msg += " (y/N)? " - res = input(msg).strip().lower() - if not res: - return default - if res in ["y", "yes", "n", "no"]: - return res[0] == "y" - print("Bad response!") - sys.exit(1) diff --git a/tools/decomp-permuter/src/net/cmd/vouch.py b/tools/decomp-permuter/src/net/cmd/vouch.py deleted file mode 100644 index 82b19a8b08..0000000000 --- a/tools/decomp-permuter/src/net/cmd/vouch.py +++ /dev/null @@ -1,73 +0,0 @@ -from argparse import ArgumentParser, Namespace -import base64 -import sys - -from nacl.encoding import HexEncoder -from nacl.public import SealedBox -from nacl.signing import VerifyKey - -from ..core import connect, read_config, verify_with_magic -from .base import Command -from .util import ask - - -class VouchCommand(Command): - command = "vouch" - help = "Give someone access to the central server." - - @staticmethod - def add_arguments(parser: ArgumentParser) -> None: - parser.add_argument( - "magic", - help="Opaque hex string generated by 'setup'.", - ) - - @staticmethod - def run(args: Namespace) -> None: - run_vouch(args.magic) - - -def run_vouch(magic: str) -> None: - try: - vouch_data = base64.b64decode(magic.encode("utf-8")) - verify_key = VerifyKey(vouch_data[:32]) - signed_nickname = vouch_data[32:] - msg = verify_with_magic(b"NAME", verify_key, signed_nickname) - nickname = msg.decode("utf-8") - except Exception: - print("Could not parse data!") - sys.exit(1) - - try: - config = read_config() - port = connect(config) - port.send_json( - { - "method": "vouch", - "who": verify_key.encode(HexEncoder).decode("utf-8"), - "signed_name": HexEncoder.encode(signed_nickname).decode("utf-8"), - } - ) - port.receive_json() - except Exception as e: - print(f"Error: {e}") - sys.exit(1) - - if not ask(f"Grant permuter server access to {nickname}", default=True): - return - - try: - port.send_json({}) - port.receive_json() - except Exception as e: - print(f"Failed to grant access: {e}") - sys.exit(1) - - assert config.server_address, "checked by connect" - assert config.server_verify_key, "checked by connect" - data = config.server_verify_key.encode() + config.server_address.encode("utf-8") - token = SealedBox(verify_key.to_curve25519_public_key()).encrypt(data) - print("Granted!") - print() - print("Send them the following token:") - print(base64.b64encode(token).decode("utf-8")) diff --git a/tools/decomp-permuter/src/net/controller/.gitignore b/tools/decomp-permuter/src/net/controller/.gitignore deleted file mode 100644 index a2274adae7..0000000000 --- a/tools/decomp-permuter/src/net/controller/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -target/ -config.toml -*.json diff --git a/tools/decomp-permuter/src/net/controller/Cargo.lock b/tools/decomp-permuter/src/net/controller/Cargo.lock deleted file mode 100644 index 9c9fab8c19..0000000000 --- a/tools/decomp-permuter/src/net/controller/Cargo.lock +++ /dev/null @@ -1,607 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -[[package]] -name = "argh" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91792f088f87cdc7a2cfb1d617fa5ea18d7f1dc22ef0e1b5f82f3157cdc522be" -dependencies = [ - "argh_derive", - "argh_shared", -] - -[[package]] -name = "argh_derive" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4eb0c0c120ad477412dc95a4ce31e38f2113e46bd13511253f79196ca68b067" -dependencies = [ - "argh_shared", - "heck", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "argh_shared" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "781f336cc9826dbaddb9754cb5db61e64cab4f69668bd19dcc4a0394a86f4cb1" - -[[package]] -name = "autocfg" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" - -[[package]] -name = "bitflags" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" - -[[package]] -name = "bytes" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" - -[[package]] -name = "cc" -version = "1.0.67" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3c69b077ad434294d3ce9f1f6143a2a4b89a8a2d54ef813d85003a4fd1137fd" - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "chrono" -version = "0.4.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "670ad68c9088c2a963aaa298cb369688cf3f9465ce5e2d4ca10e6e0098a1ce73" -dependencies = [ - "libc", - "num-integer", - "num-traits", - "time", - "winapi", -] - -[[package]] -name = "getrandom" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9495705279e7140bf035dde1f6e750c162df8b625267cd52cc44e0b156732c8" -dependencies = [ - "cfg-if", - "libc", - "wasi", -] - -[[package]] -name = "heck" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cbf45460356b7deeb5e3415b5563308c0a9b057c85e12b06ad551f98d0a6ac" -dependencies = [ - "unicode-segmentation", -] - -[[package]] -name = "hermit-abi" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "322f4de77956e22ed0e5032c359a0f1273f1f7f0d79bfa3b8ffbc730d7fbcc5c" -dependencies = [ - "libc", -] - -[[package]] -name = "hex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" - -[[package]] -name = "instant" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61124eeebbd69b8190558df225adf7e4caafce0d743919e5d6b19652314ec5ec" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "itoa" -version = "0.4.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736" - -[[package]] -name = "libc" -version = "0.2.93" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9385f66bf6105b241aa65a61cb923ef20efc665cb9f9bb50ac2f0c4b7f378d41" - -[[package]] -name = "libsodium-sys" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a685b64f837b339074115f2e7f7b431ac73681d08d75b389db7498b8892b8a58" -dependencies = [ - "cc", - "libc", - "pkg-config", -] - -[[package]] -name = "lock_api" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a3c91c24eae6777794bb1997ad98bbb87daf92890acab859f7eaa4320333176" -dependencies = [ - "scopeguard", -] - -[[package]] -name = "log" -version = "0.4.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "memchr" -version = "2.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ee1c47aaa256ecabcaea351eae4a9b01ef39ed810004e298d2511ed284b1525" - -[[package]] -name = "mio" -version = "0.7.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf80d3e903b34e0bd7282b218398aec54e082c840d9baf8339e0080a0c542956" -dependencies = [ - "libc", - "log", - "miow", - "ntapi", - "winapi", -] - -[[package]] -name = "miow" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9f1c5b025cda876f66ef43a113f91ebc9f4ccef34843000e0adf6ebbab84e21" -dependencies = [ - "winapi", -] - -[[package]] -name = "ntapi" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44" -dependencies = [ - "winapi", -] - -[[package]] -name = "num-integer" -version = "0.1.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db" -dependencies = [ - "autocfg", - "num-traits", -] - -[[package]] -name = "num-traits" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290" -dependencies = [ - "autocfg", -] - -[[package]] -name = "num_cpus" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" -dependencies = [ - "hermit-abi", - "libc", -] - -[[package]] -name = "once_cell" -version = "1.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af8b08b04175473088b46763e51ee54da5f9a164bc162f615b91bc179dbf15a3" - -[[package]] -name = "pahserver" -version = "0.0.1" -dependencies = [ - "argh", - "chrono", - "hex", - "pin-project", - "serde", - "serde_json", - "serde_tuple", - "slotmap", - "sodiumoxide", - "tempfile", - "tokio", - "toml", -] - -[[package]] -name = "parking_lot" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d7744ac029df22dca6284efe4e898991d28e3085c706c972bcd7da4a27a15eb" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7a782938e745763fe6907fc6ba86946d72f49fe7e21de074e08128a99fb018" -dependencies = [ - "cfg-if", - "instant", - "libc", - "redox_syscall", - "smallvec", - "winapi", -] - -[[package]] -name = "pin-project" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7509cc106041c40a4518d2af7a61530e1eed0e6285296a3d8c5472806ccc4a4" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c950132583b500556b1efd71d45b319029f2b71518d979fcc208e16b42426f" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "pin-project-lite" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc0e1f259c92177c30a4c9d177246edd0a3568b25756a977d0632cf8fa37e905" - -[[package]] -name = "pkg-config" -version = "0.3.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3831453b3449ceb48b6d9c7ad7c96d5ea673e9b470a1dc578c2ce6521230884c" - -[[package]] -name = "ppv-lite86" -version = "0.2.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857" - -[[package]] -name = "proc-macro2" -version = "1.0.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a152013215dca273577e18d2bf00fa862b89b24169fb78c4c95aeb07992c9cec" -dependencies = [ - "unicode-xid", -] - -[[package]] -name = "quote" -version = "1.0.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rand" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ef9e7e66b4468674bfcb0c81af8b7fa0bb154fa9f28eb840da5c447baeb8d7e" -dependencies = [ - "libc", - "rand_chacha", - "rand_core", - "rand_hc", -] - -[[package]] -name = "rand_chacha" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e12735cf05c9e10bf21534da50a147b924d555dc7a547c42e6bb2d5b6017ae0d" -dependencies = [ - "ppv-lite86", - "rand_core", -] - -[[package]] -name = "rand_core" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34cf66eb183df1c5876e2dcf6b13d57340741e8dc255b48e40a26de954d06ae7" -dependencies = [ - "getrandom", -] - -[[package]] -name = "rand_hc" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3190ef7066a446f2e7f42e239d161e905420ccab01eb967c9eb27d21b2322a73" -dependencies = [ - "rand_core", -] - -[[package]] -name = "redox_syscall" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8270314b5ccceb518e7e578952f0b72b88222d02e8f77f5ecf7abbb673539041" -dependencies = [ - "bitflags", -] - -[[package]] -name = "remove_dir_all" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" -dependencies = [ - "winapi", -] - -[[package]] -name = "ryu" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e" - -[[package]] -name = "scopeguard" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" - -[[package]] -name = "serde" -version = "1.0.125" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "558dc50e1a5a5fa7112ca2ce4effcb321b0300c0d4ccf0776a9f60cd89031171" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.125" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b093b7a2bb58203b5da3056c05b4ec1fed827dcfdb37347a8841695263b3d06d" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "serde_json" -version = "1.0.64" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79" -dependencies = [ - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "serde_tuple" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4f025b91216f15a2a32aa39669329a475733590a015835d1783549a56d09427" -dependencies = [ - "serde", - "serde_tuple_macros", -] - -[[package]] -name = "serde_tuple_macros" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4076151d1a2b688e25aaf236997933c66e18b870d0369f8b248b8ab2be630d7e" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "signal-hook-registry" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16f1d0fef1604ba8f7a073c7e701f213e056707210e9020af4528e0101ce11a6" -dependencies = [ - "libc", -] - -[[package]] -name = "slotmap" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585cd5dffe4e9e06f6dfdf66708b70aca3f781bed561f4f667b2d9c0d4559e36" -dependencies = [ - "version_check", -] - -[[package]] -name = "smallvec" -version = "1.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" - -[[package]] -name = "sodiumoxide" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7038b67c941e23501573cb7242ffb08709abe9b11eb74bceff875bbda024a6a8" -dependencies = [ - "libc", - "libsodium-sys", - "serde", -] - -[[package]] -name = "syn" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48fe99c6bd8b1cc636890bcc071842de909d902c81ac7dab53ba33c421ab8ffb" -dependencies = [ - "proc-macro2", - "quote", - "unicode-xid", -] - -[[package]] -name = "tempfile" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dac1c663cfc93810f88aed9b8941d48cabf856a1b111c29a40439018d870eb22" -dependencies = [ - "cfg-if", - "libc", - "rand", - "redox_syscall", - "remove_dir_all", - "winapi", -] - -[[package]] -name = "time" -version = "0.1.43" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438" -dependencies = [ - "libc", - "winapi", -] - -[[package]] -name = "tokio" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83f0c8e7c0addab50b663055baf787d0af7f413a46e6e7fb9559a4e4db7137a5" -dependencies = [ - "autocfg", - "bytes", - "libc", - "memchr", - "mio", - "num_cpus", - "once_cell", - "parking_lot", - "pin-project-lite", - "signal-hook-registry", - "tokio-macros", - "winapi", -] - -[[package]] -name = "tokio-macros" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "caf7b11a536f46a809a8a9f0bb4237020f70ecbf115b842360afb127ea2fda57" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "toml" -version = "0.5.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31142970826733df8241ef35dc040ef98c679ab14d7c3e54d827099b3acecaa" -dependencies = [ - "serde", -] - -[[package]] -name = "unicode-segmentation" -version = "1.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb0d2e7be6ae3a5fa87eed5fb451aff96f2573d2694942e40543ae0bbe19c796" - -[[package]] -name = "unicode-xid" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564" - -[[package]] -name = "version_check" -version = "0.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe" - -[[package]] -name = "wasi" -version = "0.10.2+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" - -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" diff --git a/tools/decomp-permuter/src/net/controller/Cargo.toml b/tools/decomp-permuter/src/net/controller/Cargo.toml deleted file mode 100644 index 40b050b0a5..0000000000 --- a/tools/decomp-permuter/src/net/controller/Cargo.toml +++ /dev/null @@ -1,19 +0,0 @@ -[package] -name = "pahserver" -version = "0.0.1" -edition = "2018" -resolver = "2" - -[dependencies] -tokio = { version = "1", features = ["full"] } -sodiumoxide = "0.2" -toml = "0.5" -serde = { version = "1.0", features = ["derive", "rc"] } -serde_json = "1.0" -serde_tuple = "0.5" -hex = "0.4" -argh = "0.1" -tempfile = "3" -slotmap = "1" -pin-project = "1" -chrono = "*" diff --git a/tools/decomp-permuter/src/net/controller/README.md b/tools/decomp-permuter/src/net/controller/README.md deleted file mode 100644 index 9661e6cf3d..0000000000 --- a/tools/decomp-permuter/src/net/controller/README.md +++ /dev/null @@ -1,23 +0,0 @@ -# controller - -This directory contains code for the central permuter@home controller server, -written in Rust. All p@h traffic passes through here. - -If you just want to run a regular p@h server, you don't need to care about this. - -To setup your own copy of the controller server: - -- Install Rust and (for the libsodium dependency) GCC. -- Run `cargo build --release`. -- Run `./target/release/pahserver setup --db path/to/database.json` and follow - the instructions there. This will set the `priv_seed` part of `config.toml`, and - set up an initial trusted client. The rest of `config.toml` can be copied from - `config_example.toml`. -- Set up a reverse proxy that forwards HTTPS traffic from an external port or route - to HTTP for a port of your choice, e.g. using Nginx or Traefik. - If applicable, configure your firewall to let the external port through. -- Start the server with: - ``` -./target/release/pahserver run --listen-on 0.0.0.0: --config config.toml --db path/to/database.json -``` -and configure the system to run this at startup. diff --git a/tools/decomp-permuter/src/net/controller/config_example.toml b/tools/decomp-permuter/src/net/controller/config_example.toml deleted file mode 100644 index 9e09270859..0000000000 --- a/tools/decomp-permuter/src/net/controller/config_example.toml +++ /dev/null @@ -1,2 +0,0 @@ -docker_image = "" -priv_seed = "0000000000000000000000000000000000000000000000000000000000000000" diff --git a/tools/decomp-permuter/src/net/controller/src/client.rs b/tools/decomp-permuter/src/net/controller/src/client.rs deleted file mode 100644 index 72a486d1ba..0000000000 --- a/tools/decomp-permuter/src/net/controller/src/client.rs +++ /dev/null @@ -1,205 +0,0 @@ -use std::collections::VecDeque; -use std::sync::Arc; - -use serde::{Deserialize, Serialize}; -use serde_json::json; -use tokio::sync::mpsc; - -use crate::db::UserId; -use crate::flimsy_semaphore::FlimsySemaphore; -use crate::port::{ReadPort, WritePort}; -use crate::stats; -use crate::util::SimpleResult; -use crate::{ - current_load, Permuter, PermuterData, PermuterId, PermuterResult, PermuterWork, ServerUpdate, - State, -}; - -const MIN_PERMUTER_VERSION: u32 = 1; - -const CLIENT_MAX_QUEUES_SIZE: usize = 100; -const MIN_PRIORITY: f64 = 0.001; -const MAX_PRIORITY: f64 = 10.0; - -#[derive(Debug, Deserialize)] -pub(crate) struct ConnectClientData { - priority: f64, -} - -#[derive(Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum ClientMessage { - Work { work: PermuterWork }, -} - -#[derive(Serialize)] -struct PermuterResultMessage<'a> { - server: String, - #[serde(flatten)] - update: &'a ServerUpdate, -} - -async fn client_read( - port: &mut ReadPort<'_>, - perm_id: &PermuterId, - semaphore: &FlimsySemaphore, - state: &State, -) -> SimpleResult<()> { - loop { - let msg = port.recv().await?; - let msg: ClientMessage = serde_json::from_slice(&msg)?; - let ClientMessage::Work { work } = msg; - - // Avoid the work and result queues growing indefinitely by restricting - // their combined size with a semaphore. - semaphore.acquire().await; - - let mut m = state.m.lock().unwrap(); - let perm = m.permuters.get_mut(perm_id).unwrap(); - if perm.work_queue.is_empty() { - state.new_work_notification.notify_waiters(); - } - perm.work_queue.push_back(work); - } -} - -async fn client_write( - port: &mut WritePort<'_>, - fn_name: &str, - semaphore: &FlimsySemaphore, - state: &State, - mut result_rx: mpsc::UnboundedReceiver, - client_id: &UserId, -) -> SimpleResult<()> { - loop { - let res = result_rx.recv().await.unwrap(); - semaphore.release(); - - match res { - PermuterResult::NeedWork => { - port.send_json(&json!({ - "type": "need_work", - })) - .await?; - } - PermuterResult::Result(server_id, server_name, server_update) => { - port.send_json(&PermuterResultMessage { - server: server_name, - update: &server_update, - }) - .await?; - - if let ServerUpdate::Result { - compressed_source, - ref more_props, - .. - } = server_update - { - if let Some(ref data) = compressed_source { - port.send(data).await?; - } - - let score = more_props.get("score").and_then(|score| score.as_i64()); - let outcome = if compressed_source.is_none() { - stats::Outcome::Unhelpful - } else if matches!(score, Some(0)) { - stats::Outcome::Matched - } else { - stats::Outcome::Improved - }; - state - .log_stats(stats::Record::WorkDone { - server: server_id, - client: client_id.clone(), - fn_name: fn_name.to_string(), - outcome, - }) - .await?; - } - } - } - } -} - -pub(crate) async fn handle_connect_client<'a>( - mut read_port: ReadPort<'a>, - mut write_port: WritePort<'a>, - who_id: UserId, - who_name: &str, - permuter_version: u32, - state: &State, - data: ConnectClientData, -) -> SimpleResult<()> { - if permuter_version < MIN_PERMUTER_VERSION { - Err("Permuter version too old!")?; - } - - if !(MIN_PRIORITY <= data.priority && data.priority <= MAX_PRIORITY) { - Err("Priority out of range")?; - } - - let load = current_load(state, Some(data.priority)); - write_port.send_json(&load).await?; - - let permuter_data = read_port.recv().await?; - let mut permuter_data: PermuterData = serde_json::from_slice(&permuter_data)?; - permuter_data.compressed_source = read_port.recv().await?; - permuter_data.compressed_target_o_bin = read_port.recv().await?; - write_port.send_json(&json!({})).await?; - - eprintln!( - "[{}] start client ({}, {})", - &who_name, &permuter_data.fn_name, data.priority - ); - - state - .log_stats(stats::Record::ClientNewFunction { - client: who_id.clone(), - fn_name: permuter_data.fn_name.clone(), - }) - .await?; - - let energy_add = 1.0 / data.priority; - let fn_name = permuter_data.fn_name.clone(); - - let (result_tx, result_rx) = mpsc::unbounded_channel(); - let semaphore = Arc::new(FlimsySemaphore::new(CLIENT_MAX_QUEUES_SIZE)); - - let perm_id = { - let mut m = state.m.lock().unwrap(); - let id = m.next_permuter_id; - m.next_permuter_id += 1; - m.permuters.insert( - id, - Permuter { - data: permuter_data.into(), - client_id: who_id.clone(), - client_name: who_name.to_string(), - work_queue: VecDeque::new(), - result_tx: result_tx.clone(), - semaphore: semaphore.clone(), - priority: data.priority, - energy_add, - }, - ); - state.new_work_notification.notify_waiters(); - id - }; - - let r = tokio::try_join!( - client_read(&mut read_port, &perm_id, &semaphore, state), - client_write( - &mut write_port, - &fn_name, - &semaphore, - state, - result_rx, - &who_id - ) - ); - - state.m.lock().unwrap().permuters.remove(&perm_id); - state.new_work_notification.notify_waiters(); - r?; - Ok(()) -} diff --git a/tools/decomp-permuter/src/net/controller/src/db.rs b/tools/decomp-permuter/src/net/controller/src/db.rs deleted file mode 100644 index 8b608bd0dd..0000000000 --- a/tools/decomp-permuter/src/net/controller/src/db.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::collections::HashMap; -use std::convert::TryInto; - -use hex::FromHex; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use serde_tuple::{Deserialize_tuple, Serialize_tuple}; -use sodiumoxide::crypto::sign; - -#[derive(Clone, Debug, Hash, Eq, PartialEq)] -pub struct ByteString([u8; SIZE]); - -impl ByteString { - fn to_hex(&self) -> String { - hex::encode(&self.0) - } - - fn from_hex(string: &str) -> Result, &'static str> { - Ok(ByteString( - Vec::from_hex(&string) - .map_err(|_| "not a valid hex string")? - .try_into() - .map_err(|_| "byte string has wrong size")?, - )) - } -} - -impl Serialize for ByteString { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_str(&self.to_hex()) - } -} - -impl<'de, const SIZE: usize> Deserialize<'de> for ByteString { - fn deserialize(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - use serde::de::Error; - let string = String::deserialize(deserializer)?; - ByteString::from_hex(&string).map_err(Error::custom) - } -} - -pub type UserId = ByteString<32>; - -impl UserId { - pub fn from_pubkey(key: &sign::PublicKey) -> UserId { - ByteString(key.as_ref().try_into().unwrap()) - } - - pub fn to_pubkey(&self) -> sign::PublicKey { - sign::PublicKey::from_slice(&self.0).unwrap() - } -} - -impl ByteString { - pub fn to_seed(&self) -> sign::Seed { - sign::Seed::from_slice(&self.0).unwrap() - } -} - -#[derive(Debug, Deserialize_tuple, Serialize_tuple)] -pub struct Stats { - pub iterations: u64, - pub improvements: u64, - pub matches: u64, - pub functions: u64, -} - -impl Default for Stats { - fn default() -> Stats { - Stats { - iterations: 0, - improvements: 0, - matches: 0, - functions: 0, - } - } -} - -#[derive(Debug, Deserialize, Serialize)] -pub struct User { - pub trusted_by: Option, - pub name: String, - pub client_stats: Stats, - pub server_stats: Stats, -} - -#[derive(Debug, Deserialize, Serialize)] -pub struct DB { - pub users: HashMap, - pub func_stats: HashMap, - pub total_stats: Stats, -} - -impl DB { - pub fn func_stat(&mut self, fn_name: String) -> &mut Stats { - self.func_stats - .entry(fn_name) - .or_insert_with(Stats::default) - } -} diff --git a/tools/decomp-permuter/src/net/controller/src/flimsy_semaphore.rs b/tools/decomp-permuter/src/net/controller/src/flimsy_semaphore.rs deleted file mode 100644 index b29b41fea2..0000000000 --- a/tools/decomp-permuter/src/net/controller/src/flimsy_semaphore.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::convert::TryInto; -use std::sync::atomic::{AtomicIsize, Ordering}; - -use tokio::sync::Notify; - -/// An unfair semaphore that allows overdrafts. -pub struct FlimsySemaphore { - notify: Notify, - slots: AtomicIsize, -} - -impl FlimsySemaphore { - // Invariant: if `slots` has ever become non-positive, then if positive - // there will be a notify token in circulation. Taking the token - // synchronizes with a positive `slots`. - pub fn new(limit: usize) -> FlimsySemaphore { - FlimsySemaphore { - notify: Notify::new(), - slots: AtomicIsize::new(limit.try_into().unwrap()), - } - } - - pub fn acquire_ignore_limit(&self) { - self.slots.fetch_add(-1, Ordering::Acquire); - } - - pub async fn acquire(&self) { - let mut was_woken = false; - let mut val = self.slots.load(Ordering::Relaxed); - loop { - if val > 0 { - match self.slots.compare_exchange( - val, - val - 1, - Ordering::Acquire, - Ordering::Relaxed, - ) { - Ok(_) => { - if was_woken && val > 1 { - self.notify.notify_one(); - } - return; - } - Err(actually) => { - val = actually; - } - } - } else { - self.notify.notified().await; - was_woken = true; - val = self.slots.load(Ordering::Relaxed); - } - } - } - - pub fn release(&self) { - if self.slots.fetch_add(1, Ordering::Release) == 0 { - self.notify.notify_one(); - } - } -} diff --git a/tools/decomp-permuter/src/net/controller/src/main.rs b/tools/decomp-permuter/src/net/controller/src/main.rs deleted file mode 100644 index e50a91d531..0000000000 --- a/tools/decomp-permuter/src/net/controller/src/main.rs +++ /dev/null @@ -1,418 +0,0 @@ -#![allow(clippy::try_err)] - -use std::collections::{HashMap, VecDeque}; -use std::convert::TryInto; -use std::default::Default; -use std::io::ErrorKind; -use std::sync::{Arc, Mutex}; - -use argh::FromArgs; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use slotmap::{new_key_type, SlotMap}; -use sodiumoxide::crypto::box_; -use sodiumoxide::crypto::sign; -use tokio::fs; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::tcp::{ReadHalf, WriteHalf}; -use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::{mpsc, watch, Notify}; -use tokio::time; - -use crate::db::{ByteString, UserId}; -use crate::flimsy_semaphore::FlimsySemaphore; -use crate::port::{ReadPort, WritePort}; -use crate::save::SaveableDB; -use crate::util::SimpleResult; - -mod client; -mod db; -mod flimsy_semaphore; -mod port; -mod save; -mod server; -mod setup; -mod stats; -mod util; -mod vouch; - -const HEARTBEAT_TIME: time::Duration = time::Duration::from_secs(300); - -#[derive(FromArgs)] -/// The permuter@home control server. -struct CmdOpts { - #[argh(subcommand)] - sub: SubCommand, -} - -#[derive(FromArgs)] -#[argh(subcommand)] -enum SubCommand { - RunServer(RunServerOpts), - Setup(SetupOpts), -} - -#[derive(FromArgs)] -/// Run the permuter@home control server. -#[argh(subcommand, name = "run")] -struct RunServerOpts { - /// ip:port to listen on (e.g. 0.0.0.0:1234) - #[argh(option)] - listen_on: String, - - /// path to TOML configuration file - #[argh(option)] - config: String, - - /// path to JSON database - #[argh(option)] - db: String, - - /// enable debug logging - #[argh(switch)] - debug: bool, -} - -#[derive(FromArgs)] -/// Setup initial database and config for permuter@home. -#[argh(subcommand, name = "setup")] -struct SetupOpts { - /// path to JSON database - #[argh(option)] - db: String, -} - -#[derive(Deserialize)] -struct Config { - docker_image: String, - priv_seed: ByteString<32>, -} - -#[derive(Debug, Deserialize, Serialize)] -struct PermuterData { - fn_name: String, - #[serde(skip)] - compressed_source: Vec, - #[serde(skip)] - compressed_target_o_bin: Vec, - #[serde(flatten)] - more_props: HashMap, -} - -#[derive(Clone, Copy, Debug, Deserialize, Serialize)] -struct PermuterWork { - seed: u64, -} - -#[derive(Debug, Deserialize, Serialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum ServerUpdate { - Result { - #[serde(skip_serializing, default)] - overhead_us: i64, - #[serde(skip)] - compressed_source: Option>, - #[serde(default)] - has_source: bool, - #[serde(flatten)] - more_props: HashMap, - }, - InitDone { - hash: String, - }, - InitFailed { - reason: String, - }, - Disconnect, -} - -#[derive(Debug)] -enum PermuterResult { - NeedWork, - Result(UserId, String, ServerUpdate), -} - -type PermuterId = u64; - -struct Permuter { - data: Arc, - client_id: UserId, - client_name: String, - work_queue: VecDeque, - result_tx: mpsc::UnboundedSender, - semaphore: Arc, - priority: f64, - energy_add: f64, -} - -impl Permuter { - fn send_result(&mut self, res: PermuterResult) { - // We can't use a blocking semaphore acquire here, because we don't - // want server sends to block on random client receives. In practice, - // this is probably fine. - let _ = self.result_tx.send(res); - self.semaphore.acquire_ignore_limit(); - } -} - -new_key_type! { struct ServerId; } - -struct ConnectedServer { - min_priority: f64, - num_cores: f64, -} - -struct MutableState { - servers: SlotMap, - permuters: HashMap, - next_permuter_id: PermuterId, -} - -struct State { - docker_image: String, - debug: bool, - sign_sk: sign::SecretKey, - db: SaveableDB, - stats_tx: mpsc::Sender, - heartbeat_rx: watch::Receiver<()>, - new_work_notification: Notify, - m: Mutex, -} - -impl State { - async fn log_stats(&self, record: stats::Record) -> SimpleResult<()> { - self.stats_tx - .send(record) - .await - .map_err(|_| "stats thread died".into()) - } -} - -#[derive(Deserialize)] -#[serde(tag = "method", rename_all = "snake_case")] -enum Request { - Ping, - Vouch(vouch::VouchData), - ConnectServer(server::ConnectServerData), - ConnectClient(client::ConnectClientData), -} - -#[derive(Serialize)] -struct Load { - clients: usize, - servers: usize, - cores: f64, -} - -#[tokio::main] -async fn main() -> SimpleResult<()> { - sodiumoxide::init().map_err(|()| "Failed to initialize cryptography library")?; - - let opts: CmdOpts = argh::from_env(); - - match opts.sub { - SubCommand::RunServer(opts) => run_server(opts).await?, - SubCommand::Setup(opts) => setup::run_setup(opts)?, - } - Ok(()) -} - -async fn run_server(opts: RunServerOpts) -> SimpleResult<()> { - let config: Config = toml::from_str(&fs::read_to_string(&opts.config).await?)?; - let (_, sign_sk) = sign::keypair_from_seed(&config.priv_seed.to_seed()); - - let (save_fut, db) = SaveableDB::open(&opts.db)?; - tokio::spawn(async move { - if let Err(e) = save_fut.await { - eprintln!("Failed to save! {:?}", e); - std::process::exit(1); - } - }); - - let (stats_fut, stats_tx) = stats::stats_thread(&db); - tokio::spawn(stats_fut); - - let (heartbeat_tx, heartbeat_rx) = watch::channel(()); - - let state: &'static State = Box::leak(Box::new(State { - docker_image: config.docker_image, - debug: opts.debug, - sign_sk, - db, - stats_tx, - heartbeat_rx, - new_work_notification: Notify::new(), - m: Mutex::new(MutableState { - servers: SlotMap::with_key(), - permuters: HashMap::new(), - next_permuter_id: 0, - }), - })); - - tokio::spawn(async move { - loop { - heartbeat_tx.send(()).expect("receiver is still alive"); - time::sleep(HEARTBEAT_TIME).await; - } - }); - - let listener = TcpListener::bind(opts.listen_on).await?; - - loop { - let (socket, _) = listener.accept().await?; - tokio::spawn(async move { - let mut who = "anonymous".to_string(); - if let Err(e) = handle_connection(socket, state, &mut who).await { - if let Some(e) = e.downcast_ref::() { - if matches!( - e.kind(), - ErrorKind::UnexpectedEof - | ErrorKind::ConnectionReset - | ErrorKind::TimedOut - | ErrorKind::BrokenPipe - ) { - eprintln!("[{}] disconnected", &who); - return; - } - } - eprintln!("[{}] error: {:?}", &who, e); - } - }); - } -} - -fn concat(a: &[T], b: &[T]) -> Vec { - a.iter().chain(b).cloned().collect() -} - -fn concat3(a: &[T], b: &[T], c: &[T]) -> Vec { - a.iter().chain(b).chain(c).cloned().collect() -} - -async fn handshake<'a>( - mut rd: ReadHalf<'a>, - mut wr: WriteHalf<'a>, - sign_sk: &sign::SecretKey, -) -> SimpleResult<(ReadPort<'a>, WritePort<'a>, UserId, u32)> { - let mut buffer = [0; 4 + 32]; - rd.read_exact(&mut buffer).await?; - let (magic, their_pk) = buffer.split_at(4); - if magic != b"p@h0" { - Err("Invalid protocol version")?; - } - let their_pk = box_::PublicKey::from_slice(&their_pk).unwrap(); - - let (our_pk, our_sk) = box_::gen_keypair(); - let signed_data = concat3(b"HELLO:", their_pk.as_ref(), our_pk.as_ref()); - let signature = sign::sign_detached(&signed_data, &sign_sk); - wr.write_all(&concat(our_pk.as_ref(), signature.as_ref())) - .await?; - - let key = box_::precompute(&their_pk, &our_sk); - let mut read_port = ReadPort::new(rd, &key); - let write_port = WritePort::new(wr, &key); - - let reply = read_port.recv().await?; - if reply.len() != 32 + 64 + 4 { - Err("Failed to perform secret handshake")?; - } - let (client_ver_key, rest) = reply.split_at(32); - let (client_signature, permuter_version) = rest.split_at(64); - let client_ver_key = sign::PublicKey::from_slice(client_ver_key).unwrap(); - let client_signature = sign::Signature::from_slice(client_signature).unwrap(); - let permuter_version = u32::from_be_bytes(permuter_version.try_into().unwrap()); - let signed_data = concat(b"WORLD:", our_pk.as_ref()); - if !sign::verify_detached(&client_signature, &signed_data, &client_ver_key) { - Err("Spoofed client signature!")?; - } - - Ok(( - read_port, - write_port, - UserId::from_pubkey(&client_ver_key), - permuter_version, - )) -} - -fn current_load(state: &State, priority: Option) -> Load { - let m = state.m.lock().unwrap(); - let mut servers: usize = 0; - let mut cores: f64 = 0.0; - for server in m.servers.values() { - if priority.map_or(true, |p| p >= server.min_priority) { - servers += 1; - cores += server.num_cores; - } - } - Load { - clients: m.permuters.len(), - servers, - cores, - } -} - -async fn handle_connection( - mut socket: TcpStream, - state: &State, - out_name: &mut String, -) -> SimpleResult<()> { - let (rd, wr) = socket.split(); - let (mut read_port, mut write_port, user_id, permuter_version) = - handshake(rd, wr, &state.sign_sk).await?; - let name = match state.db.read(|db| { - let user = db.users.get(&user_id)?; - Some(user.name.clone()) - }) { - Some(name) => name, - None => { - write_port.send_error("Access denied!").await?; - Err("Unknown client!")? - } - }; - *out_name = name.clone(); - eprintln!("[{}] connected (v {})", &name, permuter_version); - if state.debug { - read_port.set_debug(&name); - write_port.set_debug(&name); - } - write_port.send_json(&json!({})).await?; - - let request = read_port.recv().await?; - let request: Request = serde_json::from_slice(&request)?; - match request { - Request::Ping => { - eprintln!("[{}] ping", &name); - let load = current_load(state, None); - write_port.send_json(&load).await?; - } - Request::Vouch(data) => { - vouch::handle_vouch(read_port, write_port, user_id, &name, state, data).await?; - } - Request::ConnectServer(data) => { - server::handle_connect_server( - read_port, - write_port, - user_id, - &name, - permuter_version, - state, - data, - ) - .await?; - } - Request::ConnectClient(data) => { - client::handle_connect_client( - read_port, - write_port, - user_id, - &name, - permuter_version, - state, - data, - ) - .await?; - } - }; - - Ok(()) -} diff --git a/tools/decomp-permuter/src/net/controller/src/port.rs b/tools/decomp-permuter/src/net/controller/src/port.rs deleted file mode 100644 index 5e9c52eebb..0000000000 --- a/tools/decomp-permuter/src/net/controller/src/port.rs +++ /dev/null @@ -1,115 +0,0 @@ -use std::convert::TryInto; - -use chrono::Local; -use serde::Serialize; -use sodiumoxide::crypto::box_; -use sodiumoxide::crypto::box_::{Nonce, PrecomputedKey}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::tcp::{ReadHalf, WriteHalf}; - -use crate::util::SimpleResult; - -fn debug_print(action: &str, who: &str, msg: &[u8]) { - let time = Local::now().format("%H:%M:%S:%f"); - if msg.len() <= 300 { - let msg = String::from_utf8( - msg.iter() - .copied() - .flat_map(std::ascii::escape_default) - .collect(), - ) - .unwrap(); - println!("{} debug: {} {}: {}", time, action, who, msg); - } else { - println!("{} debug: {} {}: {} bytes", time, action, who, msg.len()); - } -} - -pub struct ReadPort<'a> { - read_half: ReadHalf<'a>, - key: PrecomputedKey, - nonce: u64, - debug_name: Option<&'a str>, -} - -impl<'a> ReadPort<'a> { - pub fn new(read_half: ReadHalf<'a>, key: &PrecomputedKey) -> Self { - ReadPort { - read_half, - key: key.clone(), - nonce: 0, - debug_name: None, - } - } - - pub fn set_debug(&mut self, name: &'a str) { - self.debug_name = Some(name); - } - - pub async fn recv(&mut self) -> SimpleResult> { - let len = self.read_half.read_u64().await?; - if len >= (1 << 48) { - Err("Unreasonable packet length")? - } - let mut buffer = vec![0; len.try_into()?]; - self.read_half.read_exact(&mut buffer).await?; - let nonce = nonce_from_u64(self.nonce); - self.nonce += 2; - let data = - box_::open_precomputed(&buffer, &nonce, &self.key).map_err(|()| "Failed to decrypt")?; - if let Some(name) = self.debug_name { - debug_print("Receive from", name, &data); - } - Ok(data) - } -} - -pub struct WritePort<'a> { - write_half: WriteHalf<'a>, - key: PrecomputedKey, - nonce: u64, - debug_name: Option<&'a str>, -} - -impl<'a> WritePort<'a> { - pub fn new(write_half: WriteHalf<'a>, key: &PrecomputedKey) -> Self { - WritePort { - write_half, - key: key.clone(), - nonce: 1, - debug_name: None, - } - } - - pub fn set_debug(&mut self, name: &'a str) { - self.debug_name = Some(name); - } - - pub async fn send(&mut self, data: &[u8]) -> SimpleResult<()> { - if let Some(name) = self.debug_name { - debug_print("Send to", name, &data); - } - let nonce = nonce_from_u64(self.nonce); - self.nonce += 2; - let data = box_::seal_precomputed(data, &nonce, &self.key); - self.write_half.write_u64(data.len() as u64).await?; - self.write_half.write_all(&data).await?; - Ok(()) - } - - pub async fn send_json(&mut self, value: &T) -> SimpleResult<()> - where - T: Serialize, - { - self.send(&serde_json::to_vec(value)?).await - } - - pub async fn send_error(&mut self, message: &str) -> SimpleResult<()> { - self.send_json(message).await - } -} - -fn nonce_from_u64(num: u64) -> Nonce { - let nonce_bytes = [[0; 8], [0; 8], num.to_be_bytes()].concat(); - Nonce::from_slice(&nonce_bytes).unwrap() -} diff --git a/tools/decomp-permuter/src/net/controller/src/save.rs b/tools/decomp-permuter/src/net/controller/src/save.rs deleted file mode 100644 index 883637279a..0000000000 --- a/tools/decomp-permuter/src/net/controller/src/save.rs +++ /dev/null @@ -1,158 +0,0 @@ -use std::future::Future; -use std::io::Write; -use std::path::{Path, PathBuf}; -use std::sync::{Arc, RwLock}; -use std::time::Duration; - -use tempfile::NamedTempFile; -use tokio::sync::{mpsc, oneshot}; -use tokio::time::timeout; - -use crate::db::DB; -use crate::util::{FutureExt, SimpleResult}; - -const SAVE_INTERVAL: Duration = Duration::from_secs(30); - -enum SaveType { - Delayed, - Immediate(oneshot::Sender<()>), -} - -struct InnerSaveableDB { - db: DB, - stale: bool, - save_chan: mpsc::UnboundedSender, -} - -#[derive(Clone)] -pub struct SaveableDB(Arc>); - -async fn save_db_loop( - db: SaveableDB, - path: &Path, - mut save_channel: mpsc::UnboundedReceiver, -) -> SimpleResult<()> { - loop { - let mut done_chans = Vec::new(); - match save_channel.recv().await { - None => return Ok(()), - Some(SaveType::Immediate(chan)) => { - done_chans.push(chan); - } - Some(SaveType::Delayed) => { - // Wait for SAVE_INTERVAL or until we receive an Immediate save. - let _ = timeout(SAVE_INTERVAL, async { - loop { - match save_channel.recv().await { - None => { - break; - } - Some(SaveType::Immediate(chan)) => { - done_chans.push(chan); - break; - } - Some(SaveType::Delayed) => {} - }; - } - }) - .await; - } - }; - - // Clear the queue in case more messages have stacked up past an - // Immediate. Receiver::try_recv() is temporarily dead as of tokio 1.4 - // (https://github.com/tokio-rs/tokio/issues/3350) due to a bug where - // messages can be delayed, but in this case that doesn't matter. - loop { - match save_channel.recv().now_or_never().await { - None | Some(None) => { - break; - } - Some(Some(SaveType::Immediate(chan))) => { - done_chans.push(chan); - } - Some(Some(SaveType::Delayed)) => {} - }; - } - - // Mark the DB as non-stale, to start receiving save messages again. - db.0.write().unwrap().stale = false; - - // Actually do the save, by first serializing, then atomically saving - // the file by creating and renaming a temp file in the same directory. - let data = db.read(|db| serde_json::to_string(&db).unwrap()); - - let r: SimpleResult<()> = tokio::task::block_in_place(|| { - let parent_dir = path.parent().unwrap_or_else(|| Path::new(".")); - let mut tempf = NamedTempFile::new_in(parent_dir)?; - tempf.write_all(data.as_bytes())?; - tempf.as_file().sync_all()?; - tempf.persist(path)?; - Ok(()) - }); - r?; - - for chan in done_chans { - let _ = chan.send(()); - } - } -} - -impl SaveableDB { - pub fn open( - filename: &str, - ) -> SimpleResult<(impl Future>, SaveableDB)> { - let db_file = std::fs::File::open(filename)?; - let db: DB = serde_json::from_reader(&db_file)?; - - let (save_tx, save_rx) = mpsc::unbounded_channel(); - - let saveable_db = SaveableDB(Arc::new(RwLock::new(InnerSaveableDB { - db, - stale: false, - save_chan: save_tx, - }))); - - let path = PathBuf::from(filename); - let db2 = saveable_db.clone(); - - let fut = async move { save_db_loop(db2, &path, save_rx).await }; - Ok((fut, saveable_db)) - } - - pub fn read(&self, callback: impl FnOnce(&DB) -> T) -> T { - let inner = self.0.read().unwrap(); - callback(&inner.db) - } - - pub async fn write(&self, immediate: bool, callback: impl FnOnce(&mut DB) -> T) -> T { - let ret; - let rx2; - { - let mut inner = self.0.write().unwrap(); - ret = callback(&mut inner.db); - if immediate { - inner.stale = true; - let (tx, rx) = oneshot::channel(); - rx2 = rx; - inner - .save_chan - .send(SaveType::Immediate(tx)) - .map_err(|_| ()) - .expect("Failed to send message to save task"); - } else { - if !inner.stale { - inner.stale = true; - inner - .save_chan - .send(SaveType::Delayed) - .map_err(|_| ()) - .expect("Failed to send message to save task"); - } - return ret; - } - } - rx2.await.expect("Failed to save!"); - ret - } -} diff --git a/tools/decomp-permuter/src/net/controller/src/server.rs b/tools/decomp-permuter/src/net/controller/src/server.rs deleted file mode 100644 index 3f461adc59..0000000000 --- a/tools/decomp-permuter/src/net/controller/src/server.rs +++ /dev/null @@ -1,500 +0,0 @@ -use std::collections::{HashMap, HashSet}; -use std::sync::{Arc, Mutex}; - -use serde::{Deserialize, Serialize}; -use serde_json::json; -use tokio::sync::{mpsc, mpsc::error::TrySendError, watch, Notify}; - -use crate::db::UserId; -use crate::port::{ReadPort, WritePort}; -use crate::stats; -use crate::util::SimpleResult; -use crate::{ - ConnectedServer, MutableState, PermuterData, PermuterId, PermuterResult, PermuterWork, - ServerUpdate, State, HEARTBEAT_TIME, -}; - -const MIN_PERMUTER_VERSION: u32 = 1; - -const SERVER_WORK_QUEUE_SIZE: usize = 100; -const TIME_US_GUESS: f64 = 100_000.0; -const MIN_OVERHEAD_US: f64 = 100_000.0; -const MAX_OVERHEAD_FACTOR: i64 = 2; - -#[derive(Debug, Deserialize)] -pub(crate) struct ConnectServerData { - min_priority: f64, - num_cores: f64, -} - -#[derive(Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum ServerMessage { - NeedWork, - Update { - permuter: PermuterId, - time_us: f64, - update: ServerUpdate, - }, -} - -enum JobState { - Loading, - Loaded, - Failed, -} - -struct Job { - state: JobState, - energy: f64, - active_work: i64, -} - -struct ServerState { - min_priority: f64, - /// sum of active_work across all jobs - active_work: i64, - /// fractional part of how much work should be requested, in [0, 1) - more_work_acc: f64, - jobs: HashMap, -} - -async fn server_read( - port: &mut ReadPort<'_>, - who_id: &UserId, - who_name: &str, - server_state: &Mutex, - state: &State, - more_work_tx: mpsc::Sender<()>, - new_permuter: &Notify, -) -> SimpleResult<()> { - loop { - let msg = port.recv().await?; - let mut msg: ServerMessage = serde_json::from_slice(&msg)?; - if let ServerMessage::Update { - update: - ServerUpdate::Result { - ref mut compressed_source, - has_source: true, - .. - }, - .. - } = msg - { - *compressed_source = Some(port.recv().await?); - } - - let mut has_new = false; - let mut request_work; - - { - let mut m = state.m.lock().unwrap(); - let mut server_state = server_state.lock().unwrap(); - - let mut more_work: f64 = 1.0; - - if let ServerMessage::Update { - permuter: perm_id, - update, - time_us, - } = msg - { - // If we get back a message referring to a since-removed - // permuter, no need to do anything. Just request one more - // piece of work to make up for it. - if let Some(job) = server_state.jobs.get_mut(&perm_id) { - if let Some(perm) = m.permuters.get_mut(&perm_id) { - job.energy += perm.energy_add * time_us; - - match update { - ServerUpdate::InitDone { .. } => { - if !matches!(job.state, JobState::Loading) { - Err("Got InitDone while not in Loading state")?; - } - job.state = JobState::Loaded; - has_new = true; - } - ServerUpdate::InitFailed { .. } => { - if !matches!(job.state, JobState::Loading) { - Err("Got InitFailed while not in Loading state")?; - } - job.state = JobState::Failed; - } - ServerUpdate::Disconnect { .. } => { - if !matches!(job.state, JobState::Loaded) { - Err("Got Disconnect while not in Loaded state")?; - } - job.state = JobState::Failed; - let work = job.active_work; - job.active_work = 0; - server_state.active_work -= work; - more_work = 0.0; - } - ServerUpdate::Result { overhead_us, .. } => { - if !matches!(job.state, JobState::Loaded) { - Err("Got result while not in Loaded state")?; - } - // If the work item spent less than some given - // amount of time in queues, request more work. - // This ensures we saturate all server cores. - // On the other hand, if it spends too much time - // in queues, it's best if we reduce the amount - // of work. - // We don't need to adjust for time spent on the - // network, because we have backpressure on slow - // writes on both ends, and read continuously. - job.active_work -= 1; - server_state.active_work -= 1; - let min_overhead_us = (time_us + MIN_OVERHEAD_US) as i64; - if overhead_us == 0 { - // Legacy server, skip this logic. - } else if overhead_us > MAX_OVERHEAD_FACTOR * min_overhead_us { - more_work = 0.5; - } else if overhead_us < min_overhead_us { - more_work = 1.5; - } - } - } - perm.send_result(PermuterResult::Result( - who_id.clone(), - who_name.to_string(), - update, - )); - } - } - } - - more_work += server_state.more_work_acc; - request_work = more_work as i32; - server_state.more_work_acc = more_work - request_work as f64; - - if request_work == 0 - && server_state.active_work == 0 - && more_work_tx.capacity() == SERVER_WORK_QUEUE_SIZE - { - // Don't request 0 work if it would lead to total starvation. - request_work = 1; - } - } - - if has_new { - new_permuter.notify_waiters(); - state - .log_stats(stats::Record::ServerNewFunction { - server: who_id.clone(), - }) - .await?; - } - - for _ in 0..request_work { - // Try requesting more work by sending a message to the writer thread. - // If the queue is full (because the writer thread is blocked on a - // send), drop the request to avoid an unbounded backlog. - if let Err(TrySendError::Closed(_)) = more_work_tx.try_send(()) { - panic!("work chooser must not close except on error"); - } - } - } -} - -#[derive(Serialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum ToSend { - Work(PermuterWork), - Add { - client_id: UserId, - client_name: String, - data: Arc, - }, - Remove, -} - -#[derive(Serialize)] -struct OutMessage { - permuter: PermuterId, - #[serde(flatten)] - to_send: ToSend, -} - -fn try_next_work_message( - m: &mut MutableState, - server_state: &mut ServerState, -) -> Option { - let mut skip = HashSet::new(); - loop { - // If possible, send a new permuter. - if let Some((&perm_id, perm)) = m - .permuters - .iter() - .find(|(&perm_id, _)| !server_state.jobs.contains_key(&perm_id)) - { - server_state.jobs.insert( - perm_id, - Job { - state: JobState::Loading, - energy: 0.0, - active_work: 0, - }, - ); - return Some(OutMessage { - permuter: perm_id, - to_send: ToSend::Add { - client_id: perm.client_id.clone(), - client_name: perm.client_name.clone(), - data: perm.data.clone(), - }, - }); - } - - // If none, find an existing one to work on, or to remove. - let mut best_cost = 0.0; - let mut best: Option<(PermuterId, &mut Job)> = None; - let min_priority = server_state.min_priority; - for (&perm_id, job) in server_state.jobs.iter_mut() { - if let Some(perm) = m.permuters.get(&perm_id) { - let energy = - job.energy + (job.active_work as f64) * perm.energy_add * TIME_US_GUESS; - if matches!(job.state, JobState::Loaded) - && !skip.contains(&perm_id) - && perm.priority >= min_priority - && (best.is_none() || energy < best_cost) - { - best_cost = energy; - best = Some((perm_id, job)); - } - } else { - server_state.active_work -= job.active_work; - server_state.jobs.remove(&perm_id); - return Some(OutMessage { - permuter: perm_id, - to_send: ToSend::Remove, - }); - } - } - - let (perm_id, job) = match best { - None => return None, - Some(tup) => tup, - }; - - let perm = m.permuters.get_mut(&perm_id).unwrap(); - let work = match perm.work_queue.pop_front() { - None => { - // Chosen permuter is out of work. Ask it for more, and try - // again without it as a candidate. When the queue becomes - // non-empty again all sleeping writers will be notified. - perm.send_result(PermuterResult::NeedWork); - skip.insert(perm_id); - continue; - } - Some(work) => work, - }; - - perm.semaphore.release(); - - let min_energy = job.energy; - job.active_work += 1; - server_state.active_work += 1; - - // Adjust energies to be around zero, to avoid problems with float - // imprecision, and to ensure that new permuters that come in with - // energy zero will fit the schedule. - for job in server_state.jobs.values_mut() { - job.energy -= min_energy; - } - - return Some(OutMessage { - permuter: perm_id, - to_send: ToSend::Work(work), - }); - } -} - -async fn next_work_message( - server_state: &Mutex, - state: &State, - new_permuter: &Notify, -) -> OutMessage { - let mut wait_for = None; - loop { - if let Some(waiter) = wait_for { - waiter.await; - } - let mut m = state.m.lock().unwrap(); - let mut server_state = server_state.lock().unwrap(); - match try_next_work_message(&mut m, &mut server_state) { - Some(message) => return message, - None => { - // Nothing to work on! Register to be notified when something - // happens (while the lock is still held) and go to sleep. - let n1 = state.new_work_notification.notified(); - let n2 = new_permuter.notified(); - wait_for = Some(async move { - tokio::select! { - () = n1 => {} - () = n2 => {} - } - }); - } - } - } -} - -fn requires_response(work: &OutMessage) -> bool { - match work.to_send { - ToSend::Work { .. } => true, - ToSend::Add { .. } => true, - ToSend::Remove => false, - } -} - -async fn server_choose_work( - server_state: &Mutex, - state: &State, - mut more_work_rx: mpsc::Receiver<()>, - next_message_tx: mpsc::Sender, - wrote_message: &Notify, - new_permuter: &Notify, -) -> SimpleResult<()> { - loop { - let message = next_work_message(server_state, state, new_permuter).await; - let requires_response = requires_response(&message); - next_message_tx - .send(message) - .await - .map_err(|_| ()) - .expect("writer must not close except on error"); - wrote_message.notified().await; - if requires_response { - more_work_rx - .recv() - .await - .expect("reader must not close except on error"); - } - } -} - -async fn send_heartbeat(port: &mut WritePort<'_>) -> SimpleResult<()> { - port.send_json(&json!({ - "type": "heartbeat", - })) - .await -} - -async fn send_work(port: &mut WritePort<'_>, work: &OutMessage) -> SimpleResult<()> { - port.send_json(&work).await?; - if let ToSend::Add { ref data, .. } = work.to_send { - port.send(&data.compressed_source).await?; - port.send(&data.compressed_target_o_bin).await?; - } - Ok(()) -} - -async fn server_write( - port: &mut WritePort<'_>, - mut next_message_rx: mpsc::Receiver, - mut heartbeat_rx: watch::Receiver<()>, - wrote_message: &Notify, -) -> SimpleResult<()> { - loop { - tokio::select! { - work = next_message_rx.recv() => { - let work = work.expect("chooser must not close except on error"); - send_work(port, &work).await?; - wrote_message.notify_one(); - } - res = heartbeat_rx.changed() => { - res.expect("heartbeat thread panicked"); - send_heartbeat(port).await?; - } - } - } -} - -pub(crate) async fn handle_connect_server<'a>( - mut read_port: ReadPort<'a>, - mut write_port: WritePort<'a>, - who_id: UserId, - who_name: &str, - permuter_version: u32, - state: &State, - data: ConnectServerData, -) -> SimpleResult<()> { - if permuter_version < MIN_PERMUTER_VERSION { - Err("Permuter version too old!")?; - } - - eprintln!( - "[{}] start server ({}, {})", - who_name, data.min_priority, data.num_cores - ); - - write_port - .send_json(&json!({ - "docker_image": &state.docker_image, - "heartbeat_interval": HEARTBEAT_TIME.as_secs(), - })) - .await?; - - let (more_work_tx, more_work_rx) = mpsc::channel(SERVER_WORK_QUEUE_SIZE); - let (next_message_tx, next_message_rx) = mpsc::channel(1); - let wrote_message = Notify::new(); - let new_permuter = Notify::new(); - - let mut server_state = Mutex::new(ServerState { - min_priority: data.min_priority, - active_work: 0, - more_work_acc: 0.0, - jobs: HashMap::new(), - }); - - let id = state.m.lock().unwrap().servers.insert(ConnectedServer { - min_priority: data.min_priority, - num_cores: data.num_cores, - }); - - let r = tokio::try_join!( - server_read( - &mut read_port, - &who_id, - who_name, - &server_state, - state, - more_work_tx, - &new_permuter, - ), - server_choose_work( - &server_state, - state, - more_work_rx, - next_message_tx, - &wrote_message, - &new_permuter, - ), - server_write( - &mut write_port, - next_message_rx, - state.heartbeat_rx.clone(), - &wrote_message, - ) - ); - - { - let mut m = state.m.lock().unwrap(); - for (&perm_id, job) in &server_state.get_mut().unwrap().jobs { - if let JobState::Loaded = job.state { - if let Some(perm) = m.permuters.get_mut(&perm_id) { - perm.send_result(PermuterResult::Result( - who_id.clone(), - who_name.to_string(), - ServerUpdate::Disconnect, - )); - } - } - } - - m.servers.remove(id); - } - r?; - Ok(()) -} diff --git a/tools/decomp-permuter/src/net/controller/src/setup.rs b/tools/decomp-permuter/src/net/controller/src/setup.rs deleted file mode 100644 index 07fec490b3..0000000000 --- a/tools/decomp-permuter/src/net/controller/src/setup.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::collections::HashMap; -use std::default::Default; -use std::fs::OpenOptions; - -use sodiumoxide::crypto::sign; -use sodiumoxide::randombytes::randombytes; - -use crate::db::{User, UserId, DB}; -use crate::util::SimpleResult; -use crate::SetupOpts; - -pub(crate) fn run_setup(opts: SetupOpts) -> SimpleResult<()> { - let db_file = OpenOptions::new() - .write(true) - .create_new(true) - .open(&opts.db) - .unwrap_or_else(|e| { - eprintln!("Cannot create database file {}: {}. Aborting.", &opts.db, e); - std::process::exit(1); - }); - - let server_seed = sign::Seed::from_slice(&randombytes(32)).unwrap(); - let client_seed = sign::Seed::from_slice(&randombytes(32)).unwrap(); - - let (server_pub_key, _) = sign::keypair_from_seed(&server_seed); - let (client_pub_key, _) = sign::keypair_from_seed(&client_seed); - - let root_user = User { - trusted_by: None, - name: "root".into(), - client_stats: Default::default(), - server_stats: Default::default(), - }; - let mut users_map: HashMap = HashMap::new(); - users_map.insert(UserId::from_pubkey(&client_pub_key), root_user); - let db = DB { - users: users_map, - func_stats: HashMap::new(), - total_stats: Default::default(), - }; - - serde_json::to_writer(&db_file, &db)?; - - println!( - "Setup successful!\n\n\ - Put the following in the server's config.toml:\n\n\ - priv_seed = \"{}\"\n\n\ - Put the following in the root client's pah.conf:\n\n\ - secret_key = \"{}\"\n\ - server_public_key = \"{}\"\n\ - server_address = \"server.example:port\"", - hex::encode(server_seed), - hex::encode(client_seed), - hex::encode(server_pub_key) - ); - Ok(()) -} diff --git a/tools/decomp-permuter/src/net/controller/src/stats.rs b/tools/decomp-permuter/src/net/controller/src/stats.rs deleted file mode 100644 index 079e4582fd..0000000000 --- a/tools/decomp-permuter/src/net/controller/src/stats.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::future::Future; - -use tokio::sync::mpsc; - -use crate::db::{Stats, UserId}; -use crate::save::SaveableDB; - -const CHANNEL_CAPACITY: usize = 10000; - -#[derive(Clone, Copy)] -pub enum Outcome { - Matched, - Improved, - Unhelpful, -} - -pub enum Record { - WorkDone { - server: UserId, - client: UserId, - fn_name: String, - outcome: Outcome, - }, - ClientNewFunction { - client: UserId, - fn_name: String, - }, - ServerNewFunction { - server: UserId, - }, -} - -fn add_stats(stats: &mut Stats, outcome: Outcome) { - if matches!(outcome, Outcome::Matched) { - stats.matches += 1; - } - if matches!(outcome, Outcome::Matched | Outcome::Improved) { - stats.improvements += 1; - } - stats.iterations += 1; -} - -async fn stats_writer(db: &SaveableDB, mut rx: mpsc::Receiver) { - loop { - let record = rx.recv().await.unwrap(); - db.write(false, |db| { - match record { - Record::WorkDone { - server, - client, - fn_name, - outcome, - } => { - add_stats(&mut db.total_stats, outcome); - add_stats(db.func_stat(fn_name), outcome); - if let Some(user) = db.users.get_mut(&client) { - add_stats(&mut user.client_stats, outcome); - } - if let Some(user) = db.users.get_mut(&server) { - add_stats(&mut user.server_stats, outcome); - } - } - Record::ClientNewFunction { client, fn_name } => { - db.func_stat(fn_name).functions += 1; - if let Some(user) = db.users.get_mut(&client) { - user.client_stats.functions += 1; - } - db.total_stats.functions += 1; - } - Record::ServerNewFunction { server } => { - if let Some(user) = db.users.get_mut(&server) { - user.server_stats.functions += 1; - } - } - }; - }) - .await; - } -} - -pub fn stats_thread(db: &SaveableDB) -> (impl Future, mpsc::Sender) { - let (tx, rx) = mpsc::channel(CHANNEL_CAPACITY); - let db = db.clone(); - let fut = async move { - stats_writer(&db, rx).await; - }; - (fut, tx) -} diff --git a/tools/decomp-permuter/src/net/controller/src/util.rs b/tools/decomp-permuter/src/net/controller/src/util.rs deleted file mode 100644 index 0194f29162..0000000000 --- a/tools/decomp-permuter/src/net/controller/src/util.rs +++ /dev/null @@ -1,37 +0,0 @@ -use std::error::Error; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; - -use pin_project::pin_project; - -pub type SimpleResult = Result>; - -#[pin_project] -pub struct NowOrNever { - #[pin] - inner: F, -} - -impl Future for NowOrNever { - type Output = Option; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let ret = self.project().inner.poll(cx); - Poll::Ready(match ret { - Poll::Pending => None, - Poll::Ready(val) => Some(val), - }) - } -} - -impl FutureExt for T where T: Future {} - -pub trait FutureExt: Future { - fn now_or_never(self) -> NowOrNever - where - Self: Sized, - { - NowOrNever { inner: self } - } -} diff --git a/tools/decomp-permuter/src/net/controller/src/vouch.rs b/tools/decomp-permuter/src/net/controller/src/vouch.rs deleted file mode 100644 index 19bd3580d4..0000000000 --- a/tools/decomp-permuter/src/net/controller/src/vouch.rs +++ /dev/null @@ -1,80 +0,0 @@ -use std::str; - -use hex::FromHex; -use serde::Deserialize; -use serde_json::json; -use sodiumoxide::crypto::sign; - -use crate::db::{User, UserId}; -use crate::port::{ReadPort, WritePort}; -use crate::util::SimpleResult; -use crate::{concat, State}; - -#[derive(Debug, Deserialize)] -pub(crate) struct VouchData { - who: UserId, - signed_name: String, -} - -fn verify_with_magic<'a>( - magic: &[u8], - data: &'a [u8], - key: &sign::PublicKey, -) -> SimpleResult<&'a [u8]> { - if data.len() < 64 { - Err("signature too short")?; - } - let (signature, data) = data.split_at(64); - let signed_data = concat(magic, data); - let signature = sign::Signature::from_slice(signature).unwrap(); - if !sign::verify_detached(&signature, &signed_data, key) { - Err("bad signature")?; - } - Ok(data) -} - -fn parse_signed_name(signed_name: &str, who: &UserId) -> SimpleResult { - let signed_name = Vec::from_hex(signed_name).map_err(|_| "not a valid hex string")?; - let name_bytes = verify_with_magic(b"NAME:", &signed_name, &who.to_pubkey())?; - let name = str::from_utf8(name_bytes)?; - if name.is_empty() { - Err("name is empty")?; - } - if name.chars().any(char::is_control) { - Err("name cannot contain control characters")?; - } - Ok(name.to_string()) -} - -pub(crate) async fn handle_vouch<'a>( - mut read_port: ReadPort<'a>, - mut write_port: WritePort<'a>, - who_id: UserId, - who_name: &str, - state: &State, - data: VouchData, -) -> SimpleResult<()> { - let vouchee_name = match parse_signed_name(&data.signed_name, &data.who) { - Ok(name) => name, - Err(e) => { - write_port.send_error(&format!("{}", &e)).await?; - Err(e)? - } - }; - write_port.send_json(&json!({})).await?; - read_port.recv().await?; - state - .db - .write(true, |db| { - db.users.entry(data.who).or_insert_with(|| User { - trusted_by: Some(who_id), - name: vouchee_name.clone(), - client_stats: Default::default(), - server_stats: Default::default(), - }); - }) - .await; - write_port.send_json(&json!({})).await?; - eprintln!("[{}] vouch {}", who_name, &vouchee_name); - Ok(()) -} diff --git a/tools/decomp-permuter/src/net/core.py b/tools/decomp-permuter/src/net/core.py deleted file mode 100644 index c7f497d430..0000000000 --- a/tools/decomp-permuter/src/net/core.py +++ /dev/null @@ -1,416 +0,0 @@ -import abc -from dataclasses import dataclass -import datetime -import json -import socket -import struct -import sys -import toml -import typing -from typing import BinaryIO, List, Optional, Type, TypeVar, Union - -from nacl.encoding import HexEncoder -from nacl.public import Box, PrivateKey, PublicKey -from nacl.secret import SecretBox -from nacl.signing import SigningKey, VerifyKey - -from ..error import ServerError -from ..helpers import exception_to_string - -T = TypeVar("T") -AnyBox = Union[Box, SecretBox] - -PERMUTER_VERSION = 2 - -CONFIG_FILENAME = "pah.conf" - -MIN_PRIO = 0.01 -MAX_PRIO = 2.0 - -DEBUG_MODE = False - - -def enable_debug_mode() -> None: - """Enable debug logging.""" - global DEBUG_MODE - DEBUG_MODE = True - - -def debug_print(message: str) -> None: - if DEBUG_MODE: - time = datetime.datetime.now().strftime("%H:%M:%S:%f") - print(f"\n{time} debug: {message}") - - -@dataclass(eq=False) -class CancelToken: - cancelled: bool = False - - -@dataclass -class PermuterData: - base_score: int - base_hash: str - fn_name: str - filename: str - keep_prob: float - need_profiler: bool - stack_differences: bool - compile_script: str - source: str - target_o_bin: bytes - - -def permuter_data_from_json( - obj: dict, source: str, target_o_bin: bytes -) -> PermuterData: - return PermuterData( - base_score=json_prop(obj, "base_score", int), - base_hash=json_prop(obj, "base_hash", str), - fn_name=json_prop(obj, "fn_name", str), - filename=json_prop(obj, "filename", str), - keep_prob=json_prop(obj, "keep_prob", float), - need_profiler=json_prop(obj, "need_profiler", bool), - stack_differences=json_prop(obj, "stack_differences", bool), - compile_script=json_prop(obj, "compile_script", str), - source=source, - target_o_bin=target_o_bin, - ) - - -def permuter_data_to_json(perm: PermuterData) -> dict: - return { - "base_score": perm.base_score, - "base_hash": perm.base_hash, - "fn_name": perm.fn_name, - "filename": perm.filename, - "keep_prob": perm.keep_prob, - "need_profiler": perm.need_profiler, - "stack_differences": perm.stack_differences, - "compile_script": perm.compile_script, - } - - -@dataclass -class Config: - server_address: Optional[str] = None - server_verify_key: Optional[VerifyKey] = None - signing_key: Optional[SigningKey] = None - initial_setup_nickname: Optional[str] = None - - -def read_config() -> Config: - config = Config() - try: - with open(CONFIG_FILENAME) as f: - obj = toml.load(f) - - def read(key: str, t: Type[T]) -> Optional[T]: - ret = obj.get(key) - return ret if isinstance(ret, t) else None - - temp = read("server_public_key", str) - if temp: - config.server_verify_key = VerifyKey(HexEncoder.decode(temp)) - temp = read("secret_key", str) - if temp: - config.signing_key = SigningKey(HexEncoder.decode(temp)) - config.initial_setup_nickname = read("initial_setup_nickname", str) - config.server_address = read("server_address", str) - except FileNotFoundError: - pass - except Exception: - print(f"Malformed configuration file {CONFIG_FILENAME}.\n") - raise - return config - - -def write_config(config: Config) -> None: - obj = {} - - def write(key: str, val: Union[None, str, int]) -> None: - if val is not None: - obj[key] = val - - write("initial_setup_nickname", config.initial_setup_nickname) - write("server_address", config.server_address) - - key_hex: bytes - if config.server_verify_key: - key_hex = config.server_verify_key.encode(HexEncoder) - write("server_public_key", key_hex.decode("utf-8")) - if config.signing_key: - key_hex = config.signing_key.encode(HexEncoder) - write("secret_key", key_hex.decode("utf-8")) - - with open(CONFIG_FILENAME, "w") as f: - toml.dump(obj, f) - - -def file_read_max(inf: BinaryIO, n: int) -> bytes: - try: - ret = [] - while n > 0: - data = inf.read(n) - if not data: - break - ret.append(data) - n -= len(data) - return b"".join(ret) - except Exception as e: - raise EOFError from e - - -def file_read_fixed(inf: BinaryIO, n: int) -> bytes: - ret = file_read_max(inf, n) - if len(ret) != n: - raise EOFError - return ret - - -def socket_read_max(sock: socket.socket, n: int) -> bytes: - try: - ret = [] - while n > 0: - data = sock.recv(min(n, 4096)) - if not data: - break - ret.append(data) - n -= len(data) - return b"".join(ret) - except Exception as e: - raise EOFError from e - - -def socket_read_fixed(sock: socket.socket, n: int) -> bytes: - ret = socket_read_max(sock, n) - if len(ret) != n: - raise EOFError - return ret - - -def socket_shutdown(sock: socket.socket, how: int) -> None: - try: - sock.shutdown(how) - except Exception: - pass - - -def json_prop(obj: dict, prop: str, t: Type[T]) -> T: - ret = obj.get(prop) - if not isinstance(ret, t): - if t is float and isinstance(ret, int): - return typing.cast(T, float(ret)) - found_type = type(ret).__name__ - if prop not in obj: - raise ValueError(f"Member {prop} does not exist") - raise ValueError(f"Member {prop} must have type {t.__name__}; got {found_type}") - return ret - - -def json_array(obj: list, t: Type[T]) -> List[T]: - for elem in obj: - if not isinstance(elem, t): - found_type = type(elem).__name__ - raise ValueError( - f"Array elements must have type {t.__name__}; got {found_type}" - ) - return obj - - -def sign_with_magic(magic: bytes, signing_key: SigningKey, data: bytes) -> bytes: - signature: bytes = signing_key.sign(magic + b":" + data).signature - return signature + data - - -def verify_with_magic(magic: bytes, verify_key: VerifyKey, data: bytes) -> bytes: - if len(data) < 64: - raise ValueError("String is too small to contain a signature") - signature = data[:64] - data = data[64:] - verify_key.verify(magic + b":" + data, signature) - return data - - -class Port(abc.ABC): - def __init__(self, box: AnyBox, who: str, *, is_client: bool) -> None: - self._box = box - self._who = who - self._send_nonce = 0 if is_client else 1 - self._receive_nonce = 1 if is_client else 0 - - @abc.abstractmethod - def _send(self, data: bytes) -> None: - ... - - @abc.abstractmethod - def _receive(self, length: int) -> bytes: - ... - - @abc.abstractmethod - def _receive_max(self, length: int) -> bytes: - ... - - def send(self, msg: bytes) -> None: - """Send a binary message, potentially blocking.""" - if DEBUG_MODE: - if len(msg) <= 300: - debug_print(f"Send to {self._who}: {msg!r}") - else: - debug_print(f"Send to {self._who}: {len(msg)} bytes") - nonce = struct.pack(">16xQ", self._send_nonce) - self._send_nonce += 2 - data = self._box.encrypt(msg, nonce).ciphertext - length_data = struct.pack(">Q", len(data)) - try: - self._send(length_data + data) - except BrokenPipeError: - raise EOFError from None - - def send_json(self, msg: dict) -> None: - """Send a message in the form of a JSON dict, potentially blocking.""" - self.send(json.dumps(msg).encode("utf-8")) - - def receive(self) -> bytes: - """Read a binary message, blocking.""" - length_data = self._receive(8) - if length_data[0]: - # Lengths above 2^56 are unreasonable, so if we get one someone is - # sending us bad data. Raise an exception to help debugging. - length_data += self._receive_max(1024) - raise Exception( - f"Got unexpected data from {self._who}: " + repr(length_data) - ) - length = struct.unpack(">Q", length_data)[0] - data = self._receive(length) - nonce = struct.pack(">16xQ", self._receive_nonce) - self._receive_nonce += 2 - msg: bytes = self._box.decrypt(data, nonce) - if DEBUG_MODE: - if len(msg) <= 300: - debug_print(f"Receive from {self._who}: {msg!r}") - else: - debug_print(f"Receive from {self._who}: {len(msg)} bytes") - return msg - - def receive_json(self) -> dict: - """Read a message in the form of a JSON dict, blocking.""" - ret = json.loads(self.receive()) - if isinstance(ret, str): - # Raw strings indicate errors. - raise ServerError(ret) - if not isinstance(ret, dict): - # We always pass dictionaries as messages and no other data types, - # to ensure future extensibility. (Other types are rare in - # practice, anyway.) - raise ValueError("Top-level JSON value must be a dictionary") - return ret - - -class SocketPort(Port): - def __init__( - self, sock: socket.socket, box: AnyBox, who: str, *, is_client: bool - ) -> None: - self._sock = sock - super().__init__(box, who, is_client=is_client) - - def _send(self, data: bytes) -> None: - self._sock.sendall(data) - - def _receive(self, length: int) -> bytes: - return socket_read_fixed(self._sock, length) - - def _receive_max(self, length: int) -> bytes: - return socket_read_max(self._sock, length) - - def shutdown(self, how: int = socket.SHUT_RDWR) -> None: - socket_shutdown(self._sock, how) - - def close(self) -> None: - self._sock.close() - - -class FilePort(Port): - def __init__( - self, inf: BinaryIO, outf: BinaryIO, box: AnyBox, who: str, *, is_client: bool - ) -> None: - self._inf = inf - self._outf = outf - super().__init__(box, who, is_client=is_client) - - def _send(self, data: bytes) -> None: - self._outf.write(data) - self._outf.flush() - - def _receive(self, length: int) -> bytes: - return file_read_fixed(self._inf, length) - - def _receive_max(self, length: int) -> bytes: - return file_read_max(self._inf, length) - - -def _do_connect(config: Config) -> SocketPort: - if ( - not config.server_verify_key - or not config.signing_key - or not config.server_address - ): - print( - "Using permuter@home requires someone to give you access to a central -J server.\n" - "Run `./pah.py setup` to set this up." - ) - print() - sys.exit(1) - - host, port_str = config.server_address.split(":") - try: - sock = socket.create_connection((host, int(port_str))) - except ConnectionRefusedError: - raise EOFError("connection refused") from None - except socket.gaierror as e: - raise EOFError(f"DNS lookup failed: {e}") from None - except Exception as e: - raise EOFError("unable to connect: " + exception_to_string(e)) from None - - # Send over the protocol version and an ephemeral encryption key which we - # are going to use for all communication. - ephemeral_key = PrivateKey.generate() - ephemeral_key_data = ephemeral_key.public_key.encode() - sock.sendall(b"p@h0" + ephemeral_key_data) - - # Receive the server's encryption key, plus a signature of it and our own - # ephemeral key -- this guarantees that we are talking to the server and - # aren't victim to a replay attack. Use it to set up a communication port. - msg = socket_read_fixed(sock, 32 + 64) - server_enc_key_data = msg[:32] - config.server_verify_key.verify( - b"HELLO:" + ephemeral_key_data + server_enc_key_data, msg[32:] - ) - box = Box(ephemeral_key, PublicKey(server_enc_key_data)) - port = SocketPort(sock, box, "controller", is_client=True) - - # Use the encrypted port to send over our public key, proof that we are - # able to sign new things with it, as well as permuter version. - signature: bytes = config.signing_key.sign( - b"WORLD:" + server_enc_key_data - ).signature - port.send( - config.signing_key.verify_key.encode() - + signature - + struct.pack(">I", PERMUTER_VERSION) - ) - - # Get an acknowledgement that the server wants to talk to us. - obj = port.receive_json() - - if "message" in obj: - print(obj["message"]) - - return port - - -def connect(config: Optional[Config] = None) -> SocketPort: - """Authenticate and connect to the permuter@home controller server.""" - if not config: - config = read_config() - return _do_connect(config) diff --git a/tools/decomp-permuter/src/net/evaluator.py b/tools/decomp-permuter/src/net/evaluator.py deleted file mode 100644 index e9138a6c74..0000000000 --- a/tools/decomp-permuter/src/net/evaluator.py +++ /dev/null @@ -1,401 +0,0 @@ -"""This file runs as a free-standing program within a sandbox, and processes -permutation requests. It communicates with the outside world on stdin/stdout.""" -import base64 -from dataclasses import dataclass -import math -from multiprocessing import Process, Queue -import os -import queue -import sys -from tempfile import mkstemp -import threading -import time -import traceback -from typing import Counter, Dict, List, Optional, Set, Tuple, Union -import zlib - -from nacl.secret import SecretBox - -from ..candidate import CandidateResult -from ..compiler import Compiler -from ..error import CandidateConstructionFailure -from ..helpers import exception_to_string, static_assert_unreachable -from ..permuter import EvalError, EvalResult, Permuter -from ..profiler import Profiler -from ..scorer import Scorer -from .core import ( - FilePort, - PermuterData, - Port, - json_prop, - permuter_data_from_json, -) - - -def _fix_stdout() -> None: - """Redirect stdout to stderr to make print() debugging work. This function - *must* be called at startup for each (sub)process, since we use stdout for - our own communication purposes.""" - sys.stdout = sys.stderr - - # In addition, we set stderr to flush on newlines, which does not happen by - # default when it is piped. (Requires Python 3.7, but we can assume that's - # available inside the sandbox.) - sys.stdout.reconfigure(line_buffering=True) # type: ignore - - -def _setup_port(secret: bytes) -> Port: - """Set up communication with the outside world.""" - port = FilePort( - sys.stdin.buffer, - sys.stdout.buffer, - SecretBox(secret), - "server", - is_client=False, - ) - - _fix_stdout() - - # Follow the controlling process's sanity check protocol. - magic = port.receive() - port.send(magic) - - return port - - -def _create_permuter(data: PermuterData) -> Permuter: - fd, path = mkstemp(suffix=".o", prefix="permuter", text=False) - try: - with os.fdopen(fd, "wb") as f: - f.write(data.target_o_bin) - scorer = Scorer(target_o=path, stack_differences=data.stack_differences) - finally: - os.unlink(path) - - fd, path = mkstemp(suffix=".sh", prefix="permuter", text=True) - try: - os.chmod(fd, 0o755) - with os.fdopen(fd, "w") as f2: - f2.write(data.compile_script) - compiler = Compiler(compile_cmd=path, show_errors=False) - - return Permuter( - dir="unused", - fn_name=data.fn_name, - compiler=compiler, - scorer=scorer, - source_file=data.filename, - source=data.source, - force_seed=None, - force_rng_seed=None, - keep_prob=data.keep_prob, - need_profiler=data.need_profiler, - need_all_sources=False, - show_errors=False, - better_only=False, - best_only=False, - ) - except: - os.unlink(path) - raise - - -@dataclass -class AddPermuter: - perm_id: str - data: PermuterData - - -@dataclass -class AddPermuterLocal: - perm_id: str - permuter: Permuter - - -@dataclass -class RemovePermuter: - perm_id: str - - -@dataclass -class WorkDone: - perm_id: str - id: int - time_us: int - result: EvalResult - - -@dataclass -class Work: - perm_id: str - id: int - seed: int - - -LocalWork = Tuple[Union[AddPermuterLocal, RemovePermuter], int] -GlobalWork = Tuple[Work, int] -Task = Union[AddPermuter, RemovePermuter, Work, WorkDone] - - -def _remove_permuter(perm: Permuter) -> None: - os.unlink(perm.compiler.compile_cmd) - - -def _send_result(item: WorkDone, port: Port) -> None: - obj = { - "type": "result", - "permuter": item.perm_id, - "id": item.id, - "time_us": item.time_us, - } - - res = item.result - if isinstance(res, EvalError): - obj["error"] = res.exc_str - port.send_json(obj) - return - - compressed_source = getattr(res, "compressed_source") - - obj["score"] = res.score - obj["has_source"] = compressed_source is not None - if res.hash is not None: - obj["hash"] = res.hash - if res.profiler is not None: - obj["profiler"] = { - st.name: res.profiler.time_stats[st] for st in Profiler.StatType - } - - port.send_json(obj) - - if compressed_source is not None: - port.send(compressed_source) - - -def multiprocess_worker( - worker_queue: "Queue[GlobalWork]", - local_queue: "Queue[LocalWork]", - task_queue: "Queue[Task]", -) -> None: - _fix_stdout() - - # Prevent deadlocks in case the parent process dies. - worker_queue.cancel_join_thread() - local_queue.cancel_join_thread() - task_queue.cancel_join_thread() - - permuters: Dict[str, Permuter] = {} - timestamp = 0 - - while True: - work, required_timestamp = worker_queue.get() - while True: - try: - block = timestamp < required_timestamp - task, timestamp = local_queue.get(block=block) - except queue.Empty: - break - if isinstance(task, AddPermuterLocal): - permuters[task.perm_id] = task.permuter - elif isinstance(task, RemovePermuter): - del permuters[task.perm_id] - else: - static_assert_unreachable(task) - - time_before = time.time() - - permuter = permuters[work.perm_id] - result = permuter.try_eval_candidate(work.seed) - if isinstance(result, CandidateResult) and permuter.should_output(result): - permuter.record_result(result) - - # Compress the source within the worker. (Why waste a free - # multi-threading opportunity?) - if isinstance(result, CandidateResult): - compressed_source: Optional[bytes] = None - if result.source is not None: - compressed_source = zlib.compress(result.source.encode("utf-8")) - setattr(result, "compressed_source", compressed_source) - result.source = None - - time_us = int((time.time() - time_before) * 10 ** 6) - task_queue.put( - WorkDone(perm_id=work.perm_id, id=work.id, time_us=time_us, result=result) - ) - - -def read_loop(task_queue: "Queue[Task]", port: Port) -> None: - try: - while True: - item = port.receive_json() - msg_type = json_prop(item, "type", str) - if msg_type == "add": - perm_id = json_prop(item, "permuter", str) - source = port.receive().decode("utf-8") - target_o_bin = port.receive() - data = permuter_data_from_json(item, source, target_o_bin) - task_queue.put(AddPermuter(perm_id=perm_id, data=data)) - - elif msg_type == "remove": - perm_id = json_prop(item, "permuter", str) - task_queue.put(RemovePermuter(perm_id=perm_id)) - - elif msg_type == "work": - perm_id = json_prop(item, "permuter", str) - id = json_prop(item, "id", int) - seed = json_prop(item, "seed", int) - task_queue.put(Work(perm_id=perm_id, id=id, seed=seed)) - - else: - raise Exception(f"Invalid message type {msg_type}") - - except Exception as e: - # In case the port is closed from the other side, skip writing an ugly - # error message. - if not isinstance(e, EOFError): - traceback.print_exc() - - # Exit the whole process, to improve the odds that the Docker container - # really stops and gets removed. - # - # The parent server has a "finally:" that does that, but it's not 100% - # trustworthy. In particular, pystray has a tendency to hard-crash - # (which doesn't fire "finally"s), and also reverts the signal handler - # for SIGINT to the default on Linux, making Ctrl+C not run cleanup. - # Either way, defense in depth here doesn't hurt, since leaking Docker - # containers is pretty bad. - # - # Unfortunately this still doesn't fix the problem, since we typically - # don't get a port closure signal when the parent process stops... - # TODO: listen to heartbeats as well. - sys.exit(1) - - -def main() -> None: - secret = base64.b64decode(os.environ["SECRET"]) - del os.environ["SECRET"] - os.environ["PERMUTER_IS_REMOTE"] = "1" - - port = _setup_port(secret) - - obj = port.receive_json() - num_cores = json_prop(obj, "num_cores", float) - num_threads = math.ceil(num_cores) - - worker_queue: "Queue[GlobalWork]" = Queue() - task_queue: "Queue[Task]" = Queue() - local_queues: "List[Queue[LocalWork]]" = [] - - for i in range(num_threads): - local_queue: "Queue[LocalWork]" = Queue() - p = Process( - target=multiprocess_worker, - args=(worker_queue, local_queue, task_queue), - daemon=True, - ) - p.start() - local_queues.append(local_queue) - - reader_thread = threading.Thread( - target=read_loop, args=(task_queue, port), daemon=True - ) - reader_thread.start() - - remaining_work: Counter[str] = Counter() - should_remove: Set[str] = set() - permuters: Dict[str, Permuter] = {} - timestamp = 0 - - def try_remove(perm_id: str) -> None: - nonlocal timestamp - assert perm_id in permuters - if perm_id not in should_remove or remaining_work[perm_id] != 0: - return - del remaining_work[perm_id] - should_remove.remove(perm_id) - timestamp += 1 - for q in local_queues: - q.put((RemovePermuter(perm_id=perm_id), timestamp)) - _remove_permuter(permuters[perm_id]) - del permuters[perm_id] - - while True: - item = task_queue.get() - - if isinstance(item, AddPermuter): - assert item.perm_id not in permuters - - msg: Dict[str, object] = { - "type": "init", - "permuter": item.perm_id, - } - - time_before = time.time() - try: - # Construct a permuter. This involves a compilation on the main - # thread, which isn't great but we can live with it for now. - permuter = _create_permuter(item.data) - - if permuter.base_score != item.data.base_score: - _remove_permuter(permuter) - score_str = f"{permuter.base_score} vs {item.data.base_score}" - if permuter.base_hash == item.data.base_hash: - hash_str = "same hash; different Python or permuter versions?" - else: - hash_str = "different hash; different objdump versions?" - raise CandidateConstructionFailure( - f"mismatching score: {score_str} ({hash_str})" - ) - - permuters[item.perm_id] = permuter - - msg["success"] = True - msg["base_score"] = permuter.base_score - msg["base_hash"] = permuter.base_hash - - # Tell all the workers about the new permuter. - # TODO: ideally we would also seed their Candidate lru_cache's - # to avoid all workers having to parse the source... - timestamp += 1 - for q in local_queues: - q.put( - ( - AddPermuterLocal(perm_id=item.perm_id, permuter=permuter), - timestamp, - ) - ) - except Exception as e: - # This shouldn't practically happen, since the client compiled - # the code successfully. Print a message if it does. - msg["success"] = False - msg["error"] = exception_to_string(e) - if isinstance(e, CandidateConstructionFailure): - print(e.message) - else: - traceback.print_exc() - - msg["time_us"] = int((time.time() - time_before) * 10 ** 6) - port.send_json(msg) - - elif isinstance(item, RemovePermuter): - # Silently ignore requests to remove permuters that have already - # been removed, which can occur when AddPermuter fails. - if item.perm_id in permuters: - should_remove.add(item.perm_id) - try_remove(item.perm_id) - - elif isinstance(item, WorkDone): - remaining_work[item.perm_id] -= 1 - try_remove(item.perm_id) - _send_result(item, port) - - elif isinstance(item, Work): - remaining_work[item.perm_id] += 1 - worker_queue.put((item, timestamp)) - - else: - static_assert_unreachable(item) - - -if __name__ == "__main__": - main() diff --git a/tools/decomp-permuter/src/net/server.py b/tools/decomp-permuter/src/net/server.py deleted file mode 100644 index 65d0c0bea3..0000000000 --- a/tools/decomp-permuter/src/net/server.py +++ /dev/null @@ -1,944 +0,0 @@ -import base64 -from dataclasses import dataclass -import pathlib -import queue -import struct -import sys -import threading -import time -import traceback -from typing import BinaryIO, Dict, Optional, Set, Tuple, Union, TYPE_CHECKING -import zlib - -if TYPE_CHECKING: - import docker - -from nacl.secret import SecretBox -import nacl.utils - -from ..helpers import exception_to_string, static_assert_unreachable -from .core import ( - CancelToken, - Config, - PermuterData, - Port, - ServerError, - SocketPort, - connect, - file_read_fixed, - json_prop, - permuter_data_from_json, - permuter_data_to_json, -) - - -_HEARTBEAT_INTERVAL_SLACK_SEC: float = 50.0 - - -@dataclass -class Client: - id: str - nickname: str - - -@dataclass -class AddPermuter: - handle: int - time_start: float - client: Client - permuter_data: PermuterData - - -@dataclass -class RemovePermuter: - handle: int - - -@dataclass -class Work: - handle: int - id: int - time_start: float - seed: int - - -@dataclass -class ImmediateDisconnect: - handle: int - client: Client - reason: str - - -@dataclass -class Disconnect: - handle: int - - -@dataclass -class PermInitFail: - perm_id: str - error: str - - -@dataclass -class PermInitSuccess: - perm_id: str - base_score: int - base_hash: str - time_us: int - - -@dataclass -class WorkDone: - perm_id: str - id: int - obj: dict - time_us: int - compressed_source: Optional[bytes] - - -class NeedMoreWork: - pass - - -@dataclass -class NetThreadDisconnected: - graceful: bool - message: Optional[str] = None - - -class Heartbeat: - pass - - -class Shutdown: - pass - - -Activity = Union[ - AddPermuter, - RemovePermuter, - Work, - ImmediateDisconnect, - Disconnect, - PermInitFail, - PermInitSuccess, - WorkDone, - NeedMoreWork, - NetThreadDisconnected, - Heartbeat, - Shutdown, -] - - -@dataclass -class OutputInitFail: - handle: int - error: str - - -@dataclass -class OutputInitSuccess: - handle: int - time_us: int - base_score: int - base_hash: str - - -@dataclass -class OutputDisconnect: - handle: int - - -@dataclass -class OutputNeedMoreWork: - pass - - -@dataclass -class OutputWork: - handle: int - time_start: float - time_us: int - obj: dict - compressed_source: Optional[bytes] - - -Output = Union[ - OutputDisconnect, - OutputInitFail, - OutputInitSuccess, - OutputNeedMoreWork, - OutputWork, - Shutdown, -] - - -@dataclass -class IoConnect: - fn_name: str - client: Client - - -@dataclass -class IoDisconnect: - reason: str - - -@dataclass -class IoImmediateDisconnect: - reason: str - client: Client - - -class IoUserRemovePermuter: - pass - - -@dataclass -class IoServerFailed: - graceful: bool - message: Optional[str] - - -class IoReconnect: - pass - - -class IoShutdown: - pass - - -@dataclass -class IoWorkDone: - score: Optional[int] - is_improvement: bool - - -PermuterHandle = Tuple[int, CancelToken] -IoMessage = Union[ - IoConnect, IoDisconnect, IoImmediateDisconnect, IoUserRemovePermuter, IoWorkDone -] -IoGlobalMessage = Union[IoReconnect, IoShutdown, IoServerFailed] -IoActivity = Tuple[ - Optional[CancelToken], Union[Tuple[PermuterHandle, IoMessage], IoGlobalMessage] -] - - -@dataclass -class ServerOptions: - num_cores: float - max_memory_gb: float - min_priority: float - - -class NetThread: - _port: Optional[SocketPort] - _main_queue: "queue.Queue[Activity]" - _controller_queue: "queue.Queue[Output]" - _read_thread: "threading.Thread" - _write_thread: "threading.Thread" - _next_work_id: int - - def __init__( - self, - port: SocketPort, - main_queue: "queue.Queue[Activity]", - ) -> None: - self._port = port - self._main_queue = main_queue - self._controller_queue = queue.Queue() - self._next_work_id = 0 - - self._read_thread = threading.Thread(target=self.read_loop, daemon=True) - self._read_thread.start() - - self._write_thread = threading.Thread(target=self.write_loop, daemon=True) - self._write_thread.start() - - def stop(self) -> None: - if self._port is None: - return - try: - self._controller_queue.put(Shutdown()) - self._port.shutdown() - self._read_thread.join() - self._write_thread.join() - self._port.close() - self._port = None - except Exception: - print("Failed to stop net thread.") - traceback.print_exc() - - def send_controller(self, msg: Output) -> None: - self._controller_queue.put(msg) - - def _read_one(self) -> Activity: - assert self._port is not None - - msg = self._port.receive_json() - time_start = time.time() - - msg_type = json_prop(msg, "type", str) - - if msg_type == "heartbeat": - return Heartbeat() - - handle = json_prop(msg, "permuter", int) - - if msg_type == "work": - seed = json_prop(msg, "seed", int) - id = self._next_work_id - self._next_work_id += 1 - return Work(handle=handle, id=id, time_start=time_start, seed=seed) - - elif msg_type == "add": - client_id = json_prop(msg, "client_id", str) - client_name = json_prop(msg, "client_name", str) - client = Client(client_id, client_name) - data = json_prop(msg, "data", dict) - compressed_source = self._port.receive() - compressed_target_o_bin = self._port.receive() - - try: - source = zlib.decompress(compressed_source).decode("utf-8") - target_o_bin = zlib.decompress(compressed_target_o_bin) - permuter = permuter_data_from_json(data, source, target_o_bin) - except Exception as e: - # Client sent something illegible. This can legitimately happen if the - # client runs another version, but it's interesting to log. - traceback.print_exc() - return ImmediateDisconnect( - handle=handle, - client=client, - reason=f"Failed to parse permuter: {exception_to_string(e)}", - ) - - return AddPermuter( - handle=handle, - time_start=time_start, - client=client, - permuter_data=permuter, - ) - - elif msg_type == "remove": - return RemovePermuter(handle=handle) - - else: - raise Exception(f"Bad message type: {msg_type}") - - def read_loop(self) -> None: - try: - while True: - msg = self._read_one() - self._main_queue.put(msg) - except EOFError: - self._main_queue.put(NetThreadDisconnected(graceful=True)) - except ServerError as e: - self._main_queue.put( - NetThreadDisconnected(graceful=False, message=e.message) - ) - except Exception: - traceback.print_exc() - self._main_queue.put(NetThreadDisconnected(graceful=False)) - - def _write_one(self, item: Output) -> None: - assert self._port is not None - - if isinstance(item, Shutdown): - # Handled by caller - pass - - elif isinstance(item, OutputInitFail): - self._port.send_json( - { - "type": "update", - "permuter": item.handle, - "time_us": 0, - "update": {"type": "init_failed", "reason": item.error}, - } - ) - - elif isinstance(item, OutputInitSuccess): - self._port.send_json( - { - "type": "update", - "permuter": item.handle, - "time_us": item.time_us, - "update": {"type": "init_done", "hash": item.base_hash}, - } - ) - - elif isinstance(item, OutputDisconnect): - self._port.send_json( - { - "type": "update", - "permuter": item.handle, - "time_us": 0, - "update": {"type": "disconnect"}, - } - ) - - elif isinstance(item, OutputNeedMoreWork): - self._port.send_json({"type": "need_work"}) - - elif isinstance(item, OutputWork): - overhead_us = int((time.time() - item.time_start) * 10 ** 6) - item.time_us - self._port.send_json( - { - "type": "update", - "permuter": item.handle, - "time_us": item.time_us, - "update": { - "type": "work", - "overhead_us": overhead_us, - **item.obj, - }, - } - ) - if item.compressed_source is not None: - self._port.send(item.compressed_source) - - else: - static_assert_unreachable(item) - - def write_loop(self) -> None: - try: - while True: - item = self._controller_queue.get() - if isinstance(item, Shutdown): - break - self._write_one(item) - except EOFError: - self._main_queue.put(NetThreadDisconnected(graceful=True)) - except Exception: - traceback.print_exc() - self._main_queue.put(NetThreadDisconnected(graceful=False)) - - -class ServerInner: - """This class represents an up-and-running server, connected to the controller and - to the evaluator.""" - - _evaluator_port: "DockerPort" - _main_queue: "queue.Queue[Activity]" - _io_queue: "queue.Queue[IoActivity]" - _net_thread: NetThread - _read_eval_thread: threading.Thread - _main_thread: threading.Thread - _heartbeat_interval: float - _last_heartbeat: float - _last_heartbeat_lock: threading.Lock - _active: Set[int] - _time_starts: Dict[int, float] - _token: CancelToken - - def __init__( - self, - net_port: SocketPort, - evaluator_port: "DockerPort", - io_queue: "queue.Queue[IoActivity]", - heartbeat_interval: float, - ) -> None: - self._evaluator_port = evaluator_port - self._main_queue = queue.Queue() - self._io_queue = io_queue - self._active = set() - self._time_starts = {} - self._token = CancelToken() - - self._net_thread = NetThread(net_port, self._main_queue) - - # Start a thread for checking heartbeats. - self._heartbeat_interval = heartbeat_interval - self._last_heartbeat = time.time() - self._last_heartbeat_lock = threading.Lock() - self._heartbeat_stop = threading.Event() - self._heartbeat_thread = threading.Thread( - target=self._heartbeat_loop, daemon=True - ) - self._heartbeat_thread.start() - - # Start a thread for reading evaluator results and sending them on to - # the main loop queue. - self._read_eval_thread = threading.Thread( - target=self._read_eval_loop, daemon=True - ) - self._read_eval_thread.start() - - # Start a thread for the main loop. - self._main_thread = threading.Thread(target=self._main_loop, daemon=True) - self._main_thread.start() - - def _send_controller(self, msg: Output) -> None: - self._net_thread.send_controller(msg) - - def _send_io(self, handle: int, io_msg: IoMessage) -> None: - self._io_queue.put((self._token, ((handle, self._token), io_msg))) - - def _send_io_global(self, io_msg: IoGlobalMessage) -> None: - self._io_queue.put((self._token, io_msg)) - - def _handle_message(self, msg: Activity) -> None: - if isinstance(msg, Shutdown): - # Handled by caller - pass - - elif isinstance(msg, Heartbeat): - with self._last_heartbeat_lock: - self._last_heartbeat = time.time() - - elif isinstance(msg, Work): - if msg.handle not in self._active: - self._need_work() - return - - self._time_starts[msg.id] = msg.time_start - self._evaluator_port.send_json( - { - "type": "work", - "permuter": str(msg.handle), - "id": msg.id, - "seed": msg.seed, - } - ) - - elif isinstance(msg, AddPermuter): - if msg.handle in self._active: - raise Exception("Repeated AddPermuter!") - - self._active.add(msg.handle) - self._send_permuter(str(msg.handle), msg.permuter_data) - fn_name = msg.permuter_data.fn_name - self._send_io(msg.handle, IoConnect(fn_name, msg.client)) - - elif isinstance(msg, RemovePermuter): - if msg.handle not in self._active: - return - - self._remove(msg.handle) - self._send_io(msg.handle, IoDisconnect("disconnected")) - - elif isinstance(msg, Disconnect): - if msg.handle not in self._active: - return - - self._remove(msg.handle) - self._send_io(msg.handle, IoDisconnect("kicked")) - self._send_controller(OutputDisconnect(handle=msg.handle)) - - elif isinstance(msg, ImmediateDisconnect): - if msg.handle in self._active: - raise Exception("ImmediateDisconnect is not immediate") - - self._send_io(msg.handle, IoImmediateDisconnect(msg.reason, msg.client)) - self._send_controller(OutputInitFail(handle=msg.handle, error=msg.reason)) - - elif isinstance(msg, PermInitFail): - handle = int(msg.perm_id) - if handle not in self._active: - self._need_work() - return - - self._active.remove(handle) - self._send_io(handle, IoDisconnect("failed to compile")) - self._send_controller( - OutputInitFail( - handle=handle, - error=msg.error, - ) - ) - - elif isinstance(msg, PermInitSuccess): - handle = int(msg.perm_id) - if handle not in self._active: - self._need_work() - return - - self._send_controller( - OutputInitSuccess( - handle=handle, - time_us=msg.time_us, - base_score=msg.base_score, - base_hash=msg.base_hash, - ) - ) - - elif isinstance(msg, WorkDone): - handle = int(msg.perm_id) - time_start = self._time_starts.pop(msg.id) - if handle not in self._active: - self._need_work() - return - - obj = msg.obj - obj["permuter"] = handle - score = json_prop(obj, "score", int) if "score" in obj else None - is_improvement = msg.compressed_source is not None - self._send_io( - handle, - IoWorkDone(score=score, is_improvement=is_improvement), - ) - self._send_controller( - OutputWork( - handle=handle, - time_start=time_start, - time_us=msg.time_us, - obj=obj, - compressed_source=msg.compressed_source, - ) - ) - - elif isinstance(msg, NeedMoreWork): - self._need_work() - - elif isinstance(msg, NetThreadDisconnected): - self._send_io_global(IoServerFailed(msg.graceful, msg.message)) - - else: - static_assert_unreachable(msg) - - def _need_work(self) -> None: - self._send_controller(OutputNeedMoreWork()) - - def _remove(self, handle: int) -> None: - self._evaluator_port.send_json({"type": "remove", "permuter": str(handle)}) - self._active.remove(handle) - - def _send_permuter(self, perm_id: str, perm: PermuterData) -> None: - self._evaluator_port.send_json( - { - "type": "add", - "permuter": perm_id, - **permuter_data_to_json(perm), - } - ) - self._evaluator_port.send(perm.source.encode("utf-8")) - self._evaluator_port.send(perm.target_o_bin) - - def _do_read_eval_loop(self) -> None: - while True: - msg = self._evaluator_port.receive_json() - msg_type = json_prop(msg, "type", str) - - if msg_type == "init": - perm_id = json_prop(msg, "permuter", str) - time_us = json_prop(msg, "time_us", int) - resp: Activity - if json_prop(msg, "success", bool): - resp = PermInitSuccess( - perm_id=perm_id, - base_score=json_prop(msg, "base_score", int), - base_hash=json_prop(msg, "base_hash", str), - time_us=time_us, - ) - else: - resp = PermInitFail( - perm_id=perm_id, - error=json_prop(msg, "error", str), - ) - self._main_queue.put(resp) - - elif msg_type == "result": - compressed_source: Optional[bytes] = None - if msg.get("has_source") == True: - compressed_source = self._evaluator_port.receive() - perm_id = json_prop(msg, "permuter", str) - id = json_prop(msg, "id", int) - time_us = json_prop(msg, "time_us", int) - del msg["permuter"] - del msg["id"] - del msg["time_us"] - self._main_queue.put( - WorkDone( - perm_id=perm_id, - id=id, - obj=msg, - time_us=time_us, - compressed_source=compressed_source, - ) - ) - - else: - raise Exception(f"Unknown message type from evaluator: {msg_type}") - - def _read_eval_loop(self) -> None: - try: - self._do_read_eval_loop() - except EOFError: - # Silence errors from shutdown. - pass - - def _main_loop(self) -> None: - while True: - msg = self._main_queue.get() - if isinstance(msg, Shutdown): - break - - self._handle_message(msg) - - def _heartbeat_loop(self) -> None: - second_attempt = False - while True: - with self._last_heartbeat_lock: - delay = ( - self._last_heartbeat - + self._heartbeat_interval - + _HEARTBEAT_INTERVAL_SLACK_SEC / 2 - - time.time() - ) - if delay <= 0: - if second_attempt: - self._main_queue.put(NetThreadDisconnected(graceful=True)) - return - # Handle clock skew or computer going to sleep by waiting a bit - # longer before giving up. - second_attempt = True - if self._heartbeat_stop.wait(_HEARTBEAT_INTERVAL_SLACK_SEC / 2): - return - else: - second_attempt = False - if self._heartbeat_stop.wait(delay): - return - - def remove_permuter(self, handle: int) -> None: - assert not self._token.cancelled - self._main_queue.put(Disconnect(handle=handle)) - - def stop(self) -> None: - assert not self._token.cancelled - self._token.cancelled = True - self._main_queue.put(Shutdown()) - self._heartbeat_stop.set() - self._net_thread.stop() - self._evaluator_port.shutdown() - self._main_thread.join() - self._heartbeat_thread.join() - - -class DockerPort(Port): - """Port for communicating with Docker. Communication is encrypted for a few - not-very-good reasons: - - it allows code reuse - - it adds error-checking - - it was fun to implement""" - - _sock: BinaryIO - _container: "docker.models.containers.Container" - _stdout_buffer: bytes - _closed: bool - - def __init__( - self, container: "docker.models.containers.Container", secret: bytes - ) -> None: - self._container = container - self._stdout_buffer = b"" - self._closed = False - - # Set up a socket for reading from stdout/stderr and writing to - # stdin for the container. The docker package does not seem to - # expose an API for writing the stdin, but we can do so directly - # by attaching a socket and poking at internal state. (See - # https://github.com/docker/docker-py/issues/983.) For stdout/ - # stderr, we use the format described at - # https://docs.docker.com/engine/api/v1.24/#attach-to-a-container. - # - # Hopefully this will keep working for at least a while... - try: - self._sock = container.attach_socket( - params={"stdout": True, "stdin": True, "stderr": True, "stream": True} - ) - self._sock._writing = True # type: ignore - except: - try: - container.remove(force=True) - except Exception: - pass - raise - - super().__init__(SecretBox(secret), "docker", is_client=True) - - def shutdown(self) -> None: - import docker - - if self._closed: - return - self._closed = True - try: - self._sock.close() - self._container.remove(force=True) - except Exception as e: - if not ( - isinstance(e, docker.errors.APIError) - and e.status_code == 409 - and "is already in progress" in str(e) - ): - print("Failed to shut down Docker") - traceback.print_exc() - - def _read_one(self) -> None: - header = file_read_fixed(self._sock, 8) - stream, length = struct.unpack(">BxxxI", header) - if stream not in [1, 2]: - raise Exception("Unexpected output from Docker: " + repr(header)) - data = file_read_fixed(self._sock, length) - if stream == 1: - self._stdout_buffer += data - else: - sys.stderr.buffer.write(b"Docker stderr: " + data) - sys.stderr.buffer.flush() - - def _receive(self, length: int) -> bytes: - while len(self._stdout_buffer) < length: - self._read_one() - ret = self._stdout_buffer[:length] - self._stdout_buffer = self._stdout_buffer[length:] - return ret - - def _receive_max(self, length: int) -> bytes: - length = min(length, len(self._stdout_buffer)) - ret = self._stdout_buffer[:length] - self._stdout_buffer = self._stdout_buffer[length:] - return ret - - def _send(self, data: bytes) -> None: - while data: - written = self._sock.write(data) - data = data[written:] - self._sock.flush() - - -def _start_evaluator(docker_image: str, options: ServerOptions) -> DockerPort: - """Spawn a docker container and set it up to evaluate permutations in, - returning a handle that we can use to communicate with it. - - We do this for a few reasons: - - enforcing a known Linux environment, all while the outside server can run - on e.g. Windows and display a systray - - enforcing resource limits - - sandboxing - - Docker does have the downside of requiring root access, so ideally we would - also have a Docker-less mode, where we leave the sandboxing to some other - tool, e.g. https://github.com/ioi/isolate/.""" - print("Starting docker...") - command = ["python3", "-m", "src.net.evaluator"] - secret = nacl.utils.random(32) - enc_secret = base64.b64encode(secret).decode("utf-8") - src_path = pathlib.Path(__file__).parent.parent.absolute() - - try: - import docker - - client = docker.from_env() - client.info() - except ModuleNotFoundError: - print( - "Running a server requires the docker Python package to be installed.\n" - "Run `python3 -m pip install --upgrade docker`." - ) - sys.exit(1) - except Exception: - traceback.print_exc() - print() - print( - "Failed to start docker. Make sure you have docker installed and " - "the docker daemon running, and either run the permuter with sudo " - 'or add yourself to the "docker" UNIX group.' - ) - sys.exit(1) - - try: - container = client.containers.run( - docker_image, - command, - detach=True, - remove=True, - stdin_open=True, - stdout=True, - environment={"SECRET": enc_secret}, - volumes={src_path: {"bind": "/src", "mode": "ro"}}, - tmpfs={"/tmp": "size=1G,exec"}, - nano_cpus=int(options.num_cores * 1e9), - mem_limit=int(options.max_memory_gb * 2 ** 30), - read_only=True, - network_disabled=True, - ) - except Exception as e: - print(f"Failed to start docker container: {e}") - sys.exit(1) - - port = DockerPort(container, secret) - - try: - # Sanity-check that the Docker container started successfully and can - # be communicated with. - magic = b"\0" * 1000 - port.send(magic) - r = port.receive() - if r != magic: - raise Exception("Failed initial sanity check.") - - port.send_json({"num_cores": options.num_cores}) - except: - port.shutdown() - raise - - print("Started.") - return port - - -class Server: - """This class represents a server that may or may not be connected to the - controller and the evaluator.""" - - _server: Optional[ServerInner] - _options: ServerOptions - _config: Config - _io_queue: "queue.Queue[IoActivity]" - - def __init__( - self, - options: ServerOptions, - config: Config, - io_queue: "queue.Queue[IoActivity]", - ) -> None: - self._server = None - self._options = options - self._config = config - self._io_queue = io_queue - - def start(self) -> None: - assert self._server is None - - net_port = connect(self._config) - net_port.send_json( - { - "method": "connect_server", - "min_priority": self._options.min_priority, - "num_cores": self._options.num_cores, - } - ) - obj = net_port.receive_json() - docker_image = json_prop(obj, "docker_image", str) - heartbeat_interval = json_prop(obj, "heartbeat_interval", float) - - evaluator_port = _start_evaluator(docker_image, self._options) - - try: - self._server = ServerInner( - net_port, evaluator_port, self._io_queue, heartbeat_interval - ) - except: - evaluator_port.shutdown() - raise - - def stop(self) -> None: - if self._server is None: - return - self._server.stop() - self._server = None - - def remove_permuter(self, handle: PermuterHandle) -> None: - if self._server is not None and not handle[1].cancelled: - self._server.remove_permuter(handle[0]) diff --git a/tools/decomp-permuter/src/objdump.py b/tools/decomp-permuter/src/objdump.py deleted file mode 100644 index 09dab5fb72..0000000000 --- a/tools/decomp-permuter/src/objdump.py +++ /dev/null @@ -1,224 +0,0 @@ -#!/usr/bin/env python3 -from dataclasses import dataclass, field -import os -import re -import string -import subprocess -import sys -from typing import List, Match, Pattern, Set, Tuple - -# Ignore registers, for cleaner output. (We don't do this right now, but it can -# be useful for debugging.) -ign_regs = False - -# Don't include branch targets in the output. Assuming our input is semantically -# equivalent skipping it shouldn't be an issue, and it makes insertions have too -# large effect. -ign_branch_targets = True - -# Skip branch-likely delay slots. (They aren't interesting on IDO.) -skip_bl_delay_slots = True - -skip_lines = 1 -re_int = re.compile(r"[0-9]+") -re_int_full = re.compile(r"\b[0-9]+\b") - - -@dataclass -class ArchSettings: - objdump: List[str] - re_comment: Pattern[str] - re_reg: Pattern[str] - re_sprel: Pattern[str] - re_includes_sp: Pattern[str] - branch_instructions: Set[str] - forbidden: Set[str] = field(default_factory=lambda: set(string.ascii_letters + "_")) - branch_likely_instructions: Set[str] = field(default_factory=set) - - -MIPS_BRANCH_LIKELY_INSTRUCTIONS = { - "beql", - "bnel", - "beqzl", - "bnezl", - "bgezl", - "bgtzl", - "blezl", - "bltzl", - "bc1tl", - "bc1fl", -} -MIPS_BRANCH_INSTRUCTIONS = { - "b", - "j", - "beq", - "bne", - "beqz", - "bnez", - "bgez", - "bgtz", - "blez", - "bltz", - "bc1t", - "bc1f", -}.union(MIPS_BRANCH_LIKELY_INSTRUCTIONS) - -MIPS_SETTINGS: ArchSettings = ArchSettings( - re_comment=re.compile(r"<.*?>"), - re_reg=re.compile( - r"\$?\b(a[0-3]|t[0-9]|s[0-8]|at|v[01]|f[12]?[0-9]|f3[01]|k[01]|fp|ra)\b" # leave out $zero - ), - re_sprel=re.compile(r"(?<=,)([0-9]+|0x[0-9a-f]+)\((sp|s8)\)"), - re_includes_sp=re.compile(r"\b(sp|s8)\b"), - objdump=["mips-linux-gnu-objdump", "-drz", "-m", "mips:4300"], - branch_likely_instructions=MIPS_BRANCH_LIKELY_INSTRUCTIONS, - branch_instructions=MIPS_BRANCH_INSTRUCTIONS, -) - - -def get_arch(o_file: str) -> ArchSettings: - # https://refspecs.linuxfoundation.org/elf/gabi4+/ch4.eheader.html - with open(o_file, "rb") as f: - f.seek(18) - arch_magic = f.read(2) - if arch_magic == b"\0\x08": - return MIPS_SETTINGS - # TODO: support PPC ("\0\x14"), ARM ("0\x28") - raise Exception("Bad ELF") - - -def parse_relocated_line(line: str) -> Tuple[str, str, str]: - try: - ind2 = line.rindex(",") - except ValueError: - ind2 = line.rindex("\t") - before = line[: ind2 + 1] - after = line[ind2 + 1 :] - ind2 = after.find("(") - if ind2 == -1: - imm, after = after, "" - else: - imm, after = after[:ind2], after[ind2:] - if imm == "0x0": - imm = "0" - return before, imm, after - - -def simplify_objdump( - input_lines: List[str], arch: ArchSettings, *, stack_differences: bool -) -> List[str]: - output_lines: List[str] = [] - nops = 0 - skip_next = False - for index, row in enumerate(input_lines): - if index < skip_lines: - continue - row = row.rstrip() - if ">:" in row or not row: - continue - if "R_MIPS_" in row: - prev = output_lines[-1] - if prev == "": - continue - before, imm, after = parse_relocated_line(prev) - repl = row.split()[-1] - # As part of ignoring branch targets, we ignore relocations for j - # instructions. The target is already lost anyway. - if imm == "": - assert ign_branch_targets - continue - # Sometimes s8 is used as a non-framepointer, but we've already lost - # the immediate value by pretending it is one. This isn't too bad, - # since it's rare and applies consistently. But we do need to handle it - # here to avoid a crash, by pretending that lost imms are zero for - # relocations. - if imm != "0" and imm != "imm" and imm != "addr": - repl += "+" + imm if int(imm, 0) > 0 else imm - if any( - reloc in row - for reloc in ["R_MIPS_LO16", "R_MIPS_LITERAL", "R_MIPS_GPREL16"] - ): - repl = f"%lo({repl})" - elif "R_MIPS_HI16" in row: - # Ideally we'd pair up R_MIPS_LO16 and R_MIPS_HI16 to generate a - # correct addend for each, but objdump doesn't give us the order of - # the relocations, so we can't find the right LO16. :( - repl = f"%hi({repl})" - else: - assert "R_MIPS_26" in row, f"unknown relocation type '{row}'" - output_lines[-1] = before + repl + after - continue - row = re.sub(arch.re_comment, "", row) - row = row.rstrip() - row = "\t".join(row.split("\t")[2:]) # [20:] - if not row: - continue - if skip_next: - skip_next = False - row = "" - if ign_regs: - row = re.sub(arch.re_reg, "", row) - row_parts = row.split("\t") - if len(row_parts) == 1: - row_parts.append("") - mnemonic, instr_args = row_parts - if not stack_differences: - if mnemonic == "addiu" and arch.re_includes_sp.search(instr_args): - row = re.sub(re_int_full, "imm", row) - if mnemonic in arch.branch_instructions: - if ign_branch_targets: - instr_parts = instr_args.split(",") - instr_parts[-1] = "" - instr_args = ",".join(instr_parts) - row = f"{mnemonic}\t{instr_args}" - # The last part is in hex, so skip the dec->hex conversion - else: - - def fn(pat: Match[str]) -> str: - full = pat.group(0) - if len(full) <= 1: - return full - start, end = pat.span() - if start and row[start - 1] in arch.forbidden: - return full - if end < len(row) and row[end] in arch.forbidden: - return full - return hex(int(full)) - - row = re.sub(re_int, fn, row) - if mnemonic in arch.branch_likely_instructions and skip_bl_delay_slots: - skip_next = True - if not stack_differences: - row = re.sub(arch.re_sprel, "addr(sp)", row) - # row = row.replace(',', ', ') - if row == "nop": - # strip trailing nops; padding is irrelevant to us - nops += 1 - else: - for _ in range(nops): - output_lines.append("nop") - nops = 0 - output_lines.append(row) - return output_lines - - -def objdump( - o_filename: str, arch: ArchSettings, *, stack_differences: bool = False -) -> List[str]: - output = subprocess.check_output(arch.objdump + [o_filename]) - lines = output.decode("utf-8").splitlines() - return simplify_objdump(lines, arch, stack_differences=stack_differences) - - -if __name__ == "__main__": - if len(sys.argv) < 2: - print(f"Usage: {sys.argv[0]} file.o", file=sys.stderr) - sys.exit(1) - - if not os.path.isfile(sys.argv[1]): - print(f"Source file {sys.argv[1]} is not readable.", file=sys.stderr) - sys.exit(1) - - lines = objdump(sys.argv[1], MIPS_SETTINGS) - for row in lines: - print(row) diff --git a/tools/decomp-permuter/src/perm/__init__.py b/tools/decomp-permuter/src/perm/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tools/decomp-permuter/src/perm/ast.py b/tools/decomp-permuter/src/perm/ast.py deleted file mode 100644 index a4a14d6a1c..0000000000 --- a/tools/decomp-permuter/src/perm/ast.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import List, Optional, Tuple - -from .. import ast_util -from .perm import EvalState, Perm -from ..ast_util import Block, Statement -from ..error import CandidateConstructionFailure - -from pycparser import c_ast as ca - - -class _Done(Exception): - pass - - -def _apply_perm(fn: ca.FuncDef, perm_id: int, perm: Perm, seed: int) -> None: - """Find and apply a single late perm macro in the AST.""" - # Currently we search for statement macros only. - wanted_pragma = f"_permuter ast_perm {perm_id}" - Loc = Tuple[List[Statement], int] - - def try_handle_block(block: ca.Node, where: Optional[Loc]) -> None: - if not isinstance(block, ca.Compound) or not block.block_items: - return - pragma = block.block_items[0] - if not isinstance(pragma, ca.Pragma) or pragma.string != wanted_pragma: - return - - args: List[Statement] = block.block_items[1:] - stmts = perm.eval_statement_ast(args, seed) - - if where: - where[0][where[1] : where[1] + 1] = stmts - else: - block.block_items = stmts - raise _Done - - def rec(block: Block) -> None: - # if (x) { _Pragma(...); inputs } -> if (x) { outputs } - try_handle_block(block, None) - stmts = ast_util.get_block_stmts(block, False) - for i, stmt in enumerate(stmts): - # { ... { _Pragma(...); inputs } ... } -> { ... outputs ... } - try_handle_block(stmt, (stmts, i)) - ast_util.for_nested_blocks(stmt, rec) - - try: - rec(fn.body) - raise CandidateConstructionFailure("Failed to find PERM macro in AST.") - except _Done: - pass - - -def apply_ast_perms(fn: ca.FuncDef, eval_state: EvalState) -> None: - """Find all late perm macros in the AST and apply them.""" - # Nested perms will have smaller IDs, so apply the perms from lowest ID to - # highest to ensure that all arguments to perms have already been evaluated. - for perm_id, (perm, seed) in enumerate(eval_state.ast_perms): - _apply_perm(fn, perm_id, perm, seed) diff --git a/tools/decomp-permuter/src/perm/eval.py b/tools/decomp-permuter/src/perm/eval.py deleted file mode 100644 index b84dcd5929..0000000000 --- a/tools/decomp-permuter/src/perm/eval.py +++ /dev/null @@ -1,36 +0,0 @@ -import random -from typing import List, Iterable, Set, Tuple - -from .perm import Perm, EvalState - - -def _gen_all_seeds(total_count: int) -> Iterable[int]: - """Generate all numbers 0..total_count-1 in random order, in expected time - O(1) per number.""" - seen: Set[int] = set() - while len(seen) < total_count // 2: - seed = random.randrange(total_count) - if seed not in seen: - seen.add(seed) - yield seed - - remaining: List[int] = [] - for seed in range(total_count): - if seed not in seen: - remaining.append(seed) - random.shuffle(remaining) - for seed in remaining: - yield seed - - -def perm_gen_all_seeds(perm: Perm) -> Iterable[int]: - while True: - for seed in _gen_all_seeds(perm.perm_count): - yield seed - if not perm.is_random(): - break - - -def perm_evaluate_one(perm: Perm) -> Tuple[str, EvalState]: - eval_state = EvalState() - return perm.evaluate(0, eval_state), eval_state diff --git a/tools/decomp-permuter/src/perm/parse.py b/tools/decomp-permuter/src/perm/parse.py deleted file mode 100644 index dd5a4838db..0000000000 --- a/tools/decomp-permuter/src/perm/parse.py +++ /dev/null @@ -1,144 +0,0 @@ -from typing import Callable, Dict, List, Tuple -import re - -from .perm import ( - CombinePerm, - GeneralPerm, - IgnorePerm, - IntPerm, - LineSwapPerm, - LineSwapAstPerm, - OncePerm, - Perm, - PretendPerm, - RandomizerPerm, - RootPerm, - TextPerm, - VarPerm, -) - - -def _split_by_comma(text: str) -> List[str]: - level = 0 - current = "" - args: List[str] = [] - for c in text: - if c == "," and level == 0: - args.append(current) - current = "" - else: - if c == "(": - level += 1 - elif c == ")": - level -= 1 - assert level >= 0, "Bad nesting" - current += c - assert level == 0, "Mismatched parentheses" - args.append(current) - return args - - -def _split_args(text: str) -> List[Perm]: - perm_args = [_rec_perm_parse(arg) for arg in _split_by_comma(text)] - return perm_args - - -def _split_args_newline(text: str) -> List[Perm]: - return [_rec_perm_parse(line) for line in text.split("\n") if line.strip()] - - -def _split_args_text(text: str) -> List[str]: - perm_list = _split_args(text) - res: List[str] = [] - for perm in perm_list: - assert isinstance(perm, TextPerm) - res.append(perm.text) - return res - - -def _make_once_perm(text: str) -> OncePerm: - args = _split_by_comma(text) - if len(args) not in [1, 2]: - raise Exception("PERM_ONCE takes 1 or 2 arguments") - key = args[0].strip() - value = _rec_perm_parse(args[-1]) - return OncePerm(key, value) - - -def _make_var_perm(text: str) -> VarPerm: - args = _split_by_comma(text) - if len(args) not in [1, 2]: - raise Exception("PERM_VAR takes 1 or 2 arguments") - var_name = _rec_perm_parse(args[0]) - value = _rec_perm_parse(args[1]) if len(args) == 2 else None - return VarPerm(var_name, value) - - -PERM_FACTORIES: Dict[str, Callable[[str], Perm]] = { - "PERM_GENERAL": lambda text: GeneralPerm(_split_args(text)), - "PERM_ONCE": lambda text: _make_once_perm(text), - "PERM_RANDOMIZE": lambda text: RandomizerPerm(_rec_perm_parse(text)), - "PERM_VAR": lambda text: _make_var_perm(text), - "PERM_LINESWAP_TEXT": lambda text: LineSwapPerm(_split_args_newline(text)), - "PERM_LINESWAP": lambda text: LineSwapAstPerm(_split_args_newline(text)), - "PERM_INT": lambda text: IntPerm(*map(int, _split_args_text(text))), - "PERM_IGNORE": lambda text: IgnorePerm(_rec_perm_parse(text)), - "PERM_PRETEND": lambda text: PretendPerm(_rec_perm_parse(text)), -} - - -def _consume_arg_parens(text: str) -> Tuple[str, str]: - level = 0 - for i, c in enumerate(text): - if c == "(": - level += 1 - elif c == ")": - level -= 1 - if level == -1: - return text[:i], text[i + 1 :] - raise Exception("Failed to find closing parenthesis when parsing PERM macro") - - -def _rec_perm_parse(text: str) -> Perm: - remain = text - macro_search = r"(PERM_.+?)\(" - - perms: List[Perm] = [] - while len(remain) > 0: - match = re.search(macro_search, remain) - - # No match found; return remaining - if match is None: - text_perm = TextPerm(remain) - perms.append(text_perm) - break - - # Get perm type and args - perm_type = match.group(1) - if perm_type not in PERM_FACTORIES: - raise Exception("Unrecognized PERM macro: " + perm_type) - between = remain[: match.start()] - args, remain = _consume_arg_parens(remain[match.end() :]) - - # Create text perm - perms.append(TextPerm(between)) - - # Create new perm - perms.append(PERM_FACTORIES[perm_type](args)) - - if len(perms) == 1: - return perms[0] - return CombinePerm(perms) - - -def perm_parse(text: str) -> Perm: - ret = _rec_perm_parse(text) - if isinstance(ret, TextPerm): - ret = RandomizerPerm(ret) - print("No perm macros found. Defaulting to randomization.") - ret = RootPerm(ret) - if not ret.is_random(): - print(f"Will run for {ret.perm_count} iterations.") - else: - print(f"Will try {ret.perm_count} different base sources.") - return ret diff --git a/tools/decomp-permuter/src/perm/perm.py b/tools/decomp-permuter/src/perm/perm.py deleted file mode 100644 index 61cc6c21c1..0000000000 --- a/tools/decomp-permuter/src/perm/perm.py +++ /dev/null @@ -1,289 +0,0 @@ -from base64 import b64encode -from collections import defaultdict -from dataclasses import dataclass, field -from typing import Dict, List, Tuple, TypeVar, Optional -import math - -from pycparser import c_ast as ca - -from ..ast_util import Statement - -T = TypeVar("T") - - -@dataclass -class PreprocessState: - once_options: Dict[str, List["Perm"]] = field( - default_factory=lambda: defaultdict(list) - ) - - -@dataclass -class EvalState: - vars: Dict[str, str] = field(default_factory=dict) - once_choices: Dict[str, "Perm"] = field(default_factory=dict) - ast_perms: List[Tuple["Perm", int]] = field(default_factory=list) - - def register_ast_perm(self, perm: "Perm", seed: int) -> int: - ret = len(self.ast_perms) - self.ast_perms.append((perm, seed)) - return ret - - def gen_ast_statement_perm( - self, perm: "Perm", seed: int, *, statements: List[str] - ) -> str: - perm_id = self.register_ast_perm(perm, seed) - lines = [ - "{", - f"#pragma _permuter ast_perm {perm_id}", - *["{" + stmt + "}" for stmt in statements], - "}", - ] - return "\n".join(lines) - - -class Perm: - """A Perm subclass generates different variations of a part of the source - code. Its evaluate method will be called with a seed between 0 and - perm_count-1, and it should return a unique string for each. - - A Perm is allowed to return different strings for the same seed, but if so, - if should override is_random to return True. This will cause permutation - to happen in an infinite loop, rather than stop after the last permutation - has been tested.""" - - perm_count: int - children: List["Perm"] - - def evaluate(self, seed: int, state: EvalState) -> str: - return "" - - def eval_statement_ast(self, args: List[Statement], seed: int) -> List[Statement]: - raise NotImplementedError - - def preprocess(self, state: PreprocessState) -> None: - for p in self.children: - p.preprocess(state) - - def is_random(self) -> bool: - return any(p.is_random() for p in self.children) - - -def _eval_all(seed: int, perms: List[Perm], state: EvalState) -> List[str]: - ret = [] - for p in perms: - seed, sub_seed = divmod(seed, p.perm_count) - ret.append(p.evaluate(sub_seed, state)) - assert seed == 0, "seed must be in [0, prod(counts))" - return ret - - -def _count_all(perms: List[Perm]) -> int: - res = 1 - for p in perms: - res *= p.perm_count - return res - - -def _eval_either(seed: int, perms: List[Perm], state: EvalState) -> str: - for p in perms: - if seed < p.perm_count: - return p.evaluate(seed, state) - seed -= p.perm_count - assert False, "seed must be in [0, sum(counts))" - - -def _count_either(perms: List[Perm]) -> int: - return sum(p.perm_count for p in perms) - - -def _shuffle(items: List[T], seed: int) -> List[T]: - items = items[:] - output = [] - while items: - ind = seed % len(items) - seed //= len(items) - output.append(items[ind]) - del items[ind] - return output - - -class RootPerm(Perm): - def __init__(self, inner: Perm) -> None: - self.children = [inner] - self.perm_count = inner.perm_count - self.preprocess_state = PreprocessState() - self.preprocess(self.preprocess_state) - for key, options in self.preprocess_state.once_options.items(): - if len(options) == 1: - raise Exception(f"PERM_ONCE({key}) occurs only once, possible error?") - self.perm_count *= len(options) - - def evaluate(self, seed: int, state: EvalState) -> str: - for key, options in self.preprocess_state.once_options.items(): - seed, choice = divmod(seed, len(options)) - state.once_choices[key] = options[choice] - return self.children[0].evaluate(seed, state) - - -class TextPerm(Perm): - def __init__(self, text: str) -> None: - # Comma escape sequence - text = text.replace("(,)", ",") - self.text = text - self.children = [] - self.perm_count = 1 - - def evaluate(self, seed: int, state: EvalState) -> str: - return self.text - - -class IgnorePerm(Perm): - def __init__(self, inner: Perm) -> None: - self.children = [inner] - self.perm_count = inner.perm_count - - def evaluate(self, seed: int, state: EvalState) -> str: - text = self.children[0].evaluate(seed, state) - if not text: - return "" - encoded = b64encode(text.encode("utf-8")).decode("ascii") - return "#pragma _permuter b64literal " + encoded - - -class PretendPerm(Perm): - def __init__(self, inner: Perm) -> None: - self.children = [inner] - self.perm_count = inner.perm_count - - def evaluate(self, seed: int, state: EvalState) -> str: - text = self.children[0].evaluate(seed, state) - return "\n".join( - [ - "", - "#pragma _permuter latedefine start", - text, - "#pragma _permuter latedefine end", - "", - ] - ) - - -class CombinePerm(Perm): - def __init__(self, parts: List[Perm]) -> None: - self.children = parts - self.perm_count = _count_all(parts) - - def evaluate(self, seed: int, state: EvalState) -> str: - texts = _eval_all(seed, self.children, state) - return "".join(texts) - - -class RandomizerPerm(Perm): - def __init__(self, inner: Perm) -> None: - self.children = [inner] - self.perm_count = inner.perm_count - - def evaluate(self, seed: int, state: EvalState) -> str: - text = self.children[0].evaluate(seed, state) - return "\n".join( - [ - "", - "#pragma _permuter randomizer start", - text, - "#pragma _permuter randomizer end", - "", - ] - ) - - def is_random(self) -> bool: - return True - - -class GeneralPerm(Perm): - def __init__(self, candidates: List[Perm]) -> None: - self.perm_count = _count_either(candidates) - self.children = candidates - - def evaluate(self, seed: int, state: EvalState) -> str: - return _eval_either(seed, self.children, state) - - -class OncePerm(Perm): - def __init__(self, key: str, inner: Perm) -> None: - self.key = key - self.children = [inner] - self.perm_count = inner.perm_count - - def preprocess(self, state: PreprocessState) -> None: - state.once_options[self.key].append(self) - super().preprocess(state) - - def evaluate(self, seed: int, state: EvalState) -> str: - if state.once_choices[self.key] is self: - return self.children[0].evaluate(seed, state) - return "" - - -class VarPerm(Perm): - def __init__(self, var_name: Perm, expansion: Optional[Perm]) -> None: - if expansion: - self.children = [var_name, expansion] - else: - self.children = [var_name] - self.perm_count = _count_all(self.children) - - def evaluate(self, seed: int, state: EvalState) -> str: - var_name_perm = self.children[0] - seed, sub_seed = divmod(seed, var_name_perm.perm_count) - var_name = var_name_perm.evaluate(sub_seed, state).strip() - if len(self.children) > 1: - ret = self.children[1].evaluate(seed, state) - state.vars[var_name] = ret - return "" - else: - if var_name not in state.vars: - raise Exception(f"Tried to read undefined PERM_VAR {var_name}") - return state.vars[var_name] - - -class LineSwapPerm(Perm): - def __init__(self, lines: List[Perm]) -> None: - self.children = lines - self.own_count = math.factorial(len(lines)) - self.perm_count = self.own_count * _count_all(self.children) - - def evaluate(self, seed: int, state: EvalState) -> str: - sub_seed, variation = divmod(seed, self.own_count) - texts = _eval_all(sub_seed, self.children, state) - return "\n".join(_shuffle(texts, variation)) - - -class LineSwapAstPerm(Perm): - def __init__(self, lines: List[Perm]) -> None: - self.children = lines - self.own_count = math.factorial(len(lines)) - self.perm_count = self.own_count * _count_all(self.children) - - def evaluate(self, seed: int, state: EvalState) -> str: - sub_seed, variation = divmod(seed, self.own_count) - texts = _eval_all(sub_seed, self.children, state) - return state.gen_ast_statement_perm(self, variation, statements=texts) - - def eval_statement_ast(self, args: List[Statement], seed: int) -> List[Statement]: - ret = [] - for item in _shuffle(args, seed): - assert isinstance(item, ca.Compound) - ret.extend(item.block_items or []) - return ret - - -class IntPerm(Perm): - def __init__(self, low: int, high: int) -> None: - assert low <= high - self.low = low - self.children = [] - self.perm_count = high - low + 1 - - def evaluate(self, seed: int, state: EvalState) -> str: - return str(self.low + seed) diff --git a/tools/decomp-permuter/src/permuter.py b/tools/decomp-permuter/src/permuter.py deleted file mode 100644 index 143b981782..0000000000 --- a/tools/decomp-permuter/src/permuter.py +++ /dev/null @@ -1,279 +0,0 @@ -from dataclasses import dataclass -import difflib -import itertools -import random -import re -import time -import traceback -from typing import ( - Any, - List, - Iterator, - Optional, - Tuple, - Union, -) - -from .candidate import Candidate, CandidateResult -from .compiler import Compiler -from .error import CandidateConstructionFailure -from .perm.perm import EvalState -from .perm.eval import perm_evaluate_one, perm_gen_all_seeds -from .perm.parse import perm_parse -from .profiler import Profiler -from .scorer import Scorer - - -@dataclass -class EvalError: - exc_str: Optional[str] - seed: Optional[Tuple[int, int]] - - -EvalResult = Union[CandidateResult, EvalError] - - -@dataclass -class Finished: - reason: Optional[str] = None - - -@dataclass -class Message: - text: str - - -class NeedMoreWork: - pass - - -class _CompileFailure(Exception): - pass - - -@dataclass -class WorkDone: - perm_index: int - result: EvalResult - - -Task = Union[Finished, Tuple[int, int]] -FeedbackItem = Union[Finished, Message, NeedMoreWork, WorkDone] -Feedback = Tuple[FeedbackItem, int, Optional[str]] - - -class Permuter: - """ - Represents a single source from which permutation candidates can be generated, - and which keeps track of good scores achieved so far. - """ - - def __init__( - self, - dir: str, - fn_name: Optional[str], - compiler: Compiler, - scorer: Scorer, - source_file: str, - source: str, - *, - force_seed: Optional[int], - force_rng_seed: Optional[int], - keep_prob: float, - need_profiler: bool, - need_all_sources: bool, - show_errors: bool, - best_only: bool, - better_only: bool, - ) -> None: - self.dir = dir - self.compiler = compiler - self.scorer = scorer - self.source_file = source_file - self.source = source - - if fn_name is None: - # Semi-legacy codepath; all functions imported through import.py have a - # function name. This would ideally be done on AST level instead of on the - # pre-macro'ed source code, but we don't care enough to make that - # refactoring. - fns = _find_fns(source) - if len(fns) == 0: - raise Exception(f"{self.source_file} does not contain any function!") - if len(fns) > 1: - raise Exception( - f"{self.source_file} must contain only one function, " - "or have a function.txt next to it with a function name." - ) - self.fn_name = fns[0] - else: - self.fn_name = fn_name - self.unique_name = self.fn_name - - self._permutations = perm_parse(source) - - self._force_seed = force_seed - self._force_rng_seed = force_rng_seed - self._cur_seed: Optional[Tuple[int, int]] = None - - self.keep_prob = keep_prob - self.need_profiler = need_profiler - self._need_all_sources = need_all_sources - self._show_errors = show_errors - self._best_only = best_only - self._better_only = better_only - - ( - self.base_score, - self.base_hash, - self.base_source, - ) = self._create_and_score_base() - self.best_score = self.base_score - self.hashes = {self.base_hash} - self._cur_cand: Optional[Candidate] = None - self._last_score: Optional[int] = None - - def _create_and_score_base(self) -> Tuple[int, str, str]: - base_source, eval_state = perm_evaluate_one(self._permutations) - base_cand = Candidate.from_source( - base_source, eval_state, self.fn_name, rng_seed=0 - ) - o_file = base_cand.compile(self.compiler, show_errors=True) - if not o_file: - raise CandidateConstructionFailure(f"Unable to compile {self.source_file}") - base_result = base_cand.score(self.scorer, o_file) - assert base_result.hash is not None - return base_result.score, base_result.hash, base_cand.get_source() - - def _need_to_send_source(self, result: CandidateResult) -> bool: - return self._need_all_sources or self.should_output(result) - - def _eval_candidate(self, seed: int) -> CandidateResult: - t0 = time.time() - - # Determine if we should keep the last candidate. - # Don't keep 0-score candidates; we'll only create new, worse, zeroes. - keep = ( - self._permutations.is_random() - and random.uniform(0, 1) < self.keep_prob - and self._last_score != 0 - and self._last_score != self.scorer.PENALTY_INF - ) or self._force_rng_seed - - self._last_score = None - - # Create a new candidate if we didn't keep the last one (or if the last one didn't exist) - # N.B. if we decide to keep the previous candidate, we will skip over the provided seed. - # This means we're not guaranteed to test all seeds, but it doesn't really matter since - # we're randomizing anyway. - if not self._cur_cand or not keep: - eval_state = EvalState() - cand_c = self._permutations.evaluate(seed, eval_state) - rng_seed = self._force_rng_seed or random.randrange(1, 10 ** 20) - self._cur_seed = (seed, rng_seed) - self._cur_cand = Candidate.from_source( - cand_c, eval_state, self.fn_name, rng_seed=rng_seed - ) - - # Randomize the candidate - if self._permutations.is_random(): - self._cur_cand.randomize_ast() - - t1 = time.time() - - self._cur_cand.get_source() - - t2 = time.time() - - o_file = self._cur_cand.compile(self.compiler) - if not o_file and self._show_errors: - raise _CompileFailure() - - t3 = time.time() - - result = self._cur_cand.score(self.scorer, o_file) - - t4 = time.time() - - if self.need_profiler: - profiler = Profiler() - profiler.add_stat(Profiler.StatType.perm, t1 - t0) - profiler.add_stat(Profiler.StatType.stringify, t2 - t1) - profiler.add_stat(Profiler.StatType.compile, t3 - t2) - profiler.add_stat(Profiler.StatType.score, t4 - t3) - result.profiler = profiler - - self._last_score = result.score - - if not self._need_to_send_source(result): - result.source = None - result.hash = None - - return result - - def should_output(self, result: CandidateResult) -> bool: - """Check whether a result should be outputted. This must be more liberal - in child processes than in parent ones, or else sources will be missing.""" - return ( - result.score <= self.base_score - and result.hash is not None - and result.source is not None - and not (result.score > self.best_score and self._best_only) - and ( - result.score < self.base_score - or (result.score == self.base_score and not self._better_only) - ) - and result.hash not in self.hashes - ) - - def record_result(self, result: CandidateResult) -> None: - """Record a new result, updating the best score and adding the hash to - the set of hashes we have already seen. No hash is recorded for score - 0, since we are interested in all score 0's, not just the first.""" - self.best_score = min(self.best_score, result.score) - if result.score != 0 and result.hash is not None: - self.hashes.add(result.hash) - - def seed_iterator(self) -> Iterator[int]: - """Create an iterator over all seeds for this permuter. The iterator - will be infinite if we are randomizing.""" - if self._force_seed is None: - return iter(perm_gen_all_seeds(self._permutations)) - if self._permutations.is_random(): - return itertools.repeat(self._force_seed) - return iter([self._force_seed]) - - def try_eval_candidate(self, seed: int) -> EvalResult: - """Evaluate a seed for the permuter.""" - try: - return self._eval_candidate(seed) - except _CompileFailure: - return EvalError(exc_str=None, seed=self._cur_seed) - except Exception: - return EvalError(exc_str=traceback.format_exc(), seed=self._cur_seed) - - def diff(self, other_source: str) -> str: - """Compute a unified white-space-ignoring diff from the (pretty-printed) - base source against another source generated from this permuter.""" - - class Line(str): - def __eq__(self, other: Any) -> bool: - return isinstance(other, str) and self.strip() == other.strip() - - def __hash__(self) -> int: - return hash(self.strip()) - - a = list(map(Line, self.base_source.split("\n"))) - b = list(map(Line, other_source.split("\n"))) - return "\n".join( - difflib.unified_diff(a, b, fromfile="before", tofile="after", lineterm="") - ) - - -def _find_fns(source: str) -> List[str]: - fns = re.findall(r"(\w+)\([^()\n]*\)\s*?{", source) - return [ - fn - for fn in fns - if not fn.startswith("PERM") and fn not in ["if", "for", "switch", "while"] - ] diff --git a/tools/decomp-permuter/src/preprocess.py b/tools/decomp-permuter/src/preprocess.py deleted file mode 100644 index 214d359c89..0000000000 --- a/tools/decomp-permuter/src/preprocess.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import List -import subprocess - - -def preprocess(filename: str, cpp_args: List[str] = []) -> str: - return subprocess.check_output( - ["cpp"] + cpp_args + ["-P", "-nostdinc", filename], - universal_newlines=True, - encoding="utf-8", - ) diff --git a/tools/decomp-permuter/src/printer.py b/tools/decomp-permuter/src/printer.py deleted file mode 100644 index 804e63c154..0000000000 --- a/tools/decomp-permuter/src/printer.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Optional - -from .permuter import Permuter - -# Number of additional characters to replace with spaces, in addition to what -# is required based on the length of the previous progress line. This could be -# set to 0, but it's nice to have some margin to deal with e.g. zero-width -# control characters. -SAFETY_PAD = 10 - - -class Printer: - _last_progress: Optional[str] = None - - def progress(self, message: str) -> None: - if self._last_progress is None: - clear = "" - else: - pad = max(len(self._last_progress) - len(message) + SAFETY_PAD, 0) - clear = "\b" * pad + " " * pad + "\r" - print(clear + message, end="", flush=True) - self._last_progress = message - - def print( - self, - message: str, - permuter: Optional[Permuter], - who: Optional[str], - *, - color: str = "", - keep_progress: bool = False, - ) -> None: - if self._last_progress is not None: - if keep_progress: - print() - else: - pad = len(self._last_progress) + SAFETY_PAD - print("\r" + " " * pad + "\r", end="") - if permuter is not None: - message = f"[{permuter.unique_name}] {message}" - if who is not None: - message = f"[{who}] {message}" - if color: - message = f"{color}{message}\u001b[0m" - print(message) - self._last_progress = None diff --git a/tools/decomp-permuter/src/profiler.py b/tools/decomp-permuter/src/profiler.py deleted file mode 100644 index 974b3b9886..0000000000 --- a/tools/decomp-permuter/src/profiler.py +++ /dev/null @@ -1,23 +0,0 @@ -from enum import Enum - - -class Profiler: - class StatType(Enum): - perm = 1 - stringify = 2 - compile = 3 - score = 4 - - def __init__(self) -> None: - self.time_stats = {x: 0.0 for x in Profiler.StatType} - - def add_stat(self, stat: StatType, time_taken: float) -> None: - self.time_stats[stat] += time_taken - - def get_str_stats(self) -> str: - total_time = sum(self.time_stats[e] for e in self.time_stats) - timings = ", ".join( - f"{round(100 * self.time_stats[e] / total_time)}% {e.name}" - for e in self.time_stats - ) - return timings diff --git a/tools/decomp-permuter/src/randomizer.py b/tools/decomp-permuter/src/randomizer.py deleted file mode 100644 index 1eb59f22bb..0000000000 --- a/tools/decomp-permuter/src/randomizer.py +++ /dev/null @@ -1,1940 +0,0 @@ -import bisect -import copy -from dataclasses import dataclass, field -from random import Random -import typing -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Sequence, - Set, - Tuple, - TypeVar, - Union, -) - -from pycparser import c_ast as ca - -from . import ast_util -from .ast_util import Block, Indices, Statement, Expression -from .ast_types import ( - SimpleType, - Type, - TypeMap, - allowed_basic_type, - build_typemap, - decayed_expr_type, - get_decl_type, - resolve_typedefs, - same_type, - set_decl_name, - pointer_decay, -) - -# Set to true to perform expression type detection eagerly. This can help when -# debugging crashes in the ast_types code. -DEBUG_EAGER_TYPES = False - -# Randomize the type of introduced temporary variable with this probability -PROB_RANDOMIZE_TYPE = 0.3 - -# Reuse an existing var instead of introducing a new temporary one with this probability -PROB_REUSE_VAR = 0.5 - -# When wrapping statements in a new block, use a same-line `do { ... } while(0);` -# (as opposed to non-same-line `if (1) { ... }`) with this probability. -# This matches what macros often do. -PROB_INS_BLOCK_DOWHILE = 0.5 - -# Make a pointer to a temporary expression, rather than copy it by value, with -# this probability. (This always happens for expressions of struct type, -# regardless of this probability.) -PROB_TEMP_PTR = 0.05 - -# Instead of emitting an assignment statement, assign the temporary within the -# first expression it's used in with this probability. -PROB_TEMP_ASSIGN_AT_FIRST_USE = 0.1 - -# When creating a temporary for an expression, use the temporary for all equal -# expressions with this probability. -PROB_TEMP_REPLACE_ALL = 0.2 - -# When creating a temporary for an expression, use the temporary for an interval -# with maximal endpoint with this probability. -PROB_TEMP_REPLACE_MOST = 0.2 - -# When substituting a variable by its value, substitute all instances with this -# probability, rather than just a subrange or the complement of one. -PROB_EXPAND_REPLACE_ALL = 0.3 - -# When substituting a variable by its value, keep the variable assignment with -# this probability. -PROB_KEEP_REPLACED_VAR = 0.2 - -# Change the return type of an external function to void with this probability. -PROB_RET_VOID = 0.2 - -# Number larger than any node index. (If you're trying to compile a 1 GB large -# C file to matching asm, you have bigger problems than this limit.) -MAX_INDEX = 10 ** 9 - -T = TypeVar("T") - - -class RandomizationFailure(Exception): - pass - - -def ensure(condition: Any) -> None: - """Abort the randomization pass if 'condition' fails to hold, and try - another pass instead. Don't call this after making any modifications to - the AST.""" - if not condition: - raise RandomizationFailure - - -@dataclass -class Region: - start: int - end: int - indices: Optional[Indices] = field(compare=False) - - @staticmethod - def unbounded() -> "Region": - return Region(-1, MAX_INDEX, None) - - def is_unbounded(self) -> bool: - return self.indices is None - - def contains_node(self, node: ca.Node) -> bool: - """Check whether the region contains an entire node.""" - if self.indices is None: - return True - return ( - self.start < self.indices.starts[node] - and self.indices.ends[node] < self.end - ) - - def contains_pre(self, node: ca.Node) -> bool: - """Check whether the region contains a point just before a given node.""" - if self.indices is None: - return True - return self.start < self.indices.starts[node] < self.end - - def contains_pre_index(self, index: int) -> bool: - """Check whether the region contains a point just before a given node, - as specified by its index.""" - if self.indices is None: - return True - return self.start < index < self.end - - -def reverse_start_indices(indices: Indices) -> Dict[int, ca.Node]: - ret = {} - for k, v in indices.starts.items(): - ret[v] = k - return ret - - -def get_randomization_region( - top_node: ca.Node, indices: Indices, random: Random -) -> Region: - ret: List[Region] = [] - cur_start: Optional[int] = None - - class Visitor(ca.NodeVisitor): - def visit_Pragma(self, node: ca.Pragma) -> None: - nonlocal cur_start - if node.string == "_permuter randomizer start": - if cur_start is not None: - raise Exception("nested PERM_RANDOMIZE not supported") - cur_start = indices.ends[node] - if node.string == "_permuter randomizer end": - assert cur_start is not None, "randomizer end without start" - ret.append(Region(cur_start + 1, indices.starts[node] - 1, indices)) - cur_start = None - - Visitor().visit(top_node) - assert cur_start is None, "randomizer start without end" - if not ret: - return Region.unbounded() - return random.choice(ret) - - -def get_block_expressions(block: Block, region: Region) -> List[Expression]: - """Return a list of all expressions within a block that are also within a - given region.""" - exprs: List[Expression] = [] - - def visitor(expr: Expression) -> None: - if region.contains_node(expr): - exprs.append(expr) - - replace_subexprs(block, visitor) - return exprs - - -def compute_write_locations( - top_node: ca.Node, indices: Indices -) -> Dict[str, List[int]]: - writes: Dict[str, List[int]] = {} - - def add_write(var_name: str, loc: int) -> None: - if var_name not in writes: - writes[var_name] = [] - else: - assert ( - loc > writes[var_name][-1] - ), "consistent traversal order should guarantee monotonicity here" - writes[var_name].append(loc) - - class Visitor(ca.NodeVisitor): - def visit_Decl(self, node: ca.Decl) -> None: - if node.name: - add_write(node.name, indices.starts[node]) - self.generic_visit(node) - - def visit_UnaryOp(self, node: ca.UnaryOp) -> None: - if node.op in ["p++", "p--", "++", "--"] and isinstance(node.expr, ca.ID): - add_write(node.expr.name, indices.starts[node]) - self.generic_visit(node) - - def visit_Assignment(self, node: ca.Assignment) -> None: - if isinstance(node.lvalue, ca.ID): - add_write(node.lvalue.name, indices.starts[node]) - self.generic_visit(node) - - Visitor().visit(top_node) - return writes - - -def compute_read_locations(top_node: ca.Node, indices: Indices) -> Dict[str, List[int]]: - reads: Dict[str, List[int]] = {} - for node in find_var_reads(top_node): - var_name = node.name - loc = indices.starts[node] - if var_name not in reads: - reads[var_name] = [] - else: - assert ( - loc > reads[var_name][-1] - ), "consistent traversal order should guarantee monotonicity here" - reads[var_name].append(loc) - return reads - - -def find_var_reads(top_node: ca.Node) -> List[ca.ID]: - ret = [] - - class Visitor(ca.NodeVisitor): - def visit_Decl(self, node: ca.Decl) -> None: - if node.init: - self.visit(node.init) - - def visit_ID(self, node: ca.ID) -> None: - ret.append(node) - - def visit_UnaryOp(self, node: ca.UnaryOp) -> None: - if node.op == "&" and isinstance(node.expr, ca.ID): - return - self.generic_visit(node) - - def visit_StructRef(self, node: ca.StructRef) -> None: - self.visit(node.name) - - def visit_Assignment(self, node: ca.Assignment) -> None: - if isinstance(node.lvalue, ca.ID): - return - self.generic_visit(node) - - Visitor().visit(top_node) - return ret - - -def visit_replace(top_node: ca.Node, callback: Callable[[ca.Node, bool], Any]) -> None: - def empty_statement_to_none(node: Any) -> Any: - if isinstance(node, ca.EmptyStatement): - return None - return node - - def rec(orig_node: ca.Node, toplevel: bool = False, *, lvalue: bool = False) -> Any: - node: "ca.AnyNode" = typing.cast("ca.AnyNode", orig_node) - repl = callback(node, not toplevel and not lvalue) - if repl: - return repl - if isinstance(node, ca.Assignment): - node.lvalue = rec(node.lvalue, lvalue=True) - node.rvalue = rec(node.rvalue) - elif isinstance(node, ca.StructRef): - node.name = rec(node.name, lvalue=(lvalue and node.type == ".")) - elif isinstance(node, ca.Cast): - if node.expr: - node.expr = rec(node.expr) - elif isinstance(node, (ca.Constant, ca.ID)): - pass - elif isinstance(node, ca.UnaryOp): - if node.op in ["p++", "p--", "++", "--", "&"]: - node.expr = rec(node.expr, lvalue=True) - elif node.op != "sizeof": - node.expr = rec(node.expr) - elif isinstance(node, ca.BinaryOp): - node.left = rec(node.left) - node.right = rec(node.right) - elif isinstance(node, ca.FuncCall): - # not worth replacing .name - if node.args: - rec(node.args, True) - elif isinstance(node, ca.ExprList): - for i in range(len(node.exprs)): - if not isinstance(node.exprs[i], ca.Typename): - node.exprs[i] = rec(node.exprs[i]) - elif isinstance(node, ca.ArrayRef): - node.name = rec(node.name, lvalue=lvalue) - node.subscript = rec(node.subscript) - elif isinstance(node, ca.TernaryOp): - node.cond = rec(node.cond) - node.iftrue = rec(node.iftrue, True) - node.iffalse = rec(node.iffalse, True) - elif isinstance(node, ca.Return): - if node.expr: - node.expr = rec(node.expr) - elif isinstance(node, ca.Decl): - if node.init: - node.init = rec(node.init, isinstance(node.init, ca.InitList)) - elif isinstance(node, ca.For): - if node.init: - node.init = empty_statement_to_none(rec(node.init, True)) - if node.cond: - node.cond = rec(node.cond) - if node.next: - node.next = empty_statement_to_none(rec(node.next, True)) - node.stmt = rec(node.stmt, True) - elif isinstance(node, ca.Compound): - if node.block_items: - for i, sub in enumerate(node.block_items): - node.block_items[i] = rec(sub, True) - elif isinstance(node, (ca.Case, ca.Default)): - if node.stmts: - for i, sub in enumerate(node.stmts): - node.stmts[i] = rec(sub, True) - elif isinstance(node, ca.While): - node.cond = rec(node.cond) - node.stmt = rec(node.stmt, True) - elif isinstance(node, ca.DoWhile): - node.stmt = rec(node.stmt, True) - node.cond = rec(node.cond) - elif isinstance(node, ca.Switch): - node.cond = rec(node.cond) - node.stmt = rec(node.stmt, True) - elif isinstance(node, ca.Label): - node.stmt = rec(node.stmt, True) - elif isinstance(node, ca.If): - node.cond = rec(node.cond) - node.iftrue = rec(node.iftrue, True) - if node.iffalse: - node.iffalse = rec(node.iffalse, True) - elif isinstance( - node, - ( - ca.TypeDecl, - ca.PtrDecl, - ca.ArrayDecl, - ca.Typename, - ca.IdentifierType, - ca.Struct, - ca.Union, - ca.Enum, - ca.EmptyStatement, - ca.Pragma, - ca.Break, - ca.Continue, - ca.Goto, - ca.CompoundLiteral, - ca.Typedef, - ca.FuncDecl, - ca.FuncDef, - ca.EllipsisParam, - ca.Enumerator, - ca.EnumeratorList, - ca.FileAST, - ca.InitList, - ca.NamedInitializer, - ca.ParamList, - ), - ): - pass - else: - _: None = node - assert False, f"Node with unknown type: {node}" - return node - - rec(top_node, True) - - -def replace_subexprs(top_node: ca.Node, callback: Callable[[Expression], Any]) -> None: - def expr_filter(node: ca.Node, is_expr: bool) -> Any: - if not is_expr: - return None - return callback(typing.cast(Expression, node)) - - visit_replace(top_node, expr_filter) - - -def replace_node(top_node: ca.Node, old: ca.Node, new: ca.Node) -> None: - visit_replace(top_node, lambda node, _: new if node is old else None) - - -def random_bool(random: Random, prob: float) -> bool: - return random.random() < prob - - -def random_weighted(random: Random, values: Sequence[Tuple[T, float]]) -> T: - sumprob = 0.0 - for (val, prob) in values: - assert prob >= 0, "Probabilities must be non-negative" - sumprob += prob - assert sumprob > 0, "Cannot pick randomly from empty set" - targetprob = random.uniform(0, sumprob) - sumprob = 0.0 - for (val, prob) in values: - sumprob += prob - if sumprob > targetprob: - return val - - # Float imprecision - for (val, prob) in values: - if prob > 0: - return val - assert False, "unreachable" - - -def random_type(random: Random) -> SimpleType: - new_names: List[str] = [] - if random_bool(random, 0.5): - new_names.append("unsigned") - new_names.extend( - random_weighted( - random, - [ - (["char"], 1), - (["short"], 1), - (["int"], 2), - (["long"], 0.5), - (["long", "long"], 0.5), - ], - ) - ) - idtype = ca.IdentifierType(names=new_names) - quals = [] - if random_bool(random, 0.5): - quals = ["volatile"] - return ca.TypeDecl(declname=None, quals=quals, type=idtype) - - -def randomize_type( - type: SimpleType, typemap: TypeMap, random: Random, *, ensure_changed: bool = False -) -> SimpleType: - if allowed_basic_type( - type, typemap, ["int", "char", "long", "short", "signed", "unsigned"] - ): - return random_type(random) - if ensure_changed: - raise RandomizationFailure - return type - - -def randomize_innermost_type( - type: Type, typemap: TypeMap, random: Random, *, ensure_changed: bool = False -) -> Type: - if isinstance(type, ca.TypeDecl): - return randomize_type(type, typemap, random, ensure_changed=ensure_changed) - new_type = copy.copy(type) - new_type.type = randomize_innermost_type( - type.type, typemap, random, ensure_changed=ensure_changed - ) - return new_type - - -def get_insertion_points( - fn: ca.FuncDef, region: Region, *, allow_within_decl: bool = False -) -> List[Tuple[Block, int, Optional[ca.Node]]]: - cands: List[Tuple[Block, int, Optional[ca.Node]]] = [] - - def rec(block: Block) -> None: - stmts = ast_util.get_block_stmts(block, False) - last_node: ca.Node = block - for i, stmt in enumerate(stmts): - if region.contains_pre(stmt): - cands.append((block, i, stmt)) - ast_util.for_nested_blocks(stmt, rec) - last_node = stmt - if region.contains_node(last_node): - cands.append((block, len(stmts), None)) - - rec(fn.body) - if not allow_within_decl: - cands = [c for c in cands if not isinstance(c[2], ca.Decl)] - return cands - - -def maybe_reuse_var( - var: Optional[str], - assign_before: ca.Node, - orig_expr: Expression, - type: SimpleType, - reads: Dict[str, List[int]], - writes: Dict[str, List[int]], - indices: Indices, - typemap: TypeMap, - random: Random, -) -> Optional[str]: - if not random_bool(random, PROB_REUSE_VAR) or var is None: - return None - var_type: SimpleType = decayed_expr_type(ca.ID(var), typemap) - if not same_type(var_type, type, typemap, allow_similar=True): - return None - - def find_next(list: List[int], value: int) -> Optional[int]: - ind = bisect.bisect_left(list, value) - if ind < len(list): - return list[ind] - return None - - assignment_ind = indices.starts[assign_before] - expr_ind = indices.starts[orig_expr] - write = find_next(writes.get(var, []), assignment_ind) - read = find_next(reads.get(var, []), assignment_ind) - # TODO: if write/read is within expr, search again from after it (since - # we move expr, uses within it aren't relevant). - if read is not None and (write is None or write >= read): - # We don't want to overwrite a variable which we later read, - # unless we write to it before that read - return None - if write is not None and write < expr_ind: - # Our write will be overwritten before we manage to read from it. - return None - return var - - -def perm_temp_for_expr( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Create a temporary variable for a random expression. The variable will - be assigned at another random point (nearer the expression being more - likely), possibly reuse an existing variable, possibly be of a different - size/signedness, and possibly be used for other identical expressions as - well. Only expressions within the given region may be chosen for - replacement, but the assignment and the affected identical expressions may - be outside of it.""" - Place = Tuple[Block, int, Statement] - einds: Dict[ca.Node, int] = {} - writes: Dict[str, List[int]] = compute_write_locations(fn, indices) - reads: Dict[str, List[int]] = compute_read_locations(fn, indices) - typemap = build_typemap(ast) - candidates: List[Tuple[Tuple[Place, Expression, Optional[str]], float]] = [] - - # Step 0: decide whether to make a pointer to the chosen expression, or to - # copy it by value. - should_make_ptr = random_bool(random, PROB_TEMP_PTR) - - def surrounding_writes(expr: Expression, base: Expression) -> Tuple[int, int]: - """Compute the previous and next write to a variable included in expr, - starting from base. If none, default to -1 or MAX_INDEX respectively. - If base itself writes to an included variable (e.g. if it is an - increment expression), the \"next\" write will be defined as the node - itself, while the \"previous\" will continue searching to the left.""" - sub_reads = find_var_reads(expr) - prev_write = -1 - next_write = MAX_INDEX - base_index = indices.starts[base] - for sub_read in sub_reads: - var_name = sub_read.name - if var_name not in writes: - continue - # Find the first write that is strictly before indices[expr], - # and the first write that is on or after. - wr = writes[var_name] - ind = bisect.bisect_left(wr, base_index) - if ind > 0: - prev_write = max(prev_write, wr[ind - 1]) - if ind < len(wr): - next_write = min(next_write, wr[ind]) - return prev_write, next_write - - # Step 1: assign probabilities to each place/expression - def rec(block: Block, reuse_cands: List[str]) -> None: - stmts = ast_util.get_block_stmts(block, False) - reuse_cands = reuse_cands[:] - assignment_cands: List[Place] = [] # places to insert before - past_decls = False - for index, stmt in enumerate(stmts): - if isinstance(stmt, ca.Decl): - assert stmt.name, "Anonymous declarations cannot happen in functions" - if not isinstance(stmt.type, ca.ArrayDecl): - reuse_cands.append(stmt.name) - if not isinstance(stmt.type, ca.PtrDecl): - # Make non-pointers more common - reuse_cands.append(stmt.name) - elif not isinstance(stmt, ca.Pragma): - past_decls = True - if past_decls: - assignment_cands.append((block, index, stmt)) - - ast_util.for_nested_blocks(stmt, lambda b: rec(b, reuse_cands)) - - def visitor(expr: Expression) -> None: - if DEBUG_EAGER_TYPES: - decayed_expr_type(expr, typemap) - - if not region.contains_node(expr): - return - - orig_expr = expr - if should_make_ptr: - if not ast_util.is_lvalue(expr): - return - expr = ca.UnaryOp("&", expr) - - eind = einds.get(expr, 0) - prev_write, _ = surrounding_writes(expr, orig_expr) - - for place in assignment_cands[::-1]: - # If expr contains an ID which is written to within - # [place, expr), bail out; we're trying to move the - # assignment too high up. - # TODO: also fail on moving past function calls, or - # possibly-aliasing writes. - if indices.starts[place[2]] <= prev_write: - break - - # Make far-away places less likely, and similarly for - # trivial expressions. - eind += 1 - prob = 1 / eind - if isinstance(orig_expr, (ca.ID, ca.Constant)): - prob *= 0.15 if should_make_ptr else 0.5 - reuse_cand = random.choice(reuse_cands) if reuse_cands else None - candidates.append(((place, expr, reuse_cand), prob)) - - einds[expr] = eind - - replace_subexprs(stmt, visitor) - - rec(fn.body, []) - - # Step 2: decide on a place/expression - ensure(candidates) - place: Optional[Place] - place, expr, reuse_cand = random_weighted(random, candidates) - - if random_bool(random, PROB_TEMP_ASSIGN_AT_FIRST_USE): - # Don't emit a statement for the assignment, emit an assignment - # expression at the first use instead. - place = None - - type: SimpleType = decayed_expr_type(expr, typemap) - - # Always use pointers when replacing structs - if ( - not should_make_ptr - and isinstance(type, ca.TypeDecl) - and isinstance(type.type, (ca.Struct, ca.Union)) - and ast_util.is_lvalue(expr) - ): - should_make_ptr = True - expr = ca.UnaryOp("&", expr) - type = decayed_expr_type(expr, typemap) - - if should_make_ptr: - assert isinstance(expr, ca.UnaryOp) - assert not isinstance(expr.expr, ca.Typename) - orig_expr = expr.expr - else: - orig_expr = expr - # print("replacing:", to_c(expr)) - - # Step 3: decide on a variable to hold the expression - if place is not None: - assign_before = place[2] - else: - assign_before = orig_expr - reused_var = maybe_reuse_var( - reuse_cand, - assign_before, - orig_expr, - type, - reads, - writes, - indices, - typemap, - random, - ) - if reused_var is not None: - reused = True - var = reused_var - else: - reused = False - var = "new_var" - counter = 1 - while var in writes: - counter += 1 - var = f"new_var{counter}" - - # Step 4: possibly expand the replacement to include duplicate expressions. - prev_write, next_write = surrounding_writes(expr, orig_expr) - prev_write = max(prev_write, indices.starts[assign_before] - 1) - replace_cands: List[Expression] = [] - - def find_duplicates(e: Expression) -> None: - if prev_write < indices.starts[e] <= next_write and ast_util.equal_ast( - e, orig_expr - ): - replace_cands.append(e) - - if ast_util.is_effectful(expr): - replace_cands = [orig_expr] - else: - replace_subexprs(fn.body, find_duplicates) - - assert orig_expr in replace_cands - if random_bool(random, PROB_TEMP_REPLACE_ALL): - lo_index = 0 - hi_index = len(replace_cands) - else: - index = replace_cands.index(orig_expr) - lo_index = random.randint(0, index) - hi_index = random.randint(index + 1, len(replace_cands)) - if random_bool(random, PROB_TEMP_REPLACE_MOST): - if random_bool(random, 0.5): - lo_index = 0 - else: - hi_index = len(replace_cands) - replace_cand_set = set(replace_cands[lo_index:hi_index]) - - # Step 5: replace the chosen expression - def replacer(e: Expression) -> Optional[Expression]: - if e in replace_cand_set: - ret: Expression = ca.ID(var) - if place is None and e is orig_expr: - ret = ca.Assignment("=", ret, expr) - if should_make_ptr: - ret = ca.UnaryOp("*", ret) - return ret - return None - - replace_subexprs(fn.body, replacer) - - # Step 6: insert the assignment and any new variable declaration - if place is not None: - block, index, _ = place - assignment = ca.Assignment("=", ca.ID(var), expr) - ast_util.insert_statement(block, index, assignment) - if not reused: - if random_bool(random, PROB_RANDOMIZE_TYPE): - type = randomize_type(type, typemap, random) - ast_util.insert_decl(fn, var, type, random) - - -def perm_expand_expr( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Replace a random variable by its contents.""" - all_writes: Dict[str, List[int]] = compute_write_locations(fn, indices) - all_reads: Dict[str, List[int]] = compute_read_locations(fn, indices) - - # Step 1: pick out a variable to replace - rev: Dict[int, str] = {} - for var, locs in all_reads.items(): - for index in locs: - if region.contains_pre_index(index): - rev[index] = var - ensure(rev) - index = random.choice(list(rev.keys())) - var = rev[index] - - # Step 2: find the assignment it uses - reads = all_reads[var] - writes = all_writes.get(var, []) - i = bisect.bisect_left(writes, index) - # if i == 0, there is no write to replace the read by. - ensure(i > 0) - before = writes[i - 1] - after = MAX_INDEX if i == len(writes) else writes[i] - rev_indices = reverse_start_indices(indices) - write = rev_indices[before] - if ( - isinstance(write, ca.Decl) - and write.init - and not isinstance(write.init, ca.InitList) - ): - repl_expr = write.init - elif isinstance(write, ca.Assignment) and write.op == "=": - repl_expr = write.rvalue - else: - raise RandomizationFailure - - # Step 3: pick of the range of variables to replace - repl_cands = [ - i for i in reads if before < i < after and region.contains_pre_index(i) - ] - assert repl_cands, "index is always in repl_cands" - myi = repl_cands.index(index) - if not random_bool(random, PROB_EXPAND_REPLACE_ALL) and len(repl_cands) > 1: - # Keep using the variable for a bit in the middle - side = random.randrange(3) - H = len(repl_cands) - loi = 0 if side == 0 else random.randint(0, myi) - hii = H if side == 1 else random.randint(myi + 1, H) - if loi == 0 and hii == H: - loi, hii = myi, myi + 1 - repl_cands[loi:hii] = [] - keep_var = True - else: - keep_var = random_bool(random, PROB_KEEP_REPLACED_VAR) - repl_cands_set = set(repl_cands) - - # Don't duplicate effectful expressions. - if ast_util.is_effectful(repl_expr): - ensure(len(repl_cands) == 1 and not keep_var) - - # Step 4: do the replacement - def callback(expr: ca.Node, is_expr: bool) -> Optional[ca.Node]: - if indices.starts[expr] in repl_cands_set: - return copy.deepcopy(repl_expr) - if expr == write and isinstance(write, ca.Assignment) and not keep_var: - if is_expr: - return write.lvalue - else: - return ca.EmptyStatement() - return None - - visit_replace(fn.body, callback) - if not keep_var and isinstance(write, ca.Decl): - write.init = None - - -def perm_randomize_internal_type( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Randomize types of pre-existing local variables. Function parameters - are not included -- those are handled by perm_randomize_function_type. - Only variables mentioned within the given region are affected.""" - names: Set[str] = set() - - class IdVisitor(ca.NodeVisitor): - def visit_ID(self, node: ca.ID) -> None: - if region.contains_node(node): - names.add(node.name) - - def visit_StructRef(self, node: ca.StructRef) -> None: - self.visit(node.name) - - IdVisitor().visit(fn) - - typemap = build_typemap(ast) - decls: List[ca.Decl] = [] - - class Visitor(ca.NodeVisitor): - def visit_Decl(self, decl: ca.Decl) -> None: - if isinstance(decl.type, ca.TypeDecl) and decl.name and decl.name in names: - decls.append(decl) - self.generic_visit(decl) - - Visitor().visit(fn) - - ensure(decls) - decl = random.choice(decls) - assert isinstance(decl.type, ca.TypeDecl), "checked above" - decl.type = randomize_type(decl.type, typemap, random, ensure_changed=True) - set_decl_name(decl) - - -def perm_randomize_external_type( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Randomize types of global variables. Only variables mentioned within the - given region are affected.""" - names: Set[str] = set() - - class IdVisitor(ca.NodeVisitor): - def visit_ID(self, node: ca.ID) -> None: - if region.contains_node(node): - names.add(node.name) - - def visit_StructRef(self, node: ca.StructRef) -> None: - self.visit(node.name) - - IdVisitor().visit(fn) - - ensure(names) - name = random.choice(list(names)) - decls: List[Tuple[ca.Decl, int]] = [] - - for i in range(len(ast.ext)): - item = ast.ext[i] - if isinstance(item, ca.Decl) and item.name == name: - new_decl = copy.copy(item) - decls.append((new_decl, i)) - - ensure(decls) - decl = random.choice(decls)[0] - decl_type = get_decl_type(decl) - - typemap = build_typemap(ast) - new_type = randomize_innermost_type(decl_type, typemap, random, ensure_changed=True) - - for decl, i in decls: - decl.type = copy.deepcopy(new_type) - ast.ext[i] = decl - set_decl_name(decl) - - -def perm_randomize_function_type( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Randomize types of function parameters and returns. Only functions - called within the given region are affected, plus the current function.""" - assert fn.decl.name is not None, "function definitions have names" - names: Set[str] = {fn.decl.name} - - class IdVisitor(ca.NodeVisitor): - def visit_FuncCall(self, node: ca.FuncCall) -> None: - if region.contains_node(node) and isinstance(node.name, ca.ID): - names.add(node.name.name) - self.generic_visit(node) - - IdVisitor().visit(fn) - - name = random.choice(list(names)) - - # Find the declarations of function with the given name. For performance - # reasons, the part of the AST they live in are shared between all - # randomization runs, so if we mutated them in place bad things would - # happen. Thus, we replace the AST parts we plan to change with mutable - # copies. - all_decls: List[Tuple[ca.Decl, int, "ca.ExternalDeclaration"]] = [] - main_decl: Optional[ca.Decl] = None - for i in range(len(ast.ext)): - item = ast.ext[i] - if ( - isinstance(item, ca.Decl) - and isinstance(item.type, ca.FuncDecl) - and item.name == name - ): - new_decl = copy.copy(item) - ast.ext[i] = new_decl - all_decls.append((new_decl, i, new_decl)) - if isinstance(item, ca.FuncDef) and item.decl.name == name: - assert isinstance( - item.decl.type, ca.FuncDecl - ), "function definitions have function types" - new_fndef = copy.copy(item) - new_decl = copy.copy(item.decl) - new_fndef.decl = new_decl - ast.ext[i] = new_fndef - all_decls.append((new_decl, i, new_fndef)) - main_decl = new_decl - - # Change the type within the function definition if there is one (since we - # need to keep names there), or else within an arbitrary of the (typically - # just one) declarations. We later mirror the change to all declarations. - ensure(all_decls) - if not main_decl: - main_decl = random.choice(all_decls)[0] - - typemap = build_typemap(ast) - - main_fndecl = copy.deepcopy(main_decl.type) - assert isinstance(main_fndecl, ca.FuncDecl), "checked above" - main_decl.type = main_fndecl - - if random_bool(random, 0.5): - # Replace the return type, changing integer signedness/size as well as - # switching to/from void (which we should perhaps avoid if the function - # call result is used, but eh, it's annoying to tell). - type = pointer_decay(main_fndecl.type, typemap) - if allowed_basic_type(type, typemap, ["void"]): - main_fndecl.type = random_type(random) - elif random_bool(random, PROB_RET_VOID): - idtype = ca.IdentifierType(names=["void"]) - main_fndecl.type = ca.TypeDecl(declname=None, quals=[], type=idtype) - else: - main_fndecl.type = randomize_type( - type, typemap, random, ensure_changed=True - ) - set_decl_name(main_decl) - else: - # Replace a parameter, changing integer signedness/size. - if not main_fndecl.args: - raise RandomizationFailure - ensure(main_fndecl.args.params) - ind = random.randrange(len(main_fndecl.args.params)) - arg = main_fndecl.args.params[ind] - if isinstance(arg, (ca.ID, ca.EllipsisParam)): - raise RandomizationFailure - arg_type = arg.type if isinstance(arg, ca.Typename) else get_decl_type(arg) - type = pointer_decay(arg_type, typemap) - arg.type = randomize_type(type, typemap, random, ensure_changed=True) - if isinstance(arg, ca.Decl): - set_decl_name(arg) - - # Commit the changes by writing them back to the AST, for all declarations. - for i in range(len(all_decls)): - decl, ind, new_node = all_decls[i] - ast.ext[ind] = new_node - if decl is not main_decl: - decl.type = copy.deepcopy(main_decl.type) - - -def perm_refer_to_var( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Add `if (variable) {}` or `if (struct.member) {}` in a random place. - This will get optimized away but may affect regalloc.""" - # Find expression to insert, searching within the randomization region. - cands: List[Expression] = [ - expr - for expr in get_block_expressions(fn.body, region) - if isinstance(expr, (ca.StructRef, ca.ID)) - ] - ensure(cands) - expr = random.choice(cands) - ensure(not ast_util.is_effectful(expr)) - typemap = build_typemap(ast) - type: Type = resolve_typedefs(decayed_expr_type(expr, typemap), typemap) - if isinstance(type, ca.TypeDecl) and isinstance(type.type, (ca.Struct, ca.Union)): - expr = ca.UnaryOp("&", expr) - - if random_bool(random, 0.5): - expr = ca.UnaryOp("!", expr) - - # Insert it wherever -- possibly outside the randomization region, since regalloc - # can act at a distance. (Except before a declaration.) - ins_cands = get_insertion_points(fn, Region.unbounded()) - ensure(ins_cands) - - cond = copy.deepcopy(expr) - - # Repeat the condition up to two times: if (x && x && x) {} sometimes helps. - for i in range(random.choice((0, 0, 0, 0, 0, 1, 2, 2))): - cond = ca.BinaryOp("&&", cond, copy.deepcopy(expr)) - - stmt = ca.If(cond=cond, iftrue=ca.Compound(block_items=[]), iffalse=None) - tob, toi, _ = random.choice(ins_cands) - ast_util.insert_statement(tob, toi, stmt) - - -def perm_ins_block( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Wrap a random range of statements within `if (1) { ... }` or - `do { ... } while(0)`. Control flow can have remote effects, so this - mostly ignores the region restriction.""" - cands: List[Block] = [] - - def rec(block: Block) -> None: - cands.append(block) - for stmt in ast_util.get_block_stmts(block, False): - ast_util.for_nested_blocks(stmt, rec) - - rec(fn.body) - block = random.choice(cands) - stmts = ast_util.get_block_stmts(block, True) - decl_count = 0 - for stmt in stmts: - if isinstance(stmt, (ca.Decl, ca.Pragma)): - decl_count += 1 - else: - break - lo = random.randrange(decl_count, len(stmts) + 1) - hi = random.randrange(decl_count, len(stmts) + 1) - if hi < lo: - lo, hi = hi, lo - new_block = ca.Compound(block_items=stmts[lo:hi]) - if random_bool(random, PROB_INS_BLOCK_DOWHILE) and all( - region.contains_node(n) for n in stmts[lo:hi] - ): - cond = ca.Constant(type="int", value="0") - stmts[lo:hi] = [ - ca.Pragma("_permuter sameline start"), - ca.DoWhile(cond=cond, stmt=new_block), - ca.Pragma("_permuter sameline end"), - ] - else: - cond = ca.Constant(type="int", value="1") - stmts[lo:hi] = [ca.If(cond=cond, iftrue=new_block, iffalse=None)] - - -def perm_empty_stmt( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Inserts a no-op statement, one of: - - if (1) {} (sometimes multiple of them) - - if (0) {} - - label: - - goto label; label:; - - ; - Control flow can have remote effects, so this - ignores the region restriction.""" - - # Insert the statement wherever, except before a declaration. - cands = get_insertion_points(fn, Region.unbounded()) - ensure(cands) - - label_name = f"dummy_label_{random.randint(1, 10**6)}" - - stmts: List[Statement] = [] - - kind = random.randrange(5) - if kind == 0: # if (1) or multiple if (1) - count = random.choice([1, random.randint(2, 6)]) - for _ in range(count): - cond = ca.Constant(type="int", value="1") - stmts.append(ca.If(cond=cond, iftrue=ca.Compound([]), iffalse=None)) - elif kind == 1: # if (0) - cond = ca.Constant(type="int", value="0") - stmts = [ca.If(cond=cond, iftrue=ca.Compound([]), iffalse=None)] - elif kind == 2: # label: - stmts = [ca.Label(label_name, ca.EmptyStatement())] - pass - elif kind == 3: # goto label; label: - stmts = [ - ca.Goto(label_name), - ca.Label(label_name, ca.EmptyStatement()), - ] - elif kind == 4: # ; - stmts = [ca.EmptyStatement()] - - tob, toi, _ = random.choice(cands) - stmts.insert(0, ca.Pragma("_permuter sameline start")) - stmts.append(ca.Pragma("_permuter sameline end")) - for stmt in stmts[::-1]: - ast_util.insert_statement(tob, toi, stmt) - - -def perm_sameline( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Put all statements within a random interval on the same line.""" - cands = get_insertion_points(fn, region) - n = len(cands) - ensure(n >= 3) - # Generate a small random interval - lef: float = n - 2 - for i in range(4): - lef *= random.uniform(0, 1) - le = int(lef) + 2 - i = random.randrange(n - le) - j = i + le - # Insert the second statement first, since inserting a statement may cause - # later indices to move. - ast_util.insert_statement( - cands[j][0], cands[j][1], ca.Pragma("_permuter sameline end") - ) - ast_util.insert_statement( - cands[i][0], cands[i][1], ca.Pragma("_permuter sameline start") - ) - - -def perm_associative( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Change a+b into b+a, or similar for other commutative operations.""" - cands: List[ca.BinaryOp] = [] - commutative_ops = list("+*|&^<>") + ["<=", ">=", "==", "!="] - - class Visitor(ca.NodeVisitor): - def visit_BinaryOp(self, node: ca.BinaryOp) -> None: - if node.op in commutative_ops and region.contains_node(node): - cands.append(node) - self.generic_visit(node) - - Visitor().visit(fn.body) - ensure(cands) - node = random.choice(cands) - node.left, node.right = node.right, node.left - if node.op[0] == "<": - node.op = ">" + node.op[1:] - elif node.op[0] == ">": - node.op = "<" + node.op[1:] - - -def perm_condition( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Change if(x) into if(x != 0), or vice versa. Also handles for/while/do-while.""" - cands: List[Union[ca.If, ca.While, ca.DoWhile, ca.For]] = [] - - class Visitor(ca.NodeVisitor): - def visit_If(self, node: ca.If) -> None: - cands.append(node) - self.generic_visit(node) - - def visit_While(self, node: ca.While) -> None: - cands.append(node) - self.generic_visit(node) - - def visit_DoWhile(self, node: ca.DoWhile) -> None: - cands.append(node) - self.generic_visit(node) - - def visit_For(self, node: ca.For) -> None: - cands.append(node) - self.generic_visit(node) - - Visitor().visit(fn.body) - ensure(cands) - node = random.choice(cands) - if not node.cond: - raise RandomizationFailure - - if ( - isinstance(node.cond, ca.BinaryOp) - and node.cond.op in ["==", "!=", "<", ">", "<=", ">="] - and random_bool(random, 0.9) - ): - ensure(node.cond.op in ["==", "!="]) - ensure( - isinstance(node.cond.right, ca.Constant) - and node.cond.right.value in ["0", "0U", "0.0", "0.0f"] - ) - if node.cond.op == "==": - node.cond = ca.UnaryOp("!", node.cond.left) - else: - node.cond = node.cond.left - else: - expr = node.cond - op = "!=" - if isinstance(expr, ca.UnaryOp) and expr.op == "!" and random_bool(random, 0.9): - assert not isinstance(expr.expr, ca.Typename) - expr = expr.expr - op = "==" - zero = random_weighted( - random, - [ - (ca.Constant("int", "0"), 0.8), - (ca.Constant("unsigned int", "0U"), 0.2), - (ca.Constant("float", "0.0f"), 0.05), - ], - ) - node.cond = ca.BinaryOp(op, expr, zero) - - -def perm_add_self_assignment( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Introduce a "x = x;" somewhere.""" - cands = get_insertion_points(fn, region) - vars: List[str] = [] - - class Visitor(ca.NodeVisitor): - def visit_Decl(self, decl: ca.Decl) -> None: - if decl.name: - vars.append(decl.name) - self.generic_visit(decl) - - Visitor().visit(fn.body) - ensure(vars) - ensure(cands) - var = random.choice(vars) - where = random.choice(cands) - assignment = ca.Assignment("=", ca.ID(var), ca.ID(var)) - ast_util.insert_statement(where[0], where[1], assignment) - - -def perm_dummy_comma_expr( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Change x into (0, x) for a random expression x.""" - cands = get_block_expressions(fn.body, region) - ensure(cands) - expr = random.choice(cands) - new_expr = ca.ExprList([ca.Constant("int", "0"), expr]) - replace_node(fn.body, expr, new_expr) - - -def perm_reorder_stmts( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Move a statement to another random place.""" - cands = get_insertion_points(fn, region, allow_within_decl=True) - - # Figure out candidate statements to be moved. Don't move pragmas; it can - # cause assertion failures. Don't move blocks; statements are generally not - # reordered across basic blocks, and we don't want to risk moving a block - # to inside itself. - source_inds = [] - for i, c in enumerate(cands): - stmt = c[2] - if ( - stmt is not None - and not isinstance(stmt, ca.Pragma) - and not ast_util.has_nested_block(stmt) - ): - source_inds.append(i) - - ensure(source_inds) - fromi = random.choice(source_inds) - - weighted_cands = [] - for i in range(len(cands)): - dist = max(fromi - i, i - (fromi + 1)) - if dist == 0: - continue - # Move distance 1, 2, 3, ... with probabilities - # 23%, 12%, 8%, 6%, 4%, 3%, 3%, 2%, 2%, 2%, ... - prob = (dist + 1) ** -1.5 - weighted_cands.append((i, prob)) - ensure(weighted_cands) - toi = random_weighted(random, weighted_cands) - - fromb, fromi, from_stmt = cands[fromi] - tob, toi, to_stmt = cands[toi] - - if fromb == tob: - ensure(toi != fromi and toi != fromi + 1) - - if isinstance(from_stmt, ca.Decl): - # Moving a declaration is tricky, when also preserving C89 compatibility. - # We can move it to after another declaration, or to the start of a block. - # Alternatively, if the declaration includes an initializer, and we move - # it forwards, we can split that out as an assignment. - # We don't allow moving the declaration or assignment past the next - # occurrence of the variable. - ensure(from_stmt.name) - var_name = from_stmt.name - to_index = indices.starts[to_stmt] if to_stmt else indices.ends[fromb] - uses = 0 - - class Visitor(ca.NodeVisitor): - def visit_ID(self, node: ca.ID) -> None: - nonlocal uses - if node.name == var_name and indices.starts[node] < to_index: - uses += 1 - - def visit_TypeDecl(self, node: ca.TypeDecl) -> None: - nonlocal uses - if node.declname == var_name and indices.starts[node] < to_index: - uses += 1 - - Visitor().visit(fn.body) - ensure(uses <= 1) - - to_block_stmts = ast_util.get_block_stmts(tob, False) - if toi == 0 or isinstance(to_block_stmts[toi - 1], ca.Decl): - # Fine to move - pass - elif ( - from_stmt.name - and from_stmt.init - and not isinstance(from_stmt.init, ca.InitList) - and uses > 0 - ): - assignment = ca.Assignment("=", ca.ID(from_stmt.name), from_stmt.init) - ast_util.insert_statement(tob, toi, assignment) - from_stmt.init = None - return - else: - raise RandomizationFailure - else: - # Don't put statements before declarations. - ensure(not isinstance(to_stmt, ca.Decl)) - - if fromb == tob and fromi < toi: - toi -= 1 - - stmt = ast_util.get_block_stmts(fromb, True).pop(fromi) - ast_util.insert_statement(tob, toi, stmt) - - -def perm_compound_assignment( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Convert a statement of the form `x = x op y` to `x op= y`, or vice versa.""" - cands: List[ca.Assignment] = [] - operators = ["+", "-", "*", "/", "<<", ">>", "^", "|", "&"] - - class Visitor(ca.NodeVisitor): - def visit_Assignment(self, node: ca.Assignment) -> None: - if region.contains_node(node): - if node.op != "=" or ( - isinstance(node.rvalue, ca.BinaryOp) - and ast_util.equal_ast(node.lvalue, node.rvalue.left) - and node.rvalue.op in operators - ): - cands.append(node) - self.generic_visit(node) - - Visitor().visit(fn.body) - ensure(cands) - node = random.choice(cands) - - if node.op == "=": - assert isinstance(node.rvalue, ca.BinaryOp) - node.op = node.rvalue.op + node.op - node.rvalue = node.rvalue.right - else: - operator = node.op[:-1] - node.op = "=" - node.rvalue = ca.BinaryOp(operator, copy.deepcopy(node.lvalue), node.rvalue) - - -def perm_inequalities( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Adjusts inequalities to equivalent versions that sometimes produce different code. - For example, a > b and a >= b + 1, a < b to a <= b - 1 (and vice versa)""" - cands: List[ca.BinaryOp] = [] - inequalities = ["<", ">", "<=", ">="] - - class Visitor(ca.NodeVisitor): - def visit_BinaryOp(self, node: ca.BinaryOp) -> None: - if node.op in inequalities and region.contains_node(node): - cands.append(node) - self.generic_visit(node) - - Visitor().visit(fn.body) - ensure(cands) - - node = random.choice(cands) - - # Does not simplify, 'a <= (b + 1)' becomes 'a < ((b + 1) + 1)' - - def plus1(node: ca.Node) -> ca.BinaryOp: - return ca.BinaryOp("+", node, ca.Constant("int", "1")) - - def minus1(node: ca.Node) -> ca.BinaryOp: - return ca.BinaryOp("-", node, ca.Constant("int", "1")) - - # Don't change the operator, change both operands (can produce fake matches sometimes) - # Ex: a > b -> a + 1 > b + 1 - if random.random() < 0.25: - change = random.choice([plus1, minus1]) - node.left = change(node.left) - node.right = change(node.right) - - else: - if node.op in ["<", ">="]: - node.op = {"<": "<=", ">=": ">"}[node.op] - if random_bool(random, 0.5): - node.left = plus1(node.left) - else: - node.right = minus1(node.right) - else: - node.op = {">": ">=", "<=": "<"}[node.op] - if random_bool(random, 0.5): - node.left = minus1(node.left) - else: - node.right = plus1(node.right) - - -def perm_add_mask( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Add a random amount of masks of 0xFF[FFFFFFFFFFFFFF] to a random expression of integer type. - In some cases this mask is optimized out but affects regalloc. - The regalloc change seems to cycle with slight differences every n masks.""" - typemap = build_typemap(ast) - - # Find expression to add the mask to - cands: List[Expression] = get_block_expressions(fn.body, region) - ensure(cands) - - expr = random.choice(cands) - type: SimpleType = decayed_expr_type(expr, typemap) - ensure( - allowed_basic_type( - type, typemap, ["int", "char", "long", "short", "signed", "unsigned"] - ) - ) - - # Mask as if restricting the value to 8, 16, 32, or 64-bit width. - # Sometimes use an unsigned mask like '0xFFu' - masks: List[str] = ["0xFF", "0xFFFF", "0xFFFFFFFF", "0xFFFFFFFFFFFFFFFF"] - mask = random.choice(masks) + random.choice(["", "u"]) - - new_expr = ca.BinaryOp("&", expr, ca.Constant("int", mask)) - if random_bool(random, 0.3): - for _ in range(random.randrange(12)): - new_expr = ca.BinaryOp("&", new_expr, ca.Constant("int", mask)) - - replace_node(fn.body, expr, new_expr) - - -def perm_xor_zero( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Add ^0 to a random expression of integer type, or *1 to floats.""" - typemap = build_typemap(ast) - - # Find a random expression - cands: List[Expression] = get_block_expressions(fn.body, region) - ensure(cands) - - expr = random.choice(cands) - type: SimpleType = decayed_expr_type(expr, typemap) - int_types = ["int", "char", "long", "short", "signed", "unsigned"] - - if allowed_basic_type(type, typemap, int_types): - new_expr = ca.BinaryOp("^", expr, ca.Constant("int", "0")) - elif allowed_basic_type(type, typemap, ["float"]): - new_expr = ca.BinaryOp("*", expr, ca.Constant("float", "1.0f")) - elif allowed_basic_type(type, typemap, ["double"]): - new_expr = ca.BinaryOp("*", expr, ca.Constant("double", "1.0")) - else: - raise RandomizationFailure - - replace_node(fn.body, expr, new_expr) - - -def perm_float_literal( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Converts a Float Literal""" - cands: List[ca.Constant] = [] - - class Visitor(ca.NodeVisitor): - def visit_Constant(self, node: ca.Constant) -> None: - if node.type == "float" and region.contains_node(node): - cands.append(node) - - Visitor().visit(fn.body) - ensure(cands) - - node = random.choice(cands) - - value: str = node.value.lower() - choices: List[str] = [value[:-1]] - if value.endswith(".0f"): - choices.append(value[:-3] or "0") - elif value.endswith(".f"): - choices.append(value[:-2] or "0") - if value.startswith("0."): - choices.append("." + (value[2:] or "0")) - elif value.startswith("."): - choices.append("0" + value) - if value.endswith(".0f"): - choices.append((value[:-3] or "0") + ".f") - else: - choices.append(value[:-1] + "0f") - - ensure(choices) - value = random.choice(choices) - if value.endswith("f"): - type = "float" - elif "." in value: - type = "double" - else: - type = "int" - - replace_node(fn.body, node, ca.Constant(type, value)) - - -def perm_cast_simple( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Cast a random expression to a simple type (integral or floating point only).""" - typemap = build_typemap(ast) - - # Find a random expression - cands: List[Expression] = get_block_expressions(fn.body, region) - ensure(cands) - - expr = random.choice(cands) - type: SimpleType = decayed_expr_type(expr, typemap) - ensure( - allowed_basic_type( - type, - typemap, - ["int", "char", "long", "short", "signed", "unsigned", "float", "double"], - ) - ) - - integral_type = [["int"], ["char"], ["long"], ["short"], ["long", "long"]] - floating_type = [["float"], ["double"]] - new_type: List[str] - if random_bool(random, 0.5): - # Cast to integral type, sometimes unsigned - sign: List[str] = random.choice([[], ["unsigned"]]) - new_type = sign + random.choice(integral_type) - else: - # Cast to floating point type - new_type = random.choice(floating_type) - - # Surround the original expression with a cast to the chosen type - typedecl = ca.TypeDecl(None, [], ca.IdentifierType(new_type)) - new_expr = ca.Cast(ca.Typename(None, [], typedecl), expr) - replace_node(fn.body, expr, new_expr) - - -# struct_ref # type of a # easiest conversion -################################################################ -# (a + b).c; # impossible # -# (a + b)->c; # s* # a[b].c -# (*(a + b)).c; # s* # a[b].c -# (*(a + b))->c; # s** # (*(a[b]).c -# (&(a + b)).c; # impossible # -# (&(a + b))->c; # impossible # -# (*(&(a + b))).c; # impossible # -# (*(&(a + b)))->c; # imp: a+b=rvalue # -# (&(*(a + b))).c; # impossible # -# (&(*(a + b)))->c; # s* # a[b].c (-&* req.) -################################################################ -# (a[b]).c; # s* # (a + b)->c -# (a[b])->c; # s** # (*(a + b))->c -# (*(a[b])).c; # s** # (*(a + b))->c -# (*(a[b]))->c; # s*** # (*(*(a + b)))->c -# (&(a[b])).c; # impossible # -# (&(a[b]))->c; # s* # (&(*(a + b)))->c -# (*(&(a[b]))).c; # s* # (*(&(a + b)))->c -# (*(&(a[b])))->c; # s** # (*(&(*(a + b))))->c -# (&(*(a[b]))).c; # impossible # -# (&(*(a[b])))->c; # s** # (&(*(*(a + b))))->c -################################################################ -# a.c # s # (&a)->c -# a->c # s* # (*a).c -# (*a).c # s* # a->c -# (*a)->c # s** # (*(*a)).c -# (&a).c # impossible # -# (&a)->c # s # (*(&a)).c -def perm_struct_ref( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Permute struct references: (a + b)->c, and (*(a + b)).c, a[b].c, (&a[b])->c""" - cands: List[ca.StructRef] = [] - - class Visitor(ca.NodeVisitor): - def visit_StructRef(self, node: ca.StructRef) -> None: - if region.contains_node(node): - cands.append(node) - self.generic_visit(node) - - Visitor().visit(fn.body) - ensure(cands) - - # TODO: Split into separate perm? Need a separate one for arrayrefs, (a + b)[1] to a[b + 1] - def randomize_associative_binop(left: ca.Node, right: ca.BinaryOp) -> ca.BinaryOp: - """Try moving parentheses to the left side sometimes (sadly, it seems to matter)""" - if random_bool(random, 0.5) and right.op in ["+", "-"]: - # ((a + b) - c) - return ca.BinaryOp( - right.op, ca.BinaryOp("+", left, right.left), right.right - ) - else: - # (a + (b - c)) - return ca.BinaryOp("+", left, right) - - # Conversions - def to_array(node: ca.BinaryOp) -> ca.ArrayRef: - """Change a BinaryOp, a + b, to an ArrayRef, a[b] - The operator is expected to be + or -""" - # TODO: Permute binops like to_binop() does - if node.op == "-": - # Convert to a[-b] - node.right = ca.UnaryOp("-", node.right) - return ca.ArrayRef(node.left, node.right) - - def to_binop(node: ca.ArrayRef) -> ca.BinaryOp: - """Change an ArrayRef, a[b], to a BinaryOp, a + b - If b is also BinaryOp, such as a[b - 1], sometimes change the order of operations, - ie: a + (b - 1) vs (a + b) - 1""" - if isinstance(node.subscript, ca.BinaryOp): - return randomize_associative_binop(node.name, node.subscript) - return ca.BinaryOp("+", node.name, node.subscript) - - def deref(node: Expression) -> Expression: - """Surround the given node with a dereference operator""" - if isinstance(node, ca.UnaryOp) and node.op == "&": - assert not isinstance(node.expr, ca.Typename) - return node.expr - return ca.UnaryOp("*", node) - - def addr(node: Expression) -> Expression: - """Surround the given node with an address-of operator""" - if isinstance(node, ca.UnaryOp) and node.op == "*": - assert not isinstance(node.expr, ca.Typename) - return node.expr - return ca.UnaryOp("&", node) - - def rec(node: ca.Node) -> Any: - """Recurse down the StructRef tree, finding the parent of the leaf BinaryOp/ArrayRef. - Throws RandomizationFailure when a UnaryOp other than * or & was encountered.""" - if isinstance(node, ca.UnaryOp): - ensure(node.op in ["&", "*"]) - return rec(node.expr) or node - if isinstance(node, ca.StructRef): - return rec(node.name) or node - return None - - # TODO - def apply_child( # type: ignore - parent: Union[ca.StructRef, ca.UnaryOp], func - ) -> None: - if isinstance(parent, ca.StructRef): - parent.name = func(parent.name) - elif isinstance(parent, ca.UnaryOp): - parent.expr = func(parent.expr) - - def get_child(parent: Union[ca.StructRef, ca.UnaryOp]) -> ca.Node: - if isinstance(parent, ca.StructRef): - return parent.name - elif isinstance(parent, ca.UnaryOp): - return parent.expr - - struct_ref = random.choice(cands) - parent: Union[ca.StructRef, ca.UnaryOp] - - # Step 1: Find the parent of the leaf node - parent = rec(struct_ref) - - changed = False - - # Step 2: Simplify (...)->c to (*(...)).c - if struct_ref.type == "->": - struct_ref.type = "." - # check if deref would remove the parent node - if ( - parent is struct_ref.name - and isinstance(parent, ca.UnaryOp) - and parent.op == "&" - ): - struct_ref.name = deref(struct_ref.name) - parent = struct_ref - else: - struct_ref.name = deref(struct_ref.name) - if parent is struct_ref and isinstance( - struct_ref.name, ca.UnaryOp - ): # Check to make mypy happy - parent = struct_ref.name - changed = True - - # Simple StructRefs only need their type permuted - if isinstance(get_child(parent), (ca.ArrayRef, ca.BinaryOp)): - # For binops, a lhs like &(a+b)->c is impossible, because a + b is an rvalue - - # Step 3: Simplify further by converting ArrayRef to BinaryOp - if isinstance(get_child(parent), ca.ArrayRef): - apply_child(parent, to_binop) - apply_child(parent, deref) - parent = typing.cast("Union[ca.StructRef, ca.UnaryOp]", get_child(parent)) - changed = True - - # Step 4: Convert back to ArrayRef - if random_bool(random, 0.5): - # Sanity check that there's at least one dereference - if isinstance(parent, ca.UnaryOp) and parent.op == "*": - apply_child(parent, to_array) - apply_child(parent, addr) - changed = True - - # Step 5: Convert the StructRef type back - if random_bool(random, 0.5): - struct_ref.name = addr(struct_ref.name) - struct_ref.type = "->" - changed = True - - ensure(changed) - - -def perm_split_assignment( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Split assignments of the form a = b . c . d ...; into a = b; a = a . c . d ...;, a = c . d ...; a = b . a;, etc.""" - cands = [] - # Look for assignments of the form 'var = binaryOp' (ignores op=) - class Visitor(ca.NodeVisitor): - def visit_Assignment(self, node: ca.Assignment) -> None: - if ( - node.op == "=" - and isinstance(node.rvalue, ca.BinaryOp) - and region.contains_node(node) - ): - cands.append(node) - self.generic_visit(node) - - Visitor().visit(fn.body) - ensure(cands) - - assign = random.choice(cands) - var = assign.lvalue - - ins_cands = get_insertion_points(fn, region) - - for ins_block, ins_index, node in ins_cands: - if node is assign: - break - else: - raise RandomizationFailure - - binops = [] - - def collect_binops(node: ca.BinaryOp) -> None: - if isinstance(node.left, ca.BinaryOp): - collect_binops(node.left) - binops.append(node) - if isinstance(node.right, ca.BinaryOp): - collect_binops(node.right) - - collect_binops(typing.cast(ca.BinaryOp, assign.rvalue)) - - split = random.choice(binops) - - typemap = build_typemap(ast) - vartype = decayed_expr_type(var, typemap) - - # Choose which side to move to a new assignment - if random_bool(random, 0.5): - side = split.left - sidetype = decayed_expr_type(side, typemap) - ensure(same_type(vartype, sidetype, typemap, allow_similar=True)) - split.left = copy.deepcopy(var) - else: - side = split.right - sidetype = decayed_expr_type(side, typemap) - ensure(same_type(vartype, sidetype, typemap, allow_similar=True)) - split.right = copy.deepcopy(var) - - # The assignment is always inserted before the original - new_assign = ca.Assignment("=", copy.deepcopy(var), side) - ast_util.insert_statement(ins_block, ins_index, new_assign) - - -def perm_remove_ast( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Delete parts of the function that might be unnecessary (mistakes or unnecessary changes from an improved base.c).""" - cands: List[Tuple[ca.Node, ca.Node]] = [] - - class Visitor(ca.NodeVisitor): - def visit_Cast(self, node: ca.Cast) -> None: - if region.contains_node(node): - cands.append((node, node.expr)) - self.generic_visit(node) - - # Replace (a & constant) with (a). - def visit_BinaryOp(self, node: ca.BinaryOp) -> None: - if region.contains_node(node) and node.op == "&": - if isinstance(node.left, ca.Constant): - cands.append((node, node.right)) - if isinstance(node.right, ca.Constant): - cands.append((node, node.left)) - self.generic_visit(node) - - # Remove if statements that don't have an else - def visit_If(self, node: ca.If) -> None: - if not node.iffalse and region.contains_node(node): - cands.append((node, node.iftrue)) - self.generic_visit(node) - - # Remove loops - def visit_While(self, node: ca.While) -> None: - if region.contains_node(node): - cands.append((node, node.stmt)) - self.generic_visit(node) - - def visit_DoWhile(self, node: ca.DoWhile) -> None: - if region.contains_node(node): - cands.append((node, node.stmt)) - self.generic_visit(node) - - Visitor().visit(fn.body) - ensure(cands) - - cand, expr = random.choice(cands) - replace_node(fn.body, cand, expr) - - -def perm_duplicate_assignment( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Duplicate an assignment, sometimes forcing IDO to reuse a register.""" - cands = [] - - class Visitor(ca.NodeVisitor): - def visit_Assignment(self, node: ca.Assignment) -> None: - if region.contains_node(node) and node.op == "=": - cands.append(node) - self.generic_visit(node) - - Visitor().visit(fn.body) - ensure(cands) - cand = random.choice(cands) - - ins_cands = get_insertion_points(fn, Region.unbounded()) - ensure(ins_cands) - - dup = copy.deepcopy(cand) - tob, toi, _ = random.choice(ins_cands) - ast_util.insert_statement(tob, toi, dup) - - -def perm_pad_var_decl( - fn: ca.FuncDef, ast: ca.FileAST, indices: Indices, region: Region, random: Random -) -> None: - """Inserts an unused variable to adjust stack offsets. Probably only useful with --stack-diffs enabled.""" - vars: List[str] = [] - - class Visitor(ca.NodeVisitor): - def visit_Decl(self, decl: ca.Decl) -> None: - if decl.name: - vars.append(decl.name) - self.generic_visit(decl) - - Visitor().visit(fn.body) - - var = "pad" - counter = 1 - while var in vars: - counter += 1 - var = f"pad{counter}" - - type = random_type(random) - ast_util.insert_decl(fn, var, type, random) - - -class Randomizer: - def __init__(self, rng_seed: int) -> None: - self.random = Random(rng_seed) - - def randomize(self, ast: ca.FileAST, fn_index: int) -> None: - fn = ast.ext[fn_index] - assert isinstance(fn, ca.FuncDef) - indices = ast_util.compute_node_indices(fn) - region = get_randomization_region(fn, indices, self.random) - methods = [ - (perm_temp_for_expr, 100), - (perm_expand_expr, 20), - (perm_reorder_stmts, 20), - (perm_add_mask, 15), - (perm_xor_zero, 10), - (perm_cast_simple, 10), - (perm_refer_to_var, 10), - (perm_float_literal, 10), - (perm_randomize_internal_type, 10), - (perm_randomize_external_type, 5), - (perm_randomize_function_type, 5), - (perm_split_assignment, 10), - (perm_sameline, 10), - (perm_ins_block, 10), - (perm_struct_ref, 10), - (perm_empty_stmt, 10), - (perm_condition, 10), - (perm_dummy_comma_expr, 5), - (perm_add_self_assignment, 5), - (perm_associative, 5), - (perm_inequalities, 5), - (perm_compound_assignment, 5), - (perm_remove_ast, 5), - (perm_duplicate_assignment, 5), - (perm_pad_var_decl, 1), - ] - while True: - method = random_weighted(self.random, methods) - try: - method(fn, ast, indices, region, self.random) - break - except RandomizationFailure: - pass diff --git a/tools/decomp-permuter/src/scorer.py b/tools/decomp-permuter/src/scorer.py deleted file mode 100644 index 8bac38f506..0000000000 --- a/tools/decomp-permuter/src/scorer.py +++ /dev/null @@ -1,145 +0,0 @@ -from dataclasses import dataclass, field -import difflib -import hashlib -import re -from typing import Tuple, List, Optional -from collections import Counter - - -from .objdump import objdump, get_arch - - -@dataclass(init=False, unsafe_hash=True) -class DiffAsmLine: - line: str = field(compare=False) - mnemonic: str - - def __init__(self, line: str) -> None: - self.line = line - self.mnemonic = line.split("\t")[0] - - -class Scorer: - PENALTY_INF = 10 ** 9 - - PENALTY_STACKDIFF = 1 - PENALTY_REGALLOC = 5 - PENALTY_REORDERING = 60 - PENALTY_INSERTION = 100 - PENALTY_DELETION = 100 - - def __init__(self, target_o: str, *, stack_differences: bool): - self.target_o = target_o - self.arch = get_arch(target_o) - self.stack_differences = stack_differences - _, self.target_seq = self._objdump(target_o) - self.differ: difflib.SequenceMatcher[DiffAsmLine] = difflib.SequenceMatcher( - autojunk=False - ) - self.differ.set_seq2(self.target_seq) - - def _objdump(self, o_file: str) -> Tuple[str, List[DiffAsmLine]]: - ret = [] - lines = objdump(o_file, self.arch, stack_differences=self.stack_differences) - for line in lines: - ret.append(DiffAsmLine(line)) - return "\n".join(lines), ret - - def score(self, cand_o: Optional[str]) -> Tuple[int, str]: - if not cand_o: - return Scorer.PENALTY_INF, "" - - objdump_output, cand_seq = self._objdump(cand_o) - - score = 0 - deletions = [] - insertions = [] - - def lo_hi_match(old: str, new: str) -> bool: - old_lo = old.find("%lo") - old_hi = old.find("%hi") - new_lo = new.find("%lo") - new_hi = new.find("%hi") - - if old_lo != -1 and new_lo != -1: - old_idx = old_lo - new_idx = new_lo - elif old_hi != -1 and new_hi != -1: - old_idx = old_hi - new_idx = new_hi - else: - return False - - if old[:old_idx] != new[:new_idx]: - return False - - old_inner = old[old_idx + 4 : -1] - new_inner = new[new_idx + 4 : -1] - return old_inner.startswith(".") or new_inner.startswith(".") - - def diff_sameline(old: str, new: str) -> None: - nonlocal score - if old == new: - return - - if lo_hi_match(old, new): - return - - ignore_last_field = False - if self.stack_differences: - oldsp = re.search(self.arch.re_sprel, old) - newsp = re.search(self.arch.re_sprel, new) - if oldsp and newsp: - oldrel = int(oldsp.group(1) or "0", 0) - newrel = int(newsp.group(1) or "0", 0) - score += abs(oldrel - newrel) * self.PENALTY_STACKDIFF - ignore_last_field = True - - # Probably regalloc difference, or signed vs unsigned - - # Compare each field in order - newfields, oldfields = new.split(","), old.split(",") - if ignore_last_field: - newfields = newfields[:-1] - oldfields = oldfields[:-1] - for nf, of in zip(newfields, oldfields): - if nf != of: - score += self.PENALTY_REGALLOC - # Penalize any extra fields - score += abs(len(newfields) - len(oldfields)) * self.PENALTY_REGALLOC - - def diff_insert(line: str) -> None: - # Reordering or totally different codegen. - # Defer this until later when we can tell. - insertions.append(line) - - def diff_delete(line: str) -> None: - deletions.append(line) - - self.differ.set_seq1(cand_seq) - for (tag, i1, i2, j1, j2) in self.differ.get_opcodes(): - if tag == "equal": - for k in range(i2 - i1): - old = self.target_seq[j1 + k].line - new = cand_seq[i1 + k].line - diff_sameline(old, new) - if tag == "replace" or tag == "delete": - for k in range(i1, i2): - diff_insert(cand_seq[k].line) - if tag == "replace" or tag == "insert": - for k in range(j1, j2): - diff_delete(self.target_seq[k].line) - - insertions_co = Counter(insertions) - deletions_co = Counter(deletions) - for item in insertions_co + deletions_co: - ins = insertions_co[item] - dels = deletions_co[item] - common = min(ins, dels) - score += ( - (ins - common) * self.PENALTY_INSERTION - + (dels - common) * self.PENALTY_DELETION - + self.PENALTY_REORDERING * common - ) - - return (score, hashlib.sha256(objdump_output.encode()).hexdigest()) diff --git a/tools/decomp-permuter/strip_other_fns.py b/tools/decomp-permuter/strip_other_fns.py deleted file mode 100644 index af29a01df6..0000000000 --- a/tools/decomp-permuter/strip_other_fns.py +++ /dev/null @@ -1,74 +0,0 @@ -import re -import argparse -from typing import Optional -from pathlib import Path - - -def _find_bracket_end(input: str, start_index: int) -> int: - level = 1 - assert input[start_index] == "{" - i = start_index + 1 - while i < len(input): - if input[i] == "{": - level += 1 - elif input[i] == "}": - level -= 1 - if level == 0: - break - i += 1 - - assert level == 0, "unbalanced {}" - return i - - -def strip_other_fns(source: str, keep_fn_name: str) -> str: - result = "" - remain = source - while True: - fn_regex = re.compile(r"^.*\s+\**(\w+)\(.*\)\s*?{", re.M) - fn = re.search(fn_regex, remain) - if fn is None: - result += remain - remain = "" - break - - fn_name = fn.group(1) - bracket_end = _find_bracket_end(remain, fn.end() - 1) - if fn_name.startswith("PERM"): - result += remain[: bracket_end + 1] - elif fn_name == keep_fn_name: - result += "\n\n" + remain[: bracket_end + 1] + "\n\n" - else: - result += remain[: fn.end() - 1].rstrip() + ";" - - remain = remain[bracket_end + 1 :] - - return result - - -def strip_other_fns_and_write( - source: str, fn_name: str, out_filename: Optional[str] = None -) -> None: - stripped = strip_other_fns(source, fn_name) - - if out_filename is None: - print(stripped) - else: - with open(out_filename, "w", encoding="utf-8") as f: - f.write(stripped) - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Remove all but a single function definition from a file." - ) - parser.add_argument("c_file", help="File containing the function.") - parser.add_argument("fn_name", help="Function name.") - args = parser.parse_args() - - source = Path(args.c_file).read_text() - strip_other_fns_and_write(source, args.fn_name, args.c_file) - - -if __name__ == "__main__": - main() diff --git a/tools/decomp-permuter/stubs/pycparser/__init__.py b/tools/decomp-permuter/stubs/pycparser/__init__.py deleted file mode 100644 index 06e6865b6c..0000000000 --- a/tools/decomp-permuter/stubs/pycparser/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -#----------------------------------------------------------------- -# pycparser: __init__.py -# -# This package file exports some convenience functions for -# interacting with pycparser -# -# Eli Bendersky [https://eli.thegreenplace.net/] -# License: BSD -#----------------------------------------------------------------- -__all__ = ['c_parser', 'c_ast'] -__version__ = '2.19' - -from typing import Any, List, Union -from . import c_ast -from .c_parser import CParser - -def preprocess_file(filename: str, cpp_path: str='cpp', cpp_args: Union[List[str], str]='') -> str: ... -def parse_file(filename: str, use_cpp: bool=False, cpp_path: str='cpp', cpp_args: str='', parser: Any=None) -> c_ast.FileAST: ... diff --git a/tools/decomp-permuter/stubs/pycparser/c_ast.py b/tools/decomp-permuter/stubs/pycparser/c_ast.py deleted file mode 100644 index 5f41f6c650..0000000000 --- a/tools/decomp-permuter/stubs/pycparser/c_ast.py +++ /dev/null @@ -1,719 +0,0 @@ -# ----------------------------------------------------------------- -# pycparser: c_ast.py -# -# AST Node classes. -# -# Eli Bendersky [https://eli.thegreenplace.net/] -# License: BSD -# ----------------------------------------------------------------- - - -from typing import TextIO, Iterable, List, Any, Optional, Union as Union_ -from .plyparser import Coord -import sys - - -class Node(object): - coord: Optional[Coord] - - def __repr__(self) -> str: - ... - - def __iter__(self) -> Iterable[Node]: - ... - - def children(self) -> Iterable[Node]: - ... - - def show( - self, - buf: TextIO = sys.stdout, - offset: int = 0, - attrnames: bool = False, - nodenames: bool = False, - showcoord: bool = False, - ) -> None: - ... - - -Expression = Union_[ - "ArrayRef", - "Assignment", - "BinaryOp", - "Cast", - "CompoundLiteral", - "Constant", - "ExprList", - "FuncCall", - "ID", - "TernaryOp", - "UnaryOp", -] -Statement = Union_[ - Expression, - "Break", - "Case", - "Compound", - "Continue", - "Decl", - "Default", - "DoWhile", - "EmptyStatement", - "For", - "Goto", - "If", - "Label", - "Return", - "Switch", - "Typedef", - "While", - "Pragma", -] -Type = Union_["PtrDecl", "ArrayDecl", "FuncDecl", "TypeDecl"] -InnerType = Union_["IdentifierType", "Struct", "Union", "Enum"] -ExternalDeclaration = Union_["FuncDef", "Decl", "Typedef", "Pragma"] -AnyNode = Union_[ - Statement, - Type, - InnerType, - "FuncDef", - "EllipsisParam", - "Enumerator", - "EnumeratorList", - "FileAST", - "InitList", - "NamedInitializer", - "ParamList", - "Typename", -] - - -class NodeVisitor: - def visit(self, node: Node) -> None: - ... - - def generic_visit(self, node: Node) -> None: - ... - - def visit_ArrayDecl(self, node: ArrayDecl) -> None: - ... - - def visit_ArrayRef(self, node: ArrayRef) -> None: - ... - - def visit_Assignment(self, node: Assignment) -> None: - ... - - def visit_BinaryOp(self, node: BinaryOp) -> None: - ... - - def visit_Break(self, node: Break) -> None: - ... - - def visit_Case(self, node: Case) -> None: - ... - - def visit_Cast(self, node: Cast) -> None: - ... - - def visit_Compound(self, node: Compound) -> None: - ... - - def visit_CompoundLiteral(self, node: CompoundLiteral) -> None: - ... - - def visit_Constant(self, node: Constant) -> None: - ... - - def visit_Continue(self, node: Continue) -> None: - ... - - def visit_Decl(self, node: Decl) -> None: - ... - - def visit_DeclList(self, node: DeclList) -> None: - ... - - def visit_Default(self, node: Default) -> None: - ... - - def visit_DoWhile(self, node: DoWhile) -> None: - ... - - def visit_EllipsisParam(self, node: EllipsisParam) -> None: - ... - - def visit_EmptyStatement(self, node: EmptyStatement) -> None: - ... - - def visit_Enum(self, node: Enum) -> None: - ... - - def visit_Enumerator(self, node: Enumerator) -> None: - ... - - def visit_EnumeratorList(self, node: EnumeratorList) -> None: - ... - - def visit_ExprList(self, node: ExprList) -> None: - ... - - def visit_FileAST(self, node: FileAST) -> None: - ... - - def visit_For(self, node: For) -> None: - ... - - def visit_FuncCall(self, node: FuncCall) -> None: - ... - - def visit_FuncDecl(self, node: FuncDecl) -> None: - ... - - def visit_FuncDef(self, node: FuncDef) -> None: - ... - - def visit_Goto(self, node: Goto) -> None: - ... - - def visit_ID(self, node: ID) -> None: - ... - - def visit_IdentifierType(self, node: IdentifierType) -> None: - ... - - def visit_If(self, node: If) -> None: - ... - - def visit_InitList(self, node: InitList) -> None: - ... - - def visit_Label(self, node: Label) -> None: - ... - - def visit_NamedInitializer(self, node: NamedInitializer) -> None: - ... - - def visit_ParamList(self, node: ParamList) -> None: - ... - - def visit_PtrDecl(self, node: PtrDecl) -> None: - ... - - def visit_Return(self, node: Return) -> None: - ... - - def visit_Struct(self, node: Struct) -> None: - ... - - def visit_StructRef(self, node: StructRef) -> None: - ... - - def visit_Switch(self, node: Switch) -> None: - ... - - def visit_TernaryOp(self, node: TernaryOp) -> None: - ... - - def visit_TypeDecl(self, node: TypeDecl) -> None: - ... - - def visit_Typedef(self, node: Typedef) -> None: - ... - - def visit_Typename(self, node: Typename) -> None: - ... - - def visit_UnaryOp(self, node: UnaryOp) -> None: - ... - - def visit_Union(self, node: Union) -> None: - ... - - def visit_While(self, node: While) -> None: - ... - - def visit_Pragma(self, node: Pragma) -> None: - ... - - -class ArrayDecl(Node): - type: Type - dim: Optional[Expression] - dim_quals: List[str] - - def __init__( - self, - type: Type, - dim: Optional[Node], - dim_quals: List[str], - coord: Optional[Coord] = None, - ): - ... - - -class ArrayRef(Node): - name: Expression - subscript: Expression - - def __init__(self, name: Node, subscript: Node, coord: Optional[Coord] = None): - ... - - -class Assignment(Node): - op: str - lvalue: Expression - rvalue: Expression - - def __init__( - self, - op: str, - lvalue: Expression, - rvalue: Expression, - coord: Optional[Coord] = None, - ): - ... - - -class BinaryOp(Node): - op: str - left: Expression - right: Expression - - def __init__(self, op: str, left: Node, right: Node, coord: Optional[Coord] = None): - ... - - -class Break(Node): - def __init__(self, coord: Optional[Coord] = None): - ... - - -class Case(Node): - expr: Expression - stmts: List[Statement] - - def __init__( - self, expr: Expression, stmts: List[Statement], coord: Optional[Coord] = None - ): - ... - - -class Cast(Node): - to_type: "Typename" - expr: Expression - - def __init__( - self, to_type: "Typename", expr: Expression, coord: Optional[Coord] = None - ): - ... - - -class Compound(Node): - block_items: Optional[List[Statement]] - - def __init__( - self, block_items: Optional[List[Statement]], coord: Optional[Coord] = None - ): - ... - - -class CompoundLiteral(Node): - type: "Typename" - init: "InitList" - - def __init__( - self, type: "Typename", init: "InitList", coord: Optional[Coord] = None - ): - ... - - -class Constant(Node): - type: str - value: str - - def __init__(self, type: str, value: str, coord: Optional[Coord] = None): - ... - - -class Continue(Node): - def __init__(self, coord: Optional[Coord] = None): - ... - - -class Decl(Node): - name: Optional[str] - quals: List[str] # e.g. const - storage: List[str] # e.g. register - funcspec: List[str] # e.g. inline - type: Union_[Type, "Struct", "Union", "Enum"] - init: Optional[Union_[Expression, "InitList"]] - bitsize: Optional[Expression] - - def __init__( - self, - name: Optional[str], - quals: List[str], - storage: List[str], - funcspec: List[str], - type: Union_[Type, "Struct", "Union", "Enum"], - init: Optional[Union_[Expression, "InitList"]], - bitsize: Optional[Expression], - coord: Optional[Coord] = None, - ): - ... - - -class DeclList(Node): - decls: List[Decl] - - def __init__(self, decls: List[Decl], coord: Optional[Coord] = None): - ... - - -class Default(Node): - stmts: List[Statement] - - def __init__(self, stmts: List[Statement], coord: Optional[Coord] = None): - ... - - -class DoWhile(Node): - cond: Expression - stmt: Statement - - def __init__( - self, cond: Expression, stmt: Statement, coord: Optional[Coord] = None - ): - ... - - -class EllipsisParam(Node): - def __init__(self, coord: Optional[Coord] = None): - ... - - -class EmptyStatement(Node): - def __init__(self, coord: Optional[Coord] = None): - ... - - -class Enum(Node): - name: Optional[str] - values: "Optional[EnumeratorList]" - - def __init__( - self, - name: Optional[str], - values: "Optional[EnumeratorList]", - coord: Optional[Coord] = None, - ): - ... - - -class Enumerator(Node): - name: str - value: Optional[Expression] - - def __init__( - self, name: str, value: Optional[Expression], coord: Optional[Coord] = None - ): - ... - - -class EnumeratorList(Node): - enumerators: List[Enumerator] - - def __init__(self, enumerators: List[Enumerator], coord: Optional[Coord] = None): - ... - - -class ExprList(Node): - exprs: List[Union_[Expression, Typename]] # typename only for offsetof - - def __init__( - self, exprs: List[Union_[Expression, Typename]], coord: Optional[Coord] = None - ): - ... - - -class FileAST(Node): - ext: List[ExternalDeclaration] - - def __init__(self, ext: List[ExternalDeclaration], coord: Optional[Coord] = None): - ... - - -class For(Node): - init: Union_[None, Expression, DeclList] - cond: Optional[Expression] - next: Optional[Expression] - stmt: Statement - - def __init__( - self, - init: Union_[None, Expression, DeclList], - cond: Optional[Expression], - next: Optional[Expression], - stmt: Statement, - coord: Optional[Coord] = None, - ): - ... - - -class FuncCall(Node): - name: Expression - args: Optional[ExprList] - - def __init__( - self, name: Expression, args: Optional[ExprList], coord: Optional[Coord] = None - ): - ... - - -class FuncDecl(Node): - args: Optional[ParamList] - type: Type # return type - - def __init__( - self, args: Optional[ParamList], type: Type, coord: Optional[Coord] = None - ): - ... - - -class FuncDef(Node): - decl: Decl - param_decls: Optional[List[Decl]] - body: Compound - - def __init__( - self, - decl: Decl, - param_decls: Optional[List[Decl]], - body: Compound, - coord: Optional[Coord] = None, - ): - ... - - -class Goto(Node): - name: str - - def __init__(self, name: str, coord: Optional[Coord] = None): - ... - - -class ID(Node): - name: str - - def __init__(self, name: str, coord: Optional[Coord] = None): - ... - - -class IdentifierType(Node): - names: List[str] # e.g. ['long', 'int'] - - def __init__(self, names: List[str], coord: Optional[Coord] = None): - ... - - -class If(Node): - cond: Expression - iftrue: Statement - iffalse: Optional[Statement] - - def __init__( - self, - cond: Expression, - iftrue: Statement, - iffalse: Optional[Statement], - coord: Optional[Coord] = None, - ): - ... - - -class InitList(Node): - exprs: List[Union_[Expression, "NamedInitializer"]] - - def __init__( - self, - exprs: List[Union_[Expression, "NamedInitializer"]], - coord: Optional[Coord] = None, - ): - ... - - -class Label(Node): - name: str - stmt: Statement - - def __init__(self, name: str, stmt: Statement, coord: Optional[Coord] = None): - ... - - -class NamedInitializer(Node): - name: List[Expression] # [ID(x), Constant(4)] for {.x[4] = ...} - expr: Expression - - def __init__( - self, name: List[Expression], expr: Expression, coord: Optional[Coord] = None - ): - ... - - -class ParamList(Node): - params: List[Union_[Decl, ID, Typename, EllipsisParam]] - - def __init__( - self, - params: List[Union_[Decl, ID, Typename, EllipsisParam]], - coord: Optional[Coord] = None, - ): - ... - - -class PtrDecl(Node): - quals: List[str] - type: Type - - def __init__(self, quals: List[str], type: Type, coord: Optional[Coord] = None): - ... - - -class Return(Node): - expr: Optional[Expression] - - def __init__(self, expr: Optional[Expression], coord: Optional[Coord] = None): - ... - - -class Struct(Node): - name: Optional[str] - decls: Optional[List[Union_[Decl, Pragma]]] - - def __init__( - self, - name: Optional[str], - decls: Optional[List[Union_[Decl, Pragma]]], - coord: Optional[Coord] = None, - ): - ... - - -class StructRef(Node): - name: Expression - type: str - field: ID - - def __init__( - self, name: Expression, type: str, field: ID, coord: Optional[Coord] = None - ): - ... - - -class Switch(Node): - cond: Expression - stmt: Statement - - def __init__( - self, cond: Expression, stmt: Statement, coord: Optional[Coord] = None - ): - ... - - -class TernaryOp(Node): - cond: Expression - iftrue: Expression - iffalse: Expression - - def __init__( - self, - cond: Expression, - iftrue: Expression, - iffalse: Expression, - coord: Optional[Coord] = None, - ): - ... - - -class TypeDecl(Node): - declname: Optional[str] - quals: List[str] - type: InnerType - - def __init__( - self, - declname: Optional[str], - quals: List[str], - type: InnerType, - coord: Optional[Coord] = None, - ): - ... - - -class Typedef(Node): - name: str - quals: List[str] - storage: List[str] - type: Type - - def __init__( - self, - name: str, - quals: List[str], - storage: List[str], - type: Type, - coord: Optional[Coord] = None, - ): - ... - - -class Typename(Node): - name: None - quals: List[str] - type: Type - - def __init__( - self, name: None, quals: List[str], type: Type, coord: Optional[Coord] = None - ): - ... - - -class UnaryOp(Node): - op: str - expr: Union_[Expression, Typename] - - def __init__( - self, op: str, expr: Union_[Expression, Typename], coord: Optional[Coord] = None - ): - ... - - -class Union(Node): - name: Optional[str] - decls: Optional[List[Union_[Decl, Pragma]]] - - def __init__( - self, - name: Optional[str], - decls: Optional[List[Union_[Decl, Pragma]]], - coord: Optional[Coord] = None, - ): - ... - - -class While(Node): - cond: Expression - stmt: Statement - - def __init__( - self, cond: Expression, stmt: Statement, coord: Optional[Coord] = None - ): - ... - - -class Pragma(Node): - string: str - - def __init__(self, string: str, coord: Optional[Coord] = None): - ... diff --git a/tools/decomp-permuter/stubs/pycparser/c_generator.py b/tools/decomp-permuter/stubs/pycparser/c_generator.py deleted file mode 100644 index ea3b70d074..0000000000 --- a/tools/decomp-permuter/stubs/pycparser/c_generator.py +++ /dev/null @@ -1,13 +0,0 @@ -#------------------------------------------------------------------------------ -# pycparser: c_generator.py -# -# C code generator from pycparser AST nodes. -# -# Eli Bendersky [https://eli.thegreenplace.net/] -# License: BSD -#------------------------------------------------------------------------------ -from . import c_ast - -class CGenerator: - def __init__(self) -> None: ... - def visit(self, node: c_ast.Node) -> str: ... diff --git a/tools/decomp-permuter/stubs/pycparser/c_parser.py b/tools/decomp-permuter/stubs/pycparser/c_parser.py deleted file mode 100644 index fac2fa8388..0000000000 --- a/tools/decomp-permuter/stubs/pycparser/c_parser.py +++ /dev/null @@ -1,15 +0,0 @@ -#------------------------------------------------------------------------------ -# pycparser: c_parser.py -# -# CParser class: Parser and AST builder for the C language -# -# Eli Bendersky [https://eli.thegreenplace.net/] -# License: BSD -#------------------------------------------------------------------------------ - -from . import c_ast - -class CParser: - def __init__(self) -> None: ... - def parse(self, text: str, filename: str='', debuglevel: int=0) -> c_ast.FileAST: ... - diff --git a/tools/decomp-permuter/stubs/pycparser/plyparser.py b/tools/decomp-permuter/stubs/pycparser/plyparser.py deleted file mode 100644 index e63c53efb7..0000000000 --- a/tools/decomp-permuter/stubs/pycparser/plyparser.py +++ /dev/null @@ -1,27 +0,0 @@ -# ----------------------------------------------------------------- -# plyparser.py -# -# PLYParser class and other utilites for simplifying programming -# parsers with PLY -# -# Eli Bendersky [https://eli.thegreenplace.net/] -# License: BSD -# ----------------------------------------------------------------- - -from typing import Optional - - -class Coord: - file: str - line: int - column: Optional[int] - - def __init__(self, file: str, line: int, column: Optional[int] = None): - ... - - def __str__(self) -> str: - ... - - -class ParseError(Exception): - pass diff --git a/tools/decomp-permuter/test.py b/tools/decomp-permuter/test.py deleted file mode 100755 index aecdb14ac5..0000000000 --- a/tools/decomp-permuter/test.py +++ /dev/null @@ -1,13 +0,0 @@ -#!/usr/bin/env python3 -import sys - -from pycparser import parse_file, c_generator - -fname = "test.c" if len(sys.argv) < 2 else sys.argv[1] - - -# ast = c_parser.CParser().parse(src) -ast = parse_file(fname, use_cpp=True) -# ast.show() -# print(c_generator.CGenerator().visit(ast)) -print(ast) diff --git a/tools/decomp-permuter/test/__init__.py b/tools/decomp-permuter/test/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tools/decomp-permuter/test/compile.sh b/tools/decomp-permuter/test/compile.sh deleted file mode 100755 index 7788134bf7..0000000000 --- a/tools/decomp-permuter/test/compile.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -mips-linux-gnu-gcc -O2 -fno-PIC -fno-common -ffreestanding -mno-shared -mno-abicalls -G 0 -c "$@" diff --git a/tools/decomp-permuter/test/test_perm.py b/tools/decomp-permuter/test/test_perm.py deleted file mode 100644 index f3cc520cef..0000000000 --- a/tools/decomp-permuter/test/test_perm.py +++ /dev/null @@ -1,247 +0,0 @@ -import os -from pathlib import Path -import re -import shutil -import tempfile -from typing import Any, Optional -import unittest - -from src.compiler import Compiler -from src.preprocess import preprocess -from src import main - - -class TestPermMacros(unittest.TestCase): - def go( - self, - intro: str, - outro: str, - base: str, - target: str, - *, - fn_name: Optional[str] = None, - **kwargs: Any - ) -> int: - base = intro + "\n" + base + "\n" + outro - target = intro + "\n" + target + "\n" + outro - compiler = Compiler("test/compile.sh", show_errors=True) - - # For debugging, to avoid the auto-deleted directory: - # target_dir = tempfile.mkdtemp() - with tempfile.TemporaryDirectory() as target_dir: - with open(os.path.join(target_dir, "base.c"), "w") as f: - f.write(base) - - target_o = compiler.compile(target, show_errors=True) - assert target_o is not None - shutil.move(target_o, os.path.join(target_dir, "target.o")) - - shutil.copy2("test/compile.sh", os.path.join(target_dir, "compile.sh")) - - if fn_name: - with open(os.path.join(target_dir, "function.txt"), "w") as f: - f.write(fn_name) - - opts = main.Options(directories=[target_dir], stop_on_zero=True, **kwargs) - return main.run(opts)[0] - - def test_general(self) -> None: - score = self.go( - "int test() {", - "}", - "return PERM_GENERAL(32,64);", - "return 64;", - ) - self.assertEqual(score, 0) - - def test_not_found(self) -> None: - score = self.go( - "int test() {", - "}", - "return PERM_GENERAL(32,64);", - "return 92;", - ) - self.assertNotEqual(score, 0) - - def test_multiple_functions(self) -> None: - score = self.go( - "", - "", - """ - int ignoreme() {} - int foo() { return PERM_GENERAL(32,64); } - int ignoreme2() {} - """, - "int foo() { return 64; }", - fn_name="foo", - ) - self.assertEqual(score, 0) - - def test_general_multiple(self) -> None: - score = self.go( - "int test() {", - "}", - "return PERM_GENERAL(1,2,3) + PERM_GENERAL(3,6,9);", - "return 9;", - ) - self.assertEqual(score, 0) - - def test_general_nested(self) -> None: - score = self.go( - "int test() {", - "}", - "return PERM_GENERAL(1,PERM_GENERAL(100,101),3) + PERM_GENERAL(3,6,9);", - "return 110;", - ) - self.assertEqual(score, 0) - - def test_cast(self) -> None: - score = self.go( - "int test(int a, int b) {", - "}", - "return a / PERM_GENERAL(,(unsigned int),(float)) b;", - "return a / (float) b;", - ) - self.assertEqual(score, 0) - - def test_cast_threaded(self) -> None: - score = self.go( - "int test(int a, int b) {", - "}", - "return a / PERM_GENERAL(,(unsigned int),(float)) b;", - "return a / (float) b;", - threads=2, - ) - self.assertEqual(score, 0) - - def test_ignore(self) -> None: - score = self.go( - "int test(int a, int b) {", - "}", - "PERM_IGNORE( return a / PERM_GENERAL(a, b); )", - "return a / b;", - ) - self.assertEqual(score, 0) - - def test_pretend(self) -> None: - score = self.go( - "int global;", - "", - """ - PERM_IGNORE( inline void foo() { ) - PERM_PRETEND( void foo(); void bar() { ) - PERM_RANDOMIZE( - global = 1; - ) - PERM_IGNORE( } void bar() { ) - PERM_RANDOMIZE( - global = 2; foo(); - ) - } - """, - """ - inline void foo() { global = 1; } - void bar() { foo(); global = 2; } - """, - fn_name="bar", - ) - self.assertEqual(score, 0) - - def test_once1(self) -> None: - score = self.go( - "volatile int A, B, C; void test() {", - "}", - """ - PERM_ONCE(B = 2;) - A = 1; - PERM_ONCE(B = 2;) - C = 3; - PERM_ONCE(B = 2;) - """, - "A = 1; B = 2; C = 3;", - ) - self.assertEqual(score, 0) - - def test_once2(self) -> None: - score = self.go( - "volatile int A, B, C; void test() {", - "}", - """ - PERM_VAR(emit,) - PERM_VAR(bademit,) - PERM_ONCE(1, PERM_VAR(bademit, A = 7;) A = 2;) - PERM_ONCE(1, PERM_VAR(emit, A = 1;)) - PERM_VAR(emit) - PERM_VAR(bademit) - PERM_ONCE(2, B = 2;) - PERM_ONCE(2, B = 1;) - PERM_ONCE(2,) - PERM_ONCE(3, PERM_VAR(bademit, A = 9)) - PERM_ONCE(3, PERM_VAR(bademit, A = 9)) - C = 3; - """, - "A = 1; B = 2; C = 3;", - ) - self.assertEqual(score, 0) - - def test_lineswap(self) -> None: - score = self.go( - "void a(); void b(); void c(); void test(void) {", - "}", - """ - PERM_LINESWAP( - a(); - b(); - c(); - ) - """, - "b(); a(); c();", - ) - self.assertEqual(score, 0) - - def test_lineswap_text(self) -> None: - score = self.go( - "void a(); void b(); void c(); void test(void) {", - "}", - """ - PERM_LINESWAP_TEXT( - a(); - b(); - c(); - ) - """, - "b(); a(); c();", - ) - self.assertEqual(score, 0) - - def test_randomizer(self) -> None: - score = self.go( - "void foo(); void bar(); void test(void) {", - "}", - "PERM_RANDOMIZE(bar(); foo();)", - "foo(); bar();", - ) - self.assertEqual(score, 0) - - def test_auto_randomizer(self) -> None: - score = self.go( - "void foo(); void bar(); void test(void) {", - "}", - "bar(); foo();", - "foo(); bar();", - ) - self.assertEqual(score, 0) - - def test_randomizer_threaded(self) -> None: - score = self.go( - "void foo(); void bar(); void test(void) {", - "}", - "PERM_RANDOMIZE(bar(); foo();)", - "foo(); bar();", - threads=2, - ) - self.assertEqual(score, 0) - - -if __name__ == "__main__": - unittest.main()