Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions bin/embeddings_manager
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import os
import re
import sys
from argparse import ArgumentParser
from pathlib import Path, PurePosixPath
from shutil import rmtree
from typing import NamedTuple, Self
from zipfile import ZIP_DEFLATED, ZipFile

import boto3
import botocore.exceptions
from botocore import UNSIGNED
from botocore.client import Config

Expand Down Expand Up @@ -47,6 +49,42 @@ class EmbeddingSelection(NamedTuple):
return cls(*match.groups())



def _handle_s3_error(error: botocore.exceptions.ClientError, action: str) -> None:
"""Print a human-readable S3 error message and exit."""
code = error.response["Error"]["Code"]
message = error.response["Error"].get("Message", "")
if code in ("403", "AccessDenied"):
print(
f"\nERROR: S3 access denied while trying to {action}.\n"
f" Bucket : {S3_BUCKET}\n"
f" Code : {code} - {message}\n\n"
f"Possible causes:\n"
f" - The S3 bucket is temporarily restricted.\n"
f" - The requested embedding does not exist on S3.\n\n"
f"What to do:\n"
f" 1. Check open issues at https://github.com/reactome/reactome_chatbot/issues\n"
f" 2. Contact maintainers if the bucket should be publicly accessible.\n",
file=sys.stderr,
)
elif code == "NoSuchKey":
print(
f"\nERROR: Embedding not found on S3 while trying to {action}.\n"
f" The requested path does not exist in the bucket.\n\n"
f"Run this to see available embeddings:\n"
f" bin/embeddings_manager ls-remote\n",
file=sys.stderr,
)
else:
print(
f"\nERROR: S3 error while trying to {action}.\n"
f" Code : {code}\n"
f" Message : {message}\n",
file=sys.stderr,
)
sys.exit(1)


def pull(embedding: EmbeddingSelection):
embedding_path:Path = embedding.path(check_exists=False)
zip_tmpfile:Path = EM_ARCHIVE / "tmp.zip"
Expand All @@ -59,6 +97,8 @@ def pull(embedding: EmbeddingSelection):
print("Decompressing...")
with ZipFile(zip_tmpfile, "r") as zipf:
zipf.extractall(embedding_path)
except botocore.exceptions.ClientError as e:
_handle_s3_error(e, action=f"download '{embedding}'")
finally:
zip_tmpfile.unlink(missing_ok=True)
print(f"Saved to {embedding_path}")
Expand Down Expand Up @@ -130,10 +170,13 @@ def ls():
def ls_remote():
s3 = boto3.resource("s3", config=Config(signature_version=UNSIGNED))
s3_bucket = s3.Bucket(S3_BUCKET)
for obj in s3_bucket.objects.filter(Prefix=str(S3_PREFIX)):
relative_path = PurePosixPath(obj.key).relative_to(S3_PREFIX)
if len(relative_path.parts) == 4:
print(relative_path)
try:
for obj in s3_bucket.objects.filter(Prefix=str(S3_PREFIX)):
relative_path = PurePosixPath(obj.key).relative_to(S3_PREFIX)
if len(relative_path.parts) == 4:
print(relative_path)
except botocore.exceptions.ClientError as e:
_handle_s3_error(e, action="list remote embeddings")


def which():
Expand All @@ -144,15 +187,13 @@ def which():
if __name__ == "__main__":
parser = ArgumentParser()

# Parent parser for selecting embeddings
selection_parser = ArgumentParser(add_help=False)
selection_parser.add_argument(
"embedding",
type=EmbeddingSelection.parse,
help="Embedding selection: <modelorg>/<model>/<database>/<version>"
)

# Subcommands
subparsers = parser.add_subparsers(required=True)
pull_parser = subparsers.add_parser(
"pull",
Expand Down Expand Up @@ -206,7 +247,6 @@ if __name__ == "__main__":
)
which_parser.set_defaults(func=which)

# Command-specific arguments
make_parser.add_argument(
"--openai-key",
help="API key for OpenAI"
Expand Down
2 changes: 1 addition & 1 deletion src/retrievers/reactome/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
def create_reactome_rag(
llm: BaseChatModel,
embedding: Embeddings,
embeddings_directory: Path = EmbeddingEnvironment.get_dir("reactome"),
embeddings_directory: Path = EmbeddingEnvironment.get_dir_or_raise("reactome"),
*,
streaming: bool = False,
) -> Runnable:
Expand Down
2 changes: 1 addition & 1 deletion src/retrievers/uniprot/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
def create_uniprot_rag(
llm: BaseChatModel,
embedding: Embeddings,
embeddings_directory: Path = EmbeddingEnvironment.get_dir("uniprot"),
embeddings_directory: Path = EmbeddingEnvironment.get_dir_or_raise("uniprot"),
*,
streaming: bool = False,
) -> Runnable:
Expand Down
37 changes: 35 additions & 2 deletions src/util/embedding_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, env_path: str):
self.embeddings[db] = embedding_path

@classmethod
def _get(cls): # -> Self
def _get(cls):
if EM_CURRENT.exists():
with EM_CURRENT.open("r") as current_fp:
env_path = current_fp.read()
Expand All @@ -33,6 +33,39 @@ def get_dir(cls, key: str) -> Path | None:
else:
return None

@classmethod
def get_dir_or_raise(cls, key: str) -> Path:
"""
Like get_dir(), but raises RuntimeError with actionable install
instructions instead of returning None.

Prevents downstream AttributeError: 'NoneType' object has no
attribute 'glob' when embeddings are not installed.

Raises:
RuntimeError: if no embeddings are configured for `key`,
or if the configured directory does not exist on disk.
"""
directory = cls.get_dir(key)
if directory is None:
raise RuntimeError(
f"\n[ERROR] No embeddings configured for '{key}'.\n"
f"Install them with:\n\n"
f" ./bin/embeddings_manager install "
f"openai/text-embedding-3-large/{key}/ReleaseXX\n\n"
f"List available versions with:\n"
f" ./bin/embeddings_manager ls-remote\n"
)
if not directory.exists():
raise RuntimeError(
f"\n[ERROR] Embeddings directory configured but missing on disk:\n"
f" {directory}\n\n"
f"Re-install with:\n"
f" ./bin/embeddings_manager install "
f"openai/text-embedding-3-large/{key}/ReleaseXX\n"
)
return directory

@classmethod
def get_model(cls, key: str) -> str:
return str(cls._get().embeddings[key].parent.parent)
Expand All @@ -44,4 +77,4 @@ def set_one(cls, embedding_path: Path) -> None:
embeddings_dict[db] = embedding_path
env_path: str = ":".join(map(str, embeddings_dict.values()))
with EM_CURRENT.open("w") as current_fp:
current_fp.write(env_path)
current_fp.write(env_path)