mongo/buildscripts/s3_binary/download.py

246 lines
7.5 KiB
Python

#!/usr/bin/env python3
import argparse
import hashlib
import os
import shutil
import subprocess
import sys
import tempfile
import time
import traceback
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from buildscripts.s3_binary.hashes import S3_SHA256_HASHES
def _run(cmd: list[str]) -> tuple[int, str]:
try:
r = subprocess.run(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, check=False, text=True
)
return r.returncode, r.stdout
except Exception as e:
return 127, f"{type(e).__name__}: {e}"
def _download_with_curl_or_wget(url: str, out_path: str) -> bool:
"""
Try curl, then wget. Returns True on success.
Respects SSL_CERT_FILE / SSL_CERT_DIR if set.
"""
# curl
curl = shutil.which("curl")
if curl:
code, out = _run(
[
curl,
"--fail",
"--location",
"--silent",
"--show-error",
"--retry",
"3",
"--retry-connrefused",
"--connect-timeout",
"15",
"--output",
out_path,
url,
]
)
if code == 0:
return True
# wget
wget = shutil.which("wget")
if wget:
code, out = _run([wget, "-q", *(_wget_cert_args()), "-O", out_path, url])
if code == 0:
return True
return False
def read_sha_file(filename):
with open(filename) as f:
content = f.read()
return content.strip().split()[0]
def _fetch_remote_sha256_hash(s3_path: str):
downloaded = False
result = None
tempfile_name = None
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
tempfile_name = temp_file.name
try:
from buildscripts.util.download_utils import download_from_s3_with_boto
download_from_s3_with_boto(s3_path + ".sha256", temp_file.name)
downloaded = True
except Exception:
try:
from buildscripts.util.download_utils import download_from_s3_with_requests
download_from_s3_with_requests(s3_path + ".sha256", temp_file.name)
downloaded = True
except Exception:
# curl/wget fallback
downloaded = _download_with_curl_or_wget(s3_path + ".sha256", temp_file.name)
if downloaded:
result = read_sha_file(tempfile_name)
if tempfile_name and os.path.exists(tempfile_name):
os.unlink(tempfile_name)
return result
def _sha256_file(filename: str) -> str:
sha256_hash = hashlib.sha256()
with open(filename, "rb") as f:
for block in iter(lambda: f.read(4096), b""):
sha256_hash.update(block)
return sha256_hash.hexdigest()
def _verify_s3_hash(s3_path: str, local_path: str, expected_hash: str) -> None:
hash_string = _sha256_file(local_path)
if hash_string != expected_hash:
raise ValueError(
f"Hash mismatch for {s3_path}, expected {expected_hash} but got {hash_string}"
)
def validate_file(s3_path, output_path, remote_sha_allowed):
hexdigest = S3_SHA256_HASHES.get(s3_path)
if hexdigest:
print(f"Validating against hard coded sha256: {hexdigest}")
_verify_s3_hash(s3_path, output_path, hexdigest)
return True
if not remote_sha_allowed:
raise ValueError(f"No SHA256 hash available for {s3_path}")
if os.path.exists(output_path + ".sha256"):
hexdigest = read_sha_file(output_path + ".sha256")
print(f"Validating against sh256 file {hexdigest}\n{output_path}.sha256")
else:
hexdigest = _fetch_remote_sha256_hash(s3_path)
if hexdigest:
print(f"Validating against remote sha256 {hexdigest}\n({s3_path}.sha256)")
else:
print(f"Failed to download remote sha256 at {s3_path}.sha256)")
if hexdigest:
_verify_s3_hash(s3_path, output_path, hexdigest)
return True
else:
raise ValueError(f"No SHA256 hash available for {s3_path}")
def _download_and_verify(s3_path, output_path, remote_sha_allowed, ignore_file_not_exist):
for i in range(5):
try:
print(f"Downloading {s3_path}...")
ok = False
try:
from buildscripts.util.download_utils import download_from_s3_with_boto
download_from_s3_with_boto(s3_path, output_path)
ok = True
except Exception:
try:
from buildscripts.util.download_utils import download_from_s3_with_requests
download_from_s3_with_requests(s3_path, output_path, raise_on_error=True)
ok = True
except Exception:
ok = False
if not ok:
# curl/wget fallback
ok = _download_with_curl_or_wget(s3_path, output_path)
if not ok:
if ignore_file_not_exist:
print("Failed to find remote file. Ignoring and skipping...")
return
raise RuntimeError("All download methods failed")
validate_file(s3_path, output_path, remote_sha_allowed)
break
except Exception:
print("Download failed:")
traceback.print_exc()
if i == 4:
raise
print("Retrying download...")
time.sleep(3)
continue
def download_s3_binary(
s3_path: str,
local_path: str = None,
remote_sha_allowed=False,
ignore_file_not_exist=False,
) -> bool:
if local_path is None:
local_path = s3_path.split("/")[-1]
if os.path.exists(local_path):
try:
print(f"Downloaded file {local_path} already exists, validating...")
validate_file(s3_path, local_path, remote_sha_allowed)
return True
except Exception:
print("File is invalid, redownloading...")
tempfile_name = None
try:
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
tempfile_name = temp_file.name
_download_and_verify(s3_path, tempfile_name, remote_sha_allowed, ignore_file_not_exist)
try:
os.replace(tempfile_name, local_path)
except OSError as e:
if e.errno == 18: # EXDEV cross filesystem error, need to use a mv
shutil.move(tempfile_name, local_path)
else:
raise
print(f"Downloaded and verified {s3_path} -> {local_path}")
return True
except Exception as e:
print(f"Download failed for {s3_path}: {e}")
traceback.print_exc()
return False
finally:
if tempfile_name and os.path.exists(tempfile_name):
os.unlink(tempfile_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download and verify S3 binary.")
parser.add_argument("s3_path", help="S3 URL to download from")
parser.add_argument("local_path", nargs="?", help="Optional output file path")
parser.add_argument("--remote-sha", action="store_true", help="Allow remote .sha256 lookup")
parser.add_argument(
"--ignore-file-not-exist",
action="store_true",
help="Don't fail when remote file doesn't exist.",
)
args = parser.parse_args()
if not download_s3_binary(
args.s3_path, args.local_path, args.remote_sha, args.ignore_file_not_exist
):
sys.exit(1)