#!/usr/bin/python3
"""
Custom ld64 wrapper for macOS

This script wraps the standard macOS linker (/usr/bin/ld) to modify executable memory
protection flags in the resulting Mach-O binary. It works in three stages:

1. First, it passes through all arguments to the regular macOS linker to create the binary
2. Then, it parses command line arguments to identify output file and segment protection flags
3. Finally, it modifies the output binary's Mach-O headers to ensure segments (particularly __TEXT)
   have the maximum protection flags (rwx) we specify, even if the default macOS linker would restrict them

This is necessary because macOS restricts writable+executable memory by default,
but certain applications need this capability for dynamic code generation or JIT compilation.

Usage: Same as the standard ld64 linker, with the added benefit that -segprot options
will have their max_prot values properly preserved in the output binary.
"""

import sys
import subprocess
from itertools import takewhile
from macholib import MachO, ptypes

def parse_rwx(text):
    return ('r' in text and 1) | ('w' in text and 2) | ('x' in text and 4)

def apply_maxprots(path, maxprots):
    mach = MachO.MachO(path)
    header = mach.headers[0]
    offset = ptypes.sizeof(header.mach_header)

    for cload, ccmd, cdata in header.commands:
        if not hasattr(ccmd, 'segname'):
            break

        if hasattr(ccmd.segname, 'to_str'):
            segname = ccmd.segname.to_str().decode('utf-8').strip('\0')
        else:
            segname = ccmd.segname.decode('utf-8').strip('\0')

        if segname in maxprots and ccmd.maxprot != maxprots[segname]:
            fields = list(takewhile(lambda field: field[0] != 'maxprot', cload._fields_ + ccmd._fields_))
            index = offset + sum(ptypes.sizeof(typ) for _, typ in fields)

            with open(path, 'r+b') as fh:
                fh.seek(index)
                fh.write(bytes([maxprots[segname]]))

        offset += cload.cmdsize

try:
    subprocess.check_call(['/usr/bin/ld'] + sys.argv[1:])
except subprocess.CalledProcessError as ex:
    sys.exit(ex.returncode)

output_file = 'a.out'
segprots = {'__TEXT': parse_rwx('rwx')}  # maxprot = rwx

i = 1
while i < len(sys.argv):
    if sys.argv[i] == '-o' and i + 1 < len(sys.argv):
        output_file = sys.argv[i + 1]
        i += 2
    elif sys.argv[i] == '-segprot' and i + 3 < len(sys.argv):
        segment = sys.argv[i + 1]
        maxprot = sys.argv[i + 2]
        segprots[segment] = parse_rwx(maxprot)
        i += 4
    else:
        i += 1

apply_maxprots(output_file, segprots)
