diff --git a/.gitignore b/.gitignore index 68bc17f..413a7e1 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +.opencode diff --git a/package_python_function/nested_zip_loader.py b/package_python_function/nested_zip_loader.py index ffd52d7..342f003 100644 --- a/package_python_function/nested_zip_loader.py +++ b/package_python_function/nested_zip_loader.py @@ -24,33 +24,43 @@ """ def load_nested_zip() -> None: - from pathlib import Path + import fcntl + import importlib import sys import tempfile - import importlib + from pathlib import Path temp_path = Path(tempfile.gettempdir()) target_package_path = temp_path / "package-python-function" - if not target_package_path.exists(): - import zipfile - import shutil - import os + # We use manual locks here to allow target_package_path to stay static, + # but avoid race conditions when multiple processes try to run this + # function at the same time. + lock_path = temp_path / ".package-python-function.lock" + + with open(lock_path, "w") as lock_file: + fcntl.flock(lock_file, fcntl.LOCK_EX) + + if not target_package_path.exists(): + import zipfile + import shutil + import os - staging_package_path = temp_path / ".stage.package-python-function" + staging_package_path = temp_path / ".stage.package-python-function" - if staging_package_path.exists(): - shutil.rmtree(str(staging_package_path)) + if staging_package_path.exists(): + shutil.rmtree(str(staging_package_path)) - nested_zip_path = Path(__file__).parent / '.dependencies.zip' + nested_zip_path = Path(__file__).parent / ".dependencies.zip" - zipfile.ZipFile(str(nested_zip_path), 'r').extractall(str(staging_package_path)) + with zipfile.ZipFile(str(nested_zip_path), "r") as nested_zip: + nested_zip.extractall(str(staging_package_path)) - # The idea here is that we don't rename the path until everything has been successfuly extracted. - # This is expected to be a an atomic operation. That way, if AWS terminates us during the extraction, - # we won't try and use the incomplete extraction. - os.rename(str(staging_package_path), str(target_package_path)) + # The idea here is that we don't rename the path until everything has been successfully extracted. + # This is expected to be an atomic operation. That way, if AWS terminates us during the extraction, + # we won't try and use the incomplete extraction. + os.rename(str(staging_package_path), str(target_package_path)) # Lambda sets up the sys.path like this: # ['/var/task', '/opt/python/lib/python3.13/site-packages', '/opt/python', @@ -65,4 +75,4 @@ def load_nested_zip() -> None: sys.path[0] = str(target_package_path) importlib.reload(sys.modules[__name__]) -load_nested_zip() \ No newline at end of file +load_nested_zip() diff --git a/tests/test_nested_zip_loader.py b/tests/test_nested_zip_loader.py new file mode 100644 index 0000000..38a40de --- /dev/null +++ b/tests/test_nested_zip_loader.py @@ -0,0 +1,79 @@ +import importlib +import multiprocessing +import shutil +import sys +import tempfile +import zipfile +from pathlib import Path + +import pytest + +# We have to use importlib here because nested_zip_loader calls load_nested_zip +# at IMPORT TIME, which causes us a world of hurt in these tests if we try to +# import it "normally" here. +LOADER_PATH = Path(__file__).parent.parent / "package_python_function" / "nested_zip_loader.py" +PKG_NAME = "_test_nested_zip" + +def _make_deps_zip(path: Path) -> None: + with zipfile.ZipFile(path, "w") as zf: + zf.writestr(f"{PKG_NAME}/__init__.py", "LOADED = True\n") + +@pytest.fixture() +def lambda_env(tmp_path, monkeypatch): + """Simulate a Lambda-like layout: a task dir with a package whose __init__.py + is the nested_zip_loader code, and a .dependencies.zip with the 'real' code.""" + task_dir = tmp_path / "task" + pkg_dir = task_dir / PKG_NAME + pkg_dir.mkdir(parents=True) + shutil.copy(LOADER_PATH, pkg_dir / "__init__.py") + _make_deps_zip(pkg_dir / ".dependencies.zip") + + tmp_dir = tmp_path / "tmp" + tmp_dir.mkdir() + monkeypatch.setenv("TMPDIR", str(tmp_dir)) + tempfile.tempdir = None + + monkeypatch.syspath_prepend(str(task_dir)) + + yield tmp_path + + sys.modules.pop(PKG_NAME, None) + tempfile.tempdir = None + +def test_cold_start_extracts(lambda_env): + mod = importlib.import_module(PKG_NAME) + assert mod.LOADED is True + assert (lambda_env / "tmp" / "package-python-function").exists() + +def test_warm_start_skips_extraction(lambda_env): + target_pkg = lambda_env / "tmp" / "package-python-function" / PKG_NAME + target_pkg.mkdir(parents=True) + (target_pkg / "__init__.py").write_text("LOADED = 'warm'\n") + + mod = importlib.import_module(PKG_NAME) + assert mod.LOADED == "warm" + +def test_stale_staging_cleaned(lambda_env): + staging = lambda_env / "tmp" / ".stage.package-python-function" + staging.mkdir(parents=True) + (staging / "stale.txt").write_text("leftover") + + importlib.import_module(PKG_NAME) + assert not staging.exists() + +def _worker(task_dir): + import importlib + import sys + + sys.path.insert(0, task_dir) + assert importlib.import_module(PKG_NAME).LOADED is True + +def test_concurrent_no_race(lambda_env): + ctx = multiprocessing.get_context("forkserver") + procs = [ctx.Process(target=_worker, args=(str(lambda_env / "task"),)) for _ in range(2)] + for p in procs: + p.start() + for p in procs: + p.join(timeout=10) + assert p.exitcode == 0, "A race condition occurred while extracting." + assert (lambda_env / "tmp" / "package-python-function").exists()