mirror of
https://github.com/zeldaret/mm.git
synced 2026-05-23 06:54:14 -04:00
Remove decomp-permuter (#447)
* Nuke decomp permuter * Add decomp permuter and mips2c to gitignore
This commit is contained in:
@@ -33,6 +33,8 @@ tools/ido_recomp/* binary
|
||||
ctx.c
|
||||
graphs/
|
||||
*.c.m2c
|
||||
tools/decomp-permuter/
|
||||
tools/mips_to_c/
|
||||
|
||||
# Assets
|
||||
*.png
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
name: Systray
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'src/net/cmd/systray/*'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'src/net/cmd/systray/*'
|
||||
|
||||
jobs:
|
||||
|
||||
build:
|
||||
name: Build on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-16.04
|
||||
binary: permuter-systray-linux
|
||||
- os: windows-latest
|
||||
binary: permuter-systray.exe
|
||||
- os: macos-latest
|
||||
binary: permuter-systray-macos
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
|
||||
- name: Install gtk3
|
||||
if: ${{ matrix.os == 'ubuntu-16.04' }}
|
||||
run: sudo apt-get install libgtk-3-dev libappindicator3-dev
|
||||
|
||||
- name: Setup Go environment
|
||||
uses: actions/setup-go@v2.1.3
|
||||
|
||||
- name: Build
|
||||
run: go build -o ${{ matrix.binary }} -ldflags "-s -w" tray.go
|
||||
working-directory: src/net/cmd/systray/
|
||||
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: ${{ matrix.binary }}
|
||||
path: src/net/cmd/systray/${{ matrix.binary }}
|
||||
@@ -1,11 +0,0 @@
|
||||
*.o
|
||||
*.s
|
||||
*.c
|
||||
*.py[cod]
|
||||
.mypy_cache/
|
||||
.cache/
|
||||
__pycache__/
|
||||
!test/*.c
|
||||
/nonmatchings
|
||||
.vscode/
|
||||
pah.conf
|
||||
@@ -1,12 +0,0 @@
|
||||
; DO NOT EDIT (unless you know what you are doing)
|
||||
;
|
||||
; This subdirectory is a git "subrepo", and this file is maintained by the
|
||||
; git-subrepo command. See https://github.com/git-commands/git-subrepo#readme
|
||||
;
|
||||
[subrepo]
|
||||
remote = https://github.com/simonlindholm/decomp-permuter.git
|
||||
branch = main
|
||||
commit = a20bac9422b6d8adbf7c06473c2ae3c3fee16be5
|
||||
parent = 2668eec556c01fa2f4c16a203c93c208dc03e639
|
||||
method = merge
|
||||
cmdver = 0.4.3
|
||||
@@ -1,6 +0,0 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 20.8b1
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.6
|
||||
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2019 Simon Lindholm and contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -1,120 +0,0 @@
|
||||
# Decomp permuter
|
||||
|
||||
Automatically permutes C files to better match a target binary. The permuter has two modes of operation:
|
||||
- Random: purely at random, introduce temporary variables for values, change types, put statements on the same line...
|
||||
- Manual: test all combinations of user-specified variations, using macros like `PERM_GENERAL(a = b ? c : d;, if (b) a = c; else a = d;)` to try both specified alternatives.
|
||||
|
||||
The modes can also be combined, by using the `PERM_RANDOMIZE` macro.
|
||||
|
||||
[<img src="https://asciinema.org/a/232846.svg" height="300">](https://asciinema.org/a/232846)
|
||||
|
||||
The main target for the tool is MIPS code compiled by old compilers (IDO, possibly GCC).
|
||||
Getting it to work on other architectures shouldn't be too hard, however.
|
||||
https://github.com/laqieer/decomp-permuter-arm has an ARM port.
|
||||
|
||||
## Usage
|
||||
|
||||
`./permuter.py directory/` runs the permuter; see below for the meaning of the directory.
|
||||
Pass `-h` to see possible flags. `-j` is suggested (enables multi-threaded mode).
|
||||
|
||||
You'll first need to install a couple of prerequisites: `python3 -m pip install pycparser pynacl toml` (also `dataclasses` if on Python 3.6 or below)
|
||||
|
||||
The permuter expects as input one or more directory containing:
|
||||
- a .c file with a single function,
|
||||
- a .o file to match,
|
||||
- a .sh file that compiles the .c file.
|
||||
|
||||
For projects with a properly configured makefile, you should be able to set these up by running
|
||||
```
|
||||
./import.py <path/to/file.c> <path/to/file.s>
|
||||
```
|
||||
where file.c contains the function to be permuted, and file.s is its assembly in a self-contained file.
|
||||
Otherwise, see USAGE.md for more details.
|
||||
|
||||
For projects using Ninja instead of Make, add a `permuter_settings.toml` in the root or `tools/` directory of the project:
|
||||
```toml
|
||||
build_system = "ninja"
|
||||
```
|
||||
Then `import.py` should work as expected if `build.ninja` is at the root of the project.
|
||||
|
||||
The .c file may be modified with any of the following macros which affect manual permutation:
|
||||
|
||||
- `PERM_GENERAL(a, b, ...)` expands to any of `a`, `b`, ...
|
||||
- `PERM_VAR(a, b)` sets the meta-variable `a` to `b`, `PERM_VAR(a)` expands to the meta-variable `a`.
|
||||
- `PERM_RANDOMIZE(code)` expands to `code`, but allows randomization within that region. Multiple regions may be specified.
|
||||
- `PERM_LINESWAP(lines)` expands to a permutation of the ordered set of non-whitespace lines (split by `\n`). Each line must contain zero or more complete C statements. (For incomplete statements use `PERM_LINESWAP_TEXT`, which is slower because it has to repeatedly parse C code.)
|
||||
- `PERM_INT(lo, hi)` expands to an integer between `lo` and `hi` (which must be constants).
|
||||
- `PERM_IGNORE(code)` expands to `code`, without passing it through the C parser library (pycparser)/randomizer. This can be used to avoid parse errors for non-standard C, e.g. `asm` blocks.
|
||||
- `PERM_PRETEND(code)` expands to `code` for the purpose of the C parser/randomizer, but gets removed afterwards. This can be used together with `PERM_IGNORE` to enable the permuter to deal with input it isn't designed for (e.g. inline functions, C++, non-code).
|
||||
- `PERM_ONCE([key,] code)` expands to either `code` or to nothing, such that each unique key gets expanded exactly once. `key` defaults to `code`. For example, `PERM_ONCE(a;) b; PERM_ONCE(a;)` expands to either `a; b;` or `b; a;`.
|
||||
|
||||
Arguments are split by a commas, exluding commas inside parenthesis. `(,)` is a special escape sequence that resolves to `,`.
|
||||
|
||||
Nested macros are allowed, so e.g.
|
||||
```
|
||||
PERM_VAR(delayed, )
|
||||
PERM_GENERAL(stmt;, PERM_VAR(delayed, stmt;))
|
||||
...
|
||||
PERM_VAR(delayed)
|
||||
```
|
||||
is an alternative way of writing `PERM_ONCE`.
|
||||
|
||||
## permuter@home
|
||||
|
||||
The permuter supports a distributed mode, where people can donate processor power to your permuter runs to speed them up.
|
||||
To use this, pass `-J` when running `permuter.py` and follow the instructions.
|
||||
You will need to be granted access by someone who is already connected to a permuter network.
|
||||
|
||||
To allow others to use your computer for permuter runs, do the following:
|
||||
|
||||
- install Docker (used for sandboxing and to ensure a consistent environment)
|
||||
- if on Linux, add yourself to the Docker group: `sudo usermod -aG docker $USER`
|
||||
- install required packages: `python3 -m pip install docker`
|
||||
- open a terminal, and run `./pah.py run-server` to start the server.
|
||||
There are a few required arguments (e.g. how many cores to use), see `--help` for more details.
|
||||
|
||||
Please be aware that being in the Docker group implies (password-less) sudo rights.
|
||||
You can avoid that for your personal account by running the permuter under a separate user.
|
||||
Unfortunately, there is currently no way to run a sandboxed permuter server without sudo rights. 😢
|
||||
|
||||
Anyone who is granted access to permuter@home can run a server.
|
||||
|
||||
To set up a new permuter network, see [src/net/controller/README.md](./src/net/controller/README.md).
|
||||
|
||||
## FAQ
|
||||
|
||||
**What do the scores mean?** The scores are computed by taking diffs of objdump'd .o
|
||||
files, and giving different penalties for lines that are the same/use the same
|
||||
instruction/are reordered/don't match at all. 0 means the function matches fully.
|
||||
Stack positions are ignored unless --stack-diffs is passed (but beware that the
|
||||
permuter is currently quite bad at resolving stack differences). For more details,
|
||||
see scorer.py. It's far from a perfect system, and should probably be tweaked to
|
||||
look at e.g. the register diff graph.
|
||||
|
||||
**What sort of non-matchings are the permuter good at?** It's generally best towards
|
||||
the end, when mostly regalloc changes remain. If there are reorderings or functional
|
||||
changes, it's often easy to resolve those by hand, and neither the scorer nor the
|
||||
randomizer tends to play well with them.
|
||||
|
||||
**Should I use this instead of trying to match code by hand?** No, but it can be a good
|
||||
complement. PERM macros can be used to quickly test lots of variations of a function at
|
||||
once, in cases where there are interactions between several parts of a function.
|
||||
The randomization mode often finds lots of nonsensical changes that improve regalloc
|
||||
"by accident"; it's up to you to pick out the ones that look sensible. If none do,
|
||||
it can still be useful to know which parts of the function need to be changed to get the
|
||||
code nearer to matching. Having made one of the improvements, and the function can then be
|
||||
permuted again, to find further possible improvements.
|
||||
|
||||
## Helping out
|
||||
|
||||
There's tons of room for helping out with the permuter!
|
||||
Many more randomization passes could be added, the scoring function is far from optimal,
|
||||
the permuter could be made easier to use, etc. etc. The GitHub Issues list has some ideas.
|
||||
|
||||
Ideally, `mypy permuter.py` and `./run-tests.sh` should succeed with no errors, and files
|
||||
formatted with `black`. To setup a pre-commit hook for black, run:
|
||||
```
|
||||
pip install pre-commit black
|
||||
pre-commit install
|
||||
```
|
||||
PRs that skip this are still welcome, however.
|
||||
@@ -1,25 +0,0 @@
|
||||
This file describes how to manually set up a directory for use with the permuter.
|
||||
**You probably don't need to do this!** In normal circumstances, `./import.py`
|
||||
does all this for you. See README.md for more details.
|
||||
|
||||
* create a directory that will contain all of the input files for the invokation
|
||||
* put a compile command into `<dir>/compile.sh` (see e.g. `compile_example.sh`; it will be invoked as `./compile.sh input.c -o output.o`)
|
||||
* `gcc -E -P -I header_dir -D'__attribute__(x)=' orig_c_file.c > <dir>/base.c`
|
||||
* `python3 strip_other_fns.py <dir>/base.c func_name`
|
||||
* put asm for `func_name` into `<dir>/target.s`, with the following header:
|
||||
|
||||
```asm
|
||||
.set noat
|
||||
.set noreorder
|
||||
.set gp=64
|
||||
.macro glabel label
|
||||
.global \label
|
||||
.type \label, @function
|
||||
\label:
|
||||
.endm
|
||||
```
|
||||
* `mips-linux-gnu-as -march=vr4300 -mabi=32 <dir>/target.s -o <dir>/target.o`
|
||||
* optional sanity checks:
|
||||
- `<dir>/compile.sh <dir>/base.c -o <dir>/base.o`
|
||||
- `./diff.sh <dir>/target.o <dir>/base.o`
|
||||
* `./permuter.py <dir>`
|
||||
@@ -1,2 +0,0 @@
|
||||
#!/bin/bash
|
||||
mips-linux-gnu-gcc -O2 "$@"
|
||||
@@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [[ $# < 2 ]]; then
|
||||
echo "Usage: $0 orig.o new.o [flags]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $1 -o ! -f $2 ]; then
|
||||
echo Source files not readable
|
||||
exit 1
|
||||
fi
|
||||
|
||||
INPUT1="$1"
|
||||
INPUT2="$2"
|
||||
shift
|
||||
shift
|
||||
wdiff -n <(python3 ./src/objdump.py "$INPUT1" "$@") <(python3 ./src/objdump.py "$INPUT2" "$@") | colordiff | less -Ric
|
||||
@@ -1,808 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# usage: ./import.py path/to/file.c path/to/asm.s [make flags]
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import toml
|
||||
from typing import Callable, Dict, List, Match, Mapping, Optional, Pattern, Set, Tuple
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
|
||||
from src import ast_util
|
||||
from src.compiler import Compiler
|
||||
from src.error import CandidateConstructionFailure
|
||||
|
||||
is_macos = platform.system() == "Darwin"
|
||||
|
||||
|
||||
def homebrew_gcc_cpp() -> str:
|
||||
lookup_paths = ["/usr/local/bin", "/opt/homebrew/bin"]
|
||||
|
||||
for lookup_path in lookup_paths:
|
||||
try:
|
||||
return max(f for f in os.listdir(lookup_path) if f.startswith("cpp-"))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
print(
|
||||
"Error while looking up in " + ":".join(lookup_paths) + " for cpp- executable"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
cpp_cmd = homebrew_gcc_cpp() if is_macos else "cpp"
|
||||
make_cmd = "gmake" if is_macos else "make"
|
||||
|
||||
ASM_PRELUDE: str = """
|
||||
.set noat
|
||||
.set noreorder
|
||||
.set gp=64
|
||||
.macro glabel label
|
||||
.global \label
|
||||
.type \label, @function
|
||||
\label:
|
||||
.endm
|
||||
"""
|
||||
|
||||
DEFAULT_AS_CMDLINE: List[str] = ["mips-linux-gnu-as", "-march=vr4300", "-mabi=32"]
|
||||
|
||||
CPP: List[str] = [cpp_cmd, "-P", "-undef"]
|
||||
|
||||
STUB_FN_MACROS: List[str] = [
|
||||
"-D_Static_assert(x, y)=",
|
||||
"-D__attribute__(x)=",
|
||||
"-DGLOBAL_ASM(...)=",
|
||||
]
|
||||
|
||||
SETTINGS_FILES = ["permuter_settings.toml", "tools/permuter_settings.toml"]
|
||||
|
||||
|
||||
def formatcmd(cmdline: List[str]) -> str:
|
||||
return " ".join(shlex.quote(arg) for arg in cmdline)
|
||||
|
||||
|
||||
def parse_asm(asm_file: str) -> Tuple[str, str]:
|
||||
func_name = None
|
||||
asm_lines = []
|
||||
try:
|
||||
with open(asm_file, encoding="utf-8") as f:
|
||||
cur_section = ".text"
|
||||
for line in f:
|
||||
if line.strip().startswith(".section"):
|
||||
cur_section = line.split()[1]
|
||||
elif line.strip() in [
|
||||
".text",
|
||||
".rdata",
|
||||
".rodata",
|
||||
".late_rodata",
|
||||
".bss",
|
||||
".data",
|
||||
]:
|
||||
cur_section = line.strip()
|
||||
if cur_section == ".text":
|
||||
if func_name is None and line.strip().startswith("glabel "):
|
||||
func_name = line.split()[1]
|
||||
asm_lines.append(line)
|
||||
except OSError as e:
|
||||
print("Could not open assembly file:", e, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if func_name is None:
|
||||
print(
|
||||
"Missing function name in assembly file! The file should start with 'glabel function_name'.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if not re.fullmatch(r"[a-zA-Z0-9_$]+", func_name):
|
||||
print(f"Bad function name: {func_name}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
return func_name, "".join(asm_lines)
|
||||
|
||||
|
||||
def create_directory(func_name: str) -> str:
|
||||
os.makedirs(f"nonmatchings/", exist_ok=True)
|
||||
ctr = 0
|
||||
while True:
|
||||
ctr += 1
|
||||
dirname = f"{func_name}-{ctr}" if ctr > 1 else func_name
|
||||
dirname = f"nonmatchings/{dirname}"
|
||||
try:
|
||||
os.mkdir(dirname)
|
||||
return dirname
|
||||
except FileExistsError:
|
||||
pass
|
||||
|
||||
|
||||
def find_root_dir(filename: str, pattern: List[str]) -> Optional[str]:
|
||||
old_dirname = None
|
||||
dirname = os.path.abspath(os.path.dirname(filename))
|
||||
|
||||
while dirname and (not old_dirname or len(dirname) < len(old_dirname)):
|
||||
for fname in pattern:
|
||||
if os.path.isfile(os.path.join(dirname, fname)):
|
||||
return dirname
|
||||
old_dirname = dirname
|
||||
dirname = os.path.dirname(dirname)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def fixup_build_command(
|
||||
parts: List[str], ignore_part: str
|
||||
) -> Tuple[List[str], Optional[List[str]]]:
|
||||
res = []
|
||||
skip_count = 0
|
||||
assembler = None
|
||||
for part in parts:
|
||||
if skip_count > 0:
|
||||
skip_count -= 1
|
||||
continue
|
||||
if part in ["-MF", "-o"]:
|
||||
skip_count = 1
|
||||
continue
|
||||
if part == ignore_part:
|
||||
continue
|
||||
res.append(part)
|
||||
|
||||
try:
|
||||
ind0 = min(
|
||||
i
|
||||
for i, arg in enumerate(res)
|
||||
if any(
|
||||
cmd in arg
|
||||
for cmd in ["asm_processor", "asm-processor", "preprocess.py"]
|
||||
)
|
||||
)
|
||||
ind1 = res.index("--", ind0 + 1)
|
||||
ind2 = res.index("--", ind1 + 1)
|
||||
assembler = res[ind1 + 1 : ind2]
|
||||
res = res[ind0 + 1 : ind1] + res[ind2 + 1 :]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return res, assembler
|
||||
|
||||
|
||||
def find_build_command_line(
|
||||
root_dir: str, c_file: str, make_flags: List[str], build_system: str
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
if build_system == "make":
|
||||
build_invocation = [
|
||||
make_cmd,
|
||||
"--always-make",
|
||||
"--dry-run",
|
||||
"--debug=j",
|
||||
"PERMUTER=1",
|
||||
] + make_flags
|
||||
elif build_system == "ninja":
|
||||
build_invocation = ["ninja", "-t", "commands"] + make_flags
|
||||
else:
|
||||
print("Unknown build system '" + build_system + "'.")
|
||||
sys.exit(1)
|
||||
|
||||
rel_c_file = os.path.relpath(c_file, root_dir)
|
||||
debug_output = (
|
||||
subprocess.check_output(build_invocation, cwd=root_dir)
|
||||
.decode("utf-8")
|
||||
.split("\n")
|
||||
)
|
||||
|
||||
output = []
|
||||
close_match = False
|
||||
|
||||
assembler = DEFAULT_AS_CMDLINE
|
||||
for line in debug_output:
|
||||
while "//" in line:
|
||||
line = line.replace("//", "/")
|
||||
while "/./" in line:
|
||||
line = line.replace("/./", "/")
|
||||
if rel_c_file not in line:
|
||||
continue
|
||||
|
||||
close_match = True
|
||||
parts = shlex.split(line)
|
||||
|
||||
# extract actual command from 'bash -c "..."'
|
||||
if parts[0] == "bash" and "-c" in parts:
|
||||
for part in parts:
|
||||
if rel_c_file in part:
|
||||
parts = shlex.split(part)
|
||||
break
|
||||
|
||||
if rel_c_file not in parts:
|
||||
continue
|
||||
if "-o" not in parts:
|
||||
continue
|
||||
if "-fsyntax-only" in parts:
|
||||
continue
|
||||
cmdline, asmproc_assembler = fixup_build_command(parts, rel_c_file)
|
||||
if asmproc_assembler:
|
||||
assembler = asmproc_assembler
|
||||
output.append(cmdline)
|
||||
|
||||
if not output:
|
||||
close_extra = (
|
||||
"\n(Found one possible candidate, but didn't match due to "
|
||||
"either spaces in paths, having -fsyntax-only, or missing an -o flag.)"
|
||||
if close_match
|
||||
else ""
|
||||
)
|
||||
print(
|
||||
"Failed to find compile command from build script output. "
|
||||
f"Please ensure running '{' '.join(build_invocation)}' "
|
||||
f"contains a line with the string '{rel_c_file}'.{close_extra}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if len(output) > 1:
|
||||
output_lines = "\n".join(map(formatcmd, output))
|
||||
print(
|
||||
f"Error: found multiple compile commands for {rel_c_file}:\n{output_lines}\n"
|
||||
f"Please modify the build script such that '{' '.join(build_invocation)}' "
|
||||
"produces a single compile command.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
return output[0], assembler
|
||||
|
||||
|
||||
PreserveMacros = Tuple[Pattern[str], Callable[[str], str]]
|
||||
|
||||
|
||||
def build_preserve_macros(
|
||||
cwd: str, preserve_regex: Optional[str], settings: Mapping[str, object]
|
||||
) -> Optional[PreserveMacros]:
|
||||
|
||||
subdata = settings.get("preserve_macros", {})
|
||||
assert isinstance(subdata, dict)
|
||||
regexes = []
|
||||
for regex, value in subdata.items():
|
||||
assert isinstance(value, str)
|
||||
regexes.append((re.compile(f"^(?:{regex})$"), value))
|
||||
|
||||
if preserve_regex == "" or (preserve_regex is None and not regexes):
|
||||
return None
|
||||
|
||||
if preserve_regex is None:
|
||||
global_regex_text = "(?:" + ")|(?:".join(subdata.keys()) + ")"
|
||||
else:
|
||||
global_regex_text = preserve_regex
|
||||
global_regex = re.compile(f"^(?:{global_regex_text})$")
|
||||
|
||||
def type_fn(macro: str) -> str:
|
||||
for regex, value in regexes:
|
||||
if regex.match(macro):
|
||||
return value
|
||||
return "int"
|
||||
|
||||
return global_regex, type_fn
|
||||
|
||||
|
||||
def preprocess_c_with_macros(
|
||||
cpp_command: List[str], cwd: str, preserve_macros: PreserveMacros
|
||||
) -> Tuple[str, List[str]]:
|
||||
"""Import C file, preserving function macros. Subroutine of import_c_file.
|
||||
|
||||
Returns the source code and a list of preserved macros."""
|
||||
|
||||
preserve_regex, preserve_type_fn = preserve_macros
|
||||
|
||||
# Start by running 'cpp' in a mode that just processes ifdefs and includes.
|
||||
source = subprocess.check_output(
|
||||
cpp_command + ["-dD", "-fdirectives-only"], cwd=cwd, encoding="utf-8"
|
||||
)
|
||||
|
||||
# Modify function macros that match preserved names so the preprocessor
|
||||
# doesn't touch them, and at the same time normalize their syntax. Some
|
||||
# of these instances may be in comments, but that's fine.
|
||||
def repl(match: Match[str]) -> str:
|
||||
name = match.group(1)
|
||||
after = "(" if match.group(2) == "(" else " "
|
||||
if preserve_regex.match(name):
|
||||
return f"_permuter define {name}{after}"
|
||||
else:
|
||||
return f"#define {name}{after}"
|
||||
|
||||
source = re.sub(
|
||||
r"^\s*#\s*define\s+([a-zA-Z0-9_]+)([ \t\(]|$)",
|
||||
repl,
|
||||
source,
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
|
||||
# Get rid of auto-inserted macros which the second cpp invocation will
|
||||
# warn about.
|
||||
source = re.sub(r"^#define __STDC_.*\n", "", source, flags=re.MULTILINE)
|
||||
|
||||
# Now, run the preprocessor again for real.
|
||||
source = subprocess.check_output(
|
||||
CPP + STUB_FN_MACROS, cwd=cwd, encoding="utf-8", input=source
|
||||
)
|
||||
|
||||
# Finally, find all function-like defines that we hid (some might have
|
||||
# been comments, so we couldn't do this before), and construct fake
|
||||
# function declarations for them in a specially demarcated section of
|
||||
# the file. When the compiler runs, this section will be replaced by
|
||||
# the real defines and the preprocessor invoked once more.
|
||||
late_defines = []
|
||||
lines = []
|
||||
graph = defaultdict(set)
|
||||
reg_token = re.compile(r"[a-zA-Z0-9_]+")
|
||||
for line in source.splitlines():
|
||||
is_macro = line.startswith("_permuter define ")
|
||||
params = []
|
||||
if is_macro:
|
||||
ind1 = line.find("(")
|
||||
ind2 = line.find(" ", len("_permuter define "))
|
||||
ind = min(ind1, ind2)
|
||||
if ind == -1:
|
||||
ind = len(line) if ind1 == ind2 == -1 else max(ind1, ind2)
|
||||
before = line[:ind]
|
||||
after = line[ind:]
|
||||
name = before.split()[2]
|
||||
late_defines.append((name, after))
|
||||
if after.startswith("("):
|
||||
params = [w.strip() for w in after[1 : after.find(")")].split(",")]
|
||||
else:
|
||||
lines.append(line)
|
||||
name = ""
|
||||
for m in reg_token.finditer(line):
|
||||
name2 = m.group(0)
|
||||
has_wildcard = False
|
||||
if is_macro and name2 not in params:
|
||||
wcbefore = line[: m.start()].rstrip().endswith("##")
|
||||
wcafter = line[m.end() :].lstrip().startswith("##")
|
||||
if wcbefore or wcafter:
|
||||
graph[name].add(name2 + "*")
|
||||
has_wildcard = True
|
||||
if not has_wildcard:
|
||||
graph[name].add(name2)
|
||||
|
||||
# Prune away (recursively) unused macros, for cleanliness.
|
||||
used_anywhere = set()
|
||||
used_by_nonmacro = graph[""]
|
||||
queue = [""]
|
||||
while queue:
|
||||
name = queue.pop()
|
||||
if name not in used_anywhere:
|
||||
used_anywhere.add(name)
|
||||
if name.endswith("*"):
|
||||
wildcard = name[:-1]
|
||||
for name2 in graph:
|
||||
if wildcard in name2:
|
||||
queue.extend(graph[name2])
|
||||
else:
|
||||
queue.extend(graph[name])
|
||||
|
||||
def get_decl(name: str, after: str) -> str:
|
||||
typ = preserve_type_fn(name)
|
||||
if after.startswith("("):
|
||||
return f"{typ} {name}();"
|
||||
else:
|
||||
return f"extern {typ} {name};"
|
||||
|
||||
used_macros = [name for (name, after) in late_defines if name in used_by_nonmacro]
|
||||
|
||||
return (
|
||||
"\n".join(
|
||||
["#pragma _permuter latedefine start"]
|
||||
+ [
|
||||
f"#pragma _permuter define {name}{after}"
|
||||
for (name, after) in late_defines
|
||||
if name in used_anywhere
|
||||
]
|
||||
+ [
|
||||
get_decl(name, after)
|
||||
for (name, after) in late_defines
|
||||
if name in used_by_nonmacro
|
||||
]
|
||||
+ ["#pragma _permuter latedefine end"]
|
||||
+ lines
|
||||
+ [""]
|
||||
),
|
||||
used_macros,
|
||||
)
|
||||
|
||||
|
||||
def import_c_file(
|
||||
compiler: List[str],
|
||||
cwd: str,
|
||||
in_file: str,
|
||||
preserve_macros: Optional[PreserveMacros],
|
||||
) -> str:
|
||||
"""Preprocess a C file into permuter-usable source.
|
||||
|
||||
Prints preserved macros as a side effect.
|
||||
|
||||
Returns source for base.c and compilable (macro-expanded) source."""
|
||||
in_file = os.path.relpath(in_file, cwd)
|
||||
include_next = 0
|
||||
cpp_command = CPP + [in_file, "-D__sgi", "-D_LANGUAGE_C", "-DNON_MATCHING"]
|
||||
|
||||
for arg in compiler:
|
||||
if include_next > 0:
|
||||
include_next -= 1
|
||||
cpp_command.append(arg)
|
||||
continue
|
||||
if arg in ["-D", "-U", "-I"]:
|
||||
cpp_command.append(arg)
|
||||
include_next = 1
|
||||
continue
|
||||
if (
|
||||
arg.startswith("-D")
|
||||
or arg.startswith("-U")
|
||||
or arg.startswith("-I")
|
||||
or arg in ["-nostdinc"]
|
||||
):
|
||||
cpp_command.append(arg)
|
||||
|
||||
try:
|
||||
if preserve_macros is None:
|
||||
# Simple codepath, should work even if the more complex one breaks.
|
||||
source = subprocess.check_output(
|
||||
cpp_command + STUB_FN_MACROS, cwd=cwd, encoding="utf-8"
|
||||
)
|
||||
macros: List[str] = []
|
||||
else:
|
||||
source, macros = preprocess_c_with_macros(cpp_command, cwd, preserve_macros)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(
|
||||
"Failed to preprocess input file, when running command:\n"
|
||||
+ formatcmd(e.cmd),
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if macros:
|
||||
macro_str = "macros: " + ", ".join(macros)
|
||||
else:
|
||||
macro_str = "no macros"
|
||||
print(f"Preserving {macro_str}. Use --preserve-macros='<regex>' to override.")
|
||||
|
||||
return source
|
||||
|
||||
|
||||
def prune_source(
|
||||
source: str, should_prune: bool, func_name: str
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""Normalize the source by round-tripping it through pycparser, and
|
||||
optionally reduce it to a smaller version that includes only the imported
|
||||
function and functions/struct/variables that it uses.
|
||||
|
||||
Returns (source, compilable_source)."""
|
||||
try:
|
||||
ast = ast_util.parse_c(source, from_import=True)
|
||||
orig_fn, _ = ast_util.extract_fn(ast, func_name)
|
||||
if should_prune:
|
||||
try:
|
||||
ast_util.prune_ast(orig_fn, ast)
|
||||
source = ast_util.to_c_raw(ast)
|
||||
except Exception:
|
||||
print(
|
||||
"Source minimization failed! "
|
||||
"You could try --no-prune as a workaround."
|
||||
)
|
||||
raise
|
||||
return source, ast_util.to_c(ast, from_import=True)
|
||||
except CandidateConstructionFailure as e:
|
||||
print(e.message)
|
||||
if should_prune and "PERM_" in source:
|
||||
print(
|
||||
"Please put in PERM macros after import, otherwise source "
|
||||
"minimization does not work."
|
||||
)
|
||||
else:
|
||||
print("Proceeding anyway, but expect errors when permuting!")
|
||||
return source, None
|
||||
|
||||
|
||||
def prune_and_separate_context(
|
||||
source: str, should_prune: bool, func_name: str
|
||||
) -> Tuple[str, str]:
|
||||
"""Normalize the source by round-tripping it through pycparser, optionally
|
||||
reduce it to a smaller version that includes only the imported function and
|
||||
functions/struct/variables that it uses, and split the result into source
|
||||
for the function itself, and the rest of the file (the "context").
|
||||
|
||||
Returns (source, context)."""
|
||||
try:
|
||||
ast = ast_util.parse_c(source, from_import=True)
|
||||
orig_fn, ind = ast_util.extract_fn(ast, func_name)
|
||||
if should_prune:
|
||||
try:
|
||||
ind = ast_util.prune_ast(orig_fn, ast)
|
||||
except Exception:
|
||||
print(
|
||||
"Source minimization failed! "
|
||||
"You could try --no-prune as a workaround."
|
||||
)
|
||||
raise
|
||||
del ast.ext[ind]
|
||||
source = ast_util.to_c(orig_fn, from_import=True)
|
||||
context = ast_util.to_c(ast, from_import=True)
|
||||
return source, context
|
||||
except CandidateConstructionFailure as e:
|
||||
print(e.message)
|
||||
print("Unable to split context from source.")
|
||||
print("Proceeding anyway, but expect compile errors!")
|
||||
return ast_util.process_pragmas(source), ""
|
||||
|
||||
|
||||
def get_decompme_compiler_name(
|
||||
compiler: List[str], settings: Mapping[str, object], api_base: str
|
||||
) -> str:
|
||||
decompme_settings = settings.get("decompme", {})
|
||||
assert isinstance(decompme_settings, dict)
|
||||
compiler_mappings = decompme_settings.get("compilers", {})
|
||||
assert isinstance(compiler_mappings, dict)
|
||||
|
||||
compiler_path = compiler[0]
|
||||
|
||||
for path, compiler_name in compiler_mappings.items():
|
||||
assert isinstance(compiler_name, str)
|
||||
if path == compiler_path:
|
||||
return compiler_name
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(f"{api_base}/api/compilers") as f:
|
||||
json_data = json.load(f)
|
||||
available = json_data["compiler_ids"]
|
||||
if not isinstance(available, list):
|
||||
raise Exception("compiler_ids must be a list")
|
||||
if not all(isinstance(name, str) for name in available):
|
||||
raise Exception("compiler_ids must be a list of strings")
|
||||
except Exception as e:
|
||||
print(f"Failed to request available compilers from decomp.me:\n{e}")
|
||||
|
||||
print()
|
||||
print(
|
||||
f'Unable to map compiler path "{compiler_path}" to something '
|
||||
"decomp.me understands."
|
||||
)
|
||||
trail = "permuter_settings.toml, where ... is one of: " + ", ".join(available)
|
||||
if compiler_mappings:
|
||||
print(
|
||||
"Please add an entry:\n\n"
|
||||
f'"{compiler_path}" = "..."\n\n'
|
||||
f"to the [decompme.compilers] section of {trail}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Please add an section:\n\n"
|
||||
"[decompme.compilers]\n"
|
||||
f'"{compiler_path}" = "..."\n\n'
|
||||
f"to {trail}"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def finalize_compile_command(cmdline: List[str]) -> str:
|
||||
quoted = [arg if arg == "|" else shlex.quote(arg) for arg in cmdline]
|
||||
ind = (quoted + ["|"]).index("|")
|
||||
return " ".join(quoted[:ind] + ['"$INPUT"'] + quoted[ind:] + ["-o", '"$OUTPUT"'])
|
||||
|
||||
|
||||
def get_compiler_flags(cmdline: List[str]) -> str:
|
||||
flags = [b for a, b in zip(cmdline, cmdline[1:]) if a != "|" and b != "|"]
|
||||
return " ".join(shlex.quote(flag) for flag in flags)
|
||||
|
||||
|
||||
def write_compile_command(compiler: List[str], cwd: str, out_file: str) -> None:
|
||||
|
||||
with open(out_file, "w", encoding="utf-8") as f:
|
||||
f.write("#!/usr/bin/env bash\n")
|
||||
f.write('INPUT="$(realpath "$1")"\n')
|
||||
f.write('OUTPUT="$(realpath "$3")"\n')
|
||||
f.write(f"cd {shlex.quote(cwd)}\n")
|
||||
f.write(finalize_compile_command(compiler))
|
||||
os.chmod(out_file, 0o755)
|
||||
|
||||
|
||||
def write_asm(asm_cont: str, out_file: str) -> None:
|
||||
with open(out_file, "w", encoding="utf-8") as f:
|
||||
f.write(ASM_PRELUDE)
|
||||
f.write(asm_cont)
|
||||
|
||||
|
||||
def compile_asm(assembler: List[str], cwd: str, in_file: str, out_file: str) -> None:
|
||||
in_file = os.path.abspath(in_file)
|
||||
out_file = os.path.abspath(out_file)
|
||||
cmdline = assembler + [in_file, "-o", out_file]
|
||||
try:
|
||||
subprocess.check_call(cmdline, cwd=cwd)
|
||||
except subprocess.CalledProcessError:
|
||||
print(
|
||||
f"Failed to assemble .s file, command line:\n{formatcmd(cmdline)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def compile_base(compile_script: str, source: str, c_file: str, out_file: str) -> None:
|
||||
if "PERM_" in source:
|
||||
print(
|
||||
"Cannot test-compile imported code because it contains PERM macros. "
|
||||
"It is recommended to put in PERM macros after import."
|
||||
)
|
||||
return
|
||||
escaped_c_file = json.dumps(c_file)
|
||||
source = "#line 1 " + escaped_c_file + "\n" + source
|
||||
compiler = Compiler(compile_script, show_errors=True)
|
||||
o_file = compiler.compile(source)
|
||||
if o_file:
|
||||
shutil.move(o_file, out_file)
|
||||
else:
|
||||
print("Warning: failed to compile .c file.")
|
||||
|
||||
|
||||
def write_to_file(cont: str, filename: str) -> None:
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write(cont)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Import a function for use with the permuter.
|
||||
Will create a new directory nonmatchings/<funcname>-<id>/."""
|
||||
)
|
||||
parser.add_argument(
|
||||
"c_file",
|
||||
help="""File containing the function.
|
||||
Assumes that the file can be built with 'make' to create an .o file.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"asm_file",
|
||||
help="""File containing assembly for the function.
|
||||
Must start with 'glabel <function_name>' and contain no other functions.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"make_flags",
|
||||
nargs="*",
|
||||
help="Arguments to pass to 'make'. PERMUTER=1 will always be passed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep", action="store_true", help="Keep the directory on error."
|
||||
)
|
||||
settings_files = ", ".join(SETTINGS_FILES[:-1]) + " or " + SETTINGS_FILES[-1]
|
||||
parser.add_argument(
|
||||
"--preserve-macros",
|
||||
metavar="REGEX",
|
||||
dest="preserve_macros_regex",
|
||||
help=f"""Regex for which macros to preserve, or empty string for no macros.
|
||||
By default, this is read from {settings_files} in a parent directory of
|
||||
the imported file. Type information is also read from this file.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-prune",
|
||||
dest="prune",
|
||||
action="store_false",
|
||||
help="""Don't minimize the source to keep only the imported function and
|
||||
functions/struct/variables that it uses. Normally this behavior is
|
||||
useful to make the permuter faster, but in cases where unrelated code
|
||||
affects the generated assembly asm it can be necessary to turn off.
|
||||
Note that regardless of this setting the permuter always removes all
|
||||
other functions by replacing them with declarations.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decompme",
|
||||
dest="decompme",
|
||||
action="store_true",
|
||||
help="""Upload the function to decomp.me to share with other people,
|
||||
instead of importing.""",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
root_dir = find_root_dir(
|
||||
args.c_file, SETTINGS_FILES + ["Makefile", "makefile", "build.ninja"]
|
||||
)
|
||||
|
||||
if not root_dir:
|
||||
print(f"Can't find root dir of project!", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
settings: Mapping[str, object] = {}
|
||||
for filename in SETTINGS_FILES:
|
||||
filename = os.path.join(root_dir, filename)
|
||||
if os.path.exists(filename):
|
||||
with open(filename) as f:
|
||||
settings = toml.load(f)
|
||||
break
|
||||
|
||||
build_system = settings.get("build_system", "make")
|
||||
compiler = settings.get("compiler_command")
|
||||
assembler = settings.get("assembler_command")
|
||||
make_flags = args.make_flags
|
||||
|
||||
func_name, asm_cont = parse_asm(args.asm_file)
|
||||
print(f"Function name: {func_name}")
|
||||
|
||||
if compiler or assembler:
|
||||
assert isinstance(compiler, str)
|
||||
assert isinstance(assembler, str)
|
||||
assert settings.get("build_system") is None
|
||||
|
||||
compiler = shlex.split(compiler)
|
||||
assembler = shlex.split(assembler)
|
||||
else:
|
||||
assert isinstance(build_system, str)
|
||||
compiler, assembler = find_build_command_line(
|
||||
root_dir, args.c_file, make_flags, build_system
|
||||
)
|
||||
|
||||
print(f"Compiler: {formatcmd(compiler)} {{input}} -o {{output}}")
|
||||
print(f"Assembler: {formatcmd(assembler)} {{input}} -o {{output}}")
|
||||
|
||||
preserve_macros = build_preserve_macros(
|
||||
root_dir, args.preserve_macros_regex, settings
|
||||
)
|
||||
source = import_c_file(compiler, root_dir, args.c_file, preserve_macros)
|
||||
|
||||
if args.decompme:
|
||||
api_base = os.environ.get("DECOMPME_API_BASE", "https://decomp.me")
|
||||
compiler_name = get_decompme_compiler_name(compiler, settings, api_base)
|
||||
source, context = prune_and_separate_context(source, args.prune, func_name)
|
||||
print("Uploading...")
|
||||
try:
|
||||
post_data = urllib.parse.urlencode(
|
||||
{
|
||||
"target_asm": asm_cont,
|
||||
"context": context,
|
||||
"source_code": source,
|
||||
"compiler": compiler_name,
|
||||
"compiler_flags": get_compiler_flags(compiler),
|
||||
}
|
||||
).encode("ascii")
|
||||
with urllib.request.urlopen(f"{api_base}/api/scratch", post_data) as f:
|
||||
resp = f.read()
|
||||
json_data: Dict[str, str] = json.loads(resp)
|
||||
if "slug" in json_data:
|
||||
slug = json_data["slug"]
|
||||
print(f"https://decomp.me/scratch/{slug}")
|
||||
else:
|
||||
error = json_data.get("error", resp)
|
||||
print(f"Server error: {error}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return
|
||||
|
||||
source, compilable_source = prune_source(source, args.prune, func_name)
|
||||
|
||||
dirname = create_directory(func_name)
|
||||
base_c_file = f"{dirname}/base.c"
|
||||
base_o_file = f"{dirname}/base.o"
|
||||
target_s_file = f"{dirname}/target.s"
|
||||
target_o_file = f"{dirname}/target.o"
|
||||
compile_script = f"{dirname}/compile.sh"
|
||||
func_name_file = f"{dirname}/function.txt"
|
||||
|
||||
try:
|
||||
write_to_file(source, base_c_file)
|
||||
write_to_file(func_name, func_name_file)
|
||||
write_compile_command(compiler, root_dir, compile_script)
|
||||
write_asm(asm_cont, target_s_file)
|
||||
compile_asm(assembler, root_dir, target_s_file, target_o_file)
|
||||
if compilable_source is not None:
|
||||
compile_base(compile_script, compilable_source, base_c_file, base_o_file)
|
||||
except:
|
||||
if not args.keep:
|
||||
print(f"\nDeleting directory {dirname} (run with --keep to preserve it).")
|
||||
shutil.rmtree(dirname)
|
||||
raise
|
||||
|
||||
print(f"\nDone. Imported into {dirname}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,27 +0,0 @@
|
||||
[mypy]
|
||||
check_untyped_defs = True
|
||||
disallow_any_generics = False
|
||||
disallow_incomplete_defs = True
|
||||
disallow_subclassing_any = True
|
||||
disallow_untyped_calls = True
|
||||
disallow_untyped_decorators = True
|
||||
disallow_untyped_defs = True
|
||||
no_implicit_optional = True
|
||||
warn_redundant_casts = True
|
||||
warn_return_any = True
|
||||
warn_unused_ignores = True
|
||||
mypy_path = stubs
|
||||
python_version = 3.7
|
||||
files = import.py, pah.py, permuter.py, src/net/evaluator.py
|
||||
|
||||
[mypy-nacl.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-pystray.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-docker.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-PIL.*]
|
||||
ignore_missing_imports = True
|
||||
@@ -1,4 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
from src.net.cmd.main import main
|
||||
|
||||
main()
|
||||
@@ -1,5 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
from src.main import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,9 +0,0 @@
|
||||
# Optional configuration file for import.py. Put it in the root or in tools/
|
||||
# of the repo you are importing from.
|
||||
|
||||
build_system = "ninja"
|
||||
|
||||
[preserve_macros]
|
||||
"g[DS]P.*" = "void"
|
||||
"gDma.*" = "void"
|
||||
"_SHIFTL" = "unsigned int"
|
||||
@@ -1,3 +0,0 @@
|
||||
#!/bin/sh
|
||||
python3 -m unittest discover -s test/
|
||||
# python3 -m pytest test/
|
||||
@@ -1,13 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [[ $1 < 2 ]]; then
|
||||
echo "Usage: $0 output_dir"
|
||||
echo "Ex: $0 nonmatchings/func_80000000"
|
||||
exit 1
|
||||
fi
|
||||
if [[ ! -d $1 ]]; then
|
||||
echo "Argument must be a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
find $1 -name score.txt -exec echo -n {}\ \; -exec cat {} \; | sort -rnk2
|
||||
@@ -1,303 +0,0 @@
|
||||
"""Functions and classes for dealing with types in a C AST.
|
||||
|
||||
They make a number of simplifying assumptions:
|
||||
- const and volatile doesn't matter.
|
||||
- arithmetic promotes all int-like types to 'int'.
|
||||
- no two variables can have the same name, even across functions.
|
||||
|
||||
For the purposes of the randomizer these restrictions are acceptable."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Union, Dict, Set, List
|
||||
|
||||
from pycparser import c_ast
|
||||
from pycparser.c_ast import ArrayDecl, TypeDecl, PtrDecl, FuncDecl, IdentifierType
|
||||
|
||||
Type = Union[PtrDecl, ArrayDecl, TypeDecl, FuncDecl]
|
||||
SimpleType = Union[PtrDecl, TypeDecl]
|
||||
|
||||
StructUnion = Union[c_ast.Struct, c_ast.Union]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeMap:
|
||||
typedefs: Dict[str, Type] = field(default_factory=dict)
|
||||
fn_ret_types: Dict[str, Type] = field(default_factory=dict)
|
||||
var_types: Dict[str, Type] = field(default_factory=dict)
|
||||
struct_defs: Dict[str, StructUnion] = field(default_factory=dict)
|
||||
|
||||
|
||||
def basic_type(name: Union[str, List[str]]) -> TypeDecl:
|
||||
names = [name] if isinstance(name, str) else name
|
||||
idtype = IdentifierType(names=names)
|
||||
return TypeDecl(declname=None, quals=[], type=idtype)
|
||||
|
||||
|
||||
def pointer(type: Type) -> Type:
|
||||
return PtrDecl(quals=[], type=type)
|
||||
|
||||
|
||||
def resolve_typedefs(type: Type, typemap: TypeMap) -> Type:
|
||||
while (
|
||||
isinstance(type, TypeDecl)
|
||||
and isinstance(type.type, IdentifierType)
|
||||
and len(type.type.names) == 1
|
||||
and type.type.names[0] in typemap.typedefs
|
||||
):
|
||||
type = typemap.typedefs[type.type.names[0]]
|
||||
return type
|
||||
|
||||
|
||||
def pointer_decay(type: Type, typemap: TypeMap) -> SimpleType:
|
||||
real_type = resolve_typedefs(type, typemap)
|
||||
if isinstance(real_type, ArrayDecl):
|
||||
return PtrDecl(quals=[], type=real_type.type)
|
||||
if isinstance(real_type, FuncDecl):
|
||||
return PtrDecl(quals=[], type=type)
|
||||
if isinstance(real_type, TypeDecl) and isinstance(real_type.type, c_ast.Enum):
|
||||
return basic_type("int")
|
||||
assert not isinstance(
|
||||
type, (ArrayDecl, FuncDecl)
|
||||
), "resolve_typedefs can't hide arrays/functions"
|
||||
return type
|
||||
|
||||
|
||||
def get_decl_type(decl: c_ast.Decl) -> Type:
|
||||
"""For a Decl that declares a variable (and not just a struct/union/enum),
|
||||
return its type."""
|
||||
assert decl.name is not None
|
||||
assert isinstance(decl.type, (PtrDecl, ArrayDecl, FuncDecl, TypeDecl))
|
||||
return decl.type
|
||||
|
||||
|
||||
def deref_type(type: Type, typemap: TypeMap) -> Type:
|
||||
type = resolve_typedefs(type, typemap)
|
||||
assert isinstance(type, (ArrayDecl, PtrDecl)), "dereferencing non-pointer"
|
||||
return type.type
|
||||
|
||||
|
||||
def struct_member_type(struct: StructUnion, field_name: str, typemap: TypeMap) -> Type:
|
||||
if not struct.decls:
|
||||
assert (
|
||||
struct.name in typemap.struct_defs
|
||||
), f"Accessing field {field_name} of undefined struct {struct.name}"
|
||||
struct = typemap.struct_defs[struct.name]
|
||||
assert struct.decls, "struct_defs never points to an incomplete type"
|
||||
for decl in struct.decls:
|
||||
if isinstance(decl, c_ast.Decl):
|
||||
if decl.name == field_name:
|
||||
return get_decl_type(decl)
|
||||
if decl.name == None and isinstance(decl.type, (c_ast.Struct, c_ast.Union)):
|
||||
try:
|
||||
return struct_member_type(decl.type, field_name, typemap)
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
assert False, f"No field {field_name} in struct {struct.name}"
|
||||
|
||||
|
||||
def expr_type(node: c_ast.Node, typemap: TypeMap) -> Type:
|
||||
def rec(sub_expr: c_ast.Node) -> Type:
|
||||
return expr_type(sub_expr, typemap)
|
||||
|
||||
if isinstance(node, c_ast.Assignment):
|
||||
return rec(node.lvalue)
|
||||
if isinstance(node, c_ast.StructRef):
|
||||
lhs_type = rec(node.name)
|
||||
if node.type == "->":
|
||||
lhs_type = deref_type(lhs_type, typemap)
|
||||
struct_type = resolve_typedefs(lhs_type, typemap)
|
||||
assert isinstance(struct_type, TypeDecl)
|
||||
assert isinstance(
|
||||
struct_type.type, (c_ast.Struct, c_ast.Union)
|
||||
), f"struct deref of non-struct {struct_type.declname}"
|
||||
return struct_member_type(struct_type.type, node.field.name, typemap)
|
||||
if isinstance(node, c_ast.Cast):
|
||||
return node.to_type.type
|
||||
if isinstance(node, c_ast.Constant):
|
||||
if node.type == "string":
|
||||
return pointer(basic_type("char"))
|
||||
if node.type == "char":
|
||||
return basic_type("int")
|
||||
return basic_type(node.type.split(" "))
|
||||
if isinstance(node, c_ast.ID):
|
||||
return typemap.var_types[node.name]
|
||||
if isinstance(node, c_ast.UnaryOp):
|
||||
if node.op in ["p++", "p--", "++", "--"]:
|
||||
return rec(node.expr)
|
||||
if node.op == "&":
|
||||
return pointer(rec(node.expr))
|
||||
if node.op == "*":
|
||||
subtype = rec(node.expr)
|
||||
return deref_type(subtype, typemap)
|
||||
if node.op in ["-", "+"]:
|
||||
subtype = pointer_decay(rec(node.expr), typemap)
|
||||
if allowed_basic_type(subtype, typemap, ["double"]):
|
||||
return basic_type("double")
|
||||
if allowed_basic_type(subtype, typemap, ["float"]):
|
||||
return basic_type("float")
|
||||
if node.op in ["sizeof", "-", "+", "~", "!"]:
|
||||
return basic_type("int")
|
||||
assert False, f"unknown unary op {node.op}"
|
||||
if isinstance(node, c_ast.BinaryOp):
|
||||
lhs_type = pointer_decay(rec(node.left), typemap)
|
||||
rhs_type = pointer_decay(rec(node.right), typemap)
|
||||
if node.op in [">>", "<<"]:
|
||||
return lhs_type
|
||||
if node.op in ["<", "<=", ">", ">=", "==", "!=", "&&", "||"]:
|
||||
return basic_type("int")
|
||||
if node.op in "&|^%":
|
||||
return basic_type("int")
|
||||
real_lhs = resolve_typedefs(lhs_type, typemap)
|
||||
real_rhs = resolve_typedefs(rhs_type, typemap)
|
||||
if node.op in "+-":
|
||||
lptr = isinstance(real_lhs, PtrDecl)
|
||||
rptr = isinstance(real_rhs, PtrDecl)
|
||||
if lptr or rptr:
|
||||
if lptr and rptr:
|
||||
assert node.op != "+", "pointer + pointer"
|
||||
return basic_type("int")
|
||||
if lptr:
|
||||
return lhs_type
|
||||
assert node.op == "+", "int - pointer"
|
||||
return rhs_type
|
||||
if node.op in "*/+-":
|
||||
assert isinstance(real_lhs, TypeDecl)
|
||||
assert isinstance(real_rhs, TypeDecl)
|
||||
assert isinstance(real_lhs.type, IdentifierType)
|
||||
assert isinstance(real_rhs.type, IdentifierType)
|
||||
if "double" in real_lhs.type.names + real_rhs.type.names:
|
||||
return basic_type("double")
|
||||
if "float" in real_lhs.type.names + real_rhs.type.names:
|
||||
return basic_type("float")
|
||||
return basic_type("int")
|
||||
if isinstance(node, c_ast.FuncCall):
|
||||
expr = node.name
|
||||
if isinstance(expr, c_ast.ID):
|
||||
if expr.name not in typemap.fn_ret_types:
|
||||
raise Exception(f"Called function {expr.name} is missing a prototype")
|
||||
return typemap.fn_ret_types[expr.name]
|
||||
else:
|
||||
fptr_type = resolve_typedefs(rec(expr), typemap)
|
||||
if isinstance(fptr_type, PtrDecl):
|
||||
fptr_type = fptr_type.type
|
||||
fptr_type = resolve_typedefs(fptr_type, typemap)
|
||||
assert isinstance(fptr_type, FuncDecl), "call to non-function"
|
||||
return fptr_type.type
|
||||
if isinstance(node, c_ast.ExprList):
|
||||
return rec(node.exprs[-1])
|
||||
if isinstance(node, c_ast.ArrayRef):
|
||||
subtype = rec(node.name)
|
||||
return deref_type(subtype, typemap)
|
||||
if isinstance(node, c_ast.TernaryOp):
|
||||
return rec(node.iftrue)
|
||||
assert False, f"Unknown expression node type: {node}"
|
||||
|
||||
|
||||
def decayed_expr_type(expr: c_ast.Node, typemap: TypeMap) -> SimpleType:
|
||||
return pointer_decay(expr_type(expr, typemap), typemap)
|
||||
|
||||
|
||||
def same_type(
|
||||
type1: Type, type2: Type, typemap: TypeMap, allow_similar: bool = False
|
||||
) -> bool:
|
||||
while True:
|
||||
type1 = resolve_typedefs(type1, typemap)
|
||||
type2 = resolve_typedefs(type2, typemap)
|
||||
if isinstance(type1, ArrayDecl) and isinstance(type2, ArrayDecl):
|
||||
type1 = type1.type
|
||||
type2 = type2.type
|
||||
continue
|
||||
if isinstance(type1, PtrDecl) and isinstance(type2, PtrDecl):
|
||||
type1 = type1.type
|
||||
type2 = type2.type
|
||||
continue
|
||||
if isinstance(type1, TypeDecl) and isinstance(type2, TypeDecl):
|
||||
sub1 = type1.type
|
||||
sub2 = type2.type
|
||||
if isinstance(sub1, c_ast.Struct) and isinstance(sub2, c_ast.Struct):
|
||||
return sub1.name == sub2.name
|
||||
if isinstance(sub1, c_ast.Union) and isinstance(sub2, c_ast.Union):
|
||||
return sub1.name == sub2.name
|
||||
if (
|
||||
allow_similar
|
||||
and isinstance(sub1, (IdentifierType, c_ast.Enum))
|
||||
and isinstance(sub2, (IdentifierType, c_ast.Enum))
|
||||
):
|
||||
# All int-ish types are similar (except void, but whatever)
|
||||
return True
|
||||
if isinstance(sub1, c_ast.Enum) and isinstance(sub2, c_ast.Enum):
|
||||
return sub1.name == sub2.name
|
||||
if isinstance(sub1, IdentifierType) and isinstance(sub2, IdentifierType):
|
||||
return sorted(sub1.names) == sorted(sub2.names)
|
||||
return False
|
||||
|
||||
|
||||
def allowed_basic_type(
|
||||
type: SimpleType, typemap: TypeMap, allowed_types: List[str]
|
||||
) -> bool:
|
||||
"""Check if a type resolves to a basic type with one of the allowed_types
|
||||
keywords in it."""
|
||||
base_type = resolve_typedefs(type, typemap)
|
||||
if not isinstance(base_type, c_ast.TypeDecl):
|
||||
return False
|
||||
if not isinstance(base_type.type, c_ast.IdentifierType):
|
||||
return False
|
||||
if all(x not in base_type.type.names for x in allowed_types):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def build_typemap(ast: c_ast.FileAST) -> TypeMap:
|
||||
ret = TypeMap()
|
||||
for item in ast.ext:
|
||||
if isinstance(item, c_ast.Typedef):
|
||||
ret.typedefs[item.name] = item.type
|
||||
if isinstance(item, c_ast.FuncDef):
|
||||
assert item.decl.name is not None, "cannot define anonymous function"
|
||||
assert isinstance(item.decl.type, FuncDecl)
|
||||
ret.fn_ret_types[item.decl.name] = item.decl.type.type
|
||||
if isinstance(item, c_ast.Decl) and isinstance(item.type, FuncDecl):
|
||||
assert item.name is not None, "cannot define anonymous function"
|
||||
ret.fn_ret_types[item.name] = item.type.type
|
||||
defined_function_decls: Set[c_ast.Decl] = set()
|
||||
|
||||
class Visitor(c_ast.NodeVisitor):
|
||||
def visit_Struct(self, struct: c_ast.Struct) -> None:
|
||||
if struct.decls and struct.name is not None:
|
||||
ret.struct_defs[struct.name] = struct
|
||||
# Do not visit decls of this struct
|
||||
|
||||
def visit_Union(self, union: c_ast.Union) -> None:
|
||||
if union.decls and union.name is not None:
|
||||
ret.struct_defs[union.name] = union
|
||||
# Do not visit decls of this union
|
||||
|
||||
def visit_Decl(self, decl: c_ast.Decl) -> None:
|
||||
if decl.name is not None:
|
||||
ret.var_types[decl.name] = get_decl_type(decl)
|
||||
if not isinstance(decl.type, FuncDecl) or decl in defined_function_decls:
|
||||
# Do not visit declarations in parameter lists of functions
|
||||
# other than our own.
|
||||
self.visit(decl.type)
|
||||
|
||||
def visit_Enumerator(self, enumerator: c_ast.Enumerator) -> None:
|
||||
ret.var_types[enumerator.name] = basic_type("int")
|
||||
|
||||
def visit_FuncDef(self, fn: c_ast.FuncDef) -> None:
|
||||
if fn.decl.name is not None:
|
||||
ret.var_types[fn.decl.name] = get_decl_type(fn.decl)
|
||||
defined_function_decls.add(fn.decl)
|
||||
self.generic_visit(fn)
|
||||
|
||||
Visitor().visit(ast)
|
||||
return ret
|
||||
|
||||
|
||||
def set_decl_name(decl: c_ast.Decl) -> None:
|
||||
name = decl.name
|
||||
assert name is not None
|
||||
type = get_decl_type(decl)
|
||||
while not isinstance(type, TypeDecl):
|
||||
type = type.type
|
||||
type.declname = name
|
||||
@@ -1,505 +0,0 @@
|
||||
from base64 import b64decode
|
||||
from collections import defaultdict
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from pycparser import CParser, c_ast as ca, c_generator
|
||||
from pycparser.plyparser import ParseError
|
||||
|
||||
from .error import CandidateConstructionFailure
|
||||
from .ast_types import SimpleType, set_decl_name
|
||||
|
||||
|
||||
@dataclass
|
||||
class Indices:
|
||||
starts: Dict[ca.Node, int]
|
||||
ends: Dict[ca.Node, int]
|
||||
|
||||
|
||||
Block = Union[ca.Compound, ca.Case, ca.Default]
|
||||
if TYPE_CHECKING:
|
||||
# ca.Expression and ca.Statement don't actually exist, they live only in
|
||||
# the stubs file.
|
||||
Expression = ca.Expression
|
||||
Statement = ca.Statement
|
||||
else:
|
||||
Expression = Statement = None
|
||||
|
||||
|
||||
def to_c_raw(node: ca.Node) -> str:
|
||||
source: str = c_generator.CGenerator().visit(node)
|
||||
return source
|
||||
|
||||
|
||||
def to_c(node: ca.Node, *, from_import: bool = False) -> str:
|
||||
source = to_c_raw(node) if from_import else PatchedCGenerator().visit(node)
|
||||
return process_pragmas(source)
|
||||
|
||||
|
||||
def process_pragmas(source: str) -> str:
|
||||
if "#pragma" not in source:
|
||||
return source
|
||||
lines = source.split("\n")
|
||||
out: List[str] = []
|
||||
same_line = 0
|
||||
ignore = 0
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("#pragma _permuter "):
|
||||
# Expand permuter pragmas to nothing, by default. Still, keep one
|
||||
# output line per input line to preserve line numbers for import.py
|
||||
# error messages.
|
||||
line = ""
|
||||
|
||||
stripped = stripped[len("#pragma _permuter ") :]
|
||||
if stripped == "sameline start":
|
||||
same_line += 1
|
||||
elif stripped == "sameline end":
|
||||
same_line -= 1
|
||||
elif stripped == "latedefine start":
|
||||
ignore += 1
|
||||
elif stripped == "latedefine end":
|
||||
assert ignore > 0, "mismatched ignore pragmas"
|
||||
ignore -= 1
|
||||
elif stripped.startswith("define "):
|
||||
assert ignore > 0, "define pragma must be within latedefine block"
|
||||
line = "#" + stripped
|
||||
elif stripped.startswith("b64literal "):
|
||||
line = b64decode(stripped.split(" ", 1)[1]).decode("utf-8")
|
||||
elif ignore > 0:
|
||||
# Ignore non-pragma lines within latedefine section
|
||||
line = ""
|
||||
|
||||
if not same_line:
|
||||
line += "\n"
|
||||
elif line and out and not out[-1].endswith("\n"):
|
||||
line = " " + line.lstrip()
|
||||
out.append(line)
|
||||
assert same_line == 0
|
||||
assert ignore == 0, "unbalanced ignore pragmas"
|
||||
return "".join(out).rstrip() + "\n"
|
||||
|
||||
|
||||
class PatchedCGenerator(c_generator.CGenerator):
|
||||
"""Like a CGenerator, except it keeps else if's prettier despite
|
||||
the terrible things we've done to them in normalize_ast."""
|
||||
|
||||
def visit_If(self, n: ca.If) -> str:
|
||||
n2 = n
|
||||
if (
|
||||
n.iffalse
|
||||
and isinstance(n.iffalse, ca.Compound)
|
||||
and n.iffalse.block_items
|
||||
and len(n.iffalse.block_items) == 1
|
||||
and isinstance(n.iffalse.block_items[0], ca.If)
|
||||
):
|
||||
n2 = ca.If(cond=n.cond, iftrue=n.iftrue, iffalse=n.iffalse.block_items[0])
|
||||
return super().visit_If(n2) # type: ignore
|
||||
|
||||
|
||||
def extract_fn(ast: ca.FileAST, fn_name: str) -> Tuple[ca.FuncDef, int]:
|
||||
ret = []
|
||||
for i, node in enumerate(ast.ext):
|
||||
if isinstance(node, ca.FuncDef):
|
||||
if node.decl.name == fn_name:
|
||||
ret.append((node, i))
|
||||
else:
|
||||
node = node.decl
|
||||
ast.ext[i] = node
|
||||
if isinstance(node, ca.Decl) and isinstance(node.type, ca.FuncDecl):
|
||||
node.funcspec = [spec for spec in node.funcspec if spec != "static"]
|
||||
if len(ret) == 0:
|
||||
raise CandidateConstructionFailure(f"Function {fn_name} not found in base.c.")
|
||||
if len(ret) > 1:
|
||||
raise CandidateConstructionFailure(
|
||||
f"Found multiple copies of function {fn_name} in base.c."
|
||||
)
|
||||
return ret[0]
|
||||
|
||||
|
||||
def parse_c(source: str, *, from_import: bool = False) -> ca.FileAST:
|
||||
try:
|
||||
parser = CParser()
|
||||
return parser.parse(source, "<source>")
|
||||
except ParseError as e:
|
||||
msg = str(e)
|
||||
position, msg = msg.split(": ", 1)
|
||||
parts = position.split(":")
|
||||
if len(parts) >= 2:
|
||||
lineno = int(parts[1])
|
||||
posstr = f" at approximately line {lineno}"
|
||||
if len(parts) >= 3:
|
||||
posstr += f", column {parts[2]}"
|
||||
if not from_import:
|
||||
posstr += " (after PERM expansion)"
|
||||
try:
|
||||
line = source.split("\n")[lineno - 1].rstrip()
|
||||
posstr += "\n\n" + line
|
||||
except IndexError:
|
||||
posstr += "(out of bounds?)"
|
||||
else:
|
||||
posstr = ""
|
||||
raise CandidateConstructionFailure(
|
||||
f"Syntax error in base.c.\n{msg}{posstr}"
|
||||
) from None
|
||||
|
||||
|
||||
def compute_node_indices(top_node: ca.Node) -> Indices:
|
||||
starts: Dict[ca.Node, int] = {}
|
||||
ends: Dict[ca.Node, int] = {}
|
||||
cur_index = 1
|
||||
|
||||
class Visitor(ca.NodeVisitor):
|
||||
def generic_visit(self, node: ca.Node) -> None:
|
||||
nonlocal cur_index
|
||||
assert node not in starts, "nodes should only appear once in AST"
|
||||
starts[node] = cur_index
|
||||
cur_index += 2
|
||||
super().generic_visit(node)
|
||||
ends[node] = cur_index
|
||||
cur_index += 2
|
||||
|
||||
Visitor().visit(top_node)
|
||||
return Indices(starts, ends)
|
||||
|
||||
|
||||
def equal_ast(a: ca.Node, b: ca.Node) -> bool:
|
||||
def equal(a: Any, b: Any) -> bool:
|
||||
if type(a) != type(b):
|
||||
return False
|
||||
if a is None:
|
||||
return b is None
|
||||
if isinstance(a, list):
|
||||
assert isinstance(b, list)
|
||||
if len(a) != len(b):
|
||||
return False
|
||||
for i in range(len(a)):
|
||||
if not equal(a[i], b[i]):
|
||||
return False
|
||||
return True
|
||||
if isinstance(a, (int, str)):
|
||||
return bool(a == b)
|
||||
assert isinstance(a, ca.Node)
|
||||
for name in a.__slots__[:-2]: # type: ignore
|
||||
if not equal(getattr(a, name), getattr(b, name)):
|
||||
return False
|
||||
return True
|
||||
|
||||
return equal(a, b)
|
||||
|
||||
|
||||
def is_lvalue(expr: Expression) -> bool:
|
||||
if isinstance(expr, (ca.ID, ca.StructRef, ca.ArrayRef)):
|
||||
return True
|
||||
if isinstance(expr, ca.UnaryOp):
|
||||
return expr.op == "*"
|
||||
return False
|
||||
|
||||
|
||||
def is_effectful(expr: Expression) -> bool:
|
||||
found = False
|
||||
|
||||
class Visitor(ca.NodeVisitor):
|
||||
def visit_UnaryOp(self, node: ca.UnaryOp) -> None:
|
||||
nonlocal found
|
||||
if node.op in ["p++", "p--", "++", "--"]:
|
||||
found = True
|
||||
else:
|
||||
self.generic_visit(node.expr)
|
||||
|
||||
def visit_FuncCall(self, _: ca.Node) -> None:
|
||||
nonlocal found
|
||||
found = True
|
||||
|
||||
def visit_Assignment(self, _: ca.Node) -> None:
|
||||
nonlocal found
|
||||
found = True
|
||||
|
||||
Visitor().visit(expr)
|
||||
return found
|
||||
|
||||
|
||||
def get_block_stmts(block: Block, force: bool) -> List[Statement]:
|
||||
if isinstance(block, ca.Compound):
|
||||
ret = block.block_items or []
|
||||
if force and not block.block_items:
|
||||
block.block_items = ret
|
||||
else:
|
||||
ret = block.stmts or []
|
||||
if force and not block.stmts:
|
||||
block.stmts = ret
|
||||
return ret
|
||||
|
||||
|
||||
def insert_decl(
|
||||
fn: ca.FuncDef, var: str, type: SimpleType, random: Optional[Random] = None
|
||||
) -> None:
|
||||
type = copy.deepcopy(type)
|
||||
decl = ca.Decl(
|
||||
name=var, quals=[], storage=[], funcspec=[], type=type, init=None, bitsize=None
|
||||
)
|
||||
set_decl_name(decl)
|
||||
assert fn.body.block_items, "Non-empty function"
|
||||
for index, stmt in enumerate(fn.body.block_items):
|
||||
if not isinstance(stmt, ca.Decl):
|
||||
break
|
||||
else:
|
||||
index = len(fn.body.block_items)
|
||||
|
||||
if random:
|
||||
index = random.randint(0, index)
|
||||
fn.body.block_items[index:index] = [decl]
|
||||
|
||||
|
||||
def insert_statement(block: Block, index: int, stmt: Statement) -> None:
|
||||
stmts = get_block_stmts(block, True)
|
||||
stmts[index:index] = [stmt]
|
||||
|
||||
|
||||
def brace_nested_blocks(stmt: Statement) -> None:
|
||||
def brace(stmt: Statement) -> Block:
|
||||
if isinstance(stmt, (ca.Compound, ca.Case, ca.Default)):
|
||||
return stmt
|
||||
return ca.Compound([stmt])
|
||||
|
||||
if isinstance(stmt, (ca.For, ca.While, ca.DoWhile)):
|
||||
stmt.stmt = brace(stmt.stmt)
|
||||
elif isinstance(stmt, ca.If):
|
||||
stmt.iftrue = brace(stmt.iftrue)
|
||||
if stmt.iffalse:
|
||||
stmt.iffalse = brace(stmt.iffalse)
|
||||
elif isinstance(stmt, ca.Switch):
|
||||
stmt.stmt = brace(stmt.stmt)
|
||||
elif isinstance(stmt, ca.Label):
|
||||
brace_nested_blocks(stmt.stmt)
|
||||
|
||||
|
||||
def has_nested_block(node: ca.Node) -> bool:
|
||||
return isinstance(
|
||||
node,
|
||||
(
|
||||
ca.Compound,
|
||||
ca.For,
|
||||
ca.While,
|
||||
ca.DoWhile,
|
||||
ca.If,
|
||||
ca.Switch,
|
||||
ca.Case,
|
||||
ca.Default,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def for_nested_blocks(stmt: Statement, callback: Callable[[Block], None]) -> None:
|
||||
def invoke(stmt: Statement) -> None:
|
||||
assert isinstance(
|
||||
stmt, (ca.Compound, ca.Case, ca.Default)
|
||||
), "brace_nested_blocks should have turned nested statements into blocks"
|
||||
callback(stmt)
|
||||
|
||||
if isinstance(stmt, ca.Compound):
|
||||
invoke(stmt)
|
||||
elif isinstance(stmt, (ca.For, ca.While, ca.DoWhile)):
|
||||
invoke(stmt.stmt)
|
||||
elif isinstance(stmt, ca.If):
|
||||
if stmt.iftrue:
|
||||
invoke(stmt.iftrue)
|
||||
if stmt.iffalse:
|
||||
invoke(stmt.iffalse)
|
||||
elif isinstance(stmt, ca.Switch):
|
||||
invoke(stmt.stmt)
|
||||
elif isinstance(stmt, (ca.Case, ca.Default)):
|
||||
invoke(stmt)
|
||||
elif isinstance(stmt, ca.Label):
|
||||
for_nested_blocks(stmt.stmt, callback)
|
||||
|
||||
|
||||
def normalize_ast(fn: ca.FuncDef, ast: ca.FileAST) -> None:
|
||||
"""Add braces to all ifs/fors/etc., to make it easier to insert statements."""
|
||||
|
||||
def rec(block: Block) -> None:
|
||||
stmts = get_block_stmts(block, False)
|
||||
for stmt in stmts:
|
||||
brace_nested_blocks(stmt)
|
||||
for_nested_blocks(stmt, rec)
|
||||
|
||||
rec(fn.body)
|
||||
|
||||
|
||||
def prune_ast(fn: ca.FuncDef, ast: ca.FileAST) -> int:
|
||||
"""Prune away unnecessary parts of the AST, to reduce overhead from serialization
|
||||
and from the compiler's C parser."""
|
||||
|
||||
# Create a GC graph that maps names of declarations and enumerators to indices
|
||||
# in ast.ext, as well an initial list of GC roots, consisting of everything
|
||||
# that isn't a Decl and or Typedef.
|
||||
edges: Dict[str, List[int]] = defaultdict(list)
|
||||
gc_roots: List[int] = []
|
||||
can_fwd_declare_typedef: Set[str] = set()
|
||||
can_fwd_declare_tagged: Set[str] = set()
|
||||
|
||||
def add_type_edges(
|
||||
tp: Union["ca.Type", ca.Struct, ca.Union, ca.Enum], i: int
|
||||
) -> None:
|
||||
while isinstance(tp, (ca.PtrDecl, ca.ArrayDecl)):
|
||||
tp = tp.type
|
||||
if isinstance(tp, ca.FuncDecl):
|
||||
return
|
||||
inner_type = tp.type if isinstance(tp, ca.TypeDecl) else tp
|
||||
if isinstance(inner_type, ca.IdentifierType):
|
||||
return
|
||||
if inner_type.name:
|
||||
edges[inner_type.name].append(i)
|
||||
if isinstance(inner_type, ca.Enum) and inner_type.values:
|
||||
for value in inner_type.values.enumerators:
|
||||
edges[value.name].append(i)
|
||||
if isinstance(inner_type, (ca.Struct, ca.Union)) and inner_type.decls:
|
||||
for decl in inner_type.decls:
|
||||
if isinstance(decl, ca.Decl):
|
||||
add_type_edges(decl.type, i)
|
||||
|
||||
for i in range(len(ast.ext)):
|
||||
item = ast.ext[i]
|
||||
if isinstance(item, ca.Decl) and not item.init:
|
||||
# (Exclude declarations with initializers, since taking function
|
||||
# pointers can affect regalloc on IDO.)
|
||||
if item.name:
|
||||
edges[item.name].append(i)
|
||||
if isinstance(item.type, (ca.Struct, ca.Union, ca.Enum)) and item.type.name:
|
||||
can_fwd_declare_tagged.add(item.type.name)
|
||||
add_type_edges(item.type, i)
|
||||
elif isinstance(item, ca.Typedef):
|
||||
edges[item.name].append(i)
|
||||
if isinstance(item.type, ca.TypeDecl) and isinstance(
|
||||
item.type.type, (ca.Struct, ca.Union, ca.Enum)
|
||||
):
|
||||
can_fwd_declare_typedef.add(item.name)
|
||||
add_type_edges(item.type, i)
|
||||
elif isinstance(item, ca.Pragma) and "GLOBAL_ASM" in item.string:
|
||||
pass
|
||||
else:
|
||||
gc_roots.append(i)
|
||||
|
||||
mentioned_ids: Set[str] = set()
|
||||
|
||||
class IdVisitor(ca.NodeVisitor):
|
||||
def visit_Pragma(self, node: ca.Pragma) -> None:
|
||||
for token in re.findall(r"[a-zA-Z0-9_$]+", node.string):
|
||||
mentioned_ids.add(token)
|
||||
|
||||
def visit_ID(self, node: ca.ID) -> None:
|
||||
mentioned_ids.add(node.name)
|
||||
|
||||
IdVisitor().visit(ast)
|
||||
|
||||
# Do the GC as a DFS traversal of the graph. Visiting a node searches its
|
||||
# AST for all kinds of mentioned IDs, and adds more nodes to the stack
|
||||
# using the edges we found before.
|
||||
gc_todo: List[int] = gc_roots
|
||||
need_fwd_decl_typedef: List[str] = []
|
||||
need_fwd_decl_tagged: List[str] = []
|
||||
|
||||
def add_name(name: str) -> None:
|
||||
if name in edges:
|
||||
gc_todo.extend(edges[name])
|
||||
del edges[name]
|
||||
|
||||
class Visitor(ca.NodeVisitor):
|
||||
def visit_Pragma(self, node: ca.Pragma) -> None:
|
||||
for token in re.findall(r"[a-zA-Z0-9_$]+", node.string):
|
||||
add_name(token)
|
||||
|
||||
def visit_ID(self, node: ca.ID) -> None:
|
||||
add_name(node.name)
|
||||
|
||||
def visit_IdentifierType(self, node: ca.IdentifierType) -> None:
|
||||
for name in node.names:
|
||||
add_name(name)
|
||||
|
||||
def visit_Enum(self, node: ca.Enum) -> None:
|
||||
if node.name and not node.values:
|
||||
add_name(node.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Struct(self, node: ca.Struct) -> None:
|
||||
if node.name and not node.decls:
|
||||
add_name(node.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Union(self, node: ca.Union) -> None:
|
||||
if node.name and not node.decls:
|
||||
add_name(node.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_PtrDecl(self, node: ca.PtrDecl) -> None:
|
||||
# For pointer declarations which haven't been accessed, forward
|
||||
# declarations suffice.
|
||||
if (
|
||||
isinstance(node.type, ca.TypeDecl)
|
||||
and node.type.declname
|
||||
and node.type.declname not in mentioned_ids
|
||||
):
|
||||
tp = node.type.type
|
||||
if isinstance(tp, ca.IdentifierType):
|
||||
if all(name in can_fwd_declare_typedef for name in tp.names):
|
||||
need_fwd_decl_typedef.extend(tp.names)
|
||||
return
|
||||
elif tp.name and tp.name in can_fwd_declare_tagged:
|
||||
if not (tp.values if isinstance(tp, ca.Enum) else tp.decls):
|
||||
need_fwd_decl_tagged.append(tp.name)
|
||||
return
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_TypeDecl(self, node: ca.TypeDecl) -> None:
|
||||
if node.declname:
|
||||
add_name(node.declname)
|
||||
self.generic_visit(node)
|
||||
|
||||
keep_exts: Set[int] = set()
|
||||
while gc_todo:
|
||||
i = gc_todo.pop()
|
||||
if i not in keep_exts:
|
||||
keep_exts.add(i)
|
||||
Visitor().visit(ast.ext[i])
|
||||
|
||||
temp_id = 0
|
||||
|
||||
def fwd_declare(tp: Union[ca.Struct, ca.Union, ca.Enum]) -> None:
|
||||
nonlocal temp_id
|
||||
if not tp.name:
|
||||
temp_id += 1
|
||||
tp.name = f"_PermuterTemp{temp_id}"
|
||||
if isinstance(tp, (ca.Struct, ca.Union)):
|
||||
tp.decls = None
|
||||
elif isinstance(tp, ca.Enum):
|
||||
tp.values = None
|
||||
else:
|
||||
assert False
|
||||
|
||||
new_ext = []
|
||||
|
||||
for i, item in enumerate(ast.ext):
|
||||
if i in keep_exts:
|
||||
pass
|
||||
elif isinstance(item, ca.Typedef) and item.name in need_fwd_decl_typedef:
|
||||
assert item.name in can_fwd_declare_typedef
|
||||
assert isinstance(item.type, ca.TypeDecl)
|
||||
assert isinstance(item.type.type, (ca.Struct, ca.Union, ca.Enum))
|
||||
fwd_declare(item.type.type)
|
||||
elif (
|
||||
isinstance(item, ca.Decl)
|
||||
and isinstance(item.type, (ca.Struct, ca.Union, ca.Enum))
|
||||
and item.type.name
|
||||
and item.type.name in need_fwd_decl_tagged
|
||||
):
|
||||
assert item.type.name in can_fwd_declare_tagged
|
||||
fwd_declare(item.type)
|
||||
else:
|
||||
continue
|
||||
new_ext.append(item)
|
||||
|
||||
ast.ext = new_ext
|
||||
return ast.ext.index(fn)
|
||||
@@ -1,99 +0,0 @@
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
import functools
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from pycparser import c_ast as ca
|
||||
|
||||
from .compiler import Compiler
|
||||
from .randomizer import Randomizer
|
||||
from .scorer import Scorer
|
||||
from .perm.perm import EvalState
|
||||
from .perm.ast import apply_ast_perms
|
||||
from .helpers import try_remove
|
||||
from .profiler import Profiler
|
||||
from . import ast_util
|
||||
|
||||
|
||||
@dataclass
|
||||
class CandidateResult:
|
||||
"""Represents the result of scoring a candidate, and is sent from child to
|
||||
parent processes, or server to client with p@h."""
|
||||
|
||||
score: int
|
||||
hash: Optional[str]
|
||||
source: Optional[str]
|
||||
profiler: Optional[Profiler] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Candidate:
|
||||
"""
|
||||
Represents a AST candidate created from a source which can be randomized
|
||||
(possibly multiple times), compiled, and scored.
|
||||
"""
|
||||
|
||||
ast: ca.FileAST
|
||||
|
||||
fn_index: int
|
||||
rng_seed: int
|
||||
randomizer: Randomizer
|
||||
score_value: Optional[int] = field(init=False, default=None)
|
||||
score_hash: Optional[str] = field(init=False, default=None)
|
||||
_cache_source: Optional[str] = field(init=False, default=None)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=16)
|
||||
def _cached_shared_ast(
|
||||
source: str, fn_name: str
|
||||
) -> Tuple[ca.FuncDef, int, ca.FileAST]:
|
||||
ast = ast_util.parse_c(source)
|
||||
orig_fn, fn_index = ast_util.extract_fn(ast, fn_name)
|
||||
ast_util.normalize_ast(orig_fn, ast)
|
||||
return orig_fn, fn_index, ast
|
||||
|
||||
@staticmethod
|
||||
def from_source(
|
||||
source: str, eval_state: EvalState, fn_name: str, rng_seed: int
|
||||
) -> "Candidate":
|
||||
# Use the same AST for all instances of the same original source, but
|
||||
# with the target function deeply copied. Since we never change the
|
||||
# AST outside of the target function, this is fine, and it saves us
|
||||
# performance (deepcopy is really slow).
|
||||
orig_fn, fn_index, ast = Candidate._cached_shared_ast(source, fn_name)
|
||||
ast = copy.copy(ast)
|
||||
ast.ext = copy.copy(ast.ext)
|
||||
fn_copy = copy.deepcopy(orig_fn)
|
||||
ast.ext[fn_index] = fn_copy
|
||||
apply_ast_perms(fn_copy, eval_state)
|
||||
return Candidate(
|
||||
ast=ast,
|
||||
fn_index=fn_index,
|
||||
rng_seed=rng_seed,
|
||||
randomizer=Randomizer(rng_seed),
|
||||
)
|
||||
|
||||
def randomize_ast(self) -> None:
|
||||
self.randomizer.randomize(self.ast, self.fn_index)
|
||||
self._cache_source = None
|
||||
|
||||
def get_source(self) -> str:
|
||||
if self._cache_source is None:
|
||||
self._cache_source = ast_util.to_c(self.ast)
|
||||
return self._cache_source
|
||||
|
||||
def compile(self, compiler: Compiler, show_errors: bool = False) -> Optional[str]:
|
||||
source: str = self.get_source()
|
||||
return compiler.compile(source, show_errors=show_errors)
|
||||
|
||||
def score(self, scorer: Scorer, o_file: Optional[str]) -> CandidateResult:
|
||||
self.score_value = None
|
||||
self.score_hash = None
|
||||
try:
|
||||
self.score_value, self.score_hash = scorer.score(o_file)
|
||||
finally:
|
||||
if o_file:
|
||||
try_remove(o_file)
|
||||
return CandidateResult(
|
||||
score=self.score_value, hash=self.score_hash, source=self.get_source()
|
||||
)
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import Optional
|
||||
import tempfile
|
||||
import subprocess
|
||||
|
||||
from .helpers import try_remove
|
||||
|
||||
|
||||
class Compiler:
|
||||
def __init__(self, compile_cmd: str, *, show_errors: bool) -> None:
|
||||
self.compile_cmd = compile_cmd
|
||||
self.show_errors = show_errors
|
||||
|
||||
def compile(self, source: str, *, show_errors: bool = False) -> Optional[str]:
|
||||
"""Try to compile a piece of C code. Returns the filename of the resulting .o
|
||||
temp file if it succeeds."""
|
||||
show_errors = show_errors or self.show_errors
|
||||
with tempfile.NamedTemporaryFile(
|
||||
prefix="permuter", suffix=".c", mode="w", delete=False
|
||||
) as f:
|
||||
c_name = f.name
|
||||
f.write(source)
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
prefix="permuter", suffix=".o", delete=False
|
||||
) as f2:
|
||||
o_name = f2.name
|
||||
|
||||
try:
|
||||
stderr = 2 if show_errors else subprocess.DEVNULL
|
||||
subprocess.check_call(
|
||||
[self.compile_cmd, c_name, "-o", o_name],
|
||||
stdout=stderr,
|
||||
stderr=stderr,
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
if not show_errors:
|
||||
try_remove(c_name)
|
||||
try_remove(o_name)
|
||||
return None
|
||||
except KeyboardInterrupt:
|
||||
# If Ctrl+C happens during this call, make a best effort in
|
||||
# removing the .c and .o files. This is totally racy, but oh well...
|
||||
try_remove(c_name)
|
||||
try_remove(o_name)
|
||||
raise
|
||||
|
||||
try_remove(c_name)
|
||||
return o_name
|
||||
@@ -1,11 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerError(Exception):
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CandidateConstructionFailure(Exception):
|
||||
message: str
|
||||
@@ -1,22 +0,0 @@
|
||||
import os
|
||||
from typing import NoReturn
|
||||
|
||||
|
||||
def plural(n: int, noun: str) -> str:
|
||||
s = "s" if n != 1 else ""
|
||||
return f"{n} {noun}{s}"
|
||||
|
||||
|
||||
def exception_to_string(e: object) -> str:
|
||||
return str(e) or e.__class__.__name__
|
||||
|
||||
|
||||
def static_assert_unreachable(x: NoReturn) -> NoReturn:
|
||||
raise Exception("Unreachable! " + repr(x))
|
||||
|
||||
|
||||
def try_remove(path: str) -> None:
|
||||
try:
|
||||
os.remove(path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
@@ -1,654 +0,0 @@
|
||||
import argparse
|
||||
from dataclasses import dataclass, field
|
||||
import itertools
|
||||
import multiprocessing
|
||||
from multiprocessing import Queue
|
||||
import os
|
||||
import queue
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from .candidate import CandidateResult
|
||||
from .compiler import Compiler
|
||||
from .error import CandidateConstructionFailure
|
||||
from .helpers import plural, static_assert_unreachable
|
||||
from .net.client import start_client
|
||||
from .net.core import ServerError, connect, enable_debug_mode, MAX_PRIO, MIN_PRIO
|
||||
from .permuter import (
|
||||
EvalError,
|
||||
EvalResult,
|
||||
Feedback,
|
||||
Finished,
|
||||
Message,
|
||||
NeedMoreWork,
|
||||
Permuter,
|
||||
Task,
|
||||
WorkDone,
|
||||
)
|
||||
from .preprocess import preprocess
|
||||
from .printer import Printer
|
||||
from .profiler import Profiler
|
||||
from .scorer import Scorer
|
||||
|
||||
# The probability that the randomizer continues transforming the output it
|
||||
# generated last time.
|
||||
DEFAULT_RAND_KEEP_PROB = 0.6
|
||||
|
||||
|
||||
@dataclass
|
||||
class Options:
|
||||
directories: List[str]
|
||||
show_errors: bool = False
|
||||
show_timings: bool = False
|
||||
print_diffs: bool = False
|
||||
stack_differences: bool = False
|
||||
abort_exceptions: bool = False
|
||||
better_only: bool = False
|
||||
best_only: bool = False
|
||||
quiet: bool = False
|
||||
stop_on_zero: bool = False
|
||||
keep_prob: float = DEFAULT_RAND_KEEP_PROB
|
||||
force_seed: Optional[str] = None
|
||||
threads: int = 1
|
||||
use_network: bool = False
|
||||
network_debug: bool = False
|
||||
network_priority: float = 1.0
|
||||
|
||||
|
||||
def restricted_float(lo: float, hi: float) -> Callable[[str], float]:
|
||||
def convert(x: str) -> float:
|
||||
try:
|
||||
ret = float(x)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"invalid float value: '{x}'")
|
||||
|
||||
if ret < lo or ret > hi:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"value {x} is out of range (must be between {lo} and {hi})"
|
||||
)
|
||||
return ret
|
||||
|
||||
return convert
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalContext:
|
||||
options: Options
|
||||
printer: Printer = field(default_factory=Printer)
|
||||
iteration: int = 0
|
||||
errors: int = 0
|
||||
overall_profiler: Profiler = field(default_factory=Profiler)
|
||||
permuters: List[Permuter] = field(default_factory=list)
|
||||
|
||||
|
||||
def write_candidate(perm: Permuter, result: CandidateResult) -> None:
|
||||
"""Write the candidate's C source and score to the next output directory"""
|
||||
ctr = 0
|
||||
while True:
|
||||
ctr += 1
|
||||
try:
|
||||
output_dir = os.path.join(perm.dir, f"output-{result.score}-{ctr}")
|
||||
os.mkdir(output_dir)
|
||||
break
|
||||
except FileExistsError:
|
||||
pass
|
||||
source = result.source
|
||||
assert source is not None, "Permuter._need_to_send_source is wrong!"
|
||||
with open(os.path.join(output_dir, "source.c"), "x", encoding="utf-8") as f:
|
||||
f.write(source)
|
||||
with open(os.path.join(output_dir, "score.txt"), "x", encoding="utf-8") as f:
|
||||
f.write(f"{result.score}\n")
|
||||
with open(os.path.join(output_dir, "diff.txt"), "x", encoding="utf-8") as f:
|
||||
f.write(perm.diff(source) + "\n")
|
||||
print(f"wrote to {output_dir}")
|
||||
|
||||
|
||||
def post_score(
|
||||
context: EvalContext, permuter: Permuter, result: EvalResult, who: Optional[str]
|
||||
) -> bool:
|
||||
if isinstance(result, EvalError):
|
||||
if result.exc_str is not None:
|
||||
context.printer.print(
|
||||
"internal permuter failure.", permuter, who, keep_progress=True
|
||||
)
|
||||
print(result.exc_str)
|
||||
if result.seed is not None:
|
||||
seed_str = str(result.seed[1])
|
||||
if result.seed[0] != 0:
|
||||
seed_str = f"{result.seed[0]},{seed_str}"
|
||||
print(f"To reproduce the failure, rerun with: --seed {seed_str}")
|
||||
if context.options.abort_exceptions:
|
||||
sys.exit(1)
|
||||
else:
|
||||
return False
|
||||
|
||||
if context.options.print_diffs:
|
||||
assert result.source is not None, "Permuter._need_to_send_source is wrong"
|
||||
print()
|
||||
print(permuter.diff(result.source))
|
||||
input("Press any key to continue...")
|
||||
|
||||
profiler = result.profiler
|
||||
score_value = result.score
|
||||
|
||||
if profiler is not None:
|
||||
for stattype in profiler.time_stats:
|
||||
context.overall_profiler.add_stat(stattype, profiler.time_stats[stattype])
|
||||
|
||||
context.iteration += 1
|
||||
if score_value == permuter.scorer.PENALTY_INF:
|
||||
disp_score = "inf"
|
||||
context.errors += 1
|
||||
else:
|
||||
disp_score = str(score_value)
|
||||
timings = ""
|
||||
if context.options.show_timings:
|
||||
timings = " \t" + context.overall_profiler.get_str_stats()
|
||||
status_line = f"iteration {context.iteration}, {context.errors} errors, score = {disp_score}{timings}"
|
||||
|
||||
if permuter.should_output(result):
|
||||
former_best = permuter.best_score
|
||||
permuter.record_result(result)
|
||||
if score_value < former_best:
|
||||
color = "\u001b[32;1m"
|
||||
msg = f"found new best score! ({score_value} vs {permuter.base_score})"
|
||||
elif score_value == former_best:
|
||||
color = "\u001b[32;1m"
|
||||
msg = f"tied best score! ({score_value} vs {permuter.base_score})"
|
||||
elif score_value < permuter.base_score:
|
||||
color = "\u001b[33m"
|
||||
msg = f"found a better score! ({score_value} vs {permuter.base_score})"
|
||||
else:
|
||||
color = "\u001b[33m"
|
||||
msg = f"found different asm with same score ({score_value})"
|
||||
context.printer.print(msg, permuter, who, color=color)
|
||||
|
||||
write_candidate(permuter, result)
|
||||
if not context.options.quiet:
|
||||
context.printer.progress(status_line)
|
||||
return score_value == 0
|
||||
|
||||
|
||||
def cycle_seeds(permuters: List[Permuter]) -> Iterable[Tuple[int, int]]:
|
||||
"""
|
||||
Return all possible (permuter index, seed) pairs, cycling over permuters.
|
||||
If a permuter is randomized, it will keep repeating seeds infinitely.
|
||||
"""
|
||||
iterators: List[Iterator[Tuple[int, int]]] = []
|
||||
for perm_ind, permuter in enumerate(permuters):
|
||||
it = permuter.seed_iterator()
|
||||
iterators.append(zip(itertools.repeat(perm_ind), it))
|
||||
|
||||
i = 0
|
||||
while iterators:
|
||||
i %= len(iterators)
|
||||
item = next(iterators[i], None)
|
||||
if item is None:
|
||||
del iterators[i]
|
||||
i -= 1
|
||||
else:
|
||||
yield item
|
||||
i += 1
|
||||
|
||||
|
||||
def multiprocess_worker(
|
||||
permuters: List[Permuter],
|
||||
input_queue: "Queue[Task]",
|
||||
output_queue: "Queue[Feedback]",
|
||||
) -> None:
|
||||
try:
|
||||
while True:
|
||||
# Read a work item from the queue. If none is immediately available,
|
||||
# tell the main thread to fill the queues more, and then block on
|
||||
# the queue.
|
||||
queue_item: Task
|
||||
try:
|
||||
queue_item = input_queue.get(block=False)
|
||||
except queue.Empty:
|
||||
output_queue.put((NeedMoreWork(), -1, None))
|
||||
queue_item = input_queue.get()
|
||||
if isinstance(queue_item, Finished):
|
||||
output_queue.put((queue_item, -1, None))
|
||||
output_queue.close()
|
||||
break
|
||||
permuter_index, seed = queue_item
|
||||
permuter = permuters[permuter_index]
|
||||
result = permuter.try_eval_candidate(seed)
|
||||
if isinstance(result, CandidateResult) and permuter.should_output(result):
|
||||
permuter.record_result(result)
|
||||
output_queue.put((WorkDone(permuter_index, result), -1, None))
|
||||
output_queue.put((NeedMoreWork(), -1, None))
|
||||
except KeyboardInterrupt:
|
||||
# Don't clutter the output with stack traces; Ctrl+C is the expected
|
||||
# way to quit and sends KeyboardInterrupt to all processes.
|
||||
# A heartbeat thing here would be good but is too complex.
|
||||
# Don't join the queue background thread -- thread joins in relation
|
||||
# to KeyboardInterrupt usually result in deadlocks.
|
||||
input_queue.cancel_join_thread()
|
||||
output_queue.cancel_join_thread()
|
||||
|
||||
|
||||
def run(options: Options) -> List[int]:
|
||||
last_time = time.time()
|
||||
try:
|
||||
|
||||
def heartbeat() -> None:
|
||||
nonlocal last_time
|
||||
last_time = time.time()
|
||||
|
||||
return run_inner(options, heartbeat)
|
||||
except KeyboardInterrupt:
|
||||
if time.time() - last_time > 5:
|
||||
print()
|
||||
print("Aborting stuck process.")
|
||||
raise
|
||||
print()
|
||||
print("Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def run_inner(options: Options, heartbeat: Callable[[], None]) -> List[int]:
|
||||
print("Loading...")
|
||||
|
||||
context = EvalContext(options)
|
||||
|
||||
force_seed: Optional[int] = None
|
||||
force_rng_seed: Optional[int] = None
|
||||
if options.force_seed:
|
||||
seed_parts = list(map(int, options.force_seed.split(",")))
|
||||
force_rng_seed = seed_parts[-1]
|
||||
force_seed = 0 if len(seed_parts) == 1 else seed_parts[0]
|
||||
|
||||
name_counts: Dict[str, int] = {}
|
||||
for i, d in enumerate(options.directories):
|
||||
heartbeat()
|
||||
compile_cmd = os.path.join(d, "compile.sh")
|
||||
target_o = os.path.join(d, "target.o")
|
||||
base_c = os.path.join(d, "base.c")
|
||||
for fname in [compile_cmd, target_o, base_c]:
|
||||
if not os.path.isfile(fname):
|
||||
print(f"Missing file {fname}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if not os.stat(compile_cmd).st_mode & 0o100:
|
||||
print(f"{compile_cmd} must be marked executable.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
fn_name: Optional[str] = None
|
||||
try:
|
||||
with open(os.path.join(d, "function.txt"), encoding="utf-8") as f:
|
||||
fn_name = f.read().strip()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
if fn_name:
|
||||
print(f"{base_c} ({fn_name})")
|
||||
else:
|
||||
print(base_c)
|
||||
|
||||
compiler = Compiler(compile_cmd, show_errors=options.show_errors)
|
||||
scorer = Scorer(target_o, stack_differences=options.stack_differences)
|
||||
c_source = preprocess(base_c)
|
||||
|
||||
try:
|
||||
permuter = Permuter(
|
||||
d,
|
||||
fn_name,
|
||||
compiler,
|
||||
scorer,
|
||||
base_c,
|
||||
c_source,
|
||||
force_seed=force_seed,
|
||||
force_rng_seed=force_rng_seed,
|
||||
keep_prob=options.keep_prob,
|
||||
need_profiler=options.show_timings,
|
||||
need_all_sources=options.print_diffs,
|
||||
show_errors=options.show_errors,
|
||||
best_only=options.best_only,
|
||||
better_only=options.better_only,
|
||||
)
|
||||
except CandidateConstructionFailure as e:
|
||||
print(e.message, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
context.permuters.append(permuter)
|
||||
name_counts[permuter.fn_name] = name_counts.get(permuter.fn_name, 0) + 1
|
||||
print()
|
||||
|
||||
if not context.permuters:
|
||||
print("No permuters!")
|
||||
return []
|
||||
|
||||
for permuter in context.permuters:
|
||||
if name_counts[permuter.fn_name] > 1:
|
||||
permuter.unique_name += f" ({permuter.dir})"
|
||||
print(f"[{permuter.unique_name}] base score = {permuter.best_score}")
|
||||
|
||||
found_zero = False
|
||||
if options.threads == 1 and not options.use_network:
|
||||
# Simple single-threaded mode. This is not technically needed, but
|
||||
# makes the permuter easier to debug.
|
||||
for permuter_index, seed in cycle_seeds(context.permuters):
|
||||
heartbeat()
|
||||
permuter = context.permuters[permuter_index]
|
||||
result = permuter.try_eval_candidate(seed)
|
||||
if post_score(context, permuter, result, None):
|
||||
found_zero = True
|
||||
if options.stop_on_zero:
|
||||
break
|
||||
else:
|
||||
seed_iterators: List[Optional[Iterator[int]]] = [
|
||||
permuter.seed_iterator()
|
||||
for perm_ind, permuter in enumerate(context.permuters)
|
||||
]
|
||||
seed_iterators_remaining = len(seed_iterators)
|
||||
next_iterator_index = 0
|
||||
|
||||
# Create queues.
|
||||
worker_task_queue: "Queue[Task]" = Queue()
|
||||
feedback_queue: "Queue[Feedback]" = Queue()
|
||||
|
||||
# Connect to network and create client threads and queues.
|
||||
net_conns: "List[Tuple[threading.Thread, Queue[Task]]]" = []
|
||||
if options.use_network:
|
||||
print("Connecting to permuter@home...")
|
||||
if options.network_debug:
|
||||
enable_debug_mode()
|
||||
first_stats: Optional[Tuple[int, int, float]] = None
|
||||
for perm_index in range(len(context.permuters)):
|
||||
try:
|
||||
port = connect()
|
||||
except (EOFError, ServerError) as e:
|
||||
print("Error:", e)
|
||||
sys.exit(1)
|
||||
thread, queue, stats = start_client(
|
||||
port,
|
||||
context.permuters[perm_index],
|
||||
perm_index,
|
||||
feedback_queue,
|
||||
options.network_priority,
|
||||
)
|
||||
net_conns.append((thread, queue))
|
||||
if first_stats is None:
|
||||
first_stats = stats
|
||||
assert first_stats is not None, "has at least one permuter"
|
||||
clients_str = plural(first_stats[0], "other client")
|
||||
servers_str = plural(first_stats[1], "server")
|
||||
cores_str = plural(int(first_stats[2]), "core")
|
||||
print(f"Connected! {servers_str} online ({cores_str}, {clients_str})")
|
||||
|
||||
# Start local worker threads
|
||||
processes: List[multiprocessing.Process] = []
|
||||
for i in range(options.threads):
|
||||
p = multiprocessing.Process(
|
||||
target=multiprocess_worker,
|
||||
args=(context.permuters, worker_task_queue, feedback_queue),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
active_workers = len(processes)
|
||||
|
||||
if not active_workers and not net_conns:
|
||||
print("No workers available! Exiting.")
|
||||
sys.exit(1)
|
||||
|
||||
def process_finish(finish: Finished, source: int) -> None:
|
||||
nonlocal active_workers
|
||||
|
||||
if finish.reason:
|
||||
permuter: Optional[Permuter] = None
|
||||
if source != -1 and len(context.permuters) > 1:
|
||||
permuter = context.permuters[source]
|
||||
context.printer.print(finish.reason, permuter, None, keep_progress=True)
|
||||
|
||||
if source == -1:
|
||||
active_workers -= 1
|
||||
|
||||
def process_result(work: WorkDone, who: Optional[str]) -> bool:
|
||||
permuter = context.permuters[work.perm_index]
|
||||
return post_score(context, permuter, work.result, who)
|
||||
|
||||
def get_task(perm_index: int) -> Optional[Tuple[int, int]]:
|
||||
nonlocal next_iterator_index, seed_iterators_remaining
|
||||
if perm_index == -1:
|
||||
while seed_iterators_remaining > 0:
|
||||
task = get_task(next_iterator_index)
|
||||
next_iterator_index += 1
|
||||
next_iterator_index %= len(seed_iterators)
|
||||
if task is not None:
|
||||
return task
|
||||
else:
|
||||
it = seed_iterators[perm_index]
|
||||
if it is not None:
|
||||
seed = next(it, None)
|
||||
if seed is None:
|
||||
seed_iterators[perm_index] = None
|
||||
seed_iterators_remaining -= 1
|
||||
else:
|
||||
return (perm_index, seed)
|
||||
return None
|
||||
|
||||
# Feed the task queue with work and read from results queue.
|
||||
# We generally match these up one-by-one to avoid overfilling queues,
|
||||
# but workers can ask us to add more tasks into the system if they run
|
||||
# out of work. (This will happen e.g. at the very beginning, when the
|
||||
# queues are empty.)
|
||||
while seed_iterators_remaining > 0:
|
||||
heartbeat()
|
||||
feedback, source, who = feedback_queue.get()
|
||||
if isinstance(feedback, Finished):
|
||||
process_finish(feedback, source)
|
||||
elif isinstance(feedback, Message):
|
||||
context.printer.print(feedback.text, None, who, keep_progress=True)
|
||||
elif isinstance(feedback, WorkDone):
|
||||
if process_result(feedback, who):
|
||||
# Found score 0!
|
||||
found_zero = True
|
||||
if options.stop_on_zero:
|
||||
break
|
||||
elif isinstance(feedback, NeedMoreWork):
|
||||
task = get_task(source)
|
||||
if task is not None:
|
||||
if source == -1:
|
||||
worker_task_queue.put(task)
|
||||
else:
|
||||
net_conns[source][1].put(task)
|
||||
else:
|
||||
static_assert_unreachable(feedback)
|
||||
|
||||
# Signal workers to stop.
|
||||
for i in range(active_workers):
|
||||
worker_task_queue.put(Finished())
|
||||
|
||||
for conn in net_conns:
|
||||
conn[1].put(Finished())
|
||||
|
||||
# Await final results.
|
||||
while active_workers > 0 or net_conns:
|
||||
heartbeat()
|
||||
feedback, source, who = feedback_queue.get()
|
||||
if isinstance(feedback, Finished):
|
||||
process_finish(feedback, source)
|
||||
elif isinstance(feedback, Message):
|
||||
context.printer.print(feedback.text, None, who, keep_progress=True)
|
||||
elif isinstance(feedback, WorkDone):
|
||||
if not (options.stop_on_zero and found_zero):
|
||||
if process_result(feedback, who):
|
||||
found_zero = True
|
||||
elif isinstance(feedback, NeedMoreWork):
|
||||
pass
|
||||
else:
|
||||
static_assert_unreachable(feedback)
|
||||
|
||||
# Wait for workers to finish.
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# Wait for network connections to close (currently does not happen).
|
||||
for conn in net_conns:
|
||||
conn[0].join()
|
||||
|
||||
if found_zero:
|
||||
print("\nFound zero score! Exiting.")
|
||||
return [permuter.best_score for permuter in context.permuters]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
multiprocessing.freeze_support()
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
# Ideally we would do:
|
||||
# multiprocessing.set_start_method("spawn")
|
||||
# here, to make multiprocessing behave the same across operating systems.
|
||||
# However, that means that arguments to Process are passed across using
|
||||
# pickling, which mysteriously breaks with pycparser...
|
||||
# (AttributeError: 'CParser' object has no attribute 'p_abstract_declarator_opt')
|
||||
# So, for now we live with the defaults, which make multiprocessing work on Linux,
|
||||
# where it uses fork and doesn't pickle arguments, and break on Windows. Sigh.
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Randomly permute C files to better match a target binary."
|
||||
)
|
||||
parser.add_argument(
|
||||
"directories",
|
||||
nargs="+",
|
||||
metavar="directory",
|
||||
help="Directory containing base.c, target.o and compile.sh. Multiple directories may be given.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-errors",
|
||||
dest="show_errors",
|
||||
action="store_true",
|
||||
help="Display compiler error/warning messages, and keep .c files for failed compiles.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-timings",
|
||||
dest="show_timings",
|
||||
action="store_true",
|
||||
help="Display the time taken by permuting vs. compiling vs. scoring.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print-diffs",
|
||||
dest="print_diffs",
|
||||
action="store_true",
|
||||
help="Instead of compiling generated sources, display diffs against a base version.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--abort-exceptions",
|
||||
dest="abort_exceptions",
|
||||
action="store_true",
|
||||
help="Stop execution when an internal permuter exception occurs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--better-only",
|
||||
dest="better_only",
|
||||
action="store_true",
|
||||
help="Only report scores better than the base.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best-only",
|
||||
dest="best_only",
|
||||
action="store_true",
|
||||
help="Only report ties or new high scores.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop-on-zero",
|
||||
dest="stop_on_zero",
|
||||
action="store_true",
|
||||
help="Stop after producing an output with score 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quiet",
|
||||
dest="quiet",
|
||||
action="store_true",
|
||||
help="Don't print a status line with the number of iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stack-diffs",
|
||||
dest="stack_differences",
|
||||
action="store_true",
|
||||
help="Take stack differences into account when computing the score.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-prob",
|
||||
dest="keep_prob",
|
||||
metavar="PROB",
|
||||
type=restricted_float(0.0, 1.0),
|
||||
default=DEFAULT_RAND_KEEP_PROB,
|
||||
help="""Continue randomizing the previous output with the given probability
|
||||
(float in 0..1, default %(default)s).""",
|
||||
)
|
||||
parser.add_argument("--seed", dest="force_seed", type=str, help=argparse.SUPPRESS)
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
dest="threads",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of own threads to use (default: 1 without -J, 0 with -J).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-J",
|
||||
dest="use_network",
|
||||
action="store_true",
|
||||
help="Harness extra compute power through cyberspace (permuter@home).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pah-debug",
|
||||
dest="network_debug",
|
||||
action="store_true",
|
||||
help="Enable debug prints for permuter@home.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--priority",
|
||||
dest="network_priority",
|
||||
metavar="PRIORITY",
|
||||
type=restricted_float(MIN_PRIO, MAX_PRIO),
|
||||
default=1.0,
|
||||
help=f"""Proportion of server resources to use when multiple people
|
||||
are using -J at the same time.
|
||||
Defaults to 1.0, meaning resources are split equally, but can be
|
||||
set to any value within [{MIN_PRIO}, {MAX_PRIO}].
|
||||
Each server runs with a priority threshold, which defaults to 0.1,
|
||||
below which they will not run permuter jobs at all.""",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
threads = args.threads
|
||||
if not threads and not args.use_network:
|
||||
threads = 1
|
||||
|
||||
options = Options(
|
||||
directories=args.directories,
|
||||
show_errors=args.show_errors,
|
||||
show_timings=args.show_timings,
|
||||
print_diffs=args.print_diffs,
|
||||
abort_exceptions=args.abort_exceptions,
|
||||
better_only=args.better_only,
|
||||
best_only=args.best_only,
|
||||
quiet=args.quiet,
|
||||
stack_differences=args.stack_differences,
|
||||
stop_on_zero=args.stop_on_zero,
|
||||
keep_prob=args.keep_prob,
|
||||
force_seed=args.force_seed,
|
||||
threads=threads,
|
||||
use_network=args.use_network,
|
||||
network_debug=args.network_debug,
|
||||
network_priority=args.network_priority,
|
||||
)
|
||||
|
||||
run(options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,272 +0,0 @@
|
||||
from multiprocessing import Queue
|
||||
import re
|
||||
import threading
|
||||
from typing import Optional, Tuple
|
||||
import zlib
|
||||
|
||||
from ..candidate import CandidateResult
|
||||
from ..helpers import exception_to_string
|
||||
from ..permuter import (
|
||||
EvalError,
|
||||
EvalResult,
|
||||
Feedback,
|
||||
FeedbackItem,
|
||||
Finished,
|
||||
Message,
|
||||
NeedMoreWork,
|
||||
Permuter,
|
||||
Task,
|
||||
WorkDone,
|
||||
)
|
||||
from ..profiler import Profiler
|
||||
from .core import (
|
||||
PermuterData,
|
||||
SocketPort,
|
||||
json_prop,
|
||||
permuter_data_to_json,
|
||||
)
|
||||
|
||||
|
||||
def _profiler_from_json(obj: dict) -> Profiler:
|
||||
ret = Profiler()
|
||||
for key in obj:
|
||||
assert isinstance(key, str), "json properties are strings"
|
||||
stat = Profiler.StatType[key]
|
||||
time = json_prop(obj, key, float)
|
||||
ret.add_stat(stat, time)
|
||||
return ret
|
||||
|
||||
|
||||
def _result_from_json(obj: dict, source: Optional[str]) -> EvalResult:
|
||||
if "error" in obj:
|
||||
return EvalError(exc_str=json_prop(obj, "error", str), seed=None)
|
||||
|
||||
profiler: Optional[Profiler] = None
|
||||
if "profiler" in obj:
|
||||
profiler = _profiler_from_json(json_prop(obj, "profiler", dict))
|
||||
return CandidateResult(
|
||||
score=json_prop(obj, "score", int),
|
||||
hash=json_prop(obj, "hash", str) if "hash" in obj else None,
|
||||
source=source,
|
||||
profiler=profiler,
|
||||
)
|
||||
|
||||
|
||||
def _make_script_portable(source: str) -> str:
|
||||
"""Parse a shell script and get rid of the machine-specific parts that
|
||||
import.py introduces. The resulting script must be run in an environment
|
||||
that has the right binaries in its $PATH, and with a current working
|
||||
directory similar to where import.py found its target's make root."""
|
||||
lines = []
|
||||
for line in source.split("\n"):
|
||||
if re.match("cd '?/", line):
|
||||
# Skip cd's to absolute directory paths. Note that shlex quotes
|
||||
# its argument with ' if it contains spaces/single quotes.
|
||||
continue
|
||||
if re.match("'?/", line):
|
||||
quote = "'" if line[0] == "'" else ""
|
||||
ind = line.find(quote + " ")
|
||||
if ind == -1:
|
||||
ind = len(line)
|
||||
else:
|
||||
ind += len(quote)
|
||||
lastind = line.rfind("/", 0, ind)
|
||||
assert lastind != -1
|
||||
# Emit a call to "which" as the first part, to ensure the called
|
||||
# binary still sees an absolute path. qemu-irix requires this,
|
||||
# for some reason.
|
||||
line = "$(which " + quote + line[lastind + 1 : ind] + ")" + line[ind:]
|
||||
lines.append(line)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def make_portable_permuter(permuter: Permuter) -> PermuterData:
|
||||
with open(permuter.scorer.target_o, "rb") as f:
|
||||
target_o_bin = f.read()
|
||||
|
||||
with open(permuter.compiler.compile_cmd, "r") as f2:
|
||||
compile_script = _make_script_portable(f2.read())
|
||||
|
||||
return PermuterData(
|
||||
base_score=permuter.base_score,
|
||||
base_hash=permuter.base_hash,
|
||||
fn_name=permuter.fn_name,
|
||||
filename=permuter.source_file,
|
||||
keep_prob=permuter.keep_prob,
|
||||
need_profiler=permuter.need_profiler,
|
||||
stack_differences=permuter.scorer.stack_differences,
|
||||
compile_script=compile_script,
|
||||
source=permuter.source,
|
||||
target_o_bin=target_o_bin,
|
||||
)
|
||||
|
||||
|
||||
class Connection:
|
||||
_port: SocketPort
|
||||
_permuter_data: PermuterData
|
||||
_perm_index: int
|
||||
_task_queue: "Queue[Task]"
|
||||
_feedback_queue: "Queue[Feedback]"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: SocketPort,
|
||||
permuter_data: PermuterData,
|
||||
perm_index: int,
|
||||
task_queue: "Queue[Task]",
|
||||
feedback_queue: "Queue[Feedback]",
|
||||
) -> None:
|
||||
self._port = port
|
||||
self._permuter_data = permuter_data
|
||||
self._perm_index = perm_index
|
||||
self._task_queue = task_queue
|
||||
self._feedback_queue = feedback_queue
|
||||
|
||||
def _send_permuter(self) -> None:
|
||||
data = self._permuter_data
|
||||
self._port.send_json(permuter_data_to_json(data))
|
||||
self._port.send(zlib.compress(data.source.encode("utf-8")))
|
||||
self._port.send(zlib.compress(data.target_o_bin))
|
||||
|
||||
def _feedback(self, feedback: FeedbackItem, server_nick: Optional[str]) -> None:
|
||||
self._feedback_queue.put((feedback, self._perm_index, server_nick))
|
||||
|
||||
def _receive_one(self) -> bool:
|
||||
"""Receive a result/progress message and send it on. Returns true if
|
||||
more work should be requested."""
|
||||
msg = self._port.receive_json()
|
||||
msg_type = json_prop(msg, "type", str)
|
||||
if msg_type == "need_work":
|
||||
return True
|
||||
|
||||
server_nick = json_prop(msg, "server", str)
|
||||
if msg_type == "init_done":
|
||||
base_hash = json_prop(msg, "hash", str)
|
||||
my_base_hash = self._permuter_data.base_hash
|
||||
text = "connected"
|
||||
if base_hash != my_base_hash:
|
||||
text += " (note: mismatching hash)"
|
||||
self._feedback(Message(text), server_nick)
|
||||
return True
|
||||
|
||||
if msg_type == "init_failed":
|
||||
text = "failed to initialize: " + json_prop(msg, "reason", str)
|
||||
self._feedback(Message(text), server_nick)
|
||||
return False
|
||||
|
||||
if msg_type == "disconnect":
|
||||
self._feedback(Message("disconnected"), server_nick)
|
||||
return False
|
||||
|
||||
if msg_type == "result":
|
||||
source: Optional[str] = None
|
||||
if msg.get("has_source") == True:
|
||||
# Source is sent separately, compressed, since it can be
|
||||
# large (hundreds of kilobytes is not uncommon).
|
||||
compressed_source = self._port.receive()
|
||||
try:
|
||||
source = zlib.decompress(compressed_source).decode("utf-8")
|
||||
except Exception as e:
|
||||
text = "failed to decompress: " + exception_to_string(e)
|
||||
self._feedback(Message(text), server_nick)
|
||||
return True
|
||||
try:
|
||||
result = _result_from_json(msg, source)
|
||||
self._feedback(WorkDone(self._perm_index, result), server_nick)
|
||||
except Exception as e:
|
||||
text = "failed to parse result message: " + exception_to_string(e)
|
||||
self._feedback(Message(text), server_nick)
|
||||
return True
|
||||
|
||||
raise ValueError(f"Invalid message type {msg_type}")
|
||||
|
||||
def run(self) -> None:
|
||||
finish_reason: Optional[str] = None
|
||||
try:
|
||||
self._send_permuter()
|
||||
self._port.receive_json()
|
||||
|
||||
finished = False
|
||||
|
||||
# Main loop: send messages from the queue on to the server, and
|
||||
# vice versa. Currently we are being lazy and alternate between
|
||||
# sending and receiving; this is nicely simple and keeps us on a
|
||||
# single thread, however it could cause deadlocks if the server
|
||||
# receiver stops reading because we aren't reading fast enough.
|
||||
while True:
|
||||
if not self._receive_one():
|
||||
continue
|
||||
self._feedback(NeedMoreWork(), None)
|
||||
|
||||
# Read a task and send it on, unless there are no more tasks.
|
||||
if not finished:
|
||||
task = self._task_queue.get()
|
||||
if isinstance(task, Finished):
|
||||
# We don't have a way of indicating to the server that
|
||||
# all is done: the server currently doesn't track
|
||||
# outstanding work so it doesn't know when to close
|
||||
# the connection. (Even with this fixed we'll have the
|
||||
# problem that servers may disconnect, losing work, so
|
||||
# the task never truly finishes. But it might work well
|
||||
# enough in practice.)
|
||||
finished = True
|
||||
else:
|
||||
work = {
|
||||
"type": "work",
|
||||
"work": {
|
||||
"seed": task[1],
|
||||
},
|
||||
}
|
||||
self._port.send_json(work)
|
||||
|
||||
except EOFError:
|
||||
finish_reason = "disconnected from permuter@home"
|
||||
|
||||
except Exception as e:
|
||||
errmsg = exception_to_string(e)
|
||||
finish_reason = f"permuter@home error: {errmsg}"
|
||||
|
||||
finally:
|
||||
self._feedback(Finished(reason=finish_reason), None)
|
||||
self._port.shutdown()
|
||||
self._port.close()
|
||||
|
||||
|
||||
def start_client(
|
||||
port: SocketPort,
|
||||
permuter: Permuter,
|
||||
perm_index: int,
|
||||
feedback_queue: "Queue[Feedback]",
|
||||
priority: float,
|
||||
) -> "Tuple[threading.Thread, Queue[Task], Tuple[int, int, float]]":
|
||||
port.send_json(
|
||||
{
|
||||
"method": "connect_client",
|
||||
"priority": priority,
|
||||
}
|
||||
)
|
||||
obj = port.receive_json()
|
||||
if "error" in obj:
|
||||
err = json_prop(obj, "error", str)
|
||||
# TODO use another exception type
|
||||
raise Exception(f"Failed to connect: {err}")
|
||||
num_servers = json_prop(obj, "servers", int)
|
||||
num_clients = json_prop(obj, "clients", int)
|
||||
num_cores = json_prop(obj, "cores", float)
|
||||
permuter_data = make_portable_permuter(permuter)
|
||||
task_queue: "Queue[Task]" = Queue()
|
||||
|
||||
conn = Connection(
|
||||
port,
|
||||
permuter_data,
|
||||
perm_index,
|
||||
task_queue,
|
||||
feedback_queue,
|
||||
)
|
||||
|
||||
thread = threading.Thread(target=conn.run, daemon=True)
|
||||
thread.start()
|
||||
|
||||
stats = (num_clients, num_servers, num_cores)
|
||||
|
||||
return thread, task_queue, stats
|
||||
@@ -1,17 +0,0 @@
|
||||
import abc
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
||||
|
||||
class Command(abc.ABC):
|
||||
command: str
|
||||
help: str
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def add_arguments(parser: ArgumentParser) -> None:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def run(args: Namespace) -> None:
|
||||
...
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 112 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 101 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 110 KiB |
@@ -1,70 +0,0 @@
|
||||
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
||||
import sys
|
||||
|
||||
from ..core import ServerError, enable_debug_mode
|
||||
from .run_server import RunServerCommand
|
||||
from .setup import SetupCommand
|
||||
from .ping import PingCommand
|
||||
from .vouch import VouchCommand
|
||||
|
||||
|
||||
def main() -> None:
|
||||
try:
|
||||
# We currently sometimes log stuff to stdout, so it's preferable if it's
|
||||
# line-buffered even when redirected to a non-tty (e.g. when running a
|
||||
# permuter server as a systemd service). This is supported by Python 3.7
|
||||
# and up.
|
||||
sys.stdout.reconfigure(line_buffering=True) # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
parser = ArgumentParser(
|
||||
description="permuter@home - run the permuter across the Internet!\n\n"
|
||||
"To use p@h as a client, just pass -J when running the permuter. "
|
||||
"This script is\nonly necessary for configuration or when running a server.",
|
||||
formatter_class=RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
commands = [
|
||||
PingCommand,
|
||||
RunServerCommand,
|
||||
SetupCommand,
|
||||
VouchCommand,
|
||||
]
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
dest="debug",
|
||||
action="store_true",
|
||||
help="Enable debug logging.",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(metavar="<command>")
|
||||
for command in commands:
|
||||
subparser = subparsers.add_parser(
|
||||
command.command,
|
||||
help=command.help,
|
||||
description=command.help,
|
||||
)
|
||||
command.add_arguments(subparser)
|
||||
subparser.set_defaults(subcommand_handler=command.run)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.debug:
|
||||
enable_debug_mode()
|
||||
|
||||
if "subcommand_handler" in args:
|
||||
try:
|
||||
args.subcommand_handler(args)
|
||||
except EOFError as e:
|
||||
print("Network error:", e)
|
||||
sys.exit(1)
|
||||
except ServerError as e:
|
||||
print("Error:", e.message)
|
||||
sys.exit(1)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,32 +0,0 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
import time
|
||||
|
||||
from ...helpers import plural
|
||||
from ..core import connect, json_prop
|
||||
from .base import Command
|
||||
|
||||
|
||||
class PingCommand(Command):
|
||||
command = "ping"
|
||||
help = "Check server connectivity."
|
||||
|
||||
@staticmethod
|
||||
def add_arguments(parser: ArgumentParser) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def run(args: Namespace) -> None:
|
||||
run_ping()
|
||||
|
||||
|
||||
def run_ping() -> None:
|
||||
port = connect()
|
||||
t0 = time.time()
|
||||
port.send_json({"method": "ping"})
|
||||
msg = port.receive_json()
|
||||
rtt = (time.time() - t0) * 1000
|
||||
print(f"Connected successfully! Round-trip time: {rtt:.1f} ms")
|
||||
servers_str = plural(json_prop(msg, "servers", int), "server")
|
||||
clients_str = plural(json_prop(msg, "clients", int), "client")
|
||||
cores_str = plural(int(json_prop(msg, "cores", float)), "core")
|
||||
print(f"{servers_str} online ({cores_str}, {clients_str})")
|
||||
@@ -1,616 +0,0 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import queue
|
||||
import random
|
||||
import shutil
|
||||
from subprocess import Popen, PIPE
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import traceback
|
||||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from ...helpers import static_assert_unreachable
|
||||
from ..core import CancelToken, ServerError, read_config
|
||||
from ..server import (
|
||||
Client,
|
||||
Config,
|
||||
IoActivity,
|
||||
IoConnect,
|
||||
IoDisconnect,
|
||||
IoImmediateDisconnect,
|
||||
IoReconnect,
|
||||
IoServerFailed,
|
||||
IoShutdown,
|
||||
IoUserRemovePermuter,
|
||||
IoWorkDone,
|
||||
PermuterHandle,
|
||||
Server,
|
||||
ServerOptions,
|
||||
)
|
||||
from .base import Command
|
||||
from .util import ask
|
||||
|
||||
|
||||
class RunServerCommand(Command):
|
||||
command = "run-server"
|
||||
help = """Run a permuter server, allowing anyone with access to the central
|
||||
server to run sandboxed permuter jobs on your machine. Requires docker."""
|
||||
|
||||
@staticmethod
|
||||
def add_arguments(parser: ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--cores",
|
||||
dest="num_cores",
|
||||
metavar="CORES",
|
||||
type=float,
|
||||
required=True,
|
||||
help="Number of cores to use (float).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--memory",
|
||||
dest="max_memory_gb",
|
||||
metavar="MEMORY_GB",
|
||||
type=float,
|
||||
required=True,
|
||||
help="""Restrict the sandboxed process to the given amount of memory in
|
||||
gigabytes (float). If this limit is hit, the permuter will crash
|
||||
horribly, but at least your system won't lock up.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--systray",
|
||||
dest="systray",
|
||||
action="store_true",
|
||||
help="""Make the server controllable through the system tray.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-priority",
|
||||
dest="min_priority",
|
||||
metavar="PRIORITY",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="""Only accept jobs from clients who pass --priority with a number
|
||||
higher or equal to this value. (default: %(default)s)""",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def run(args: Namespace) -> None:
|
||||
options = ServerOptions(
|
||||
num_cores=args.num_cores,
|
||||
max_memory_gb=args.max_memory_gb,
|
||||
min_priority=args.min_priority,
|
||||
)
|
||||
|
||||
server_main(options, args.systray)
|
||||
|
||||
|
||||
class SystrayState:
|
||||
def server_reconnecting(self) -> None:
|
||||
pass
|
||||
|
||||
def server_connected(self) -> None:
|
||||
pass
|
||||
|
||||
def server_failed(self, graceful: bool, message: Optional[str] = None) -> None:
|
||||
pass
|
||||
|
||||
def connect(self, handle: PermuterHandle, nickname: str, fn_name: str) -> None:
|
||||
pass
|
||||
|
||||
def disconnect(self, handle: PermuterHandle) -> None:
|
||||
pass
|
||||
|
||||
def work_done(self, handle: PermuterHandle, is_improvement: bool) -> None:
|
||||
pass
|
||||
|
||||
def stop(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Permuter:
|
||||
nickname: str
|
||||
fn_name: str
|
||||
iterations: int = 0
|
||||
improvements: int = 0
|
||||
last_systray_update: float = 0.0
|
||||
slot: "Optional[ClientSlot]" = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientSlot:
|
||||
menu_id: int
|
||||
iterations_id: int
|
||||
improvements_id: int
|
||||
stop_id: int
|
||||
permuter: Optional[PermuterHandle] = None
|
||||
|
||||
|
||||
class SystrayStatus(Enum):
|
||||
CONNECTING = 0
|
||||
CONNECTED = 1
|
||||
FAILED = 2
|
||||
RECONNECTING = 3
|
||||
|
||||
|
||||
class RealSystrayState(SystrayState):
|
||||
_CLIENT_SLOTS = 10
|
||||
_UPDATE_INTERVAL = 2.0
|
||||
_MENU_TOOLTIP = "permuter@home"
|
||||
_permuters: Dict[PermuterHandle, Permuter]
|
||||
_onclick: Dict[int, Callable[[], None]]
|
||||
_client_slots: List[ClientSlot]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
io_queue: "queue.Queue[IoActivity]",
|
||||
) -> None:
|
||||
self._io_queue = io_queue
|
||||
self._permuters = {}
|
||||
self._onclick = {}
|
||||
self._status = SystrayStatus.CONNECTING
|
||||
self._fail_message: Optional[str] = None
|
||||
|
||||
def load_icon(fname: str) -> str:
|
||||
path = os.path.join(os.path.dirname(__file__), "icons", fname)
|
||||
with open(path, "rb") as f:
|
||||
data = f.read()
|
||||
return base64.b64encode(data).decode("ascii")
|
||||
|
||||
self._icons = {
|
||||
"working": load_icon("okthink.ico"),
|
||||
"passive": load_icon("ok.ico"),
|
||||
"fail": load_icon("notok.ico"),
|
||||
}
|
||||
self._current_icon = "working"
|
||||
|
||||
next_id = 100
|
||||
|
||||
def add_item(
|
||||
menu: List[dict],
|
||||
title: str,
|
||||
onclick: Optional[Callable[[], None]] = None,
|
||||
*,
|
||||
submenu: Optional[List[dict]] = None,
|
||||
hidden: bool = False,
|
||||
) -> int:
|
||||
nonlocal next_id
|
||||
next_id += 1
|
||||
obj = {
|
||||
"title": title,
|
||||
"enabled": onclick is not None or submenu is not None,
|
||||
"hidden": hidden,
|
||||
"__id": next_id,
|
||||
}
|
||||
if onclick is not None:
|
||||
self._onclick[next_id] = onclick
|
||||
if submenu is not None:
|
||||
obj["items"] = submenu
|
||||
menu.append(obj)
|
||||
return next_id
|
||||
|
||||
menu: List[dict] = []
|
||||
self._status_id = add_item(menu, "Connecting...")
|
||||
self._client_slots = []
|
||||
for i in range(self._CLIENT_SLOTS):
|
||||
submenu: List[dict] = []
|
||||
remove_cb = partial(self._remove_permuter, i)
|
||||
self._client_slots.append(
|
||||
ClientSlot(
|
||||
iterations_id=add_item(submenu, ""),
|
||||
improvements_id=add_item(submenu, ""),
|
||||
stop_id=add_item(submenu, "Stop", remove_cb),
|
||||
menu_id=add_item(menu, "", submenu=submenu, hidden=True),
|
||||
)
|
||||
)
|
||||
self._more_id = add_item(menu, "", hidden=True)
|
||||
add_item(menu, "Quit", self._quit)
|
||||
|
||||
try:
|
||||
path = self._setup_helper()
|
||||
self._proc = Popen(
|
||||
[path],
|
||||
stdout=PIPE,
|
||||
stdin=PIPE,
|
||||
universal_newlines=True,
|
||||
)
|
||||
assert self._proc.stdout is not None
|
||||
self._proc_stdout = self._proc.stdout
|
||||
assert self._proc.stdin is not None
|
||||
self._proc_stdin = self._proc.stdin
|
||||
|
||||
self._send(
|
||||
{
|
||||
"icon": self._icons[self._current_icon],
|
||||
"tooltip": self._MENU_TOOLTIP,
|
||||
"items": menu,
|
||||
}
|
||||
)
|
||||
|
||||
resp_str = self._proc_stdout.readline()
|
||||
assert resp_str
|
||||
resp = json.loads(resp_str)
|
||||
assert isinstance(resp, dict)
|
||||
assert resp.get("type") == "ready"
|
||||
except Exception:
|
||||
print("Failed to initialize systray!")
|
||||
print()
|
||||
print("See src/net/cmd/systray/README.md for details on how to set it up.")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
self._read_thread = threading.Thread(target=self._read_loop, daemon=True)
|
||||
self._read_thread.start()
|
||||
|
||||
@staticmethod
|
||||
def _setup_helper() -> str:
|
||||
fname = "permuter-systray"
|
||||
suffix = ""
|
||||
osname = sys.platform.replace("darwin", "macos")
|
||||
arch = platform.machine().replace("AMD64", "x86_64")
|
||||
if (
|
||||
osname in ("win32", "msys", "cygwin")
|
||||
or "microsoft" in platform.uname().release.lower()
|
||||
):
|
||||
osname = "win"
|
||||
suffix = ".exe"
|
||||
|
||||
dir = os.path.join(os.path.dirname(__file__), "systray")
|
||||
target_binary = os.path.join(dir, fname + suffix)
|
||||
if os.path.exists(target_binary):
|
||||
return target_binary
|
||||
|
||||
prebuilt_file = f"{fname}-{osname}-{arch}{suffix}"
|
||||
prebuilt_file = os.path.join(dir, "prebuilt", prebuilt_file)
|
||||
|
||||
print("An external helper binary is required for systray support.")
|
||||
print(
|
||||
"To build it from source (requires Go), see src/net/cmd/systray/README.md."
|
||||
)
|
||||
|
||||
if os.path.exists(prebuilt_file):
|
||||
print("Alternatively, a pre-built binary can be used.")
|
||||
if ask("Use pre-built binary?", default=False):
|
||||
shutil.copy(prebuilt_file, target_binary)
|
||||
os.chmod(target_binary, 0o755)
|
||||
return target_binary
|
||||
|
||||
print("Aborting.")
|
||||
sys.exit(1)
|
||||
|
||||
def _send(self, msg: dict) -> None:
|
||||
data = json.dumps(msg)
|
||||
self._proc_stdin.write(data + "\n")
|
||||
self._proc_stdin.flush()
|
||||
|
||||
def _update_item(
|
||||
self, id: int, title: str, *, hidden: bool = False, enabled: bool = False
|
||||
) -> None:
|
||||
self._send(
|
||||
{
|
||||
"type": "update-item",
|
||||
"item": {
|
||||
"title": title,
|
||||
"enabled": enabled,
|
||||
"hidden": hidden,
|
||||
"__id": id,
|
||||
},
|
||||
"seq_id": -1,
|
||||
}
|
||||
)
|
||||
|
||||
def _remove_permuter(self, slot_index: int) -> None:
|
||||
slot = self._client_slots[slot_index]
|
||||
if not slot.permuter:
|
||||
return
|
||||
handle = slot.permuter
|
||||
self._io_queue.put((None, (handle, IoUserRemovePermuter())))
|
||||
|
||||
def _quit(self) -> None:
|
||||
self._io_queue.put((None, IoShutdown()))
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
while True:
|
||||
resp_str = self._proc_stdout.readline()
|
||||
if not resp_str:
|
||||
break
|
||||
try:
|
||||
resp = json.loads(resp_str)
|
||||
except Exception:
|
||||
raise Exception(f"Failed to parse systray JSON: {resp_str}") from None
|
||||
if resp["type"] == "clicked":
|
||||
id = resp["__id"]
|
||||
if id in self._onclick:
|
||||
self._onclick[id]()
|
||||
|
||||
def _permuter_slot(self, perm: Permuter) -> Optional[ClientSlot]:
|
||||
for slot in self._client_slots:
|
||||
if slot.permuter is not None and self._permuters[slot.permuter] is perm:
|
||||
return slot
|
||||
return None
|
||||
|
||||
def _update_permuter(self, perm: Permuter, slot: ClientSlot) -> None:
|
||||
self._update_item(
|
||||
slot.iterations_id,
|
||||
f"Iterations: {perm.iterations}",
|
||||
)
|
||||
self._update_item(
|
||||
slot.improvements_id,
|
||||
f"Improvements found: {perm.improvements}",
|
||||
)
|
||||
|
||||
def _update_status(self) -> None:
|
||||
if self._status == SystrayStatus.CONNECTING:
|
||||
status = "Reconnecting..."
|
||||
icon = "working"
|
||||
elif self._status == SystrayStatus.RECONNECTING:
|
||||
status = "Disconnected, will reconnect..."
|
||||
icon = "fail"
|
||||
elif self._status == SystrayStatus.CONNECTED:
|
||||
if self._permuters:
|
||||
status = "Currently permuting:"
|
||||
icon = "working"
|
||||
else:
|
||||
status = "Not running"
|
||||
icon = "passive"
|
||||
elif self._status == SystrayStatus.FAILED:
|
||||
if self._fail_message:
|
||||
status = f"Error: {self._fail_message}"
|
||||
else:
|
||||
status = "Error occurred"
|
||||
icon = "fail"
|
||||
else:
|
||||
assert False, f"bad status {self._status}"
|
||||
|
||||
self._update_item(self._status_id, status)
|
||||
if self._current_icon != icon:
|
||||
self._current_icon = icon
|
||||
self._send(
|
||||
{
|
||||
"type": "update-menu",
|
||||
"menu": {
|
||||
"tooltip": self._MENU_TOOLTIP,
|
||||
"icon": self._icons[icon],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def _fill_slots(self) -> None:
|
||||
has_more = False
|
||||
while True:
|
||||
key = next((k for k, p in self._permuters.items() if p.slot is None), None)
|
||||
if key is None:
|
||||
break
|
||||
chosen_slot: Optional[ClientSlot] = None
|
||||
for i in range(self._CLIENT_SLOTS - 1, -1, -1):
|
||||
slot = self._client_slots[i]
|
||||
if slot.permuter is None:
|
||||
chosen_slot = slot
|
||||
elif chosen_slot is not None:
|
||||
break
|
||||
if chosen_slot is None:
|
||||
has_more = True
|
||||
break
|
||||
perm = self._permuters[key]
|
||||
perm.slot = chosen_slot
|
||||
chosen_slot.permuter = key
|
||||
self._update_permuter(perm, chosen_slot)
|
||||
self._update_item(
|
||||
chosen_slot.menu_id, f"{perm.fn_name} ({perm.nickname})", enabled=True
|
||||
)
|
||||
self._update_item(self._more_id, "More...", hidden=not has_more)
|
||||
|
||||
def _hide_slot(self, slot: ClientSlot) -> None:
|
||||
if slot.permuter is not None:
|
||||
self._update_item(slot.menu_id, "", hidden=True)
|
||||
slot.permuter = None
|
||||
|
||||
def server_reconnecting(self) -> None:
|
||||
self._status = SystrayStatus.CONNECTING
|
||||
self._update_status()
|
||||
|
||||
def server_connected(self) -> None:
|
||||
self._status = SystrayStatus.CONNECTED
|
||||
self._update_status()
|
||||
|
||||
def server_failed(self, graceful: bool, message: Optional[str] = None) -> None:
|
||||
self._status = SystrayStatus.RECONNECTING if graceful else SystrayStatus.FAILED
|
||||
self._fail_message = message
|
||||
self._permuters = {}
|
||||
self._update_status()
|
||||
for slot in self._client_slots:
|
||||
self._hide_slot(slot)
|
||||
self._fill_slots()
|
||||
|
||||
def connect(self, handle: PermuterHandle, nickname: str, fn_name: str) -> None:
|
||||
perm = Permuter(nickname, fn_name)
|
||||
self._permuters[handle] = perm
|
||||
self._fill_slots()
|
||||
self._update_status()
|
||||
|
||||
def disconnect(self, handle: PermuterHandle) -> None:
|
||||
slot = self._permuters[handle].slot
|
||||
del self._permuters[handle]
|
||||
self._update_status()
|
||||
if slot:
|
||||
self._hide_slot(slot)
|
||||
self._fill_slots()
|
||||
|
||||
def work_done(self, handle: PermuterHandle, is_improvement: bool) -> None:
|
||||
perm = self._permuters[handle]
|
||||
perm.iterations += 1
|
||||
if is_improvement:
|
||||
perm.improvements += 1
|
||||
if perm.slot and time.time() > perm.last_systray_update + self._UPDATE_INTERVAL:
|
||||
perm.last_systray_update = time.time()
|
||||
self._update_permuter(perm, perm.slot)
|
||||
|
||||
def stop(self) -> None:
|
||||
try:
|
||||
self._send({"type": "exit"})
|
||||
except BrokenPipeError:
|
||||
# The systray process may have been killed by Ctrl+C.
|
||||
pass
|
||||
self._proc.wait()
|
||||
self._read_thread.join()
|
||||
|
||||
|
||||
class Reconnector:
|
||||
_RESET_BACKOFF_AFTER_UPTIME: float = 60.0
|
||||
_RANDOM_ADDEND_MAX: float = 60.0
|
||||
_BACKOFF_MULTIPLIER: float = 2.0
|
||||
_INITIAL_DELAY: float = 5.0
|
||||
|
||||
_io_queue: "queue.Queue[IoActivity]"
|
||||
_reconnect_token: CancelToken
|
||||
_reconnect_delay: float
|
||||
_reconnect_timer: Optional[threading.Timer]
|
||||
_start_time: float
|
||||
_stop_time: float
|
||||
|
||||
def __init__(self, io_queue: "queue.Queue[IoActivity]") -> None:
|
||||
self._io_queue = io_queue
|
||||
self._reconnect_token = CancelToken()
|
||||
self._reconnect_delay = self._INITIAL_DELAY
|
||||
self._reconnect_timer = None
|
||||
self._start_time = self._stop_time = time.time()
|
||||
|
||||
def mark_start(self) -> None:
|
||||
self._start_time = time.time()
|
||||
|
||||
def mark_stop(self) -> None:
|
||||
self._stop_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._reconnect_token.cancelled = True
|
||||
if self._reconnect_timer is not None:
|
||||
self._reconnect_timer.cancel()
|
||||
self._reconnect_timer.join()
|
||||
self._reconnect_timer = None
|
||||
|
||||
def reconnect_eventually(self) -> int:
|
||||
if self._stop_time - self._start_time > self._RESET_BACKOFF_AFTER_UPTIME:
|
||||
delay = self._reconnect_delay = self._INITIAL_DELAY
|
||||
else:
|
||||
delay = self._reconnect_delay
|
||||
self._reconnect_delay = (
|
||||
self._reconnect_delay * self._BACKOFF_MULTIPLIER
|
||||
+ random.uniform(1.0, self._RANDOM_ADDEND_MAX)
|
||||
)
|
||||
token = CancelToken()
|
||||
self._reconnect_token = token
|
||||
self._reconnect_timer = threading.Timer(
|
||||
delay, lambda: self._io_queue.put((token, IoReconnect()))
|
||||
)
|
||||
self._reconnect_timer.daemon = True
|
||||
self._reconnect_timer.start()
|
||||
return int(delay)
|
||||
|
||||
|
||||
def main_loop(
|
||||
io_queue: "queue.Queue[IoActivity]",
|
||||
server: Server,
|
||||
systray: SystrayState,
|
||||
) -> None:
|
||||
reconnector = Reconnector(io_queue)
|
||||
handle_clients: Dict[PermuterHandle, Client] = {}
|
||||
while True:
|
||||
token, activity = io_queue.get()
|
||||
if token and token.cancelled:
|
||||
continue
|
||||
|
||||
if not isinstance(activity, tuple):
|
||||
if isinstance(activity, IoShutdown):
|
||||
break
|
||||
|
||||
elif isinstance(activity, IoReconnect):
|
||||
print("reconnecting...")
|
||||
try:
|
||||
systray.server_reconnecting()
|
||||
reconnector.mark_start()
|
||||
server.start()
|
||||
systray.server_connected()
|
||||
except EOFError:
|
||||
delay = reconnector.reconnect_eventually()
|
||||
print(f"failed again, reconnecting in {delay} seconds...")
|
||||
systray.server_failed(True)
|
||||
except ServerError as e:
|
||||
print("failed!", e.message)
|
||||
systray.server_failed(False, e.message)
|
||||
except Exception:
|
||||
print("failed!")
|
||||
traceback.print_exc()
|
||||
systray.server_failed(False)
|
||||
|
||||
elif isinstance(activity, IoServerFailed):
|
||||
if activity.message:
|
||||
print("Server error:", activity.message)
|
||||
print("disconnected from permuter@home")
|
||||
server.stop()
|
||||
reconnector.mark_stop()
|
||||
systray.server_failed(activity.graceful, activity.message)
|
||||
|
||||
if activity.graceful:
|
||||
delay = reconnector.reconnect_eventually()
|
||||
print(f"will reconnect in {delay} seconds...")
|
||||
|
||||
else:
|
||||
static_assert_unreachable(activity)
|
||||
|
||||
else:
|
||||
handle, msg = activity
|
||||
|
||||
if isinstance(msg, IoConnect):
|
||||
client = msg.client
|
||||
handle_clients[handle] = client
|
||||
systray.connect(handle, client.nickname, msg.fn_name)
|
||||
print(f"[{client.nickname}] connected ({msg.fn_name})")
|
||||
|
||||
elif isinstance(msg, IoDisconnect):
|
||||
systray.disconnect(handle)
|
||||
nickname = handle_clients[handle].nickname
|
||||
del handle_clients[handle]
|
||||
print(f"[{nickname}] {msg.reason}")
|
||||
|
||||
elif isinstance(msg, IoImmediateDisconnect):
|
||||
print(f"[{msg.client.nickname}] {msg.reason}")
|
||||
|
||||
elif isinstance(msg, IoWorkDone):
|
||||
# TODO: statistics
|
||||
systray.work_done(handle, msg.is_improvement)
|
||||
|
||||
elif isinstance(msg, IoUserRemovePermuter):
|
||||
server.remove_permuter(handle)
|
||||
|
||||
else:
|
||||
static_assert_unreachable(msg)
|
||||
|
||||
|
||||
def server_main(options: ServerOptions, use_systray: bool) -> None:
|
||||
io_queue: "queue.Queue[IoActivity]" = queue.Queue()
|
||||
config = read_config()
|
||||
|
||||
systray: SystrayState
|
||||
if use_systray:
|
||||
systray = RealSystrayState(config, io_queue)
|
||||
else:
|
||||
systray = SystrayState()
|
||||
|
||||
try:
|
||||
server = Server(options, config, io_queue)
|
||||
server.start()
|
||||
|
||||
try:
|
||||
systray.server_connected()
|
||||
main_loop(io_queue, server, systray)
|
||||
finally:
|
||||
server.stop()
|
||||
finally:
|
||||
systray.stop()
|
||||
@@ -1,86 +0,0 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
import base64
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from nacl.public import SealedBox
|
||||
from nacl.signing import SigningKey, VerifyKey
|
||||
|
||||
from .base import Command
|
||||
from ..core import connect, read_config, sign_with_magic, write_config
|
||||
from .util import ask
|
||||
|
||||
|
||||
class SetupCommand(Command):
|
||||
command = "setup"
|
||||
help = """Set up permuter@home. This will require someone else to grant you
|
||||
access to the central server."""
|
||||
|
||||
@staticmethod
|
||||
def add_arguments(parser: ArgumentParser) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def run(args: Namespace) -> None:
|
||||
_run_initial_setup()
|
||||
|
||||
|
||||
def _random_name() -> str:
|
||||
return "".join(random.choice(string.ascii_lowercase) for _ in range(5))
|
||||
|
||||
|
||||
def _run_initial_setup() -> None:
|
||||
config = read_config()
|
||||
signing_key: Optional[SigningKey] = config.signing_key
|
||||
if not signing_key or not ask("Keep previous secret key", default=True):
|
||||
signing_key = SigningKey.generate()
|
||||
config.signing_key = signing_key
|
||||
write_config(config)
|
||||
verify_key = signing_key.verify_key
|
||||
|
||||
nickname: Optional[str] = config.initial_setup_nickname
|
||||
if not nickname or not ask(f"Keep previous nickname [{nickname}]", default=True):
|
||||
default_nickname = os.environ.get("USER") or _random_name()
|
||||
nickname = (
|
||||
input(f"Nickname [default: {default_nickname}]: ") or default_nickname
|
||||
)
|
||||
config.initial_setup_nickname = nickname
|
||||
write_config(config)
|
||||
|
||||
signed_nickname = sign_with_magic(b"NAME", signing_key, nickname.encode("utf-8"))
|
||||
|
||||
vouch_data = verify_key.encode() + signed_nickname
|
||||
vouch_text = base64.b64encode(vouch_data).decode("utf-8")
|
||||
print("Ask someone to run the following command:")
|
||||
print(f"./pah.py vouch {vouch_text}")
|
||||
print()
|
||||
print("They should give you a token back in return. Paste that here:")
|
||||
inp = input().strip()
|
||||
|
||||
try:
|
||||
token = base64.b64decode(inp.encode("utf-8"))
|
||||
data = SealedBox(signing_key.to_curve25519_private_key()).decrypt(token)
|
||||
config.server_address = data[32:].decode("utf-8")
|
||||
config.server_verify_key = VerifyKey(data[:32])
|
||||
config.initial_setup_nickname = None
|
||||
except Exception:
|
||||
print("Invalid token!")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Server: {config.server_address}")
|
||||
print("Testing connection...")
|
||||
|
||||
port = connect(config)
|
||||
port.send_json({"method": "ping"})
|
||||
port.receive_json()
|
||||
|
||||
try:
|
||||
write_config(config)
|
||||
except Exception as e:
|
||||
print("Failed to write config:", e)
|
||||
sys.exit(1)
|
||||
|
||||
print("permuter@home successfully set up!")
|
||||
@@ -1,2 +0,0 @@
|
||||
permuter-systray
|
||||
permuter-systray.exe
|
||||
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2017 Zack Young
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -1,13 +0,0 @@
|
||||
# systray
|
||||
|
||||
This directory contains a Go application that shows a system tray, which the Python code interacts with.
|
||||
|
||||
It is a fork of https://github.com/felixhao28/systray-portable.
|
||||
|
||||
To build it:
|
||||
|
||||
- install Go
|
||||
- if on Linux, install dependencies: `libgtk-3-dev`, `libappindicator3-dev`
|
||||
- run `go build`
|
||||
|
||||
If on Windows, this needs to be done *outside* of WSL.
|
||||
@@ -1,7 +0,0 @@
|
||||
module permuter-systray
|
||||
|
||||
go 1.15
|
||||
|
||||
require github.com/getlantern/systray v1.1.0
|
||||
|
||||
replace github.com/getlantern/systray v1.1.0 => github.com/simonlindholm/systray v1.1.1-0.20210502122945-b7c77212cd56
|
||||
@@ -1,4 +0,0 @@
|
||||
github.com/simonlindholm/systray v1.1.1-0.20210502122945-b7c77212cd56 h1:UZcM1HdV25CQhhJD340jxRLRGl0V11V0wIoUDKTOZMI=
|
||||
github.com/simonlindholm/systray v1.1.1-0.20210502122945-b7c77212cd56/go.mod h1:N5dpnnWiJhCxh+gXuNgDS2p5MjgcVR/TGwWuaDc4gLk=
|
||||
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9 h1:YTzHMGlqJu67/uEo1lBv0n3wBXhXNeUbB1XfN2vmTm0=
|
||||
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,287 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"reflect"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/getlantern/systray"
|
||||
)
|
||||
|
||||
func main() {
|
||||
systray.Run(onReady, onExit)
|
||||
}
|
||||
|
||||
func onExit() {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Item represents an item in the menu
|
||||
type Item struct {
|
||||
Icon string `json:"icon"`
|
||||
Title string `json:"title"`
|
||||
Tooltip string `json:"tooltip"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Checked bool `json:"checked"`
|
||||
Hidden bool `json:"hidden"`
|
||||
Items []Item `json:"items"`
|
||||
InternalID int `json:"__id"`
|
||||
}
|
||||
|
||||
// Menu has an icon, title and list of items
|
||||
type Menu struct {
|
||||
Icon string `json:"icon"`
|
||||
Title string `json:"title"`
|
||||
Tooltip string `json:"tooltip"`
|
||||
Items []Item `json:"items"`
|
||||
}
|
||||
|
||||
// Action for an item?..
|
||||
type Action struct {
|
||||
Type string `json:"type"`
|
||||
Item Item `json:"item"`
|
||||
Menu Menu `json:"menu"`
|
||||
}
|
||||
|
||||
// ClickEvent for an click event
|
||||
type ClickEvent struct {
|
||||
Type string `json:"type"`
|
||||
InternalID int `json:"__id"`
|
||||
}
|
||||
|
||||
func readJSON(reader *bufio.Reader, v interface{}) error {
|
||||
input, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(input) < 1 {
|
||||
return fmt.Errorf("Empty line")
|
||||
}
|
||||
|
||||
lineReader := strings.NewReader(input[0 : len(input)-1])
|
||||
if err := json.NewDecoder(lineReader).Decode(v); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func addMenuItem(items *[]*systray.MenuItem, seqID2InternalID *[]int, internalID2SeqID *map[int]int, item *Item, parent *systray.MenuItem) {
|
||||
if item.Title == "<SEPARATOR>" {
|
||||
systray.AddSeparator()
|
||||
*items = append(*items, nil)
|
||||
} else {
|
||||
var menuItem *systray.MenuItem
|
||||
if parent == nil {
|
||||
menuItem = systray.AddMenuItem(item.Title, item.Tooltip)
|
||||
} else {
|
||||
menuItem = parent.AddSubMenuItem(item.Title, item.Tooltip)
|
||||
}
|
||||
if item.Checked {
|
||||
menuItem.Check()
|
||||
} else {
|
||||
menuItem.Uncheck()
|
||||
}
|
||||
if item.Enabled {
|
||||
menuItem.Enable()
|
||||
} else {
|
||||
menuItem.Disable()
|
||||
}
|
||||
if len(item.Icon) > 0 {
|
||||
icon, err := base64.StdEncoding.DecodeString(item.Icon)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
} else {
|
||||
menuItem.SetIcon(icon)
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(item.Items); i++ {
|
||||
subitem := item.Items[i]
|
||||
addMenuItem(items, seqID2InternalID, internalID2SeqID, &subitem, menuItem)
|
||||
}
|
||||
if item.Hidden {
|
||||
menuItem.Hide()
|
||||
}
|
||||
*items = append(*items, menuItem)
|
||||
}
|
||||
seqID := len(*items) - 1
|
||||
(*internalID2SeqID)[item.InternalID] = seqID
|
||||
*seqID2InternalID = append(*seqID2InternalID, item.InternalID)
|
||||
}
|
||||
|
||||
func onReady() {
|
||||
signalChannel := make(chan os.Signal, 2)
|
||||
signal.Notify(signalChannel, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
for sig := range signalChannel {
|
||||
switch sig {
|
||||
case os.Interrupt, syscall.SIGTERM:
|
||||
// handle SIGINT, SIGTERM
|
||||
fmt.Fprintln(os.Stderr, "Quit")
|
||||
systray.Quit()
|
||||
default:
|
||||
fmt.Fprintln(os.Stderr, "Unhandled signal:", sig)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
items := make([]*systray.MenuItem, 0)
|
||||
seqID2InternalID := make([]int, 0)
|
||||
internalID2SeqID := make(map[int]int)
|
||||
fmt.Println(`{"type": "ready"}`)
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
|
||||
var menu Menu
|
||||
if err := readJSON(reader, &menu); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
systray.Quit()
|
||||
return
|
||||
}
|
||||
|
||||
icon, err := base64.StdEncoding.DecodeString(menu.Icon)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
systray.Quit()
|
||||
return
|
||||
}
|
||||
|
||||
systray.SetIcon(icon)
|
||||
systray.SetTitle(menu.Title)
|
||||
systray.SetTooltip(menu.Tooltip)
|
||||
|
||||
updateItem := func(action Action) {
|
||||
item := action.Item
|
||||
seqID := internalID2SeqID[action.Item.InternalID]
|
||||
menuItem := items[seqID]
|
||||
if menuItem == nil {
|
||||
return
|
||||
}
|
||||
if item.Hidden {
|
||||
menuItem.Hide()
|
||||
} else {
|
||||
if item.Checked {
|
||||
menuItem.Check()
|
||||
} else {
|
||||
menuItem.Uncheck()
|
||||
}
|
||||
if item.Enabled {
|
||||
menuItem.Enable()
|
||||
} else {
|
||||
menuItem.Disable()
|
||||
}
|
||||
menuItem.SetTitle(item.Title)
|
||||
menuItem.SetTooltip(item.Tooltip)
|
||||
if len(item.Icon) > 0 {
|
||||
icon, err := base64.StdEncoding.DecodeString(item.Icon)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
} else {
|
||||
menuItem.SetIcon(icon)
|
||||
}
|
||||
}
|
||||
menuItem.Show()
|
||||
for _, child := range item.Items {
|
||||
seqID = internalID2SeqID[child.InternalID]
|
||||
items[seqID].Show()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
updateMenu := func(action Action) {
|
||||
m := action.Menu
|
||||
if menu.Title != m.Title {
|
||||
menu.Title = m.Title
|
||||
systray.SetTitle(menu.Title)
|
||||
}
|
||||
if menu.Icon != m.Icon && m.Icon != "" {
|
||||
menu.Icon = m.Icon
|
||||
icon, err := base64.StdEncoding.DecodeString(menu.Icon)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
} else {
|
||||
systray.SetIcon(icon)
|
||||
}
|
||||
}
|
||||
if menu.Tooltip != m.Tooltip {
|
||||
menu.Tooltip = m.Tooltip
|
||||
systray.SetTooltip(menu.Tooltip)
|
||||
}
|
||||
}
|
||||
|
||||
update := func(action Action) {
|
||||
switch action.Type {
|
||||
case "update-item":
|
||||
updateItem(action)
|
||||
case "update-menu":
|
||||
updateMenu(action)
|
||||
case "update-item-and-menu":
|
||||
updateItem(action)
|
||||
updateMenu(action)
|
||||
case "exit":
|
||||
systray.Quit()
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(menu.Items); i++ {
|
||||
item := menu.Items[i]
|
||||
addMenuItem(&items, &seqID2InternalID, &internalID2SeqID, &item, nil)
|
||||
}
|
||||
|
||||
go func(reader *bufio.Reader) {
|
||||
for {
|
||||
var action Action
|
||||
if err := readJSON(reader, &action); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
systray.Quit()
|
||||
break
|
||||
}
|
||||
update(action)
|
||||
}
|
||||
}(reader)
|
||||
|
||||
stdoutEnc := json.NewEncoder(os.Stdout)
|
||||
for {
|
||||
itemsCnt := 0
|
||||
for _, ch := range items {
|
||||
if ch != nil {
|
||||
itemsCnt++
|
||||
}
|
||||
}
|
||||
cases := make([]reflect.SelectCase, itemsCnt)
|
||||
caseCnt2SeqID := make([]int, len(items))
|
||||
itemsCnt = 0
|
||||
for i, ch := range items {
|
||||
if ch == nil {
|
||||
continue
|
||||
}
|
||||
cases[itemsCnt] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch.ClickedCh)}
|
||||
caseCnt2SeqID[itemsCnt] = i
|
||||
itemsCnt++
|
||||
}
|
||||
|
||||
remaining := len(cases)
|
||||
for remaining > 0 {
|
||||
chosen, _, ok := reflect.Select(cases)
|
||||
if !ok {
|
||||
// The chosen channel has been closed, so zero out the channel to disable the case
|
||||
cases[chosen].Chan = reflect.ValueOf(nil)
|
||||
remaining--
|
||||
continue
|
||||
}
|
||||
seqID := caseCnt2SeqID[chosen]
|
||||
err := stdoutEnc.Encode(ClickEvent{
|
||||
Type: "clicked",
|
||||
InternalID: seqID2InternalID[seqID],
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
import sys
|
||||
|
||||
|
||||
def ask(msg: str, *, default: bool) -> bool:
|
||||
if default:
|
||||
msg += " (Y/n)? "
|
||||
else:
|
||||
msg += " (y/N)? "
|
||||
res = input(msg).strip().lower()
|
||||
if not res:
|
||||
return default
|
||||
if res in ["y", "yes", "n", "no"]:
|
||||
return res[0] == "y"
|
||||
print("Bad response!")
|
||||
sys.exit(1)
|
||||
@@ -1,73 +0,0 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
import base64
|
||||
import sys
|
||||
|
||||
from nacl.encoding import HexEncoder
|
||||
from nacl.public import SealedBox
|
||||
from nacl.signing import VerifyKey
|
||||
|
||||
from ..core import connect, read_config, verify_with_magic
|
||||
from .base import Command
|
||||
from .util import ask
|
||||
|
||||
|
||||
class VouchCommand(Command):
|
||||
command = "vouch"
|
||||
help = "Give someone access to the central server."
|
||||
|
||||
@staticmethod
|
||||
def add_arguments(parser: ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"magic",
|
||||
help="Opaque hex string generated by 'setup'.",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def run(args: Namespace) -> None:
|
||||
run_vouch(args.magic)
|
||||
|
||||
|
||||
def run_vouch(magic: str) -> None:
|
||||
try:
|
||||
vouch_data = base64.b64decode(magic.encode("utf-8"))
|
||||
verify_key = VerifyKey(vouch_data[:32])
|
||||
signed_nickname = vouch_data[32:]
|
||||
msg = verify_with_magic(b"NAME", verify_key, signed_nickname)
|
||||
nickname = msg.decode("utf-8")
|
||||
except Exception:
|
||||
print("Could not parse data!")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
config = read_config()
|
||||
port = connect(config)
|
||||
port.send_json(
|
||||
{
|
||||
"method": "vouch",
|
||||
"who": verify_key.encode(HexEncoder).decode("utf-8"),
|
||||
"signed_name": HexEncoder.encode(signed_nickname).decode("utf-8"),
|
||||
}
|
||||
)
|
||||
port.receive_json()
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if not ask(f"Grant permuter server access to {nickname}", default=True):
|
||||
return
|
||||
|
||||
try:
|
||||
port.send_json({})
|
||||
port.receive_json()
|
||||
except Exception as e:
|
||||
print(f"Failed to grant access: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
assert config.server_address, "checked by connect"
|
||||
assert config.server_verify_key, "checked by connect"
|
||||
data = config.server_verify_key.encode() + config.server_address.encode("utf-8")
|
||||
token = SealedBox(verify_key.to_curve25519_public_key()).encrypt(data)
|
||||
print("Granted!")
|
||||
print()
|
||||
print("Send them the following token:")
|
||||
print(base64.b64encode(token).decode("utf-8"))
|
||||
@@ -1,3 +0,0 @@
|
||||
target/
|
||||
config.toml
|
||||
*.json
|
||||
-607
@@ -1,607 +0,0 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
[[package]]
|
||||
name = "argh"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91792f088f87cdc7a2cfb1d617fa5ea18d7f1dc22ef0e1b5f82f3157cdc522be"
|
||||
dependencies = [
|
||||
"argh_derive",
|
||||
"argh_shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "argh_derive"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c4eb0c0c120ad477412dc95a4ce31e38f2113e46bd13511253f79196ca68b067"
|
||||
dependencies = [
|
||||
"argh_shared",
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "argh_shared"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "781f336cc9826dbaddb9754cb5db61e64cab4f69668bd19dcc4a0394a86f4cb1"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.67"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3c69b077ad434294d3ce9f1f6143a2a4b89a8a2d54ef813d85003a4fd1137fd"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "670ad68c9088c2a963aaa298cb369688cf3f9465ce5e2d4ca10e6e0098a1ce73"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"time",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c9495705279e7140bf035dde1f6e750c162df8b625267cd52cc44e0b156732c8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87cbf45460356b7deeb5e3415b5563308c0a9b057c85e12b06ad551f98d0a6ac"
|
||||
dependencies = [
|
||||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "322f4de77956e22ed0e5032c359a0f1273f1f7f0d79bfa3b8ffbc730d7fbcc5c"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hex"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||
|
||||
[[package]]
|
||||
name = "instant"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "61124eeebbd69b8190558df225adf7e4caafce0d743919e5d6b19652314ec5ec"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "0.4.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.93"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9385f66bf6105b241aa65a61cb923ef20efc665cb9f9bb50ac2f0c4b7f378d41"
|
||||
|
||||
[[package]]
|
||||
name = "libsodium-sys"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a685b64f837b339074115f2e7f7b431ac73681d08d75b389db7498b8892b8a58"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a3c91c24eae6777794bb1997ad98bbb87daf92890acab859f7eaa4320333176"
|
||||
dependencies = [
|
||||
"scopeguard",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ee1c47aaa256ecabcaea351eae4a9b01ef39ed810004e298d2511ed284b1525"
|
||||
|
||||
[[package]]
|
||||
name = "mio"
|
||||
version = "0.7.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf80d3e903b34e0bd7282b218398aec54e082c840d9baf8339e0080a0c542956"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
"miow",
|
||||
"ntapi",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "miow"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9f1c5b025cda876f66ef43a113f91ebc9f4ccef34843000e0adf6ebbab84e21"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ntapi"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.44"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "af8b08b04175473088b46763e51ee54da5f9a164bc162f615b91bc179dbf15a3"
|
||||
|
||||
[[package]]
|
||||
name = "pahserver"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"argh",
|
||||
"chrono",
|
||||
"hex",
|
||||
"pin-project",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_tuple",
|
||||
"slotmap",
|
||||
"sodiumoxide",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"toml",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d7744ac029df22dca6284efe4e898991d28e3085c706c972bcd7da4a27a15eb"
|
||||
dependencies = [
|
||||
"instant",
|
||||
"lock_api",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot_core"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fa7a782938e745763fe6907fc6ba86946d72f49fe7e21de074e08128a99fb018"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"instant",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
"smallvec",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project"
|
||||
version = "1.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7509cc106041c40a4518d2af7a61530e1eed0e6285296a3d8c5472806ccc4a4"
|
||||
dependencies = [
|
||||
"pin-project-internal",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-internal"
|
||||
version = "1.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48c950132583b500556b1efd71d45b319029f2b71518d979fcc208e16b42426f"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc0e1f259c92177c30a4c9d177246edd0a3568b25756a977d0632cf8fa37e905"
|
||||
|
||||
[[package]]
|
||||
name = "pkg-config"
|
||||
version = "0.3.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3831453b3449ceb48b6d9c7ad7c96d5ea673e9b470a1dc578c2ce6521230884c"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a152013215dca273577e18d2bf00fa862b89b24169fb78c4c95aeb07992c9cec"
|
||||
dependencies = [
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ef9e7e66b4468674bfcb0c81af8b7fa0bb154fa9f28eb840da5c447baeb8d7e"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"rand_chacha",
|
||||
"rand_core",
|
||||
"rand_hc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e12735cf05c9e10bf21534da50a147b924d555dc7a547c42e6bb2d5b6017ae0d"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34cf66eb183df1c5876e2dcf6b13d57340741e8dc255b48e40a26de954d06ae7"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_hc"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3190ef7066a446f2e7f42e239d161e905420ccab01eb967c9eb27d21b2322a73"
|
||||
dependencies = [
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8270314b5ccceb518e7e578952f0b72b88222d02e8f77f5ecf7abbb673539041"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "remove_dir_all"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e"
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.125"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "558dc50e1a5a5fa7112ca2ce4effcb321b0300c0d4ccf0776a9f60cd89031171"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.125"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b093b7a2bb58203b5da3056c05b4ec1fed827dcfdb37347a8841695263b3d06d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.64"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_tuple"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f4f025b91216f15a2a32aa39669329a475733590a015835d1783549a56d09427"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_tuple_macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_tuple_macros"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4076151d1a2b688e25aaf236997933c66e18b870d0369f8b248b8ab2be630d7e"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "16f1d0fef1604ba8f7a073c7e701f213e056707210e9020af4528e0101ce11a6"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "slotmap"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "585cd5dffe4e9e06f6dfdf66708b70aca3f781bed561f4f667b2d9c0d4559e36"
|
||||
dependencies = [
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e"
|
||||
|
||||
[[package]]
|
||||
name = "sodiumoxide"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7038b67c941e23501573cb7242ffb08709abe9b11eb74bceff875bbda024a6a8"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"libsodium-sys",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.69"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48fe99c6bd8b1cc636890bcc071842de909d902c81ac7dab53ba33c421ab8ffb"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dac1c663cfc93810f88aed9b8941d48cabf856a1b111c29a40439018d870eb22"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"rand",
|
||||
"redox_syscall",
|
||||
"remove_dir_all",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.1.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "83f0c8e7c0addab50b663055baf787d0af7f413a46e6e7fb9559a4e4db7137a5"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"bytes",
|
||||
"libc",
|
||||
"memchr",
|
||||
"mio",
|
||||
"num_cpus",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
"tokio-macros",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-macros"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "caf7b11a536f46a809a8a9f0bb4237020f70ecbf115b842360afb127ea2fda57"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.5.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a31142970826733df8241ef35dc040ef98c679ab14d7c3e54d827099b3acecaa"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-segmentation"
|
||||
version = "1.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bb0d2e7be6ae3a5fa87eed5fb451aff96f2573d2694942e40543ae0bbe19c796"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.10.2+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||
dependencies = [
|
||||
"winapi-i686-pc-windows-gnu",
|
||||
"winapi-x86_64-pc-windows-gnu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-i686-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
@@ -1,19 +0,0 @@
|
||||
[package]
|
||||
name = "pahserver"
|
||||
version = "0.0.1"
|
||||
edition = "2018"
|
||||
resolver = "2"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
sodiumoxide = "0.2"
|
||||
toml = "0.5"
|
||||
serde = { version = "1.0", features = ["derive", "rc"] }
|
||||
serde_json = "1.0"
|
||||
serde_tuple = "0.5"
|
||||
hex = "0.4"
|
||||
argh = "0.1"
|
||||
tempfile = "3"
|
||||
slotmap = "1"
|
||||
pin-project = "1"
|
||||
chrono = "*"
|
||||
@@ -1,23 +0,0 @@
|
||||
# controller
|
||||
|
||||
This directory contains code for the central permuter@home controller server,
|
||||
written in Rust. All p@h traffic passes through here.
|
||||
|
||||
If you just want to run a regular p@h server, you don't need to care about this.
|
||||
|
||||
To setup your own copy of the controller server:
|
||||
|
||||
- Install Rust and (for the libsodium dependency) GCC.
|
||||
- Run `cargo build --release`.
|
||||
- Run `./target/release/pahserver setup --db path/to/database.json` and follow
|
||||
the instructions there. This will set the `priv_seed` part of `config.toml`, and
|
||||
set up an initial trusted client. The rest of `config.toml` can be copied from
|
||||
`config_example.toml`.
|
||||
- Set up a reverse proxy that forwards HTTPS traffic from an external port or route
|
||||
to HTTP for a port of your choice, e.g. using Nginx or Traefik.
|
||||
If applicable, configure your firewall to let the external port through.
|
||||
- Start the server with:
|
||||
```
|
||||
./target/release/pahserver run --listen-on 0.0.0.0:<port> --config config.toml --db path/to/database.json
|
||||
```
|
||||
and configure the system to run this at startup.
|
||||
@@ -1,2 +0,0 @@
|
||||
docker_image = ""
|
||||
priv_seed = "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
@@ -1,205 +0,0 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::db::UserId;
|
||||
use crate::flimsy_semaphore::FlimsySemaphore;
|
||||
use crate::port::{ReadPort, WritePort};
|
||||
use crate::stats;
|
||||
use crate::util::SimpleResult;
|
||||
use crate::{
|
||||
current_load, Permuter, PermuterData, PermuterId, PermuterResult, PermuterWork, ServerUpdate,
|
||||
State,
|
||||
};
|
||||
|
||||
const MIN_PERMUTER_VERSION: u32 = 1;
|
||||
|
||||
const CLIENT_MAX_QUEUES_SIZE: usize = 100;
|
||||
const MIN_PRIORITY: f64 = 0.001;
|
||||
const MAX_PRIORITY: f64 = 10.0;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct ConnectClientData {
|
||||
priority: f64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum ClientMessage {
|
||||
Work { work: PermuterWork },
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct PermuterResultMessage<'a> {
|
||||
server: String,
|
||||
#[serde(flatten)]
|
||||
update: &'a ServerUpdate,
|
||||
}
|
||||
|
||||
async fn client_read(
|
||||
port: &mut ReadPort<'_>,
|
||||
perm_id: &PermuterId,
|
||||
semaphore: &FlimsySemaphore,
|
||||
state: &State,
|
||||
) -> SimpleResult<()> {
|
||||
loop {
|
||||
let msg = port.recv().await?;
|
||||
let msg: ClientMessage = serde_json::from_slice(&msg)?;
|
||||
let ClientMessage::Work { work } = msg;
|
||||
|
||||
// Avoid the work and result queues growing indefinitely by restricting
|
||||
// their combined size with a semaphore.
|
||||
semaphore.acquire().await;
|
||||
|
||||
let mut m = state.m.lock().unwrap();
|
||||
let perm = m.permuters.get_mut(perm_id).unwrap();
|
||||
if perm.work_queue.is_empty() {
|
||||
state.new_work_notification.notify_waiters();
|
||||
}
|
||||
perm.work_queue.push_back(work);
|
||||
}
|
||||
}
|
||||
|
||||
async fn client_write(
|
||||
port: &mut WritePort<'_>,
|
||||
fn_name: &str,
|
||||
semaphore: &FlimsySemaphore,
|
||||
state: &State,
|
||||
mut result_rx: mpsc::UnboundedReceiver<PermuterResult>,
|
||||
client_id: &UserId,
|
||||
) -> SimpleResult<()> {
|
||||
loop {
|
||||
let res = result_rx.recv().await.unwrap();
|
||||
semaphore.release();
|
||||
|
||||
match res {
|
||||
PermuterResult::NeedWork => {
|
||||
port.send_json(&json!({
|
||||
"type": "need_work",
|
||||
}))
|
||||
.await?;
|
||||
}
|
||||
PermuterResult::Result(server_id, server_name, server_update) => {
|
||||
port.send_json(&PermuterResultMessage {
|
||||
server: server_name,
|
||||
update: &server_update,
|
||||
})
|
||||
.await?;
|
||||
|
||||
if let ServerUpdate::Result {
|
||||
compressed_source,
|
||||
ref more_props,
|
||||
..
|
||||
} = server_update
|
||||
{
|
||||
if let Some(ref data) = compressed_source {
|
||||
port.send(data).await?;
|
||||
}
|
||||
|
||||
let score = more_props.get("score").and_then(|score| score.as_i64());
|
||||
let outcome = if compressed_source.is_none() {
|
||||
stats::Outcome::Unhelpful
|
||||
} else if matches!(score, Some(0)) {
|
||||
stats::Outcome::Matched
|
||||
} else {
|
||||
stats::Outcome::Improved
|
||||
};
|
||||
state
|
||||
.log_stats(stats::Record::WorkDone {
|
||||
server: server_id,
|
||||
client: client_id.clone(),
|
||||
fn_name: fn_name.to_string(),
|
||||
outcome,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_connect_client<'a>(
|
||||
mut read_port: ReadPort<'a>,
|
||||
mut write_port: WritePort<'a>,
|
||||
who_id: UserId,
|
||||
who_name: &str,
|
||||
permuter_version: u32,
|
||||
state: &State,
|
||||
data: ConnectClientData,
|
||||
) -> SimpleResult<()> {
|
||||
if permuter_version < MIN_PERMUTER_VERSION {
|
||||
Err("Permuter version too old!")?;
|
||||
}
|
||||
|
||||
if !(MIN_PRIORITY <= data.priority && data.priority <= MAX_PRIORITY) {
|
||||
Err("Priority out of range")?;
|
||||
}
|
||||
|
||||
let load = current_load(state, Some(data.priority));
|
||||
write_port.send_json(&load).await?;
|
||||
|
||||
let permuter_data = read_port.recv().await?;
|
||||
let mut permuter_data: PermuterData = serde_json::from_slice(&permuter_data)?;
|
||||
permuter_data.compressed_source = read_port.recv().await?;
|
||||
permuter_data.compressed_target_o_bin = read_port.recv().await?;
|
||||
write_port.send_json(&json!({})).await?;
|
||||
|
||||
eprintln!(
|
||||
"[{}] start client ({}, {})",
|
||||
&who_name, &permuter_data.fn_name, data.priority
|
||||
);
|
||||
|
||||
state
|
||||
.log_stats(stats::Record::ClientNewFunction {
|
||||
client: who_id.clone(),
|
||||
fn_name: permuter_data.fn_name.clone(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let energy_add = 1.0 / data.priority;
|
||||
let fn_name = permuter_data.fn_name.clone();
|
||||
|
||||
let (result_tx, result_rx) = mpsc::unbounded_channel();
|
||||
let semaphore = Arc::new(FlimsySemaphore::new(CLIENT_MAX_QUEUES_SIZE));
|
||||
|
||||
let perm_id = {
|
||||
let mut m = state.m.lock().unwrap();
|
||||
let id = m.next_permuter_id;
|
||||
m.next_permuter_id += 1;
|
||||
m.permuters.insert(
|
||||
id,
|
||||
Permuter {
|
||||
data: permuter_data.into(),
|
||||
client_id: who_id.clone(),
|
||||
client_name: who_name.to_string(),
|
||||
work_queue: VecDeque::new(),
|
||||
result_tx: result_tx.clone(),
|
||||
semaphore: semaphore.clone(),
|
||||
priority: data.priority,
|
||||
energy_add,
|
||||
},
|
||||
);
|
||||
state.new_work_notification.notify_waiters();
|
||||
id
|
||||
};
|
||||
|
||||
let r = tokio::try_join!(
|
||||
client_read(&mut read_port, &perm_id, &semaphore, state),
|
||||
client_write(
|
||||
&mut write_port,
|
||||
&fn_name,
|
||||
&semaphore,
|
||||
state,
|
||||
result_rx,
|
||||
&who_id
|
||||
)
|
||||
);
|
||||
|
||||
state.m.lock().unwrap().permuters.remove(&perm_id);
|
||||
state.new_work_notification.notify_waiters();
|
||||
r?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryInto;
|
||||
|
||||
use hex::FromHex;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use serde_tuple::{Deserialize_tuple, Serialize_tuple};
|
||||
use sodiumoxide::crypto::sign;
|
||||
|
||||
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
|
||||
pub struct ByteString<const SIZE: usize>([u8; SIZE]);
|
||||
|
||||
impl<const SIZE: usize> ByteString<SIZE> {
|
||||
fn to_hex(&self) -> String {
|
||||
hex::encode(&self.0)
|
||||
}
|
||||
|
||||
fn from_hex(string: &str) -> Result<ByteString<SIZE>, &'static str> {
|
||||
Ok(ByteString(
|
||||
Vec::from_hex(&string)
|
||||
.map_err(|_| "not a valid hex string")?
|
||||
.try_into()
|
||||
.map_err(|_| "byte string has wrong size")?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<const SIZE: usize> Serialize for ByteString<SIZE> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(&self.to_hex())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, const SIZE: usize> Deserialize<'de> for ByteString<SIZE> {
|
||||
fn deserialize<D>(deserializer: D) -> Result<ByteString<SIZE>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
use serde::de::Error;
|
||||
let string = String::deserialize(deserializer)?;
|
||||
ByteString::from_hex(&string).map_err(Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
pub type UserId = ByteString<32>;
|
||||
|
||||
impl UserId {
|
||||
pub fn from_pubkey(key: &sign::PublicKey) -> UserId {
|
||||
ByteString(key.as_ref().try_into().unwrap())
|
||||
}
|
||||
|
||||
pub fn to_pubkey(&self) -> sign::PublicKey {
|
||||
sign::PublicKey::from_slice(&self.0).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const SIZE: usize> ByteString<SIZE> {
|
||||
pub fn to_seed(&self) -> sign::Seed {
|
||||
sign::Seed::from_slice(&self.0).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize_tuple, Serialize_tuple)]
|
||||
pub struct Stats {
|
||||
pub iterations: u64,
|
||||
pub improvements: u64,
|
||||
pub matches: u64,
|
||||
pub functions: u64,
|
||||
}
|
||||
|
||||
impl Default for Stats {
|
||||
fn default() -> Stats {
|
||||
Stats {
|
||||
iterations: 0,
|
||||
improvements: 0,
|
||||
matches: 0,
|
||||
functions: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct User {
|
||||
pub trusted_by: Option<UserId>,
|
||||
pub name: String,
|
||||
pub client_stats: Stats,
|
||||
pub server_stats: Stats,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct DB {
|
||||
pub users: HashMap<UserId, User>,
|
||||
pub func_stats: HashMap<String, Stats>,
|
||||
pub total_stats: Stats,
|
||||
}
|
||||
|
||||
impl DB {
|
||||
pub fn func_stat(&mut self, fn_name: String) -> &mut Stats {
|
||||
self.func_stats
|
||||
.entry(fn_name)
|
||||
.or_insert_with(Stats::default)
|
||||
}
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
use std::convert::TryInto;
|
||||
use std::sync::atomic::{AtomicIsize, Ordering};
|
||||
|
||||
use tokio::sync::Notify;
|
||||
|
||||
/// An unfair semaphore that allows overdrafts.
|
||||
pub struct FlimsySemaphore {
|
||||
notify: Notify,
|
||||
slots: AtomicIsize,
|
||||
}
|
||||
|
||||
impl FlimsySemaphore {
|
||||
// Invariant: if `slots` has ever become non-positive, then if positive
|
||||
// there will be a notify token in circulation. Taking the token
|
||||
// synchronizes with a positive `slots`.
|
||||
pub fn new(limit: usize) -> FlimsySemaphore {
|
||||
FlimsySemaphore {
|
||||
notify: Notify::new(),
|
||||
slots: AtomicIsize::new(limit.try_into().unwrap()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn acquire_ignore_limit(&self) {
|
||||
self.slots.fetch_add(-1, Ordering::Acquire);
|
||||
}
|
||||
|
||||
pub async fn acquire(&self) {
|
||||
let mut was_woken = false;
|
||||
let mut val = self.slots.load(Ordering::Relaxed);
|
||||
loop {
|
||||
if val > 0 {
|
||||
match self.slots.compare_exchange(
|
||||
val,
|
||||
val - 1,
|
||||
Ordering::Acquire,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => {
|
||||
if was_woken && val > 1 {
|
||||
self.notify.notify_one();
|
||||
}
|
||||
return;
|
||||
}
|
||||
Err(actually) => {
|
||||
val = actually;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.notify.notified().await;
|
||||
was_woken = true;
|
||||
val = self.slots.load(Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn release(&self) {
|
||||
if self.slots.fetch_add(1, Ordering::Release) == 0 {
|
||||
self.notify.notify_one();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,418 +0,0 @@
|
||||
#![allow(clippy::try_err)]
|
||||
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::convert::TryInto;
|
||||
use std::default::Default;
|
||||
use std::io::ErrorKind;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use argh::FromArgs;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use slotmap::{new_key_type, SlotMap};
|
||||
use sodiumoxide::crypto::box_;
|
||||
use sodiumoxide::crypto::sign;
|
||||
use tokio::fs;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::tcp::{ReadHalf, WriteHalf};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::{mpsc, watch, Notify};
|
||||
use tokio::time;
|
||||
|
||||
use crate::db::{ByteString, UserId};
|
||||
use crate::flimsy_semaphore::FlimsySemaphore;
|
||||
use crate::port::{ReadPort, WritePort};
|
||||
use crate::save::SaveableDB;
|
||||
use crate::util::SimpleResult;
|
||||
|
||||
mod client;
|
||||
mod db;
|
||||
mod flimsy_semaphore;
|
||||
mod port;
|
||||
mod save;
|
||||
mod server;
|
||||
mod setup;
|
||||
mod stats;
|
||||
mod util;
|
||||
mod vouch;
|
||||
|
||||
const HEARTBEAT_TIME: time::Duration = time::Duration::from_secs(300);
|
||||
|
||||
#[derive(FromArgs)]
|
||||
/// The permuter@home control server.
|
||||
struct CmdOpts {
|
||||
#[argh(subcommand)]
|
||||
sub: SubCommand,
|
||||
}
|
||||
|
||||
#[derive(FromArgs)]
|
||||
#[argh(subcommand)]
|
||||
enum SubCommand {
|
||||
RunServer(RunServerOpts),
|
||||
Setup(SetupOpts),
|
||||
}
|
||||
|
||||
#[derive(FromArgs)]
|
||||
/// Run the permuter@home control server.
|
||||
#[argh(subcommand, name = "run")]
|
||||
struct RunServerOpts {
|
||||
/// ip:port to listen on (e.g. 0.0.0.0:1234)
|
||||
#[argh(option)]
|
||||
listen_on: String,
|
||||
|
||||
/// path to TOML configuration file
|
||||
#[argh(option)]
|
||||
config: String,
|
||||
|
||||
/// path to JSON database
|
||||
#[argh(option)]
|
||||
db: String,
|
||||
|
||||
/// enable debug logging
|
||||
#[argh(switch)]
|
||||
debug: bool,
|
||||
}
|
||||
|
||||
#[derive(FromArgs)]
|
||||
/// Setup initial database and config for permuter@home.
|
||||
#[argh(subcommand, name = "setup")]
|
||||
struct SetupOpts {
|
||||
/// path to JSON database
|
||||
#[argh(option)]
|
||||
db: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Config {
|
||||
docker_image: String,
|
||||
priv_seed: ByteString<32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct PermuterData {
|
||||
fn_name: String,
|
||||
#[serde(skip)]
|
||||
compressed_source: Vec<u8>,
|
||||
#[serde(skip)]
|
||||
compressed_target_o_bin: Vec<u8>,
|
||||
#[serde(flatten)]
|
||||
more_props: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Serialize)]
|
||||
struct PermuterWork {
|
||||
seed: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum ServerUpdate {
|
||||
Result {
|
||||
#[serde(skip_serializing, default)]
|
||||
overhead_us: i64,
|
||||
#[serde(skip)]
|
||||
compressed_source: Option<Vec<u8>>,
|
||||
#[serde(default)]
|
||||
has_source: bool,
|
||||
#[serde(flatten)]
|
||||
more_props: HashMap<String, serde_json::Value>,
|
||||
},
|
||||
InitDone {
|
||||
hash: String,
|
||||
},
|
||||
InitFailed {
|
||||
reason: String,
|
||||
},
|
||||
Disconnect,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum PermuterResult {
|
||||
NeedWork,
|
||||
Result(UserId, String, ServerUpdate),
|
||||
}
|
||||
|
||||
type PermuterId = u64;
|
||||
|
||||
struct Permuter {
|
||||
data: Arc<PermuterData>,
|
||||
client_id: UserId,
|
||||
client_name: String,
|
||||
work_queue: VecDeque<PermuterWork>,
|
||||
result_tx: mpsc::UnboundedSender<PermuterResult>,
|
||||
semaphore: Arc<FlimsySemaphore>,
|
||||
priority: f64,
|
||||
energy_add: f64,
|
||||
}
|
||||
|
||||
impl Permuter {
|
||||
fn send_result(&mut self, res: PermuterResult) {
|
||||
// We can't use a blocking semaphore acquire here, because we don't
|
||||
// want server sends to block on random client receives. In practice,
|
||||
// this is probably fine.
|
||||
let _ = self.result_tx.send(res);
|
||||
self.semaphore.acquire_ignore_limit();
|
||||
}
|
||||
}
|
||||
|
||||
new_key_type! { struct ServerId; }
|
||||
|
||||
struct ConnectedServer {
|
||||
min_priority: f64,
|
||||
num_cores: f64,
|
||||
}
|
||||
|
||||
struct MutableState {
|
||||
servers: SlotMap<ServerId, ConnectedServer>,
|
||||
permuters: HashMap<PermuterId, Permuter>,
|
||||
next_permuter_id: PermuterId,
|
||||
}
|
||||
|
||||
struct State {
|
||||
docker_image: String,
|
||||
debug: bool,
|
||||
sign_sk: sign::SecretKey,
|
||||
db: SaveableDB,
|
||||
stats_tx: mpsc::Sender<stats::Record>,
|
||||
heartbeat_rx: watch::Receiver<()>,
|
||||
new_work_notification: Notify,
|
||||
m: Mutex<MutableState>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
async fn log_stats(&self, record: stats::Record) -> SimpleResult<()> {
|
||||
self.stats_tx
|
||||
.send(record)
|
||||
.await
|
||||
.map_err(|_| "stats thread died".into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "method", rename_all = "snake_case")]
|
||||
enum Request {
|
||||
Ping,
|
||||
Vouch(vouch::VouchData),
|
||||
ConnectServer(server::ConnectServerData),
|
||||
ConnectClient(client::ConnectClientData),
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Load {
|
||||
clients: usize,
|
||||
servers: usize,
|
||||
cores: f64,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> SimpleResult<()> {
|
||||
sodiumoxide::init().map_err(|()| "Failed to initialize cryptography library")?;
|
||||
|
||||
let opts: CmdOpts = argh::from_env();
|
||||
|
||||
match opts.sub {
|
||||
SubCommand::RunServer(opts) => run_server(opts).await?,
|
||||
SubCommand::Setup(opts) => setup::run_setup(opts)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_server(opts: RunServerOpts) -> SimpleResult<()> {
|
||||
let config: Config = toml::from_str(&fs::read_to_string(&opts.config).await?)?;
|
||||
let (_, sign_sk) = sign::keypair_from_seed(&config.priv_seed.to_seed());
|
||||
|
||||
let (save_fut, db) = SaveableDB::open(&opts.db)?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = save_fut.await {
|
||||
eprintln!("Failed to save! {:?}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
});
|
||||
|
||||
let (stats_fut, stats_tx) = stats::stats_thread(&db);
|
||||
tokio::spawn(stats_fut);
|
||||
|
||||
let (heartbeat_tx, heartbeat_rx) = watch::channel(());
|
||||
|
||||
let state: &'static State = Box::leak(Box::new(State {
|
||||
docker_image: config.docker_image,
|
||||
debug: opts.debug,
|
||||
sign_sk,
|
||||
db,
|
||||
stats_tx,
|
||||
heartbeat_rx,
|
||||
new_work_notification: Notify::new(),
|
||||
m: Mutex::new(MutableState {
|
||||
servers: SlotMap::with_key(),
|
||||
permuters: HashMap::new(),
|
||||
next_permuter_id: 0,
|
||||
}),
|
||||
}));
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
heartbeat_tx.send(()).expect("receiver is still alive");
|
||||
time::sleep(HEARTBEAT_TIME).await;
|
||||
}
|
||||
});
|
||||
|
||||
let listener = TcpListener::bind(opts.listen_on).await?;
|
||||
|
||||
loop {
|
||||
let (socket, _) = listener.accept().await?;
|
||||
tokio::spawn(async move {
|
||||
let mut who = "anonymous".to_string();
|
||||
if let Err(e) = handle_connection(socket, state, &mut who).await {
|
||||
if let Some(e) = e.downcast_ref::<std::io::Error>() {
|
||||
if matches!(
|
||||
e.kind(),
|
||||
ErrorKind::UnexpectedEof
|
||||
| ErrorKind::ConnectionReset
|
||||
| ErrorKind::TimedOut
|
||||
| ErrorKind::BrokenPipe
|
||||
) {
|
||||
eprintln!("[{}] disconnected", &who);
|
||||
return;
|
||||
}
|
||||
}
|
||||
eprintln!("[{}] error: {:?}", &who, e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn concat<T: Clone>(a: &[T], b: &[T]) -> Vec<T> {
|
||||
a.iter().chain(b).cloned().collect()
|
||||
}
|
||||
|
||||
fn concat3<T: Clone>(a: &[T], b: &[T], c: &[T]) -> Vec<T> {
|
||||
a.iter().chain(b).chain(c).cloned().collect()
|
||||
}
|
||||
|
||||
async fn handshake<'a>(
|
||||
mut rd: ReadHalf<'a>,
|
||||
mut wr: WriteHalf<'a>,
|
||||
sign_sk: &sign::SecretKey,
|
||||
) -> SimpleResult<(ReadPort<'a>, WritePort<'a>, UserId, u32)> {
|
||||
let mut buffer = [0; 4 + 32];
|
||||
rd.read_exact(&mut buffer).await?;
|
||||
let (magic, their_pk) = buffer.split_at(4);
|
||||
if magic != b"p@h0" {
|
||||
Err("Invalid protocol version")?;
|
||||
}
|
||||
let their_pk = box_::PublicKey::from_slice(&their_pk).unwrap();
|
||||
|
||||
let (our_pk, our_sk) = box_::gen_keypair();
|
||||
let signed_data = concat3(b"HELLO:", their_pk.as_ref(), our_pk.as_ref());
|
||||
let signature = sign::sign_detached(&signed_data, &sign_sk);
|
||||
wr.write_all(&concat(our_pk.as_ref(), signature.as_ref()))
|
||||
.await?;
|
||||
|
||||
let key = box_::precompute(&their_pk, &our_sk);
|
||||
let mut read_port = ReadPort::new(rd, &key);
|
||||
let write_port = WritePort::new(wr, &key);
|
||||
|
||||
let reply = read_port.recv().await?;
|
||||
if reply.len() != 32 + 64 + 4 {
|
||||
Err("Failed to perform secret handshake")?;
|
||||
}
|
||||
let (client_ver_key, rest) = reply.split_at(32);
|
||||
let (client_signature, permuter_version) = rest.split_at(64);
|
||||
let client_ver_key = sign::PublicKey::from_slice(client_ver_key).unwrap();
|
||||
let client_signature = sign::Signature::from_slice(client_signature).unwrap();
|
||||
let permuter_version = u32::from_be_bytes(permuter_version.try_into().unwrap());
|
||||
let signed_data = concat(b"WORLD:", our_pk.as_ref());
|
||||
if !sign::verify_detached(&client_signature, &signed_data, &client_ver_key) {
|
||||
Err("Spoofed client signature!")?;
|
||||
}
|
||||
|
||||
Ok((
|
||||
read_port,
|
||||
write_port,
|
||||
UserId::from_pubkey(&client_ver_key),
|
||||
permuter_version,
|
||||
))
|
||||
}
|
||||
|
||||
fn current_load(state: &State, priority: Option<f64>) -> Load {
|
||||
let m = state.m.lock().unwrap();
|
||||
let mut servers: usize = 0;
|
||||
let mut cores: f64 = 0.0;
|
||||
for server in m.servers.values() {
|
||||
if priority.map_or(true, |p| p >= server.min_priority) {
|
||||
servers += 1;
|
||||
cores += server.num_cores;
|
||||
}
|
||||
}
|
||||
Load {
|
||||
clients: m.permuters.len(),
|
||||
servers,
|
||||
cores,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_connection(
|
||||
mut socket: TcpStream,
|
||||
state: &State,
|
||||
out_name: &mut String,
|
||||
) -> SimpleResult<()> {
|
||||
let (rd, wr) = socket.split();
|
||||
let (mut read_port, mut write_port, user_id, permuter_version) =
|
||||
handshake(rd, wr, &state.sign_sk).await?;
|
||||
let name = match state.db.read(|db| {
|
||||
let user = db.users.get(&user_id)?;
|
||||
Some(user.name.clone())
|
||||
}) {
|
||||
Some(name) => name,
|
||||
None => {
|
||||
write_port.send_error("Access denied!").await?;
|
||||
Err("Unknown client!")?
|
||||
}
|
||||
};
|
||||
*out_name = name.clone();
|
||||
eprintln!("[{}] connected (v {})", &name, permuter_version);
|
||||
if state.debug {
|
||||
read_port.set_debug(&name);
|
||||
write_port.set_debug(&name);
|
||||
}
|
||||
write_port.send_json(&json!({})).await?;
|
||||
|
||||
let request = read_port.recv().await?;
|
||||
let request: Request = serde_json::from_slice(&request)?;
|
||||
match request {
|
||||
Request::Ping => {
|
||||
eprintln!("[{}] ping", &name);
|
||||
let load = current_load(state, None);
|
||||
write_port.send_json(&load).await?;
|
||||
}
|
||||
Request::Vouch(data) => {
|
||||
vouch::handle_vouch(read_port, write_port, user_id, &name, state, data).await?;
|
||||
}
|
||||
Request::ConnectServer(data) => {
|
||||
server::handle_connect_server(
|
||||
read_port,
|
||||
write_port,
|
||||
user_id,
|
||||
&name,
|
||||
permuter_version,
|
||||
state,
|
||||
data,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Request::ConnectClient(data) => {
|
||||
client::handle_connect_client(
|
||||
read_port,
|
||||
write_port,
|
||||
user_id,
|
||||
&name,
|
||||
permuter_version,
|
||||
state,
|
||||
data,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,115 +0,0 @@
|
||||
use std::convert::TryInto;
|
||||
|
||||
use chrono::Local;
|
||||
use serde::Serialize;
|
||||
use sodiumoxide::crypto::box_;
|
||||
use sodiumoxide::crypto::box_::{Nonce, PrecomputedKey};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::tcp::{ReadHalf, WriteHalf};
|
||||
|
||||
use crate::util::SimpleResult;
|
||||
|
||||
fn debug_print(action: &str, who: &str, msg: &[u8]) {
|
||||
let time = Local::now().format("%H:%M:%S:%f");
|
||||
if msg.len() <= 300 {
|
||||
let msg = String::from_utf8(
|
||||
msg.iter()
|
||||
.copied()
|
||||
.flat_map(std::ascii::escape_default)
|
||||
.collect(),
|
||||
)
|
||||
.unwrap();
|
||||
println!("{} debug: {} {}: {}", time, action, who, msg);
|
||||
} else {
|
||||
println!("{} debug: {} {}: {} bytes", time, action, who, msg.len());
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReadPort<'a> {
|
||||
read_half: ReadHalf<'a>,
|
||||
key: PrecomputedKey,
|
||||
nonce: u64,
|
||||
debug_name: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a> ReadPort<'a> {
|
||||
pub fn new(read_half: ReadHalf<'a>, key: &PrecomputedKey) -> Self {
|
||||
ReadPort {
|
||||
read_half,
|
||||
key: key.clone(),
|
||||
nonce: 0,
|
||||
debug_name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_debug(&mut self, name: &'a str) {
|
||||
self.debug_name = Some(name);
|
||||
}
|
||||
|
||||
pub async fn recv(&mut self) -> SimpleResult<Vec<u8>> {
|
||||
let len = self.read_half.read_u64().await?;
|
||||
if len >= (1 << 48) {
|
||||
Err("Unreasonable packet length")?
|
||||
}
|
||||
let mut buffer = vec![0; len.try_into()?];
|
||||
self.read_half.read_exact(&mut buffer).await?;
|
||||
let nonce = nonce_from_u64(self.nonce);
|
||||
self.nonce += 2;
|
||||
let data =
|
||||
box_::open_precomputed(&buffer, &nonce, &self.key).map_err(|()| "Failed to decrypt")?;
|
||||
if let Some(name) = self.debug_name {
|
||||
debug_print("Receive from", name, &data);
|
||||
}
|
||||
Ok(data)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WritePort<'a> {
|
||||
write_half: WriteHalf<'a>,
|
||||
key: PrecomputedKey,
|
||||
nonce: u64,
|
||||
debug_name: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a> WritePort<'a> {
|
||||
pub fn new(write_half: WriteHalf<'a>, key: &PrecomputedKey) -> Self {
|
||||
WritePort {
|
||||
write_half,
|
||||
key: key.clone(),
|
||||
nonce: 1,
|
||||
debug_name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_debug(&mut self, name: &'a str) {
|
||||
self.debug_name = Some(name);
|
||||
}
|
||||
|
||||
pub async fn send(&mut self, data: &[u8]) -> SimpleResult<()> {
|
||||
if let Some(name) = self.debug_name {
|
||||
debug_print("Send to", name, &data);
|
||||
}
|
||||
let nonce = nonce_from_u64(self.nonce);
|
||||
self.nonce += 2;
|
||||
let data = box_::seal_precomputed(data, &nonce, &self.key);
|
||||
self.write_half.write_u64(data.len() as u64).await?;
|
||||
self.write_half.write_all(&data).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_json<T: ?Sized>(&mut self, value: &T) -> SimpleResult<()>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
self.send(&serde_json::to_vec(value)?).await
|
||||
}
|
||||
|
||||
pub async fn send_error(&mut self, message: &str) -> SimpleResult<()> {
|
||||
self.send_json(message).await
|
||||
}
|
||||
}
|
||||
|
||||
fn nonce_from_u64(num: u64) -> Nonce {
|
||||
let nonce_bytes = [[0; 8], [0; 8], num.to_be_bytes()].concat();
|
||||
Nonce::from_slice(&nonce_bytes).unwrap()
|
||||
}
|
||||
@@ -1,158 +0,0 @@
|
||||
use std::future::Future;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use tempfile::NamedTempFile;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::db::DB;
|
||||
use crate::util::{FutureExt, SimpleResult};
|
||||
|
||||
const SAVE_INTERVAL: Duration = Duration::from_secs(30);
|
||||
|
||||
enum SaveType {
|
||||
Delayed,
|
||||
Immediate(oneshot::Sender<()>),
|
||||
}
|
||||
|
||||
struct InnerSaveableDB {
|
||||
db: DB,
|
||||
stale: bool,
|
||||
save_chan: mpsc::UnboundedSender<SaveType>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SaveableDB(Arc<RwLock<InnerSaveableDB>>);
|
||||
|
||||
async fn save_db_loop(
|
||||
db: SaveableDB,
|
||||
path: &Path,
|
||||
mut save_channel: mpsc::UnboundedReceiver<SaveType>,
|
||||
) -> SimpleResult<()> {
|
||||
loop {
|
||||
let mut done_chans = Vec::new();
|
||||
match save_channel.recv().await {
|
||||
None => return Ok(()),
|
||||
Some(SaveType::Immediate(chan)) => {
|
||||
done_chans.push(chan);
|
||||
}
|
||||
Some(SaveType::Delayed) => {
|
||||
// Wait for SAVE_INTERVAL or until we receive an Immediate save.
|
||||
let _ = timeout(SAVE_INTERVAL, async {
|
||||
loop {
|
||||
match save_channel.recv().await {
|
||||
None => {
|
||||
break;
|
||||
}
|
||||
Some(SaveType::Immediate(chan)) => {
|
||||
done_chans.push(chan);
|
||||
break;
|
||||
}
|
||||
Some(SaveType::Delayed) => {}
|
||||
};
|
||||
}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
};
|
||||
|
||||
// Clear the queue in case more messages have stacked up past an
|
||||
// Immediate. Receiver::try_recv() is temporarily dead as of tokio 1.4
|
||||
// (https://github.com/tokio-rs/tokio/issues/3350) due to a bug where
|
||||
// messages can be delayed, but in this case that doesn't matter.
|
||||
loop {
|
||||
match save_channel.recv().now_or_never().await {
|
||||
None | Some(None) => {
|
||||
break;
|
||||
}
|
||||
Some(Some(SaveType::Immediate(chan))) => {
|
||||
done_chans.push(chan);
|
||||
}
|
||||
Some(Some(SaveType::Delayed)) => {}
|
||||
};
|
||||
}
|
||||
|
||||
// Mark the DB as non-stale, to start receiving save messages again.
|
||||
db.0.write().unwrap().stale = false;
|
||||
|
||||
// Actually do the save, by first serializing, then atomically saving
|
||||
// the file by creating and renaming a temp file in the same directory.
|
||||
let data = db.read(|db| serde_json::to_string(&db).unwrap());
|
||||
|
||||
let r: SimpleResult<()> = tokio::task::block_in_place(|| {
|
||||
let parent_dir = path.parent().unwrap_or_else(|| Path::new("."));
|
||||
let mut tempf = NamedTempFile::new_in(parent_dir)?;
|
||||
tempf.write_all(data.as_bytes())?;
|
||||
tempf.as_file().sync_all()?;
|
||||
tempf.persist(path)?;
|
||||
Ok(())
|
||||
});
|
||||
r?;
|
||||
|
||||
for chan in done_chans {
|
||||
let _ = chan.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SaveableDB {
|
||||
pub fn open(
|
||||
filename: &str,
|
||||
) -> SimpleResult<(impl Future<Output = SimpleResult<()>>, SaveableDB)> {
|
||||
let db_file = std::fs::File::open(filename)?;
|
||||
let db: DB = serde_json::from_reader(&db_file)?;
|
||||
|
||||
let (save_tx, save_rx) = mpsc::unbounded_channel();
|
||||
|
||||
let saveable_db = SaveableDB(Arc::new(RwLock::new(InnerSaveableDB {
|
||||
db,
|
||||
stale: false,
|
||||
save_chan: save_tx,
|
||||
})));
|
||||
|
||||
let path = PathBuf::from(filename);
|
||||
let db2 = saveable_db.clone();
|
||||
|
||||
let fut = async move { save_db_loop(db2, &path, save_rx).await };
|
||||
Ok((fut, saveable_db))
|
||||
}
|
||||
|
||||
pub fn read<T>(&self, callback: impl FnOnce(&DB) -> T) -> T {
|
||||
let inner = self.0.read().unwrap();
|
||||
callback(&inner.db)
|
||||
}
|
||||
|
||||
pub async fn write<T>(&self, immediate: bool, callback: impl FnOnce(&mut DB) -> T) -> T {
|
||||
let ret;
|
||||
let rx2;
|
||||
{
|
||||
let mut inner = self.0.write().unwrap();
|
||||
ret = callback(&mut inner.db);
|
||||
if immediate {
|
||||
inner.stale = true;
|
||||
let (tx, rx) = oneshot::channel();
|
||||
rx2 = rx;
|
||||
inner
|
||||
.save_chan
|
||||
.send(SaveType::Immediate(tx))
|
||||
.map_err(|_| ())
|
||||
.expect("Failed to send message to save task");
|
||||
} else {
|
||||
if !inner.stale {
|
||||
inner.stale = true;
|
||||
inner
|
||||
.save_chan
|
||||
.send(SaveType::Delayed)
|
||||
.map_err(|_| ())
|
||||
.expect("Failed to send message to save task");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
rx2.await.expect("Failed to save!");
|
||||
ret
|
||||
}
|
||||
}
|
||||
@@ -1,500 +0,0 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tokio::sync::{mpsc, mpsc::error::TrySendError, watch, Notify};
|
||||
|
||||
use crate::db::UserId;
|
||||
use crate::port::{ReadPort, WritePort};
|
||||
use crate::stats;
|
||||
use crate::util::SimpleResult;
|
||||
use crate::{
|
||||
ConnectedServer, MutableState, PermuterData, PermuterId, PermuterResult, PermuterWork,
|
||||
ServerUpdate, State, HEARTBEAT_TIME,
|
||||
};
|
||||
|
||||
const MIN_PERMUTER_VERSION: u32 = 1;
|
||||
|
||||
const SERVER_WORK_QUEUE_SIZE: usize = 100;
|
||||
const TIME_US_GUESS: f64 = 100_000.0;
|
||||
const MIN_OVERHEAD_US: f64 = 100_000.0;
|
||||
const MAX_OVERHEAD_FACTOR: i64 = 2;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct ConnectServerData {
|
||||
min_priority: f64,
|
||||
num_cores: f64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum ServerMessage {
|
||||
NeedWork,
|
||||
Update {
|
||||
permuter: PermuterId,
|
||||
time_us: f64,
|
||||
update: ServerUpdate,
|
||||
},
|
||||
}
|
||||
|
||||
enum JobState {
|
||||
Loading,
|
||||
Loaded,
|
||||
Failed,
|
||||
}
|
||||
|
||||
struct Job {
|
||||
state: JobState,
|
||||
energy: f64,
|
||||
active_work: i64,
|
||||
}
|
||||
|
||||
struct ServerState {
|
||||
min_priority: f64,
|
||||
/// sum of active_work across all jobs
|
||||
active_work: i64,
|
||||
/// fractional part of how much work should be requested, in [0, 1)
|
||||
more_work_acc: f64,
|
||||
jobs: HashMap<PermuterId, Job>,
|
||||
}
|
||||
|
||||
async fn server_read(
|
||||
port: &mut ReadPort<'_>,
|
||||
who_id: &UserId,
|
||||
who_name: &str,
|
||||
server_state: &Mutex<ServerState>,
|
||||
state: &State,
|
||||
more_work_tx: mpsc::Sender<()>,
|
||||
new_permuter: &Notify,
|
||||
) -> SimpleResult<()> {
|
||||
loop {
|
||||
let msg = port.recv().await?;
|
||||
let mut msg: ServerMessage = serde_json::from_slice(&msg)?;
|
||||
if let ServerMessage::Update {
|
||||
update:
|
||||
ServerUpdate::Result {
|
||||
ref mut compressed_source,
|
||||
has_source: true,
|
||||
..
|
||||
},
|
||||
..
|
||||
} = msg
|
||||
{
|
||||
*compressed_source = Some(port.recv().await?);
|
||||
}
|
||||
|
||||
let mut has_new = false;
|
||||
let mut request_work;
|
||||
|
||||
{
|
||||
let mut m = state.m.lock().unwrap();
|
||||
let mut server_state = server_state.lock().unwrap();
|
||||
|
||||
let mut more_work: f64 = 1.0;
|
||||
|
||||
if let ServerMessage::Update {
|
||||
permuter: perm_id,
|
||||
update,
|
||||
time_us,
|
||||
} = msg
|
||||
{
|
||||
// If we get back a message referring to a since-removed
|
||||
// permuter, no need to do anything. Just request one more
|
||||
// piece of work to make up for it.
|
||||
if let Some(job) = server_state.jobs.get_mut(&perm_id) {
|
||||
if let Some(perm) = m.permuters.get_mut(&perm_id) {
|
||||
job.energy += perm.energy_add * time_us;
|
||||
|
||||
match update {
|
||||
ServerUpdate::InitDone { .. } => {
|
||||
if !matches!(job.state, JobState::Loading) {
|
||||
Err("Got InitDone while not in Loading state")?;
|
||||
}
|
||||
job.state = JobState::Loaded;
|
||||
has_new = true;
|
||||
}
|
||||
ServerUpdate::InitFailed { .. } => {
|
||||
if !matches!(job.state, JobState::Loading) {
|
||||
Err("Got InitFailed while not in Loading state")?;
|
||||
}
|
||||
job.state = JobState::Failed;
|
||||
}
|
||||
ServerUpdate::Disconnect { .. } => {
|
||||
if !matches!(job.state, JobState::Loaded) {
|
||||
Err("Got Disconnect while not in Loaded state")?;
|
||||
}
|
||||
job.state = JobState::Failed;
|
||||
let work = job.active_work;
|
||||
job.active_work = 0;
|
||||
server_state.active_work -= work;
|
||||
more_work = 0.0;
|
||||
}
|
||||
ServerUpdate::Result { overhead_us, .. } => {
|
||||
if !matches!(job.state, JobState::Loaded) {
|
||||
Err("Got result while not in Loaded state")?;
|
||||
}
|
||||
// If the work item spent less than some given
|
||||
// amount of time in queues, request more work.
|
||||
// This ensures we saturate all server cores.
|
||||
// On the other hand, if it spends too much time
|
||||
// in queues, it's best if we reduce the amount
|
||||
// of work.
|
||||
// We don't need to adjust for time spent on the
|
||||
// network, because we have backpressure on slow
|
||||
// writes on both ends, and read continuously.
|
||||
job.active_work -= 1;
|
||||
server_state.active_work -= 1;
|
||||
let min_overhead_us = (time_us + MIN_OVERHEAD_US) as i64;
|
||||
if overhead_us == 0 {
|
||||
// Legacy server, skip this logic.
|
||||
} else if overhead_us > MAX_OVERHEAD_FACTOR * min_overhead_us {
|
||||
more_work = 0.5;
|
||||
} else if overhead_us < min_overhead_us {
|
||||
more_work = 1.5;
|
||||
}
|
||||
}
|
||||
}
|
||||
perm.send_result(PermuterResult::Result(
|
||||
who_id.clone(),
|
||||
who_name.to_string(),
|
||||
update,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
more_work += server_state.more_work_acc;
|
||||
request_work = more_work as i32;
|
||||
server_state.more_work_acc = more_work - request_work as f64;
|
||||
|
||||
if request_work == 0
|
||||
&& server_state.active_work == 0
|
||||
&& more_work_tx.capacity() == SERVER_WORK_QUEUE_SIZE
|
||||
{
|
||||
// Don't request 0 work if it would lead to total starvation.
|
||||
request_work = 1;
|
||||
}
|
||||
}
|
||||
|
||||
if has_new {
|
||||
new_permuter.notify_waiters();
|
||||
state
|
||||
.log_stats(stats::Record::ServerNewFunction {
|
||||
server: who_id.clone(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
for _ in 0..request_work {
|
||||
// Try requesting more work by sending a message to the writer thread.
|
||||
// If the queue is full (because the writer thread is blocked on a
|
||||
// send), drop the request to avoid an unbounded backlog.
|
||||
if let Err(TrySendError::Closed(_)) = more_work_tx.try_send(()) {
|
||||
panic!("work chooser must not close except on error");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum ToSend {
|
||||
Work(PermuterWork),
|
||||
Add {
|
||||
client_id: UserId,
|
||||
client_name: String,
|
||||
data: Arc<PermuterData>,
|
||||
},
|
||||
Remove,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OutMessage {
|
||||
permuter: PermuterId,
|
||||
#[serde(flatten)]
|
||||
to_send: ToSend,
|
||||
}
|
||||
|
||||
fn try_next_work_message(
|
||||
m: &mut MutableState,
|
||||
server_state: &mut ServerState,
|
||||
) -> Option<OutMessage> {
|
||||
let mut skip = HashSet::new();
|
||||
loop {
|
||||
// If possible, send a new permuter.
|
||||
if let Some((&perm_id, perm)) = m
|
||||
.permuters
|
||||
.iter()
|
||||
.find(|(&perm_id, _)| !server_state.jobs.contains_key(&perm_id))
|
||||
{
|
||||
server_state.jobs.insert(
|
||||
perm_id,
|
||||
Job {
|
||||
state: JobState::Loading,
|
||||
energy: 0.0,
|
||||
active_work: 0,
|
||||
},
|
||||
);
|
||||
return Some(OutMessage {
|
||||
permuter: perm_id,
|
||||
to_send: ToSend::Add {
|
||||
client_id: perm.client_id.clone(),
|
||||
client_name: perm.client_name.clone(),
|
||||
data: perm.data.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// If none, find an existing one to work on, or to remove.
|
||||
let mut best_cost = 0.0;
|
||||
let mut best: Option<(PermuterId, &mut Job)> = None;
|
||||
let min_priority = server_state.min_priority;
|
||||
for (&perm_id, job) in server_state.jobs.iter_mut() {
|
||||
if let Some(perm) = m.permuters.get(&perm_id) {
|
||||
let energy =
|
||||
job.energy + (job.active_work as f64) * perm.energy_add * TIME_US_GUESS;
|
||||
if matches!(job.state, JobState::Loaded)
|
||||
&& !skip.contains(&perm_id)
|
||||
&& perm.priority >= min_priority
|
||||
&& (best.is_none() || energy < best_cost)
|
||||
{
|
||||
best_cost = energy;
|
||||
best = Some((perm_id, job));
|
||||
}
|
||||
} else {
|
||||
server_state.active_work -= job.active_work;
|
||||
server_state.jobs.remove(&perm_id);
|
||||
return Some(OutMessage {
|
||||
permuter: perm_id,
|
||||
to_send: ToSend::Remove,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let (perm_id, job) = match best {
|
||||
None => return None,
|
||||
Some(tup) => tup,
|
||||
};
|
||||
|
||||
let perm = m.permuters.get_mut(&perm_id).unwrap();
|
||||
let work = match perm.work_queue.pop_front() {
|
||||
None => {
|
||||
// Chosen permuter is out of work. Ask it for more, and try
|
||||
// again without it as a candidate. When the queue becomes
|
||||
// non-empty again all sleeping writers will be notified.
|
||||
perm.send_result(PermuterResult::NeedWork);
|
||||
skip.insert(perm_id);
|
||||
continue;
|
||||
}
|
||||
Some(work) => work,
|
||||
};
|
||||
|
||||
perm.semaphore.release();
|
||||
|
||||
let min_energy = job.energy;
|
||||
job.active_work += 1;
|
||||
server_state.active_work += 1;
|
||||
|
||||
// Adjust energies to be around zero, to avoid problems with float
|
||||
// imprecision, and to ensure that new permuters that come in with
|
||||
// energy zero will fit the schedule.
|
||||
for job in server_state.jobs.values_mut() {
|
||||
job.energy -= min_energy;
|
||||
}
|
||||
|
||||
return Some(OutMessage {
|
||||
permuter: perm_id,
|
||||
to_send: ToSend::Work(work),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn next_work_message(
|
||||
server_state: &Mutex<ServerState>,
|
||||
state: &State,
|
||||
new_permuter: &Notify,
|
||||
) -> OutMessage {
|
||||
let mut wait_for = None;
|
||||
loop {
|
||||
if let Some(waiter) = wait_for {
|
||||
waiter.await;
|
||||
}
|
||||
let mut m = state.m.lock().unwrap();
|
||||
let mut server_state = server_state.lock().unwrap();
|
||||
match try_next_work_message(&mut m, &mut server_state) {
|
||||
Some(message) => return message,
|
||||
None => {
|
||||
// Nothing to work on! Register to be notified when something
|
||||
// happens (while the lock is still held) and go to sleep.
|
||||
let n1 = state.new_work_notification.notified();
|
||||
let n2 = new_permuter.notified();
|
||||
wait_for = Some(async move {
|
||||
tokio::select! {
|
||||
() = n1 => {}
|
||||
() = n2 => {}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn requires_response(work: &OutMessage) -> bool {
|
||||
match work.to_send {
|
||||
ToSend::Work { .. } => true,
|
||||
ToSend::Add { .. } => true,
|
||||
ToSend::Remove => false,
|
||||
}
|
||||
}
|
||||
|
||||
async fn server_choose_work(
|
||||
server_state: &Mutex<ServerState>,
|
||||
state: &State,
|
||||
mut more_work_rx: mpsc::Receiver<()>,
|
||||
next_message_tx: mpsc::Sender<OutMessage>,
|
||||
wrote_message: &Notify,
|
||||
new_permuter: &Notify,
|
||||
) -> SimpleResult<()> {
|
||||
loop {
|
||||
let message = next_work_message(server_state, state, new_permuter).await;
|
||||
let requires_response = requires_response(&message);
|
||||
next_message_tx
|
||||
.send(message)
|
||||
.await
|
||||
.map_err(|_| ())
|
||||
.expect("writer must not close except on error");
|
||||
wrote_message.notified().await;
|
||||
if requires_response {
|
||||
more_work_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("reader must not close except on error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_heartbeat(port: &mut WritePort<'_>) -> SimpleResult<()> {
|
||||
port.send_json(&json!({
|
||||
"type": "heartbeat",
|
||||
}))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_work(port: &mut WritePort<'_>, work: &OutMessage) -> SimpleResult<()> {
|
||||
port.send_json(&work).await?;
|
||||
if let ToSend::Add { ref data, .. } = work.to_send {
|
||||
port.send(&data.compressed_source).await?;
|
||||
port.send(&data.compressed_target_o_bin).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn server_write(
|
||||
port: &mut WritePort<'_>,
|
||||
mut next_message_rx: mpsc::Receiver<OutMessage>,
|
||||
mut heartbeat_rx: watch::Receiver<()>,
|
||||
wrote_message: &Notify,
|
||||
) -> SimpleResult<()> {
|
||||
loop {
|
||||
tokio::select! {
|
||||
work = next_message_rx.recv() => {
|
||||
let work = work.expect("chooser must not close except on error");
|
||||
send_work(port, &work).await?;
|
||||
wrote_message.notify_one();
|
||||
}
|
||||
res = heartbeat_rx.changed() => {
|
||||
res.expect("heartbeat thread panicked");
|
||||
send_heartbeat(port).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_connect_server<'a>(
|
||||
mut read_port: ReadPort<'a>,
|
||||
mut write_port: WritePort<'a>,
|
||||
who_id: UserId,
|
||||
who_name: &str,
|
||||
permuter_version: u32,
|
||||
state: &State,
|
||||
data: ConnectServerData,
|
||||
) -> SimpleResult<()> {
|
||||
if permuter_version < MIN_PERMUTER_VERSION {
|
||||
Err("Permuter version too old!")?;
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"[{}] start server ({}, {})",
|
||||
who_name, data.min_priority, data.num_cores
|
||||
);
|
||||
|
||||
write_port
|
||||
.send_json(&json!({
|
||||
"docker_image": &state.docker_image,
|
||||
"heartbeat_interval": HEARTBEAT_TIME.as_secs(),
|
||||
}))
|
||||
.await?;
|
||||
|
||||
let (more_work_tx, more_work_rx) = mpsc::channel(SERVER_WORK_QUEUE_SIZE);
|
||||
let (next_message_tx, next_message_rx) = mpsc::channel(1);
|
||||
let wrote_message = Notify::new();
|
||||
let new_permuter = Notify::new();
|
||||
|
||||
let mut server_state = Mutex::new(ServerState {
|
||||
min_priority: data.min_priority,
|
||||
active_work: 0,
|
||||
more_work_acc: 0.0,
|
||||
jobs: HashMap::new(),
|
||||
});
|
||||
|
||||
let id = state.m.lock().unwrap().servers.insert(ConnectedServer {
|
||||
min_priority: data.min_priority,
|
||||
num_cores: data.num_cores,
|
||||
});
|
||||
|
||||
let r = tokio::try_join!(
|
||||
server_read(
|
||||
&mut read_port,
|
||||
&who_id,
|
||||
who_name,
|
||||
&server_state,
|
||||
state,
|
||||
more_work_tx,
|
||||
&new_permuter,
|
||||
),
|
||||
server_choose_work(
|
||||
&server_state,
|
||||
state,
|
||||
more_work_rx,
|
||||
next_message_tx,
|
||||
&wrote_message,
|
||||
&new_permuter,
|
||||
),
|
||||
server_write(
|
||||
&mut write_port,
|
||||
next_message_rx,
|
||||
state.heartbeat_rx.clone(),
|
||||
&wrote_message,
|
||||
)
|
||||
);
|
||||
|
||||
{
|
||||
let mut m = state.m.lock().unwrap();
|
||||
for (&perm_id, job) in &server_state.get_mut().unwrap().jobs {
|
||||
if let JobState::Loaded = job.state {
|
||||
if let Some(perm) = m.permuters.get_mut(&perm_id) {
|
||||
perm.send_result(PermuterResult::Result(
|
||||
who_id.clone(),
|
||||
who_name.to_string(),
|
||||
ServerUpdate::Disconnect,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.servers.remove(id);
|
||||
}
|
||||
r?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::default::Default;
|
||||
use std::fs::OpenOptions;
|
||||
|
||||
use sodiumoxide::crypto::sign;
|
||||
use sodiumoxide::randombytes::randombytes;
|
||||
|
||||
use crate::db::{User, UserId, DB};
|
||||
use crate::util::SimpleResult;
|
||||
use crate::SetupOpts;
|
||||
|
||||
pub(crate) fn run_setup(opts: SetupOpts) -> SimpleResult<()> {
|
||||
let db_file = OpenOptions::new()
|
||||
.write(true)
|
||||
.create_new(true)
|
||||
.open(&opts.db)
|
||||
.unwrap_or_else(|e| {
|
||||
eprintln!("Cannot create database file {}: {}. Aborting.", &opts.db, e);
|
||||
std::process::exit(1);
|
||||
});
|
||||
|
||||
let server_seed = sign::Seed::from_slice(&randombytes(32)).unwrap();
|
||||
let client_seed = sign::Seed::from_slice(&randombytes(32)).unwrap();
|
||||
|
||||
let (server_pub_key, _) = sign::keypair_from_seed(&server_seed);
|
||||
let (client_pub_key, _) = sign::keypair_from_seed(&client_seed);
|
||||
|
||||
let root_user = User {
|
||||
trusted_by: None,
|
||||
name: "root".into(),
|
||||
client_stats: Default::default(),
|
||||
server_stats: Default::default(),
|
||||
};
|
||||
let mut users_map: HashMap<UserId, User> = HashMap::new();
|
||||
users_map.insert(UserId::from_pubkey(&client_pub_key), root_user);
|
||||
let db = DB {
|
||||
users: users_map,
|
||||
func_stats: HashMap::new(),
|
||||
total_stats: Default::default(),
|
||||
};
|
||||
|
||||
serde_json::to_writer(&db_file, &db)?;
|
||||
|
||||
println!(
|
||||
"Setup successful!\n\n\
|
||||
Put the following in the server's config.toml:\n\n\
|
||||
priv_seed = \"{}\"\n\n\
|
||||
Put the following in the root client's pah.conf:\n\n\
|
||||
secret_key = \"{}\"\n\
|
||||
server_public_key = \"{}\"\n\
|
||||
server_address = \"server.example:port\"",
|
||||
hex::encode(server_seed),
|
||||
hex::encode(client_seed),
|
||||
hex::encode(server_pub_key)
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
use std::future::Future;
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::db::{Stats, UserId};
|
||||
use crate::save::SaveableDB;
|
||||
|
||||
const CHANNEL_CAPACITY: usize = 10000;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum Outcome {
|
||||
Matched,
|
||||
Improved,
|
||||
Unhelpful,
|
||||
}
|
||||
|
||||
pub enum Record {
|
||||
WorkDone {
|
||||
server: UserId,
|
||||
client: UserId,
|
||||
fn_name: String,
|
||||
outcome: Outcome,
|
||||
},
|
||||
ClientNewFunction {
|
||||
client: UserId,
|
||||
fn_name: String,
|
||||
},
|
||||
ServerNewFunction {
|
||||
server: UserId,
|
||||
},
|
||||
}
|
||||
|
||||
fn add_stats(stats: &mut Stats, outcome: Outcome) {
|
||||
if matches!(outcome, Outcome::Matched) {
|
||||
stats.matches += 1;
|
||||
}
|
||||
if matches!(outcome, Outcome::Matched | Outcome::Improved) {
|
||||
stats.improvements += 1;
|
||||
}
|
||||
stats.iterations += 1;
|
||||
}
|
||||
|
||||
async fn stats_writer(db: &SaveableDB, mut rx: mpsc::Receiver<Record>) {
|
||||
loop {
|
||||
let record = rx.recv().await.unwrap();
|
||||
db.write(false, |db| {
|
||||
match record {
|
||||
Record::WorkDone {
|
||||
server,
|
||||
client,
|
||||
fn_name,
|
||||
outcome,
|
||||
} => {
|
||||
add_stats(&mut db.total_stats, outcome);
|
||||
add_stats(db.func_stat(fn_name), outcome);
|
||||
if let Some(user) = db.users.get_mut(&client) {
|
||||
add_stats(&mut user.client_stats, outcome);
|
||||
}
|
||||
if let Some(user) = db.users.get_mut(&server) {
|
||||
add_stats(&mut user.server_stats, outcome);
|
||||
}
|
||||
}
|
||||
Record::ClientNewFunction { client, fn_name } => {
|
||||
db.func_stat(fn_name).functions += 1;
|
||||
if let Some(user) = db.users.get_mut(&client) {
|
||||
user.client_stats.functions += 1;
|
||||
}
|
||||
db.total_stats.functions += 1;
|
||||
}
|
||||
Record::ServerNewFunction { server } => {
|
||||
if let Some(user) = db.users.get_mut(&server) {
|
||||
user.server_stats.functions += 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stats_thread(db: &SaveableDB) -> (impl Future<Output = ()>, mpsc::Sender<Record>) {
|
||||
let (tx, rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let db = db.clone();
|
||||
let fut = async move {
|
||||
stats_writer(&db, rx).await;
|
||||
};
|
||||
(fut, tx)
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
use std::error::Error;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use pin_project::pin_project;
|
||||
|
||||
pub type SimpleResult<T> = Result<T, Box<dyn Error + Send + Sync>>;
|
||||
|
||||
#[pin_project]
|
||||
pub struct NowOrNever<F: Future> {
|
||||
#[pin]
|
||||
inner: F,
|
||||
}
|
||||
|
||||
impl<F: Future> Future for NowOrNever<F> {
|
||||
type Output = Option<F::Output>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let ret = self.project().inner.poll(cx);
|
||||
Poll::Ready(match ret {
|
||||
Poll::Pending => None,
|
||||
Poll::Ready(val) => Some(val),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> FutureExt for T where T: Future {}
|
||||
|
||||
pub trait FutureExt: Future {
|
||||
fn now_or_never(self) -> NowOrNever<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
NowOrNever { inner: self }
|
||||
}
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
use std::str;
|
||||
|
||||
use hex::FromHex;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use sodiumoxide::crypto::sign;
|
||||
|
||||
use crate::db::{User, UserId};
|
||||
use crate::port::{ReadPort, WritePort};
|
||||
use crate::util::SimpleResult;
|
||||
use crate::{concat, State};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct VouchData {
|
||||
who: UserId,
|
||||
signed_name: String,
|
||||
}
|
||||
|
||||
fn verify_with_magic<'a>(
|
||||
magic: &[u8],
|
||||
data: &'a [u8],
|
||||
key: &sign::PublicKey,
|
||||
) -> SimpleResult<&'a [u8]> {
|
||||
if data.len() < 64 {
|
||||
Err("signature too short")?;
|
||||
}
|
||||
let (signature, data) = data.split_at(64);
|
||||
let signed_data = concat(magic, data);
|
||||
let signature = sign::Signature::from_slice(signature).unwrap();
|
||||
if !sign::verify_detached(&signature, &signed_data, key) {
|
||||
Err("bad signature")?;
|
||||
}
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn parse_signed_name(signed_name: &str, who: &UserId) -> SimpleResult<String> {
|
||||
let signed_name = Vec::from_hex(signed_name).map_err(|_| "not a valid hex string")?;
|
||||
let name_bytes = verify_with_magic(b"NAME:", &signed_name, &who.to_pubkey())?;
|
||||
let name = str::from_utf8(name_bytes)?;
|
||||
if name.is_empty() {
|
||||
Err("name is empty")?;
|
||||
}
|
||||
if name.chars().any(char::is_control) {
|
||||
Err("name cannot contain control characters")?;
|
||||
}
|
||||
Ok(name.to_string())
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_vouch<'a>(
|
||||
mut read_port: ReadPort<'a>,
|
||||
mut write_port: WritePort<'a>,
|
||||
who_id: UserId,
|
||||
who_name: &str,
|
||||
state: &State,
|
||||
data: VouchData,
|
||||
) -> SimpleResult<()> {
|
||||
let vouchee_name = match parse_signed_name(&data.signed_name, &data.who) {
|
||||
Ok(name) => name,
|
||||
Err(e) => {
|
||||
write_port.send_error(&format!("{}", &e)).await?;
|
||||
Err(e)?
|
||||
}
|
||||
};
|
||||
write_port.send_json(&json!({})).await?;
|
||||
read_port.recv().await?;
|
||||
state
|
||||
.db
|
||||
.write(true, |db| {
|
||||
db.users.entry(data.who).or_insert_with(|| User {
|
||||
trusted_by: Some(who_id),
|
||||
name: vouchee_name.clone(),
|
||||
client_stats: Default::default(),
|
||||
server_stats: Default::default(),
|
||||
});
|
||||
})
|
||||
.await;
|
||||
write_port.send_json(&json!({})).await?;
|
||||
eprintln!("[{}] vouch {}", who_name, &vouchee_name);
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,416 +0,0 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
import datetime
|
||||
import json
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import toml
|
||||
import typing
|
||||
from typing import BinaryIO, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from nacl.encoding import HexEncoder
|
||||
from nacl.public import Box, PrivateKey, PublicKey
|
||||
from nacl.secret import SecretBox
|
||||
from nacl.signing import SigningKey, VerifyKey
|
||||
|
||||
from ..error import ServerError
|
||||
from ..helpers import exception_to_string
|
||||
|
||||
T = TypeVar("T")
|
||||
AnyBox = Union[Box, SecretBox]
|
||||
|
||||
PERMUTER_VERSION = 2
|
||||
|
||||
CONFIG_FILENAME = "pah.conf"
|
||||
|
||||
MIN_PRIO = 0.01
|
||||
MAX_PRIO = 2.0
|
||||
|
||||
DEBUG_MODE = False
|
||||
|
||||
|
||||
def enable_debug_mode() -> None:
|
||||
"""Enable debug logging."""
|
||||
global DEBUG_MODE
|
||||
DEBUG_MODE = True
|
||||
|
||||
|
||||
def debug_print(message: str) -> None:
|
||||
if DEBUG_MODE:
|
||||
time = datetime.datetime.now().strftime("%H:%M:%S:%f")
|
||||
print(f"\n{time} debug: {message}")
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class CancelToken:
|
||||
cancelled: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class PermuterData:
|
||||
base_score: int
|
||||
base_hash: str
|
||||
fn_name: str
|
||||
filename: str
|
||||
keep_prob: float
|
||||
need_profiler: bool
|
||||
stack_differences: bool
|
||||
compile_script: str
|
||||
source: str
|
||||
target_o_bin: bytes
|
||||
|
||||
|
||||
def permuter_data_from_json(
|
||||
obj: dict, source: str, target_o_bin: bytes
|
||||
) -> PermuterData:
|
||||
return PermuterData(
|
||||
base_score=json_prop(obj, "base_score", int),
|
||||
base_hash=json_prop(obj, "base_hash", str),
|
||||
fn_name=json_prop(obj, "fn_name", str),
|
||||
filename=json_prop(obj, "filename", str),
|
||||
keep_prob=json_prop(obj, "keep_prob", float),
|
||||
need_profiler=json_prop(obj, "need_profiler", bool),
|
||||
stack_differences=json_prop(obj, "stack_differences", bool),
|
||||
compile_script=json_prop(obj, "compile_script", str),
|
||||
source=source,
|
||||
target_o_bin=target_o_bin,
|
||||
)
|
||||
|
||||
|
||||
def permuter_data_to_json(perm: PermuterData) -> dict:
|
||||
return {
|
||||
"base_score": perm.base_score,
|
||||
"base_hash": perm.base_hash,
|
||||
"fn_name": perm.fn_name,
|
||||
"filename": perm.filename,
|
||||
"keep_prob": perm.keep_prob,
|
||||
"need_profiler": perm.need_profiler,
|
||||
"stack_differences": perm.stack_differences,
|
||||
"compile_script": perm.compile_script,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
server_address: Optional[str] = None
|
||||
server_verify_key: Optional[VerifyKey] = None
|
||||
signing_key: Optional[SigningKey] = None
|
||||
initial_setup_nickname: Optional[str] = None
|
||||
|
||||
|
||||
def read_config() -> Config:
|
||||
config = Config()
|
||||
try:
|
||||
with open(CONFIG_FILENAME) as f:
|
||||
obj = toml.load(f)
|
||||
|
||||
def read(key: str, t: Type[T]) -> Optional[T]:
|
||||
ret = obj.get(key)
|
||||
return ret if isinstance(ret, t) else None
|
||||
|
||||
temp = read("server_public_key", str)
|
||||
if temp:
|
||||
config.server_verify_key = VerifyKey(HexEncoder.decode(temp))
|
||||
temp = read("secret_key", str)
|
||||
if temp:
|
||||
config.signing_key = SigningKey(HexEncoder.decode(temp))
|
||||
config.initial_setup_nickname = read("initial_setup_nickname", str)
|
||||
config.server_address = read("server_address", str)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
print(f"Malformed configuration file {CONFIG_FILENAME}.\n")
|
||||
raise
|
||||
return config
|
||||
|
||||
|
||||
def write_config(config: Config) -> None:
|
||||
obj = {}
|
||||
|
||||
def write(key: str, val: Union[None, str, int]) -> None:
|
||||
if val is not None:
|
||||
obj[key] = val
|
||||
|
||||
write("initial_setup_nickname", config.initial_setup_nickname)
|
||||
write("server_address", config.server_address)
|
||||
|
||||
key_hex: bytes
|
||||
if config.server_verify_key:
|
||||
key_hex = config.server_verify_key.encode(HexEncoder)
|
||||
write("server_public_key", key_hex.decode("utf-8"))
|
||||
if config.signing_key:
|
||||
key_hex = config.signing_key.encode(HexEncoder)
|
||||
write("secret_key", key_hex.decode("utf-8"))
|
||||
|
||||
with open(CONFIG_FILENAME, "w") as f:
|
||||
toml.dump(obj, f)
|
||||
|
||||
|
||||
def file_read_max(inf: BinaryIO, n: int) -> bytes:
|
||||
try:
|
||||
ret = []
|
||||
while n > 0:
|
||||
data = inf.read(n)
|
||||
if not data:
|
||||
break
|
||||
ret.append(data)
|
||||
n -= len(data)
|
||||
return b"".join(ret)
|
||||
except Exception as e:
|
||||
raise EOFError from e
|
||||
|
||||
|
||||
def file_read_fixed(inf: BinaryIO, n: int) -> bytes:
|
||||
ret = file_read_max(inf, n)
|
||||
if len(ret) != n:
|
||||
raise EOFError
|
||||
return ret
|
||||
|
||||
|
||||
def socket_read_max(sock: socket.socket, n: int) -> bytes:
|
||||
try:
|
||||
ret = []
|
||||
while n > 0:
|
||||
data = sock.recv(min(n, 4096))
|
||||
if not data:
|
||||
break
|
||||
ret.append(data)
|
||||
n -= len(data)
|
||||
return b"".join(ret)
|
||||
except Exception as e:
|
||||
raise EOFError from e
|
||||
|
||||
|
||||
def socket_read_fixed(sock: socket.socket, n: int) -> bytes:
|
||||
ret = socket_read_max(sock, n)
|
||||
if len(ret) != n:
|
||||
raise EOFError
|
||||
return ret
|
||||
|
||||
|
||||
def socket_shutdown(sock: socket.socket, how: int) -> None:
|
||||
try:
|
||||
sock.shutdown(how)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def json_prop(obj: dict, prop: str, t: Type[T]) -> T:
|
||||
ret = obj.get(prop)
|
||||
if not isinstance(ret, t):
|
||||
if t is float and isinstance(ret, int):
|
||||
return typing.cast(T, float(ret))
|
||||
found_type = type(ret).__name__
|
||||
if prop not in obj:
|
||||
raise ValueError(f"Member {prop} does not exist")
|
||||
raise ValueError(f"Member {prop} must have type {t.__name__}; got {found_type}")
|
||||
return ret
|
||||
|
||||
|
||||
def json_array(obj: list, t: Type[T]) -> List[T]:
|
||||
for elem in obj:
|
||||
if not isinstance(elem, t):
|
||||
found_type = type(elem).__name__
|
||||
raise ValueError(
|
||||
f"Array elements must have type {t.__name__}; got {found_type}"
|
||||
)
|
||||
return obj
|
||||
|
||||
|
||||
def sign_with_magic(magic: bytes, signing_key: SigningKey, data: bytes) -> bytes:
|
||||
signature: bytes = signing_key.sign(magic + b":" + data).signature
|
||||
return signature + data
|
||||
|
||||
|
||||
def verify_with_magic(magic: bytes, verify_key: VerifyKey, data: bytes) -> bytes:
|
||||
if len(data) < 64:
|
||||
raise ValueError("String is too small to contain a signature")
|
||||
signature = data[:64]
|
||||
data = data[64:]
|
||||
verify_key.verify(magic + b":" + data, signature)
|
||||
return data
|
||||
|
||||
|
||||
class Port(abc.ABC):
|
||||
def __init__(self, box: AnyBox, who: str, *, is_client: bool) -> None:
|
||||
self._box = box
|
||||
self._who = who
|
||||
self._send_nonce = 0 if is_client else 1
|
||||
self._receive_nonce = 1 if is_client else 0
|
||||
|
||||
@abc.abstractmethod
|
||||
def _send(self, data: bytes) -> None:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def _receive(self, length: int) -> bytes:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def _receive_max(self, length: int) -> bytes:
|
||||
...
|
||||
|
||||
def send(self, msg: bytes) -> None:
|
||||
"""Send a binary message, potentially blocking."""
|
||||
if DEBUG_MODE:
|
||||
if len(msg) <= 300:
|
||||
debug_print(f"Send to {self._who}: {msg!r}")
|
||||
else:
|
||||
debug_print(f"Send to {self._who}: {len(msg)} bytes")
|
||||
nonce = struct.pack(">16xQ", self._send_nonce)
|
||||
self._send_nonce += 2
|
||||
data = self._box.encrypt(msg, nonce).ciphertext
|
||||
length_data = struct.pack(">Q", len(data))
|
||||
try:
|
||||
self._send(length_data + data)
|
||||
except BrokenPipeError:
|
||||
raise EOFError from None
|
||||
|
||||
def send_json(self, msg: dict) -> None:
|
||||
"""Send a message in the form of a JSON dict, potentially blocking."""
|
||||
self.send(json.dumps(msg).encode("utf-8"))
|
||||
|
||||
def receive(self) -> bytes:
|
||||
"""Read a binary message, blocking."""
|
||||
length_data = self._receive(8)
|
||||
if length_data[0]:
|
||||
# Lengths above 2^56 are unreasonable, so if we get one someone is
|
||||
# sending us bad data. Raise an exception to help debugging.
|
||||
length_data += self._receive_max(1024)
|
||||
raise Exception(
|
||||
f"Got unexpected data from {self._who}: " + repr(length_data)
|
||||
)
|
||||
length = struct.unpack(">Q", length_data)[0]
|
||||
data = self._receive(length)
|
||||
nonce = struct.pack(">16xQ", self._receive_nonce)
|
||||
self._receive_nonce += 2
|
||||
msg: bytes = self._box.decrypt(data, nonce)
|
||||
if DEBUG_MODE:
|
||||
if len(msg) <= 300:
|
||||
debug_print(f"Receive from {self._who}: {msg!r}")
|
||||
else:
|
||||
debug_print(f"Receive from {self._who}: {len(msg)} bytes")
|
||||
return msg
|
||||
|
||||
def receive_json(self) -> dict:
|
||||
"""Read a message in the form of a JSON dict, blocking."""
|
||||
ret = json.loads(self.receive())
|
||||
if isinstance(ret, str):
|
||||
# Raw strings indicate errors.
|
||||
raise ServerError(ret)
|
||||
if not isinstance(ret, dict):
|
||||
# We always pass dictionaries as messages and no other data types,
|
||||
# to ensure future extensibility. (Other types are rare in
|
||||
# practice, anyway.)
|
||||
raise ValueError("Top-level JSON value must be a dictionary")
|
||||
return ret
|
||||
|
||||
|
||||
class SocketPort(Port):
|
||||
def __init__(
|
||||
self, sock: socket.socket, box: AnyBox, who: str, *, is_client: bool
|
||||
) -> None:
|
||||
self._sock = sock
|
||||
super().__init__(box, who, is_client=is_client)
|
||||
|
||||
def _send(self, data: bytes) -> None:
|
||||
self._sock.sendall(data)
|
||||
|
||||
def _receive(self, length: int) -> bytes:
|
||||
return socket_read_fixed(self._sock, length)
|
||||
|
||||
def _receive_max(self, length: int) -> bytes:
|
||||
return socket_read_max(self._sock, length)
|
||||
|
||||
def shutdown(self, how: int = socket.SHUT_RDWR) -> None:
|
||||
socket_shutdown(self._sock, how)
|
||||
|
||||
def close(self) -> None:
|
||||
self._sock.close()
|
||||
|
||||
|
||||
class FilePort(Port):
|
||||
def __init__(
|
||||
self, inf: BinaryIO, outf: BinaryIO, box: AnyBox, who: str, *, is_client: bool
|
||||
) -> None:
|
||||
self._inf = inf
|
||||
self._outf = outf
|
||||
super().__init__(box, who, is_client=is_client)
|
||||
|
||||
def _send(self, data: bytes) -> None:
|
||||
self._outf.write(data)
|
||||
self._outf.flush()
|
||||
|
||||
def _receive(self, length: int) -> bytes:
|
||||
return file_read_fixed(self._inf, length)
|
||||
|
||||
def _receive_max(self, length: int) -> bytes:
|
||||
return file_read_max(self._inf, length)
|
||||
|
||||
|
||||
def _do_connect(config: Config) -> SocketPort:
|
||||
if (
|
||||
not config.server_verify_key
|
||||
or not config.signing_key
|
||||
or not config.server_address
|
||||
):
|
||||
print(
|
||||
"Using permuter@home requires someone to give you access to a central -J server.\n"
|
||||
"Run `./pah.py setup` to set this up."
|
||||
)
|
||||
print()
|
||||
sys.exit(1)
|
||||
|
||||
host, port_str = config.server_address.split(":")
|
||||
try:
|
||||
sock = socket.create_connection((host, int(port_str)))
|
||||
except ConnectionRefusedError:
|
||||
raise EOFError("connection refused") from None
|
||||
except socket.gaierror as e:
|
||||
raise EOFError(f"DNS lookup failed: {e}") from None
|
||||
except Exception as e:
|
||||
raise EOFError("unable to connect: " + exception_to_string(e)) from None
|
||||
|
||||
# Send over the protocol version and an ephemeral encryption key which we
|
||||
# are going to use for all communication.
|
||||
ephemeral_key = PrivateKey.generate()
|
||||
ephemeral_key_data = ephemeral_key.public_key.encode()
|
||||
sock.sendall(b"p@h0" + ephemeral_key_data)
|
||||
|
||||
# Receive the server's encryption key, plus a signature of it and our own
|
||||
# ephemeral key -- this guarantees that we are talking to the server and
|
||||
# aren't victim to a replay attack. Use it to set up a communication port.
|
||||
msg = socket_read_fixed(sock, 32 + 64)
|
||||
server_enc_key_data = msg[:32]
|
||||
config.server_verify_key.verify(
|
||||
b"HELLO:" + ephemeral_key_data + server_enc_key_data, msg[32:]
|
||||
)
|
||||
box = Box(ephemeral_key, PublicKey(server_enc_key_data))
|
||||
port = SocketPort(sock, box, "controller", is_client=True)
|
||||
|
||||
# Use the encrypted port to send over our public key, proof that we are
|
||||
# able to sign new things with it, as well as permuter version.
|
||||
signature: bytes = config.signing_key.sign(
|
||||
b"WORLD:" + server_enc_key_data
|
||||
).signature
|
||||
port.send(
|
||||
config.signing_key.verify_key.encode()
|
||||
+ signature
|
||||
+ struct.pack(">I", PERMUTER_VERSION)
|
||||
)
|
||||
|
||||
# Get an acknowledgement that the server wants to talk to us.
|
||||
obj = port.receive_json()
|
||||
|
||||
if "message" in obj:
|
||||
print(obj["message"])
|
||||
|
||||
return port
|
||||
|
||||
|
||||
def connect(config: Optional[Config] = None) -> SocketPort:
|
||||
"""Authenticate and connect to the permuter@home controller server."""
|
||||
if not config:
|
||||
config = read_config()
|
||||
return _do_connect(config)
|
||||
@@ -1,401 +0,0 @@
|
||||
"""This file runs as a free-standing program within a sandbox, and processes
|
||||
permutation requests. It communicates with the outside world on stdin/stdout."""
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from multiprocessing import Process, Queue
|
||||
import os
|
||||
import queue
|
||||
import sys
|
||||
from tempfile import mkstemp
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from typing import Counter, Dict, List, Optional, Set, Tuple, Union
|
||||
import zlib
|
||||
|
||||
from nacl.secret import SecretBox
|
||||
|
||||
from ..candidate import CandidateResult
|
||||
from ..compiler import Compiler
|
||||
from ..error import CandidateConstructionFailure
|
||||
from ..helpers import exception_to_string, static_assert_unreachable
|
||||
from ..permuter import EvalError, EvalResult, Permuter
|
||||
from ..profiler import Profiler
|
||||
from ..scorer import Scorer
|
||||
from .core import (
|
||||
FilePort,
|
||||
PermuterData,
|
||||
Port,
|
||||
json_prop,
|
||||
permuter_data_from_json,
|
||||
)
|
||||
|
||||
|
||||
def _fix_stdout() -> None:
|
||||
"""Redirect stdout to stderr to make print() debugging work. This function
|
||||
*must* be called at startup for each (sub)process, since we use stdout for
|
||||
our own communication purposes."""
|
||||
sys.stdout = sys.stderr
|
||||
|
||||
# In addition, we set stderr to flush on newlines, which does not happen by
|
||||
# default when it is piped. (Requires Python 3.7, but we can assume that's
|
||||
# available inside the sandbox.)
|
||||
sys.stdout.reconfigure(line_buffering=True) # type: ignore
|
||||
|
||||
|
||||
def _setup_port(secret: bytes) -> Port:
|
||||
"""Set up communication with the outside world."""
|
||||
port = FilePort(
|
||||
sys.stdin.buffer,
|
||||
sys.stdout.buffer,
|
||||
SecretBox(secret),
|
||||
"server",
|
||||
is_client=False,
|
||||
)
|
||||
|
||||
_fix_stdout()
|
||||
|
||||
# Follow the controlling process's sanity check protocol.
|
||||
magic = port.receive()
|
||||
port.send(magic)
|
||||
|
||||
return port
|
||||
|
||||
|
||||
def _create_permuter(data: PermuterData) -> Permuter:
|
||||
fd, path = mkstemp(suffix=".o", prefix="permuter", text=False)
|
||||
try:
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
f.write(data.target_o_bin)
|
||||
scorer = Scorer(target_o=path, stack_differences=data.stack_differences)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
fd, path = mkstemp(suffix=".sh", prefix="permuter", text=True)
|
||||
try:
|
||||
os.chmod(fd, 0o755)
|
||||
with os.fdopen(fd, "w") as f2:
|
||||
f2.write(data.compile_script)
|
||||
compiler = Compiler(compile_cmd=path, show_errors=False)
|
||||
|
||||
return Permuter(
|
||||
dir="unused",
|
||||
fn_name=data.fn_name,
|
||||
compiler=compiler,
|
||||
scorer=scorer,
|
||||
source_file=data.filename,
|
||||
source=data.source,
|
||||
force_seed=None,
|
||||
force_rng_seed=None,
|
||||
keep_prob=data.keep_prob,
|
||||
need_profiler=data.need_profiler,
|
||||
need_all_sources=False,
|
||||
show_errors=False,
|
||||
better_only=False,
|
||||
best_only=False,
|
||||
)
|
||||
except:
|
||||
os.unlink(path)
|
||||
raise
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddPermuter:
|
||||
perm_id: str
|
||||
data: PermuterData
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddPermuterLocal:
|
||||
perm_id: str
|
||||
permuter: Permuter
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemovePermuter:
|
||||
perm_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkDone:
|
||||
perm_id: str
|
||||
id: int
|
||||
time_us: int
|
||||
result: EvalResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class Work:
|
||||
perm_id: str
|
||||
id: int
|
||||
seed: int
|
||||
|
||||
|
||||
LocalWork = Tuple[Union[AddPermuterLocal, RemovePermuter], int]
|
||||
GlobalWork = Tuple[Work, int]
|
||||
Task = Union[AddPermuter, RemovePermuter, Work, WorkDone]
|
||||
|
||||
|
||||
def _remove_permuter(perm: Permuter) -> None:
|
||||
os.unlink(perm.compiler.compile_cmd)
|
||||
|
||||
|
||||
def _send_result(item: WorkDone, port: Port) -> None:
|
||||
obj = {
|
||||
"type": "result",
|
||||
"permuter": item.perm_id,
|
||||
"id": item.id,
|
||||
"time_us": item.time_us,
|
||||
}
|
||||
|
||||
res = item.result
|
||||
if isinstance(res, EvalError):
|
||||
obj["error"] = res.exc_str
|
||||
port.send_json(obj)
|
||||
return
|
||||
|
||||
compressed_source = getattr(res, "compressed_source")
|
||||
|
||||
obj["score"] = res.score
|
||||
obj["has_source"] = compressed_source is not None
|
||||
if res.hash is not None:
|
||||
obj["hash"] = res.hash
|
||||
if res.profiler is not None:
|
||||
obj["profiler"] = {
|
||||
st.name: res.profiler.time_stats[st] for st in Profiler.StatType
|
||||
}
|
||||
|
||||
port.send_json(obj)
|
||||
|
||||
if compressed_source is not None:
|
||||
port.send(compressed_source)
|
||||
|
||||
|
||||
def multiprocess_worker(
|
||||
worker_queue: "Queue[GlobalWork]",
|
||||
local_queue: "Queue[LocalWork]",
|
||||
task_queue: "Queue[Task]",
|
||||
) -> None:
|
||||
_fix_stdout()
|
||||
|
||||
# Prevent deadlocks in case the parent process dies.
|
||||
worker_queue.cancel_join_thread()
|
||||
local_queue.cancel_join_thread()
|
||||
task_queue.cancel_join_thread()
|
||||
|
||||
permuters: Dict[str, Permuter] = {}
|
||||
timestamp = 0
|
||||
|
||||
while True:
|
||||
work, required_timestamp = worker_queue.get()
|
||||
while True:
|
||||
try:
|
||||
block = timestamp < required_timestamp
|
||||
task, timestamp = local_queue.get(block=block)
|
||||
except queue.Empty:
|
||||
break
|
||||
if isinstance(task, AddPermuterLocal):
|
||||
permuters[task.perm_id] = task.permuter
|
||||
elif isinstance(task, RemovePermuter):
|
||||
del permuters[task.perm_id]
|
||||
else:
|
||||
static_assert_unreachable(task)
|
||||
|
||||
time_before = time.time()
|
||||
|
||||
permuter = permuters[work.perm_id]
|
||||
result = permuter.try_eval_candidate(work.seed)
|
||||
if isinstance(result, CandidateResult) and permuter.should_output(result):
|
||||
permuter.record_result(result)
|
||||
|
||||
# Compress the source within the worker. (Why waste a free
|
||||
# multi-threading opportunity?)
|
||||
if isinstance(result, CandidateResult):
|
||||
compressed_source: Optional[bytes] = None
|
||||
if result.source is not None:
|
||||
compressed_source = zlib.compress(result.source.encode("utf-8"))
|
||||
setattr(result, "compressed_source", compressed_source)
|
||||
result.source = None
|
||||
|
||||
time_us = int((time.time() - time_before) * 10 ** 6)
|
||||
task_queue.put(
|
||||
WorkDone(perm_id=work.perm_id, id=work.id, time_us=time_us, result=result)
|
||||
)
|
||||
|
||||
|
||||
def read_loop(task_queue: "Queue[Task]", port: Port) -> None:
|
||||
try:
|
||||
while True:
|
||||
item = port.receive_json()
|
||||
msg_type = json_prop(item, "type", str)
|
||||
if msg_type == "add":
|
||||
perm_id = json_prop(item, "permuter", str)
|
||||
source = port.receive().decode("utf-8")
|
||||
target_o_bin = port.receive()
|
||||
data = permuter_data_from_json(item, source, target_o_bin)
|
||||
task_queue.put(AddPermuter(perm_id=perm_id, data=data))
|
||||
|
||||
elif msg_type == "remove":
|
||||
perm_id = json_prop(item, "permuter", str)
|
||||
task_queue.put(RemovePermuter(perm_id=perm_id))
|
||||
|
||||
elif msg_type == "work":
|
||||
perm_id = json_prop(item, "permuter", str)
|
||||
id = json_prop(item, "id", int)
|
||||
seed = json_prop(item, "seed", int)
|
||||
task_queue.put(Work(perm_id=perm_id, id=id, seed=seed))
|
||||
|
||||
else:
|
||||
raise Exception(f"Invalid message type {msg_type}")
|
||||
|
||||
except Exception as e:
|
||||
# In case the port is closed from the other side, skip writing an ugly
|
||||
# error message.
|
||||
if not isinstance(e, EOFError):
|
||||
traceback.print_exc()
|
||||
|
||||
# Exit the whole process, to improve the odds that the Docker container
|
||||
# really stops and gets removed.
|
||||
#
|
||||
# The parent server has a "finally:" that does that, but it's not 100%
|
||||
# trustworthy. In particular, pystray has a tendency to hard-crash
|
||||
# (which doesn't fire "finally"s), and also reverts the signal handler
|
||||
# for SIGINT to the default on Linux, making Ctrl+C not run cleanup.
|
||||
# Either way, defense in depth here doesn't hurt, since leaking Docker
|
||||
# containers is pretty bad.
|
||||
#
|
||||
# Unfortunately this still doesn't fix the problem, since we typically
|
||||
# don't get a port closure signal when the parent process stops...
|
||||
# TODO: listen to heartbeats as well.
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
secret = base64.b64decode(os.environ["SECRET"])
|
||||
del os.environ["SECRET"]
|
||||
os.environ["PERMUTER_IS_REMOTE"] = "1"
|
||||
|
||||
port = _setup_port(secret)
|
||||
|
||||
obj = port.receive_json()
|
||||
num_cores = json_prop(obj, "num_cores", float)
|
||||
num_threads = math.ceil(num_cores)
|
||||
|
||||
worker_queue: "Queue[GlobalWork]" = Queue()
|
||||
task_queue: "Queue[Task]" = Queue()
|
||||
local_queues: "List[Queue[LocalWork]]" = []
|
||||
|
||||
for i in range(num_threads):
|
||||
local_queue: "Queue[LocalWork]" = Queue()
|
||||
p = Process(
|
||||
target=multiprocess_worker,
|
||||
args=(worker_queue, local_queue, task_queue),
|
||||
daemon=True,
|
||||
)
|
||||
p.start()
|
||||
local_queues.append(local_queue)
|
||||
|
||||
reader_thread = threading.Thread(
|
||||
target=read_loop, args=(task_queue, port), daemon=True
|
||||
)
|
||||
reader_thread.start()
|
||||
|
||||
remaining_work: Counter[str] = Counter()
|
||||
should_remove: Set[str] = set()
|
||||
permuters: Dict[str, Permuter] = {}
|
||||
timestamp = 0
|
||||
|
||||
def try_remove(perm_id: str) -> None:
|
||||
nonlocal timestamp
|
||||
assert perm_id in permuters
|
||||
if perm_id not in should_remove or remaining_work[perm_id] != 0:
|
||||
return
|
||||
del remaining_work[perm_id]
|
||||
should_remove.remove(perm_id)
|
||||
timestamp += 1
|
||||
for q in local_queues:
|
||||
q.put((RemovePermuter(perm_id=perm_id), timestamp))
|
||||
_remove_permuter(permuters[perm_id])
|
||||
del permuters[perm_id]
|
||||
|
||||
while True:
|
||||
item = task_queue.get()
|
||||
|
||||
if isinstance(item, AddPermuter):
|
||||
assert item.perm_id not in permuters
|
||||
|
||||
msg: Dict[str, object] = {
|
||||
"type": "init",
|
||||
"permuter": item.perm_id,
|
||||
}
|
||||
|
||||
time_before = time.time()
|
||||
try:
|
||||
# Construct a permuter. This involves a compilation on the main
|
||||
# thread, which isn't great but we can live with it for now.
|
||||
permuter = _create_permuter(item.data)
|
||||
|
||||
if permuter.base_score != item.data.base_score:
|
||||
_remove_permuter(permuter)
|
||||
score_str = f"{permuter.base_score} vs {item.data.base_score}"
|
||||
if permuter.base_hash == item.data.base_hash:
|
||||
hash_str = "same hash; different Python or permuter versions?"
|
||||
else:
|
||||
hash_str = "different hash; different objdump versions?"
|
||||
raise CandidateConstructionFailure(
|
||||
f"mismatching score: {score_str} ({hash_str})"
|
||||
)
|
||||
|
||||
permuters[item.perm_id] = permuter
|
||||
|
||||
msg["success"] = True
|
||||
msg["base_score"] = permuter.base_score
|
||||
msg["base_hash"] = permuter.base_hash
|
||||
|
||||
# Tell all the workers about the new permuter.
|
||||
# TODO: ideally we would also seed their Candidate lru_cache's
|
||||
# to avoid all workers having to parse the source...
|
||||
timestamp += 1
|
||||
for q in local_queues:
|
||||
q.put(
|
||||
(
|
||||
AddPermuterLocal(perm_id=item.perm_id, permuter=permuter),
|
||||
timestamp,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# This shouldn't practically happen, since the client compiled
|
||||
# the code successfully. Print a message if it does.
|
||||
msg["success"] = False
|
||||
msg["error"] = exception_to_string(e)
|
||||
if isinstance(e, CandidateConstructionFailure):
|
||||
print(e.message)
|
||||
else:
|
||||
traceback.print_exc()
|
||||
|
||||
msg["time_us"] = int((time.time() - time_before) * 10 ** 6)
|
||||
port.send_json(msg)
|
||||
|
||||
elif isinstance(item, RemovePermuter):
|
||||
# Silently ignore requests to remove permuters that have already
|
||||
# been removed, which can occur when AddPermuter fails.
|
||||
if item.perm_id in permuters:
|
||||
should_remove.add(item.perm_id)
|
||||
try_remove(item.perm_id)
|
||||
|
||||
elif isinstance(item, WorkDone):
|
||||
remaining_work[item.perm_id] -= 1
|
||||
try_remove(item.perm_id)
|
||||
_send_result(item, port)
|
||||
|
||||
elif isinstance(item, Work):
|
||||
remaining_work[item.perm_id] += 1
|
||||
worker_queue.put((item, timestamp))
|
||||
|
||||
else:
|
||||
static_assert_unreachable(item)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,944 +0,0 @@
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
import pathlib
|
||||
import queue
|
||||
import struct
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from typing import BinaryIO, Dict, Optional, Set, Tuple, Union, TYPE_CHECKING
|
||||
import zlib
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import docker
|
||||
|
||||
from nacl.secret import SecretBox
|
||||
import nacl.utils
|
||||
|
||||
from ..helpers import exception_to_string, static_assert_unreachable
|
||||
from .core import (
|
||||
CancelToken,
|
||||
Config,
|
||||
PermuterData,
|
||||
Port,
|
||||
ServerError,
|
||||
SocketPort,
|
||||
connect,
|
||||
file_read_fixed,
|
||||
json_prop,
|
||||
permuter_data_from_json,
|
||||
permuter_data_to_json,
|
||||
)
|
||||
|
||||
|
||||
_HEARTBEAT_INTERVAL_SLACK_SEC: float = 50.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Client:
|
||||
id: str
|
||||
nickname: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddPermuter:
|
||||
handle: int
|
||||
time_start: float
|
||||
client: Client
|
||||
permuter_data: PermuterData
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemovePermuter:
|
||||
handle: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Work:
|
||||
handle: int
|
||||
id: int
|
||||
time_start: float
|
||||
seed: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImmediateDisconnect:
|
||||
handle: int
|
||||
client: Client
|
||||
reason: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Disconnect:
|
||||
handle: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PermInitFail:
|
||||
perm_id: str
|
||||
error: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PermInitSuccess:
|
||||
perm_id: str
|
||||
base_score: int
|
||||
base_hash: str
|
||||
time_us: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkDone:
|
||||
perm_id: str
|
||||
id: int
|
||||
obj: dict
|
||||
time_us: int
|
||||
compressed_source: Optional[bytes]
|
||||
|
||||
|
||||
class NeedMoreWork:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class NetThreadDisconnected:
|
||||
graceful: bool
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class Heartbeat:
|
||||
pass
|
||||
|
||||
|
||||
class Shutdown:
|
||||
pass
|
||||
|
||||
|
||||
Activity = Union[
|
||||
AddPermuter,
|
||||
RemovePermuter,
|
||||
Work,
|
||||
ImmediateDisconnect,
|
||||
Disconnect,
|
||||
PermInitFail,
|
||||
PermInitSuccess,
|
||||
WorkDone,
|
||||
NeedMoreWork,
|
||||
NetThreadDisconnected,
|
||||
Heartbeat,
|
||||
Shutdown,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputInitFail:
|
||||
handle: int
|
||||
error: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputInitSuccess:
|
||||
handle: int
|
||||
time_us: int
|
||||
base_score: int
|
||||
base_hash: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputDisconnect:
|
||||
handle: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputNeedMoreWork:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputWork:
|
||||
handle: int
|
||||
time_start: float
|
||||
time_us: int
|
||||
obj: dict
|
||||
compressed_source: Optional[bytes]
|
||||
|
||||
|
||||
Output = Union[
|
||||
OutputDisconnect,
|
||||
OutputInitFail,
|
||||
OutputInitSuccess,
|
||||
OutputNeedMoreWork,
|
||||
OutputWork,
|
||||
Shutdown,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class IoConnect:
|
||||
fn_name: str
|
||||
client: Client
|
||||
|
||||
|
||||
@dataclass
|
||||
class IoDisconnect:
|
||||
reason: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class IoImmediateDisconnect:
|
||||
reason: str
|
||||
client: Client
|
||||
|
||||
|
||||
class IoUserRemovePermuter:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class IoServerFailed:
|
||||
graceful: bool
|
||||
message: Optional[str]
|
||||
|
||||
|
||||
class IoReconnect:
|
||||
pass
|
||||
|
||||
|
||||
class IoShutdown:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class IoWorkDone:
|
||||
score: Optional[int]
|
||||
is_improvement: bool
|
||||
|
||||
|
||||
PermuterHandle = Tuple[int, CancelToken]
|
||||
IoMessage = Union[
|
||||
IoConnect, IoDisconnect, IoImmediateDisconnect, IoUserRemovePermuter, IoWorkDone
|
||||
]
|
||||
IoGlobalMessage = Union[IoReconnect, IoShutdown, IoServerFailed]
|
||||
IoActivity = Tuple[
|
||||
Optional[CancelToken], Union[Tuple[PermuterHandle, IoMessage], IoGlobalMessage]
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerOptions:
|
||||
num_cores: float
|
||||
max_memory_gb: float
|
||||
min_priority: float
|
||||
|
||||
|
||||
class NetThread:
|
||||
_port: Optional[SocketPort]
|
||||
_main_queue: "queue.Queue[Activity]"
|
||||
_controller_queue: "queue.Queue[Output]"
|
||||
_read_thread: "threading.Thread"
|
||||
_write_thread: "threading.Thread"
|
||||
_next_work_id: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: SocketPort,
|
||||
main_queue: "queue.Queue[Activity]",
|
||||
) -> None:
|
||||
self._port = port
|
||||
self._main_queue = main_queue
|
||||
self._controller_queue = queue.Queue()
|
||||
self._next_work_id = 0
|
||||
|
||||
self._read_thread = threading.Thread(target=self.read_loop, daemon=True)
|
||||
self._read_thread.start()
|
||||
|
||||
self._write_thread = threading.Thread(target=self.write_loop, daemon=True)
|
||||
self._write_thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
if self._port is None:
|
||||
return
|
||||
try:
|
||||
self._controller_queue.put(Shutdown())
|
||||
self._port.shutdown()
|
||||
self._read_thread.join()
|
||||
self._write_thread.join()
|
||||
self._port.close()
|
||||
self._port = None
|
||||
except Exception:
|
||||
print("Failed to stop net thread.")
|
||||
traceback.print_exc()
|
||||
|
||||
def send_controller(self, msg: Output) -> None:
|
||||
self._controller_queue.put(msg)
|
||||
|
||||
def _read_one(self) -> Activity:
|
||||
assert self._port is not None
|
||||
|
||||
msg = self._port.receive_json()
|
||||
time_start = time.time()
|
||||
|
||||
msg_type = json_prop(msg, "type", str)
|
||||
|
||||
if msg_type == "heartbeat":
|
||||
return Heartbeat()
|
||||
|
||||
handle = json_prop(msg, "permuter", int)
|
||||
|
||||
if msg_type == "work":
|
||||
seed = json_prop(msg, "seed", int)
|
||||
id = self._next_work_id
|
||||
self._next_work_id += 1
|
||||
return Work(handle=handle, id=id, time_start=time_start, seed=seed)
|
||||
|
||||
elif msg_type == "add":
|
||||
client_id = json_prop(msg, "client_id", str)
|
||||
client_name = json_prop(msg, "client_name", str)
|
||||
client = Client(client_id, client_name)
|
||||
data = json_prop(msg, "data", dict)
|
||||
compressed_source = self._port.receive()
|
||||
compressed_target_o_bin = self._port.receive()
|
||||
|
||||
try:
|
||||
source = zlib.decompress(compressed_source).decode("utf-8")
|
||||
target_o_bin = zlib.decompress(compressed_target_o_bin)
|
||||
permuter = permuter_data_from_json(data, source, target_o_bin)
|
||||
except Exception as e:
|
||||
# Client sent something illegible. This can legitimately happen if the
|
||||
# client runs another version, but it's interesting to log.
|
||||
traceback.print_exc()
|
||||
return ImmediateDisconnect(
|
||||
handle=handle,
|
||||
client=client,
|
||||
reason=f"Failed to parse permuter: {exception_to_string(e)}",
|
||||
)
|
||||
|
||||
return AddPermuter(
|
||||
handle=handle,
|
||||
time_start=time_start,
|
||||
client=client,
|
||||
permuter_data=permuter,
|
||||
)
|
||||
|
||||
elif msg_type == "remove":
|
||||
return RemovePermuter(handle=handle)
|
||||
|
||||
else:
|
||||
raise Exception(f"Bad message type: {msg_type}")
|
||||
|
||||
def read_loop(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
msg = self._read_one()
|
||||
self._main_queue.put(msg)
|
||||
except EOFError:
|
||||
self._main_queue.put(NetThreadDisconnected(graceful=True))
|
||||
except ServerError as e:
|
||||
self._main_queue.put(
|
||||
NetThreadDisconnected(graceful=False, message=e.message)
|
||||
)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
self._main_queue.put(NetThreadDisconnected(graceful=False))
|
||||
|
||||
def _write_one(self, item: Output) -> None:
|
||||
assert self._port is not None
|
||||
|
||||
if isinstance(item, Shutdown):
|
||||
# Handled by caller
|
||||
pass
|
||||
|
||||
elif isinstance(item, OutputInitFail):
|
||||
self._port.send_json(
|
||||
{
|
||||
"type": "update",
|
||||
"permuter": item.handle,
|
||||
"time_us": 0,
|
||||
"update": {"type": "init_failed", "reason": item.error},
|
||||
}
|
||||
)
|
||||
|
||||
elif isinstance(item, OutputInitSuccess):
|
||||
self._port.send_json(
|
||||
{
|
||||
"type": "update",
|
||||
"permuter": item.handle,
|
||||
"time_us": item.time_us,
|
||||
"update": {"type": "init_done", "hash": item.base_hash},
|
||||
}
|
||||
)
|
||||
|
||||
elif isinstance(item, OutputDisconnect):
|
||||
self._port.send_json(
|
||||
{
|
||||
"type": "update",
|
||||
"permuter": item.handle,
|
||||
"time_us": 0,
|
||||
"update": {"type": "disconnect"},
|
||||
}
|
||||
)
|
||||
|
||||
elif isinstance(item, OutputNeedMoreWork):
|
||||
self._port.send_json({"type": "need_work"})
|
||||
|
||||
elif isinstance(item, OutputWork):
|
||||
overhead_us = int((time.time() - item.time_start) * 10 ** 6) - item.time_us
|
||||
self._port.send_json(
|
||||
{
|
||||
"type": "update",
|
||||
"permuter": item.handle,
|
||||
"time_us": item.time_us,
|
||||
"update": {
|
||||
"type": "work",
|
||||
"overhead_us": overhead_us,
|
||||
**item.obj,
|
||||
},
|
||||
}
|
||||
)
|
||||
if item.compressed_source is not None:
|
||||
self._port.send(item.compressed_source)
|
||||
|
||||
else:
|
||||
static_assert_unreachable(item)
|
||||
|
||||
def write_loop(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
item = self._controller_queue.get()
|
||||
if isinstance(item, Shutdown):
|
||||
break
|
||||
self._write_one(item)
|
||||
except EOFError:
|
||||
self._main_queue.put(NetThreadDisconnected(graceful=True))
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
self._main_queue.put(NetThreadDisconnected(graceful=False))
|
||||
|
||||
|
||||
class ServerInner:
|
||||
"""This class represents an up-and-running server, connected to the controller and
|
||||
to the evaluator."""
|
||||
|
||||
_evaluator_port: "DockerPort"
|
||||
_main_queue: "queue.Queue[Activity]"
|
||||
_io_queue: "queue.Queue[IoActivity]"
|
||||
_net_thread: NetThread
|
||||
_read_eval_thread: threading.Thread
|
||||
_main_thread: threading.Thread
|
||||
_heartbeat_interval: float
|
||||
_last_heartbeat: float
|
||||
_last_heartbeat_lock: threading.Lock
|
||||
_active: Set[int]
|
||||
_time_starts: Dict[int, float]
|
||||
_token: CancelToken
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
net_port: SocketPort,
|
||||
evaluator_port: "DockerPort",
|
||||
io_queue: "queue.Queue[IoActivity]",
|
||||
heartbeat_interval: float,
|
||||
) -> None:
|
||||
self._evaluator_port = evaluator_port
|
||||
self._main_queue = queue.Queue()
|
||||
self._io_queue = io_queue
|
||||
self._active = set()
|
||||
self._time_starts = {}
|
||||
self._token = CancelToken()
|
||||
|
||||
self._net_thread = NetThread(net_port, self._main_queue)
|
||||
|
||||
# Start a thread for checking heartbeats.
|
||||
self._heartbeat_interval = heartbeat_interval
|
||||
self._last_heartbeat = time.time()
|
||||
self._last_heartbeat_lock = threading.Lock()
|
||||
self._heartbeat_stop = threading.Event()
|
||||
self._heartbeat_thread = threading.Thread(
|
||||
target=self._heartbeat_loop, daemon=True
|
||||
)
|
||||
self._heartbeat_thread.start()
|
||||
|
||||
# Start a thread for reading evaluator results and sending them on to
|
||||
# the main loop queue.
|
||||
self._read_eval_thread = threading.Thread(
|
||||
target=self._read_eval_loop, daemon=True
|
||||
)
|
||||
self._read_eval_thread.start()
|
||||
|
||||
# Start a thread for the main loop.
|
||||
self._main_thread = threading.Thread(target=self._main_loop, daemon=True)
|
||||
self._main_thread.start()
|
||||
|
||||
def _send_controller(self, msg: Output) -> None:
|
||||
self._net_thread.send_controller(msg)
|
||||
|
||||
def _send_io(self, handle: int, io_msg: IoMessage) -> None:
|
||||
self._io_queue.put((self._token, ((handle, self._token), io_msg)))
|
||||
|
||||
def _send_io_global(self, io_msg: IoGlobalMessage) -> None:
|
||||
self._io_queue.put((self._token, io_msg))
|
||||
|
||||
def _handle_message(self, msg: Activity) -> None:
|
||||
if isinstance(msg, Shutdown):
|
||||
# Handled by caller
|
||||
pass
|
||||
|
||||
elif isinstance(msg, Heartbeat):
|
||||
with self._last_heartbeat_lock:
|
||||
self._last_heartbeat = time.time()
|
||||
|
||||
elif isinstance(msg, Work):
|
||||
if msg.handle not in self._active:
|
||||
self._need_work()
|
||||
return
|
||||
|
||||
self._time_starts[msg.id] = msg.time_start
|
||||
self._evaluator_port.send_json(
|
||||
{
|
||||
"type": "work",
|
||||
"permuter": str(msg.handle),
|
||||
"id": msg.id,
|
||||
"seed": msg.seed,
|
||||
}
|
||||
)
|
||||
|
||||
elif isinstance(msg, AddPermuter):
|
||||
if msg.handle in self._active:
|
||||
raise Exception("Repeated AddPermuter!")
|
||||
|
||||
self._active.add(msg.handle)
|
||||
self._send_permuter(str(msg.handle), msg.permuter_data)
|
||||
fn_name = msg.permuter_data.fn_name
|
||||
self._send_io(msg.handle, IoConnect(fn_name, msg.client))
|
||||
|
||||
elif isinstance(msg, RemovePermuter):
|
||||
if msg.handle not in self._active:
|
||||
return
|
||||
|
||||
self._remove(msg.handle)
|
||||
self._send_io(msg.handle, IoDisconnect("disconnected"))
|
||||
|
||||
elif isinstance(msg, Disconnect):
|
||||
if msg.handle not in self._active:
|
||||
return
|
||||
|
||||
self._remove(msg.handle)
|
||||
self._send_io(msg.handle, IoDisconnect("kicked"))
|
||||
self._send_controller(OutputDisconnect(handle=msg.handle))
|
||||
|
||||
elif isinstance(msg, ImmediateDisconnect):
|
||||
if msg.handle in self._active:
|
||||
raise Exception("ImmediateDisconnect is not immediate")
|
||||
|
||||
self._send_io(msg.handle, IoImmediateDisconnect(msg.reason, msg.client))
|
||||
self._send_controller(OutputInitFail(handle=msg.handle, error=msg.reason))
|
||||
|
||||
elif isinstance(msg, PermInitFail):
|
||||
handle = int(msg.perm_id)
|
||||
if handle not in self._active:
|
||||
self._need_work()
|
||||
return
|
||||
|
||||
self._active.remove(handle)
|
||||
self._send_io(handle, IoDisconnect("failed to compile"))
|
||||
self._send_controller(
|
||||
OutputInitFail(
|
||||
handle=handle,
|
||||
error=msg.error,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(msg, PermInitSuccess):
|
||||
handle = int(msg.perm_id)
|
||||
if handle not in self._active:
|
||||
self._need_work()
|
||||
return
|
||||
|
||||
self._send_controller(
|
||||
OutputInitSuccess(
|
||||
handle=handle,
|
||||
time_us=msg.time_us,
|
||||
base_score=msg.base_score,
|
||||
base_hash=msg.base_hash,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(msg, WorkDone):
|
||||
handle = int(msg.perm_id)
|
||||
time_start = self._time_starts.pop(msg.id)
|
||||
if handle not in self._active:
|
||||
self._need_work()
|
||||
return
|
||||
|
||||
obj = msg.obj
|
||||
obj["permuter"] = handle
|
||||
score = json_prop(obj, "score", int) if "score" in obj else None
|
||||
is_improvement = msg.compressed_source is not None
|
||||
self._send_io(
|
||||
handle,
|
||||
IoWorkDone(score=score, is_improvement=is_improvement),
|
||||
)
|
||||
self._send_controller(
|
||||
OutputWork(
|
||||
handle=handle,
|
||||
time_start=time_start,
|
||||
time_us=msg.time_us,
|
||||
obj=obj,
|
||||
compressed_source=msg.compressed_source,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(msg, NeedMoreWork):
|
||||
self._need_work()
|
||||
|
||||
elif isinstance(msg, NetThreadDisconnected):
|
||||
self._send_io_global(IoServerFailed(msg.graceful, msg.message))
|
||||
|
||||
else:
|
||||
static_assert_unreachable(msg)
|
||||
|
||||
def _need_work(self) -> None:
|
||||
self._send_controller(OutputNeedMoreWork())
|
||||
|
||||
def _remove(self, handle: int) -> None:
|
||||
self._evaluator_port.send_json({"type": "remove", "permuter": str(handle)})
|
||||
self._active.remove(handle)
|
||||
|
||||
def _send_permuter(self, perm_id: str, perm: PermuterData) -> None:
|
||||
self._evaluator_port.send_json(
|
||||
{
|
||||
"type": "add",
|
||||
"permuter": perm_id,
|
||||
**permuter_data_to_json(perm),
|
||||
}
|
||||
)
|
||||
self._evaluator_port.send(perm.source.encode("utf-8"))
|
||||
self._evaluator_port.send(perm.target_o_bin)
|
||||
|
||||
def _do_read_eval_loop(self) -> None:
|
||||
while True:
|
||||
msg = self._evaluator_port.receive_json()
|
||||
msg_type = json_prop(msg, "type", str)
|
||||
|
||||
if msg_type == "init":
|
||||
perm_id = json_prop(msg, "permuter", str)
|
||||
time_us = json_prop(msg, "time_us", int)
|
||||
resp: Activity
|
||||
if json_prop(msg, "success", bool):
|
||||
resp = PermInitSuccess(
|
||||
perm_id=perm_id,
|
||||
base_score=json_prop(msg, "base_score", int),
|
||||
base_hash=json_prop(msg, "base_hash", str),
|
||||
time_us=time_us,
|
||||
)
|
||||
else:
|
||||
resp = PermInitFail(
|
||||
perm_id=perm_id,
|
||||
error=json_prop(msg, "error", str),
|
||||
)
|
||||
self._main_queue.put(resp)
|
||||
|
||||
elif msg_type == "result":
|
||||
compressed_source: Optional[bytes] = None
|
||||
if msg.get("has_source") == True:
|
||||
compressed_source = self._evaluator_port.receive()
|
||||
perm_id = json_prop(msg, "permuter", str)
|
||||
id = json_prop(msg, "id", int)
|
||||
time_us = json_prop(msg, "time_us", int)
|
||||
del msg["permuter"]
|
||||
del msg["id"]
|
||||
del msg["time_us"]
|
||||
self._main_queue.put(
|
||||
WorkDone(
|
||||
perm_id=perm_id,
|
||||
id=id,
|
||||
obj=msg,
|
||||
time_us=time_us,
|
||||
compressed_source=compressed_source,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception(f"Unknown message type from evaluator: {msg_type}")
|
||||
|
||||
def _read_eval_loop(self) -> None:
|
||||
try:
|
||||
self._do_read_eval_loop()
|
||||
except EOFError:
|
||||
# Silence errors from shutdown.
|
||||
pass
|
||||
|
||||
def _main_loop(self) -> None:
|
||||
while True:
|
||||
msg = self._main_queue.get()
|
||||
if isinstance(msg, Shutdown):
|
||||
break
|
||||
|
||||
self._handle_message(msg)
|
||||
|
||||
def _heartbeat_loop(self) -> None:
|
||||
second_attempt = False
|
||||
while True:
|
||||
with self._last_heartbeat_lock:
|
||||
delay = (
|
||||
self._last_heartbeat
|
||||
+ self._heartbeat_interval
|
||||
+ _HEARTBEAT_INTERVAL_SLACK_SEC / 2
|
||||
- time.time()
|
||||
)
|
||||
if delay <= 0:
|
||||
if second_attempt:
|
||||
self._main_queue.put(NetThreadDisconnected(graceful=True))
|
||||
return
|
||||
# Handle clock skew or computer going to sleep by waiting a bit
|
||||
# longer before giving up.
|
||||
second_attempt = True
|
||||
if self._heartbeat_stop.wait(_HEARTBEAT_INTERVAL_SLACK_SEC / 2):
|
||||
return
|
||||
else:
|
||||
second_attempt = False
|
||||
if self._heartbeat_stop.wait(delay):
|
||||
return
|
||||
|
||||
def remove_permuter(self, handle: int) -> None:
|
||||
assert not self._token.cancelled
|
||||
self._main_queue.put(Disconnect(handle=handle))
|
||||
|
||||
def stop(self) -> None:
|
||||
assert not self._token.cancelled
|
||||
self._token.cancelled = True
|
||||
self._main_queue.put(Shutdown())
|
||||
self._heartbeat_stop.set()
|
||||
self._net_thread.stop()
|
||||
self._evaluator_port.shutdown()
|
||||
self._main_thread.join()
|
||||
self._heartbeat_thread.join()
|
||||
|
||||
|
||||
class DockerPort(Port):
|
||||
"""Port for communicating with Docker. Communication is encrypted for a few
|
||||
not-very-good reasons:
|
||||
- it allows code reuse
|
||||
- it adds error-checking
|
||||
- it was fun to implement"""
|
||||
|
||||
_sock: BinaryIO
|
||||
_container: "docker.models.containers.Container"
|
||||
_stdout_buffer: bytes
|
||||
_closed: bool
|
||||
|
||||
def __init__(
|
||||
self, container: "docker.models.containers.Container", secret: bytes
|
||||
) -> None:
|
||||
self._container = container
|
||||
self._stdout_buffer = b""
|
||||
self._closed = False
|
||||
|
||||
# Set up a socket for reading from stdout/stderr and writing to
|
||||
# stdin for the container. The docker package does not seem to
|
||||
# expose an API for writing the stdin, but we can do so directly
|
||||
# by attaching a socket and poking at internal state. (See
|
||||
# https://github.com/docker/docker-py/issues/983.) For stdout/
|
||||
# stderr, we use the format described at
|
||||
# https://docs.docker.com/engine/api/v1.24/#attach-to-a-container.
|
||||
#
|
||||
# Hopefully this will keep working for at least a while...
|
||||
try:
|
||||
self._sock = container.attach_socket(
|
||||
params={"stdout": True, "stdin": True, "stderr": True, "stream": True}
|
||||
)
|
||||
self._sock._writing = True # type: ignore
|
||||
except:
|
||||
try:
|
||||
container.remove(force=True)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
super().__init__(SecretBox(secret), "docker", is_client=True)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
import docker
|
||||
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
self._sock.close()
|
||||
self._container.remove(force=True)
|
||||
except Exception as e:
|
||||
if not (
|
||||
isinstance(e, docker.errors.APIError)
|
||||
and e.status_code == 409
|
||||
and "is already in progress" in str(e)
|
||||
):
|
||||
print("Failed to shut down Docker")
|
||||
traceback.print_exc()
|
||||
|
||||
def _read_one(self) -> None:
|
||||
header = file_read_fixed(self._sock, 8)
|
||||
stream, length = struct.unpack(">BxxxI", header)
|
||||
if stream not in [1, 2]:
|
||||
raise Exception("Unexpected output from Docker: " + repr(header))
|
||||
data = file_read_fixed(self._sock, length)
|
||||
if stream == 1:
|
||||
self._stdout_buffer += data
|
||||
else:
|
||||
sys.stderr.buffer.write(b"Docker stderr: " + data)
|
||||
sys.stderr.buffer.flush()
|
||||
|
||||
def _receive(self, length: int) -> bytes:
|
||||
while len(self._stdout_buffer) < length:
|
||||
self._read_one()
|
||||
ret = self._stdout_buffer[:length]
|
||||
self._stdout_buffer = self._stdout_buffer[length:]
|
||||
return ret
|
||||
|
||||
def _receive_max(self, length: int) -> bytes:
|
||||
length = min(length, len(self._stdout_buffer))
|
||||
ret = self._stdout_buffer[:length]
|
||||
self._stdout_buffer = self._stdout_buffer[length:]
|
||||
return ret
|
||||
|
||||
def _send(self, data: bytes) -> None:
|
||||
while data:
|
||||
written = self._sock.write(data)
|
||||
data = data[written:]
|
||||
self._sock.flush()
|
||||
|
||||
|
||||
def _start_evaluator(docker_image: str, options: ServerOptions) -> DockerPort:
|
||||
"""Spawn a docker container and set it up to evaluate permutations in,
|
||||
returning a handle that we can use to communicate with it.
|
||||
|
||||
We do this for a few reasons:
|
||||
- enforcing a known Linux environment, all while the outside server can run
|
||||
on e.g. Windows and display a systray
|
||||
- enforcing resource limits
|
||||
- sandboxing
|
||||
|
||||
Docker does have the downside of requiring root access, so ideally we would
|
||||
also have a Docker-less mode, where we leave the sandboxing to some other
|
||||
tool, e.g. https://github.com/ioi/isolate/."""
|
||||
print("Starting docker...")
|
||||
command = ["python3", "-m", "src.net.evaluator"]
|
||||
secret = nacl.utils.random(32)
|
||||
enc_secret = base64.b64encode(secret).decode("utf-8")
|
||||
src_path = pathlib.Path(__file__).parent.parent.absolute()
|
||||
|
||||
try:
|
||||
import docker
|
||||
|
||||
client = docker.from_env()
|
||||
client.info()
|
||||
except ModuleNotFoundError:
|
||||
print(
|
||||
"Running a server requires the docker Python package to be installed.\n"
|
||||
"Run `python3 -m pip install --upgrade docker`."
|
||||
)
|
||||
sys.exit(1)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
print()
|
||||
print(
|
||||
"Failed to start docker. Make sure you have docker installed and "
|
||||
"the docker daemon running, and either run the permuter with sudo "
|
||||
'or add yourself to the "docker" UNIX group.'
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
container = client.containers.run(
|
||||
docker_image,
|
||||
command,
|
||||
detach=True,
|
||||
remove=True,
|
||||
stdin_open=True,
|
||||
stdout=True,
|
||||
environment={"SECRET": enc_secret},
|
||||
volumes={src_path: {"bind": "/src", "mode": "ro"}},
|
||||
tmpfs={"/tmp": "size=1G,exec"},
|
||||
nano_cpus=int(options.num_cores * 1e9),
|
||||
mem_limit=int(options.max_memory_gb * 2 ** 30),
|
||||
read_only=True,
|
||||
network_disabled=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to start docker container: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
port = DockerPort(container, secret)
|
||||
|
||||
try:
|
||||
# Sanity-check that the Docker container started successfully and can
|
||||
# be communicated with.
|
||||
magic = b"\0" * 1000
|
||||
port.send(magic)
|
||||
r = port.receive()
|
||||
if r != magic:
|
||||
raise Exception("Failed initial sanity check.")
|
||||
|
||||
port.send_json({"num_cores": options.num_cores})
|
||||
except:
|
||||
port.shutdown()
|
||||
raise
|
||||
|
||||
print("Started.")
|
||||
return port
|
||||
|
||||
|
||||
class Server:
|
||||
"""This class represents a server that may or may not be connected to the
|
||||
controller and the evaluator."""
|
||||
|
||||
_server: Optional[ServerInner]
|
||||
_options: ServerOptions
|
||||
_config: Config
|
||||
_io_queue: "queue.Queue[IoActivity]"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
options: ServerOptions,
|
||||
config: Config,
|
||||
io_queue: "queue.Queue[IoActivity]",
|
||||
) -> None:
|
||||
self._server = None
|
||||
self._options = options
|
||||
self._config = config
|
||||
self._io_queue = io_queue
|
||||
|
||||
def start(self) -> None:
|
||||
assert self._server is None
|
||||
|
||||
net_port = connect(self._config)
|
||||
net_port.send_json(
|
||||
{
|
||||
"method": "connect_server",
|
||||
"min_priority": self._options.min_priority,
|
||||
"num_cores": self._options.num_cores,
|
||||
}
|
||||
)
|
||||
obj = net_port.receive_json()
|
||||
docker_image = json_prop(obj, "docker_image", str)
|
||||
heartbeat_interval = json_prop(obj, "heartbeat_interval", float)
|
||||
|
||||
evaluator_port = _start_evaluator(docker_image, self._options)
|
||||
|
||||
try:
|
||||
self._server = ServerInner(
|
||||
net_port, evaluator_port, self._io_queue, heartbeat_interval
|
||||
)
|
||||
except:
|
||||
evaluator_port.shutdown()
|
||||
raise
|
||||
|
||||
def stop(self) -> None:
|
||||
if self._server is None:
|
||||
return
|
||||
self._server.stop()
|
||||
self._server = None
|
||||
|
||||
def remove_permuter(self, handle: PermuterHandle) -> None:
|
||||
if self._server is not None and not handle[1].cancelled:
|
||||
self._server.remove_permuter(handle[0])
|
||||
@@ -1,224 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import List, Match, Pattern, Set, Tuple
|
||||
|
||||
# Ignore registers, for cleaner output. (We don't do this right now, but it can
|
||||
# be useful for debugging.)
|
||||
ign_regs = False
|
||||
|
||||
# Don't include branch targets in the output. Assuming our input is semantically
|
||||
# equivalent skipping it shouldn't be an issue, and it makes insertions have too
|
||||
# large effect.
|
||||
ign_branch_targets = True
|
||||
|
||||
# Skip branch-likely delay slots. (They aren't interesting on IDO.)
|
||||
skip_bl_delay_slots = True
|
||||
|
||||
skip_lines = 1
|
||||
re_int = re.compile(r"[0-9]+")
|
||||
re_int_full = re.compile(r"\b[0-9]+\b")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArchSettings:
|
||||
objdump: List[str]
|
||||
re_comment: Pattern[str]
|
||||
re_reg: Pattern[str]
|
||||
re_sprel: Pattern[str]
|
||||
re_includes_sp: Pattern[str]
|
||||
branch_instructions: Set[str]
|
||||
forbidden: Set[str] = field(default_factory=lambda: set(string.ascii_letters + "_"))
|
||||
branch_likely_instructions: Set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
MIPS_BRANCH_LIKELY_INSTRUCTIONS = {
|
||||
"beql",
|
||||
"bnel",
|
||||
"beqzl",
|
||||
"bnezl",
|
||||
"bgezl",
|
||||
"bgtzl",
|
||||
"blezl",
|
||||
"bltzl",
|
||||
"bc1tl",
|
||||
"bc1fl",
|
||||
}
|
||||
MIPS_BRANCH_INSTRUCTIONS = {
|
||||
"b",
|
||||
"j",
|
||||
"beq",
|
||||
"bne",
|
||||
"beqz",
|
||||
"bnez",
|
||||
"bgez",
|
||||
"bgtz",
|
||||
"blez",
|
||||
"bltz",
|
||||
"bc1t",
|
||||
"bc1f",
|
||||
}.union(MIPS_BRANCH_LIKELY_INSTRUCTIONS)
|
||||
|
||||
MIPS_SETTINGS: ArchSettings = ArchSettings(
|
||||
re_comment=re.compile(r"<.*?>"),
|
||||
re_reg=re.compile(
|
||||
r"\$?\b(a[0-3]|t[0-9]|s[0-8]|at|v[01]|f[12]?[0-9]|f3[01]|k[01]|fp|ra)\b" # leave out $zero
|
||||
),
|
||||
re_sprel=re.compile(r"(?<=,)([0-9]+|0x[0-9a-f]+)\((sp|s8)\)"),
|
||||
re_includes_sp=re.compile(r"\b(sp|s8)\b"),
|
||||
objdump=["mips-linux-gnu-objdump", "-drz", "-m", "mips:4300"],
|
||||
branch_likely_instructions=MIPS_BRANCH_LIKELY_INSTRUCTIONS,
|
||||
branch_instructions=MIPS_BRANCH_INSTRUCTIONS,
|
||||
)
|
||||
|
||||
|
||||
def get_arch(o_file: str) -> ArchSettings:
|
||||
# https://refspecs.linuxfoundation.org/elf/gabi4+/ch4.eheader.html
|
||||
with open(o_file, "rb") as f:
|
||||
f.seek(18)
|
||||
arch_magic = f.read(2)
|
||||
if arch_magic == b"\0\x08":
|
||||
return MIPS_SETTINGS
|
||||
# TODO: support PPC ("\0\x14"), ARM ("0\x28")
|
||||
raise Exception("Bad ELF")
|
||||
|
||||
|
||||
def parse_relocated_line(line: str) -> Tuple[str, str, str]:
|
||||
try:
|
||||
ind2 = line.rindex(",")
|
||||
except ValueError:
|
||||
ind2 = line.rindex("\t")
|
||||
before = line[: ind2 + 1]
|
||||
after = line[ind2 + 1 :]
|
||||
ind2 = after.find("(")
|
||||
if ind2 == -1:
|
||||
imm, after = after, ""
|
||||
else:
|
||||
imm, after = after[:ind2], after[ind2:]
|
||||
if imm == "0x0":
|
||||
imm = "0"
|
||||
return before, imm, after
|
||||
|
||||
|
||||
def simplify_objdump(
|
||||
input_lines: List[str], arch: ArchSettings, *, stack_differences: bool
|
||||
) -> List[str]:
|
||||
output_lines: List[str] = []
|
||||
nops = 0
|
||||
skip_next = False
|
||||
for index, row in enumerate(input_lines):
|
||||
if index < skip_lines:
|
||||
continue
|
||||
row = row.rstrip()
|
||||
if ">:" in row or not row:
|
||||
continue
|
||||
if "R_MIPS_" in row:
|
||||
prev = output_lines[-1]
|
||||
if prev == "<skipped>":
|
||||
continue
|
||||
before, imm, after = parse_relocated_line(prev)
|
||||
repl = row.split()[-1]
|
||||
# As part of ignoring branch targets, we ignore relocations for j
|
||||
# instructions. The target is already lost anyway.
|
||||
if imm == "<target>":
|
||||
assert ign_branch_targets
|
||||
continue
|
||||
# Sometimes s8 is used as a non-framepointer, but we've already lost
|
||||
# the immediate value by pretending it is one. This isn't too bad,
|
||||
# since it's rare and applies consistently. But we do need to handle it
|
||||
# here to avoid a crash, by pretending that lost imms are zero for
|
||||
# relocations.
|
||||
if imm != "0" and imm != "imm" and imm != "addr":
|
||||
repl += "+" + imm if int(imm, 0) > 0 else imm
|
||||
if any(
|
||||
reloc in row
|
||||
for reloc in ["R_MIPS_LO16", "R_MIPS_LITERAL", "R_MIPS_GPREL16"]
|
||||
):
|
||||
repl = f"%lo({repl})"
|
||||
elif "R_MIPS_HI16" in row:
|
||||
# Ideally we'd pair up R_MIPS_LO16 and R_MIPS_HI16 to generate a
|
||||
# correct addend for each, but objdump doesn't give us the order of
|
||||
# the relocations, so we can't find the right LO16. :(
|
||||
repl = f"%hi({repl})"
|
||||
else:
|
||||
assert "R_MIPS_26" in row, f"unknown relocation type '{row}'"
|
||||
output_lines[-1] = before + repl + after
|
||||
continue
|
||||
row = re.sub(arch.re_comment, "", row)
|
||||
row = row.rstrip()
|
||||
row = "\t".join(row.split("\t")[2:]) # [20:]
|
||||
if not row:
|
||||
continue
|
||||
if skip_next:
|
||||
skip_next = False
|
||||
row = "<skipped>"
|
||||
if ign_regs:
|
||||
row = re.sub(arch.re_reg, "<reg>", row)
|
||||
row_parts = row.split("\t")
|
||||
if len(row_parts) == 1:
|
||||
row_parts.append("")
|
||||
mnemonic, instr_args = row_parts
|
||||
if not stack_differences:
|
||||
if mnemonic == "addiu" and arch.re_includes_sp.search(instr_args):
|
||||
row = re.sub(re_int_full, "imm", row)
|
||||
if mnemonic in arch.branch_instructions:
|
||||
if ign_branch_targets:
|
||||
instr_parts = instr_args.split(",")
|
||||
instr_parts[-1] = "<target>"
|
||||
instr_args = ",".join(instr_parts)
|
||||
row = f"{mnemonic}\t{instr_args}"
|
||||
# The last part is in hex, so skip the dec->hex conversion
|
||||
else:
|
||||
|
||||
def fn(pat: Match[str]) -> str:
|
||||
full = pat.group(0)
|
||||
if len(full) <= 1:
|
||||
return full
|
||||
start, end = pat.span()
|
||||
if start and row[start - 1] in arch.forbidden:
|
||||
return full
|
||||
if end < len(row) and row[end] in arch.forbidden:
|
||||
return full
|
||||
return hex(int(full))
|
||||
|
||||
row = re.sub(re_int, fn, row)
|
||||
if mnemonic in arch.branch_likely_instructions and skip_bl_delay_slots:
|
||||
skip_next = True
|
||||
if not stack_differences:
|
||||
row = re.sub(arch.re_sprel, "addr(sp)", row)
|
||||
# row = row.replace(',', ', ')
|
||||
if row == "nop":
|
||||
# strip trailing nops; padding is irrelevant to us
|
||||
nops += 1
|
||||
else:
|
||||
for _ in range(nops):
|
||||
output_lines.append("nop")
|
||||
nops = 0
|
||||
output_lines.append(row)
|
||||
return output_lines
|
||||
|
||||
|
||||
def objdump(
|
||||
o_filename: str, arch: ArchSettings, *, stack_differences: bool = False
|
||||
) -> List[str]:
|
||||
output = subprocess.check_output(arch.objdump + [o_filename])
|
||||
lines = output.decode("utf-8").splitlines()
|
||||
return simplify_objdump(lines, arch, stack_differences=stack_differences)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: {sys.argv[0]} file.o", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not os.path.isfile(sys.argv[1]):
|
||||
print(f"Source file {sys.argv[1]} is not readable.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
lines = objdump(sys.argv[1], MIPS_SETTINGS)
|
||||
for row in lines:
|
||||
print(row)
|
||||
@@ -1,58 +0,0 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from .. import ast_util
|
||||
from .perm import EvalState, Perm
|
||||
from ..ast_util import Block, Statement
|
||||
from ..error import CandidateConstructionFailure
|
||||
|
||||
from pycparser import c_ast as ca
|
||||
|
||||
|
||||
class _Done(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _apply_perm(fn: ca.FuncDef, perm_id: int, perm: Perm, seed: int) -> None:
|
||||
"""Find and apply a single late perm macro in the AST."""
|
||||
# Currently we search for statement macros only.
|
||||
wanted_pragma = f"_permuter ast_perm {perm_id}"
|
||||
Loc = Tuple[List[Statement], int]
|
||||
|
||||
def try_handle_block(block: ca.Node, where: Optional[Loc]) -> None:
|
||||
if not isinstance(block, ca.Compound) or not block.block_items:
|
||||
return
|
||||
pragma = block.block_items[0]
|
||||
if not isinstance(pragma, ca.Pragma) or pragma.string != wanted_pragma:
|
||||
return
|
||||
|
||||
args: List[Statement] = block.block_items[1:]
|
||||
stmts = perm.eval_statement_ast(args, seed)
|
||||
|
||||
if where:
|
||||
where[0][where[1] : where[1] + 1] = stmts
|
||||
else:
|
||||
block.block_items = stmts
|
||||
raise _Done
|
||||
|
||||
def rec(block: Block) -> None:
|
||||
# if (x) { _Pragma(...); inputs } -> if (x) { outputs }
|
||||
try_handle_block(block, None)
|
||||
stmts = ast_util.get_block_stmts(block, False)
|
||||
for i, stmt in enumerate(stmts):
|
||||
# { ... { _Pragma(...); inputs } ... } -> { ... outputs ... }
|
||||
try_handle_block(stmt, (stmts, i))
|
||||
ast_util.for_nested_blocks(stmt, rec)
|
||||
|
||||
try:
|
||||
rec(fn.body)
|
||||
raise CandidateConstructionFailure("Failed to find PERM macro in AST.")
|
||||
except _Done:
|
||||
pass
|
||||
|
||||
|
||||
def apply_ast_perms(fn: ca.FuncDef, eval_state: EvalState) -> None:
|
||||
"""Find all late perm macros in the AST and apply them."""
|
||||
# Nested perms will have smaller IDs, so apply the perms from lowest ID to
|
||||
# highest to ensure that all arguments to perms have already been evaluated.
|
||||
for perm_id, (perm, seed) in enumerate(eval_state.ast_perms):
|
||||
_apply_perm(fn, perm_id, perm, seed)
|
||||
@@ -1,36 +0,0 @@
|
||||
import random
|
||||
from typing import List, Iterable, Set, Tuple
|
||||
|
||||
from .perm import Perm, EvalState
|
||||
|
||||
|
||||
def _gen_all_seeds(total_count: int) -> Iterable[int]:
|
||||
"""Generate all numbers 0..total_count-1 in random order, in expected time
|
||||
O(1) per number."""
|
||||
seen: Set[int] = set()
|
||||
while len(seen) < total_count // 2:
|
||||
seed = random.randrange(total_count)
|
||||
if seed not in seen:
|
||||
seen.add(seed)
|
||||
yield seed
|
||||
|
||||
remaining: List[int] = []
|
||||
for seed in range(total_count):
|
||||
if seed not in seen:
|
||||
remaining.append(seed)
|
||||
random.shuffle(remaining)
|
||||
for seed in remaining:
|
||||
yield seed
|
||||
|
||||
|
||||
def perm_gen_all_seeds(perm: Perm) -> Iterable[int]:
|
||||
while True:
|
||||
for seed in _gen_all_seeds(perm.perm_count):
|
||||
yield seed
|
||||
if not perm.is_random():
|
||||
break
|
||||
|
||||
|
||||
def perm_evaluate_one(perm: Perm) -> Tuple[str, EvalState]:
|
||||
eval_state = EvalState()
|
||||
return perm.evaluate(0, eval_state), eval_state
|
||||
@@ -1,144 +0,0 @@
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
import re
|
||||
|
||||
from .perm import (
|
||||
CombinePerm,
|
||||
GeneralPerm,
|
||||
IgnorePerm,
|
||||
IntPerm,
|
||||
LineSwapPerm,
|
||||
LineSwapAstPerm,
|
||||
OncePerm,
|
||||
Perm,
|
||||
PretendPerm,
|
||||
RandomizerPerm,
|
||||
RootPerm,
|
||||
TextPerm,
|
||||
VarPerm,
|
||||
)
|
||||
|
||||
|
||||
def _split_by_comma(text: str) -> List[str]:
|
||||
level = 0
|
||||
current = ""
|
||||
args: List[str] = []
|
||||
for c in text:
|
||||
if c == "," and level == 0:
|
||||
args.append(current)
|
||||
current = ""
|
||||
else:
|
||||
if c == "(":
|
||||
level += 1
|
||||
elif c == ")":
|
||||
level -= 1
|
||||
assert level >= 0, "Bad nesting"
|
||||
current += c
|
||||
assert level == 0, "Mismatched parentheses"
|
||||
args.append(current)
|
||||
return args
|
||||
|
||||
|
||||
def _split_args(text: str) -> List[Perm]:
|
||||
perm_args = [_rec_perm_parse(arg) for arg in _split_by_comma(text)]
|
||||
return perm_args
|
||||
|
||||
|
||||
def _split_args_newline(text: str) -> List[Perm]:
|
||||
return [_rec_perm_parse(line) for line in text.split("\n") if line.strip()]
|
||||
|
||||
|
||||
def _split_args_text(text: str) -> List[str]:
|
||||
perm_list = _split_args(text)
|
||||
res: List[str] = []
|
||||
for perm in perm_list:
|
||||
assert isinstance(perm, TextPerm)
|
||||
res.append(perm.text)
|
||||
return res
|
||||
|
||||
|
||||
def _make_once_perm(text: str) -> OncePerm:
|
||||
args = _split_by_comma(text)
|
||||
if len(args) not in [1, 2]:
|
||||
raise Exception("PERM_ONCE takes 1 or 2 arguments")
|
||||
key = args[0].strip()
|
||||
value = _rec_perm_parse(args[-1])
|
||||
return OncePerm(key, value)
|
||||
|
||||
|
||||
def _make_var_perm(text: str) -> VarPerm:
|
||||
args = _split_by_comma(text)
|
||||
if len(args) not in [1, 2]:
|
||||
raise Exception("PERM_VAR takes 1 or 2 arguments")
|
||||
var_name = _rec_perm_parse(args[0])
|
||||
value = _rec_perm_parse(args[1]) if len(args) == 2 else None
|
||||
return VarPerm(var_name, value)
|
||||
|
||||
|
||||
PERM_FACTORIES: Dict[str, Callable[[str], Perm]] = {
|
||||
"PERM_GENERAL": lambda text: GeneralPerm(_split_args(text)),
|
||||
"PERM_ONCE": lambda text: _make_once_perm(text),
|
||||
"PERM_RANDOMIZE": lambda text: RandomizerPerm(_rec_perm_parse(text)),
|
||||
"PERM_VAR": lambda text: _make_var_perm(text),
|
||||
"PERM_LINESWAP_TEXT": lambda text: LineSwapPerm(_split_args_newline(text)),
|
||||
"PERM_LINESWAP": lambda text: LineSwapAstPerm(_split_args_newline(text)),
|
||||
"PERM_INT": lambda text: IntPerm(*map(int, _split_args_text(text))),
|
||||
"PERM_IGNORE": lambda text: IgnorePerm(_rec_perm_parse(text)),
|
||||
"PERM_PRETEND": lambda text: PretendPerm(_rec_perm_parse(text)),
|
||||
}
|
||||
|
||||
|
||||
def _consume_arg_parens(text: str) -> Tuple[str, str]:
|
||||
level = 0
|
||||
for i, c in enumerate(text):
|
||||
if c == "(":
|
||||
level += 1
|
||||
elif c == ")":
|
||||
level -= 1
|
||||
if level == -1:
|
||||
return text[:i], text[i + 1 :]
|
||||
raise Exception("Failed to find closing parenthesis when parsing PERM macro")
|
||||
|
||||
|
||||
def _rec_perm_parse(text: str) -> Perm:
|
||||
remain = text
|
||||
macro_search = r"(PERM_.+?)\("
|
||||
|
||||
perms: List[Perm] = []
|
||||
while len(remain) > 0:
|
||||
match = re.search(macro_search, remain)
|
||||
|
||||
# No match found; return remaining
|
||||
if match is None:
|
||||
text_perm = TextPerm(remain)
|
||||
perms.append(text_perm)
|
||||
break
|
||||
|
||||
# Get perm type and args
|
||||
perm_type = match.group(1)
|
||||
if perm_type not in PERM_FACTORIES:
|
||||
raise Exception("Unrecognized PERM macro: " + perm_type)
|
||||
between = remain[: match.start()]
|
||||
args, remain = _consume_arg_parens(remain[match.end() :])
|
||||
|
||||
# Create text perm
|
||||
perms.append(TextPerm(between))
|
||||
|
||||
# Create new perm
|
||||
perms.append(PERM_FACTORIES[perm_type](args))
|
||||
|
||||
if len(perms) == 1:
|
||||
return perms[0]
|
||||
return CombinePerm(perms)
|
||||
|
||||
|
||||
def perm_parse(text: str) -> Perm:
|
||||
ret = _rec_perm_parse(text)
|
||||
if isinstance(ret, TextPerm):
|
||||
ret = RandomizerPerm(ret)
|
||||
print("No perm macros found. Defaulting to randomization.")
|
||||
ret = RootPerm(ret)
|
||||
if not ret.is_random():
|
||||
print(f"Will run for {ret.perm_count} iterations.")
|
||||
else:
|
||||
print(f"Will try {ret.perm_count} different base sources.")
|
||||
return ret
|
||||
@@ -1,289 +0,0 @@
|
||||
from base64 import b64encode
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, TypeVar, Optional
|
||||
import math
|
||||
|
||||
from pycparser import c_ast as ca
|
||||
|
||||
from ..ast_util import Statement
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessState:
|
||||
once_options: Dict[str, List["Perm"]] = field(
|
||||
default_factory=lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalState:
|
||||
vars: Dict[str, str] = field(default_factory=dict)
|
||||
once_choices: Dict[str, "Perm"] = field(default_factory=dict)
|
||||
ast_perms: List[Tuple["Perm", int]] = field(default_factory=list)
|
||||
|
||||
def register_ast_perm(self, perm: "Perm", seed: int) -> int:
|
||||
ret = len(self.ast_perms)
|
||||
self.ast_perms.append((perm, seed))
|
||||
return ret
|
||||
|
||||
def gen_ast_statement_perm(
|
||||
self, perm: "Perm", seed: int, *, statements: List[str]
|
||||
) -> str:
|
||||
perm_id = self.register_ast_perm(perm, seed)
|
||||
lines = [
|
||||
"{",
|
||||
f"#pragma _permuter ast_perm {perm_id}",
|
||||
*["{" + stmt + "}" for stmt in statements],
|
||||
"}",
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class Perm:
|
||||
"""A Perm subclass generates different variations of a part of the source
|
||||
code. Its evaluate method will be called with a seed between 0 and
|
||||
perm_count-1, and it should return a unique string for each.
|
||||
|
||||
A Perm is allowed to return different strings for the same seed, but if so,
|
||||
if should override is_random to return True. This will cause permutation
|
||||
to happen in an infinite loop, rather than stop after the last permutation
|
||||
has been tested."""
|
||||
|
||||
perm_count: int
|
||||
children: List["Perm"]
|
||||
|
||||
def evaluate(self, seed: int, state: EvalState) -> str:
|
||||
return ""
|
||||
|
||||
def eval_statement_ast(self, args: List[Statement], seed: int) -> List[Statement]:
|
||||
raise NotImplementedError
|
||||
|
||||
def preprocess(self, state: PreprocessState) -> None:
|
||||
for p in self.children:
|
||||
p.preprocess(state)
|
||||
|
||||
def is_random(self) -> bool:
|
||||
return any(p.is_random() for p in self.children)
|
||||
|
||||
|
||||
def _eval_all(seed: int, perms: List[Perm], state: EvalState) -> List[str]:
|
||||
ret = []
|
||||
for p in perms:
|
||||
seed, sub_seed = divmod(seed, p.perm_count)
|
||||
ret.append(p.evaluate(sub_seed, state))
|
||||
assert seed == 0, "seed must be in [0, prod(counts))"
|
||||
return ret
|
||||
|
||||
|
||||
def _count_all(perms: List[Perm]) -> int:
|
||||
res = 1
|
||||
for p in perms:
|
||||
res *= p.perm_count
|
||||
return res
|
||||
|
||||
|
||||
def _eval_either(seed: int, perms: List[Perm], state: EvalState) -> str:
|
||||
for p in perms:
|
||||
if seed < p.perm_count:
|
||||
return p.evaluate(seed, state)
|
||||
seed -= p.perm_count
|
||||
assert False, "seed must be in [0, sum(counts))"
|
||||
|
||||
|
||||
def _count_either(perms: List[Perm]) -> int:
|
||||
return sum(p.perm_count for p in perms)
|
||||
|
||||
|
||||
def _shuffle(items: List[T], seed: int) -> List[T]:
|
||||
items = items[:]
|
||||
output = []
|
||||
while items:
|
||||
ind = seed % len(items)
|
||||
seed //= len(items)
|
||||
output.append(items[ind])
|
||||
del items[ind]
|
||||
return output
|
||||
|
||||
|
||||
class RootPerm(Perm):
|
||||
def __init__(self, inner: Perm) -> None:
|
||||
self.children = [inner]
|
||||
self.perm_count = inner.perm_count
|
||||
self.preprocess_state = PreprocessState()
|
||||
self.preprocess(self.preprocess_state)
|
||||
for key, options in self.preprocess_state.once_options.items():
|
||||
if len(options) == 1:
|
||||
raise Exception(f"PERM_ONCE({key}) occurs only once, possible error?")
|
||||
self.perm_count *= len(options)
|
||||
|
||||
def evaluate(self, seed: int, state: EvalState) -> str:
|
||||
for key, options in self.preprocess_state.once_options.items():
|
||||
seed, choice = divmod(seed, len(options))
|
||||
state.once_choices[key] = options[choice]
|
||||
return self.children[0].evaluate(seed, state)
|
||||
|
||||
|
||||
class TextPerm(Perm):
|
||||
def __init__(self, text: str) -> None:
|
||||
# Comma escape sequence
|
||||
text = text.replace("(,)", ",")
|
||||
self.text = text
|
||||
self.children = []
|
||||
self.perm_count = 1
|
||||
|
||||
def evaluate(self, seed: int, state: EvalState) -> str:
|
||||
return self.text
|
||||
|
||||
|
||||
class IgnorePerm(Perm):
|
||||
def __init__(self, inner: Perm) -> None:
|
||||
self.children = [inner]
|
||||
self.perm_count = inner.perm_count
|
||||
|
||||
def evaluate(self, seed: int, state: EvalState) -> str:
|
||||
text = self.children[0].evaluate(seed, state)
|
||||
if not text:
|
||||
return ""
|
||||
encoded = b64encode(text.encode("utf-8")).decode("ascii")
|
||||
return "#pragma _permuter b64literal " + encoded
|
||||
|
||||
|
||||
class PretendPerm(Perm):
|
||||
def __init__(self, inner: Perm) -> None:
|
||||
self.children = [inner]
|
||||
self.perm_count = inner.perm_count
|
||||
|
||||
def evaluate(self, seed: int, state: EvalState) -> str:
|
||||
text = self.children[0].evaluate(seed, state)
|
||||
return "\n".join(
|
||||
[
|
||||
"",
|
||||
"#pragma _permuter latedefine start",
|
||||
text,
|
||||
"#pragma _permuter latedefine end",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class CombinePerm(Perm):
|
||||
def __init__(self, parts: List[Perm]) -> None:
|
||||
self.children = parts
|
||||
self.perm_count = _count_all(parts)
|
||||
|
||||
def evaluate(self, seed: int, state: EvalState) -> str:
|
||||
texts = _eval_all(seed, self.children, state)
|
||||
return "".join(texts)
|
||||
|
||||
|
||||
class RandomizerPerm(Perm):
|
||||
def __init__(self, inner: Perm) -> None:
|
||||
self.children = [inner]
|
||||
self.perm_count = inner.perm_count
|
||||
|
||||
def evaluate(self, seed: int, state: EvalState) -> str:
|
||||
text = self.children[0].evaluate(seed, state)
|
||||
return "\n".join(
|
||||
[
|
||||
"",
|
||||
"#pragma _permuter randomizer start",
|
||||
text,
|
||||
"#pragma _permuter randomizer end",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
def is_random(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class GeneralPerm(Perm):
|
||||
def __init__(self, candidates: List[Perm]) -> None:
|
||||
self.perm_count = _count_either(candidates)
|
||||
self.children = candidates
|
||||
|
||||
def evaluate(self, seed: int, state: EvalState) -> str:
|
||||
return _eval_either(seed, self.children, state)
|
||||
|
||||
|
||||
class OncePerm(Perm):
|
||||
def __init__(self, key: str, inner: Perm) -> None:
|
||||
self.key = key
|
||||
self.children = [inner]
|
||||
self.perm_count = inner.perm_count
|
||||
|
||||
def preprocess(self, state: PreprocessState) -> None:
|
||||
state.once_options[self.key].append(self)
|
||||
super().preprocess(state)
|
||||
|
||||
def evaluate(self, seed: int, state: EvalState) -> str:
|
||||
if state.once_choices[self.key] is self:
|
||||
return self.children[0].evaluate(seed, state)
|
||||
return ""
|
||||
|
||||
|
||||
class VarPerm(Perm):
|
||||
def __init__(self, var_name: Perm, expansion: Optional[Perm]) -> None:
|
||||
if expansion:
|
||||
self.children = [var_name, expansion]
|
||||
else:
|
||||
self.children = [var_name]
|
||||
self.perm_count = _count_all(self.children)
|
||||
|
||||
def evaluate(self, seed: int, state: EvalState) -> str:
|
||||
var_name_perm = self.children[0]
|
||||
seed, sub_seed = divmod(seed, var_name_perm.perm_count)
|
||||
var_name = var_name_perm.evaluate(sub_seed, state).strip()
|
||||
if len(self.children) > 1:
|
||||
ret = self.children[1].evaluate(seed, state)
|
||||
state.vars[var_name] = ret
|
||||
return ""
|
||||
else:
|
||||
if var_name not in state.vars:
|
||||
raise Exception(f"Tried to read undefined PERM_VAR {var_name}")
|
||||
return state.vars[var_name]
|
||||
|
||||
|
||||
class LineSwapPerm(Perm):
|
||||
def __init__(self, lines: List[Perm]) -> None:
|
||||
self.children = lines
|
||||
self.own_count = math.factorial(len(lines))
|
||||
self.perm_count = self.own_count * _count_all(self.children)
|
||||
|
||||
def evaluate(self, seed: int, state: EvalState) -> str:
|
||||
sub_seed, variation = divmod(seed, self.own_count)
|
||||
texts = _eval_all(sub_seed, self.children, state)
|
||||
return "\n".join(_shuffle(texts, variation))
|
||||
|
||||
|
||||
class LineSwapAstPerm(Perm):
|
||||
def __init__(self, lines: List[Perm]) -> None:
|
||||
self.children = lines
|
||||
self.own_count = math.factorial(len(lines))
|
||||
self.perm_count = self.own_count * _count_all(self.children)
|
||||
|
||||
def evaluate(self, seed: int, state: EvalState) -> str:
|
||||
sub_seed, variation = divmod(seed, self.own_count)
|
||||
texts = _eval_all(sub_seed, self.children, state)
|
||||
return state.gen_ast_statement_perm(self, variation, statements=texts)
|
||||
|
||||
def eval_statement_ast(self, args: List[Statement], seed: int) -> List[Statement]:
|
||||
ret = []
|
||||
for item in _shuffle(args, seed):
|
||||
assert isinstance(item, ca.Compound)
|
||||
ret.extend(item.block_items or [])
|
||||
return ret
|
||||
|
||||
|
||||
class IntPerm(Perm):
|
||||
def __init__(self, low: int, high: int) -> None:
|
||||
assert low <= high
|
||||
self.low = low
|
||||
self.children = []
|
||||
self.perm_count = high - low + 1
|
||||
|
||||
def evaluate(self, seed: int, state: EvalState) -> str:
|
||||
return str(self.low + seed)
|
||||
@@ -1,279 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
import difflib
|
||||
import itertools
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Iterator,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from .candidate import Candidate, CandidateResult
|
||||
from .compiler import Compiler
|
||||
from .error import CandidateConstructionFailure
|
||||
from .perm.perm import EvalState
|
||||
from .perm.eval import perm_evaluate_one, perm_gen_all_seeds
|
||||
from .perm.parse import perm_parse
|
||||
from .profiler import Profiler
|
||||
from .scorer import Scorer
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalError:
|
||||
exc_str: Optional[str]
|
||||
seed: Optional[Tuple[int, int]]
|
||||
|
||||
|
||||
EvalResult = Union[CandidateResult, EvalError]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Finished:
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
text: str
|
||||
|
||||
|
||||
class NeedMoreWork:
|
||||
pass
|
||||
|
||||
|
||||
class _CompileFailure(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkDone:
|
||||
perm_index: int
|
||||
result: EvalResult
|
||||
|
||||
|
||||
Task = Union[Finished, Tuple[int, int]]
|
||||
FeedbackItem = Union[Finished, Message, NeedMoreWork, WorkDone]
|
||||
Feedback = Tuple[FeedbackItem, int, Optional[str]]
|
||||
|
||||
|
||||
class Permuter:
|
||||
"""
|
||||
Represents a single source from which permutation candidates can be generated,
|
||||
and which keeps track of good scores achieved so far.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dir: str,
|
||||
fn_name: Optional[str],
|
||||
compiler: Compiler,
|
||||
scorer: Scorer,
|
||||
source_file: str,
|
||||
source: str,
|
||||
*,
|
||||
force_seed: Optional[int],
|
||||
force_rng_seed: Optional[int],
|
||||
keep_prob: float,
|
||||
need_profiler: bool,
|
||||
need_all_sources: bool,
|
||||
show_errors: bool,
|
||||
best_only: bool,
|
||||
better_only: bool,
|
||||
) -> None:
|
||||
self.dir = dir
|
||||
self.compiler = compiler
|
||||
self.scorer = scorer
|
||||
self.source_file = source_file
|
||||
self.source = source
|
||||
|
||||
if fn_name is None:
|
||||
# Semi-legacy codepath; all functions imported through import.py have a
|
||||
# function name. This would ideally be done on AST level instead of on the
|
||||
# pre-macro'ed source code, but we don't care enough to make that
|
||||
# refactoring.
|
||||
fns = _find_fns(source)
|
||||
if len(fns) == 0:
|
||||
raise Exception(f"{self.source_file} does not contain any function!")
|
||||
if len(fns) > 1:
|
||||
raise Exception(
|
||||
f"{self.source_file} must contain only one function, "
|
||||
"or have a function.txt next to it with a function name."
|
||||
)
|
||||
self.fn_name = fns[0]
|
||||
else:
|
||||
self.fn_name = fn_name
|
||||
self.unique_name = self.fn_name
|
||||
|
||||
self._permutations = perm_parse(source)
|
||||
|
||||
self._force_seed = force_seed
|
||||
self._force_rng_seed = force_rng_seed
|
||||
self._cur_seed: Optional[Tuple[int, int]] = None
|
||||
|
||||
self.keep_prob = keep_prob
|
||||
self.need_profiler = need_profiler
|
||||
self._need_all_sources = need_all_sources
|
||||
self._show_errors = show_errors
|
||||
self._best_only = best_only
|
||||
self._better_only = better_only
|
||||
|
||||
(
|
||||
self.base_score,
|
||||
self.base_hash,
|
||||
self.base_source,
|
||||
) = self._create_and_score_base()
|
||||
self.best_score = self.base_score
|
||||
self.hashes = {self.base_hash}
|
||||
self._cur_cand: Optional[Candidate] = None
|
||||
self._last_score: Optional[int] = None
|
||||
|
||||
def _create_and_score_base(self) -> Tuple[int, str, str]:
|
||||
base_source, eval_state = perm_evaluate_one(self._permutations)
|
||||
base_cand = Candidate.from_source(
|
||||
base_source, eval_state, self.fn_name, rng_seed=0
|
||||
)
|
||||
o_file = base_cand.compile(self.compiler, show_errors=True)
|
||||
if not o_file:
|
||||
raise CandidateConstructionFailure(f"Unable to compile {self.source_file}")
|
||||
base_result = base_cand.score(self.scorer, o_file)
|
||||
assert base_result.hash is not None
|
||||
return base_result.score, base_result.hash, base_cand.get_source()
|
||||
|
||||
def _need_to_send_source(self, result: CandidateResult) -> bool:
|
||||
return self._need_all_sources or self.should_output(result)
|
||||
|
||||
def _eval_candidate(self, seed: int) -> CandidateResult:
|
||||
t0 = time.time()
|
||||
|
||||
# Determine if we should keep the last candidate.
|
||||
# Don't keep 0-score candidates; we'll only create new, worse, zeroes.
|
||||
keep = (
|
||||
self._permutations.is_random()
|
||||
and random.uniform(0, 1) < self.keep_prob
|
||||
and self._last_score != 0
|
||||
and self._last_score != self.scorer.PENALTY_INF
|
||||
) or self._force_rng_seed
|
||||
|
||||
self._last_score = None
|
||||
|
||||
# Create a new candidate if we didn't keep the last one (or if the last one didn't exist)
|
||||
# N.B. if we decide to keep the previous candidate, we will skip over the provided seed.
|
||||
# This means we're not guaranteed to test all seeds, but it doesn't really matter since
|
||||
# we're randomizing anyway.
|
||||
if not self._cur_cand or not keep:
|
||||
eval_state = EvalState()
|
||||
cand_c = self._permutations.evaluate(seed, eval_state)
|
||||
rng_seed = self._force_rng_seed or random.randrange(1, 10 ** 20)
|
||||
self._cur_seed = (seed, rng_seed)
|
||||
self._cur_cand = Candidate.from_source(
|
||||
cand_c, eval_state, self.fn_name, rng_seed=rng_seed
|
||||
)
|
||||
|
||||
# Randomize the candidate
|
||||
if self._permutations.is_random():
|
||||
self._cur_cand.randomize_ast()
|
||||
|
||||
t1 = time.time()
|
||||
|
||||
self._cur_cand.get_source()
|
||||
|
||||
t2 = time.time()
|
||||
|
||||
o_file = self._cur_cand.compile(self.compiler)
|
||||
if not o_file and self._show_errors:
|
||||
raise _CompileFailure()
|
||||
|
||||
t3 = time.time()
|
||||
|
||||
result = self._cur_cand.score(self.scorer, o_file)
|
||||
|
||||
t4 = time.time()
|
||||
|
||||
if self.need_profiler:
|
||||
profiler = Profiler()
|
||||
profiler.add_stat(Profiler.StatType.perm, t1 - t0)
|
||||
profiler.add_stat(Profiler.StatType.stringify, t2 - t1)
|
||||
profiler.add_stat(Profiler.StatType.compile, t3 - t2)
|
||||
profiler.add_stat(Profiler.StatType.score, t4 - t3)
|
||||
result.profiler = profiler
|
||||
|
||||
self._last_score = result.score
|
||||
|
||||
if not self._need_to_send_source(result):
|
||||
result.source = None
|
||||
result.hash = None
|
||||
|
||||
return result
|
||||
|
||||
def should_output(self, result: CandidateResult) -> bool:
|
||||
"""Check whether a result should be outputted. This must be more liberal
|
||||
in child processes than in parent ones, or else sources will be missing."""
|
||||
return (
|
||||
result.score <= self.base_score
|
||||
and result.hash is not None
|
||||
and result.source is not None
|
||||
and not (result.score > self.best_score and self._best_only)
|
||||
and (
|
||||
result.score < self.base_score
|
||||
or (result.score == self.base_score and not self._better_only)
|
||||
)
|
||||
and result.hash not in self.hashes
|
||||
)
|
||||
|
||||
def record_result(self, result: CandidateResult) -> None:
|
||||
"""Record a new result, updating the best score and adding the hash to
|
||||
the set of hashes we have already seen. No hash is recorded for score
|
||||
0, since we are interested in all score 0's, not just the first."""
|
||||
self.best_score = min(self.best_score, result.score)
|
||||
if result.score != 0 and result.hash is not None:
|
||||
self.hashes.add(result.hash)
|
||||
|
||||
def seed_iterator(self) -> Iterator[int]:
|
||||
"""Create an iterator over all seeds for this permuter. The iterator
|
||||
will be infinite if we are randomizing."""
|
||||
if self._force_seed is None:
|
||||
return iter(perm_gen_all_seeds(self._permutations))
|
||||
if self._permutations.is_random():
|
||||
return itertools.repeat(self._force_seed)
|
||||
return iter([self._force_seed])
|
||||
|
||||
def try_eval_candidate(self, seed: int) -> EvalResult:
|
||||
"""Evaluate a seed for the permuter."""
|
||||
try:
|
||||
return self._eval_candidate(seed)
|
||||
except _CompileFailure:
|
||||
return EvalError(exc_str=None, seed=self._cur_seed)
|
||||
except Exception:
|
||||
return EvalError(exc_str=traceback.format_exc(), seed=self._cur_seed)
|
||||
|
||||
def diff(self, other_source: str) -> str:
|
||||
"""Compute a unified white-space-ignoring diff from the (pretty-printed)
|
||||
base source against another source generated from this permuter."""
|
||||
|
||||
class Line(str):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, str) and self.strip() == other.strip()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.strip())
|
||||
|
||||
a = list(map(Line, self.base_source.split("\n")))
|
||||
b = list(map(Line, other_source.split("\n")))
|
||||
return "\n".join(
|
||||
difflib.unified_diff(a, b, fromfile="before", tofile="after", lineterm="")
|
||||
)
|
||||
|
||||
|
||||
def _find_fns(source: str) -> List[str]:
|
||||
fns = re.findall(r"(\w+)\([^()\n]*\)\s*?{", source)
|
||||
return [
|
||||
fn
|
||||
for fn in fns
|
||||
if not fn.startswith("PERM") and fn not in ["if", "for", "switch", "while"]
|
||||
]
|
||||
@@ -1,10 +0,0 @@
|
||||
from typing import List
|
||||
import subprocess
|
||||
|
||||
|
||||
def preprocess(filename: str, cpp_args: List[str] = []) -> str:
|
||||
return subprocess.check_output(
|
||||
["cpp"] + cpp_args + ["-P", "-nostdinc", filename],
|
||||
universal_newlines=True,
|
||||
encoding="utf-8",
|
||||
)
|
||||
@@ -1,46 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from .permuter import Permuter
|
||||
|
||||
# Number of additional characters to replace with spaces, in addition to what
|
||||
# is required based on the length of the previous progress line. This could be
|
||||
# set to 0, but it's nice to have some margin to deal with e.g. zero-width
|
||||
# control characters.
|
||||
SAFETY_PAD = 10
|
||||
|
||||
|
||||
class Printer:
|
||||
_last_progress: Optional[str] = None
|
||||
|
||||
def progress(self, message: str) -> None:
|
||||
if self._last_progress is None:
|
||||
clear = ""
|
||||
else:
|
||||
pad = max(len(self._last_progress) - len(message) + SAFETY_PAD, 0)
|
||||
clear = "\b" * pad + " " * pad + "\r"
|
||||
print(clear + message, end="", flush=True)
|
||||
self._last_progress = message
|
||||
|
||||
def print(
|
||||
self,
|
||||
message: str,
|
||||
permuter: Optional[Permuter],
|
||||
who: Optional[str],
|
||||
*,
|
||||
color: str = "",
|
||||
keep_progress: bool = False,
|
||||
) -> None:
|
||||
if self._last_progress is not None:
|
||||
if keep_progress:
|
||||
print()
|
||||
else:
|
||||
pad = len(self._last_progress) + SAFETY_PAD
|
||||
print("\r" + " " * pad + "\r", end="")
|
||||
if permuter is not None:
|
||||
message = f"[{permuter.unique_name}] {message}"
|
||||
if who is not None:
|
||||
message = f"[{who}] {message}"
|
||||
if color:
|
||||
message = f"{color}{message}\u001b[0m"
|
||||
print(message)
|
||||
self._last_progress = None
|
||||
@@ -1,23 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Profiler:
|
||||
class StatType(Enum):
|
||||
perm = 1
|
||||
stringify = 2
|
||||
compile = 3
|
||||
score = 4
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.time_stats = {x: 0.0 for x in Profiler.StatType}
|
||||
|
||||
def add_stat(self, stat: StatType, time_taken: float) -> None:
|
||||
self.time_stats[stat] += time_taken
|
||||
|
||||
def get_str_stats(self) -> str:
|
||||
total_time = sum(self.time_stats[e] for e in self.time_stats)
|
||||
timings = ", ".join(
|
||||
f"{round(100 * self.time_stats[e] / total_time)}% {e.name}"
|
||||
for e in self.time_stats
|
||||
)
|
||||
return timings
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,145 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
import difflib
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Tuple, List, Optional
|
||||
from collections import Counter
|
||||
|
||||
|
||||
from .objdump import objdump, get_arch
|
||||
|
||||
|
||||
@dataclass(init=False, unsafe_hash=True)
|
||||
class DiffAsmLine:
|
||||
line: str = field(compare=False)
|
||||
mnemonic: str
|
||||
|
||||
def __init__(self, line: str) -> None:
|
||||
self.line = line
|
||||
self.mnemonic = line.split("\t")[0]
|
||||
|
||||
|
||||
class Scorer:
|
||||
PENALTY_INF = 10 ** 9
|
||||
|
||||
PENALTY_STACKDIFF = 1
|
||||
PENALTY_REGALLOC = 5
|
||||
PENALTY_REORDERING = 60
|
||||
PENALTY_INSERTION = 100
|
||||
PENALTY_DELETION = 100
|
||||
|
||||
def __init__(self, target_o: str, *, stack_differences: bool):
|
||||
self.target_o = target_o
|
||||
self.arch = get_arch(target_o)
|
||||
self.stack_differences = stack_differences
|
||||
_, self.target_seq = self._objdump(target_o)
|
||||
self.differ: difflib.SequenceMatcher[DiffAsmLine] = difflib.SequenceMatcher(
|
||||
autojunk=False
|
||||
)
|
||||
self.differ.set_seq2(self.target_seq)
|
||||
|
||||
def _objdump(self, o_file: str) -> Tuple[str, List[DiffAsmLine]]:
|
||||
ret = []
|
||||
lines = objdump(o_file, self.arch, stack_differences=self.stack_differences)
|
||||
for line in lines:
|
||||
ret.append(DiffAsmLine(line))
|
||||
return "\n".join(lines), ret
|
||||
|
||||
def score(self, cand_o: Optional[str]) -> Tuple[int, str]:
|
||||
if not cand_o:
|
||||
return Scorer.PENALTY_INF, ""
|
||||
|
||||
objdump_output, cand_seq = self._objdump(cand_o)
|
||||
|
||||
score = 0
|
||||
deletions = []
|
||||
insertions = []
|
||||
|
||||
def lo_hi_match(old: str, new: str) -> bool:
|
||||
old_lo = old.find("%lo")
|
||||
old_hi = old.find("%hi")
|
||||
new_lo = new.find("%lo")
|
||||
new_hi = new.find("%hi")
|
||||
|
||||
if old_lo != -1 and new_lo != -1:
|
||||
old_idx = old_lo
|
||||
new_idx = new_lo
|
||||
elif old_hi != -1 and new_hi != -1:
|
||||
old_idx = old_hi
|
||||
new_idx = new_hi
|
||||
else:
|
||||
return False
|
||||
|
||||
if old[:old_idx] != new[:new_idx]:
|
||||
return False
|
||||
|
||||
old_inner = old[old_idx + 4 : -1]
|
||||
new_inner = new[new_idx + 4 : -1]
|
||||
return old_inner.startswith(".") or new_inner.startswith(".")
|
||||
|
||||
def diff_sameline(old: str, new: str) -> None:
|
||||
nonlocal score
|
||||
if old == new:
|
||||
return
|
||||
|
||||
if lo_hi_match(old, new):
|
||||
return
|
||||
|
||||
ignore_last_field = False
|
||||
if self.stack_differences:
|
||||
oldsp = re.search(self.arch.re_sprel, old)
|
||||
newsp = re.search(self.arch.re_sprel, new)
|
||||
if oldsp and newsp:
|
||||
oldrel = int(oldsp.group(1) or "0", 0)
|
||||
newrel = int(newsp.group(1) or "0", 0)
|
||||
score += abs(oldrel - newrel) * self.PENALTY_STACKDIFF
|
||||
ignore_last_field = True
|
||||
|
||||
# Probably regalloc difference, or signed vs unsigned
|
||||
|
||||
# Compare each field in order
|
||||
newfields, oldfields = new.split(","), old.split(",")
|
||||
if ignore_last_field:
|
||||
newfields = newfields[:-1]
|
||||
oldfields = oldfields[:-1]
|
||||
for nf, of in zip(newfields, oldfields):
|
||||
if nf != of:
|
||||
score += self.PENALTY_REGALLOC
|
||||
# Penalize any extra fields
|
||||
score += abs(len(newfields) - len(oldfields)) * self.PENALTY_REGALLOC
|
||||
|
||||
def diff_insert(line: str) -> None:
|
||||
# Reordering or totally different codegen.
|
||||
# Defer this until later when we can tell.
|
||||
insertions.append(line)
|
||||
|
||||
def diff_delete(line: str) -> None:
|
||||
deletions.append(line)
|
||||
|
||||
self.differ.set_seq1(cand_seq)
|
||||
for (tag, i1, i2, j1, j2) in self.differ.get_opcodes():
|
||||
if tag == "equal":
|
||||
for k in range(i2 - i1):
|
||||
old = self.target_seq[j1 + k].line
|
||||
new = cand_seq[i1 + k].line
|
||||
diff_sameline(old, new)
|
||||
if tag == "replace" or tag == "delete":
|
||||
for k in range(i1, i2):
|
||||
diff_insert(cand_seq[k].line)
|
||||
if tag == "replace" or tag == "insert":
|
||||
for k in range(j1, j2):
|
||||
diff_delete(self.target_seq[k].line)
|
||||
|
||||
insertions_co = Counter(insertions)
|
||||
deletions_co = Counter(deletions)
|
||||
for item in insertions_co + deletions_co:
|
||||
ins = insertions_co[item]
|
||||
dels = deletions_co[item]
|
||||
common = min(ins, dels)
|
||||
score += (
|
||||
(ins - common) * self.PENALTY_INSERTION
|
||||
+ (dels - common) * self.PENALTY_DELETION
|
||||
+ self.PENALTY_REORDERING * common
|
||||
)
|
||||
|
||||
return (score, hashlib.sha256(objdump_output.encode()).hexdigest())
|
||||
@@ -1,74 +0,0 @@
|
||||
import re
|
||||
import argparse
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _find_bracket_end(input: str, start_index: int) -> int:
|
||||
level = 1
|
||||
assert input[start_index] == "{"
|
||||
i = start_index + 1
|
||||
while i < len(input):
|
||||
if input[i] == "{":
|
||||
level += 1
|
||||
elif input[i] == "}":
|
||||
level -= 1
|
||||
if level == 0:
|
||||
break
|
||||
i += 1
|
||||
|
||||
assert level == 0, "unbalanced {}"
|
||||
return i
|
||||
|
||||
|
||||
def strip_other_fns(source: str, keep_fn_name: str) -> str:
|
||||
result = ""
|
||||
remain = source
|
||||
while True:
|
||||
fn_regex = re.compile(r"^.*\s+\**(\w+)\(.*\)\s*?{", re.M)
|
||||
fn = re.search(fn_regex, remain)
|
||||
if fn is None:
|
||||
result += remain
|
||||
remain = ""
|
||||
break
|
||||
|
||||
fn_name = fn.group(1)
|
||||
bracket_end = _find_bracket_end(remain, fn.end() - 1)
|
||||
if fn_name.startswith("PERM"):
|
||||
result += remain[: bracket_end + 1]
|
||||
elif fn_name == keep_fn_name:
|
||||
result += "\n\n" + remain[: bracket_end + 1] + "\n\n"
|
||||
else:
|
||||
result += remain[: fn.end() - 1].rstrip() + ";"
|
||||
|
||||
remain = remain[bracket_end + 1 :]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def strip_other_fns_and_write(
|
||||
source: str, fn_name: str, out_filename: Optional[str] = None
|
||||
) -> None:
|
||||
stripped = strip_other_fns(source, fn_name)
|
||||
|
||||
if out_filename is None:
|
||||
print(stripped)
|
||||
else:
|
||||
with open(out_filename, "w", encoding="utf-8") as f:
|
||||
f.write(stripped)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Remove all but a single function definition from a file."
|
||||
)
|
||||
parser.add_argument("c_file", help="File containing the function.")
|
||||
parser.add_argument("fn_name", help="Function name.")
|
||||
args = parser.parse_args()
|
||||
|
||||
source = Path(args.c_file).read_text()
|
||||
strip_other_fns_and_write(source, args.fn_name, args.c_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,18 +0,0 @@
|
||||
#-----------------------------------------------------------------
|
||||
# pycparser: __init__.py
|
||||
#
|
||||
# This package file exports some convenience functions for
|
||||
# interacting with pycparser
|
||||
#
|
||||
# Eli Bendersky [https://eli.thegreenplace.net/]
|
||||
# License: BSD
|
||||
#-----------------------------------------------------------------
|
||||
__all__ = ['c_parser', 'c_ast']
|
||||
__version__ = '2.19'
|
||||
|
||||
from typing import Any, List, Union
|
||||
from . import c_ast
|
||||
from .c_parser import CParser
|
||||
|
||||
def preprocess_file(filename: str, cpp_path: str='cpp', cpp_args: Union[List[str], str]='') -> str: ...
|
||||
def parse_file(filename: str, use_cpp: bool=False, cpp_path: str='cpp', cpp_args: str='', parser: Any=None) -> c_ast.FileAST: ...
|
||||
@@ -1,719 +0,0 @@
|
||||
# -----------------------------------------------------------------
|
||||
# pycparser: c_ast.py
|
||||
#
|
||||
# AST Node classes.
|
||||
#
|
||||
# Eli Bendersky [https://eli.thegreenplace.net/]
|
||||
# License: BSD
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
|
||||
from typing import TextIO, Iterable, List, Any, Optional, Union as Union_
|
||||
from .plyparser import Coord
|
||||
import sys
|
||||
|
||||
|
||||
class Node(object):
|
||||
coord: Optional[Coord]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
...
|
||||
|
||||
def __iter__(self) -> Iterable[Node]:
|
||||
...
|
||||
|
||||
def children(self) -> Iterable[Node]:
|
||||
...
|
||||
|
||||
def show(
|
||||
self,
|
||||
buf: TextIO = sys.stdout,
|
||||
offset: int = 0,
|
||||
attrnames: bool = False,
|
||||
nodenames: bool = False,
|
||||
showcoord: bool = False,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
|
||||
Expression = Union_[
|
||||
"ArrayRef",
|
||||
"Assignment",
|
||||
"BinaryOp",
|
||||
"Cast",
|
||||
"CompoundLiteral",
|
||||
"Constant",
|
||||
"ExprList",
|
||||
"FuncCall",
|
||||
"ID",
|
||||
"TernaryOp",
|
||||
"UnaryOp",
|
||||
]
|
||||
Statement = Union_[
|
||||
Expression,
|
||||
"Break",
|
||||
"Case",
|
||||
"Compound",
|
||||
"Continue",
|
||||
"Decl",
|
||||
"Default",
|
||||
"DoWhile",
|
||||
"EmptyStatement",
|
||||
"For",
|
||||
"Goto",
|
||||
"If",
|
||||
"Label",
|
||||
"Return",
|
||||
"Switch",
|
||||
"Typedef",
|
||||
"While",
|
||||
"Pragma",
|
||||
]
|
||||
Type = Union_["PtrDecl", "ArrayDecl", "FuncDecl", "TypeDecl"]
|
||||
InnerType = Union_["IdentifierType", "Struct", "Union", "Enum"]
|
||||
ExternalDeclaration = Union_["FuncDef", "Decl", "Typedef", "Pragma"]
|
||||
AnyNode = Union_[
|
||||
Statement,
|
||||
Type,
|
||||
InnerType,
|
||||
"FuncDef",
|
||||
"EllipsisParam",
|
||||
"Enumerator",
|
||||
"EnumeratorList",
|
||||
"FileAST",
|
||||
"InitList",
|
||||
"NamedInitializer",
|
||||
"ParamList",
|
||||
"Typename",
|
||||
]
|
||||
|
||||
|
||||
class NodeVisitor:
|
||||
def visit(self, node: Node) -> None:
|
||||
...
|
||||
|
||||
def generic_visit(self, node: Node) -> None:
|
||||
...
|
||||
|
||||
def visit_ArrayDecl(self, node: ArrayDecl) -> None:
|
||||
...
|
||||
|
||||
def visit_ArrayRef(self, node: ArrayRef) -> None:
|
||||
...
|
||||
|
||||
def visit_Assignment(self, node: Assignment) -> None:
|
||||
...
|
||||
|
||||
def visit_BinaryOp(self, node: BinaryOp) -> None:
|
||||
...
|
||||
|
||||
def visit_Break(self, node: Break) -> None:
|
||||
...
|
||||
|
||||
def visit_Case(self, node: Case) -> None:
|
||||
...
|
||||
|
||||
def visit_Cast(self, node: Cast) -> None:
|
||||
...
|
||||
|
||||
def visit_Compound(self, node: Compound) -> None:
|
||||
...
|
||||
|
||||
def visit_CompoundLiteral(self, node: CompoundLiteral) -> None:
|
||||
...
|
||||
|
||||
def visit_Constant(self, node: Constant) -> None:
|
||||
...
|
||||
|
||||
def visit_Continue(self, node: Continue) -> None:
|
||||
...
|
||||
|
||||
def visit_Decl(self, node: Decl) -> None:
|
||||
...
|
||||
|
||||
def visit_DeclList(self, node: DeclList) -> None:
|
||||
...
|
||||
|
||||
def visit_Default(self, node: Default) -> None:
|
||||
...
|
||||
|
||||
def visit_DoWhile(self, node: DoWhile) -> None:
|
||||
...
|
||||
|
||||
def visit_EllipsisParam(self, node: EllipsisParam) -> None:
|
||||
...
|
||||
|
||||
def visit_EmptyStatement(self, node: EmptyStatement) -> None:
|
||||
...
|
||||
|
||||
def visit_Enum(self, node: Enum) -> None:
|
||||
...
|
||||
|
||||
def visit_Enumerator(self, node: Enumerator) -> None:
|
||||
...
|
||||
|
||||
def visit_EnumeratorList(self, node: EnumeratorList) -> None:
|
||||
...
|
||||
|
||||
def visit_ExprList(self, node: ExprList) -> None:
|
||||
...
|
||||
|
||||
def visit_FileAST(self, node: FileAST) -> None:
|
||||
...
|
||||
|
||||
def visit_For(self, node: For) -> None:
|
||||
...
|
||||
|
||||
def visit_FuncCall(self, node: FuncCall) -> None:
|
||||
...
|
||||
|
||||
def visit_FuncDecl(self, node: FuncDecl) -> None:
|
||||
...
|
||||
|
||||
def visit_FuncDef(self, node: FuncDef) -> None:
|
||||
...
|
||||
|
||||
def visit_Goto(self, node: Goto) -> None:
|
||||
...
|
||||
|
||||
def visit_ID(self, node: ID) -> None:
|
||||
...
|
||||
|
||||
def visit_IdentifierType(self, node: IdentifierType) -> None:
|
||||
...
|
||||
|
||||
def visit_If(self, node: If) -> None:
|
||||
...
|
||||
|
||||
def visit_InitList(self, node: InitList) -> None:
|
||||
...
|
||||
|
||||
def visit_Label(self, node: Label) -> None:
|
||||
...
|
||||
|
||||
def visit_NamedInitializer(self, node: NamedInitializer) -> None:
|
||||
...
|
||||
|
||||
def visit_ParamList(self, node: ParamList) -> None:
|
||||
...
|
||||
|
||||
def visit_PtrDecl(self, node: PtrDecl) -> None:
|
||||
...
|
||||
|
||||
def visit_Return(self, node: Return) -> None:
|
||||
...
|
||||
|
||||
def visit_Struct(self, node: Struct) -> None:
|
||||
...
|
||||
|
||||
def visit_StructRef(self, node: StructRef) -> None:
|
||||
...
|
||||
|
||||
def visit_Switch(self, node: Switch) -> None:
|
||||
...
|
||||
|
||||
def visit_TernaryOp(self, node: TernaryOp) -> None:
|
||||
...
|
||||
|
||||
def visit_TypeDecl(self, node: TypeDecl) -> None:
|
||||
...
|
||||
|
||||
def visit_Typedef(self, node: Typedef) -> None:
|
||||
...
|
||||
|
||||
def visit_Typename(self, node: Typename) -> None:
|
||||
...
|
||||
|
||||
def visit_UnaryOp(self, node: UnaryOp) -> None:
|
||||
...
|
||||
|
||||
def visit_Union(self, node: Union) -> None:
|
||||
...
|
||||
|
||||
def visit_While(self, node: While) -> None:
|
||||
...
|
||||
|
||||
def visit_Pragma(self, node: Pragma) -> None:
|
||||
...
|
||||
|
||||
|
||||
class ArrayDecl(Node):
|
||||
type: Type
|
||||
dim: Optional[Expression]
|
||||
dim_quals: List[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Type,
|
||||
dim: Optional[Node],
|
||||
dim_quals: List[str],
|
||||
coord: Optional[Coord] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class ArrayRef(Node):
|
||||
name: Expression
|
||||
subscript: Expression
|
||||
|
||||
def __init__(self, name: Node, subscript: Node, coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class Assignment(Node):
|
||||
op: str
|
||||
lvalue: Expression
|
||||
rvalue: Expression
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
op: str,
|
||||
lvalue: Expression,
|
||||
rvalue: Expression,
|
||||
coord: Optional[Coord] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class BinaryOp(Node):
|
||||
op: str
|
||||
left: Expression
|
||||
right: Expression
|
||||
|
||||
def __init__(self, op: str, left: Node, right: Node, coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class Break(Node):
|
||||
def __init__(self, coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class Case(Node):
|
||||
expr: Expression
|
||||
stmts: List[Statement]
|
||||
|
||||
def __init__(
|
||||
self, expr: Expression, stmts: List[Statement], coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class Cast(Node):
|
||||
to_type: "Typename"
|
||||
expr: Expression
|
||||
|
||||
def __init__(
|
||||
self, to_type: "Typename", expr: Expression, coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class Compound(Node):
|
||||
block_items: Optional[List[Statement]]
|
||||
|
||||
def __init__(
|
||||
self, block_items: Optional[List[Statement]], coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class CompoundLiteral(Node):
|
||||
type: "Typename"
|
||||
init: "InitList"
|
||||
|
||||
def __init__(
|
||||
self, type: "Typename", init: "InitList", coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class Constant(Node):
|
||||
type: str
|
||||
value: str
|
||||
|
||||
def __init__(self, type: str, value: str, coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class Continue(Node):
|
||||
def __init__(self, coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class Decl(Node):
|
||||
name: Optional[str]
|
||||
quals: List[str] # e.g. const
|
||||
storage: List[str] # e.g. register
|
||||
funcspec: List[str] # e.g. inline
|
||||
type: Union_[Type, "Struct", "Union", "Enum"]
|
||||
init: Optional[Union_[Expression, "InitList"]]
|
||||
bitsize: Optional[Expression]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str],
|
||||
quals: List[str],
|
||||
storage: List[str],
|
||||
funcspec: List[str],
|
||||
type: Union_[Type, "Struct", "Union", "Enum"],
|
||||
init: Optional[Union_[Expression, "InitList"]],
|
||||
bitsize: Optional[Expression],
|
||||
coord: Optional[Coord] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class DeclList(Node):
|
||||
decls: List[Decl]
|
||||
|
||||
def __init__(self, decls: List[Decl], coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class Default(Node):
|
||||
stmts: List[Statement]
|
||||
|
||||
def __init__(self, stmts: List[Statement], coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class DoWhile(Node):
|
||||
cond: Expression
|
||||
stmt: Statement
|
||||
|
||||
def __init__(
|
||||
self, cond: Expression, stmt: Statement, coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class EllipsisParam(Node):
|
||||
def __init__(self, coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class EmptyStatement(Node):
|
||||
def __init__(self, coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class Enum(Node):
|
||||
name: Optional[str]
|
||||
values: "Optional[EnumeratorList]"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str],
|
||||
values: "Optional[EnumeratorList]",
|
||||
coord: Optional[Coord] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class Enumerator(Node):
|
||||
name: str
|
||||
value: Optional[Expression]
|
||||
|
||||
def __init__(
|
||||
self, name: str, value: Optional[Expression], coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class EnumeratorList(Node):
|
||||
enumerators: List[Enumerator]
|
||||
|
||||
def __init__(self, enumerators: List[Enumerator], coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class ExprList(Node):
|
||||
exprs: List[Union_[Expression, Typename]] # typename only for offsetof
|
||||
|
||||
def __init__(
|
||||
self, exprs: List[Union_[Expression, Typename]], coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class FileAST(Node):
|
||||
ext: List[ExternalDeclaration]
|
||||
|
||||
def __init__(self, ext: List[ExternalDeclaration], coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class For(Node):
|
||||
init: Union_[None, Expression, DeclList]
|
||||
cond: Optional[Expression]
|
||||
next: Optional[Expression]
|
||||
stmt: Statement
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
init: Union_[None, Expression, DeclList],
|
||||
cond: Optional[Expression],
|
||||
next: Optional[Expression],
|
||||
stmt: Statement,
|
||||
coord: Optional[Coord] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class FuncCall(Node):
|
||||
name: Expression
|
||||
args: Optional[ExprList]
|
||||
|
||||
def __init__(
|
||||
self, name: Expression, args: Optional[ExprList], coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class FuncDecl(Node):
|
||||
args: Optional[ParamList]
|
||||
type: Type # return type
|
||||
|
||||
def __init__(
|
||||
self, args: Optional[ParamList], type: Type, coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class FuncDef(Node):
|
||||
decl: Decl
|
||||
param_decls: Optional[List[Decl]]
|
||||
body: Compound
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decl: Decl,
|
||||
param_decls: Optional[List[Decl]],
|
||||
body: Compound,
|
||||
coord: Optional[Coord] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class Goto(Node):
|
||||
name: str
|
||||
|
||||
def __init__(self, name: str, coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class ID(Node):
|
||||
name: str
|
||||
|
||||
def __init__(self, name: str, coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class IdentifierType(Node):
|
||||
names: List[str] # e.g. ['long', 'int']
|
||||
|
||||
def __init__(self, names: List[str], coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class If(Node):
|
||||
cond: Expression
|
||||
iftrue: Statement
|
||||
iffalse: Optional[Statement]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cond: Expression,
|
||||
iftrue: Statement,
|
||||
iffalse: Optional[Statement],
|
||||
coord: Optional[Coord] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class InitList(Node):
|
||||
exprs: List[Union_[Expression, "NamedInitializer"]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exprs: List[Union_[Expression, "NamedInitializer"]],
|
||||
coord: Optional[Coord] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class Label(Node):
|
||||
name: str
|
||||
stmt: Statement
|
||||
|
||||
def __init__(self, name: str, stmt: Statement, coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class NamedInitializer(Node):
|
||||
name: List[Expression] # [ID(x), Constant(4)] for {.x[4] = ...}
|
||||
expr: Expression
|
||||
|
||||
def __init__(
|
||||
self, name: List[Expression], expr: Expression, coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class ParamList(Node):
|
||||
params: List[Union_[Decl, ID, Typename, EllipsisParam]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: List[Union_[Decl, ID, Typename, EllipsisParam]],
|
||||
coord: Optional[Coord] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class PtrDecl(Node):
|
||||
quals: List[str]
|
||||
type: Type
|
||||
|
||||
def __init__(self, quals: List[str], type: Type, coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class Return(Node):
|
||||
expr: Optional[Expression]
|
||||
|
||||
def __init__(self, expr: Optional[Expression], coord: Optional[Coord] = None):
|
||||
...
|
||||
|
||||
|
||||
class Struct(Node):
|
||||
name: Optional[str]
|
||||
decls: Optional[List[Union_[Decl, Pragma]]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str],
|
||||
decls: Optional[List[Union_[Decl, Pragma]]],
|
||||
coord: Optional[Coord] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class StructRef(Node):
|
||||
name: Expression
|
||||
type: str
|
||||
field: ID
|
||||
|
||||
def __init__(
|
||||
self, name: Expression, type: str, field: ID, coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class Switch(Node):
|
||||
cond: Expression
|
||||
stmt: Statement
|
||||
|
||||
def __init__(
|
||||
self, cond: Expression, stmt: Statement, coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class TernaryOp(Node):
|
||||
cond: Expression
|
||||
iftrue: Expression
|
||||
iffalse: Expression
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cond: Expression,
|
||||
iftrue: Expression,
|
||||
iffalse: Expression,
|
||||
coord: Optional[Coord] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class TypeDecl(Node):
|
||||
declname: Optional[str]
|
||||
quals: List[str]
|
||||
type: InnerType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
declname: Optional[str],
|
||||
quals: List[str],
|
||||
type: InnerType,
|
||||
coord: Optional[Coord] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class Typedef(Node):
|
||||
name: str
|
||||
quals: List[str]
|
||||
storage: List[str]
|
||||
type: Type
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
quals: List[str],
|
||||
storage: List[str],
|
||||
type: Type,
|
||||
coord: Optional[Coord] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class Typename(Node):
|
||||
name: None
|
||||
quals: List[str]
|
||||
type: Type
|
||||
|
||||
def __init__(
|
||||
self, name: None, quals: List[str], type: Type, coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class UnaryOp(Node):
|
||||
op: str
|
||||
expr: Union_[Expression, Typename]
|
||||
|
||||
def __init__(
|
||||
self, op: str, expr: Union_[Expression, Typename], coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class Union(Node):
|
||||
name: Optional[str]
|
||||
decls: Optional[List[Union_[Decl, Pragma]]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str],
|
||||
decls: Optional[List[Union_[Decl, Pragma]]],
|
||||
coord: Optional[Coord] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class While(Node):
|
||||
cond: Expression
|
||||
stmt: Statement
|
||||
|
||||
def __init__(
|
||||
self, cond: Expression, stmt: Statement, coord: Optional[Coord] = None
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
class Pragma(Node):
|
||||
string: str
|
||||
|
||||
def __init__(self, string: str, coord: Optional[Coord] = None):
|
||||
...
|
||||
@@ -1,13 +0,0 @@
|
||||
#------------------------------------------------------------------------------
|
||||
# pycparser: c_generator.py
|
||||
#
|
||||
# C code generator from pycparser AST nodes.
|
||||
#
|
||||
# Eli Bendersky [https://eli.thegreenplace.net/]
|
||||
# License: BSD
|
||||
#------------------------------------------------------------------------------
|
||||
from . import c_ast
|
||||
|
||||
class CGenerator:
|
||||
def __init__(self) -> None: ...
|
||||
def visit(self, node: c_ast.Node) -> str: ...
|
||||
@@ -1,15 +0,0 @@
|
||||
#------------------------------------------------------------------------------
|
||||
# pycparser: c_parser.py
|
||||
#
|
||||
# CParser class: Parser and AST builder for the C language
|
||||
#
|
||||
# Eli Bendersky [https://eli.thegreenplace.net/]
|
||||
# License: BSD
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
from . import c_ast
|
||||
|
||||
class CParser:
|
||||
def __init__(self) -> None: ...
|
||||
def parse(self, text: str, filename: str='', debuglevel: int=0) -> c_ast.FileAST: ...
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
# -----------------------------------------------------------------
|
||||
# plyparser.py
|
||||
#
|
||||
# PLYParser class and other utilites for simplifying programming
|
||||
# parsers with PLY
|
||||
#
|
||||
# Eli Bendersky [https://eli.thegreenplace.net/]
|
||||
# License: BSD
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Coord:
|
||||
file: str
|
||||
line: int
|
||||
column: Optional[int]
|
||||
|
||||
def __init__(self, file: str, line: int, column: Optional[int] = None):
|
||||
...
|
||||
|
||||
def __str__(self) -> str:
|
||||
...
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
pass
|
||||
@@ -1,13 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
|
||||
from pycparser import parse_file, c_generator
|
||||
|
||||
fname = "test.c" if len(sys.argv) < 2 else sys.argv[1]
|
||||
|
||||
|
||||
# ast = c_parser.CParser().parse(src)
|
||||
ast = parse_file(fname, use_cpp=True)
|
||||
# ast.show()
|
||||
# print(c_generator.CGenerator().visit(ast))
|
||||
print(ast)
|
||||
@@ -1,2 +0,0 @@
|
||||
#!/bin/bash
|
||||
mips-linux-gnu-gcc -O2 -fno-PIC -fno-common -ffreestanding -mno-shared -mno-abicalls -G 0 -c "$@"
|
||||
@@ -1,247 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Any, Optional
|
||||
import unittest
|
||||
|
||||
from src.compiler import Compiler
|
||||
from src.preprocess import preprocess
|
||||
from src import main
|
||||
|
||||
|
||||
class TestPermMacros(unittest.TestCase):
|
||||
def go(
|
||||
self,
|
||||
intro: str,
|
||||
outro: str,
|
||||
base: str,
|
||||
target: str,
|
||||
*,
|
||||
fn_name: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
) -> int:
|
||||
base = intro + "\n" + base + "\n" + outro
|
||||
target = intro + "\n" + target + "\n" + outro
|
||||
compiler = Compiler("test/compile.sh", show_errors=True)
|
||||
|
||||
# For debugging, to avoid the auto-deleted directory:
|
||||
# target_dir = tempfile.mkdtemp()
|
||||
with tempfile.TemporaryDirectory() as target_dir:
|
||||
with open(os.path.join(target_dir, "base.c"), "w") as f:
|
||||
f.write(base)
|
||||
|
||||
target_o = compiler.compile(target, show_errors=True)
|
||||
assert target_o is not None
|
||||
shutil.move(target_o, os.path.join(target_dir, "target.o"))
|
||||
|
||||
shutil.copy2("test/compile.sh", os.path.join(target_dir, "compile.sh"))
|
||||
|
||||
if fn_name:
|
||||
with open(os.path.join(target_dir, "function.txt"), "w") as f:
|
||||
f.write(fn_name)
|
||||
|
||||
opts = main.Options(directories=[target_dir], stop_on_zero=True, **kwargs)
|
||||
return main.run(opts)[0]
|
||||
|
||||
def test_general(self) -> None:
|
||||
score = self.go(
|
||||
"int test() {",
|
||||
"}",
|
||||
"return PERM_GENERAL(32,64);",
|
||||
"return 64;",
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
def test_not_found(self) -> None:
|
||||
score = self.go(
|
||||
"int test() {",
|
||||
"}",
|
||||
"return PERM_GENERAL(32,64);",
|
||||
"return 92;",
|
||||
)
|
||||
self.assertNotEqual(score, 0)
|
||||
|
||||
def test_multiple_functions(self) -> None:
|
||||
score = self.go(
|
||||
"",
|
||||
"",
|
||||
"""
|
||||
int ignoreme() {}
|
||||
int foo() { return PERM_GENERAL(32,64); }
|
||||
int ignoreme2() {}
|
||||
""",
|
||||
"int foo() { return 64; }",
|
||||
fn_name="foo",
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
def test_general_multiple(self) -> None:
|
||||
score = self.go(
|
||||
"int test() {",
|
||||
"}",
|
||||
"return PERM_GENERAL(1,2,3) + PERM_GENERAL(3,6,9);",
|
||||
"return 9;",
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
def test_general_nested(self) -> None:
|
||||
score = self.go(
|
||||
"int test() {",
|
||||
"}",
|
||||
"return PERM_GENERAL(1,PERM_GENERAL(100,101),3) + PERM_GENERAL(3,6,9);",
|
||||
"return 110;",
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
def test_cast(self) -> None:
|
||||
score = self.go(
|
||||
"int test(int a, int b) {",
|
||||
"}",
|
||||
"return a / PERM_GENERAL(,(unsigned int),(float)) b;",
|
||||
"return a / (float) b;",
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
def test_cast_threaded(self) -> None:
|
||||
score = self.go(
|
||||
"int test(int a, int b) {",
|
||||
"}",
|
||||
"return a / PERM_GENERAL(,(unsigned int),(float)) b;",
|
||||
"return a / (float) b;",
|
||||
threads=2,
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
def test_ignore(self) -> None:
|
||||
score = self.go(
|
||||
"int test(int a, int b) {",
|
||||
"}",
|
||||
"PERM_IGNORE( return a / PERM_GENERAL(a, b); )",
|
||||
"return a / b;",
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
def test_pretend(self) -> None:
|
||||
score = self.go(
|
||||
"int global;",
|
||||
"",
|
||||
"""
|
||||
PERM_IGNORE( inline void foo() { )
|
||||
PERM_PRETEND( void foo(); void bar() { )
|
||||
PERM_RANDOMIZE(
|
||||
global = 1;
|
||||
)
|
||||
PERM_IGNORE( } void bar() { )
|
||||
PERM_RANDOMIZE(
|
||||
global = 2; foo();
|
||||
)
|
||||
}
|
||||
""",
|
||||
"""
|
||||
inline void foo() { global = 1; }
|
||||
void bar() { foo(); global = 2; }
|
||||
""",
|
||||
fn_name="bar",
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
def test_once1(self) -> None:
|
||||
score = self.go(
|
||||
"volatile int A, B, C; void test() {",
|
||||
"}",
|
||||
"""
|
||||
PERM_ONCE(B = 2;)
|
||||
A = 1;
|
||||
PERM_ONCE(B = 2;)
|
||||
C = 3;
|
||||
PERM_ONCE(B = 2;)
|
||||
""",
|
||||
"A = 1; B = 2; C = 3;",
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
def test_once2(self) -> None:
|
||||
score = self.go(
|
||||
"volatile int A, B, C; void test() {",
|
||||
"}",
|
||||
"""
|
||||
PERM_VAR(emit,)
|
||||
PERM_VAR(bademit,)
|
||||
PERM_ONCE(1, PERM_VAR(bademit, A = 7;) A = 2;)
|
||||
PERM_ONCE(1, PERM_VAR(emit, A = 1;))
|
||||
PERM_VAR(emit)
|
||||
PERM_VAR(bademit)
|
||||
PERM_ONCE(2, B = 2;)
|
||||
PERM_ONCE(2, B = 1;)
|
||||
PERM_ONCE(2,)
|
||||
PERM_ONCE(3, PERM_VAR(bademit, A = 9))
|
||||
PERM_ONCE(3, PERM_VAR(bademit, A = 9))
|
||||
C = 3;
|
||||
""",
|
||||
"A = 1; B = 2; C = 3;",
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
def test_lineswap(self) -> None:
|
||||
score = self.go(
|
||||
"void a(); void b(); void c(); void test(void) {",
|
||||
"}",
|
||||
"""
|
||||
PERM_LINESWAP(
|
||||
a();
|
||||
b();
|
||||
c();
|
||||
)
|
||||
""",
|
||||
"b(); a(); c();",
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
def test_lineswap_text(self) -> None:
|
||||
score = self.go(
|
||||
"void a(); void b(); void c(); void test(void) {",
|
||||
"}",
|
||||
"""
|
||||
PERM_LINESWAP_TEXT(
|
||||
a();
|
||||
b();
|
||||
c();
|
||||
)
|
||||
""",
|
||||
"b(); a(); c();",
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
def test_randomizer(self) -> None:
|
||||
score = self.go(
|
||||
"void foo(); void bar(); void test(void) {",
|
||||
"}",
|
||||
"PERM_RANDOMIZE(bar(); foo();)",
|
||||
"foo(); bar();",
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
def test_auto_randomizer(self) -> None:
|
||||
score = self.go(
|
||||
"void foo(); void bar(); void test(void) {",
|
||||
"}",
|
||||
"bar(); foo();",
|
||||
"foo(); bar();",
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
def test_randomizer_threaded(self) -> None:
|
||||
score = self.go(
|
||||
"void foo(); void bar(); void test(void) {",
|
||||
"}",
|
||||
"PERM_RANDOMIZE(bar(); foo();)",
|
||||
"foo(); bar();",
|
||||
threads=2,
|
||||
)
|
||||
self.assertEqual(score, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user