Remove decomp-permuter (#447)

* Nuke decomp permuter

* Add decomp permuter and mips2c to gitignore
This commit is contained in:
Anghelo Carvajal
2021-11-11 12:46:18 -03:00
committed by GitHub
parent 2e5c142f3b
commit c018e83d36
88 changed files with 2 additions and 12846 deletions
+2
View File
@@ -33,6 +33,8 @@ tools/ido_recomp/* binary
ctx.c
graphs/
*.c.m2c
tools/decomp-permuter/
tools/mips_to_c/
# Assets
*.png
-45
View File
@@ -1,45 +0,0 @@
name: Systray
on:
push:
branches: [ main ]
paths:
- 'src/net/cmd/systray/*'
pull_request:
branches: [ main ]
paths:
- 'src/net/cmd/systray/*'
jobs:
build:
name: Build on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
include:
- os: ubuntu-16.04
binary: permuter-systray-linux
- os: windows-latest
binary: permuter-systray.exe
- os: macos-latest
binary: permuter-systray-macos
steps:
- uses: actions/checkout@main
- name: Install gtk3
if: ${{ matrix.os == 'ubuntu-16.04' }}
run: sudo apt-get install libgtk-3-dev libappindicator3-dev
- name: Setup Go environment
uses: actions/setup-go@v2.1.3
- name: Build
run: go build -o ${{ matrix.binary }} -ldflags "-s -w" tray.go
working-directory: src/net/cmd/systray/
- name: Upload artifact
uses: actions/upload-artifact@v2
with:
name: ${{ matrix.binary }}
path: src/net/cmd/systray/${{ matrix.binary }}
-11
View File
@@ -1,11 +0,0 @@
*.o
*.s
*.c
*.py[cod]
.mypy_cache/
.cache/
__pycache__/
!test/*.c
/nonmatchings
.vscode/
pah.conf
-12
View File
@@ -1,12 +0,0 @@
; DO NOT EDIT (unless you know what you are doing)
;
; This subdirectory is a git "subrepo", and this file is maintained by the
; git-subrepo command. See https://github.com/git-commands/git-subrepo#readme
;
[subrepo]
remote = https://github.com/simonlindholm/decomp-permuter.git
branch = main
commit = a20bac9422b6d8adbf7c06473c2ae3c3fee16be5
parent = 2668eec556c01fa2f4c16a203c93c208dc03e639
method = merge
cmdver = 0.4.3
@@ -1,6 +0,0 @@
repos:
- repo: https://github.com/psf/black
rev: 20.8b1
hooks:
- id: black
language_version: python3.6
-21
View File
@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2019 Simon Lindholm and contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-120
View File
@@ -1,120 +0,0 @@
# Decomp permuter
Automatically permutes C files to better match a target binary. The permuter has two modes of operation:
- Random: purely at random, introduce temporary variables for values, change types, put statements on the same line...
- Manual: test all combinations of user-specified variations, using macros like `PERM_GENERAL(a = b ? c : d;, if (b) a = c; else a = d;)` to try both specified alternatives.
The modes can also be combined, by using the `PERM_RANDOMIZE` macro.
[<img src="https://asciinema.org/a/232846.svg" height="300">](https://asciinema.org/a/232846)
The main target for the tool is MIPS code compiled by old compilers (IDO, possibly GCC).
Getting it to work on other architectures shouldn't be too hard, however.
https://github.com/laqieer/decomp-permuter-arm has an ARM port.
## Usage
`./permuter.py directory/` runs the permuter; see below for the meaning of the directory.
Pass `-h` to see possible flags. `-j` is suggested (enables multi-threaded mode).
You'll first need to install a couple of prerequisites: `python3 -m pip install pycparser pynacl toml` (also `dataclasses` if on Python 3.6 or below)
The permuter expects as input one or more directory containing:
- a .c file with a single function,
- a .o file to match,
- a .sh file that compiles the .c file.
For projects with a properly configured makefile, you should be able to set these up by running
```
./import.py <path/to/file.c> <path/to/file.s>
```
where file.c contains the function to be permuted, and file.s is its assembly in a self-contained file.
Otherwise, see USAGE.md for more details.
For projects using Ninja instead of Make, add a `permuter_settings.toml` in the root or `tools/` directory of the project:
```toml
build_system = "ninja"
```
Then `import.py` should work as expected if `build.ninja` is at the root of the project.
The .c file may be modified with any of the following macros which affect manual permutation:
- `PERM_GENERAL(a, b, ...)` expands to any of `a`, `b`, ...
- `PERM_VAR(a, b)` sets the meta-variable `a` to `b`, `PERM_VAR(a)` expands to the meta-variable `a`.
- `PERM_RANDOMIZE(code)` expands to `code`, but allows randomization within that region. Multiple regions may be specified.
- `PERM_LINESWAP(lines)` expands to a permutation of the ordered set of non-whitespace lines (split by `\n`). Each line must contain zero or more complete C statements. (For incomplete statements use `PERM_LINESWAP_TEXT`, which is slower because it has to repeatedly parse C code.)
- `PERM_INT(lo, hi)` expands to an integer between `lo` and `hi` (which must be constants).
- `PERM_IGNORE(code)` expands to `code`, without passing it through the C parser library (pycparser)/randomizer. This can be used to avoid parse errors for non-standard C, e.g. `asm` blocks.
- `PERM_PRETEND(code)` expands to `code` for the purpose of the C parser/randomizer, but gets removed afterwards. This can be used together with `PERM_IGNORE` to enable the permuter to deal with input it isn't designed for (e.g. inline functions, C++, non-code).
- `PERM_ONCE([key,] code)` expands to either `code` or to nothing, such that each unique key gets expanded exactly once. `key` defaults to `code`. For example, `PERM_ONCE(a;) b; PERM_ONCE(a;)` expands to either `a; b;` or `b; a;`.
Arguments are split by a commas, exluding commas inside parenthesis. `(,)` is a special escape sequence that resolves to `,`.
Nested macros are allowed, so e.g.
```
PERM_VAR(delayed, )
PERM_GENERAL(stmt;, PERM_VAR(delayed, stmt;))
...
PERM_VAR(delayed)
```
is an alternative way of writing `PERM_ONCE`.
## permuter@home
The permuter supports a distributed mode, where people can donate processor power to your permuter runs to speed them up.
To use this, pass `-J` when running `permuter.py` and follow the instructions.
You will need to be granted access by someone who is already connected to a permuter network.
To allow others to use your computer for permuter runs, do the following:
- install Docker (used for sandboxing and to ensure a consistent environment)
- if on Linux, add yourself to the Docker group: `sudo usermod -aG docker $USER`
- install required packages: `python3 -m pip install docker`
- open a terminal, and run `./pah.py run-server` to start the server.
There are a few required arguments (e.g. how many cores to use), see `--help` for more details.
Please be aware that being in the Docker group implies (password-less) sudo rights.
You can avoid that for your personal account by running the permuter under a separate user.
Unfortunately, there is currently no way to run a sandboxed permuter server without sudo rights. 😢
Anyone who is granted access to permuter@home can run a server.
To set up a new permuter network, see [src/net/controller/README.md](./src/net/controller/README.md).
## FAQ
**What do the scores mean?** The scores are computed by taking diffs of objdump'd .o
files, and giving different penalties for lines that are the same/use the same
instruction/are reordered/don't match at all. 0 means the function matches fully.
Stack positions are ignored unless --stack-diffs is passed (but beware that the
permuter is currently quite bad at resolving stack differences). For more details,
see scorer.py. It's far from a perfect system, and should probably be tweaked to
look at e.g. the register diff graph.
**What sort of non-matchings are the permuter good at?** It's generally best towards
the end, when mostly regalloc changes remain. If there are reorderings or functional
changes, it's often easy to resolve those by hand, and neither the scorer nor the
randomizer tends to play well with them.
**Should I use this instead of trying to match code by hand?** No, but it can be a good
complement. PERM macros can be used to quickly test lots of variations of a function at
once, in cases where there are interactions between several parts of a function.
The randomization mode often finds lots of nonsensical changes that improve regalloc
"by accident"; it's up to you to pick out the ones that look sensible. If none do,
it can still be useful to know which parts of the function need to be changed to get the
code nearer to matching. Having made one of the improvements, and the function can then be
permuted again, to find further possible improvements.
## Helping out
There's tons of room for helping out with the permuter!
Many more randomization passes could be added, the scoring function is far from optimal,
the permuter could be made easier to use, etc. etc. The GitHub Issues list has some ideas.
Ideally, `mypy permuter.py` and `./run-tests.sh` should succeed with no errors, and files
formatted with `black`. To setup a pre-commit hook for black, run:
```
pip install pre-commit black
pre-commit install
```
PRs that skip this are still welcome, however.
-25
View File
@@ -1,25 +0,0 @@
This file describes how to manually set up a directory for use with the permuter.
**You probably don't need to do this!** In normal circumstances, `./import.py`
does all this for you. See README.md for more details.
* create a directory that will contain all of the input files for the invokation
* put a compile command into `<dir>/compile.sh` (see e.g. `compile_example.sh`; it will be invoked as `./compile.sh input.c -o output.o`)
* `gcc -E -P -I header_dir -D'__attribute__(x)=' orig_c_file.c > <dir>/base.c`
* `python3 strip_other_fns.py <dir>/base.c func_name`
* put asm for `func_name` into `<dir>/target.s`, with the following header:
```asm
.set noat
.set noreorder
.set gp=64
.macro glabel label
.global \label
.type \label, @function
\label:
.endm
```
* `mips-linux-gnu-as -march=vr4300 -mabi=32 <dir>/target.s -o <dir>/target.o`
* optional sanity checks:
- `<dir>/compile.sh <dir>/base.c -o <dir>/base.o`
- `./diff.sh <dir>/target.o <dir>/base.o`
* `./permuter.py <dir>`
-2
View File
@@ -1,2 +0,0 @@
#!/bin/bash
mips-linux-gnu-gcc -O2 "$@"
-17
View File
@@ -1,17 +0,0 @@
#!/bin/bash
if [[ $# < 2 ]]; then
echo "Usage: $0 orig.o new.o [flags]"
exit 1
fi
if [ ! -f $1 -o ! -f $2 ]; then
echo Source files not readable
exit 1
fi
INPUT1="$1"
INPUT2="$2"
shift
shift
wdiff -n <(python3 ./src/objdump.py "$INPUT1" "$@") <(python3 ./src/objdump.py "$INPUT2" "$@") | colordiff | less -Ric
-808
View File
@@ -1,808 +0,0 @@
#!/usr/bin/env python3
# usage: ./import.py path/to/file.c path/to/asm.s [make flags]
import argparse
from collections import defaultdict
import json
import os
import platform
import re
import shlex
import shutil
import subprocess
import sys
import toml
from typing import Callable, Dict, List, Match, Mapping, Optional, Pattern, Set, Tuple
import urllib.request
import urllib.parse
from src import ast_util
from src.compiler import Compiler
from src.error import CandidateConstructionFailure
is_macos = platform.system() == "Darwin"
def homebrew_gcc_cpp() -> str:
lookup_paths = ["/usr/local/bin", "/opt/homebrew/bin"]
for lookup_path in lookup_paths:
try:
return max(f for f in os.listdir(lookup_path) if f.startswith("cpp-"))
except ValueError:
pass
print(
"Error while looking up in " + ":".join(lookup_paths) + " for cpp- executable"
)
sys.exit(1)
cpp_cmd = homebrew_gcc_cpp() if is_macos else "cpp"
make_cmd = "gmake" if is_macos else "make"
ASM_PRELUDE: str = """
.set noat
.set noreorder
.set gp=64
.macro glabel label
.global \label
.type \label, @function
\label:
.endm
"""
DEFAULT_AS_CMDLINE: List[str] = ["mips-linux-gnu-as", "-march=vr4300", "-mabi=32"]
CPP: List[str] = [cpp_cmd, "-P", "-undef"]
STUB_FN_MACROS: List[str] = [
"-D_Static_assert(x, y)=",
"-D__attribute__(x)=",
"-DGLOBAL_ASM(...)=",
]
SETTINGS_FILES = ["permuter_settings.toml", "tools/permuter_settings.toml"]
def formatcmd(cmdline: List[str]) -> str:
return " ".join(shlex.quote(arg) for arg in cmdline)
def parse_asm(asm_file: str) -> Tuple[str, str]:
func_name = None
asm_lines = []
try:
with open(asm_file, encoding="utf-8") as f:
cur_section = ".text"
for line in f:
if line.strip().startswith(".section"):
cur_section = line.split()[1]
elif line.strip() in [
".text",
".rdata",
".rodata",
".late_rodata",
".bss",
".data",
]:
cur_section = line.strip()
if cur_section == ".text":
if func_name is None and line.strip().startswith("glabel "):
func_name = line.split()[1]
asm_lines.append(line)
except OSError as e:
print("Could not open assembly file:", e, file=sys.stderr)
sys.exit(1)
if func_name is None:
print(
"Missing function name in assembly file! The file should start with 'glabel function_name'.",
file=sys.stderr,
)
sys.exit(1)
if not re.fullmatch(r"[a-zA-Z0-9_$]+", func_name):
print(f"Bad function name: {func_name}", file=sys.stderr)
sys.exit(1)
return func_name, "".join(asm_lines)
def create_directory(func_name: str) -> str:
os.makedirs(f"nonmatchings/", exist_ok=True)
ctr = 0
while True:
ctr += 1
dirname = f"{func_name}-{ctr}" if ctr > 1 else func_name
dirname = f"nonmatchings/{dirname}"
try:
os.mkdir(dirname)
return dirname
except FileExistsError:
pass
def find_root_dir(filename: str, pattern: List[str]) -> Optional[str]:
old_dirname = None
dirname = os.path.abspath(os.path.dirname(filename))
while dirname and (not old_dirname or len(dirname) < len(old_dirname)):
for fname in pattern:
if os.path.isfile(os.path.join(dirname, fname)):
return dirname
old_dirname = dirname
dirname = os.path.dirname(dirname)
return None
def fixup_build_command(
parts: List[str], ignore_part: str
) -> Tuple[List[str], Optional[List[str]]]:
res = []
skip_count = 0
assembler = None
for part in parts:
if skip_count > 0:
skip_count -= 1
continue
if part in ["-MF", "-o"]:
skip_count = 1
continue
if part == ignore_part:
continue
res.append(part)
try:
ind0 = min(
i
for i, arg in enumerate(res)
if any(
cmd in arg
for cmd in ["asm_processor", "asm-processor", "preprocess.py"]
)
)
ind1 = res.index("--", ind0 + 1)
ind2 = res.index("--", ind1 + 1)
assembler = res[ind1 + 1 : ind2]
res = res[ind0 + 1 : ind1] + res[ind2 + 1 :]
except ValueError:
pass
return res, assembler
def find_build_command_line(
root_dir: str, c_file: str, make_flags: List[str], build_system: str
) -> Tuple[List[str], List[str]]:
if build_system == "make":
build_invocation = [
make_cmd,
"--always-make",
"--dry-run",
"--debug=j",
"PERMUTER=1",
] + make_flags
elif build_system == "ninja":
build_invocation = ["ninja", "-t", "commands"] + make_flags
else:
print("Unknown build system '" + build_system + "'.")
sys.exit(1)
rel_c_file = os.path.relpath(c_file, root_dir)
debug_output = (
subprocess.check_output(build_invocation, cwd=root_dir)
.decode("utf-8")
.split("\n")
)
output = []
close_match = False
assembler = DEFAULT_AS_CMDLINE
for line in debug_output:
while "//" in line:
line = line.replace("//", "/")
while "/./" in line:
line = line.replace("/./", "/")
if rel_c_file not in line:
continue
close_match = True
parts = shlex.split(line)
# extract actual command from 'bash -c "..."'
if parts[0] == "bash" and "-c" in parts:
for part in parts:
if rel_c_file in part:
parts = shlex.split(part)
break
if rel_c_file not in parts:
continue
if "-o" not in parts:
continue
if "-fsyntax-only" in parts:
continue
cmdline, asmproc_assembler = fixup_build_command(parts, rel_c_file)
if asmproc_assembler:
assembler = asmproc_assembler
output.append(cmdline)
if not output:
close_extra = (
"\n(Found one possible candidate, but didn't match due to "
"either spaces in paths, having -fsyntax-only, or missing an -o flag.)"
if close_match
else ""
)
print(
"Failed to find compile command from build script output. "
f"Please ensure running '{' '.join(build_invocation)}' "
f"contains a line with the string '{rel_c_file}'.{close_extra}",
file=sys.stderr,
)
sys.exit(1)
if len(output) > 1:
output_lines = "\n".join(map(formatcmd, output))
print(
f"Error: found multiple compile commands for {rel_c_file}:\n{output_lines}\n"
f"Please modify the build script such that '{' '.join(build_invocation)}' "
"produces a single compile command.",
file=sys.stderr,
)
sys.exit(1)
return output[0], assembler
PreserveMacros = Tuple[Pattern[str], Callable[[str], str]]
def build_preserve_macros(
cwd: str, preserve_regex: Optional[str], settings: Mapping[str, object]
) -> Optional[PreserveMacros]:
subdata = settings.get("preserve_macros", {})
assert isinstance(subdata, dict)
regexes = []
for regex, value in subdata.items():
assert isinstance(value, str)
regexes.append((re.compile(f"^(?:{regex})$"), value))
if preserve_regex == "" or (preserve_regex is None and not regexes):
return None
if preserve_regex is None:
global_regex_text = "(?:" + ")|(?:".join(subdata.keys()) + ")"
else:
global_regex_text = preserve_regex
global_regex = re.compile(f"^(?:{global_regex_text})$")
def type_fn(macro: str) -> str:
for regex, value in regexes:
if regex.match(macro):
return value
return "int"
return global_regex, type_fn
def preprocess_c_with_macros(
cpp_command: List[str], cwd: str, preserve_macros: PreserveMacros
) -> Tuple[str, List[str]]:
"""Import C file, preserving function macros. Subroutine of import_c_file.
Returns the source code and a list of preserved macros."""
preserve_regex, preserve_type_fn = preserve_macros
# Start by running 'cpp' in a mode that just processes ifdefs and includes.
source = subprocess.check_output(
cpp_command + ["-dD", "-fdirectives-only"], cwd=cwd, encoding="utf-8"
)
# Modify function macros that match preserved names so the preprocessor
# doesn't touch them, and at the same time normalize their syntax. Some
# of these instances may be in comments, but that's fine.
def repl(match: Match[str]) -> str:
name = match.group(1)
after = "(" if match.group(2) == "(" else " "
if preserve_regex.match(name):
return f"_permuter define {name}{after}"
else:
return f"#define {name}{after}"
source = re.sub(
r"^\s*#\s*define\s+([a-zA-Z0-9_]+)([ \t\(]|$)",
repl,
source,
flags=re.MULTILINE,
)
# Get rid of auto-inserted macros which the second cpp invocation will
# warn about.
source = re.sub(r"^#define __STDC_.*\n", "", source, flags=re.MULTILINE)
# Now, run the preprocessor again for real.
source = subprocess.check_output(
CPP + STUB_FN_MACROS, cwd=cwd, encoding="utf-8", input=source
)
# Finally, find all function-like defines that we hid (some might have
# been comments, so we couldn't do this before), and construct fake
# function declarations for them in a specially demarcated section of
# the file. When the compiler runs, this section will be replaced by
# the real defines and the preprocessor invoked once more.
late_defines = []
lines = []
graph = defaultdict(set)
reg_token = re.compile(r"[a-zA-Z0-9_]+")
for line in source.splitlines():
is_macro = line.startswith("_permuter define ")
params = []
if is_macro:
ind1 = line.find("(")
ind2 = line.find(" ", len("_permuter define "))
ind = min(ind1, ind2)
if ind == -1:
ind = len(line) if ind1 == ind2 == -1 else max(ind1, ind2)
before = line[:ind]
after = line[ind:]
name = before.split()[2]
late_defines.append((name, after))
if after.startswith("("):
params = [w.strip() for w in after[1 : after.find(")")].split(",")]
else:
lines.append(line)
name = ""
for m in reg_token.finditer(line):
name2 = m.group(0)
has_wildcard = False
if is_macro and name2 not in params:
wcbefore = line[: m.start()].rstrip().endswith("##")
wcafter = line[m.end() :].lstrip().startswith("##")
if wcbefore or wcafter:
graph[name].add(name2 + "*")
has_wildcard = True
if not has_wildcard:
graph[name].add(name2)
# Prune away (recursively) unused macros, for cleanliness.
used_anywhere = set()
used_by_nonmacro = graph[""]
queue = [""]
while queue:
name = queue.pop()
if name not in used_anywhere:
used_anywhere.add(name)
if name.endswith("*"):
wildcard = name[:-1]
for name2 in graph:
if wildcard in name2:
queue.extend(graph[name2])
else:
queue.extend(graph[name])
def get_decl(name: str, after: str) -> str:
typ = preserve_type_fn(name)
if after.startswith("("):
return f"{typ} {name}();"
else:
return f"extern {typ} {name};"
used_macros = [name for (name, after) in late_defines if name in used_by_nonmacro]
return (
"\n".join(
["#pragma _permuter latedefine start"]
+ [
f"#pragma _permuter define {name}{after}"
for (name, after) in late_defines
if name in used_anywhere
]
+ [
get_decl(name, after)
for (name, after) in late_defines
if name in used_by_nonmacro
]
+ ["#pragma _permuter latedefine end"]
+ lines
+ [""]
),
used_macros,
)
def import_c_file(
compiler: List[str],
cwd: str,
in_file: str,
preserve_macros: Optional[PreserveMacros],
) -> str:
"""Preprocess a C file into permuter-usable source.
Prints preserved macros as a side effect.
Returns source for base.c and compilable (macro-expanded) source."""
in_file = os.path.relpath(in_file, cwd)
include_next = 0
cpp_command = CPP + [in_file, "-D__sgi", "-D_LANGUAGE_C", "-DNON_MATCHING"]
for arg in compiler:
if include_next > 0:
include_next -= 1
cpp_command.append(arg)
continue
if arg in ["-D", "-U", "-I"]:
cpp_command.append(arg)
include_next = 1
continue
if (
arg.startswith("-D")
or arg.startswith("-U")
or arg.startswith("-I")
or arg in ["-nostdinc"]
):
cpp_command.append(arg)
try:
if preserve_macros is None:
# Simple codepath, should work even if the more complex one breaks.
source = subprocess.check_output(
cpp_command + STUB_FN_MACROS, cwd=cwd, encoding="utf-8"
)
macros: List[str] = []
else:
source, macros = preprocess_c_with_macros(cpp_command, cwd, preserve_macros)
except subprocess.CalledProcessError as e:
print(
"Failed to preprocess input file, when running command:\n"
+ formatcmd(e.cmd),
file=sys.stderr,
)
sys.exit(1)
if macros:
macro_str = "macros: " + ", ".join(macros)
else:
macro_str = "no macros"
print(f"Preserving {macro_str}. Use --preserve-macros='<regex>' to override.")
return source
def prune_source(
source: str, should_prune: bool, func_name: str
) -> Tuple[str, Optional[str]]:
"""Normalize the source by round-tripping it through pycparser, and
optionally reduce it to a smaller version that includes only the imported
function and functions/struct/variables that it uses.
Returns (source, compilable_source)."""
try:
ast = ast_util.parse_c(source, from_import=True)
orig_fn, _ = ast_util.extract_fn(ast, func_name)
if should_prune:
try:
ast_util.prune_ast(orig_fn, ast)
source = ast_util.to_c_raw(ast)
except Exception:
print(
"Source minimization failed! "
"You could try --no-prune as a workaround."
)
raise
return source, ast_util.to_c(ast, from_import=True)
except CandidateConstructionFailure as e:
print(e.message)
if should_prune and "PERM_" in source:
print(
"Please put in PERM macros after import, otherwise source "
"minimization does not work."
)
else:
print("Proceeding anyway, but expect errors when permuting!")
return source, None
def prune_and_separate_context(
source: str, should_prune: bool, func_name: str
) -> Tuple[str, str]:
"""Normalize the source by round-tripping it through pycparser, optionally
reduce it to a smaller version that includes only the imported function and
functions/struct/variables that it uses, and split the result into source
for the function itself, and the rest of the file (the "context").
Returns (source, context)."""
try:
ast = ast_util.parse_c(source, from_import=True)
orig_fn, ind = ast_util.extract_fn(ast, func_name)
if should_prune:
try:
ind = ast_util.prune_ast(orig_fn, ast)
except Exception:
print(
"Source minimization failed! "
"You could try --no-prune as a workaround."
)
raise
del ast.ext[ind]
source = ast_util.to_c(orig_fn, from_import=True)
context = ast_util.to_c(ast, from_import=True)
return source, context
except CandidateConstructionFailure as e:
print(e.message)
print("Unable to split context from source.")
print("Proceeding anyway, but expect compile errors!")
return ast_util.process_pragmas(source), ""
def get_decompme_compiler_name(
compiler: List[str], settings: Mapping[str, object], api_base: str
) -> str:
decompme_settings = settings.get("decompme", {})
assert isinstance(decompme_settings, dict)
compiler_mappings = decompme_settings.get("compilers", {})
assert isinstance(compiler_mappings, dict)
compiler_path = compiler[0]
for path, compiler_name in compiler_mappings.items():
assert isinstance(compiler_name, str)
if path == compiler_path:
return compiler_name
try:
with urllib.request.urlopen(f"{api_base}/api/compilers") as f:
json_data = json.load(f)
available = json_data["compiler_ids"]
if not isinstance(available, list):
raise Exception("compiler_ids must be a list")
if not all(isinstance(name, str) for name in available):
raise Exception("compiler_ids must be a list of strings")
except Exception as e:
print(f"Failed to request available compilers from decomp.me:\n{e}")
print()
print(
f'Unable to map compiler path "{compiler_path}" to something '
"decomp.me understands."
)
trail = "permuter_settings.toml, where ... is one of: " + ", ".join(available)
if compiler_mappings:
print(
"Please add an entry:\n\n"
f'"{compiler_path}" = "..."\n\n'
f"to the [decompme.compilers] section of {trail}"
)
else:
print(
"Please add an section:\n\n"
"[decompme.compilers]\n"
f'"{compiler_path}" = "..."\n\n'
f"to {trail}"
)
sys.exit(1)
def finalize_compile_command(cmdline: List[str]) -> str:
quoted = [arg if arg == "|" else shlex.quote(arg) for arg in cmdline]
ind = (quoted + ["|"]).index("|")
return " ".join(quoted[:ind] + ['"$INPUT"'] + quoted[ind:] + ["-o", '"$OUTPUT"'])
def get_compiler_flags(cmdline: List[str]) -> str:
flags = [b for a, b in zip(cmdline, cmdline[1:]) if a != "|" and b != "|"]
return " ".join(shlex.quote(flag) for flag in flags)
def write_compile_command(compiler: List[str], cwd: str, out_file: str) -> None:
with open(out_file, "w", encoding="utf-8") as f:
f.write("#!/usr/bin/env bash\n")
f.write('INPUT="$(realpath "$1")"\n')
f.write('OUTPUT="$(realpath "$3")"\n')
f.write(f"cd {shlex.quote(cwd)}\n")
f.write(finalize_compile_command(compiler))
os.chmod(out_file, 0o755)
def write_asm(asm_cont: str, out_file: str) -> None:
with open(out_file, "w", encoding="utf-8") as f:
f.write(ASM_PRELUDE)
f.write(asm_cont)
def compile_asm(assembler: List[str], cwd: str, in_file: str, out_file: str) -> None:
in_file = os.path.abspath(in_file)
out_file = os.path.abspath(out_file)
cmdline = assembler + [in_file, "-o", out_file]
try:
subprocess.check_call(cmdline, cwd=cwd)
except subprocess.CalledProcessError:
print(
f"Failed to assemble .s file, command line:\n{formatcmd(cmdline)}",
file=sys.stderr,
)
sys.exit(1)
def compile_base(compile_script: str, source: str, c_file: str, out_file: str) -> None:
if "PERM_" in source:
print(
"Cannot test-compile imported code because it contains PERM macros. "
"It is recommended to put in PERM macros after import."
)
return
escaped_c_file = json.dumps(c_file)
source = "#line 1 " + escaped_c_file + "\n" + source
compiler = Compiler(compile_script, show_errors=True)
o_file = compiler.compile(source)
if o_file:
shutil.move(o_file, out_file)
else:
print("Warning: failed to compile .c file.")
def write_to_file(cont: str, filename: str) -> None:
with open(filename, "w", encoding="utf-8") as f:
f.write(cont)
def main() -> None:
parser = argparse.ArgumentParser(
description="""Import a function for use with the permuter.
Will create a new directory nonmatchings/<funcname>-<id>/."""
)
parser.add_argument(
"c_file",
help="""File containing the function.
Assumes that the file can be built with 'make' to create an .o file.""",
)
parser.add_argument(
"asm_file",
help="""File containing assembly for the function.
Must start with 'glabel <function_name>' and contain no other functions.""",
)
parser.add_argument(
"make_flags",
nargs="*",
help="Arguments to pass to 'make'. PERMUTER=1 will always be passed.",
)
parser.add_argument(
"--keep", action="store_true", help="Keep the directory on error."
)
settings_files = ", ".join(SETTINGS_FILES[:-1]) + " or " + SETTINGS_FILES[-1]
parser.add_argument(
"--preserve-macros",
metavar="REGEX",
dest="preserve_macros_regex",
help=f"""Regex for which macros to preserve, or empty string for no macros.
By default, this is read from {settings_files} in a parent directory of
the imported file. Type information is also read from this file.""",
)
parser.add_argument(
"--no-prune",
dest="prune",
action="store_false",
help="""Don't minimize the source to keep only the imported function and
functions/struct/variables that it uses. Normally this behavior is
useful to make the permuter faster, but in cases where unrelated code
affects the generated assembly asm it can be necessary to turn off.
Note that regardless of this setting the permuter always removes all
other functions by replacing them with declarations.""",
)
parser.add_argument(
"--decompme",
dest="decompme",
action="store_true",
help="""Upload the function to decomp.me to share with other people,
instead of importing.""",
)
args = parser.parse_args()
root_dir = find_root_dir(
args.c_file, SETTINGS_FILES + ["Makefile", "makefile", "build.ninja"]
)
if not root_dir:
print(f"Can't find root dir of project!", file=sys.stderr)
sys.exit(1)
settings: Mapping[str, object] = {}
for filename in SETTINGS_FILES:
filename = os.path.join(root_dir, filename)
if os.path.exists(filename):
with open(filename) as f:
settings = toml.load(f)
break
build_system = settings.get("build_system", "make")
compiler = settings.get("compiler_command")
assembler = settings.get("assembler_command")
make_flags = args.make_flags
func_name, asm_cont = parse_asm(args.asm_file)
print(f"Function name: {func_name}")
if compiler or assembler:
assert isinstance(compiler, str)
assert isinstance(assembler, str)
assert settings.get("build_system") is None
compiler = shlex.split(compiler)
assembler = shlex.split(assembler)
else:
assert isinstance(build_system, str)
compiler, assembler = find_build_command_line(
root_dir, args.c_file, make_flags, build_system
)
print(f"Compiler: {formatcmd(compiler)} {{input}} -o {{output}}")
print(f"Assembler: {formatcmd(assembler)} {{input}} -o {{output}}")
preserve_macros = build_preserve_macros(
root_dir, args.preserve_macros_regex, settings
)
source = import_c_file(compiler, root_dir, args.c_file, preserve_macros)
if args.decompme:
api_base = os.environ.get("DECOMPME_API_BASE", "https://decomp.me")
compiler_name = get_decompme_compiler_name(compiler, settings, api_base)
source, context = prune_and_separate_context(source, args.prune, func_name)
print("Uploading...")
try:
post_data = urllib.parse.urlencode(
{
"target_asm": asm_cont,
"context": context,
"source_code": source,
"compiler": compiler_name,
"compiler_flags": get_compiler_flags(compiler),
}
).encode("ascii")
with urllib.request.urlopen(f"{api_base}/api/scratch", post_data) as f:
resp = f.read()
json_data: Dict[str, str] = json.loads(resp)
if "slug" in json_data:
slug = json_data["slug"]
print(f"https://decomp.me/scratch/{slug}")
else:
error = json_data.get("error", resp)
print(f"Server error: {error}")
except Exception as e:
print(e)
return
source, compilable_source = prune_source(source, args.prune, func_name)
dirname = create_directory(func_name)
base_c_file = f"{dirname}/base.c"
base_o_file = f"{dirname}/base.o"
target_s_file = f"{dirname}/target.s"
target_o_file = f"{dirname}/target.o"
compile_script = f"{dirname}/compile.sh"
func_name_file = f"{dirname}/function.txt"
try:
write_to_file(source, base_c_file)
write_to_file(func_name, func_name_file)
write_compile_command(compiler, root_dir, compile_script)
write_asm(asm_cont, target_s_file)
compile_asm(assembler, root_dir, target_s_file, target_o_file)
if compilable_source is not None:
compile_base(compile_script, compilable_source, base_c_file, base_o_file)
except:
if not args.keep:
print(f"\nDeleting directory {dirname} (run with --keep to preserve it).")
shutil.rmtree(dirname)
raise
print(f"\nDone. Imported into {dirname}")
if __name__ == "__main__":
main()
-27
View File
@@ -1,27 +0,0 @@
[mypy]
check_untyped_defs = True
disallow_any_generics = False
disallow_incomplete_defs = True
disallow_subclassing_any = True
disallow_untyped_calls = True
disallow_untyped_decorators = True
disallow_untyped_defs = True
no_implicit_optional = True
warn_redundant_casts = True
warn_return_any = True
warn_unused_ignores = True
mypy_path = stubs
python_version = 3.7
files = import.py, pah.py, permuter.py, src/net/evaluator.py
[mypy-nacl.*]
ignore_missing_imports = True
[mypy-pystray.*]
ignore_missing_imports = True
[mypy-docker.*]
ignore_missing_imports = True
[mypy-PIL.*]
ignore_missing_imports = True
-4
View File
@@ -1,4 +0,0 @@
#!/usr/bin/env python3
from src.net.cmd.main import main
main()
-5
View File
@@ -1,5 +0,0 @@
#!/usr/bin/env python3
from src.main import main
if __name__ == "__main__":
main()
@@ -1,9 +0,0 @@
# Optional configuration file for import.py. Put it in the root or in tools/
# of the repo you are importing from.
build_system = "ninja"
[preserve_macros]
"g[DS]P.*" = "void"
"gDma.*" = "void"
"_SHIFTL" = "unsigned int"
-3
View File
@@ -1,3 +0,0 @@
#!/bin/sh
python3 -m unittest discover -s test/
# python3 -m pytest test/
-13
View File
@@ -1,13 +0,0 @@
#!/bin/bash
if [[ $1 < 2 ]]; then
echo "Usage: $0 output_dir"
echo "Ex: $0 nonmatchings/func_80000000"
exit 1
fi
if [[ ! -d $1 ]]; then
echo "Argument must be a directory"
exit 1
fi
find $1 -name score.txt -exec echo -n {}\ \; -exec cat {} \; | sort -rnk2
-303
View File
@@ -1,303 +0,0 @@
"""Functions and classes for dealing with types in a C AST.
They make a number of simplifying assumptions:
- const and volatile doesn't matter.
- arithmetic promotes all int-like types to 'int'.
- no two variables can have the same name, even across functions.
For the purposes of the randomizer these restrictions are acceptable."""
from dataclasses import dataclass, field
from typing import Union, Dict, Set, List
from pycparser import c_ast
from pycparser.c_ast import ArrayDecl, TypeDecl, PtrDecl, FuncDecl, IdentifierType
Type = Union[PtrDecl, ArrayDecl, TypeDecl, FuncDecl]
SimpleType = Union[PtrDecl, TypeDecl]
StructUnion = Union[c_ast.Struct, c_ast.Union]
@dataclass
class TypeMap:
typedefs: Dict[str, Type] = field(default_factory=dict)
fn_ret_types: Dict[str, Type] = field(default_factory=dict)
var_types: Dict[str, Type] = field(default_factory=dict)
struct_defs: Dict[str, StructUnion] = field(default_factory=dict)
def basic_type(name: Union[str, List[str]]) -> TypeDecl:
names = [name] if isinstance(name, str) else name
idtype = IdentifierType(names=names)
return TypeDecl(declname=None, quals=[], type=idtype)
def pointer(type: Type) -> Type:
return PtrDecl(quals=[], type=type)
def resolve_typedefs(type: Type, typemap: TypeMap) -> Type:
while (
isinstance(type, TypeDecl)
and isinstance(type.type, IdentifierType)
and len(type.type.names) == 1
and type.type.names[0] in typemap.typedefs
):
type = typemap.typedefs[type.type.names[0]]
return type
def pointer_decay(type: Type, typemap: TypeMap) -> SimpleType:
real_type = resolve_typedefs(type, typemap)
if isinstance(real_type, ArrayDecl):
return PtrDecl(quals=[], type=real_type.type)
if isinstance(real_type, FuncDecl):
return PtrDecl(quals=[], type=type)
if isinstance(real_type, TypeDecl) and isinstance(real_type.type, c_ast.Enum):
return basic_type("int")
assert not isinstance(
type, (ArrayDecl, FuncDecl)
), "resolve_typedefs can't hide arrays/functions"
return type
def get_decl_type(decl: c_ast.Decl) -> Type:
"""For a Decl that declares a variable (and not just a struct/union/enum),
return its type."""
assert decl.name is not None
assert isinstance(decl.type, (PtrDecl, ArrayDecl, FuncDecl, TypeDecl))
return decl.type
def deref_type(type: Type, typemap: TypeMap) -> Type:
type = resolve_typedefs(type, typemap)
assert isinstance(type, (ArrayDecl, PtrDecl)), "dereferencing non-pointer"
return type.type
def struct_member_type(struct: StructUnion, field_name: str, typemap: TypeMap) -> Type:
if not struct.decls:
assert (
struct.name in typemap.struct_defs
), f"Accessing field {field_name} of undefined struct {struct.name}"
struct = typemap.struct_defs[struct.name]
assert struct.decls, "struct_defs never points to an incomplete type"
for decl in struct.decls:
if isinstance(decl, c_ast.Decl):
if decl.name == field_name:
return get_decl_type(decl)
if decl.name == None and isinstance(decl.type, (c_ast.Struct, c_ast.Union)):
try:
return struct_member_type(decl.type, field_name, typemap)
except AssertionError:
pass
assert False, f"No field {field_name} in struct {struct.name}"
def expr_type(node: c_ast.Node, typemap: TypeMap) -> Type:
def rec(sub_expr: c_ast.Node) -> Type:
return expr_type(sub_expr, typemap)
if isinstance(node, c_ast.Assignment):
return rec(node.lvalue)
if isinstance(node, c_ast.StructRef):
lhs_type = rec(node.name)
if node.type == "->":
lhs_type = deref_type(lhs_type, typemap)
struct_type = resolve_typedefs(lhs_type, typemap)
assert isinstance(struct_type, TypeDecl)
assert isinstance(
struct_type.type, (c_ast.Struct, c_ast.Union)
), f"struct deref of non-struct {struct_type.declname}"
return struct_member_type(struct_type.type, node.field.name, typemap)
if isinstance(node, c_ast.Cast):
return node.to_type.type
if isinstance(node, c_ast.Constant):
if node.type == "string":
return pointer(basic_type("char"))
if node.type == "char":
return basic_type("int")
return basic_type(node.type.split(" "))
if isinstance(node, c_ast.ID):
return typemap.var_types[node.name]
if isinstance(node, c_ast.UnaryOp):
if node.op in ["p++", "p--", "++", "--"]:
return rec(node.expr)
if node.op == "&":
return pointer(rec(node.expr))
if node.op == "*":
subtype = rec(node.expr)
return deref_type(subtype, typemap)
if node.op in ["-", "+"]:
subtype = pointer_decay(rec(node.expr), typemap)
if allowed_basic_type(subtype, typemap, ["double"]):
return basic_type("double")
if allowed_basic_type(subtype, typemap, ["float"]):
return basic_type("float")
if node.op in ["sizeof", "-", "+", "~", "!"]:
return basic_type("int")
assert False, f"unknown unary op {node.op}"
if isinstance(node, c_ast.BinaryOp):
lhs_type = pointer_decay(rec(node.left), typemap)
rhs_type = pointer_decay(rec(node.right), typemap)
if node.op in [">>", "<<"]:
return lhs_type
if node.op in ["<", "<=", ">", ">=", "==", "!=", "&&", "||"]:
return basic_type("int")
if node.op in "&|^%":
return basic_type("int")
real_lhs = resolve_typedefs(lhs_type, typemap)
real_rhs = resolve_typedefs(rhs_type, typemap)
if node.op in "+-":
lptr = isinstance(real_lhs, PtrDecl)
rptr = isinstance(real_rhs, PtrDecl)
if lptr or rptr:
if lptr and rptr:
assert node.op != "+", "pointer + pointer"
return basic_type("int")
if lptr:
return lhs_type
assert node.op == "+", "int - pointer"
return rhs_type
if node.op in "*/+-":
assert isinstance(real_lhs, TypeDecl)
assert isinstance(real_rhs, TypeDecl)
assert isinstance(real_lhs.type, IdentifierType)
assert isinstance(real_rhs.type, IdentifierType)
if "double" in real_lhs.type.names + real_rhs.type.names:
return basic_type("double")
if "float" in real_lhs.type.names + real_rhs.type.names:
return basic_type("float")
return basic_type("int")
if isinstance(node, c_ast.FuncCall):
expr = node.name
if isinstance(expr, c_ast.ID):
if expr.name not in typemap.fn_ret_types:
raise Exception(f"Called function {expr.name} is missing a prototype")
return typemap.fn_ret_types[expr.name]
else:
fptr_type = resolve_typedefs(rec(expr), typemap)
if isinstance(fptr_type, PtrDecl):
fptr_type = fptr_type.type
fptr_type = resolve_typedefs(fptr_type, typemap)
assert isinstance(fptr_type, FuncDecl), "call to non-function"
return fptr_type.type
if isinstance(node, c_ast.ExprList):
return rec(node.exprs[-1])
if isinstance(node, c_ast.ArrayRef):
subtype = rec(node.name)
return deref_type(subtype, typemap)
if isinstance(node, c_ast.TernaryOp):
return rec(node.iftrue)
assert False, f"Unknown expression node type: {node}"
def decayed_expr_type(expr: c_ast.Node, typemap: TypeMap) -> SimpleType:
return pointer_decay(expr_type(expr, typemap), typemap)
def same_type(
type1: Type, type2: Type, typemap: TypeMap, allow_similar: bool = False
) -> bool:
while True:
type1 = resolve_typedefs(type1, typemap)
type2 = resolve_typedefs(type2, typemap)
if isinstance(type1, ArrayDecl) and isinstance(type2, ArrayDecl):
type1 = type1.type
type2 = type2.type
continue
if isinstance(type1, PtrDecl) and isinstance(type2, PtrDecl):
type1 = type1.type
type2 = type2.type
continue
if isinstance(type1, TypeDecl) and isinstance(type2, TypeDecl):
sub1 = type1.type
sub2 = type2.type
if isinstance(sub1, c_ast.Struct) and isinstance(sub2, c_ast.Struct):
return sub1.name == sub2.name
if isinstance(sub1, c_ast.Union) and isinstance(sub2, c_ast.Union):
return sub1.name == sub2.name
if (
allow_similar
and isinstance(sub1, (IdentifierType, c_ast.Enum))
and isinstance(sub2, (IdentifierType, c_ast.Enum))
):
# All int-ish types are similar (except void, but whatever)
return True
if isinstance(sub1, c_ast.Enum) and isinstance(sub2, c_ast.Enum):
return sub1.name == sub2.name
if isinstance(sub1, IdentifierType) and isinstance(sub2, IdentifierType):
return sorted(sub1.names) == sorted(sub2.names)
return False
def allowed_basic_type(
type: SimpleType, typemap: TypeMap, allowed_types: List[str]
) -> bool:
"""Check if a type resolves to a basic type with one of the allowed_types
keywords in it."""
base_type = resolve_typedefs(type, typemap)
if not isinstance(base_type, c_ast.TypeDecl):
return False
if not isinstance(base_type.type, c_ast.IdentifierType):
return False
if all(x not in base_type.type.names for x in allowed_types):
return False
return True
def build_typemap(ast: c_ast.FileAST) -> TypeMap:
ret = TypeMap()
for item in ast.ext:
if isinstance(item, c_ast.Typedef):
ret.typedefs[item.name] = item.type
if isinstance(item, c_ast.FuncDef):
assert item.decl.name is not None, "cannot define anonymous function"
assert isinstance(item.decl.type, FuncDecl)
ret.fn_ret_types[item.decl.name] = item.decl.type.type
if isinstance(item, c_ast.Decl) and isinstance(item.type, FuncDecl):
assert item.name is not None, "cannot define anonymous function"
ret.fn_ret_types[item.name] = item.type.type
defined_function_decls: Set[c_ast.Decl] = set()
class Visitor(c_ast.NodeVisitor):
def visit_Struct(self, struct: c_ast.Struct) -> None:
if struct.decls and struct.name is not None:
ret.struct_defs[struct.name] = struct
# Do not visit decls of this struct
def visit_Union(self, union: c_ast.Union) -> None:
if union.decls and union.name is not None:
ret.struct_defs[union.name] = union
# Do not visit decls of this union
def visit_Decl(self, decl: c_ast.Decl) -> None:
if decl.name is not None:
ret.var_types[decl.name] = get_decl_type(decl)
if not isinstance(decl.type, FuncDecl) or decl in defined_function_decls:
# Do not visit declarations in parameter lists of functions
# other than our own.
self.visit(decl.type)
def visit_Enumerator(self, enumerator: c_ast.Enumerator) -> None:
ret.var_types[enumerator.name] = basic_type("int")
def visit_FuncDef(self, fn: c_ast.FuncDef) -> None:
if fn.decl.name is not None:
ret.var_types[fn.decl.name] = get_decl_type(fn.decl)
defined_function_decls.add(fn.decl)
self.generic_visit(fn)
Visitor().visit(ast)
return ret
def set_decl_name(decl: c_ast.Decl) -> None:
name = decl.name
assert name is not None
type = get_decl_type(decl)
while not isinstance(type, TypeDecl):
type = type.type
type.declname = name
-505
View File
@@ -1,505 +0,0 @@
from base64 import b64decode
from collections import defaultdict
import copy
from dataclasses import dataclass
from random import Random
import re
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
from pycparser import CParser, c_ast as ca, c_generator
from pycparser.plyparser import ParseError
from .error import CandidateConstructionFailure
from .ast_types import SimpleType, set_decl_name
@dataclass
class Indices:
starts: Dict[ca.Node, int]
ends: Dict[ca.Node, int]
Block = Union[ca.Compound, ca.Case, ca.Default]
if TYPE_CHECKING:
# ca.Expression and ca.Statement don't actually exist, they live only in
# the stubs file.
Expression = ca.Expression
Statement = ca.Statement
else:
Expression = Statement = None
def to_c_raw(node: ca.Node) -> str:
source: str = c_generator.CGenerator().visit(node)
return source
def to_c(node: ca.Node, *, from_import: bool = False) -> str:
source = to_c_raw(node) if from_import else PatchedCGenerator().visit(node)
return process_pragmas(source)
def process_pragmas(source: str) -> str:
if "#pragma" not in source:
return source
lines = source.split("\n")
out: List[str] = []
same_line = 0
ignore = 0
for line in lines:
stripped = line.strip()
if stripped.startswith("#pragma _permuter "):
# Expand permuter pragmas to nothing, by default. Still, keep one
# output line per input line to preserve line numbers for import.py
# error messages.
line = ""
stripped = stripped[len("#pragma _permuter ") :]
if stripped == "sameline start":
same_line += 1
elif stripped == "sameline end":
same_line -= 1
elif stripped == "latedefine start":
ignore += 1
elif stripped == "latedefine end":
assert ignore > 0, "mismatched ignore pragmas"
ignore -= 1
elif stripped.startswith("define "):
assert ignore > 0, "define pragma must be within latedefine block"
line = "#" + stripped
elif stripped.startswith("b64literal "):
line = b64decode(stripped.split(" ", 1)[1]).decode("utf-8")
elif ignore > 0:
# Ignore non-pragma lines within latedefine section
line = ""
if not same_line:
line += "\n"
elif line and out and not out[-1].endswith("\n"):
line = " " + line.lstrip()
out.append(line)
assert same_line == 0
assert ignore == 0, "unbalanced ignore pragmas"
return "".join(out).rstrip() + "\n"
class PatchedCGenerator(c_generator.CGenerator):
"""Like a CGenerator, except it keeps else if's prettier despite
the terrible things we've done to them in normalize_ast."""
def visit_If(self, n: ca.If) -> str:
n2 = n
if (
n.iffalse
and isinstance(n.iffalse, ca.Compound)
and n.iffalse.block_items
and len(n.iffalse.block_items) == 1
and isinstance(n.iffalse.block_items[0], ca.If)
):
n2 = ca.If(cond=n.cond, iftrue=n.iftrue, iffalse=n.iffalse.block_items[0])
return super().visit_If(n2) # type: ignore
def extract_fn(ast: ca.FileAST, fn_name: str) -> Tuple[ca.FuncDef, int]:
ret = []
for i, node in enumerate(ast.ext):
if isinstance(node, ca.FuncDef):
if node.decl.name == fn_name:
ret.append((node, i))
else:
node = node.decl
ast.ext[i] = node
if isinstance(node, ca.Decl) and isinstance(node.type, ca.FuncDecl):
node.funcspec = [spec for spec in node.funcspec if spec != "static"]
if len(ret) == 0:
raise CandidateConstructionFailure(f"Function {fn_name} not found in base.c.")
if len(ret) > 1:
raise CandidateConstructionFailure(
f"Found multiple copies of function {fn_name} in base.c."
)
return ret[0]
def parse_c(source: str, *, from_import: bool = False) -> ca.FileAST:
try:
parser = CParser()
return parser.parse(source, "<source>")
except ParseError as e:
msg = str(e)
position, msg = msg.split(": ", 1)
parts = position.split(":")
if len(parts) >= 2:
lineno = int(parts[1])
posstr = f" at approximately line {lineno}"
if len(parts) >= 3:
posstr += f", column {parts[2]}"
if not from_import:
posstr += " (after PERM expansion)"
try:
line = source.split("\n")[lineno - 1].rstrip()
posstr += "\n\n" + line
except IndexError:
posstr += "(out of bounds?)"
else:
posstr = ""
raise CandidateConstructionFailure(
f"Syntax error in base.c.\n{msg}{posstr}"
) from None
def compute_node_indices(top_node: ca.Node) -> Indices:
starts: Dict[ca.Node, int] = {}
ends: Dict[ca.Node, int] = {}
cur_index = 1
class Visitor(ca.NodeVisitor):
def generic_visit(self, node: ca.Node) -> None:
nonlocal cur_index
assert node not in starts, "nodes should only appear once in AST"
starts[node] = cur_index
cur_index += 2
super().generic_visit(node)
ends[node] = cur_index
cur_index += 2
Visitor().visit(top_node)
return Indices(starts, ends)
def equal_ast(a: ca.Node, b: ca.Node) -> bool:
def equal(a: Any, b: Any) -> bool:
if type(a) != type(b):
return False
if a is None:
return b is None
if isinstance(a, list):
assert isinstance(b, list)
if len(a) != len(b):
return False
for i in range(len(a)):
if not equal(a[i], b[i]):
return False
return True
if isinstance(a, (int, str)):
return bool(a == b)
assert isinstance(a, ca.Node)
for name in a.__slots__[:-2]: # type: ignore
if not equal(getattr(a, name), getattr(b, name)):
return False
return True
return equal(a, b)
def is_lvalue(expr: Expression) -> bool:
if isinstance(expr, (ca.ID, ca.StructRef, ca.ArrayRef)):
return True
if isinstance(expr, ca.UnaryOp):
return expr.op == "*"
return False
def is_effectful(expr: Expression) -> bool:
found = False
class Visitor(ca.NodeVisitor):
def visit_UnaryOp(self, node: ca.UnaryOp) -> None:
nonlocal found
if node.op in ["p++", "p--", "++", "--"]:
found = True
else:
self.generic_visit(node.expr)
def visit_FuncCall(self, _: ca.Node) -> None:
nonlocal found
found = True
def visit_Assignment(self, _: ca.Node) -> None:
nonlocal found
found = True
Visitor().visit(expr)
return found
def get_block_stmts(block: Block, force: bool) -> List[Statement]:
if isinstance(block, ca.Compound):
ret = block.block_items or []
if force and not block.block_items:
block.block_items = ret
else:
ret = block.stmts or []
if force and not block.stmts:
block.stmts = ret
return ret
def insert_decl(
fn: ca.FuncDef, var: str, type: SimpleType, random: Optional[Random] = None
) -> None:
type = copy.deepcopy(type)
decl = ca.Decl(
name=var, quals=[], storage=[], funcspec=[], type=type, init=None, bitsize=None
)
set_decl_name(decl)
assert fn.body.block_items, "Non-empty function"
for index, stmt in enumerate(fn.body.block_items):
if not isinstance(stmt, ca.Decl):
break
else:
index = len(fn.body.block_items)
if random:
index = random.randint(0, index)
fn.body.block_items[index:index] = [decl]
def insert_statement(block: Block, index: int, stmt: Statement) -> None:
stmts = get_block_stmts(block, True)
stmts[index:index] = [stmt]
def brace_nested_blocks(stmt: Statement) -> None:
def brace(stmt: Statement) -> Block:
if isinstance(stmt, (ca.Compound, ca.Case, ca.Default)):
return stmt
return ca.Compound([stmt])
if isinstance(stmt, (ca.For, ca.While, ca.DoWhile)):
stmt.stmt = brace(stmt.stmt)
elif isinstance(stmt, ca.If):
stmt.iftrue = brace(stmt.iftrue)
if stmt.iffalse:
stmt.iffalse = brace(stmt.iffalse)
elif isinstance(stmt, ca.Switch):
stmt.stmt = brace(stmt.stmt)
elif isinstance(stmt, ca.Label):
brace_nested_blocks(stmt.stmt)
def has_nested_block(node: ca.Node) -> bool:
return isinstance(
node,
(
ca.Compound,
ca.For,
ca.While,
ca.DoWhile,
ca.If,
ca.Switch,
ca.Case,
ca.Default,
),
)
def for_nested_blocks(stmt: Statement, callback: Callable[[Block], None]) -> None:
def invoke(stmt: Statement) -> None:
assert isinstance(
stmt, (ca.Compound, ca.Case, ca.Default)
), "brace_nested_blocks should have turned nested statements into blocks"
callback(stmt)
if isinstance(stmt, ca.Compound):
invoke(stmt)
elif isinstance(stmt, (ca.For, ca.While, ca.DoWhile)):
invoke(stmt.stmt)
elif isinstance(stmt, ca.If):
if stmt.iftrue:
invoke(stmt.iftrue)
if stmt.iffalse:
invoke(stmt.iffalse)
elif isinstance(stmt, ca.Switch):
invoke(stmt.stmt)
elif isinstance(stmt, (ca.Case, ca.Default)):
invoke(stmt)
elif isinstance(stmt, ca.Label):
for_nested_blocks(stmt.stmt, callback)
def normalize_ast(fn: ca.FuncDef, ast: ca.FileAST) -> None:
"""Add braces to all ifs/fors/etc., to make it easier to insert statements."""
def rec(block: Block) -> None:
stmts = get_block_stmts(block, False)
for stmt in stmts:
brace_nested_blocks(stmt)
for_nested_blocks(stmt, rec)
rec(fn.body)
def prune_ast(fn: ca.FuncDef, ast: ca.FileAST) -> int:
"""Prune away unnecessary parts of the AST, to reduce overhead from serialization
and from the compiler's C parser."""
# Create a GC graph that maps names of declarations and enumerators to indices
# in ast.ext, as well an initial list of GC roots, consisting of everything
# that isn't a Decl and or Typedef.
edges: Dict[str, List[int]] = defaultdict(list)
gc_roots: List[int] = []
can_fwd_declare_typedef: Set[str] = set()
can_fwd_declare_tagged: Set[str] = set()
def add_type_edges(
tp: Union["ca.Type", ca.Struct, ca.Union, ca.Enum], i: int
) -> None:
while isinstance(tp, (ca.PtrDecl, ca.ArrayDecl)):
tp = tp.type
if isinstance(tp, ca.FuncDecl):
return
inner_type = tp.type if isinstance(tp, ca.TypeDecl) else tp
if isinstance(inner_type, ca.IdentifierType):
return
if inner_type.name:
edges[inner_type.name].append(i)
if isinstance(inner_type, ca.Enum) and inner_type.values:
for value in inner_type.values.enumerators:
edges[value.name].append(i)
if isinstance(inner_type, (ca.Struct, ca.Union)) and inner_type.decls:
for decl in inner_type.decls:
if isinstance(decl, ca.Decl):
add_type_edges(decl.type, i)
for i in range(len(ast.ext)):
item = ast.ext[i]
if isinstance(item, ca.Decl) and not item.init:
# (Exclude declarations with initializers, since taking function
# pointers can affect regalloc on IDO.)
if item.name:
edges[item.name].append(i)
if isinstance(item.type, (ca.Struct, ca.Union, ca.Enum)) and item.type.name:
can_fwd_declare_tagged.add(item.type.name)
add_type_edges(item.type, i)
elif isinstance(item, ca.Typedef):
edges[item.name].append(i)
if isinstance(item.type, ca.TypeDecl) and isinstance(
item.type.type, (ca.Struct, ca.Union, ca.Enum)
):
can_fwd_declare_typedef.add(item.name)
add_type_edges(item.type, i)
elif isinstance(item, ca.Pragma) and "GLOBAL_ASM" in item.string:
pass
else:
gc_roots.append(i)
mentioned_ids: Set[str] = set()
class IdVisitor(ca.NodeVisitor):
def visit_Pragma(self, node: ca.Pragma) -> None:
for token in re.findall(r"[a-zA-Z0-9_$]+", node.string):
mentioned_ids.add(token)
def visit_ID(self, node: ca.ID) -> None:
mentioned_ids.add(node.name)
IdVisitor().visit(ast)
# Do the GC as a DFS traversal of the graph. Visiting a node searches its
# AST for all kinds of mentioned IDs, and adds more nodes to the stack
# using the edges we found before.
gc_todo: List[int] = gc_roots
need_fwd_decl_typedef: List[str] = []
need_fwd_decl_tagged: List[str] = []
def add_name(name: str) -> None:
if name in edges:
gc_todo.extend(edges[name])
del edges[name]
class Visitor(ca.NodeVisitor):
def visit_Pragma(self, node: ca.Pragma) -> None:
for token in re.findall(r"[a-zA-Z0-9_$]+", node.string):
add_name(token)
def visit_ID(self, node: ca.ID) -> None:
add_name(node.name)
def visit_IdentifierType(self, node: ca.IdentifierType) -> None:
for name in node.names:
add_name(name)
def visit_Enum(self, node: ca.Enum) -> None:
if node.name and not node.values:
add_name(node.name)
self.generic_visit(node)
def visit_Struct(self, node: ca.Struct) -> None:
if node.name and not node.decls:
add_name(node.name)
self.generic_visit(node)
def visit_Union(self, node: ca.Union) -> None:
if node.name and not node.decls:
add_name(node.name)
self.generic_visit(node)
def visit_PtrDecl(self, node: ca.PtrDecl) -> None:
# For pointer declarations which haven't been accessed, forward
# declarations suffice.
if (
isinstance(node.type, ca.TypeDecl)
and node.type.declname
and node.type.declname not in mentioned_ids
):
tp = node.type.type
if isinstance(tp, ca.IdentifierType):
if all(name in can_fwd_declare_typedef for name in tp.names):
need_fwd_decl_typedef.extend(tp.names)
return
elif tp.name and tp.name in can_fwd_declare_tagged:
if not (tp.values if isinstance(tp, ca.Enum) else tp.decls):
need_fwd_decl_tagged.append(tp.name)
return
self.generic_visit(node)
def visit_TypeDecl(self, node: ca.TypeDecl) -> None:
if node.declname:
add_name(node.declname)
self.generic_visit(node)
keep_exts: Set[int] = set()
while gc_todo:
i = gc_todo.pop()
if i not in keep_exts:
keep_exts.add(i)
Visitor().visit(ast.ext[i])
temp_id = 0
def fwd_declare(tp: Union[ca.Struct, ca.Union, ca.Enum]) -> None:
nonlocal temp_id
if not tp.name:
temp_id += 1
tp.name = f"_PermuterTemp{temp_id}"
if isinstance(tp, (ca.Struct, ca.Union)):
tp.decls = None
elif isinstance(tp, ca.Enum):
tp.values = None
else:
assert False
new_ext = []
for i, item in enumerate(ast.ext):
if i in keep_exts:
pass
elif isinstance(item, ca.Typedef) and item.name in need_fwd_decl_typedef:
assert item.name in can_fwd_declare_typedef
assert isinstance(item.type, ca.TypeDecl)
assert isinstance(item.type.type, (ca.Struct, ca.Union, ca.Enum))
fwd_declare(item.type.type)
elif (
isinstance(item, ca.Decl)
and isinstance(item.type, (ca.Struct, ca.Union, ca.Enum))
and item.type.name
and item.type.name in need_fwd_decl_tagged
):
assert item.type.name in can_fwd_declare_tagged
fwd_declare(item.type)
else:
continue
new_ext.append(item)
ast.ext = new_ext
return ast.ext.index(fn)
-99
View File
@@ -1,99 +0,0 @@
import copy
from dataclasses import dataclass, field
import functools
from typing import Optional, Tuple
from pycparser import c_ast as ca
from .compiler import Compiler
from .randomizer import Randomizer
from .scorer import Scorer
from .perm.perm import EvalState
from .perm.ast import apply_ast_perms
from .helpers import try_remove
from .profiler import Profiler
from . import ast_util
@dataclass
class CandidateResult:
"""Represents the result of scoring a candidate, and is sent from child to
parent processes, or server to client with p@h."""
score: int
hash: Optional[str]
source: Optional[str]
profiler: Optional[Profiler] = None
@dataclass
class Candidate:
"""
Represents a AST candidate created from a source which can be randomized
(possibly multiple times), compiled, and scored.
"""
ast: ca.FileAST
fn_index: int
rng_seed: int
randomizer: Randomizer
score_value: Optional[int] = field(init=False, default=None)
score_hash: Optional[str] = field(init=False, default=None)
_cache_source: Optional[str] = field(init=False, default=None)
@staticmethod
@functools.lru_cache(maxsize=16)
def _cached_shared_ast(
source: str, fn_name: str
) -> Tuple[ca.FuncDef, int, ca.FileAST]:
ast = ast_util.parse_c(source)
orig_fn, fn_index = ast_util.extract_fn(ast, fn_name)
ast_util.normalize_ast(orig_fn, ast)
return orig_fn, fn_index, ast
@staticmethod
def from_source(
source: str, eval_state: EvalState, fn_name: str, rng_seed: int
) -> "Candidate":
# Use the same AST for all instances of the same original source, but
# with the target function deeply copied. Since we never change the
# AST outside of the target function, this is fine, and it saves us
# performance (deepcopy is really slow).
orig_fn, fn_index, ast = Candidate._cached_shared_ast(source, fn_name)
ast = copy.copy(ast)
ast.ext = copy.copy(ast.ext)
fn_copy = copy.deepcopy(orig_fn)
ast.ext[fn_index] = fn_copy
apply_ast_perms(fn_copy, eval_state)
return Candidate(
ast=ast,
fn_index=fn_index,
rng_seed=rng_seed,
randomizer=Randomizer(rng_seed),
)
def randomize_ast(self) -> None:
self.randomizer.randomize(self.ast, self.fn_index)
self._cache_source = None
def get_source(self) -> str:
if self._cache_source is None:
self._cache_source = ast_util.to_c(self.ast)
return self._cache_source
def compile(self, compiler: Compiler, show_errors: bool = False) -> Optional[str]:
source: str = self.get_source()
return compiler.compile(source, show_errors=show_errors)
def score(self, scorer: Scorer, o_file: Optional[str]) -> CandidateResult:
self.score_value = None
self.score_hash = None
try:
self.score_value, self.score_hash = scorer.score(o_file)
finally:
if o_file:
try_remove(o_file)
return CandidateResult(
score=self.score_value, hash=self.score_hash, source=self.get_source()
)
-48
View File
@@ -1,48 +0,0 @@
from typing import Optional
import tempfile
import subprocess
from .helpers import try_remove
class Compiler:
def __init__(self, compile_cmd: str, *, show_errors: bool) -> None:
self.compile_cmd = compile_cmd
self.show_errors = show_errors
def compile(self, source: str, *, show_errors: bool = False) -> Optional[str]:
"""Try to compile a piece of C code. Returns the filename of the resulting .o
temp file if it succeeds."""
show_errors = show_errors or self.show_errors
with tempfile.NamedTemporaryFile(
prefix="permuter", suffix=".c", mode="w", delete=False
) as f:
c_name = f.name
f.write(source)
with tempfile.NamedTemporaryFile(
prefix="permuter", suffix=".o", delete=False
) as f2:
o_name = f2.name
try:
stderr = 2 if show_errors else subprocess.DEVNULL
subprocess.check_call(
[self.compile_cmd, c_name, "-o", o_name],
stdout=stderr,
stderr=stderr,
)
except subprocess.CalledProcessError:
if not show_errors:
try_remove(c_name)
try_remove(o_name)
return None
except KeyboardInterrupt:
# If Ctrl+C happens during this call, make a best effort in
# removing the .c and .o files. This is totally racy, but oh well...
try_remove(c_name)
try_remove(o_name)
raise
try_remove(c_name)
return o_name
-11
View File
@@ -1,11 +0,0 @@
from dataclasses import dataclass
@dataclass
class ServerError(Exception):
message: str
@dataclass
class CandidateConstructionFailure(Exception):
message: str
-22
View File
@@ -1,22 +0,0 @@
import os
from typing import NoReturn
def plural(n: int, noun: str) -> str:
s = "s" if n != 1 else ""
return f"{n} {noun}{s}"
def exception_to_string(e: object) -> str:
return str(e) or e.__class__.__name__
def static_assert_unreachable(x: NoReturn) -> NoReturn:
raise Exception("Unreachable! " + repr(x))
def try_remove(path: str) -> None:
try:
os.remove(path)
except FileNotFoundError:
pass
-654
View File
@@ -1,654 +0,0 @@
import argparse
from dataclasses import dataclass, field
import itertools
import multiprocessing
from multiprocessing import Queue
import os
import queue
import sys
import threading
import time
from typing import (
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
)
from .candidate import CandidateResult
from .compiler import Compiler
from .error import CandidateConstructionFailure
from .helpers import plural, static_assert_unreachable
from .net.client import start_client
from .net.core import ServerError, connect, enable_debug_mode, MAX_PRIO, MIN_PRIO
from .permuter import (
EvalError,
EvalResult,
Feedback,
Finished,
Message,
NeedMoreWork,
Permuter,
Task,
WorkDone,
)
from .preprocess import preprocess
from .printer import Printer
from .profiler import Profiler
from .scorer import Scorer
# The probability that the randomizer continues transforming the output it
# generated last time.
DEFAULT_RAND_KEEP_PROB = 0.6
@dataclass
class Options:
directories: List[str]
show_errors: bool = False
show_timings: bool = False
print_diffs: bool = False
stack_differences: bool = False
abort_exceptions: bool = False
better_only: bool = False
best_only: bool = False
quiet: bool = False
stop_on_zero: bool = False
keep_prob: float = DEFAULT_RAND_KEEP_PROB
force_seed: Optional[str] = None
threads: int = 1
use_network: bool = False
network_debug: bool = False
network_priority: float = 1.0
def restricted_float(lo: float, hi: float) -> Callable[[str], float]:
def convert(x: str) -> float:
try:
ret = float(x)
except ValueError:
raise argparse.ArgumentTypeError(f"invalid float value: '{x}'")
if ret < lo or ret > hi:
raise argparse.ArgumentTypeError(
f"value {x} is out of range (must be between {lo} and {hi})"
)
return ret
return convert
@dataclass
class EvalContext:
options: Options
printer: Printer = field(default_factory=Printer)
iteration: int = 0
errors: int = 0
overall_profiler: Profiler = field(default_factory=Profiler)
permuters: List[Permuter] = field(default_factory=list)
def write_candidate(perm: Permuter, result: CandidateResult) -> None:
"""Write the candidate's C source and score to the next output directory"""
ctr = 0
while True:
ctr += 1
try:
output_dir = os.path.join(perm.dir, f"output-{result.score}-{ctr}")
os.mkdir(output_dir)
break
except FileExistsError:
pass
source = result.source
assert source is not None, "Permuter._need_to_send_source is wrong!"
with open(os.path.join(output_dir, "source.c"), "x", encoding="utf-8") as f:
f.write(source)
with open(os.path.join(output_dir, "score.txt"), "x", encoding="utf-8") as f:
f.write(f"{result.score}\n")
with open(os.path.join(output_dir, "diff.txt"), "x", encoding="utf-8") as f:
f.write(perm.diff(source) + "\n")
print(f"wrote to {output_dir}")
def post_score(
context: EvalContext, permuter: Permuter, result: EvalResult, who: Optional[str]
) -> bool:
if isinstance(result, EvalError):
if result.exc_str is not None:
context.printer.print(
"internal permuter failure.", permuter, who, keep_progress=True
)
print(result.exc_str)
if result.seed is not None:
seed_str = str(result.seed[1])
if result.seed[0] != 0:
seed_str = f"{result.seed[0]},{seed_str}"
print(f"To reproduce the failure, rerun with: --seed {seed_str}")
if context.options.abort_exceptions:
sys.exit(1)
else:
return False
if context.options.print_diffs:
assert result.source is not None, "Permuter._need_to_send_source is wrong"
print()
print(permuter.diff(result.source))
input("Press any key to continue...")
profiler = result.profiler
score_value = result.score
if profiler is not None:
for stattype in profiler.time_stats:
context.overall_profiler.add_stat(stattype, profiler.time_stats[stattype])
context.iteration += 1
if score_value == permuter.scorer.PENALTY_INF:
disp_score = "inf"
context.errors += 1
else:
disp_score = str(score_value)
timings = ""
if context.options.show_timings:
timings = " \t" + context.overall_profiler.get_str_stats()
status_line = f"iteration {context.iteration}, {context.errors} errors, score = {disp_score}{timings}"
if permuter.should_output(result):
former_best = permuter.best_score
permuter.record_result(result)
if score_value < former_best:
color = "\u001b[32;1m"
msg = f"found new best score! ({score_value} vs {permuter.base_score})"
elif score_value == former_best:
color = "\u001b[32;1m"
msg = f"tied best score! ({score_value} vs {permuter.base_score})"
elif score_value < permuter.base_score:
color = "\u001b[33m"
msg = f"found a better score! ({score_value} vs {permuter.base_score})"
else:
color = "\u001b[33m"
msg = f"found different asm with same score ({score_value})"
context.printer.print(msg, permuter, who, color=color)
write_candidate(permuter, result)
if not context.options.quiet:
context.printer.progress(status_line)
return score_value == 0
def cycle_seeds(permuters: List[Permuter]) -> Iterable[Tuple[int, int]]:
"""
Return all possible (permuter index, seed) pairs, cycling over permuters.
If a permuter is randomized, it will keep repeating seeds infinitely.
"""
iterators: List[Iterator[Tuple[int, int]]] = []
for perm_ind, permuter in enumerate(permuters):
it = permuter.seed_iterator()
iterators.append(zip(itertools.repeat(perm_ind), it))
i = 0
while iterators:
i %= len(iterators)
item = next(iterators[i], None)
if item is None:
del iterators[i]
i -= 1
else:
yield item
i += 1
def multiprocess_worker(
permuters: List[Permuter],
input_queue: "Queue[Task]",
output_queue: "Queue[Feedback]",
) -> None:
try:
while True:
# Read a work item from the queue. If none is immediately available,
# tell the main thread to fill the queues more, and then block on
# the queue.
queue_item: Task
try:
queue_item = input_queue.get(block=False)
except queue.Empty:
output_queue.put((NeedMoreWork(), -1, None))
queue_item = input_queue.get()
if isinstance(queue_item, Finished):
output_queue.put((queue_item, -1, None))
output_queue.close()
break
permuter_index, seed = queue_item
permuter = permuters[permuter_index]
result = permuter.try_eval_candidate(seed)
if isinstance(result, CandidateResult) and permuter.should_output(result):
permuter.record_result(result)
output_queue.put((WorkDone(permuter_index, result), -1, None))
output_queue.put((NeedMoreWork(), -1, None))
except KeyboardInterrupt:
# Don't clutter the output with stack traces; Ctrl+C is the expected
# way to quit and sends KeyboardInterrupt to all processes.
# A heartbeat thing here would be good but is too complex.
# Don't join the queue background thread -- thread joins in relation
# to KeyboardInterrupt usually result in deadlocks.
input_queue.cancel_join_thread()
output_queue.cancel_join_thread()
def run(options: Options) -> List[int]:
last_time = time.time()
try:
def heartbeat() -> None:
nonlocal last_time
last_time = time.time()
return run_inner(options, heartbeat)
except KeyboardInterrupt:
if time.time() - last_time > 5:
print()
print("Aborting stuck process.")
raise
print()
print("Exiting.")
sys.exit(0)
def run_inner(options: Options, heartbeat: Callable[[], None]) -> List[int]:
print("Loading...")
context = EvalContext(options)
force_seed: Optional[int] = None
force_rng_seed: Optional[int] = None
if options.force_seed:
seed_parts = list(map(int, options.force_seed.split(",")))
force_rng_seed = seed_parts[-1]
force_seed = 0 if len(seed_parts) == 1 else seed_parts[0]
name_counts: Dict[str, int] = {}
for i, d in enumerate(options.directories):
heartbeat()
compile_cmd = os.path.join(d, "compile.sh")
target_o = os.path.join(d, "target.o")
base_c = os.path.join(d, "base.c")
for fname in [compile_cmd, target_o, base_c]:
if not os.path.isfile(fname):
print(f"Missing file {fname}", file=sys.stderr)
sys.exit(1)
if not os.stat(compile_cmd).st_mode & 0o100:
print(f"{compile_cmd} must be marked executable.", file=sys.stderr)
sys.exit(1)
fn_name: Optional[str] = None
try:
with open(os.path.join(d, "function.txt"), encoding="utf-8") as f:
fn_name = f.read().strip()
except FileNotFoundError:
pass
if fn_name:
print(f"{base_c} ({fn_name})")
else:
print(base_c)
compiler = Compiler(compile_cmd, show_errors=options.show_errors)
scorer = Scorer(target_o, stack_differences=options.stack_differences)
c_source = preprocess(base_c)
try:
permuter = Permuter(
d,
fn_name,
compiler,
scorer,
base_c,
c_source,
force_seed=force_seed,
force_rng_seed=force_rng_seed,
keep_prob=options.keep_prob,
need_profiler=options.show_timings,
need_all_sources=options.print_diffs,
show_errors=options.show_errors,
best_only=options.best_only,
better_only=options.better_only,
)
except CandidateConstructionFailure as e:
print(e.message, file=sys.stderr)
sys.exit(1)
context.permuters.append(permuter)
name_counts[permuter.fn_name] = name_counts.get(permuter.fn_name, 0) + 1
print()
if not context.permuters:
print("No permuters!")
return []
for permuter in context.permuters:
if name_counts[permuter.fn_name] > 1:
permuter.unique_name += f" ({permuter.dir})"
print(f"[{permuter.unique_name}] base score = {permuter.best_score}")
found_zero = False
if options.threads == 1 and not options.use_network:
# Simple single-threaded mode. This is not technically needed, but
# makes the permuter easier to debug.
for permuter_index, seed in cycle_seeds(context.permuters):
heartbeat()
permuter = context.permuters[permuter_index]
result = permuter.try_eval_candidate(seed)
if post_score(context, permuter, result, None):
found_zero = True
if options.stop_on_zero:
break
else:
seed_iterators: List[Optional[Iterator[int]]] = [
permuter.seed_iterator()
for perm_ind, permuter in enumerate(context.permuters)
]
seed_iterators_remaining = len(seed_iterators)
next_iterator_index = 0
# Create queues.
worker_task_queue: "Queue[Task]" = Queue()
feedback_queue: "Queue[Feedback]" = Queue()
# Connect to network and create client threads and queues.
net_conns: "List[Tuple[threading.Thread, Queue[Task]]]" = []
if options.use_network:
print("Connecting to permuter@home...")
if options.network_debug:
enable_debug_mode()
first_stats: Optional[Tuple[int, int, float]] = None
for perm_index in range(len(context.permuters)):
try:
port = connect()
except (EOFError, ServerError) as e:
print("Error:", e)
sys.exit(1)
thread, queue, stats = start_client(
port,
context.permuters[perm_index],
perm_index,
feedback_queue,
options.network_priority,
)
net_conns.append((thread, queue))
if first_stats is None:
first_stats = stats
assert first_stats is not None, "has at least one permuter"
clients_str = plural(first_stats[0], "other client")
servers_str = plural(first_stats[1], "server")
cores_str = plural(int(first_stats[2]), "core")
print(f"Connected! {servers_str} online ({cores_str}, {clients_str})")
# Start local worker threads
processes: List[multiprocessing.Process] = []
for i in range(options.threads):
p = multiprocessing.Process(
target=multiprocess_worker,
args=(context.permuters, worker_task_queue, feedback_queue),
)
p.start()
processes.append(p)
active_workers = len(processes)
if not active_workers and not net_conns:
print("No workers available! Exiting.")
sys.exit(1)
def process_finish(finish: Finished, source: int) -> None:
nonlocal active_workers
if finish.reason:
permuter: Optional[Permuter] = None
if source != -1 and len(context.permuters) > 1:
permuter = context.permuters[source]
context.printer.print(finish.reason, permuter, None, keep_progress=True)
if source == -1:
active_workers -= 1
def process_result(work: WorkDone, who: Optional[str]) -> bool:
permuter = context.permuters[work.perm_index]
return post_score(context, permuter, work.result, who)
def get_task(perm_index: int) -> Optional[Tuple[int, int]]:
nonlocal next_iterator_index, seed_iterators_remaining
if perm_index == -1:
while seed_iterators_remaining > 0:
task = get_task(next_iterator_index)
next_iterator_index += 1
next_iterator_index %= len(seed_iterators)
if task is not None:
return task
else:
it = seed_iterators[perm_index]
if it is not None:
seed = next(it, None)
if seed is None:
seed_iterators[perm_index] = None
seed_iterators_remaining -= 1
else:
return (perm_index, seed)
return None
# Feed the task queue with work and read from results queue.
# We generally match these up one-by-one to avoid overfilling queues,
# but workers can ask us to add more tasks into the system if they run
# out of work. (This will happen e.g. at the very beginning, when the
# queues are empty.)
while seed_iterators_remaining > 0:
heartbeat()
feedback, source, who = feedback_queue.get()
if isinstance(feedback, Finished):
process_finish(feedback, source)
elif isinstance(feedback, Message):
context.printer.print(feedback.text, None, who, keep_progress=True)
elif isinstance(feedback, WorkDone):
if process_result(feedback, who):
# Found score 0!
found_zero = True
if options.stop_on_zero:
break
elif isinstance(feedback, NeedMoreWork):
task = get_task(source)
if task is not None:
if source == -1:
worker_task_queue.put(task)
else:
net_conns[source][1].put(task)
else:
static_assert_unreachable(feedback)
# Signal workers to stop.
for i in range(active_workers):
worker_task_queue.put(Finished())
for conn in net_conns:
conn[1].put(Finished())
# Await final results.
while active_workers > 0 or net_conns:
heartbeat()
feedback, source, who = feedback_queue.get()
if isinstance(feedback, Finished):
process_finish(feedback, source)
elif isinstance(feedback, Message):
context.printer.print(feedback.text, None, who, keep_progress=True)
elif isinstance(feedback, WorkDone):
if not (options.stop_on_zero and found_zero):
if process_result(feedback, who):
found_zero = True
elif isinstance(feedback, NeedMoreWork):
pass
else:
static_assert_unreachable(feedback)
# Wait for workers to finish.
for p in processes:
p.join()
# Wait for network connections to close (currently does not happen).
for conn in net_conns:
conn[0].join()
if found_zero:
print("\nFound zero score! Exiting.")
return [permuter.best_score for permuter in context.permuters]
def main() -> None:
multiprocessing.freeze_support()
sys.setrecursionlimit(10000)
# Ideally we would do:
# multiprocessing.set_start_method("spawn")
# here, to make multiprocessing behave the same across operating systems.
# However, that means that arguments to Process are passed across using
# pickling, which mysteriously breaks with pycparser...
# (AttributeError: 'CParser' object has no attribute 'p_abstract_declarator_opt')
# So, for now we live with the defaults, which make multiprocessing work on Linux,
# where it uses fork and doesn't pickle arguments, and break on Windows. Sigh.
parser = argparse.ArgumentParser(
description="Randomly permute C files to better match a target binary."
)
parser.add_argument(
"directories",
nargs="+",
metavar="directory",
help="Directory containing base.c, target.o and compile.sh. Multiple directories may be given.",
)
parser.add_argument(
"--show-errors",
dest="show_errors",
action="store_true",
help="Display compiler error/warning messages, and keep .c files for failed compiles.",
)
parser.add_argument(
"--show-timings",
dest="show_timings",
action="store_true",
help="Display the time taken by permuting vs. compiling vs. scoring.",
)
parser.add_argument(
"--print-diffs",
dest="print_diffs",
action="store_true",
help="Instead of compiling generated sources, display diffs against a base version.",
)
parser.add_argument(
"--abort-exceptions",
dest="abort_exceptions",
action="store_true",
help="Stop execution when an internal permuter exception occurs.",
)
parser.add_argument(
"--better-only",
dest="better_only",
action="store_true",
help="Only report scores better than the base.",
)
parser.add_argument(
"--best-only",
dest="best_only",
action="store_true",
help="Only report ties or new high scores.",
)
parser.add_argument(
"--stop-on-zero",
dest="stop_on_zero",
action="store_true",
help="Stop after producing an output with score 0.",
)
parser.add_argument(
"--quiet",
dest="quiet",
action="store_true",
help="Don't print a status line with the number of iterations.",
)
parser.add_argument(
"--stack-diffs",
dest="stack_differences",
action="store_true",
help="Take stack differences into account when computing the score.",
)
parser.add_argument(
"--keep-prob",
dest="keep_prob",
metavar="PROB",
type=restricted_float(0.0, 1.0),
default=DEFAULT_RAND_KEEP_PROB,
help="""Continue randomizing the previous output with the given probability
(float in 0..1, default %(default)s).""",
)
parser.add_argument("--seed", dest="force_seed", type=str, help=argparse.SUPPRESS)
parser.add_argument(
"-j",
dest="threads",
type=int,
default=0,
help="Number of own threads to use (default: 1 without -J, 0 with -J).",
)
parser.add_argument(
"-J",
dest="use_network",
action="store_true",
help="Harness extra compute power through cyberspace (permuter@home).",
)
parser.add_argument(
"--pah-debug",
dest="network_debug",
action="store_true",
help="Enable debug prints for permuter@home.",
)
parser.add_argument(
"--priority",
dest="network_priority",
metavar="PRIORITY",
type=restricted_float(MIN_PRIO, MAX_PRIO),
default=1.0,
help=f"""Proportion of server resources to use when multiple people
are using -J at the same time.
Defaults to 1.0, meaning resources are split equally, but can be
set to any value within [{MIN_PRIO}, {MAX_PRIO}].
Each server runs with a priority threshold, which defaults to 0.1,
below which they will not run permuter jobs at all.""",
)
args = parser.parse_args()
threads = args.threads
if not threads and not args.use_network:
threads = 1
options = Options(
directories=args.directories,
show_errors=args.show_errors,
show_timings=args.show_timings,
print_diffs=args.print_diffs,
abort_exceptions=args.abort_exceptions,
better_only=args.better_only,
best_only=args.best_only,
quiet=args.quiet,
stack_differences=args.stack_differences,
stop_on_zero=args.stop_on_zero,
keep_prob=args.keep_prob,
force_seed=args.force_seed,
threads=threads,
use_network=args.use_network,
network_debug=args.network_debug,
network_priority=args.network_priority,
)
run(options)
if __name__ == "__main__":
main()
-272
View File
@@ -1,272 +0,0 @@
from multiprocessing import Queue
import re
import threading
from typing import Optional, Tuple
import zlib
from ..candidate import CandidateResult
from ..helpers import exception_to_string
from ..permuter import (
EvalError,
EvalResult,
Feedback,
FeedbackItem,
Finished,
Message,
NeedMoreWork,
Permuter,
Task,
WorkDone,
)
from ..profiler import Profiler
from .core import (
PermuterData,
SocketPort,
json_prop,
permuter_data_to_json,
)
def _profiler_from_json(obj: dict) -> Profiler:
ret = Profiler()
for key in obj:
assert isinstance(key, str), "json properties are strings"
stat = Profiler.StatType[key]
time = json_prop(obj, key, float)
ret.add_stat(stat, time)
return ret
def _result_from_json(obj: dict, source: Optional[str]) -> EvalResult:
if "error" in obj:
return EvalError(exc_str=json_prop(obj, "error", str), seed=None)
profiler: Optional[Profiler] = None
if "profiler" in obj:
profiler = _profiler_from_json(json_prop(obj, "profiler", dict))
return CandidateResult(
score=json_prop(obj, "score", int),
hash=json_prop(obj, "hash", str) if "hash" in obj else None,
source=source,
profiler=profiler,
)
def _make_script_portable(source: str) -> str:
"""Parse a shell script and get rid of the machine-specific parts that
import.py introduces. The resulting script must be run in an environment
that has the right binaries in its $PATH, and with a current working
directory similar to where import.py found its target's make root."""
lines = []
for line in source.split("\n"):
if re.match("cd '?/", line):
# Skip cd's to absolute directory paths. Note that shlex quotes
# its argument with ' if it contains spaces/single quotes.
continue
if re.match("'?/", line):
quote = "'" if line[0] == "'" else ""
ind = line.find(quote + " ")
if ind == -1:
ind = len(line)
else:
ind += len(quote)
lastind = line.rfind("/", 0, ind)
assert lastind != -1
# Emit a call to "which" as the first part, to ensure the called
# binary still sees an absolute path. qemu-irix requires this,
# for some reason.
line = "$(which " + quote + line[lastind + 1 : ind] + ")" + line[ind:]
lines.append(line)
return "\n".join(lines)
def make_portable_permuter(permuter: Permuter) -> PermuterData:
with open(permuter.scorer.target_o, "rb") as f:
target_o_bin = f.read()
with open(permuter.compiler.compile_cmd, "r") as f2:
compile_script = _make_script_portable(f2.read())
return PermuterData(
base_score=permuter.base_score,
base_hash=permuter.base_hash,
fn_name=permuter.fn_name,
filename=permuter.source_file,
keep_prob=permuter.keep_prob,
need_profiler=permuter.need_profiler,
stack_differences=permuter.scorer.stack_differences,
compile_script=compile_script,
source=permuter.source,
target_o_bin=target_o_bin,
)
class Connection:
_port: SocketPort
_permuter_data: PermuterData
_perm_index: int
_task_queue: "Queue[Task]"
_feedback_queue: "Queue[Feedback]"
def __init__(
self,
port: SocketPort,
permuter_data: PermuterData,
perm_index: int,
task_queue: "Queue[Task]",
feedback_queue: "Queue[Feedback]",
) -> None:
self._port = port
self._permuter_data = permuter_data
self._perm_index = perm_index
self._task_queue = task_queue
self._feedback_queue = feedback_queue
def _send_permuter(self) -> None:
data = self._permuter_data
self._port.send_json(permuter_data_to_json(data))
self._port.send(zlib.compress(data.source.encode("utf-8")))
self._port.send(zlib.compress(data.target_o_bin))
def _feedback(self, feedback: FeedbackItem, server_nick: Optional[str]) -> None:
self._feedback_queue.put((feedback, self._perm_index, server_nick))
def _receive_one(self) -> bool:
"""Receive a result/progress message and send it on. Returns true if
more work should be requested."""
msg = self._port.receive_json()
msg_type = json_prop(msg, "type", str)
if msg_type == "need_work":
return True
server_nick = json_prop(msg, "server", str)
if msg_type == "init_done":
base_hash = json_prop(msg, "hash", str)
my_base_hash = self._permuter_data.base_hash
text = "connected"
if base_hash != my_base_hash:
text += " (note: mismatching hash)"
self._feedback(Message(text), server_nick)
return True
if msg_type == "init_failed":
text = "failed to initialize: " + json_prop(msg, "reason", str)
self._feedback(Message(text), server_nick)
return False
if msg_type == "disconnect":
self._feedback(Message("disconnected"), server_nick)
return False
if msg_type == "result":
source: Optional[str] = None
if msg.get("has_source") == True:
# Source is sent separately, compressed, since it can be
# large (hundreds of kilobytes is not uncommon).
compressed_source = self._port.receive()
try:
source = zlib.decompress(compressed_source).decode("utf-8")
except Exception as e:
text = "failed to decompress: " + exception_to_string(e)
self._feedback(Message(text), server_nick)
return True
try:
result = _result_from_json(msg, source)
self._feedback(WorkDone(self._perm_index, result), server_nick)
except Exception as e:
text = "failed to parse result message: " + exception_to_string(e)
self._feedback(Message(text), server_nick)
return True
raise ValueError(f"Invalid message type {msg_type}")
def run(self) -> None:
finish_reason: Optional[str] = None
try:
self._send_permuter()
self._port.receive_json()
finished = False
# Main loop: send messages from the queue on to the server, and
# vice versa. Currently we are being lazy and alternate between
# sending and receiving; this is nicely simple and keeps us on a
# single thread, however it could cause deadlocks if the server
# receiver stops reading because we aren't reading fast enough.
while True:
if not self._receive_one():
continue
self._feedback(NeedMoreWork(), None)
# Read a task and send it on, unless there are no more tasks.
if not finished:
task = self._task_queue.get()
if isinstance(task, Finished):
# We don't have a way of indicating to the server that
# all is done: the server currently doesn't track
# outstanding work so it doesn't know when to close
# the connection. (Even with this fixed we'll have the
# problem that servers may disconnect, losing work, so
# the task never truly finishes. But it might work well
# enough in practice.)
finished = True
else:
work = {
"type": "work",
"work": {
"seed": task[1],
},
}
self._port.send_json(work)
except EOFError:
finish_reason = "disconnected from permuter@home"
except Exception as e:
errmsg = exception_to_string(e)
finish_reason = f"permuter@home error: {errmsg}"
finally:
self._feedback(Finished(reason=finish_reason), None)
self._port.shutdown()
self._port.close()
def start_client(
port: SocketPort,
permuter: Permuter,
perm_index: int,
feedback_queue: "Queue[Feedback]",
priority: float,
) -> "Tuple[threading.Thread, Queue[Task], Tuple[int, int, float]]":
port.send_json(
{
"method": "connect_client",
"priority": priority,
}
)
obj = port.receive_json()
if "error" in obj:
err = json_prop(obj, "error", str)
# TODO use another exception type
raise Exception(f"Failed to connect: {err}")
num_servers = json_prop(obj, "servers", int)
num_clients = json_prop(obj, "clients", int)
num_cores = json_prop(obj, "cores", float)
permuter_data = make_portable_permuter(permuter)
task_queue: "Queue[Task]" = Queue()
conn = Connection(
port,
permuter_data,
perm_index,
task_queue,
feedback_queue,
)
thread = threading.Thread(target=conn.run, daemon=True)
thread.start()
stats = (num_clients, num_servers, num_cores)
return thread, task_queue, stats
-17
View File
@@ -1,17 +0,0 @@
import abc
from argparse import ArgumentParser, Namespace
class Command(abc.ABC):
command: str
help: str
@staticmethod
@abc.abstractmethod
def add_arguments(parser: ArgumentParser) -> None:
...
@staticmethod
@abc.abstractmethod
def run(args: Namespace) -> None:
...
Binary file not shown.

Before

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 110 KiB

-70
View File
@@ -1,70 +0,0 @@
from argparse import ArgumentParser, RawDescriptionHelpFormatter
import sys
from ..core import ServerError, enable_debug_mode
from .run_server import RunServerCommand
from .setup import SetupCommand
from .ping import PingCommand
from .vouch import VouchCommand
def main() -> None:
try:
# We currently sometimes log stuff to stdout, so it's preferable if it's
# line-buffered even when redirected to a non-tty (e.g. when running a
# permuter server as a systemd service). This is supported by Python 3.7
# and up.
sys.stdout.reconfigure(line_buffering=True) # type: ignore
except Exception:
pass
parser = ArgumentParser(
description="permuter@home - run the permuter across the Internet!\n\n"
"To use p@h as a client, just pass -J when running the permuter. "
"This script is\nonly necessary for configuration or when running a server.",
formatter_class=RawDescriptionHelpFormatter,
)
commands = [
PingCommand,
RunServerCommand,
SetupCommand,
VouchCommand,
]
parser.add_argument(
"--debug",
dest="debug",
action="store_true",
help="Enable debug logging.",
)
subparsers = parser.add_subparsers(metavar="<command>")
for command in commands:
subparser = subparsers.add_parser(
command.command,
help=command.help,
description=command.help,
)
command.add_arguments(subparser)
subparser.set_defaults(subcommand_handler=command.run)
args = parser.parse_args()
if args.debug:
enable_debug_mode()
if "subcommand_handler" in args:
try:
args.subcommand_handler(args)
except EOFError as e:
print("Network error:", e)
sys.exit(1)
except ServerError as e:
print("Error:", e.message)
sys.exit(1)
else:
parser.print_help()
if __name__ == "__main__":
main()
-32
View File
@@ -1,32 +0,0 @@
from argparse import ArgumentParser, Namespace
import time
from ...helpers import plural
from ..core import connect, json_prop
from .base import Command
class PingCommand(Command):
command = "ping"
help = "Check server connectivity."
@staticmethod
def add_arguments(parser: ArgumentParser) -> None:
pass
@staticmethod
def run(args: Namespace) -> None:
run_ping()
def run_ping() -> None:
port = connect()
t0 = time.time()
port.send_json({"method": "ping"})
msg = port.receive_json()
rtt = (time.time() - t0) * 1000
print(f"Connected successfully! Round-trip time: {rtt:.1f} ms")
servers_str = plural(json_prop(msg, "servers", int), "server")
clients_str = plural(json_prop(msg, "clients", int), "client")
cores_str = plural(int(json_prop(msg, "cores", float)), "core")
print(f"{servers_str} online ({cores_str}, {clients_str})")
@@ -1,616 +0,0 @@
from argparse import ArgumentParser, Namespace
import base64
from dataclasses import dataclass
from enum import Enum
from functools import partial
import json
import os
import platform
import queue
import random
import shutil
from subprocess import Popen, PIPE
import sys
import time
import threading
import traceback
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
from ...helpers import static_assert_unreachable
from ..core import CancelToken, ServerError, read_config
from ..server import (
Client,
Config,
IoActivity,
IoConnect,
IoDisconnect,
IoImmediateDisconnect,
IoReconnect,
IoServerFailed,
IoShutdown,
IoUserRemovePermuter,
IoWorkDone,
PermuterHandle,
Server,
ServerOptions,
)
from .base import Command
from .util import ask
class RunServerCommand(Command):
command = "run-server"
help = """Run a permuter server, allowing anyone with access to the central
server to run sandboxed permuter jobs on your machine. Requires docker."""
@staticmethod
def add_arguments(parser: ArgumentParser) -> None:
parser.add_argument(
"--cores",
dest="num_cores",
metavar="CORES",
type=float,
required=True,
help="Number of cores to use (float).",
)
parser.add_argument(
"--memory",
dest="max_memory_gb",
metavar="MEMORY_GB",
type=float,
required=True,
help="""Restrict the sandboxed process to the given amount of memory in
gigabytes (float). If this limit is hit, the permuter will crash
horribly, but at least your system won't lock up.""",
)
parser.add_argument(
"--systray",
dest="systray",
action="store_true",
help="""Make the server controllable through the system tray.""",
)
parser.add_argument(
"--min-priority",
dest="min_priority",
metavar="PRIORITY",
type=float,
default=0.1,
help="""Only accept jobs from clients who pass --priority with a number
higher or equal to this value. (default: %(default)s)""",
)
@staticmethod
def run(args: Namespace) -> None:
options = ServerOptions(
num_cores=args.num_cores,
max_memory_gb=args.max_memory_gb,
min_priority=args.min_priority,
)
server_main(options, args.systray)
class SystrayState:
def server_reconnecting(self) -> None:
pass
def server_connected(self) -> None:
pass
def server_failed(self, graceful: bool, message: Optional[str] = None) -> None:
pass
def connect(self, handle: PermuterHandle, nickname: str, fn_name: str) -> None:
pass
def disconnect(self, handle: PermuterHandle) -> None:
pass
def work_done(self, handle: PermuterHandle, is_improvement: bool) -> None:
pass
def stop(self) -> None:
pass
@dataclass
class Permuter:
nickname: str
fn_name: str
iterations: int = 0
improvements: int = 0
last_systray_update: float = 0.0
slot: "Optional[ClientSlot]" = None
@dataclass
class ClientSlot:
menu_id: int
iterations_id: int
improvements_id: int
stop_id: int
permuter: Optional[PermuterHandle] = None
class SystrayStatus(Enum):
CONNECTING = 0
CONNECTED = 1
FAILED = 2
RECONNECTING = 3
class RealSystrayState(SystrayState):
_CLIENT_SLOTS = 10
_UPDATE_INTERVAL = 2.0
_MENU_TOOLTIP = "permuter@home"
_permuters: Dict[PermuterHandle, Permuter]
_onclick: Dict[int, Callable[[], None]]
_client_slots: List[ClientSlot]
def __init__(
self,
config: Config,
io_queue: "queue.Queue[IoActivity]",
) -> None:
self._io_queue = io_queue
self._permuters = {}
self._onclick = {}
self._status = SystrayStatus.CONNECTING
self._fail_message: Optional[str] = None
def load_icon(fname: str) -> str:
path = os.path.join(os.path.dirname(__file__), "icons", fname)
with open(path, "rb") as f:
data = f.read()
return base64.b64encode(data).decode("ascii")
self._icons = {
"working": load_icon("okthink.ico"),
"passive": load_icon("ok.ico"),
"fail": load_icon("notok.ico"),
}
self._current_icon = "working"
next_id = 100
def add_item(
menu: List[dict],
title: str,
onclick: Optional[Callable[[], None]] = None,
*,
submenu: Optional[List[dict]] = None,
hidden: bool = False,
) -> int:
nonlocal next_id
next_id += 1
obj = {
"title": title,
"enabled": onclick is not None or submenu is not None,
"hidden": hidden,
"__id": next_id,
}
if onclick is not None:
self._onclick[next_id] = onclick
if submenu is not None:
obj["items"] = submenu
menu.append(obj)
return next_id
menu: List[dict] = []
self._status_id = add_item(menu, "Connecting...")
self._client_slots = []
for i in range(self._CLIENT_SLOTS):
submenu: List[dict] = []
remove_cb = partial(self._remove_permuter, i)
self._client_slots.append(
ClientSlot(
iterations_id=add_item(submenu, ""),
improvements_id=add_item(submenu, ""),
stop_id=add_item(submenu, "Stop", remove_cb),
menu_id=add_item(menu, "", submenu=submenu, hidden=True),
)
)
self._more_id = add_item(menu, "", hidden=True)
add_item(menu, "Quit", self._quit)
try:
path = self._setup_helper()
self._proc = Popen(
[path],
stdout=PIPE,
stdin=PIPE,
universal_newlines=True,
)
assert self._proc.stdout is not None
self._proc_stdout = self._proc.stdout
assert self._proc.stdin is not None
self._proc_stdin = self._proc.stdin
self._send(
{
"icon": self._icons[self._current_icon],
"tooltip": self._MENU_TOOLTIP,
"items": menu,
}
)
resp_str = self._proc_stdout.readline()
assert resp_str
resp = json.loads(resp_str)
assert isinstance(resp, dict)
assert resp.get("type") == "ready"
except Exception:
print("Failed to initialize systray!")
print()
print("See src/net/cmd/systray/README.md for details on how to set it up.")
traceback.print_exc()
sys.exit(1)
self._read_thread = threading.Thread(target=self._read_loop, daemon=True)
self._read_thread.start()
@staticmethod
def _setup_helper() -> str:
fname = "permuter-systray"
suffix = ""
osname = sys.platform.replace("darwin", "macos")
arch = platform.machine().replace("AMD64", "x86_64")
if (
osname in ("win32", "msys", "cygwin")
or "microsoft" in platform.uname().release.lower()
):
osname = "win"
suffix = ".exe"
dir = os.path.join(os.path.dirname(__file__), "systray")
target_binary = os.path.join(dir, fname + suffix)
if os.path.exists(target_binary):
return target_binary
prebuilt_file = f"{fname}-{osname}-{arch}{suffix}"
prebuilt_file = os.path.join(dir, "prebuilt", prebuilt_file)
print("An external helper binary is required for systray support.")
print(
"To build it from source (requires Go), see src/net/cmd/systray/README.md."
)
if os.path.exists(prebuilt_file):
print("Alternatively, a pre-built binary can be used.")
if ask("Use pre-built binary?", default=False):
shutil.copy(prebuilt_file, target_binary)
os.chmod(target_binary, 0o755)
return target_binary
print("Aborting.")
sys.exit(1)
def _send(self, msg: dict) -> None:
data = json.dumps(msg)
self._proc_stdin.write(data + "\n")
self._proc_stdin.flush()
def _update_item(
self, id: int, title: str, *, hidden: bool = False, enabled: bool = False
) -> None:
self._send(
{
"type": "update-item",
"item": {
"title": title,
"enabled": enabled,
"hidden": hidden,
"__id": id,
},
"seq_id": -1,
}
)
def _remove_permuter(self, slot_index: int) -> None:
slot = self._client_slots[slot_index]
if not slot.permuter:
return
handle = slot.permuter
self._io_queue.put((None, (handle, IoUserRemovePermuter())))
def _quit(self) -> None:
self._io_queue.put((None, IoShutdown()))
def _read_loop(self) -> None:
while True:
resp_str = self._proc_stdout.readline()
if not resp_str:
break
try:
resp = json.loads(resp_str)
except Exception:
raise Exception(f"Failed to parse systray JSON: {resp_str}") from None
if resp["type"] == "clicked":
id = resp["__id"]
if id in self._onclick:
self._onclick[id]()
def _permuter_slot(self, perm: Permuter) -> Optional[ClientSlot]:
for slot in self._client_slots:
if slot.permuter is not None and self._permuters[slot.permuter] is perm:
return slot
return None
def _update_permuter(self, perm: Permuter, slot: ClientSlot) -> None:
self._update_item(
slot.iterations_id,
f"Iterations: {perm.iterations}",
)
self._update_item(
slot.improvements_id,
f"Improvements found: {perm.improvements}",
)
def _update_status(self) -> None:
if self._status == SystrayStatus.CONNECTING:
status = "Reconnecting..."
icon = "working"
elif self._status == SystrayStatus.RECONNECTING:
status = "Disconnected, will reconnect..."
icon = "fail"
elif self._status == SystrayStatus.CONNECTED:
if self._permuters:
status = "Currently permuting:"
icon = "working"
else:
status = "Not running"
icon = "passive"
elif self._status == SystrayStatus.FAILED:
if self._fail_message:
status = f"Error: {self._fail_message}"
else:
status = "Error occurred"
icon = "fail"
else:
assert False, f"bad status {self._status}"
self._update_item(self._status_id, status)
if self._current_icon != icon:
self._current_icon = icon
self._send(
{
"type": "update-menu",
"menu": {
"tooltip": self._MENU_TOOLTIP,
"icon": self._icons[icon],
},
}
)
def _fill_slots(self) -> None:
has_more = False
while True:
key = next((k for k, p in self._permuters.items() if p.slot is None), None)
if key is None:
break
chosen_slot: Optional[ClientSlot] = None
for i in range(self._CLIENT_SLOTS - 1, -1, -1):
slot = self._client_slots[i]
if slot.permuter is None:
chosen_slot = slot
elif chosen_slot is not None:
break
if chosen_slot is None:
has_more = True
break
perm = self._permuters[key]
perm.slot = chosen_slot
chosen_slot.permuter = key
self._update_permuter(perm, chosen_slot)
self._update_item(
chosen_slot.menu_id, f"{perm.fn_name} ({perm.nickname})", enabled=True
)
self._update_item(self._more_id, "More...", hidden=not has_more)
def _hide_slot(self, slot: ClientSlot) -> None:
if slot.permuter is not None:
self._update_item(slot.menu_id, "", hidden=True)
slot.permuter = None
def server_reconnecting(self) -> None:
self._status = SystrayStatus.CONNECTING
self._update_status()
def server_connected(self) -> None:
self._status = SystrayStatus.CONNECTED
self._update_status()
def server_failed(self, graceful: bool, message: Optional[str] = None) -> None:
self._status = SystrayStatus.RECONNECTING if graceful else SystrayStatus.FAILED
self._fail_message = message
self._permuters = {}
self._update_status()
for slot in self._client_slots:
self._hide_slot(slot)
self._fill_slots()
def connect(self, handle: PermuterHandle, nickname: str, fn_name: str) -> None:
perm = Permuter(nickname, fn_name)
self._permuters[handle] = perm
self._fill_slots()
self._update_status()
def disconnect(self, handle: PermuterHandle) -> None:
slot = self._permuters[handle].slot
del self._permuters[handle]
self._update_status()
if slot:
self._hide_slot(slot)
self._fill_slots()
def work_done(self, handle: PermuterHandle, is_improvement: bool) -> None:
perm = self._permuters[handle]
perm.iterations += 1
if is_improvement:
perm.improvements += 1
if perm.slot and time.time() > perm.last_systray_update + self._UPDATE_INTERVAL:
perm.last_systray_update = time.time()
self._update_permuter(perm, perm.slot)
def stop(self) -> None:
try:
self._send({"type": "exit"})
except BrokenPipeError:
# The systray process may have been killed by Ctrl+C.
pass
self._proc.wait()
self._read_thread.join()
class Reconnector:
_RESET_BACKOFF_AFTER_UPTIME: float = 60.0
_RANDOM_ADDEND_MAX: float = 60.0
_BACKOFF_MULTIPLIER: float = 2.0
_INITIAL_DELAY: float = 5.0
_io_queue: "queue.Queue[IoActivity]"
_reconnect_token: CancelToken
_reconnect_delay: float
_reconnect_timer: Optional[threading.Timer]
_start_time: float
_stop_time: float
def __init__(self, io_queue: "queue.Queue[IoActivity]") -> None:
self._io_queue = io_queue
self._reconnect_token = CancelToken()
self._reconnect_delay = self._INITIAL_DELAY
self._reconnect_timer = None
self._start_time = self._stop_time = time.time()
def mark_start(self) -> None:
self._start_time = time.time()
def mark_stop(self) -> None:
self._stop_time = time.time()
def stop(self) -> None:
self._reconnect_token.cancelled = True
if self._reconnect_timer is not None:
self._reconnect_timer.cancel()
self._reconnect_timer.join()
self._reconnect_timer = None
def reconnect_eventually(self) -> int:
if self._stop_time - self._start_time > self._RESET_BACKOFF_AFTER_UPTIME:
delay = self._reconnect_delay = self._INITIAL_DELAY
else:
delay = self._reconnect_delay
self._reconnect_delay = (
self._reconnect_delay * self._BACKOFF_MULTIPLIER
+ random.uniform(1.0, self._RANDOM_ADDEND_MAX)
)
token = CancelToken()
self._reconnect_token = token
self._reconnect_timer = threading.Timer(
delay, lambda: self._io_queue.put((token, IoReconnect()))
)
self._reconnect_timer.daemon = True
self._reconnect_timer.start()
return int(delay)
def main_loop(
io_queue: "queue.Queue[IoActivity]",
server: Server,
systray: SystrayState,
) -> None:
reconnector = Reconnector(io_queue)
handle_clients: Dict[PermuterHandle, Client] = {}
while True:
token, activity = io_queue.get()
if token and token.cancelled:
continue
if not isinstance(activity, tuple):
if isinstance(activity, IoShutdown):
break
elif isinstance(activity, IoReconnect):
print("reconnecting...")
try:
systray.server_reconnecting()
reconnector.mark_start()
server.start()
systray.server_connected()
except EOFError:
delay = reconnector.reconnect_eventually()
print(f"failed again, reconnecting in {delay} seconds...")
systray.server_failed(True)
except ServerError as e:
print("failed!", e.message)
systray.server_failed(False, e.message)
except Exception:
print("failed!")
traceback.print_exc()
systray.server_failed(False)
elif isinstance(activity, IoServerFailed):
if activity.message:
print("Server error:", activity.message)
print("disconnected from permuter@home")
server.stop()
reconnector.mark_stop()
systray.server_failed(activity.graceful, activity.message)
if activity.graceful:
delay = reconnector.reconnect_eventually()
print(f"will reconnect in {delay} seconds...")
else:
static_assert_unreachable(activity)
else:
handle, msg = activity
if isinstance(msg, IoConnect):
client = msg.client
handle_clients[handle] = client
systray.connect(handle, client.nickname, msg.fn_name)
print(f"[{client.nickname}] connected ({msg.fn_name})")
elif isinstance(msg, IoDisconnect):
systray.disconnect(handle)
nickname = handle_clients[handle].nickname
del handle_clients[handle]
print(f"[{nickname}] {msg.reason}")
elif isinstance(msg, IoImmediateDisconnect):
print(f"[{msg.client.nickname}] {msg.reason}")
elif isinstance(msg, IoWorkDone):
# TODO: statistics
systray.work_done(handle, msg.is_improvement)
elif isinstance(msg, IoUserRemovePermuter):
server.remove_permuter(handle)
else:
static_assert_unreachable(msg)
def server_main(options: ServerOptions, use_systray: bool) -> None:
io_queue: "queue.Queue[IoActivity]" = queue.Queue()
config = read_config()
systray: SystrayState
if use_systray:
systray = RealSystrayState(config, io_queue)
else:
systray = SystrayState()
try:
server = Server(options, config, io_queue)
server.start()
try:
systray.server_connected()
main_loop(io_queue, server, systray)
finally:
server.stop()
finally:
systray.stop()
@@ -1,86 +0,0 @@
from argparse import ArgumentParser, Namespace
import base64
import os
import random
import string
import sys
from typing import Optional
from nacl.public import SealedBox
from nacl.signing import SigningKey, VerifyKey
from .base import Command
from ..core import connect, read_config, sign_with_magic, write_config
from .util import ask
class SetupCommand(Command):
command = "setup"
help = """Set up permuter@home. This will require someone else to grant you
access to the central server."""
@staticmethod
def add_arguments(parser: ArgumentParser) -> None:
pass
@staticmethod
def run(args: Namespace) -> None:
_run_initial_setup()
def _random_name() -> str:
return "".join(random.choice(string.ascii_lowercase) for _ in range(5))
def _run_initial_setup() -> None:
config = read_config()
signing_key: Optional[SigningKey] = config.signing_key
if not signing_key or not ask("Keep previous secret key", default=True):
signing_key = SigningKey.generate()
config.signing_key = signing_key
write_config(config)
verify_key = signing_key.verify_key
nickname: Optional[str] = config.initial_setup_nickname
if not nickname or not ask(f"Keep previous nickname [{nickname}]", default=True):
default_nickname = os.environ.get("USER") or _random_name()
nickname = (
input(f"Nickname [default: {default_nickname}]: ") or default_nickname
)
config.initial_setup_nickname = nickname
write_config(config)
signed_nickname = sign_with_magic(b"NAME", signing_key, nickname.encode("utf-8"))
vouch_data = verify_key.encode() + signed_nickname
vouch_text = base64.b64encode(vouch_data).decode("utf-8")
print("Ask someone to run the following command:")
print(f"./pah.py vouch {vouch_text}")
print()
print("They should give you a token back in return. Paste that here:")
inp = input().strip()
try:
token = base64.b64decode(inp.encode("utf-8"))
data = SealedBox(signing_key.to_curve25519_private_key()).decrypt(token)
config.server_address = data[32:].decode("utf-8")
config.server_verify_key = VerifyKey(data[:32])
config.initial_setup_nickname = None
except Exception:
print("Invalid token!")
sys.exit(1)
print(f"Server: {config.server_address}")
print("Testing connection...")
port = connect(config)
port.send_json({"method": "ping"})
port.receive_json()
try:
write_config(config)
except Exception as e:
print("Failed to write config:", e)
sys.exit(1)
print("permuter@home successfully set up!")
@@ -1,2 +0,0 @@
permuter-systray
permuter-systray.exe
@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2017 Zack Young
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
@@ -1,13 +0,0 @@
# systray
This directory contains a Go application that shows a system tray, which the Python code interacts with.
It is a fork of https://github.com/felixhao28/systray-portable.
To build it:
- install Go
- if on Linux, install dependencies: `libgtk-3-dev`, `libappindicator3-dev`
- run `go build`
If on Windows, this needs to be done *outside* of WSL.
@@ -1,7 +0,0 @@
module permuter-systray
go 1.15
require github.com/getlantern/systray v1.1.0
replace github.com/getlantern/systray v1.1.0 => github.com/simonlindholm/systray v1.1.1-0.20210502122945-b7c77212cd56
@@ -1,4 +0,0 @@
github.com/simonlindholm/systray v1.1.1-0.20210502122945-b7c77212cd56 h1:UZcM1HdV25CQhhJD340jxRLRGl0V11V0wIoUDKTOZMI=
github.com/simonlindholm/systray v1.1.1-0.20210502122945-b7c77212cd56/go.mod h1:N5dpnnWiJhCxh+gXuNgDS2p5MjgcVR/TGwWuaDc4gLk=
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9 h1:YTzHMGlqJu67/uEo1lBv0n3wBXhXNeUbB1XfN2vmTm0=
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -1,287 +0,0 @@
package main
import (
"bufio"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"os/signal"
"reflect"
"strings"
"syscall"
"github.com/getlantern/systray"
)
func main() {
systray.Run(onReady, onExit)
}
func onExit() {
os.Exit(0)
}
// Item represents an item in the menu
type Item struct {
Icon string `json:"icon"`
Title string `json:"title"`
Tooltip string `json:"tooltip"`
Enabled bool `json:"enabled"`
Checked bool `json:"checked"`
Hidden bool `json:"hidden"`
Items []Item `json:"items"`
InternalID int `json:"__id"`
}
// Menu has an icon, title and list of items
type Menu struct {
Icon string `json:"icon"`
Title string `json:"title"`
Tooltip string `json:"tooltip"`
Items []Item `json:"items"`
}
// Action for an item?..
type Action struct {
Type string `json:"type"`
Item Item `json:"item"`
Menu Menu `json:"menu"`
}
// ClickEvent for an click event
type ClickEvent struct {
Type string `json:"type"`
InternalID int `json:"__id"`
}
func readJSON(reader *bufio.Reader, v interface{}) error {
input, err := reader.ReadString('\n')
if err != nil {
return err
}
if len(input) < 1 {
return fmt.Errorf("Empty line")
}
lineReader := strings.NewReader(input[0 : len(input)-1])
if err := json.NewDecoder(lineReader).Decode(v); err != nil {
return err
}
return nil
}
func addMenuItem(items *[]*systray.MenuItem, seqID2InternalID *[]int, internalID2SeqID *map[int]int, item *Item, parent *systray.MenuItem) {
if item.Title == "<SEPARATOR>" {
systray.AddSeparator()
*items = append(*items, nil)
} else {
var menuItem *systray.MenuItem
if parent == nil {
menuItem = systray.AddMenuItem(item.Title, item.Tooltip)
} else {
menuItem = parent.AddSubMenuItem(item.Title, item.Tooltip)
}
if item.Checked {
menuItem.Check()
} else {
menuItem.Uncheck()
}
if item.Enabled {
menuItem.Enable()
} else {
menuItem.Disable()
}
if len(item.Icon) > 0 {
icon, err := base64.StdEncoding.DecodeString(item.Icon)
if err != nil {
fmt.Fprintln(os.Stderr, err)
} else {
menuItem.SetIcon(icon)
}
}
for i := 0; i < len(item.Items); i++ {
subitem := item.Items[i]
addMenuItem(items, seqID2InternalID, internalID2SeqID, &subitem, menuItem)
}
if item.Hidden {
menuItem.Hide()
}
*items = append(*items, menuItem)
}
seqID := len(*items) - 1
(*internalID2SeqID)[item.InternalID] = seqID
*seqID2InternalID = append(*seqID2InternalID, item.InternalID)
}
func onReady() {
signalChannel := make(chan os.Signal, 2)
signal.Notify(signalChannel, os.Interrupt, syscall.SIGTERM)
go func() {
for sig := range signalChannel {
switch sig {
case os.Interrupt, syscall.SIGTERM:
// handle SIGINT, SIGTERM
fmt.Fprintln(os.Stderr, "Quit")
systray.Quit()
default:
fmt.Fprintln(os.Stderr, "Unhandled signal:", sig)
}
}
}()
items := make([]*systray.MenuItem, 0)
seqID2InternalID := make([]int, 0)
internalID2SeqID := make(map[int]int)
fmt.Println(`{"type": "ready"}`)
reader := bufio.NewReader(os.Stdin)
var menu Menu
if err := readJSON(reader, &menu); err != nil {
fmt.Fprintln(os.Stderr, err)
systray.Quit()
return
}
icon, err := base64.StdEncoding.DecodeString(menu.Icon)
if err != nil {
fmt.Fprintln(os.Stderr, err)
systray.Quit()
return
}
systray.SetIcon(icon)
systray.SetTitle(menu.Title)
systray.SetTooltip(menu.Tooltip)
updateItem := func(action Action) {
item := action.Item
seqID := internalID2SeqID[action.Item.InternalID]
menuItem := items[seqID]
if menuItem == nil {
return
}
if item.Hidden {
menuItem.Hide()
} else {
if item.Checked {
menuItem.Check()
} else {
menuItem.Uncheck()
}
if item.Enabled {
menuItem.Enable()
} else {
menuItem.Disable()
}
menuItem.SetTitle(item.Title)
menuItem.SetTooltip(item.Tooltip)
if len(item.Icon) > 0 {
icon, err := base64.StdEncoding.DecodeString(item.Icon)
if err != nil {
fmt.Fprintln(os.Stderr, err)
} else {
menuItem.SetIcon(icon)
}
}
menuItem.Show()
for _, child := range item.Items {
seqID = internalID2SeqID[child.InternalID]
items[seqID].Show()
}
}
}
updateMenu := func(action Action) {
m := action.Menu
if menu.Title != m.Title {
menu.Title = m.Title
systray.SetTitle(menu.Title)
}
if menu.Icon != m.Icon && m.Icon != "" {
menu.Icon = m.Icon
icon, err := base64.StdEncoding.DecodeString(menu.Icon)
if err != nil {
fmt.Fprintln(os.Stderr, err)
} else {
systray.SetIcon(icon)
}
}
if menu.Tooltip != m.Tooltip {
menu.Tooltip = m.Tooltip
systray.SetTooltip(menu.Tooltip)
}
}
update := func(action Action) {
switch action.Type {
case "update-item":
updateItem(action)
case "update-menu":
updateMenu(action)
case "update-item-and-menu":
updateItem(action)
updateMenu(action)
case "exit":
systray.Quit()
}
}
for i := 0; i < len(menu.Items); i++ {
item := menu.Items[i]
addMenuItem(&items, &seqID2InternalID, &internalID2SeqID, &item, nil)
}
go func(reader *bufio.Reader) {
for {
var action Action
if err := readJSON(reader, &action); err != nil {
fmt.Fprintln(os.Stderr, err)
systray.Quit()
break
}
update(action)
}
}(reader)
stdoutEnc := json.NewEncoder(os.Stdout)
for {
itemsCnt := 0
for _, ch := range items {
if ch != nil {
itemsCnt++
}
}
cases := make([]reflect.SelectCase, itemsCnt)
caseCnt2SeqID := make([]int, len(items))
itemsCnt = 0
for i, ch := range items {
if ch == nil {
continue
}
cases[itemsCnt] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch.ClickedCh)}
caseCnt2SeqID[itemsCnt] = i
itemsCnt++
}
remaining := len(cases)
for remaining > 0 {
chosen, _, ok := reflect.Select(cases)
if !ok {
// The chosen channel has been closed, so zero out the channel to disable the case
cases[chosen].Chan = reflect.ValueOf(nil)
remaining--
continue
}
seqID := caseCnt2SeqID[chosen]
err := stdoutEnc.Encode(ClickEvent{
Type: "clicked",
InternalID: seqID2InternalID[seqID],
})
if err != nil {
fmt.Fprintln(os.Stderr, err)
}
}
}
}
-15
View File
@@ -1,15 +0,0 @@
import sys
def ask(msg: str, *, default: bool) -> bool:
if default:
msg += " (Y/n)? "
else:
msg += " (y/N)? "
res = input(msg).strip().lower()
if not res:
return default
if res in ["y", "yes", "n", "no"]:
return res[0] == "y"
print("Bad response!")
sys.exit(1)
@@ -1,73 +0,0 @@
from argparse import ArgumentParser, Namespace
import base64
import sys
from nacl.encoding import HexEncoder
from nacl.public import SealedBox
from nacl.signing import VerifyKey
from ..core import connect, read_config, verify_with_magic
from .base import Command
from .util import ask
class VouchCommand(Command):
command = "vouch"
help = "Give someone access to the central server."
@staticmethod
def add_arguments(parser: ArgumentParser) -> None:
parser.add_argument(
"magic",
help="Opaque hex string generated by 'setup'.",
)
@staticmethod
def run(args: Namespace) -> None:
run_vouch(args.magic)
def run_vouch(magic: str) -> None:
try:
vouch_data = base64.b64decode(magic.encode("utf-8"))
verify_key = VerifyKey(vouch_data[:32])
signed_nickname = vouch_data[32:]
msg = verify_with_magic(b"NAME", verify_key, signed_nickname)
nickname = msg.decode("utf-8")
except Exception:
print("Could not parse data!")
sys.exit(1)
try:
config = read_config()
port = connect(config)
port.send_json(
{
"method": "vouch",
"who": verify_key.encode(HexEncoder).decode("utf-8"),
"signed_name": HexEncoder.encode(signed_nickname).decode("utf-8"),
}
)
port.receive_json()
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
if not ask(f"Grant permuter server access to {nickname}", default=True):
return
try:
port.send_json({})
port.receive_json()
except Exception as e:
print(f"Failed to grant access: {e}")
sys.exit(1)
assert config.server_address, "checked by connect"
assert config.server_verify_key, "checked by connect"
data = config.server_verify_key.encode() + config.server_address.encode("utf-8")
token = SealedBox(verify_key.to_curve25519_public_key()).encrypt(data)
print("Granted!")
print()
print("Send them the following token:")
print(base64.b64encode(token).decode("utf-8"))
@@ -1,3 +0,0 @@
target/
config.toml
*.json
-607
View File
@@ -1,607 +0,0 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
[[package]]
name = "argh"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91792f088f87cdc7a2cfb1d617fa5ea18d7f1dc22ef0e1b5f82f3157cdc522be"
dependencies = [
"argh_derive",
"argh_shared",
]
[[package]]
name = "argh_derive"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4eb0c0c120ad477412dc95a4ce31e38f2113e46bd13511253f79196ca68b067"
dependencies = [
"argh_shared",
"heck",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "argh_shared"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "781f336cc9826dbaddb9754cb5db61e64cab4f69668bd19dcc4a0394a86f4cb1"
[[package]]
name = "autocfg"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
[[package]]
name = "bitflags"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
[[package]]
name = "bytes"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040"
[[package]]
name = "cc"
version = "1.0.67"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3c69b077ad434294d3ce9f1f6143a2a4b89a8a2d54ef813d85003a4fd1137fd"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "670ad68c9088c2a963aaa298cb369688cf3f9465ce5e2d4ca10e6e0098a1ce73"
dependencies = [
"libc",
"num-integer",
"num-traits",
"time",
"winapi",
]
[[package]]
name = "getrandom"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9495705279e7140bf035dde1f6e750c162df8b625267cd52cc44e0b156732c8"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "heck"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87cbf45460356b7deeb5e3415b5563308c0a9b057c85e12b06ad551f98d0a6ac"
dependencies = [
"unicode-segmentation",
]
[[package]]
name = "hermit-abi"
version = "0.1.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "322f4de77956e22ed0e5032c359a0f1273f1f7f0d79bfa3b8ffbc730d7fbcc5c"
dependencies = [
"libc",
]
[[package]]
name = "hex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "instant"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61124eeebbd69b8190558df225adf7e4caafce0d743919e5d6b19652314ec5ec"
dependencies = [
"cfg-if",
]
[[package]]
name = "itoa"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736"
[[package]]
name = "libc"
version = "0.2.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9385f66bf6105b241aa65a61cb923ef20efc665cb9f9bb50ac2f0c4b7f378d41"
[[package]]
name = "libsodium-sys"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a685b64f837b339074115f2e7f7b431ac73681d08d75b389db7498b8892b8a58"
dependencies = [
"cc",
"libc",
"pkg-config",
]
[[package]]
name = "lock_api"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a3c91c24eae6777794bb1997ad98bbb87daf92890acab859f7eaa4320333176"
dependencies = [
"scopeguard",
]
[[package]]
name = "log"
version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710"
dependencies = [
"cfg-if",
]
[[package]]
name = "memchr"
version = "2.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ee1c47aaa256ecabcaea351eae4a9b01ef39ed810004e298d2511ed284b1525"
[[package]]
name = "mio"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf80d3e903b34e0bd7282b218398aec54e082c840d9baf8339e0080a0c542956"
dependencies = [
"libc",
"log",
"miow",
"ntapi",
"winapi",
]
[[package]]
name = "miow"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9f1c5b025cda876f66ef43a113f91ebc9f4ccef34843000e0adf6ebbab84e21"
dependencies = [
"winapi",
]
[[package]]
name = "ntapi"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44"
dependencies = [
"winapi",
]
[[package]]
name = "num-integer"
version = "0.1.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290"
dependencies = [
"autocfg",
]
[[package]]
name = "num_cpus"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
dependencies = [
"hermit-abi",
"libc",
]
[[package]]
name = "once_cell"
version = "1.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af8b08b04175473088b46763e51ee54da5f9a164bc162f615b91bc179dbf15a3"
[[package]]
name = "pahserver"
version = "0.0.1"
dependencies = [
"argh",
"chrono",
"hex",
"pin-project",
"serde",
"serde_json",
"serde_tuple",
"slotmap",
"sodiumoxide",
"tempfile",
"tokio",
"toml",
]
[[package]]
name = "parking_lot"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d7744ac029df22dca6284efe4e898991d28e3085c706c972bcd7da4a27a15eb"
dependencies = [
"instant",
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa7a782938e745763fe6907fc6ba86946d72f49fe7e21de074e08128a99fb018"
dependencies = [
"cfg-if",
"instant",
"libc",
"redox_syscall",
"smallvec",
"winapi",
]
[[package]]
name = "pin-project"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7509cc106041c40a4518d2af7a61530e1eed0e6285296a3d8c5472806ccc4a4"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c950132583b500556b1efd71d45b319029f2b71518d979fcc208e16b42426f"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "pin-project-lite"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc0e1f259c92177c30a4c9d177246edd0a3568b25756a977d0632cf8fa37e905"
[[package]]
name = "pkg-config"
version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3831453b3449ceb48b6d9c7ad7c96d5ea673e9b470a1dc578c2ce6521230884c"
[[package]]
name = "ppv-lite86"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857"
[[package]]
name = "proc-macro2"
version = "1.0.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a152013215dca273577e18d2bf00fa862b89b24169fb78c4c95aeb07992c9cec"
dependencies = [
"unicode-xid",
]
[[package]]
name = "quote"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ef9e7e66b4468674bfcb0c81af8b7fa0bb154fa9f28eb840da5c447baeb8d7e"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
"rand_hc",
]
[[package]]
name = "rand_chacha"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e12735cf05c9e10bf21534da50a147b924d555dc7a547c42e6bb2d5b6017ae0d"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34cf66eb183df1c5876e2dcf6b13d57340741e8dc255b48e40a26de954d06ae7"
dependencies = [
"getrandom",
]
[[package]]
name = "rand_hc"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3190ef7066a446f2e7f42e239d161e905420ccab01eb967c9eb27d21b2322a73"
dependencies = [
"rand_core",
]
[[package]]
name = "redox_syscall"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8270314b5ccceb518e7e578952f0b72b88222d02e8f77f5ecf7abbb673539041"
dependencies = [
"bitflags",
]
[[package]]
name = "remove_dir_all"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7"
dependencies = [
"winapi",
]
[[package]]
name = "ryu"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e"
[[package]]
name = "scopeguard"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "serde"
version = "1.0.125"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "558dc50e1a5a5fa7112ca2ce4effcb321b0300c0d4ccf0776a9f60cd89031171"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.125"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b093b7a2bb58203b5da3056c05b4ec1fed827dcfdb37347a8841695263b3d06d"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]]
name = "serde_tuple"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4f025b91216f15a2a32aa39669329a475733590a015835d1783549a56d09427"
dependencies = [
"serde",
"serde_tuple_macros",
]
[[package]]
name = "serde_tuple_macros"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4076151d1a2b688e25aaf236997933c66e18b870d0369f8b248b8ab2be630d7e"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "signal-hook-registry"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16f1d0fef1604ba8f7a073c7e701f213e056707210e9020af4528e0101ce11a6"
dependencies = [
"libc",
]
[[package]]
name = "slotmap"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "585cd5dffe4e9e06f6dfdf66708b70aca3f781bed561f4f667b2d9c0d4559e36"
dependencies = [
"version_check",
]
[[package]]
name = "smallvec"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e"
[[package]]
name = "sodiumoxide"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7038b67c941e23501573cb7242ffb08709abe9b11eb74bceff875bbda024a6a8"
dependencies = [
"libc",
"libsodium-sys",
"serde",
]
[[package]]
name = "syn"
version = "1.0.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48fe99c6bd8b1cc636890bcc071842de909d902c81ac7dab53ba33c421ab8ffb"
dependencies = [
"proc-macro2",
"quote",
"unicode-xid",
]
[[package]]
name = "tempfile"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dac1c663cfc93810f88aed9b8941d48cabf856a1b111c29a40439018d870eb22"
dependencies = [
"cfg-if",
"libc",
"rand",
"redox_syscall",
"remove_dir_all",
"winapi",
]
[[package]]
name = "time"
version = "0.1.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438"
dependencies = [
"libc",
"winapi",
]
[[package]]
name = "tokio"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83f0c8e7c0addab50b663055baf787d0af7f413a46e6e7fb9559a4e4db7137a5"
dependencies = [
"autocfg",
"bytes",
"libc",
"memchr",
"mio",
"num_cpus",
"once_cell",
"parking_lot",
"pin-project-lite",
"signal-hook-registry",
"tokio-macros",
"winapi",
]
[[package]]
name = "tokio-macros"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "caf7b11a536f46a809a8a9f0bb4237020f70ecbf115b842360afb127ea2fda57"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "toml"
version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a31142970826733df8241ef35dc040ef98c679ab14d7c3e54d827099b3acecaa"
dependencies = [
"serde",
]
[[package]]
name = "unicode-segmentation"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb0d2e7be6ae3a5fa87eed5fb451aff96f2573d2694942e40543ae0bbe19c796"
[[package]]
name = "unicode-xid"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564"
[[package]]
name = "version_check"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe"
[[package]]
name = "wasi"
version = "0.10.2+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
@@ -1,19 +0,0 @@
[package]
name = "pahserver"
version = "0.0.1"
edition = "2018"
resolver = "2"
[dependencies]
tokio = { version = "1", features = ["full"] }
sodiumoxide = "0.2"
toml = "0.5"
serde = { version = "1.0", features = ["derive", "rc"] }
serde_json = "1.0"
serde_tuple = "0.5"
hex = "0.4"
argh = "0.1"
tempfile = "3"
slotmap = "1"
pin-project = "1"
chrono = "*"
@@ -1,23 +0,0 @@
# controller
This directory contains code for the central permuter@home controller server,
written in Rust. All p@h traffic passes through here.
If you just want to run a regular p@h server, you don't need to care about this.
To setup your own copy of the controller server:
- Install Rust and (for the libsodium dependency) GCC.
- Run `cargo build --release`.
- Run `./target/release/pahserver setup --db path/to/database.json` and follow
the instructions there. This will set the `priv_seed` part of `config.toml`, and
set up an initial trusted client. The rest of `config.toml` can be copied from
`config_example.toml`.
- Set up a reverse proxy that forwards HTTPS traffic from an external port or route
to HTTP for a port of your choice, e.g. using Nginx or Traefik.
If applicable, configure your firewall to let the external port through.
- Start the server with:
```
./target/release/pahserver run --listen-on 0.0.0.0:<port> --config config.toml --db path/to/database.json
```
and configure the system to run this at startup.
@@ -1,2 +0,0 @@
docker_image = ""
priv_seed = "0000000000000000000000000000000000000000000000000000000000000000"
@@ -1,205 +0,0 @@
use std::collections::VecDeque;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::sync::mpsc;
use crate::db::UserId;
use crate::flimsy_semaphore::FlimsySemaphore;
use crate::port::{ReadPort, WritePort};
use crate::stats;
use crate::util::SimpleResult;
use crate::{
current_load, Permuter, PermuterData, PermuterId, PermuterResult, PermuterWork, ServerUpdate,
State,
};
const MIN_PERMUTER_VERSION: u32 = 1;
const CLIENT_MAX_QUEUES_SIZE: usize = 100;
const MIN_PRIORITY: f64 = 0.001;
const MAX_PRIORITY: f64 = 10.0;
#[derive(Debug, Deserialize)]
pub(crate) struct ConnectClientData {
priority: f64,
}
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ClientMessage {
Work { work: PermuterWork },
}
#[derive(Serialize)]
struct PermuterResultMessage<'a> {
server: String,
#[serde(flatten)]
update: &'a ServerUpdate,
}
async fn client_read(
port: &mut ReadPort<'_>,
perm_id: &PermuterId,
semaphore: &FlimsySemaphore,
state: &State,
) -> SimpleResult<()> {
loop {
let msg = port.recv().await?;
let msg: ClientMessage = serde_json::from_slice(&msg)?;
let ClientMessage::Work { work } = msg;
// Avoid the work and result queues growing indefinitely by restricting
// their combined size with a semaphore.
semaphore.acquire().await;
let mut m = state.m.lock().unwrap();
let perm = m.permuters.get_mut(perm_id).unwrap();
if perm.work_queue.is_empty() {
state.new_work_notification.notify_waiters();
}
perm.work_queue.push_back(work);
}
}
async fn client_write(
port: &mut WritePort<'_>,
fn_name: &str,
semaphore: &FlimsySemaphore,
state: &State,
mut result_rx: mpsc::UnboundedReceiver<PermuterResult>,
client_id: &UserId,
) -> SimpleResult<()> {
loop {
let res = result_rx.recv().await.unwrap();
semaphore.release();
match res {
PermuterResult::NeedWork => {
port.send_json(&json!({
"type": "need_work",
}))
.await?;
}
PermuterResult::Result(server_id, server_name, server_update) => {
port.send_json(&PermuterResultMessage {
server: server_name,
update: &server_update,
})
.await?;
if let ServerUpdate::Result {
compressed_source,
ref more_props,
..
} = server_update
{
if let Some(ref data) = compressed_source {
port.send(data).await?;
}
let score = more_props.get("score").and_then(|score| score.as_i64());
let outcome = if compressed_source.is_none() {
stats::Outcome::Unhelpful
} else if matches!(score, Some(0)) {
stats::Outcome::Matched
} else {
stats::Outcome::Improved
};
state
.log_stats(stats::Record::WorkDone {
server: server_id,
client: client_id.clone(),
fn_name: fn_name.to_string(),
outcome,
})
.await?;
}
}
}
}
}
pub(crate) async fn handle_connect_client<'a>(
mut read_port: ReadPort<'a>,
mut write_port: WritePort<'a>,
who_id: UserId,
who_name: &str,
permuter_version: u32,
state: &State,
data: ConnectClientData,
) -> SimpleResult<()> {
if permuter_version < MIN_PERMUTER_VERSION {
Err("Permuter version too old!")?;
}
if !(MIN_PRIORITY <= data.priority && data.priority <= MAX_PRIORITY) {
Err("Priority out of range")?;
}
let load = current_load(state, Some(data.priority));
write_port.send_json(&load).await?;
let permuter_data = read_port.recv().await?;
let mut permuter_data: PermuterData = serde_json::from_slice(&permuter_data)?;
permuter_data.compressed_source = read_port.recv().await?;
permuter_data.compressed_target_o_bin = read_port.recv().await?;
write_port.send_json(&json!({})).await?;
eprintln!(
"[{}] start client ({}, {})",
&who_name, &permuter_data.fn_name, data.priority
);
state
.log_stats(stats::Record::ClientNewFunction {
client: who_id.clone(),
fn_name: permuter_data.fn_name.clone(),
})
.await?;
let energy_add = 1.0 / data.priority;
let fn_name = permuter_data.fn_name.clone();
let (result_tx, result_rx) = mpsc::unbounded_channel();
let semaphore = Arc::new(FlimsySemaphore::new(CLIENT_MAX_QUEUES_SIZE));
let perm_id = {
let mut m = state.m.lock().unwrap();
let id = m.next_permuter_id;
m.next_permuter_id += 1;
m.permuters.insert(
id,
Permuter {
data: permuter_data.into(),
client_id: who_id.clone(),
client_name: who_name.to_string(),
work_queue: VecDeque::new(),
result_tx: result_tx.clone(),
semaphore: semaphore.clone(),
priority: data.priority,
energy_add,
},
);
state.new_work_notification.notify_waiters();
id
};
let r = tokio::try_join!(
client_read(&mut read_port, &perm_id, &semaphore, state),
client_write(
&mut write_port,
&fn_name,
&semaphore,
state,
result_rx,
&who_id
)
);
state.m.lock().unwrap().permuters.remove(&perm_id);
state.new_work_notification.notify_waiters();
r?;
Ok(())
}
@@ -1,105 +0,0 @@
use std::collections::HashMap;
use std::convert::TryInto;
use hex::FromHex;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_tuple::{Deserialize_tuple, Serialize_tuple};
use sodiumoxide::crypto::sign;
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub struct ByteString<const SIZE: usize>([u8; SIZE]);
impl<const SIZE: usize> ByteString<SIZE> {
fn to_hex(&self) -> String {
hex::encode(&self.0)
}
fn from_hex(string: &str) -> Result<ByteString<SIZE>, &'static str> {
Ok(ByteString(
Vec::from_hex(&string)
.map_err(|_| "not a valid hex string")?
.try_into()
.map_err(|_| "byte string has wrong size")?,
))
}
}
impl<const SIZE: usize> Serialize for ByteString<SIZE> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_hex())
}
}
impl<'de, const SIZE: usize> Deserialize<'de> for ByteString<SIZE> {
fn deserialize<D>(deserializer: D) -> Result<ByteString<SIZE>, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error;
let string = String::deserialize(deserializer)?;
ByteString::from_hex(&string).map_err(Error::custom)
}
}
pub type UserId = ByteString<32>;
impl UserId {
pub fn from_pubkey(key: &sign::PublicKey) -> UserId {
ByteString(key.as_ref().try_into().unwrap())
}
pub fn to_pubkey(&self) -> sign::PublicKey {
sign::PublicKey::from_slice(&self.0).unwrap()
}
}
impl<const SIZE: usize> ByteString<SIZE> {
pub fn to_seed(&self) -> sign::Seed {
sign::Seed::from_slice(&self.0).unwrap()
}
}
#[derive(Debug, Deserialize_tuple, Serialize_tuple)]
pub struct Stats {
pub iterations: u64,
pub improvements: u64,
pub matches: u64,
pub functions: u64,
}
impl Default for Stats {
fn default() -> Stats {
Stats {
iterations: 0,
improvements: 0,
matches: 0,
functions: 0,
}
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct User {
pub trusted_by: Option<UserId>,
pub name: String,
pub client_stats: Stats,
pub server_stats: Stats,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct DB {
pub users: HashMap<UserId, User>,
pub func_stats: HashMap<String, Stats>,
pub total_stats: Stats,
}
impl DB {
pub fn func_stat(&mut self, fn_name: String) -> &mut Stats {
self.func_stats
.entry(fn_name)
.or_insert_with(Stats::default)
}
}
@@ -1,61 +0,0 @@
use std::convert::TryInto;
use std::sync::atomic::{AtomicIsize, Ordering};
use tokio::sync::Notify;
/// An unfair semaphore that allows overdrafts.
pub struct FlimsySemaphore {
notify: Notify,
slots: AtomicIsize,
}
impl FlimsySemaphore {
// Invariant: if `slots` has ever become non-positive, then if positive
// there will be a notify token in circulation. Taking the token
// synchronizes with a positive `slots`.
pub fn new(limit: usize) -> FlimsySemaphore {
FlimsySemaphore {
notify: Notify::new(),
slots: AtomicIsize::new(limit.try_into().unwrap()),
}
}
pub fn acquire_ignore_limit(&self) {
self.slots.fetch_add(-1, Ordering::Acquire);
}
pub async fn acquire(&self) {
let mut was_woken = false;
let mut val = self.slots.load(Ordering::Relaxed);
loop {
if val > 0 {
match self.slots.compare_exchange(
val,
val - 1,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => {
if was_woken && val > 1 {
self.notify.notify_one();
}
return;
}
Err(actually) => {
val = actually;
}
}
} else {
self.notify.notified().await;
was_woken = true;
val = self.slots.load(Ordering::Relaxed);
}
}
}
pub fn release(&self) {
if self.slots.fetch_add(1, Ordering::Release) == 0 {
self.notify.notify_one();
}
}
}
@@ -1,418 +0,0 @@
#![allow(clippy::try_err)]
use std::collections::{HashMap, VecDeque};
use std::convert::TryInto;
use std::default::Default;
use std::io::ErrorKind;
use std::sync::{Arc, Mutex};
use argh::FromArgs;
use serde::{Deserialize, Serialize};
use serde_json::json;
use slotmap::{new_key_type, SlotMap};
use sodiumoxide::crypto::box_;
use sodiumoxide::crypto::sign;
use tokio::fs;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, watch, Notify};
use tokio::time;
use crate::db::{ByteString, UserId};
use crate::flimsy_semaphore::FlimsySemaphore;
use crate::port::{ReadPort, WritePort};
use crate::save::SaveableDB;
use crate::util::SimpleResult;
mod client;
mod db;
mod flimsy_semaphore;
mod port;
mod save;
mod server;
mod setup;
mod stats;
mod util;
mod vouch;
const HEARTBEAT_TIME: time::Duration = time::Duration::from_secs(300);
#[derive(FromArgs)]
/// The permuter@home control server.
struct CmdOpts {
#[argh(subcommand)]
sub: SubCommand,
}
#[derive(FromArgs)]
#[argh(subcommand)]
enum SubCommand {
RunServer(RunServerOpts),
Setup(SetupOpts),
}
#[derive(FromArgs)]
/// Run the permuter@home control server.
#[argh(subcommand, name = "run")]
struct RunServerOpts {
/// ip:port to listen on (e.g. 0.0.0.0:1234)
#[argh(option)]
listen_on: String,
/// path to TOML configuration file
#[argh(option)]
config: String,
/// path to JSON database
#[argh(option)]
db: String,
/// enable debug logging
#[argh(switch)]
debug: bool,
}
#[derive(FromArgs)]
/// Setup initial database and config for permuter@home.
#[argh(subcommand, name = "setup")]
struct SetupOpts {
/// path to JSON database
#[argh(option)]
db: String,
}
#[derive(Deserialize)]
struct Config {
docker_image: String,
priv_seed: ByteString<32>,
}
#[derive(Debug, Deserialize, Serialize)]
struct PermuterData {
fn_name: String,
#[serde(skip)]
compressed_source: Vec<u8>,
#[serde(skip)]
compressed_target_o_bin: Vec<u8>,
#[serde(flatten)]
more_props: HashMap<String, serde_json::Value>,
}
#[derive(Clone, Copy, Debug, Deserialize, Serialize)]
struct PermuterWork {
seed: u64,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ServerUpdate {
Result {
#[serde(skip_serializing, default)]
overhead_us: i64,
#[serde(skip)]
compressed_source: Option<Vec<u8>>,
#[serde(default)]
has_source: bool,
#[serde(flatten)]
more_props: HashMap<String, serde_json::Value>,
},
InitDone {
hash: String,
},
InitFailed {
reason: String,
},
Disconnect,
}
#[derive(Debug)]
enum PermuterResult {
NeedWork,
Result(UserId, String, ServerUpdate),
}
type PermuterId = u64;
struct Permuter {
data: Arc<PermuterData>,
client_id: UserId,
client_name: String,
work_queue: VecDeque<PermuterWork>,
result_tx: mpsc::UnboundedSender<PermuterResult>,
semaphore: Arc<FlimsySemaphore>,
priority: f64,
energy_add: f64,
}
impl Permuter {
fn send_result(&mut self, res: PermuterResult) {
// We can't use a blocking semaphore acquire here, because we don't
// want server sends to block on random client receives. In practice,
// this is probably fine.
let _ = self.result_tx.send(res);
self.semaphore.acquire_ignore_limit();
}
}
new_key_type! { struct ServerId; }
struct ConnectedServer {
min_priority: f64,
num_cores: f64,
}
struct MutableState {
servers: SlotMap<ServerId, ConnectedServer>,
permuters: HashMap<PermuterId, Permuter>,
next_permuter_id: PermuterId,
}
struct State {
docker_image: String,
debug: bool,
sign_sk: sign::SecretKey,
db: SaveableDB,
stats_tx: mpsc::Sender<stats::Record>,
heartbeat_rx: watch::Receiver<()>,
new_work_notification: Notify,
m: Mutex<MutableState>,
}
impl State {
async fn log_stats(&self, record: stats::Record) -> SimpleResult<()> {
self.stats_tx
.send(record)
.await
.map_err(|_| "stats thread died".into())
}
}
#[derive(Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")]
enum Request {
Ping,
Vouch(vouch::VouchData),
ConnectServer(server::ConnectServerData),
ConnectClient(client::ConnectClientData),
}
#[derive(Serialize)]
struct Load {
clients: usize,
servers: usize,
cores: f64,
}
#[tokio::main]
async fn main() -> SimpleResult<()> {
sodiumoxide::init().map_err(|()| "Failed to initialize cryptography library")?;
let opts: CmdOpts = argh::from_env();
match opts.sub {
SubCommand::RunServer(opts) => run_server(opts).await?,
SubCommand::Setup(opts) => setup::run_setup(opts)?,
}
Ok(())
}
async fn run_server(opts: RunServerOpts) -> SimpleResult<()> {
let config: Config = toml::from_str(&fs::read_to_string(&opts.config).await?)?;
let (_, sign_sk) = sign::keypair_from_seed(&config.priv_seed.to_seed());
let (save_fut, db) = SaveableDB::open(&opts.db)?;
tokio::spawn(async move {
if let Err(e) = save_fut.await {
eprintln!("Failed to save! {:?}", e);
std::process::exit(1);
}
});
let (stats_fut, stats_tx) = stats::stats_thread(&db);
tokio::spawn(stats_fut);
let (heartbeat_tx, heartbeat_rx) = watch::channel(());
let state: &'static State = Box::leak(Box::new(State {
docker_image: config.docker_image,
debug: opts.debug,
sign_sk,
db,
stats_tx,
heartbeat_rx,
new_work_notification: Notify::new(),
m: Mutex::new(MutableState {
servers: SlotMap::with_key(),
permuters: HashMap::new(),
next_permuter_id: 0,
}),
}));
tokio::spawn(async move {
loop {
heartbeat_tx.send(()).expect("receiver is still alive");
time::sleep(HEARTBEAT_TIME).await;
}
});
let listener = TcpListener::bind(opts.listen_on).await?;
loop {
let (socket, _) = listener.accept().await?;
tokio::spawn(async move {
let mut who = "anonymous".to_string();
if let Err(e) = handle_connection(socket, state, &mut who).await {
if let Some(e) = e.downcast_ref::<std::io::Error>() {
if matches!(
e.kind(),
ErrorKind::UnexpectedEof
| ErrorKind::ConnectionReset
| ErrorKind::TimedOut
| ErrorKind::BrokenPipe
) {
eprintln!("[{}] disconnected", &who);
return;
}
}
eprintln!("[{}] error: {:?}", &who, e);
}
});
}
}
fn concat<T: Clone>(a: &[T], b: &[T]) -> Vec<T> {
a.iter().chain(b).cloned().collect()
}
fn concat3<T: Clone>(a: &[T], b: &[T], c: &[T]) -> Vec<T> {
a.iter().chain(b).chain(c).cloned().collect()
}
async fn handshake<'a>(
mut rd: ReadHalf<'a>,
mut wr: WriteHalf<'a>,
sign_sk: &sign::SecretKey,
) -> SimpleResult<(ReadPort<'a>, WritePort<'a>, UserId, u32)> {
let mut buffer = [0; 4 + 32];
rd.read_exact(&mut buffer).await?;
let (magic, their_pk) = buffer.split_at(4);
if magic != b"p@h0" {
Err("Invalid protocol version")?;
}
let their_pk = box_::PublicKey::from_slice(&their_pk).unwrap();
let (our_pk, our_sk) = box_::gen_keypair();
let signed_data = concat3(b"HELLO:", their_pk.as_ref(), our_pk.as_ref());
let signature = sign::sign_detached(&signed_data, &sign_sk);
wr.write_all(&concat(our_pk.as_ref(), signature.as_ref()))
.await?;
let key = box_::precompute(&their_pk, &our_sk);
let mut read_port = ReadPort::new(rd, &key);
let write_port = WritePort::new(wr, &key);
let reply = read_port.recv().await?;
if reply.len() != 32 + 64 + 4 {
Err("Failed to perform secret handshake")?;
}
let (client_ver_key, rest) = reply.split_at(32);
let (client_signature, permuter_version) = rest.split_at(64);
let client_ver_key = sign::PublicKey::from_slice(client_ver_key).unwrap();
let client_signature = sign::Signature::from_slice(client_signature).unwrap();
let permuter_version = u32::from_be_bytes(permuter_version.try_into().unwrap());
let signed_data = concat(b"WORLD:", our_pk.as_ref());
if !sign::verify_detached(&client_signature, &signed_data, &client_ver_key) {
Err("Spoofed client signature!")?;
}
Ok((
read_port,
write_port,
UserId::from_pubkey(&client_ver_key),
permuter_version,
))
}
fn current_load(state: &State, priority: Option<f64>) -> Load {
let m = state.m.lock().unwrap();
let mut servers: usize = 0;
let mut cores: f64 = 0.0;
for server in m.servers.values() {
if priority.map_or(true, |p| p >= server.min_priority) {
servers += 1;
cores += server.num_cores;
}
}
Load {
clients: m.permuters.len(),
servers,
cores,
}
}
async fn handle_connection(
mut socket: TcpStream,
state: &State,
out_name: &mut String,
) -> SimpleResult<()> {
let (rd, wr) = socket.split();
let (mut read_port, mut write_port, user_id, permuter_version) =
handshake(rd, wr, &state.sign_sk).await?;
let name = match state.db.read(|db| {
let user = db.users.get(&user_id)?;
Some(user.name.clone())
}) {
Some(name) => name,
None => {
write_port.send_error("Access denied!").await?;
Err("Unknown client!")?
}
};
*out_name = name.clone();
eprintln!("[{}] connected (v {})", &name, permuter_version);
if state.debug {
read_port.set_debug(&name);
write_port.set_debug(&name);
}
write_port.send_json(&json!({})).await?;
let request = read_port.recv().await?;
let request: Request = serde_json::from_slice(&request)?;
match request {
Request::Ping => {
eprintln!("[{}] ping", &name);
let load = current_load(state, None);
write_port.send_json(&load).await?;
}
Request::Vouch(data) => {
vouch::handle_vouch(read_port, write_port, user_id, &name, state, data).await?;
}
Request::ConnectServer(data) => {
server::handle_connect_server(
read_port,
write_port,
user_id,
&name,
permuter_version,
state,
data,
)
.await?;
}
Request::ConnectClient(data) => {
client::handle_connect_client(
read_port,
write_port,
user_id,
&name,
permuter_version,
state,
data,
)
.await?;
}
};
Ok(())
}
@@ -1,115 +0,0 @@
use std::convert::TryInto;
use chrono::Local;
use serde::Serialize;
use sodiumoxide::crypto::box_;
use sodiumoxide::crypto::box_::{Nonce, PrecomputedKey};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::{ReadHalf, WriteHalf};
use crate::util::SimpleResult;
fn debug_print(action: &str, who: &str, msg: &[u8]) {
let time = Local::now().format("%H:%M:%S:%f");
if msg.len() <= 300 {
let msg = String::from_utf8(
msg.iter()
.copied()
.flat_map(std::ascii::escape_default)
.collect(),
)
.unwrap();
println!("{} debug: {} {}: {}", time, action, who, msg);
} else {
println!("{} debug: {} {}: {} bytes", time, action, who, msg.len());
}
}
pub struct ReadPort<'a> {
read_half: ReadHalf<'a>,
key: PrecomputedKey,
nonce: u64,
debug_name: Option<&'a str>,
}
impl<'a> ReadPort<'a> {
pub fn new(read_half: ReadHalf<'a>, key: &PrecomputedKey) -> Self {
ReadPort {
read_half,
key: key.clone(),
nonce: 0,
debug_name: None,
}
}
pub fn set_debug(&mut self, name: &'a str) {
self.debug_name = Some(name);
}
pub async fn recv(&mut self) -> SimpleResult<Vec<u8>> {
let len = self.read_half.read_u64().await?;
if len >= (1 << 48) {
Err("Unreasonable packet length")?
}
let mut buffer = vec![0; len.try_into()?];
self.read_half.read_exact(&mut buffer).await?;
let nonce = nonce_from_u64(self.nonce);
self.nonce += 2;
let data =
box_::open_precomputed(&buffer, &nonce, &self.key).map_err(|()| "Failed to decrypt")?;
if let Some(name) = self.debug_name {
debug_print("Receive from", name, &data);
}
Ok(data)
}
}
pub struct WritePort<'a> {
write_half: WriteHalf<'a>,
key: PrecomputedKey,
nonce: u64,
debug_name: Option<&'a str>,
}
impl<'a> WritePort<'a> {
pub fn new(write_half: WriteHalf<'a>, key: &PrecomputedKey) -> Self {
WritePort {
write_half,
key: key.clone(),
nonce: 1,
debug_name: None,
}
}
pub fn set_debug(&mut self, name: &'a str) {
self.debug_name = Some(name);
}
pub async fn send(&mut self, data: &[u8]) -> SimpleResult<()> {
if let Some(name) = self.debug_name {
debug_print("Send to", name, &data);
}
let nonce = nonce_from_u64(self.nonce);
self.nonce += 2;
let data = box_::seal_precomputed(data, &nonce, &self.key);
self.write_half.write_u64(data.len() as u64).await?;
self.write_half.write_all(&data).await?;
Ok(())
}
pub async fn send_json<T: ?Sized>(&mut self, value: &T) -> SimpleResult<()>
where
T: Serialize,
{
self.send(&serde_json::to_vec(value)?).await
}
pub async fn send_error(&mut self, message: &str) -> SimpleResult<()> {
self.send_json(message).await
}
}
fn nonce_from_u64(num: u64) -> Nonce {
let nonce_bytes = [[0; 8], [0; 8], num.to_be_bytes()].concat();
Nonce::from_slice(&nonce_bytes).unwrap()
}
@@ -1,158 +0,0 @@
use std::future::Future;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tempfile::NamedTempFile;
use tokio::sync::{mpsc, oneshot};
use tokio::time::timeout;
use crate::db::DB;
use crate::util::{FutureExt, SimpleResult};
const SAVE_INTERVAL: Duration = Duration::from_secs(30);
enum SaveType {
Delayed,
Immediate(oneshot::Sender<()>),
}
struct InnerSaveableDB {
db: DB,
stale: bool,
save_chan: mpsc::UnboundedSender<SaveType>,
}
#[derive(Clone)]
pub struct SaveableDB(Arc<RwLock<InnerSaveableDB>>);
async fn save_db_loop(
db: SaveableDB,
path: &Path,
mut save_channel: mpsc::UnboundedReceiver<SaveType>,
) -> SimpleResult<()> {
loop {
let mut done_chans = Vec::new();
match save_channel.recv().await {
None => return Ok(()),
Some(SaveType::Immediate(chan)) => {
done_chans.push(chan);
}
Some(SaveType::Delayed) => {
// Wait for SAVE_INTERVAL or until we receive an Immediate save.
let _ = timeout(SAVE_INTERVAL, async {
loop {
match save_channel.recv().await {
None => {
break;
}
Some(SaveType::Immediate(chan)) => {
done_chans.push(chan);
break;
}
Some(SaveType::Delayed) => {}
};
}
})
.await;
}
};
// Clear the queue in case more messages have stacked up past an
// Immediate. Receiver::try_recv() is temporarily dead as of tokio 1.4
// (https://github.com/tokio-rs/tokio/issues/3350) due to a bug where
// messages can be delayed, but in this case that doesn't matter.
loop {
match save_channel.recv().now_or_never().await {
None | Some(None) => {
break;
}
Some(Some(SaveType::Immediate(chan))) => {
done_chans.push(chan);
}
Some(Some(SaveType::Delayed)) => {}
};
}
// Mark the DB as non-stale, to start receiving save messages again.
db.0.write().unwrap().stale = false;
// Actually do the save, by first serializing, then atomically saving
// the file by creating and renaming a temp file in the same directory.
let data = db.read(|db| serde_json::to_string(&db).unwrap());
let r: SimpleResult<()> = tokio::task::block_in_place(|| {
let parent_dir = path.parent().unwrap_or_else(|| Path::new("."));
let mut tempf = NamedTempFile::new_in(parent_dir)?;
tempf.write_all(data.as_bytes())?;
tempf.as_file().sync_all()?;
tempf.persist(path)?;
Ok(())
});
r?;
for chan in done_chans {
let _ = chan.send(());
}
}
}
impl SaveableDB {
pub fn open(
filename: &str,
) -> SimpleResult<(impl Future<Output = SimpleResult<()>>, SaveableDB)> {
let db_file = std::fs::File::open(filename)?;
let db: DB = serde_json::from_reader(&db_file)?;
let (save_tx, save_rx) = mpsc::unbounded_channel();
let saveable_db = SaveableDB(Arc::new(RwLock::new(InnerSaveableDB {
db,
stale: false,
save_chan: save_tx,
})));
let path = PathBuf::from(filename);
let db2 = saveable_db.clone();
let fut = async move { save_db_loop(db2, &path, save_rx).await };
Ok((fut, saveable_db))
}
pub fn read<T>(&self, callback: impl FnOnce(&DB) -> T) -> T {
let inner = self.0.read().unwrap();
callback(&inner.db)
}
pub async fn write<T>(&self, immediate: bool, callback: impl FnOnce(&mut DB) -> T) -> T {
let ret;
let rx2;
{
let mut inner = self.0.write().unwrap();
ret = callback(&mut inner.db);
if immediate {
inner.stale = true;
let (tx, rx) = oneshot::channel();
rx2 = rx;
inner
.save_chan
.send(SaveType::Immediate(tx))
.map_err(|_| ())
.expect("Failed to send message to save task");
} else {
if !inner.stale {
inner.stale = true;
inner
.save_chan
.send(SaveType::Delayed)
.map_err(|_| ())
.expect("Failed to send message to save task");
}
return ret;
}
}
rx2.await.expect("Failed to save!");
ret
}
}
@@ -1,500 +0,0 @@
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::sync::{mpsc, mpsc::error::TrySendError, watch, Notify};
use crate::db::UserId;
use crate::port::{ReadPort, WritePort};
use crate::stats;
use crate::util::SimpleResult;
use crate::{
ConnectedServer, MutableState, PermuterData, PermuterId, PermuterResult, PermuterWork,
ServerUpdate, State, HEARTBEAT_TIME,
};
const MIN_PERMUTER_VERSION: u32 = 1;
const SERVER_WORK_QUEUE_SIZE: usize = 100;
const TIME_US_GUESS: f64 = 100_000.0;
const MIN_OVERHEAD_US: f64 = 100_000.0;
const MAX_OVERHEAD_FACTOR: i64 = 2;
#[derive(Debug, Deserialize)]
pub(crate) struct ConnectServerData {
min_priority: f64,
num_cores: f64,
}
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ServerMessage {
NeedWork,
Update {
permuter: PermuterId,
time_us: f64,
update: ServerUpdate,
},
}
enum JobState {
Loading,
Loaded,
Failed,
}
struct Job {
state: JobState,
energy: f64,
active_work: i64,
}
struct ServerState {
min_priority: f64,
/// sum of active_work across all jobs
active_work: i64,
/// fractional part of how much work should be requested, in [0, 1)
more_work_acc: f64,
jobs: HashMap<PermuterId, Job>,
}
async fn server_read(
port: &mut ReadPort<'_>,
who_id: &UserId,
who_name: &str,
server_state: &Mutex<ServerState>,
state: &State,
more_work_tx: mpsc::Sender<()>,
new_permuter: &Notify,
) -> SimpleResult<()> {
loop {
let msg = port.recv().await?;
let mut msg: ServerMessage = serde_json::from_slice(&msg)?;
if let ServerMessage::Update {
update:
ServerUpdate::Result {
ref mut compressed_source,
has_source: true,
..
},
..
} = msg
{
*compressed_source = Some(port.recv().await?);
}
let mut has_new = false;
let mut request_work;
{
let mut m = state.m.lock().unwrap();
let mut server_state = server_state.lock().unwrap();
let mut more_work: f64 = 1.0;
if let ServerMessage::Update {
permuter: perm_id,
update,
time_us,
} = msg
{
// If we get back a message referring to a since-removed
// permuter, no need to do anything. Just request one more
// piece of work to make up for it.
if let Some(job) = server_state.jobs.get_mut(&perm_id) {
if let Some(perm) = m.permuters.get_mut(&perm_id) {
job.energy += perm.energy_add * time_us;
match update {
ServerUpdate::InitDone { .. } => {
if !matches!(job.state, JobState::Loading) {
Err("Got InitDone while not in Loading state")?;
}
job.state = JobState::Loaded;
has_new = true;
}
ServerUpdate::InitFailed { .. } => {
if !matches!(job.state, JobState::Loading) {
Err("Got InitFailed while not in Loading state")?;
}
job.state = JobState::Failed;
}
ServerUpdate::Disconnect { .. } => {
if !matches!(job.state, JobState::Loaded) {
Err("Got Disconnect while not in Loaded state")?;
}
job.state = JobState::Failed;
let work = job.active_work;
job.active_work = 0;
server_state.active_work -= work;
more_work = 0.0;
}
ServerUpdate::Result { overhead_us, .. } => {
if !matches!(job.state, JobState::Loaded) {
Err("Got result while not in Loaded state")?;
}
// If the work item spent less than some given
// amount of time in queues, request more work.
// This ensures we saturate all server cores.
// On the other hand, if it spends too much time
// in queues, it's best if we reduce the amount
// of work.
// We don't need to adjust for time spent on the
// network, because we have backpressure on slow
// writes on both ends, and read continuously.
job.active_work -= 1;
server_state.active_work -= 1;
let min_overhead_us = (time_us + MIN_OVERHEAD_US) as i64;
if overhead_us == 0 {
// Legacy server, skip this logic.
} else if overhead_us > MAX_OVERHEAD_FACTOR * min_overhead_us {
more_work = 0.5;
} else if overhead_us < min_overhead_us {
more_work = 1.5;
}
}
}
perm.send_result(PermuterResult::Result(
who_id.clone(),
who_name.to_string(),
update,
));
}
}
}
more_work += server_state.more_work_acc;
request_work = more_work as i32;
server_state.more_work_acc = more_work - request_work as f64;
if request_work == 0
&& server_state.active_work == 0
&& more_work_tx.capacity() == SERVER_WORK_QUEUE_SIZE
{
// Don't request 0 work if it would lead to total starvation.
request_work = 1;
}
}
if has_new {
new_permuter.notify_waiters();
state
.log_stats(stats::Record::ServerNewFunction {
server: who_id.clone(),
})
.await?;
}
for _ in 0..request_work {
// Try requesting more work by sending a message to the writer thread.
// If the queue is full (because the writer thread is blocked on a
// send), drop the request to avoid an unbounded backlog.
if let Err(TrySendError::Closed(_)) = more_work_tx.try_send(()) {
panic!("work chooser must not close except on error");
}
}
}
}
#[derive(Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ToSend {
Work(PermuterWork),
Add {
client_id: UserId,
client_name: String,
data: Arc<PermuterData>,
},
Remove,
}
#[derive(Serialize)]
struct OutMessage {
permuter: PermuterId,
#[serde(flatten)]
to_send: ToSend,
}
fn try_next_work_message(
m: &mut MutableState,
server_state: &mut ServerState,
) -> Option<OutMessage> {
let mut skip = HashSet::new();
loop {
// If possible, send a new permuter.
if let Some((&perm_id, perm)) = m
.permuters
.iter()
.find(|(&perm_id, _)| !server_state.jobs.contains_key(&perm_id))
{
server_state.jobs.insert(
perm_id,
Job {
state: JobState::Loading,
energy: 0.0,
active_work: 0,
},
);
return Some(OutMessage {
permuter: perm_id,
to_send: ToSend::Add {
client_id: perm.client_id.clone(),
client_name: perm.client_name.clone(),
data: perm.data.clone(),
},
});
}
// If none, find an existing one to work on, or to remove.
let mut best_cost = 0.0;
let mut best: Option<(PermuterId, &mut Job)> = None;
let min_priority = server_state.min_priority;
for (&perm_id, job) in server_state.jobs.iter_mut() {
if let Some(perm) = m.permuters.get(&perm_id) {
let energy =
job.energy + (job.active_work as f64) * perm.energy_add * TIME_US_GUESS;
if matches!(job.state, JobState::Loaded)
&& !skip.contains(&perm_id)
&& perm.priority >= min_priority
&& (best.is_none() || energy < best_cost)
{
best_cost = energy;
best = Some((perm_id, job));
}
} else {
server_state.active_work -= job.active_work;
server_state.jobs.remove(&perm_id);
return Some(OutMessage {
permuter: perm_id,
to_send: ToSend::Remove,
});
}
}
let (perm_id, job) = match best {
None => return None,
Some(tup) => tup,
};
let perm = m.permuters.get_mut(&perm_id).unwrap();
let work = match perm.work_queue.pop_front() {
None => {
// Chosen permuter is out of work. Ask it for more, and try
// again without it as a candidate. When the queue becomes
// non-empty again all sleeping writers will be notified.
perm.send_result(PermuterResult::NeedWork);
skip.insert(perm_id);
continue;
}
Some(work) => work,
};
perm.semaphore.release();
let min_energy = job.energy;
job.active_work += 1;
server_state.active_work += 1;
// Adjust energies to be around zero, to avoid problems with float
// imprecision, and to ensure that new permuters that come in with
// energy zero will fit the schedule.
for job in server_state.jobs.values_mut() {
job.energy -= min_energy;
}
return Some(OutMessage {
permuter: perm_id,
to_send: ToSend::Work(work),
});
}
}
async fn next_work_message(
server_state: &Mutex<ServerState>,
state: &State,
new_permuter: &Notify,
) -> OutMessage {
let mut wait_for = None;
loop {
if let Some(waiter) = wait_for {
waiter.await;
}
let mut m = state.m.lock().unwrap();
let mut server_state = server_state.lock().unwrap();
match try_next_work_message(&mut m, &mut server_state) {
Some(message) => return message,
None => {
// Nothing to work on! Register to be notified when something
// happens (while the lock is still held) and go to sleep.
let n1 = state.new_work_notification.notified();
let n2 = new_permuter.notified();
wait_for = Some(async move {
tokio::select! {
() = n1 => {}
() = n2 => {}
}
});
}
}
}
}
fn requires_response(work: &OutMessage) -> bool {
match work.to_send {
ToSend::Work { .. } => true,
ToSend::Add { .. } => true,
ToSend::Remove => false,
}
}
async fn server_choose_work(
server_state: &Mutex<ServerState>,
state: &State,
mut more_work_rx: mpsc::Receiver<()>,
next_message_tx: mpsc::Sender<OutMessage>,
wrote_message: &Notify,
new_permuter: &Notify,
) -> SimpleResult<()> {
loop {
let message = next_work_message(server_state, state, new_permuter).await;
let requires_response = requires_response(&message);
next_message_tx
.send(message)
.await
.map_err(|_| ())
.expect("writer must not close except on error");
wrote_message.notified().await;
if requires_response {
more_work_rx
.recv()
.await
.expect("reader must not close except on error");
}
}
}
async fn send_heartbeat(port: &mut WritePort<'_>) -> SimpleResult<()> {
port.send_json(&json!({
"type": "heartbeat",
}))
.await
}
async fn send_work(port: &mut WritePort<'_>, work: &OutMessage) -> SimpleResult<()> {
port.send_json(&work).await?;
if let ToSend::Add { ref data, .. } = work.to_send {
port.send(&data.compressed_source).await?;
port.send(&data.compressed_target_o_bin).await?;
}
Ok(())
}
async fn server_write(
port: &mut WritePort<'_>,
mut next_message_rx: mpsc::Receiver<OutMessage>,
mut heartbeat_rx: watch::Receiver<()>,
wrote_message: &Notify,
) -> SimpleResult<()> {
loop {
tokio::select! {
work = next_message_rx.recv() => {
let work = work.expect("chooser must not close except on error");
send_work(port, &work).await?;
wrote_message.notify_one();
}
res = heartbeat_rx.changed() => {
res.expect("heartbeat thread panicked");
send_heartbeat(port).await?;
}
}
}
}
pub(crate) async fn handle_connect_server<'a>(
mut read_port: ReadPort<'a>,
mut write_port: WritePort<'a>,
who_id: UserId,
who_name: &str,
permuter_version: u32,
state: &State,
data: ConnectServerData,
) -> SimpleResult<()> {
if permuter_version < MIN_PERMUTER_VERSION {
Err("Permuter version too old!")?;
}
eprintln!(
"[{}] start server ({}, {})",
who_name, data.min_priority, data.num_cores
);
write_port
.send_json(&json!({
"docker_image": &state.docker_image,
"heartbeat_interval": HEARTBEAT_TIME.as_secs(),
}))
.await?;
let (more_work_tx, more_work_rx) = mpsc::channel(SERVER_WORK_QUEUE_SIZE);
let (next_message_tx, next_message_rx) = mpsc::channel(1);
let wrote_message = Notify::new();
let new_permuter = Notify::new();
let mut server_state = Mutex::new(ServerState {
min_priority: data.min_priority,
active_work: 0,
more_work_acc: 0.0,
jobs: HashMap::new(),
});
let id = state.m.lock().unwrap().servers.insert(ConnectedServer {
min_priority: data.min_priority,
num_cores: data.num_cores,
});
let r = tokio::try_join!(
server_read(
&mut read_port,
&who_id,
who_name,
&server_state,
state,
more_work_tx,
&new_permuter,
),
server_choose_work(
&server_state,
state,
more_work_rx,
next_message_tx,
&wrote_message,
&new_permuter,
),
server_write(
&mut write_port,
next_message_rx,
state.heartbeat_rx.clone(),
&wrote_message,
)
);
{
let mut m = state.m.lock().unwrap();
for (&perm_id, job) in &server_state.get_mut().unwrap().jobs {
if let JobState::Loaded = job.state {
if let Some(perm) = m.permuters.get_mut(&perm_id) {
perm.send_result(PermuterResult::Result(
who_id.clone(),
who_name.to_string(),
ServerUpdate::Disconnect,
));
}
}
}
m.servers.remove(id);
}
r?;
Ok(())
}
@@ -1,57 +0,0 @@
use std::collections::HashMap;
use std::default::Default;
use std::fs::OpenOptions;
use sodiumoxide::crypto::sign;
use sodiumoxide::randombytes::randombytes;
use crate::db::{User, UserId, DB};
use crate::util::SimpleResult;
use crate::SetupOpts;
pub(crate) fn run_setup(opts: SetupOpts) -> SimpleResult<()> {
let db_file = OpenOptions::new()
.write(true)
.create_new(true)
.open(&opts.db)
.unwrap_or_else(|e| {
eprintln!("Cannot create database file {}: {}. Aborting.", &opts.db, e);
std::process::exit(1);
});
let server_seed = sign::Seed::from_slice(&randombytes(32)).unwrap();
let client_seed = sign::Seed::from_slice(&randombytes(32)).unwrap();
let (server_pub_key, _) = sign::keypair_from_seed(&server_seed);
let (client_pub_key, _) = sign::keypair_from_seed(&client_seed);
let root_user = User {
trusted_by: None,
name: "root".into(),
client_stats: Default::default(),
server_stats: Default::default(),
};
let mut users_map: HashMap<UserId, User> = HashMap::new();
users_map.insert(UserId::from_pubkey(&client_pub_key), root_user);
let db = DB {
users: users_map,
func_stats: HashMap::new(),
total_stats: Default::default(),
};
serde_json::to_writer(&db_file, &db)?;
println!(
"Setup successful!\n\n\
Put the following in the server's config.toml:\n\n\
priv_seed = \"{}\"\n\n\
Put the following in the root client's pah.conf:\n\n\
secret_key = \"{}\"\n\
server_public_key = \"{}\"\n\
server_address = \"server.example:port\"",
hex::encode(server_seed),
hex::encode(client_seed),
hex::encode(server_pub_key)
);
Ok(())
}
@@ -1,88 +0,0 @@
use std::future::Future;
use tokio::sync::mpsc;
use crate::db::{Stats, UserId};
use crate::save::SaveableDB;
const CHANNEL_CAPACITY: usize = 10000;
#[derive(Clone, Copy)]
pub enum Outcome {
Matched,
Improved,
Unhelpful,
}
pub enum Record {
WorkDone {
server: UserId,
client: UserId,
fn_name: String,
outcome: Outcome,
},
ClientNewFunction {
client: UserId,
fn_name: String,
},
ServerNewFunction {
server: UserId,
},
}
fn add_stats(stats: &mut Stats, outcome: Outcome) {
if matches!(outcome, Outcome::Matched) {
stats.matches += 1;
}
if matches!(outcome, Outcome::Matched | Outcome::Improved) {
stats.improvements += 1;
}
stats.iterations += 1;
}
async fn stats_writer(db: &SaveableDB, mut rx: mpsc::Receiver<Record>) {
loop {
let record = rx.recv().await.unwrap();
db.write(false, |db| {
match record {
Record::WorkDone {
server,
client,
fn_name,
outcome,
} => {
add_stats(&mut db.total_stats, outcome);
add_stats(db.func_stat(fn_name), outcome);
if let Some(user) = db.users.get_mut(&client) {
add_stats(&mut user.client_stats, outcome);
}
if let Some(user) = db.users.get_mut(&server) {
add_stats(&mut user.server_stats, outcome);
}
}
Record::ClientNewFunction { client, fn_name } => {
db.func_stat(fn_name).functions += 1;
if let Some(user) = db.users.get_mut(&client) {
user.client_stats.functions += 1;
}
db.total_stats.functions += 1;
}
Record::ServerNewFunction { server } => {
if let Some(user) = db.users.get_mut(&server) {
user.server_stats.functions += 1;
}
}
};
})
.await;
}
}
pub fn stats_thread(db: &SaveableDB) -> (impl Future<Output = ()>, mpsc::Sender<Record>) {
let (tx, rx) = mpsc::channel(CHANNEL_CAPACITY);
let db = db.clone();
let fut = async move {
stats_writer(&db, rx).await;
};
(fut, tx)
}
@@ -1,37 +0,0 @@
use std::error::Error;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use pin_project::pin_project;
pub type SimpleResult<T> = Result<T, Box<dyn Error + Send + Sync>>;
#[pin_project]
pub struct NowOrNever<F: Future> {
#[pin]
inner: F,
}
impl<F: Future> Future for NowOrNever<F> {
type Output = Option<F::Output>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let ret = self.project().inner.poll(cx);
Poll::Ready(match ret {
Poll::Pending => None,
Poll::Ready(val) => Some(val),
})
}
}
impl<T> FutureExt for T where T: Future {}
pub trait FutureExt: Future {
fn now_or_never(self) -> NowOrNever<Self>
where
Self: Sized,
{
NowOrNever { inner: self }
}
}
@@ -1,80 +0,0 @@
use std::str;
use hex::FromHex;
use serde::Deserialize;
use serde_json::json;
use sodiumoxide::crypto::sign;
use crate::db::{User, UserId};
use crate::port::{ReadPort, WritePort};
use crate::util::SimpleResult;
use crate::{concat, State};
#[derive(Debug, Deserialize)]
pub(crate) struct VouchData {
who: UserId,
signed_name: String,
}
fn verify_with_magic<'a>(
magic: &[u8],
data: &'a [u8],
key: &sign::PublicKey,
) -> SimpleResult<&'a [u8]> {
if data.len() < 64 {
Err("signature too short")?;
}
let (signature, data) = data.split_at(64);
let signed_data = concat(magic, data);
let signature = sign::Signature::from_slice(signature).unwrap();
if !sign::verify_detached(&signature, &signed_data, key) {
Err("bad signature")?;
}
Ok(data)
}
fn parse_signed_name(signed_name: &str, who: &UserId) -> SimpleResult<String> {
let signed_name = Vec::from_hex(signed_name).map_err(|_| "not a valid hex string")?;
let name_bytes = verify_with_magic(b"NAME:", &signed_name, &who.to_pubkey())?;
let name = str::from_utf8(name_bytes)?;
if name.is_empty() {
Err("name is empty")?;
}
if name.chars().any(char::is_control) {
Err("name cannot contain control characters")?;
}
Ok(name.to_string())
}
pub(crate) async fn handle_vouch<'a>(
mut read_port: ReadPort<'a>,
mut write_port: WritePort<'a>,
who_id: UserId,
who_name: &str,
state: &State,
data: VouchData,
) -> SimpleResult<()> {
let vouchee_name = match parse_signed_name(&data.signed_name, &data.who) {
Ok(name) => name,
Err(e) => {
write_port.send_error(&format!("{}", &e)).await?;
Err(e)?
}
};
write_port.send_json(&json!({})).await?;
read_port.recv().await?;
state
.db
.write(true, |db| {
db.users.entry(data.who).or_insert_with(|| User {
trusted_by: Some(who_id),
name: vouchee_name.clone(),
client_stats: Default::default(),
server_stats: Default::default(),
});
})
.await;
write_port.send_json(&json!({})).await?;
eprintln!("[{}] vouch {}", who_name, &vouchee_name);
Ok(())
}
-416
View File
@@ -1,416 +0,0 @@
import abc
from dataclasses import dataclass
import datetime
import json
import socket
import struct
import sys
import toml
import typing
from typing import BinaryIO, List, Optional, Type, TypeVar, Union
from nacl.encoding import HexEncoder
from nacl.public import Box, PrivateKey, PublicKey
from nacl.secret import SecretBox
from nacl.signing import SigningKey, VerifyKey
from ..error import ServerError
from ..helpers import exception_to_string
T = TypeVar("T")
AnyBox = Union[Box, SecretBox]
PERMUTER_VERSION = 2
CONFIG_FILENAME = "pah.conf"
MIN_PRIO = 0.01
MAX_PRIO = 2.0
DEBUG_MODE = False
def enable_debug_mode() -> None:
"""Enable debug logging."""
global DEBUG_MODE
DEBUG_MODE = True
def debug_print(message: str) -> None:
if DEBUG_MODE:
time = datetime.datetime.now().strftime("%H:%M:%S:%f")
print(f"\n{time} debug: {message}")
@dataclass(eq=False)
class CancelToken:
cancelled: bool = False
@dataclass
class PermuterData:
base_score: int
base_hash: str
fn_name: str
filename: str
keep_prob: float
need_profiler: bool
stack_differences: bool
compile_script: str
source: str
target_o_bin: bytes
def permuter_data_from_json(
obj: dict, source: str, target_o_bin: bytes
) -> PermuterData:
return PermuterData(
base_score=json_prop(obj, "base_score", int),
base_hash=json_prop(obj, "base_hash", str),
fn_name=json_prop(obj, "fn_name", str),
filename=json_prop(obj, "filename", str),
keep_prob=json_prop(obj, "keep_prob", float),
need_profiler=json_prop(obj, "need_profiler", bool),
stack_differences=json_prop(obj, "stack_differences", bool),
compile_script=json_prop(obj, "compile_script", str),
source=source,
target_o_bin=target_o_bin,
)
def permuter_data_to_json(perm: PermuterData) -> dict:
return {
"base_score": perm.base_score,
"base_hash": perm.base_hash,
"fn_name": perm.fn_name,
"filename": perm.filename,
"keep_prob": perm.keep_prob,
"need_profiler": perm.need_profiler,
"stack_differences": perm.stack_differences,
"compile_script": perm.compile_script,
}
@dataclass
class Config:
server_address: Optional[str] = None
server_verify_key: Optional[VerifyKey] = None
signing_key: Optional[SigningKey] = None
initial_setup_nickname: Optional[str] = None
def read_config() -> Config:
config = Config()
try:
with open(CONFIG_FILENAME) as f:
obj = toml.load(f)
def read(key: str, t: Type[T]) -> Optional[T]:
ret = obj.get(key)
return ret if isinstance(ret, t) else None
temp = read("server_public_key", str)
if temp:
config.server_verify_key = VerifyKey(HexEncoder.decode(temp))
temp = read("secret_key", str)
if temp:
config.signing_key = SigningKey(HexEncoder.decode(temp))
config.initial_setup_nickname = read("initial_setup_nickname", str)
config.server_address = read("server_address", str)
except FileNotFoundError:
pass
except Exception:
print(f"Malformed configuration file {CONFIG_FILENAME}.\n")
raise
return config
def write_config(config: Config) -> None:
obj = {}
def write(key: str, val: Union[None, str, int]) -> None:
if val is not None:
obj[key] = val
write("initial_setup_nickname", config.initial_setup_nickname)
write("server_address", config.server_address)
key_hex: bytes
if config.server_verify_key:
key_hex = config.server_verify_key.encode(HexEncoder)
write("server_public_key", key_hex.decode("utf-8"))
if config.signing_key:
key_hex = config.signing_key.encode(HexEncoder)
write("secret_key", key_hex.decode("utf-8"))
with open(CONFIG_FILENAME, "w") as f:
toml.dump(obj, f)
def file_read_max(inf: BinaryIO, n: int) -> bytes:
try:
ret = []
while n > 0:
data = inf.read(n)
if not data:
break
ret.append(data)
n -= len(data)
return b"".join(ret)
except Exception as e:
raise EOFError from e
def file_read_fixed(inf: BinaryIO, n: int) -> bytes:
ret = file_read_max(inf, n)
if len(ret) != n:
raise EOFError
return ret
def socket_read_max(sock: socket.socket, n: int) -> bytes:
try:
ret = []
while n > 0:
data = sock.recv(min(n, 4096))
if not data:
break
ret.append(data)
n -= len(data)
return b"".join(ret)
except Exception as e:
raise EOFError from e
def socket_read_fixed(sock: socket.socket, n: int) -> bytes:
ret = socket_read_max(sock, n)
if len(ret) != n:
raise EOFError
return ret
def socket_shutdown(sock: socket.socket, how: int) -> None:
try:
sock.shutdown(how)
except Exception:
pass
def json_prop(obj: dict, prop: str, t: Type[T]) -> T:
ret = obj.get(prop)
if not isinstance(ret, t):
if t is float and isinstance(ret, int):
return typing.cast(T, float(ret))
found_type = type(ret).__name__
if prop not in obj:
raise ValueError(f"Member {prop} does not exist")
raise ValueError(f"Member {prop} must have type {t.__name__}; got {found_type}")
return ret
def json_array(obj: list, t: Type[T]) -> List[T]:
for elem in obj:
if not isinstance(elem, t):
found_type = type(elem).__name__
raise ValueError(
f"Array elements must have type {t.__name__}; got {found_type}"
)
return obj
def sign_with_magic(magic: bytes, signing_key: SigningKey, data: bytes) -> bytes:
signature: bytes = signing_key.sign(magic + b":" + data).signature
return signature + data
def verify_with_magic(magic: bytes, verify_key: VerifyKey, data: bytes) -> bytes:
if len(data) < 64:
raise ValueError("String is too small to contain a signature")
signature = data[:64]
data = data[64:]
verify_key.verify(magic + b":" + data, signature)
return data
class Port(abc.ABC):
def __init__(self, box: AnyBox, who: str, *, is_client: bool) -> None:
self._box = box
self._who = who
self._send_nonce = 0 if is_client else 1
self._receive_nonce = 1 if is_client else 0
@abc.abstractmethod
def _send(self, data: bytes) -> None:
...
@abc.abstractmethod
def _receive(self, length: int) -> bytes:
...
@abc.abstractmethod
def _receive_max(self, length: int) -> bytes:
...
def send(self, msg: bytes) -> None:
"""Send a binary message, potentially blocking."""
if DEBUG_MODE:
if len(msg) <= 300:
debug_print(f"Send to {self._who}: {msg!r}")
else:
debug_print(f"Send to {self._who}: {len(msg)} bytes")
nonce = struct.pack(">16xQ", self._send_nonce)
self._send_nonce += 2
data = self._box.encrypt(msg, nonce).ciphertext
length_data = struct.pack(">Q", len(data))
try:
self._send(length_data + data)
except BrokenPipeError:
raise EOFError from None
def send_json(self, msg: dict) -> None:
"""Send a message in the form of a JSON dict, potentially blocking."""
self.send(json.dumps(msg).encode("utf-8"))
def receive(self) -> bytes:
"""Read a binary message, blocking."""
length_data = self._receive(8)
if length_data[0]:
# Lengths above 2^56 are unreasonable, so if we get one someone is
# sending us bad data. Raise an exception to help debugging.
length_data += self._receive_max(1024)
raise Exception(
f"Got unexpected data from {self._who}: " + repr(length_data)
)
length = struct.unpack(">Q", length_data)[0]
data = self._receive(length)
nonce = struct.pack(">16xQ", self._receive_nonce)
self._receive_nonce += 2
msg: bytes = self._box.decrypt(data, nonce)
if DEBUG_MODE:
if len(msg) <= 300:
debug_print(f"Receive from {self._who}: {msg!r}")
else:
debug_print(f"Receive from {self._who}: {len(msg)} bytes")
return msg
def receive_json(self) -> dict:
"""Read a message in the form of a JSON dict, blocking."""
ret = json.loads(self.receive())
if isinstance(ret, str):
# Raw strings indicate errors.
raise ServerError(ret)
if not isinstance(ret, dict):
# We always pass dictionaries as messages and no other data types,
# to ensure future extensibility. (Other types are rare in
# practice, anyway.)
raise ValueError("Top-level JSON value must be a dictionary")
return ret
class SocketPort(Port):
def __init__(
self, sock: socket.socket, box: AnyBox, who: str, *, is_client: bool
) -> None:
self._sock = sock
super().__init__(box, who, is_client=is_client)
def _send(self, data: bytes) -> None:
self._sock.sendall(data)
def _receive(self, length: int) -> bytes:
return socket_read_fixed(self._sock, length)
def _receive_max(self, length: int) -> bytes:
return socket_read_max(self._sock, length)
def shutdown(self, how: int = socket.SHUT_RDWR) -> None:
socket_shutdown(self._sock, how)
def close(self) -> None:
self._sock.close()
class FilePort(Port):
def __init__(
self, inf: BinaryIO, outf: BinaryIO, box: AnyBox, who: str, *, is_client: bool
) -> None:
self._inf = inf
self._outf = outf
super().__init__(box, who, is_client=is_client)
def _send(self, data: bytes) -> None:
self._outf.write(data)
self._outf.flush()
def _receive(self, length: int) -> bytes:
return file_read_fixed(self._inf, length)
def _receive_max(self, length: int) -> bytes:
return file_read_max(self._inf, length)
def _do_connect(config: Config) -> SocketPort:
if (
not config.server_verify_key
or not config.signing_key
or not config.server_address
):
print(
"Using permuter@home requires someone to give you access to a central -J server.\n"
"Run `./pah.py setup` to set this up."
)
print()
sys.exit(1)
host, port_str = config.server_address.split(":")
try:
sock = socket.create_connection((host, int(port_str)))
except ConnectionRefusedError:
raise EOFError("connection refused") from None
except socket.gaierror as e:
raise EOFError(f"DNS lookup failed: {e}") from None
except Exception as e:
raise EOFError("unable to connect: " + exception_to_string(e)) from None
# Send over the protocol version and an ephemeral encryption key which we
# are going to use for all communication.
ephemeral_key = PrivateKey.generate()
ephemeral_key_data = ephemeral_key.public_key.encode()
sock.sendall(b"p@h0" + ephemeral_key_data)
# Receive the server's encryption key, plus a signature of it and our own
# ephemeral key -- this guarantees that we are talking to the server and
# aren't victim to a replay attack. Use it to set up a communication port.
msg = socket_read_fixed(sock, 32 + 64)
server_enc_key_data = msg[:32]
config.server_verify_key.verify(
b"HELLO:" + ephemeral_key_data + server_enc_key_data, msg[32:]
)
box = Box(ephemeral_key, PublicKey(server_enc_key_data))
port = SocketPort(sock, box, "controller", is_client=True)
# Use the encrypted port to send over our public key, proof that we are
# able to sign new things with it, as well as permuter version.
signature: bytes = config.signing_key.sign(
b"WORLD:" + server_enc_key_data
).signature
port.send(
config.signing_key.verify_key.encode()
+ signature
+ struct.pack(">I", PERMUTER_VERSION)
)
# Get an acknowledgement that the server wants to talk to us.
obj = port.receive_json()
if "message" in obj:
print(obj["message"])
return port
def connect(config: Optional[Config] = None) -> SocketPort:
"""Authenticate and connect to the permuter@home controller server."""
if not config:
config = read_config()
return _do_connect(config)
-401
View File
@@ -1,401 +0,0 @@
"""This file runs as a free-standing program within a sandbox, and processes
permutation requests. It communicates with the outside world on stdin/stdout."""
import base64
from dataclasses import dataclass
import math
from multiprocessing import Process, Queue
import os
import queue
import sys
from tempfile import mkstemp
import threading
import time
import traceback
from typing import Counter, Dict, List, Optional, Set, Tuple, Union
import zlib
from nacl.secret import SecretBox
from ..candidate import CandidateResult
from ..compiler import Compiler
from ..error import CandidateConstructionFailure
from ..helpers import exception_to_string, static_assert_unreachable
from ..permuter import EvalError, EvalResult, Permuter
from ..profiler import Profiler
from ..scorer import Scorer
from .core import (
FilePort,
PermuterData,
Port,
json_prop,
permuter_data_from_json,
)
def _fix_stdout() -> None:
"""Redirect stdout to stderr to make print() debugging work. This function
*must* be called at startup for each (sub)process, since we use stdout for
our own communication purposes."""
sys.stdout = sys.stderr
# In addition, we set stderr to flush on newlines, which does not happen by
# default when it is piped. (Requires Python 3.7, but we can assume that's
# available inside the sandbox.)
sys.stdout.reconfigure(line_buffering=True) # type: ignore
def _setup_port(secret: bytes) -> Port:
"""Set up communication with the outside world."""
port = FilePort(
sys.stdin.buffer,
sys.stdout.buffer,
SecretBox(secret),
"server",
is_client=False,
)
_fix_stdout()
# Follow the controlling process's sanity check protocol.
magic = port.receive()
port.send(magic)
return port
def _create_permuter(data: PermuterData) -> Permuter:
fd, path = mkstemp(suffix=".o", prefix="permuter", text=False)
try:
with os.fdopen(fd, "wb") as f:
f.write(data.target_o_bin)
scorer = Scorer(target_o=path, stack_differences=data.stack_differences)
finally:
os.unlink(path)
fd, path = mkstemp(suffix=".sh", prefix="permuter", text=True)
try:
os.chmod(fd, 0o755)
with os.fdopen(fd, "w") as f2:
f2.write(data.compile_script)
compiler = Compiler(compile_cmd=path, show_errors=False)
return Permuter(
dir="unused",
fn_name=data.fn_name,
compiler=compiler,
scorer=scorer,
source_file=data.filename,
source=data.source,
force_seed=None,
force_rng_seed=None,
keep_prob=data.keep_prob,
need_profiler=data.need_profiler,
need_all_sources=False,
show_errors=False,
better_only=False,
best_only=False,
)
except:
os.unlink(path)
raise
@dataclass
class AddPermuter:
perm_id: str
data: PermuterData
@dataclass
class AddPermuterLocal:
perm_id: str
permuter: Permuter
@dataclass
class RemovePermuter:
perm_id: str
@dataclass
class WorkDone:
perm_id: str
id: int
time_us: int
result: EvalResult
@dataclass
class Work:
perm_id: str
id: int
seed: int
LocalWork = Tuple[Union[AddPermuterLocal, RemovePermuter], int]
GlobalWork = Tuple[Work, int]
Task = Union[AddPermuter, RemovePermuter, Work, WorkDone]
def _remove_permuter(perm: Permuter) -> None:
os.unlink(perm.compiler.compile_cmd)
def _send_result(item: WorkDone, port: Port) -> None:
obj = {
"type": "result",
"permuter": item.perm_id,
"id": item.id,
"time_us": item.time_us,
}
res = item.result
if isinstance(res, EvalError):
obj["error"] = res.exc_str
port.send_json(obj)
return
compressed_source = getattr(res, "compressed_source")
obj["score"] = res.score
obj["has_source"] = compressed_source is not None
if res.hash is not None:
obj["hash"] = res.hash
if res.profiler is not None:
obj["profiler"] = {
st.name: res.profiler.time_stats[st] for st in Profiler.StatType
}
port.send_json(obj)
if compressed_source is not None:
port.send(compressed_source)
def multiprocess_worker(
worker_queue: "Queue[GlobalWork]",
local_queue: "Queue[LocalWork]",
task_queue: "Queue[Task]",
) -> None:
_fix_stdout()
# Prevent deadlocks in case the parent process dies.
worker_queue.cancel_join_thread()
local_queue.cancel_join_thread()
task_queue.cancel_join_thread()
permuters: Dict[str, Permuter] = {}
timestamp = 0
while True:
work, required_timestamp = worker_queue.get()
while True:
try:
block = timestamp < required_timestamp
task, timestamp = local_queue.get(block=block)
except queue.Empty:
break
if isinstance(task, AddPermuterLocal):
permuters[task.perm_id] = task.permuter
elif isinstance(task, RemovePermuter):
del permuters[task.perm_id]
else:
static_assert_unreachable(task)
time_before = time.time()
permuter = permuters[work.perm_id]
result = permuter.try_eval_candidate(work.seed)
if isinstance(result, CandidateResult) and permuter.should_output(result):
permuter.record_result(result)
# Compress the source within the worker. (Why waste a free
# multi-threading opportunity?)
if isinstance(result, CandidateResult):
compressed_source: Optional[bytes] = None
if result.source is not None:
compressed_source = zlib.compress(result.source.encode("utf-8"))
setattr(result, "compressed_source", compressed_source)
result.source = None
time_us = int((time.time() - time_before) * 10 ** 6)
task_queue.put(
WorkDone(perm_id=work.perm_id, id=work.id, time_us=time_us, result=result)
)
def read_loop(task_queue: "Queue[Task]", port: Port) -> None:
try:
while True:
item = port.receive_json()
msg_type = json_prop(item, "type", str)
if msg_type == "add":
perm_id = json_prop(item, "permuter", str)
source = port.receive().decode("utf-8")
target_o_bin = port.receive()
data = permuter_data_from_json(item, source, target_o_bin)
task_queue.put(AddPermuter(perm_id=perm_id, data=data))
elif msg_type == "remove":
perm_id = json_prop(item, "permuter", str)
task_queue.put(RemovePermuter(perm_id=perm_id))
elif msg_type == "work":
perm_id = json_prop(item, "permuter", str)
id = json_prop(item, "id", int)
seed = json_prop(item, "seed", int)
task_queue.put(Work(perm_id=perm_id, id=id, seed=seed))
else:
raise Exception(f"Invalid message type {msg_type}")
except Exception as e:
# In case the port is closed from the other side, skip writing an ugly
# error message.
if not isinstance(e, EOFError):
traceback.print_exc()
# Exit the whole process, to improve the odds that the Docker container
# really stops and gets removed.
#
# The parent server has a "finally:" that does that, but it's not 100%
# trustworthy. In particular, pystray has a tendency to hard-crash
# (which doesn't fire "finally"s), and also reverts the signal handler
# for SIGINT to the default on Linux, making Ctrl+C not run cleanup.
# Either way, defense in depth here doesn't hurt, since leaking Docker
# containers is pretty bad.
#
# Unfortunately this still doesn't fix the problem, since we typically
# don't get a port closure signal when the parent process stops...
# TODO: listen to heartbeats as well.
sys.exit(1)
def main() -> None:
secret = base64.b64decode(os.environ["SECRET"])
del os.environ["SECRET"]
os.environ["PERMUTER_IS_REMOTE"] = "1"
port = _setup_port(secret)
obj = port.receive_json()
num_cores = json_prop(obj, "num_cores", float)
num_threads = math.ceil(num_cores)
worker_queue: "Queue[GlobalWork]" = Queue()
task_queue: "Queue[Task]" = Queue()
local_queues: "List[Queue[LocalWork]]" = []
for i in range(num_threads):
local_queue: "Queue[LocalWork]" = Queue()
p = Process(
target=multiprocess_worker,
args=(worker_queue, local_queue, task_queue),
daemon=True,
)
p.start()
local_queues.append(local_queue)
reader_thread = threading.Thread(
target=read_loop, args=(task_queue, port), daemon=True
)
reader_thread.start()
remaining_work: Counter[str] = Counter()
should_remove: Set[str] = set()
permuters: Dict[str, Permuter] = {}
timestamp = 0
def try_remove(perm_id: str) -> None:
nonlocal timestamp
assert perm_id in permuters
if perm_id not in should_remove or remaining_work[perm_id] != 0:
return
del remaining_work[perm_id]
should_remove.remove(perm_id)
timestamp += 1
for q in local_queues:
q.put((RemovePermuter(perm_id=perm_id), timestamp))
_remove_permuter(permuters[perm_id])
del permuters[perm_id]
while True:
item = task_queue.get()
if isinstance(item, AddPermuter):
assert item.perm_id not in permuters
msg: Dict[str, object] = {
"type": "init",
"permuter": item.perm_id,
}
time_before = time.time()
try:
# Construct a permuter. This involves a compilation on the main
# thread, which isn't great but we can live with it for now.
permuter = _create_permuter(item.data)
if permuter.base_score != item.data.base_score:
_remove_permuter(permuter)
score_str = f"{permuter.base_score} vs {item.data.base_score}"
if permuter.base_hash == item.data.base_hash:
hash_str = "same hash; different Python or permuter versions?"
else:
hash_str = "different hash; different objdump versions?"
raise CandidateConstructionFailure(
f"mismatching score: {score_str} ({hash_str})"
)
permuters[item.perm_id] = permuter
msg["success"] = True
msg["base_score"] = permuter.base_score
msg["base_hash"] = permuter.base_hash
# Tell all the workers about the new permuter.
# TODO: ideally we would also seed their Candidate lru_cache's
# to avoid all workers having to parse the source...
timestamp += 1
for q in local_queues:
q.put(
(
AddPermuterLocal(perm_id=item.perm_id, permuter=permuter),
timestamp,
)
)
except Exception as e:
# This shouldn't practically happen, since the client compiled
# the code successfully. Print a message if it does.
msg["success"] = False
msg["error"] = exception_to_string(e)
if isinstance(e, CandidateConstructionFailure):
print(e.message)
else:
traceback.print_exc()
msg["time_us"] = int((time.time() - time_before) * 10 ** 6)
port.send_json(msg)
elif isinstance(item, RemovePermuter):
# Silently ignore requests to remove permuters that have already
# been removed, which can occur when AddPermuter fails.
if item.perm_id in permuters:
should_remove.add(item.perm_id)
try_remove(item.perm_id)
elif isinstance(item, WorkDone):
remaining_work[item.perm_id] -= 1
try_remove(item.perm_id)
_send_result(item, port)
elif isinstance(item, Work):
remaining_work[item.perm_id] += 1
worker_queue.put((item, timestamp))
else:
static_assert_unreachable(item)
if __name__ == "__main__":
main()
-944
View File
@@ -1,944 +0,0 @@
import base64
from dataclasses import dataclass
import pathlib
import queue
import struct
import sys
import threading
import time
import traceback
from typing import BinaryIO, Dict, Optional, Set, Tuple, Union, TYPE_CHECKING
import zlib
if TYPE_CHECKING:
import docker
from nacl.secret import SecretBox
import nacl.utils
from ..helpers import exception_to_string, static_assert_unreachable
from .core import (
CancelToken,
Config,
PermuterData,
Port,
ServerError,
SocketPort,
connect,
file_read_fixed,
json_prop,
permuter_data_from_json,
permuter_data_to_json,
)
_HEARTBEAT_INTERVAL_SLACK_SEC: float = 50.0
@dataclass
class Client:
id: str
nickname: str
@dataclass
class AddPermuter:
handle: int
time_start: float
client: Client
permuter_data: PermuterData
@dataclass
class RemovePermuter:
handle: int
@dataclass
class Work:
handle: int
id: int
time_start: float
seed: int
@dataclass
class ImmediateDisconnect:
handle: int
client: Client
reason: str
@dataclass
class Disconnect:
handle: int
@dataclass
class PermInitFail:
perm_id: str
error: str
@dataclass
class PermInitSuccess:
perm_id: str
base_score: int
base_hash: str
time_us: int
@dataclass
class WorkDone:
perm_id: str
id: int
obj: dict
time_us: int
compressed_source: Optional[bytes]
class NeedMoreWork:
pass
@dataclass
class NetThreadDisconnected:
graceful: bool
message: Optional[str] = None
class Heartbeat:
pass
class Shutdown:
pass
Activity = Union[
AddPermuter,
RemovePermuter,
Work,
ImmediateDisconnect,
Disconnect,
PermInitFail,
PermInitSuccess,
WorkDone,
NeedMoreWork,
NetThreadDisconnected,
Heartbeat,
Shutdown,
]
@dataclass
class OutputInitFail:
handle: int
error: str
@dataclass
class OutputInitSuccess:
handle: int
time_us: int
base_score: int
base_hash: str
@dataclass
class OutputDisconnect:
handle: int
@dataclass
class OutputNeedMoreWork:
pass
@dataclass
class OutputWork:
handle: int
time_start: float
time_us: int
obj: dict
compressed_source: Optional[bytes]
Output = Union[
OutputDisconnect,
OutputInitFail,
OutputInitSuccess,
OutputNeedMoreWork,
OutputWork,
Shutdown,
]
@dataclass
class IoConnect:
fn_name: str
client: Client
@dataclass
class IoDisconnect:
reason: str
@dataclass
class IoImmediateDisconnect:
reason: str
client: Client
class IoUserRemovePermuter:
pass
@dataclass
class IoServerFailed:
graceful: bool
message: Optional[str]
class IoReconnect:
pass
class IoShutdown:
pass
@dataclass
class IoWorkDone:
score: Optional[int]
is_improvement: bool
PermuterHandle = Tuple[int, CancelToken]
IoMessage = Union[
IoConnect, IoDisconnect, IoImmediateDisconnect, IoUserRemovePermuter, IoWorkDone
]
IoGlobalMessage = Union[IoReconnect, IoShutdown, IoServerFailed]
IoActivity = Tuple[
Optional[CancelToken], Union[Tuple[PermuterHandle, IoMessage], IoGlobalMessage]
]
@dataclass
class ServerOptions:
num_cores: float
max_memory_gb: float
min_priority: float
class NetThread:
_port: Optional[SocketPort]
_main_queue: "queue.Queue[Activity]"
_controller_queue: "queue.Queue[Output]"
_read_thread: "threading.Thread"
_write_thread: "threading.Thread"
_next_work_id: int
def __init__(
self,
port: SocketPort,
main_queue: "queue.Queue[Activity]",
) -> None:
self._port = port
self._main_queue = main_queue
self._controller_queue = queue.Queue()
self._next_work_id = 0
self._read_thread = threading.Thread(target=self.read_loop, daemon=True)
self._read_thread.start()
self._write_thread = threading.Thread(target=self.write_loop, daemon=True)
self._write_thread.start()
def stop(self) -> None:
if self._port is None:
return
try:
self._controller_queue.put(Shutdown())
self._port.shutdown()
self._read_thread.join()
self._write_thread.join()
self._port.close()
self._port = None
except Exception:
print("Failed to stop net thread.")
traceback.print_exc()
def send_controller(self, msg: Output) -> None:
self._controller_queue.put(msg)
def _read_one(self) -> Activity:
assert self._port is not None
msg = self._port.receive_json()
time_start = time.time()
msg_type = json_prop(msg, "type", str)
if msg_type == "heartbeat":
return Heartbeat()
handle = json_prop(msg, "permuter", int)
if msg_type == "work":
seed = json_prop(msg, "seed", int)
id = self._next_work_id
self._next_work_id += 1
return Work(handle=handle, id=id, time_start=time_start, seed=seed)
elif msg_type == "add":
client_id = json_prop(msg, "client_id", str)
client_name = json_prop(msg, "client_name", str)
client = Client(client_id, client_name)
data = json_prop(msg, "data", dict)
compressed_source = self._port.receive()
compressed_target_o_bin = self._port.receive()
try:
source = zlib.decompress(compressed_source).decode("utf-8")
target_o_bin = zlib.decompress(compressed_target_o_bin)
permuter = permuter_data_from_json(data, source, target_o_bin)
except Exception as e:
# Client sent something illegible. This can legitimately happen if the
# client runs another version, but it's interesting to log.
traceback.print_exc()
return ImmediateDisconnect(
handle=handle,
client=client,
reason=f"Failed to parse permuter: {exception_to_string(e)}",
)
return AddPermuter(
handle=handle,
time_start=time_start,
client=client,
permuter_data=permuter,
)
elif msg_type == "remove":
return RemovePermuter(handle=handle)
else:
raise Exception(f"Bad message type: {msg_type}")
def read_loop(self) -> None:
try:
while True:
msg = self._read_one()
self._main_queue.put(msg)
except EOFError:
self._main_queue.put(NetThreadDisconnected(graceful=True))
except ServerError as e:
self._main_queue.put(
NetThreadDisconnected(graceful=False, message=e.message)
)
except Exception:
traceback.print_exc()
self._main_queue.put(NetThreadDisconnected(graceful=False))
def _write_one(self, item: Output) -> None:
assert self._port is not None
if isinstance(item, Shutdown):
# Handled by caller
pass
elif isinstance(item, OutputInitFail):
self._port.send_json(
{
"type": "update",
"permuter": item.handle,
"time_us": 0,
"update": {"type": "init_failed", "reason": item.error},
}
)
elif isinstance(item, OutputInitSuccess):
self._port.send_json(
{
"type": "update",
"permuter": item.handle,
"time_us": item.time_us,
"update": {"type": "init_done", "hash": item.base_hash},
}
)
elif isinstance(item, OutputDisconnect):
self._port.send_json(
{
"type": "update",
"permuter": item.handle,
"time_us": 0,
"update": {"type": "disconnect"},
}
)
elif isinstance(item, OutputNeedMoreWork):
self._port.send_json({"type": "need_work"})
elif isinstance(item, OutputWork):
overhead_us = int((time.time() - item.time_start) * 10 ** 6) - item.time_us
self._port.send_json(
{
"type": "update",
"permuter": item.handle,
"time_us": item.time_us,
"update": {
"type": "work",
"overhead_us": overhead_us,
**item.obj,
},
}
)
if item.compressed_source is not None:
self._port.send(item.compressed_source)
else:
static_assert_unreachable(item)
def write_loop(self) -> None:
try:
while True:
item = self._controller_queue.get()
if isinstance(item, Shutdown):
break
self._write_one(item)
except EOFError:
self._main_queue.put(NetThreadDisconnected(graceful=True))
except Exception:
traceback.print_exc()
self._main_queue.put(NetThreadDisconnected(graceful=False))
class ServerInner:
"""This class represents an up-and-running server, connected to the controller and
to the evaluator."""
_evaluator_port: "DockerPort"
_main_queue: "queue.Queue[Activity]"
_io_queue: "queue.Queue[IoActivity]"
_net_thread: NetThread
_read_eval_thread: threading.Thread
_main_thread: threading.Thread
_heartbeat_interval: float
_last_heartbeat: float
_last_heartbeat_lock: threading.Lock
_active: Set[int]
_time_starts: Dict[int, float]
_token: CancelToken
def __init__(
self,
net_port: SocketPort,
evaluator_port: "DockerPort",
io_queue: "queue.Queue[IoActivity]",
heartbeat_interval: float,
) -> None:
self._evaluator_port = evaluator_port
self._main_queue = queue.Queue()
self._io_queue = io_queue
self._active = set()
self._time_starts = {}
self._token = CancelToken()
self._net_thread = NetThread(net_port, self._main_queue)
# Start a thread for checking heartbeats.
self._heartbeat_interval = heartbeat_interval
self._last_heartbeat = time.time()
self._last_heartbeat_lock = threading.Lock()
self._heartbeat_stop = threading.Event()
self._heartbeat_thread = threading.Thread(
target=self._heartbeat_loop, daemon=True
)
self._heartbeat_thread.start()
# Start a thread for reading evaluator results and sending them on to
# the main loop queue.
self._read_eval_thread = threading.Thread(
target=self._read_eval_loop, daemon=True
)
self._read_eval_thread.start()
# Start a thread for the main loop.
self._main_thread = threading.Thread(target=self._main_loop, daemon=True)
self._main_thread.start()
def _send_controller(self, msg: Output) -> None:
self._net_thread.send_controller(msg)
def _send_io(self, handle: int, io_msg: IoMessage) -> None:
self._io_queue.put((self._token, ((handle, self._token), io_msg)))
def _send_io_global(self, io_msg: IoGlobalMessage) -> None:
self._io_queue.put((self._token, io_msg))
def _handle_message(self, msg: Activity) -> None:
if isinstance(msg, Shutdown):
# Handled by caller
pass
elif isinstance(msg, Heartbeat):
with self._last_heartbeat_lock:
self._last_heartbeat = time.time()
elif isinstance(msg, Work):
if msg.handle not in self._active:
self._need_work()
return
self._time_starts[msg.id] = msg.time_start
self._evaluator_port.send_json(
{
"type": "work",
"permuter": str(msg.handle),
"id": msg.id,
"seed": msg.seed,
}
)
elif isinstance(msg, AddPermuter):
if msg.handle in self._active:
raise Exception("Repeated AddPermuter!")
self._active.add(msg.handle)
self._send_permuter(str(msg.handle), msg.permuter_data)
fn_name = msg.permuter_data.fn_name
self._send_io(msg.handle, IoConnect(fn_name, msg.client))
elif isinstance(msg, RemovePermuter):
if msg.handle not in self._active:
return
self._remove(msg.handle)
self._send_io(msg.handle, IoDisconnect("disconnected"))
elif isinstance(msg, Disconnect):
if msg.handle not in self._active:
return
self._remove(msg.handle)
self._send_io(msg.handle, IoDisconnect("kicked"))
self._send_controller(OutputDisconnect(handle=msg.handle))
elif isinstance(msg, ImmediateDisconnect):
if msg.handle in self._active:
raise Exception("ImmediateDisconnect is not immediate")
self._send_io(msg.handle, IoImmediateDisconnect(msg.reason, msg.client))
self._send_controller(OutputInitFail(handle=msg.handle, error=msg.reason))
elif isinstance(msg, PermInitFail):
handle = int(msg.perm_id)
if handle not in self._active:
self._need_work()
return
self._active.remove(handle)
self._send_io(handle, IoDisconnect("failed to compile"))
self._send_controller(
OutputInitFail(
handle=handle,
error=msg.error,
)
)
elif isinstance(msg, PermInitSuccess):
handle = int(msg.perm_id)
if handle not in self._active:
self._need_work()
return
self._send_controller(
OutputInitSuccess(
handle=handle,
time_us=msg.time_us,
base_score=msg.base_score,
base_hash=msg.base_hash,
)
)
elif isinstance(msg, WorkDone):
handle = int(msg.perm_id)
time_start = self._time_starts.pop(msg.id)
if handle not in self._active:
self._need_work()
return
obj = msg.obj
obj["permuter"] = handle
score = json_prop(obj, "score", int) if "score" in obj else None
is_improvement = msg.compressed_source is not None
self._send_io(
handle,
IoWorkDone(score=score, is_improvement=is_improvement),
)
self._send_controller(
OutputWork(
handle=handle,
time_start=time_start,
time_us=msg.time_us,
obj=obj,
compressed_source=msg.compressed_source,
)
)
elif isinstance(msg, NeedMoreWork):
self._need_work()
elif isinstance(msg, NetThreadDisconnected):
self._send_io_global(IoServerFailed(msg.graceful, msg.message))
else:
static_assert_unreachable(msg)
def _need_work(self) -> None:
self._send_controller(OutputNeedMoreWork())
def _remove(self, handle: int) -> None:
self._evaluator_port.send_json({"type": "remove", "permuter": str(handle)})
self._active.remove(handle)
def _send_permuter(self, perm_id: str, perm: PermuterData) -> None:
self._evaluator_port.send_json(
{
"type": "add",
"permuter": perm_id,
**permuter_data_to_json(perm),
}
)
self._evaluator_port.send(perm.source.encode("utf-8"))
self._evaluator_port.send(perm.target_o_bin)
def _do_read_eval_loop(self) -> None:
while True:
msg = self._evaluator_port.receive_json()
msg_type = json_prop(msg, "type", str)
if msg_type == "init":
perm_id = json_prop(msg, "permuter", str)
time_us = json_prop(msg, "time_us", int)
resp: Activity
if json_prop(msg, "success", bool):
resp = PermInitSuccess(
perm_id=perm_id,
base_score=json_prop(msg, "base_score", int),
base_hash=json_prop(msg, "base_hash", str),
time_us=time_us,
)
else:
resp = PermInitFail(
perm_id=perm_id,
error=json_prop(msg, "error", str),
)
self._main_queue.put(resp)
elif msg_type == "result":
compressed_source: Optional[bytes] = None
if msg.get("has_source") == True:
compressed_source = self._evaluator_port.receive()
perm_id = json_prop(msg, "permuter", str)
id = json_prop(msg, "id", int)
time_us = json_prop(msg, "time_us", int)
del msg["permuter"]
del msg["id"]
del msg["time_us"]
self._main_queue.put(
WorkDone(
perm_id=perm_id,
id=id,
obj=msg,
time_us=time_us,
compressed_source=compressed_source,
)
)
else:
raise Exception(f"Unknown message type from evaluator: {msg_type}")
def _read_eval_loop(self) -> None:
try:
self._do_read_eval_loop()
except EOFError:
# Silence errors from shutdown.
pass
def _main_loop(self) -> None:
while True:
msg = self._main_queue.get()
if isinstance(msg, Shutdown):
break
self._handle_message(msg)
def _heartbeat_loop(self) -> None:
second_attempt = False
while True:
with self._last_heartbeat_lock:
delay = (
self._last_heartbeat
+ self._heartbeat_interval
+ _HEARTBEAT_INTERVAL_SLACK_SEC / 2
- time.time()
)
if delay <= 0:
if second_attempt:
self._main_queue.put(NetThreadDisconnected(graceful=True))
return
# Handle clock skew or computer going to sleep by waiting a bit
# longer before giving up.
second_attempt = True
if self._heartbeat_stop.wait(_HEARTBEAT_INTERVAL_SLACK_SEC / 2):
return
else:
second_attempt = False
if self._heartbeat_stop.wait(delay):
return
def remove_permuter(self, handle: int) -> None:
assert not self._token.cancelled
self._main_queue.put(Disconnect(handle=handle))
def stop(self) -> None:
assert not self._token.cancelled
self._token.cancelled = True
self._main_queue.put(Shutdown())
self._heartbeat_stop.set()
self._net_thread.stop()
self._evaluator_port.shutdown()
self._main_thread.join()
self._heartbeat_thread.join()
class DockerPort(Port):
"""Port for communicating with Docker. Communication is encrypted for a few
not-very-good reasons:
- it allows code reuse
- it adds error-checking
- it was fun to implement"""
_sock: BinaryIO
_container: "docker.models.containers.Container"
_stdout_buffer: bytes
_closed: bool
def __init__(
self, container: "docker.models.containers.Container", secret: bytes
) -> None:
self._container = container
self._stdout_buffer = b""
self._closed = False
# Set up a socket for reading from stdout/stderr and writing to
# stdin for the container. The docker package does not seem to
# expose an API for writing the stdin, but we can do so directly
# by attaching a socket and poking at internal state. (See
# https://github.com/docker/docker-py/issues/983.) For stdout/
# stderr, we use the format described at
# https://docs.docker.com/engine/api/v1.24/#attach-to-a-container.
#
# Hopefully this will keep working for at least a while...
try:
self._sock = container.attach_socket(
params={"stdout": True, "stdin": True, "stderr": True, "stream": True}
)
self._sock._writing = True # type: ignore
except:
try:
container.remove(force=True)
except Exception:
pass
raise
super().__init__(SecretBox(secret), "docker", is_client=True)
def shutdown(self) -> None:
import docker
if self._closed:
return
self._closed = True
try:
self._sock.close()
self._container.remove(force=True)
except Exception as e:
if not (
isinstance(e, docker.errors.APIError)
and e.status_code == 409
and "is already in progress" in str(e)
):
print("Failed to shut down Docker")
traceback.print_exc()
def _read_one(self) -> None:
header = file_read_fixed(self._sock, 8)
stream, length = struct.unpack(">BxxxI", header)
if stream not in [1, 2]:
raise Exception("Unexpected output from Docker: " + repr(header))
data = file_read_fixed(self._sock, length)
if stream == 1:
self._stdout_buffer += data
else:
sys.stderr.buffer.write(b"Docker stderr: " + data)
sys.stderr.buffer.flush()
def _receive(self, length: int) -> bytes:
while len(self._stdout_buffer) < length:
self._read_one()
ret = self._stdout_buffer[:length]
self._stdout_buffer = self._stdout_buffer[length:]
return ret
def _receive_max(self, length: int) -> bytes:
length = min(length, len(self._stdout_buffer))
ret = self._stdout_buffer[:length]
self._stdout_buffer = self._stdout_buffer[length:]
return ret
def _send(self, data: bytes) -> None:
while data:
written = self._sock.write(data)
data = data[written:]
self._sock.flush()
def _start_evaluator(docker_image: str, options: ServerOptions) -> DockerPort:
"""Spawn a docker container and set it up to evaluate permutations in,
returning a handle that we can use to communicate with it.
We do this for a few reasons:
- enforcing a known Linux environment, all while the outside server can run
on e.g. Windows and display a systray
- enforcing resource limits
- sandboxing
Docker does have the downside of requiring root access, so ideally we would
also have a Docker-less mode, where we leave the sandboxing to some other
tool, e.g. https://github.com/ioi/isolate/."""
print("Starting docker...")
command = ["python3", "-m", "src.net.evaluator"]
secret = nacl.utils.random(32)
enc_secret = base64.b64encode(secret).decode("utf-8")
src_path = pathlib.Path(__file__).parent.parent.absolute()
try:
import docker
client = docker.from_env()
client.info()
except ModuleNotFoundError:
print(
"Running a server requires the docker Python package to be installed.\n"
"Run `python3 -m pip install --upgrade docker`."
)
sys.exit(1)
except Exception:
traceback.print_exc()
print()
print(
"Failed to start docker. Make sure you have docker installed and "
"the docker daemon running, and either run the permuter with sudo "
'or add yourself to the "docker" UNIX group.'
)
sys.exit(1)
try:
container = client.containers.run(
docker_image,
command,
detach=True,
remove=True,
stdin_open=True,
stdout=True,
environment={"SECRET": enc_secret},
volumes={src_path: {"bind": "/src", "mode": "ro"}},
tmpfs={"/tmp": "size=1G,exec"},
nano_cpus=int(options.num_cores * 1e9),
mem_limit=int(options.max_memory_gb * 2 ** 30),
read_only=True,
network_disabled=True,
)
except Exception as e:
print(f"Failed to start docker container: {e}")
sys.exit(1)
port = DockerPort(container, secret)
try:
# Sanity-check that the Docker container started successfully and can
# be communicated with.
magic = b"\0" * 1000
port.send(magic)
r = port.receive()
if r != magic:
raise Exception("Failed initial sanity check.")
port.send_json({"num_cores": options.num_cores})
except:
port.shutdown()
raise
print("Started.")
return port
class Server:
"""This class represents a server that may or may not be connected to the
controller and the evaluator."""
_server: Optional[ServerInner]
_options: ServerOptions
_config: Config
_io_queue: "queue.Queue[IoActivity]"
def __init__(
self,
options: ServerOptions,
config: Config,
io_queue: "queue.Queue[IoActivity]",
) -> None:
self._server = None
self._options = options
self._config = config
self._io_queue = io_queue
def start(self) -> None:
assert self._server is None
net_port = connect(self._config)
net_port.send_json(
{
"method": "connect_server",
"min_priority": self._options.min_priority,
"num_cores": self._options.num_cores,
}
)
obj = net_port.receive_json()
docker_image = json_prop(obj, "docker_image", str)
heartbeat_interval = json_prop(obj, "heartbeat_interval", float)
evaluator_port = _start_evaluator(docker_image, self._options)
try:
self._server = ServerInner(
net_port, evaluator_port, self._io_queue, heartbeat_interval
)
except:
evaluator_port.shutdown()
raise
def stop(self) -> None:
if self._server is None:
return
self._server.stop()
self._server = None
def remove_permuter(self, handle: PermuterHandle) -> None:
if self._server is not None and not handle[1].cancelled:
self._server.remove_permuter(handle[0])
-224
View File
@@ -1,224 +0,0 @@
#!/usr/bin/env python3
from dataclasses import dataclass, field
import os
import re
import string
import subprocess
import sys
from typing import List, Match, Pattern, Set, Tuple
# Ignore registers, for cleaner output. (We don't do this right now, but it can
# be useful for debugging.)
ign_regs = False
# Don't include branch targets in the output. Assuming our input is semantically
# equivalent skipping it shouldn't be an issue, and it makes insertions have too
# large effect.
ign_branch_targets = True
# Skip branch-likely delay slots. (They aren't interesting on IDO.)
skip_bl_delay_slots = True
skip_lines = 1
re_int = re.compile(r"[0-9]+")
re_int_full = re.compile(r"\b[0-9]+\b")
@dataclass
class ArchSettings:
objdump: List[str]
re_comment: Pattern[str]
re_reg: Pattern[str]
re_sprel: Pattern[str]
re_includes_sp: Pattern[str]
branch_instructions: Set[str]
forbidden: Set[str] = field(default_factory=lambda: set(string.ascii_letters + "_"))
branch_likely_instructions: Set[str] = field(default_factory=set)
MIPS_BRANCH_LIKELY_INSTRUCTIONS = {
"beql",
"bnel",
"beqzl",
"bnezl",
"bgezl",
"bgtzl",
"blezl",
"bltzl",
"bc1tl",
"bc1fl",
}
MIPS_BRANCH_INSTRUCTIONS = {
"b",
"j",
"beq",
"bne",
"beqz",
"bnez",
"bgez",
"bgtz",
"blez",
"bltz",
"bc1t",
"bc1f",
}.union(MIPS_BRANCH_LIKELY_INSTRUCTIONS)
MIPS_SETTINGS: ArchSettings = ArchSettings(
re_comment=re.compile(r"<.*?>"),
re_reg=re.compile(
r"\$?\b(a[0-3]|t[0-9]|s[0-8]|at|v[01]|f[12]?[0-9]|f3[01]|k[01]|fp|ra)\b" # leave out $zero
),
re_sprel=re.compile(r"(?<=,)([0-9]+|0x[0-9a-f]+)\((sp|s8)\)"),
re_includes_sp=re.compile(r"\b(sp|s8)\b"),
objdump=["mips-linux-gnu-objdump", "-drz", "-m", "mips:4300"],
branch_likely_instructions=MIPS_BRANCH_LIKELY_INSTRUCTIONS,
branch_instructions=MIPS_BRANCH_INSTRUCTIONS,
)
def get_arch(o_file: str) -> ArchSettings:
# https://refspecs.linuxfoundation.org/elf/gabi4+/ch4.eheader.html
with open(o_file, "rb") as f:
f.seek(18)
arch_magic = f.read(2)
if arch_magic == b"\0\x08":
return MIPS_SETTINGS
# TODO: support PPC ("\0\x14"), ARM ("0\x28")
raise Exception("Bad ELF")
def parse_relocated_line(line: str) -> Tuple[str, str, str]:
try:
ind2 = line.rindex(",")
except ValueError:
ind2 = line.rindex("\t")
before = line[: ind2 + 1]
after = line[ind2 + 1 :]
ind2 = after.find("(")
if ind2 == -1:
imm, after = after, ""
else:
imm, after = after[:ind2], after[ind2:]
if imm == "0x0":
imm = "0"
return before, imm, after
def simplify_objdump(
input_lines: List[str], arch: ArchSettings, *, stack_differences: bool
) -> List[str]:
output_lines: List[str] = []
nops = 0
skip_next = False
for index, row in enumerate(input_lines):
if index < skip_lines:
continue
row = row.rstrip()
if ">:" in row or not row:
continue
if "R_MIPS_" in row:
prev = output_lines[-1]
if prev == "<skipped>":
continue
before, imm, after = parse_relocated_line(prev)
repl = row.split()[-1]
# As part of ignoring branch targets, we ignore relocations for j
# instructions. The target is already lost anyway.
if imm == "<target>":
assert ign_branch_targets
continue
# Sometimes s8 is used as a non-framepointer, but we've already lost
# the immediate value by pretending it is one. This isn't too bad,
# since it's rare and applies consistently. But we do need to handle it
# here to avoid a crash, by pretending that lost imms are zero for
# relocations.
if imm != "0" and imm != "imm" and imm != "addr":
repl += "+" + imm if int(imm, 0) > 0 else imm
if any(
reloc in row
for reloc in ["R_MIPS_LO16", "R_MIPS_LITERAL", "R_MIPS_GPREL16"]
):
repl = f"%lo({repl})"
elif "R_MIPS_HI16" in row:
# Ideally we'd pair up R_MIPS_LO16 and R_MIPS_HI16 to generate a
# correct addend for each, but objdump doesn't give us the order of
# the relocations, so we can't find the right LO16. :(
repl = f"%hi({repl})"
else:
assert "R_MIPS_26" in row, f"unknown relocation type '{row}'"
output_lines[-1] = before + repl + after
continue
row = re.sub(arch.re_comment, "", row)
row = row.rstrip()
row = "\t".join(row.split("\t")[2:]) # [20:]
if not row:
continue
if skip_next:
skip_next = False
row = "<skipped>"
if ign_regs:
row = re.sub(arch.re_reg, "<reg>", row)
row_parts = row.split("\t")
if len(row_parts) == 1:
row_parts.append("")
mnemonic, instr_args = row_parts
if not stack_differences:
if mnemonic == "addiu" and arch.re_includes_sp.search(instr_args):
row = re.sub(re_int_full, "imm", row)
if mnemonic in arch.branch_instructions:
if ign_branch_targets:
instr_parts = instr_args.split(",")
instr_parts[-1] = "<target>"
instr_args = ",".join(instr_parts)
row = f"{mnemonic}\t{instr_args}"
# The last part is in hex, so skip the dec->hex conversion
else:
def fn(pat: Match[str]) -> str:
full = pat.group(0)
if len(full) <= 1:
return full
start, end = pat.span()
if start and row[start - 1] in arch.forbidden:
return full
if end < len(row) and row[end] in arch.forbidden:
return full
return hex(int(full))
row = re.sub(re_int, fn, row)
if mnemonic in arch.branch_likely_instructions and skip_bl_delay_slots:
skip_next = True
if not stack_differences:
row = re.sub(arch.re_sprel, "addr(sp)", row)
# row = row.replace(',', ', ')
if row == "nop":
# strip trailing nops; padding is irrelevant to us
nops += 1
else:
for _ in range(nops):
output_lines.append("nop")
nops = 0
output_lines.append(row)
return output_lines
def objdump(
o_filename: str, arch: ArchSettings, *, stack_differences: bool = False
) -> List[str]:
output = subprocess.check_output(arch.objdump + [o_filename])
lines = output.decode("utf-8").splitlines()
return simplify_objdump(lines, arch, stack_differences=stack_differences)
if __name__ == "__main__":
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} file.o", file=sys.stderr)
sys.exit(1)
if not os.path.isfile(sys.argv[1]):
print(f"Source file {sys.argv[1]} is not readable.", file=sys.stderr)
sys.exit(1)
lines = objdump(sys.argv[1], MIPS_SETTINGS)
for row in lines:
print(row)
-58
View File
@@ -1,58 +0,0 @@
from typing import List, Optional, Tuple
from .. import ast_util
from .perm import EvalState, Perm
from ..ast_util import Block, Statement
from ..error import CandidateConstructionFailure
from pycparser import c_ast as ca
class _Done(Exception):
pass
def _apply_perm(fn: ca.FuncDef, perm_id: int, perm: Perm, seed: int) -> None:
"""Find and apply a single late perm macro in the AST."""
# Currently we search for statement macros only.
wanted_pragma = f"_permuter ast_perm {perm_id}"
Loc = Tuple[List[Statement], int]
def try_handle_block(block: ca.Node, where: Optional[Loc]) -> None:
if not isinstance(block, ca.Compound) or not block.block_items:
return
pragma = block.block_items[0]
if not isinstance(pragma, ca.Pragma) or pragma.string != wanted_pragma:
return
args: List[Statement] = block.block_items[1:]
stmts = perm.eval_statement_ast(args, seed)
if where:
where[0][where[1] : where[1] + 1] = stmts
else:
block.block_items = stmts
raise _Done
def rec(block: Block) -> None:
# if (x) { _Pragma(...); inputs } -> if (x) { outputs }
try_handle_block(block, None)
stmts = ast_util.get_block_stmts(block, False)
for i, stmt in enumerate(stmts):
# { ... { _Pragma(...); inputs } ... } -> { ... outputs ... }
try_handle_block(stmt, (stmts, i))
ast_util.for_nested_blocks(stmt, rec)
try:
rec(fn.body)
raise CandidateConstructionFailure("Failed to find PERM macro in AST.")
except _Done:
pass
def apply_ast_perms(fn: ca.FuncDef, eval_state: EvalState) -> None:
"""Find all late perm macros in the AST and apply them."""
# Nested perms will have smaller IDs, so apply the perms from lowest ID to
# highest to ensure that all arguments to perms have already been evaluated.
for perm_id, (perm, seed) in enumerate(eval_state.ast_perms):
_apply_perm(fn, perm_id, perm, seed)
-36
View File
@@ -1,36 +0,0 @@
import random
from typing import List, Iterable, Set, Tuple
from .perm import Perm, EvalState
def _gen_all_seeds(total_count: int) -> Iterable[int]:
"""Generate all numbers 0..total_count-1 in random order, in expected time
O(1) per number."""
seen: Set[int] = set()
while len(seen) < total_count // 2:
seed = random.randrange(total_count)
if seed not in seen:
seen.add(seed)
yield seed
remaining: List[int] = []
for seed in range(total_count):
if seed not in seen:
remaining.append(seed)
random.shuffle(remaining)
for seed in remaining:
yield seed
def perm_gen_all_seeds(perm: Perm) -> Iterable[int]:
while True:
for seed in _gen_all_seeds(perm.perm_count):
yield seed
if not perm.is_random():
break
def perm_evaluate_one(perm: Perm) -> Tuple[str, EvalState]:
eval_state = EvalState()
return perm.evaluate(0, eval_state), eval_state
-144
View File
@@ -1,144 +0,0 @@
from typing import Callable, Dict, List, Tuple
import re
from .perm import (
CombinePerm,
GeneralPerm,
IgnorePerm,
IntPerm,
LineSwapPerm,
LineSwapAstPerm,
OncePerm,
Perm,
PretendPerm,
RandomizerPerm,
RootPerm,
TextPerm,
VarPerm,
)
def _split_by_comma(text: str) -> List[str]:
level = 0
current = ""
args: List[str] = []
for c in text:
if c == "," and level == 0:
args.append(current)
current = ""
else:
if c == "(":
level += 1
elif c == ")":
level -= 1
assert level >= 0, "Bad nesting"
current += c
assert level == 0, "Mismatched parentheses"
args.append(current)
return args
def _split_args(text: str) -> List[Perm]:
perm_args = [_rec_perm_parse(arg) for arg in _split_by_comma(text)]
return perm_args
def _split_args_newline(text: str) -> List[Perm]:
return [_rec_perm_parse(line) for line in text.split("\n") if line.strip()]
def _split_args_text(text: str) -> List[str]:
perm_list = _split_args(text)
res: List[str] = []
for perm in perm_list:
assert isinstance(perm, TextPerm)
res.append(perm.text)
return res
def _make_once_perm(text: str) -> OncePerm:
args = _split_by_comma(text)
if len(args) not in [1, 2]:
raise Exception("PERM_ONCE takes 1 or 2 arguments")
key = args[0].strip()
value = _rec_perm_parse(args[-1])
return OncePerm(key, value)
def _make_var_perm(text: str) -> VarPerm:
args = _split_by_comma(text)
if len(args) not in [1, 2]:
raise Exception("PERM_VAR takes 1 or 2 arguments")
var_name = _rec_perm_parse(args[0])
value = _rec_perm_parse(args[1]) if len(args) == 2 else None
return VarPerm(var_name, value)
PERM_FACTORIES: Dict[str, Callable[[str], Perm]] = {
"PERM_GENERAL": lambda text: GeneralPerm(_split_args(text)),
"PERM_ONCE": lambda text: _make_once_perm(text),
"PERM_RANDOMIZE": lambda text: RandomizerPerm(_rec_perm_parse(text)),
"PERM_VAR": lambda text: _make_var_perm(text),
"PERM_LINESWAP_TEXT": lambda text: LineSwapPerm(_split_args_newline(text)),
"PERM_LINESWAP": lambda text: LineSwapAstPerm(_split_args_newline(text)),
"PERM_INT": lambda text: IntPerm(*map(int, _split_args_text(text))),
"PERM_IGNORE": lambda text: IgnorePerm(_rec_perm_parse(text)),
"PERM_PRETEND": lambda text: PretendPerm(_rec_perm_parse(text)),
}
def _consume_arg_parens(text: str) -> Tuple[str, str]:
level = 0
for i, c in enumerate(text):
if c == "(":
level += 1
elif c == ")":
level -= 1
if level == -1:
return text[:i], text[i + 1 :]
raise Exception("Failed to find closing parenthesis when parsing PERM macro")
def _rec_perm_parse(text: str) -> Perm:
remain = text
macro_search = r"(PERM_.+?)\("
perms: List[Perm] = []
while len(remain) > 0:
match = re.search(macro_search, remain)
# No match found; return remaining
if match is None:
text_perm = TextPerm(remain)
perms.append(text_perm)
break
# Get perm type and args
perm_type = match.group(1)
if perm_type not in PERM_FACTORIES:
raise Exception("Unrecognized PERM macro: " + perm_type)
between = remain[: match.start()]
args, remain = _consume_arg_parens(remain[match.end() :])
# Create text perm
perms.append(TextPerm(between))
# Create new perm
perms.append(PERM_FACTORIES[perm_type](args))
if len(perms) == 1:
return perms[0]
return CombinePerm(perms)
def perm_parse(text: str) -> Perm:
ret = _rec_perm_parse(text)
if isinstance(ret, TextPerm):
ret = RandomizerPerm(ret)
print("No perm macros found. Defaulting to randomization.")
ret = RootPerm(ret)
if not ret.is_random():
print(f"Will run for {ret.perm_count} iterations.")
else:
print(f"Will try {ret.perm_count} different base sources.")
return ret
-289
View File
@@ -1,289 +0,0 @@
from base64 import b64encode
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, TypeVar, Optional
import math
from pycparser import c_ast as ca
from ..ast_util import Statement
T = TypeVar("T")
@dataclass
class PreprocessState:
once_options: Dict[str, List["Perm"]] = field(
default_factory=lambda: defaultdict(list)
)
@dataclass
class EvalState:
vars: Dict[str, str] = field(default_factory=dict)
once_choices: Dict[str, "Perm"] = field(default_factory=dict)
ast_perms: List[Tuple["Perm", int]] = field(default_factory=list)
def register_ast_perm(self, perm: "Perm", seed: int) -> int:
ret = len(self.ast_perms)
self.ast_perms.append((perm, seed))
return ret
def gen_ast_statement_perm(
self, perm: "Perm", seed: int, *, statements: List[str]
) -> str:
perm_id = self.register_ast_perm(perm, seed)
lines = [
"{",
f"#pragma _permuter ast_perm {perm_id}",
*["{" + stmt + "}" for stmt in statements],
"}",
]
return "\n".join(lines)
class Perm:
"""A Perm subclass generates different variations of a part of the source
code. Its evaluate method will be called with a seed between 0 and
perm_count-1, and it should return a unique string for each.
A Perm is allowed to return different strings for the same seed, but if so,
if should override is_random to return True. This will cause permutation
to happen in an infinite loop, rather than stop after the last permutation
has been tested."""
perm_count: int
children: List["Perm"]
def evaluate(self, seed: int, state: EvalState) -> str:
return ""
def eval_statement_ast(self, args: List[Statement], seed: int) -> List[Statement]:
raise NotImplementedError
def preprocess(self, state: PreprocessState) -> None:
for p in self.children:
p.preprocess(state)
def is_random(self) -> bool:
return any(p.is_random() for p in self.children)
def _eval_all(seed: int, perms: List[Perm], state: EvalState) -> List[str]:
ret = []
for p in perms:
seed, sub_seed = divmod(seed, p.perm_count)
ret.append(p.evaluate(sub_seed, state))
assert seed == 0, "seed must be in [0, prod(counts))"
return ret
def _count_all(perms: List[Perm]) -> int:
res = 1
for p in perms:
res *= p.perm_count
return res
def _eval_either(seed: int, perms: List[Perm], state: EvalState) -> str:
for p in perms:
if seed < p.perm_count:
return p.evaluate(seed, state)
seed -= p.perm_count
assert False, "seed must be in [0, sum(counts))"
def _count_either(perms: List[Perm]) -> int:
return sum(p.perm_count for p in perms)
def _shuffle(items: List[T], seed: int) -> List[T]:
items = items[:]
output = []
while items:
ind = seed % len(items)
seed //= len(items)
output.append(items[ind])
del items[ind]
return output
class RootPerm(Perm):
def __init__(self, inner: Perm) -> None:
self.children = [inner]
self.perm_count = inner.perm_count
self.preprocess_state = PreprocessState()
self.preprocess(self.preprocess_state)
for key, options in self.preprocess_state.once_options.items():
if len(options) == 1:
raise Exception(f"PERM_ONCE({key}) occurs only once, possible error?")
self.perm_count *= len(options)
def evaluate(self, seed: int, state: EvalState) -> str:
for key, options in self.preprocess_state.once_options.items():
seed, choice = divmod(seed, len(options))
state.once_choices[key] = options[choice]
return self.children[0].evaluate(seed, state)
class TextPerm(Perm):
def __init__(self, text: str) -> None:
# Comma escape sequence
text = text.replace("(,)", ",")
self.text = text
self.children = []
self.perm_count = 1
def evaluate(self, seed: int, state: EvalState) -> str:
return self.text
class IgnorePerm(Perm):
def __init__(self, inner: Perm) -> None:
self.children = [inner]
self.perm_count = inner.perm_count
def evaluate(self, seed: int, state: EvalState) -> str:
text = self.children[0].evaluate(seed, state)
if not text:
return ""
encoded = b64encode(text.encode("utf-8")).decode("ascii")
return "#pragma _permuter b64literal " + encoded
class PretendPerm(Perm):
def __init__(self, inner: Perm) -> None:
self.children = [inner]
self.perm_count = inner.perm_count
def evaluate(self, seed: int, state: EvalState) -> str:
text = self.children[0].evaluate(seed, state)
return "\n".join(
[
"",
"#pragma _permuter latedefine start",
text,
"#pragma _permuter latedefine end",
"",
]
)
class CombinePerm(Perm):
def __init__(self, parts: List[Perm]) -> None:
self.children = parts
self.perm_count = _count_all(parts)
def evaluate(self, seed: int, state: EvalState) -> str:
texts = _eval_all(seed, self.children, state)
return "".join(texts)
class RandomizerPerm(Perm):
def __init__(self, inner: Perm) -> None:
self.children = [inner]
self.perm_count = inner.perm_count
def evaluate(self, seed: int, state: EvalState) -> str:
text = self.children[0].evaluate(seed, state)
return "\n".join(
[
"",
"#pragma _permuter randomizer start",
text,
"#pragma _permuter randomizer end",
"",
]
)
def is_random(self) -> bool:
return True
class GeneralPerm(Perm):
def __init__(self, candidates: List[Perm]) -> None:
self.perm_count = _count_either(candidates)
self.children = candidates
def evaluate(self, seed: int, state: EvalState) -> str:
return _eval_either(seed, self.children, state)
class OncePerm(Perm):
def __init__(self, key: str, inner: Perm) -> None:
self.key = key
self.children = [inner]
self.perm_count = inner.perm_count
def preprocess(self, state: PreprocessState) -> None:
state.once_options[self.key].append(self)
super().preprocess(state)
def evaluate(self, seed: int, state: EvalState) -> str:
if state.once_choices[self.key] is self:
return self.children[0].evaluate(seed, state)
return ""
class VarPerm(Perm):
def __init__(self, var_name: Perm, expansion: Optional[Perm]) -> None:
if expansion:
self.children = [var_name, expansion]
else:
self.children = [var_name]
self.perm_count = _count_all(self.children)
def evaluate(self, seed: int, state: EvalState) -> str:
var_name_perm = self.children[0]
seed, sub_seed = divmod(seed, var_name_perm.perm_count)
var_name = var_name_perm.evaluate(sub_seed, state).strip()
if len(self.children) > 1:
ret = self.children[1].evaluate(seed, state)
state.vars[var_name] = ret
return ""
else:
if var_name not in state.vars:
raise Exception(f"Tried to read undefined PERM_VAR {var_name}")
return state.vars[var_name]
class LineSwapPerm(Perm):
def __init__(self, lines: List[Perm]) -> None:
self.children = lines
self.own_count = math.factorial(len(lines))
self.perm_count = self.own_count * _count_all(self.children)
def evaluate(self, seed: int, state: EvalState) -> str:
sub_seed, variation = divmod(seed, self.own_count)
texts = _eval_all(sub_seed, self.children, state)
return "\n".join(_shuffle(texts, variation))
class LineSwapAstPerm(Perm):
def __init__(self, lines: List[Perm]) -> None:
self.children = lines
self.own_count = math.factorial(len(lines))
self.perm_count = self.own_count * _count_all(self.children)
def evaluate(self, seed: int, state: EvalState) -> str:
sub_seed, variation = divmod(seed, self.own_count)
texts = _eval_all(sub_seed, self.children, state)
return state.gen_ast_statement_perm(self, variation, statements=texts)
def eval_statement_ast(self, args: List[Statement], seed: int) -> List[Statement]:
ret = []
for item in _shuffle(args, seed):
assert isinstance(item, ca.Compound)
ret.extend(item.block_items or [])
return ret
class IntPerm(Perm):
def __init__(self, low: int, high: int) -> None:
assert low <= high
self.low = low
self.children = []
self.perm_count = high - low + 1
def evaluate(self, seed: int, state: EvalState) -> str:
return str(self.low + seed)
-279
View File
@@ -1,279 +0,0 @@
from dataclasses import dataclass
import difflib
import itertools
import random
import re
import time
import traceback
from typing import (
Any,
List,
Iterator,
Optional,
Tuple,
Union,
)
from .candidate import Candidate, CandidateResult
from .compiler import Compiler
from .error import CandidateConstructionFailure
from .perm.perm import EvalState
from .perm.eval import perm_evaluate_one, perm_gen_all_seeds
from .perm.parse import perm_parse
from .profiler import Profiler
from .scorer import Scorer
@dataclass
class EvalError:
exc_str: Optional[str]
seed: Optional[Tuple[int, int]]
EvalResult = Union[CandidateResult, EvalError]
@dataclass
class Finished:
reason: Optional[str] = None
@dataclass
class Message:
text: str
class NeedMoreWork:
pass
class _CompileFailure(Exception):
pass
@dataclass
class WorkDone:
perm_index: int
result: EvalResult
Task = Union[Finished, Tuple[int, int]]
FeedbackItem = Union[Finished, Message, NeedMoreWork, WorkDone]
Feedback = Tuple[FeedbackItem, int, Optional[str]]
class Permuter:
"""
Represents a single source from which permutation candidates can be generated,
and which keeps track of good scores achieved so far.
"""
def __init__(
self,
dir: str,
fn_name: Optional[str],
compiler: Compiler,
scorer: Scorer,
source_file: str,
source: str,
*,
force_seed: Optional[int],
force_rng_seed: Optional[int],
keep_prob: float,
need_profiler: bool,
need_all_sources: bool,
show_errors: bool,
best_only: bool,
better_only: bool,
) -> None:
self.dir = dir
self.compiler = compiler
self.scorer = scorer
self.source_file = source_file
self.source = source
if fn_name is None:
# Semi-legacy codepath; all functions imported through import.py have a
# function name. This would ideally be done on AST level instead of on the
# pre-macro'ed source code, but we don't care enough to make that
# refactoring.
fns = _find_fns(source)
if len(fns) == 0:
raise Exception(f"{self.source_file} does not contain any function!")
if len(fns) > 1:
raise Exception(
f"{self.source_file} must contain only one function, "
"or have a function.txt next to it with a function name."
)
self.fn_name = fns[0]
else:
self.fn_name = fn_name
self.unique_name = self.fn_name
self._permutations = perm_parse(source)
self._force_seed = force_seed
self._force_rng_seed = force_rng_seed
self._cur_seed: Optional[Tuple[int, int]] = None
self.keep_prob = keep_prob
self.need_profiler = need_profiler
self._need_all_sources = need_all_sources
self._show_errors = show_errors
self._best_only = best_only
self._better_only = better_only
(
self.base_score,
self.base_hash,
self.base_source,
) = self._create_and_score_base()
self.best_score = self.base_score
self.hashes = {self.base_hash}
self._cur_cand: Optional[Candidate] = None
self._last_score: Optional[int] = None
def _create_and_score_base(self) -> Tuple[int, str, str]:
base_source, eval_state = perm_evaluate_one(self._permutations)
base_cand = Candidate.from_source(
base_source, eval_state, self.fn_name, rng_seed=0
)
o_file = base_cand.compile(self.compiler, show_errors=True)
if not o_file:
raise CandidateConstructionFailure(f"Unable to compile {self.source_file}")
base_result = base_cand.score(self.scorer, o_file)
assert base_result.hash is not None
return base_result.score, base_result.hash, base_cand.get_source()
def _need_to_send_source(self, result: CandidateResult) -> bool:
return self._need_all_sources or self.should_output(result)
def _eval_candidate(self, seed: int) -> CandidateResult:
t0 = time.time()
# Determine if we should keep the last candidate.
# Don't keep 0-score candidates; we'll only create new, worse, zeroes.
keep = (
self._permutations.is_random()
and random.uniform(0, 1) < self.keep_prob
and self._last_score != 0
and self._last_score != self.scorer.PENALTY_INF
) or self._force_rng_seed
self._last_score = None
# Create a new candidate if we didn't keep the last one (or if the last one didn't exist)
# N.B. if we decide to keep the previous candidate, we will skip over the provided seed.
# This means we're not guaranteed to test all seeds, but it doesn't really matter since
# we're randomizing anyway.
if not self._cur_cand or not keep:
eval_state = EvalState()
cand_c = self._permutations.evaluate(seed, eval_state)
rng_seed = self._force_rng_seed or random.randrange(1, 10 ** 20)
self._cur_seed = (seed, rng_seed)
self._cur_cand = Candidate.from_source(
cand_c, eval_state, self.fn_name, rng_seed=rng_seed
)
# Randomize the candidate
if self._permutations.is_random():
self._cur_cand.randomize_ast()
t1 = time.time()
self._cur_cand.get_source()
t2 = time.time()
o_file = self._cur_cand.compile(self.compiler)
if not o_file and self._show_errors:
raise _CompileFailure()
t3 = time.time()
result = self._cur_cand.score(self.scorer, o_file)
t4 = time.time()
if self.need_profiler:
profiler = Profiler()
profiler.add_stat(Profiler.StatType.perm, t1 - t0)
profiler.add_stat(Profiler.StatType.stringify, t2 - t1)
profiler.add_stat(Profiler.StatType.compile, t3 - t2)
profiler.add_stat(Profiler.StatType.score, t4 - t3)
result.profiler = profiler
self._last_score = result.score
if not self._need_to_send_source(result):
result.source = None
result.hash = None
return result
def should_output(self, result: CandidateResult) -> bool:
"""Check whether a result should be outputted. This must be more liberal
in child processes than in parent ones, or else sources will be missing."""
return (
result.score <= self.base_score
and result.hash is not None
and result.source is not None
and not (result.score > self.best_score and self._best_only)
and (
result.score < self.base_score
or (result.score == self.base_score and not self._better_only)
)
and result.hash not in self.hashes
)
def record_result(self, result: CandidateResult) -> None:
"""Record a new result, updating the best score and adding the hash to
the set of hashes we have already seen. No hash is recorded for score
0, since we are interested in all score 0's, not just the first."""
self.best_score = min(self.best_score, result.score)
if result.score != 0 and result.hash is not None:
self.hashes.add(result.hash)
def seed_iterator(self) -> Iterator[int]:
"""Create an iterator over all seeds for this permuter. The iterator
will be infinite if we are randomizing."""
if self._force_seed is None:
return iter(perm_gen_all_seeds(self._permutations))
if self._permutations.is_random():
return itertools.repeat(self._force_seed)
return iter([self._force_seed])
def try_eval_candidate(self, seed: int) -> EvalResult:
"""Evaluate a seed for the permuter."""
try:
return self._eval_candidate(seed)
except _CompileFailure:
return EvalError(exc_str=None, seed=self._cur_seed)
except Exception:
return EvalError(exc_str=traceback.format_exc(), seed=self._cur_seed)
def diff(self, other_source: str) -> str:
"""Compute a unified white-space-ignoring diff from the (pretty-printed)
base source against another source generated from this permuter."""
class Line(str):
def __eq__(self, other: Any) -> bool:
return isinstance(other, str) and self.strip() == other.strip()
def __hash__(self) -> int:
return hash(self.strip())
a = list(map(Line, self.base_source.split("\n")))
b = list(map(Line, other_source.split("\n")))
return "\n".join(
difflib.unified_diff(a, b, fromfile="before", tofile="after", lineterm="")
)
def _find_fns(source: str) -> List[str]:
fns = re.findall(r"(\w+)\([^()\n]*\)\s*?{", source)
return [
fn
for fn in fns
if not fn.startswith("PERM") and fn not in ["if", "for", "switch", "while"]
]
-10
View File
@@ -1,10 +0,0 @@
from typing import List
import subprocess
def preprocess(filename: str, cpp_args: List[str] = []) -> str:
return subprocess.check_output(
["cpp"] + cpp_args + ["-P", "-nostdinc", filename],
universal_newlines=True,
encoding="utf-8",
)
-46
View File
@@ -1,46 +0,0 @@
from typing import Optional
from .permuter import Permuter
# Number of additional characters to replace with spaces, in addition to what
# is required based on the length of the previous progress line. This could be
# set to 0, but it's nice to have some margin to deal with e.g. zero-width
# control characters.
SAFETY_PAD = 10
class Printer:
_last_progress: Optional[str] = None
def progress(self, message: str) -> None:
if self._last_progress is None:
clear = ""
else:
pad = max(len(self._last_progress) - len(message) + SAFETY_PAD, 0)
clear = "\b" * pad + " " * pad + "\r"
print(clear + message, end="", flush=True)
self._last_progress = message
def print(
self,
message: str,
permuter: Optional[Permuter],
who: Optional[str],
*,
color: str = "",
keep_progress: bool = False,
) -> None:
if self._last_progress is not None:
if keep_progress:
print()
else:
pad = len(self._last_progress) + SAFETY_PAD
print("\r" + " " * pad + "\r", end="")
if permuter is not None:
message = f"[{permuter.unique_name}] {message}"
if who is not None:
message = f"[{who}] {message}"
if color:
message = f"{color}{message}\u001b[0m"
print(message)
self._last_progress = None
-23
View File
@@ -1,23 +0,0 @@
from enum import Enum
class Profiler:
class StatType(Enum):
perm = 1
stringify = 2
compile = 3
score = 4
def __init__(self) -> None:
self.time_stats = {x: 0.0 for x in Profiler.StatType}
def add_stat(self, stat: StatType, time_taken: float) -> None:
self.time_stats[stat] += time_taken
def get_str_stats(self) -> str:
total_time = sum(self.time_stats[e] for e in self.time_stats)
timings = ", ".join(
f"{round(100 * self.time_stats[e] / total_time)}% {e.name}"
for e in self.time_stats
)
return timings
File diff suppressed because it is too large Load Diff
-145
View File
@@ -1,145 +0,0 @@
from dataclasses import dataclass, field
import difflib
import hashlib
import re
from typing import Tuple, List, Optional
from collections import Counter
from .objdump import objdump, get_arch
@dataclass(init=False, unsafe_hash=True)
class DiffAsmLine:
line: str = field(compare=False)
mnemonic: str
def __init__(self, line: str) -> None:
self.line = line
self.mnemonic = line.split("\t")[0]
class Scorer:
PENALTY_INF = 10 ** 9
PENALTY_STACKDIFF = 1
PENALTY_REGALLOC = 5
PENALTY_REORDERING = 60
PENALTY_INSERTION = 100
PENALTY_DELETION = 100
def __init__(self, target_o: str, *, stack_differences: bool):
self.target_o = target_o
self.arch = get_arch(target_o)
self.stack_differences = stack_differences
_, self.target_seq = self._objdump(target_o)
self.differ: difflib.SequenceMatcher[DiffAsmLine] = difflib.SequenceMatcher(
autojunk=False
)
self.differ.set_seq2(self.target_seq)
def _objdump(self, o_file: str) -> Tuple[str, List[DiffAsmLine]]:
ret = []
lines = objdump(o_file, self.arch, stack_differences=self.stack_differences)
for line in lines:
ret.append(DiffAsmLine(line))
return "\n".join(lines), ret
def score(self, cand_o: Optional[str]) -> Tuple[int, str]:
if not cand_o:
return Scorer.PENALTY_INF, ""
objdump_output, cand_seq = self._objdump(cand_o)
score = 0
deletions = []
insertions = []
def lo_hi_match(old: str, new: str) -> bool:
old_lo = old.find("%lo")
old_hi = old.find("%hi")
new_lo = new.find("%lo")
new_hi = new.find("%hi")
if old_lo != -1 and new_lo != -1:
old_idx = old_lo
new_idx = new_lo
elif old_hi != -1 and new_hi != -1:
old_idx = old_hi
new_idx = new_hi
else:
return False
if old[:old_idx] != new[:new_idx]:
return False
old_inner = old[old_idx + 4 : -1]
new_inner = new[new_idx + 4 : -1]
return old_inner.startswith(".") or new_inner.startswith(".")
def diff_sameline(old: str, new: str) -> None:
nonlocal score
if old == new:
return
if lo_hi_match(old, new):
return
ignore_last_field = False
if self.stack_differences:
oldsp = re.search(self.arch.re_sprel, old)
newsp = re.search(self.arch.re_sprel, new)
if oldsp and newsp:
oldrel = int(oldsp.group(1) or "0", 0)
newrel = int(newsp.group(1) or "0", 0)
score += abs(oldrel - newrel) * self.PENALTY_STACKDIFF
ignore_last_field = True
# Probably regalloc difference, or signed vs unsigned
# Compare each field in order
newfields, oldfields = new.split(","), old.split(",")
if ignore_last_field:
newfields = newfields[:-1]
oldfields = oldfields[:-1]
for nf, of in zip(newfields, oldfields):
if nf != of:
score += self.PENALTY_REGALLOC
# Penalize any extra fields
score += abs(len(newfields) - len(oldfields)) * self.PENALTY_REGALLOC
def diff_insert(line: str) -> None:
# Reordering or totally different codegen.
# Defer this until later when we can tell.
insertions.append(line)
def diff_delete(line: str) -> None:
deletions.append(line)
self.differ.set_seq1(cand_seq)
for (tag, i1, i2, j1, j2) in self.differ.get_opcodes():
if tag == "equal":
for k in range(i2 - i1):
old = self.target_seq[j1 + k].line
new = cand_seq[i1 + k].line
diff_sameline(old, new)
if tag == "replace" or tag == "delete":
for k in range(i1, i2):
diff_insert(cand_seq[k].line)
if tag == "replace" or tag == "insert":
for k in range(j1, j2):
diff_delete(self.target_seq[k].line)
insertions_co = Counter(insertions)
deletions_co = Counter(deletions)
for item in insertions_co + deletions_co:
ins = insertions_co[item]
dels = deletions_co[item]
common = min(ins, dels)
score += (
(ins - common) * self.PENALTY_INSERTION
+ (dels - common) * self.PENALTY_DELETION
+ self.PENALTY_REORDERING * common
)
return (score, hashlib.sha256(objdump_output.encode()).hexdigest())
-74
View File
@@ -1,74 +0,0 @@
import re
import argparse
from typing import Optional
from pathlib import Path
def _find_bracket_end(input: str, start_index: int) -> int:
level = 1
assert input[start_index] == "{"
i = start_index + 1
while i < len(input):
if input[i] == "{":
level += 1
elif input[i] == "}":
level -= 1
if level == 0:
break
i += 1
assert level == 0, "unbalanced {}"
return i
def strip_other_fns(source: str, keep_fn_name: str) -> str:
result = ""
remain = source
while True:
fn_regex = re.compile(r"^.*\s+\**(\w+)\(.*\)\s*?{", re.M)
fn = re.search(fn_regex, remain)
if fn is None:
result += remain
remain = ""
break
fn_name = fn.group(1)
bracket_end = _find_bracket_end(remain, fn.end() - 1)
if fn_name.startswith("PERM"):
result += remain[: bracket_end + 1]
elif fn_name == keep_fn_name:
result += "\n\n" + remain[: bracket_end + 1] + "\n\n"
else:
result += remain[: fn.end() - 1].rstrip() + ";"
remain = remain[bracket_end + 1 :]
return result
def strip_other_fns_and_write(
source: str, fn_name: str, out_filename: Optional[str] = None
) -> None:
stripped = strip_other_fns(source, fn_name)
if out_filename is None:
print(stripped)
else:
with open(out_filename, "w", encoding="utf-8") as f:
f.write(stripped)
def main() -> None:
parser = argparse.ArgumentParser(
description="Remove all but a single function definition from a file."
)
parser.add_argument("c_file", help="File containing the function.")
parser.add_argument("fn_name", help="Function name.")
args = parser.parse_args()
source = Path(args.c_file).read_text()
strip_other_fns_and_write(source, args.fn_name, args.c_file)
if __name__ == "__main__":
main()
@@ -1,18 +0,0 @@
#-----------------------------------------------------------------
# pycparser: __init__.py
#
# This package file exports some convenience functions for
# interacting with pycparser
#
# Eli Bendersky [https://eli.thegreenplace.net/]
# License: BSD
#-----------------------------------------------------------------
__all__ = ['c_parser', 'c_ast']
__version__ = '2.19'
from typing import Any, List, Union
from . import c_ast
from .c_parser import CParser
def preprocess_file(filename: str, cpp_path: str='cpp', cpp_args: Union[List[str], str]='') -> str: ...
def parse_file(filename: str, use_cpp: bool=False, cpp_path: str='cpp', cpp_args: str='', parser: Any=None) -> c_ast.FileAST: ...
@@ -1,719 +0,0 @@
# -----------------------------------------------------------------
# pycparser: c_ast.py
#
# AST Node classes.
#
# Eli Bendersky [https://eli.thegreenplace.net/]
# License: BSD
# -----------------------------------------------------------------
from typing import TextIO, Iterable, List, Any, Optional, Union as Union_
from .plyparser import Coord
import sys
class Node(object):
coord: Optional[Coord]
def __repr__(self) -> str:
...
def __iter__(self) -> Iterable[Node]:
...
def children(self) -> Iterable[Node]:
...
def show(
self,
buf: TextIO = sys.stdout,
offset: int = 0,
attrnames: bool = False,
nodenames: bool = False,
showcoord: bool = False,
) -> None:
...
Expression = Union_[
"ArrayRef",
"Assignment",
"BinaryOp",
"Cast",
"CompoundLiteral",
"Constant",
"ExprList",
"FuncCall",
"ID",
"TernaryOp",
"UnaryOp",
]
Statement = Union_[
Expression,
"Break",
"Case",
"Compound",
"Continue",
"Decl",
"Default",
"DoWhile",
"EmptyStatement",
"For",
"Goto",
"If",
"Label",
"Return",
"Switch",
"Typedef",
"While",
"Pragma",
]
Type = Union_["PtrDecl", "ArrayDecl", "FuncDecl", "TypeDecl"]
InnerType = Union_["IdentifierType", "Struct", "Union", "Enum"]
ExternalDeclaration = Union_["FuncDef", "Decl", "Typedef", "Pragma"]
AnyNode = Union_[
Statement,
Type,
InnerType,
"FuncDef",
"EllipsisParam",
"Enumerator",
"EnumeratorList",
"FileAST",
"InitList",
"NamedInitializer",
"ParamList",
"Typename",
]
class NodeVisitor:
def visit(self, node: Node) -> None:
...
def generic_visit(self, node: Node) -> None:
...
def visit_ArrayDecl(self, node: ArrayDecl) -> None:
...
def visit_ArrayRef(self, node: ArrayRef) -> None:
...
def visit_Assignment(self, node: Assignment) -> None:
...
def visit_BinaryOp(self, node: BinaryOp) -> None:
...
def visit_Break(self, node: Break) -> None:
...
def visit_Case(self, node: Case) -> None:
...
def visit_Cast(self, node: Cast) -> None:
...
def visit_Compound(self, node: Compound) -> None:
...
def visit_CompoundLiteral(self, node: CompoundLiteral) -> None:
...
def visit_Constant(self, node: Constant) -> None:
...
def visit_Continue(self, node: Continue) -> None:
...
def visit_Decl(self, node: Decl) -> None:
...
def visit_DeclList(self, node: DeclList) -> None:
...
def visit_Default(self, node: Default) -> None:
...
def visit_DoWhile(self, node: DoWhile) -> None:
...
def visit_EllipsisParam(self, node: EllipsisParam) -> None:
...
def visit_EmptyStatement(self, node: EmptyStatement) -> None:
...
def visit_Enum(self, node: Enum) -> None:
...
def visit_Enumerator(self, node: Enumerator) -> None:
...
def visit_EnumeratorList(self, node: EnumeratorList) -> None:
...
def visit_ExprList(self, node: ExprList) -> None:
...
def visit_FileAST(self, node: FileAST) -> None:
...
def visit_For(self, node: For) -> None:
...
def visit_FuncCall(self, node: FuncCall) -> None:
...
def visit_FuncDecl(self, node: FuncDecl) -> None:
...
def visit_FuncDef(self, node: FuncDef) -> None:
...
def visit_Goto(self, node: Goto) -> None:
...
def visit_ID(self, node: ID) -> None:
...
def visit_IdentifierType(self, node: IdentifierType) -> None:
...
def visit_If(self, node: If) -> None:
...
def visit_InitList(self, node: InitList) -> None:
...
def visit_Label(self, node: Label) -> None:
...
def visit_NamedInitializer(self, node: NamedInitializer) -> None:
...
def visit_ParamList(self, node: ParamList) -> None:
...
def visit_PtrDecl(self, node: PtrDecl) -> None:
...
def visit_Return(self, node: Return) -> None:
...
def visit_Struct(self, node: Struct) -> None:
...
def visit_StructRef(self, node: StructRef) -> None:
...
def visit_Switch(self, node: Switch) -> None:
...
def visit_TernaryOp(self, node: TernaryOp) -> None:
...
def visit_TypeDecl(self, node: TypeDecl) -> None:
...
def visit_Typedef(self, node: Typedef) -> None:
...
def visit_Typename(self, node: Typename) -> None:
...
def visit_UnaryOp(self, node: UnaryOp) -> None:
...
def visit_Union(self, node: Union) -> None:
...
def visit_While(self, node: While) -> None:
...
def visit_Pragma(self, node: Pragma) -> None:
...
class ArrayDecl(Node):
type: Type
dim: Optional[Expression]
dim_quals: List[str]
def __init__(
self,
type: Type,
dim: Optional[Node],
dim_quals: List[str],
coord: Optional[Coord] = None,
):
...
class ArrayRef(Node):
name: Expression
subscript: Expression
def __init__(self, name: Node, subscript: Node, coord: Optional[Coord] = None):
...
class Assignment(Node):
op: str
lvalue: Expression
rvalue: Expression
def __init__(
self,
op: str,
lvalue: Expression,
rvalue: Expression,
coord: Optional[Coord] = None,
):
...
class BinaryOp(Node):
op: str
left: Expression
right: Expression
def __init__(self, op: str, left: Node, right: Node, coord: Optional[Coord] = None):
...
class Break(Node):
def __init__(self, coord: Optional[Coord] = None):
...
class Case(Node):
expr: Expression
stmts: List[Statement]
def __init__(
self, expr: Expression, stmts: List[Statement], coord: Optional[Coord] = None
):
...
class Cast(Node):
to_type: "Typename"
expr: Expression
def __init__(
self, to_type: "Typename", expr: Expression, coord: Optional[Coord] = None
):
...
class Compound(Node):
block_items: Optional[List[Statement]]
def __init__(
self, block_items: Optional[List[Statement]], coord: Optional[Coord] = None
):
...
class CompoundLiteral(Node):
type: "Typename"
init: "InitList"
def __init__(
self, type: "Typename", init: "InitList", coord: Optional[Coord] = None
):
...
class Constant(Node):
type: str
value: str
def __init__(self, type: str, value: str, coord: Optional[Coord] = None):
...
class Continue(Node):
def __init__(self, coord: Optional[Coord] = None):
...
class Decl(Node):
name: Optional[str]
quals: List[str] # e.g. const
storage: List[str] # e.g. register
funcspec: List[str] # e.g. inline
type: Union_[Type, "Struct", "Union", "Enum"]
init: Optional[Union_[Expression, "InitList"]]
bitsize: Optional[Expression]
def __init__(
self,
name: Optional[str],
quals: List[str],
storage: List[str],
funcspec: List[str],
type: Union_[Type, "Struct", "Union", "Enum"],
init: Optional[Union_[Expression, "InitList"]],
bitsize: Optional[Expression],
coord: Optional[Coord] = None,
):
...
class DeclList(Node):
decls: List[Decl]
def __init__(self, decls: List[Decl], coord: Optional[Coord] = None):
...
class Default(Node):
stmts: List[Statement]
def __init__(self, stmts: List[Statement], coord: Optional[Coord] = None):
...
class DoWhile(Node):
cond: Expression
stmt: Statement
def __init__(
self, cond: Expression, stmt: Statement, coord: Optional[Coord] = None
):
...
class EllipsisParam(Node):
def __init__(self, coord: Optional[Coord] = None):
...
class EmptyStatement(Node):
def __init__(self, coord: Optional[Coord] = None):
...
class Enum(Node):
name: Optional[str]
values: "Optional[EnumeratorList]"
def __init__(
self,
name: Optional[str],
values: "Optional[EnumeratorList]",
coord: Optional[Coord] = None,
):
...
class Enumerator(Node):
name: str
value: Optional[Expression]
def __init__(
self, name: str, value: Optional[Expression], coord: Optional[Coord] = None
):
...
class EnumeratorList(Node):
enumerators: List[Enumerator]
def __init__(self, enumerators: List[Enumerator], coord: Optional[Coord] = None):
...
class ExprList(Node):
exprs: List[Union_[Expression, Typename]] # typename only for offsetof
def __init__(
self, exprs: List[Union_[Expression, Typename]], coord: Optional[Coord] = None
):
...
class FileAST(Node):
ext: List[ExternalDeclaration]
def __init__(self, ext: List[ExternalDeclaration], coord: Optional[Coord] = None):
...
class For(Node):
init: Union_[None, Expression, DeclList]
cond: Optional[Expression]
next: Optional[Expression]
stmt: Statement
def __init__(
self,
init: Union_[None, Expression, DeclList],
cond: Optional[Expression],
next: Optional[Expression],
stmt: Statement,
coord: Optional[Coord] = None,
):
...
class FuncCall(Node):
name: Expression
args: Optional[ExprList]
def __init__(
self, name: Expression, args: Optional[ExprList], coord: Optional[Coord] = None
):
...
class FuncDecl(Node):
args: Optional[ParamList]
type: Type # return type
def __init__(
self, args: Optional[ParamList], type: Type, coord: Optional[Coord] = None
):
...
class FuncDef(Node):
decl: Decl
param_decls: Optional[List[Decl]]
body: Compound
def __init__(
self,
decl: Decl,
param_decls: Optional[List[Decl]],
body: Compound,
coord: Optional[Coord] = None,
):
...
class Goto(Node):
name: str
def __init__(self, name: str, coord: Optional[Coord] = None):
...
class ID(Node):
name: str
def __init__(self, name: str, coord: Optional[Coord] = None):
...
class IdentifierType(Node):
names: List[str] # e.g. ['long', 'int']
def __init__(self, names: List[str], coord: Optional[Coord] = None):
...
class If(Node):
cond: Expression
iftrue: Statement
iffalse: Optional[Statement]
def __init__(
self,
cond: Expression,
iftrue: Statement,
iffalse: Optional[Statement],
coord: Optional[Coord] = None,
):
...
class InitList(Node):
exprs: List[Union_[Expression, "NamedInitializer"]]
def __init__(
self,
exprs: List[Union_[Expression, "NamedInitializer"]],
coord: Optional[Coord] = None,
):
...
class Label(Node):
name: str
stmt: Statement
def __init__(self, name: str, stmt: Statement, coord: Optional[Coord] = None):
...
class NamedInitializer(Node):
name: List[Expression] # [ID(x), Constant(4)] for {.x[4] = ...}
expr: Expression
def __init__(
self, name: List[Expression], expr: Expression, coord: Optional[Coord] = None
):
...
class ParamList(Node):
params: List[Union_[Decl, ID, Typename, EllipsisParam]]
def __init__(
self,
params: List[Union_[Decl, ID, Typename, EllipsisParam]],
coord: Optional[Coord] = None,
):
...
class PtrDecl(Node):
quals: List[str]
type: Type
def __init__(self, quals: List[str], type: Type, coord: Optional[Coord] = None):
...
class Return(Node):
expr: Optional[Expression]
def __init__(self, expr: Optional[Expression], coord: Optional[Coord] = None):
...
class Struct(Node):
name: Optional[str]
decls: Optional[List[Union_[Decl, Pragma]]]
def __init__(
self,
name: Optional[str],
decls: Optional[List[Union_[Decl, Pragma]]],
coord: Optional[Coord] = None,
):
...
class StructRef(Node):
name: Expression
type: str
field: ID
def __init__(
self, name: Expression, type: str, field: ID, coord: Optional[Coord] = None
):
...
class Switch(Node):
cond: Expression
stmt: Statement
def __init__(
self, cond: Expression, stmt: Statement, coord: Optional[Coord] = None
):
...
class TernaryOp(Node):
cond: Expression
iftrue: Expression
iffalse: Expression
def __init__(
self,
cond: Expression,
iftrue: Expression,
iffalse: Expression,
coord: Optional[Coord] = None,
):
...
class TypeDecl(Node):
declname: Optional[str]
quals: List[str]
type: InnerType
def __init__(
self,
declname: Optional[str],
quals: List[str],
type: InnerType,
coord: Optional[Coord] = None,
):
...
class Typedef(Node):
name: str
quals: List[str]
storage: List[str]
type: Type
def __init__(
self,
name: str,
quals: List[str],
storage: List[str],
type: Type,
coord: Optional[Coord] = None,
):
...
class Typename(Node):
name: None
quals: List[str]
type: Type
def __init__(
self, name: None, quals: List[str], type: Type, coord: Optional[Coord] = None
):
...
class UnaryOp(Node):
op: str
expr: Union_[Expression, Typename]
def __init__(
self, op: str, expr: Union_[Expression, Typename], coord: Optional[Coord] = None
):
...
class Union(Node):
name: Optional[str]
decls: Optional[List[Union_[Decl, Pragma]]]
def __init__(
self,
name: Optional[str],
decls: Optional[List[Union_[Decl, Pragma]]],
coord: Optional[Coord] = None,
):
...
class While(Node):
cond: Expression
stmt: Statement
def __init__(
self, cond: Expression, stmt: Statement, coord: Optional[Coord] = None
):
...
class Pragma(Node):
string: str
def __init__(self, string: str, coord: Optional[Coord] = None):
...
@@ -1,13 +0,0 @@
#------------------------------------------------------------------------------
# pycparser: c_generator.py
#
# C code generator from pycparser AST nodes.
#
# Eli Bendersky [https://eli.thegreenplace.net/]
# License: BSD
#------------------------------------------------------------------------------
from . import c_ast
class CGenerator:
def __init__(self) -> None: ...
def visit(self, node: c_ast.Node) -> str: ...
@@ -1,15 +0,0 @@
#------------------------------------------------------------------------------
# pycparser: c_parser.py
#
# CParser class: Parser and AST builder for the C language
#
# Eli Bendersky [https://eli.thegreenplace.net/]
# License: BSD
#------------------------------------------------------------------------------
from . import c_ast
class CParser:
def __init__(self) -> None: ...
def parse(self, text: str, filename: str='', debuglevel: int=0) -> c_ast.FileAST: ...
@@ -1,27 +0,0 @@
# -----------------------------------------------------------------
# plyparser.py
#
# PLYParser class and other utilites for simplifying programming
# parsers with PLY
#
# Eli Bendersky [https://eli.thegreenplace.net/]
# License: BSD
# -----------------------------------------------------------------
from typing import Optional
class Coord:
file: str
line: int
column: Optional[int]
def __init__(self, file: str, line: int, column: Optional[int] = None):
...
def __str__(self) -> str:
...
class ParseError(Exception):
pass
-13
View File
@@ -1,13 +0,0 @@
#!/usr/bin/env python3
import sys
from pycparser import parse_file, c_generator
fname = "test.c" if len(sys.argv) < 2 else sys.argv[1]
# ast = c_parser.CParser().parse(src)
ast = parse_file(fname, use_cpp=True)
# ast.show()
# print(c_generator.CGenerator().visit(ast))
print(ast)
-2
View File
@@ -1,2 +0,0 @@
#!/bin/bash
mips-linux-gnu-gcc -O2 -fno-PIC -fno-common -ffreestanding -mno-shared -mno-abicalls -G 0 -c "$@"
-247
View File
@@ -1,247 +0,0 @@
import os
from pathlib import Path
import re
import shutil
import tempfile
from typing import Any, Optional
import unittest
from src.compiler import Compiler
from src.preprocess import preprocess
from src import main
class TestPermMacros(unittest.TestCase):
def go(
self,
intro: str,
outro: str,
base: str,
target: str,
*,
fn_name: Optional[str] = None,
**kwargs: Any
) -> int:
base = intro + "\n" + base + "\n" + outro
target = intro + "\n" + target + "\n" + outro
compiler = Compiler("test/compile.sh", show_errors=True)
# For debugging, to avoid the auto-deleted directory:
# target_dir = tempfile.mkdtemp()
with tempfile.TemporaryDirectory() as target_dir:
with open(os.path.join(target_dir, "base.c"), "w") as f:
f.write(base)
target_o = compiler.compile(target, show_errors=True)
assert target_o is not None
shutil.move(target_o, os.path.join(target_dir, "target.o"))
shutil.copy2("test/compile.sh", os.path.join(target_dir, "compile.sh"))
if fn_name:
with open(os.path.join(target_dir, "function.txt"), "w") as f:
f.write(fn_name)
opts = main.Options(directories=[target_dir], stop_on_zero=True, **kwargs)
return main.run(opts)[0]
def test_general(self) -> None:
score = self.go(
"int test() {",
"}",
"return PERM_GENERAL(32,64);",
"return 64;",
)
self.assertEqual(score, 0)
def test_not_found(self) -> None:
score = self.go(
"int test() {",
"}",
"return PERM_GENERAL(32,64);",
"return 92;",
)
self.assertNotEqual(score, 0)
def test_multiple_functions(self) -> None:
score = self.go(
"",
"",
"""
int ignoreme() {}
int foo() { return PERM_GENERAL(32,64); }
int ignoreme2() {}
""",
"int foo() { return 64; }",
fn_name="foo",
)
self.assertEqual(score, 0)
def test_general_multiple(self) -> None:
score = self.go(
"int test() {",
"}",
"return PERM_GENERAL(1,2,3) + PERM_GENERAL(3,6,9);",
"return 9;",
)
self.assertEqual(score, 0)
def test_general_nested(self) -> None:
score = self.go(
"int test() {",
"}",
"return PERM_GENERAL(1,PERM_GENERAL(100,101),3) + PERM_GENERAL(3,6,9);",
"return 110;",
)
self.assertEqual(score, 0)
def test_cast(self) -> None:
score = self.go(
"int test(int a, int b) {",
"}",
"return a / PERM_GENERAL(,(unsigned int),(float)) b;",
"return a / (float) b;",
)
self.assertEqual(score, 0)
def test_cast_threaded(self) -> None:
score = self.go(
"int test(int a, int b) {",
"}",
"return a / PERM_GENERAL(,(unsigned int),(float)) b;",
"return a / (float) b;",
threads=2,
)
self.assertEqual(score, 0)
def test_ignore(self) -> None:
score = self.go(
"int test(int a, int b) {",
"}",
"PERM_IGNORE( return a / PERM_GENERAL(a, b); )",
"return a / b;",
)
self.assertEqual(score, 0)
def test_pretend(self) -> None:
score = self.go(
"int global;",
"",
"""
PERM_IGNORE( inline void foo() { )
PERM_PRETEND( void foo(); void bar() { )
PERM_RANDOMIZE(
global = 1;
)
PERM_IGNORE( } void bar() { )
PERM_RANDOMIZE(
global = 2; foo();
)
}
""",
"""
inline void foo() { global = 1; }
void bar() { foo(); global = 2; }
""",
fn_name="bar",
)
self.assertEqual(score, 0)
def test_once1(self) -> None:
score = self.go(
"volatile int A, B, C; void test() {",
"}",
"""
PERM_ONCE(B = 2;)
A = 1;
PERM_ONCE(B = 2;)
C = 3;
PERM_ONCE(B = 2;)
""",
"A = 1; B = 2; C = 3;",
)
self.assertEqual(score, 0)
def test_once2(self) -> None:
score = self.go(
"volatile int A, B, C; void test() {",
"}",
"""
PERM_VAR(emit,)
PERM_VAR(bademit,)
PERM_ONCE(1, PERM_VAR(bademit, A = 7;) A = 2;)
PERM_ONCE(1, PERM_VAR(emit, A = 1;))
PERM_VAR(emit)
PERM_VAR(bademit)
PERM_ONCE(2, B = 2;)
PERM_ONCE(2, B = 1;)
PERM_ONCE(2,)
PERM_ONCE(3, PERM_VAR(bademit, A = 9))
PERM_ONCE(3, PERM_VAR(bademit, A = 9))
C = 3;
""",
"A = 1; B = 2; C = 3;",
)
self.assertEqual(score, 0)
def test_lineswap(self) -> None:
score = self.go(
"void a(); void b(); void c(); void test(void) {",
"}",
"""
PERM_LINESWAP(
a();
b();
c();
)
""",
"b(); a(); c();",
)
self.assertEqual(score, 0)
def test_lineswap_text(self) -> None:
score = self.go(
"void a(); void b(); void c(); void test(void) {",
"}",
"""
PERM_LINESWAP_TEXT(
a();
b();
c();
)
""",
"b(); a(); c();",
)
self.assertEqual(score, 0)
def test_randomizer(self) -> None:
score = self.go(
"void foo(); void bar(); void test(void) {",
"}",
"PERM_RANDOMIZE(bar(); foo();)",
"foo(); bar();",
)
self.assertEqual(score, 0)
def test_auto_randomizer(self) -> None:
score = self.go(
"void foo(); void bar(); void test(void) {",
"}",
"bar(); foo();",
"foo(); bar();",
)
self.assertEqual(score, 0)
def test_randomizer_threaded(self) -> None:
score = self.go(
"void foo(); void bar(); void test(void) {",
"}",
"PERM_RANDOMIZE(bar(); foo();)",
"foo(); bar();",
threads=2,
)
self.assertEqual(score, 0)
if __name__ == "__main__":
unittest.main()