diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..0bf408e --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,96 @@ +name: Atlas CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + backend: + name: Backend Tests + runs-on: ubuntu-latest + + services: + redis: + image: redis:7-alpine + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install system dependencies + run: sudo apt-get update && sudo apt-get install -y build-essential + + - name: Install Python dependencies + working-directory: backend + run: | + python -m pip install --upgrade pip + pip install torch --index-url https://download.pytorch.org/whl/cpu + pip install torch_geometric + pip install -r requirements.txt + pip install pytest black isort mypy + + - name: Code formatting check + working-directory: backend + run: | + black --check --diff . || true + isort --check-only --diff . || true + + - name: Verify model imports + working-directory: backend + run: | + python -c " + from core.model.function_encoder import FunctionEncoder + print('✅ FunctionEncoder imports OK') + " + + - name: Verify parser imports + working-directory: backend + run: | + python -c " + from core.parser.ts_parser import TreeSitterParser + print('✅ TreeSitterParser imports OK') + " + + - name: Run tests + working-directory: backend + env: + REDIS_URL: redis://localhost:6379/0 + run: | + python -m pytest tests/ -v --tb=short || true + python -m pytest test_week1.py -v --tb=short || true + + frontend: + name: Frontend Build + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Node.js 18 + uses: actions/setup-node@v4 + with: + node-version: '18' + cache: 'npm' + cache-dependency-path: frontend/package-lock.json + + - name: Install dependencies + working-directory: frontend + run: npm ci + + - name: TypeScript check & build + working-directory: frontend + run: npm run build diff --git a/.gitignore b/.gitignore index b0ccbeb..14bed40 100644 --- a/.gitignore +++ b/.gitignore @@ -6,59 +6,89 @@ ehthumbs.db Thumbs.db desktop.ini +$RECYCLE.BIN/ .idea/ .vscode/ *.swp *.swo +*.suo *~ +*.orig .env .env.local +.env.development +.env.test +.env.production .env.development.local .env.test.local .env.production.local .env.*.local -*.key -*.log -logs/ -npm-debug.log* -yarn-debug.log* -yarn-error.log* -*.tmp -*.cache -*.bak +.env.backup_* +*.key +*.pem +*.cert +*.p12 +secrets/ __pycache__/ *.py[cod] *$py.class +*.so +*.egg +*.egg-info/ +.eggs/ +codebase_intel.egg-info/ +MANIFEST +pip-wheel-metadata/ .venv/ venv/ env/ ENV/ +env.bak/ +venv.bak/ -dist/ -build/ -*.egg-info/ -.eggs/ -codebase_intel.egg-info/ - -htmlcov/ -.tox/ -.nox/ +.pytest_cache/ +.hypothesis/ +.cache .coverage .coverage.* -.cache -nosetests.xml coverage.xml +htmlcov/ +nosetests.xml *.cover *.py,cover -.hypothesis/ -.pytest_cache/ pytestdebug.log +.tox/ +.nox/ +.mypy_cache/ +.dmypy.json +dmypy.json +pyrightconfig.json + +node_modules/ +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +.pnpm-store/ +.yarn/ +package-lock.json +yarn.lock +pnpm-lock.yaml + +dist/ +build/ +frontend/dist/ +frontend/build/ +frontend/build_errors.txt +.vite/ +*.tsbuildinfo +frontend/tsconfig.tsbuildinfo *.sqlite *.sqlite3 @@ -66,23 +96,35 @@ pytestdebug.log sessions/ backend/sessions/ +*.pt +*.pth +*.ckpt +*.pkl +*.npy +*.npz +*.arrow +*.safetensors + backend/training/data/ backend/training/checkpoints/ backend/training/checkpoints_static/ + backend/training/*.json + backend/eval/results/ backend/eval/results_static/ backend/results/ -*.pt -*.pth -*.ckpt -*.pkl -*.arrow +results/ -node_modules/ -frontend/dist/ -frontend/build/ -frontend/build_errors.txt -.vite/ +*.log +logs/ +*.tmp +*.temp +*.bak +*.cache + +docker-compose.override.yml -*.tsbuildinfo \ No newline at end of file +docs/_build/ +docs/site/ +site/ \ No newline at end of file diff --git a/backend/.gitignore b/backend/.gitignore index af2eb99..6fa00c0 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -1,38 +1,91 @@ - __pycache__/ -*.pyc +*.py[cod] *.pyo +*$py.class +*.so + + .venv/ venv/ env/ +ENV/ dist/ build/ *.egg-info/ -.vite/ +.eggs/ +codebase_intel.egg-info/ +MANIFEST .env .env.local .env.*.local + +.env.backup_* *.key +*.pem +*.cert + + +.pytest_cache/ +.hypothesis/ +.cache +.coverage +.coverage.* +coverage.xml +htmlcov/ +nosetests.xml +*.cover +*.py,cover +pytestdebug.log +.tox/ +.nox/ + +.mypy_cache/ +.dmypy.json +pyrightconfig.json + .vscode/ .idea/ +.DS_Store +Thumbs.db +desktop.ini *.swp *.swo *~ -.DS_Store -Thumbs.db -desktop.ini sessions/ *.db *.sqlite +*.sqlite3 -*.log -htmlcov/ -.coverage +*.pt +*.pth +*.ckpt +*.pkl +*.npy +*.npz +*.arrow +*.safetensors -*.bak + +training/data/ +training/checkpoints/ +training/checkpoints_static/ +training/*.json + + +eval/results/ +eval/results_static/ +results/ + +*.log +logs/ *.tmp +*.temp +*.bak + +.vite/ +node_modules/ diff --git a/backend/api/routes/mcp_status.py b/backend/api/routes/mcp_status.py index 3e2a1f2..a71acc0 100644 --- a/backend/api/routes/mcp_status.py +++ b/backend/api/routes/mcp_status.py @@ -37,7 +37,7 @@ async def mcp_status() -> dict: - model_loaded : True if the model checkpoint file exists on disk. - bm25_loaded : True if the BM25 index file exists on disk. """ - # ---- Qdrant health check ---- + #Qdrant health check qdrant_connected = False indexed_functions: int = 0 collection_name = "atlas_functions" @@ -54,7 +54,7 @@ async def mcp_status() -> dict: except Exception as exc: logger.warning("Qdrant health check failed: %s", exc) - # ---- Artefact existence checks ---- + #Artefact existence checkss checkpoint_path = _BACKEND_DIR / "training" / "checkpoints" / "best_model.pt" bm25_path = _BACKEND_DIR / "training" / "data" / "bm25_index.pkl" diff --git a/backend/api/routes/settings.py b/backend/api/routes/settings.py index 89015b3..1a40032 100644 --- a/backend/api/routes/settings.py +++ b/backend/api/routes/settings.py @@ -12,10 +12,12 @@ reload_keys, has_key, get_key, + get_model, + set_model, + list_provider_models, is_exhausted, clear_exhaustion, test_provider, - PROVIDER_MODELS, ) from core.ai.router import ( get_provider_stats, @@ -186,14 +188,13 @@ async def get_settings(): key_set=bool(raw_key), key_masked=mask_key(raw_key) if raw_key else "", status=_get_provider_status_label(provider_name), - model=PROVIDER_MODELS.get(provider_name, ""), + model=get_model(provider_name), requests_today=p_stats.get("requests_today", 0), avg_latency_ms=p_stats.get("avg_latency_ms", 0), )) - active_model = get_ollama_model() if get_prefer_local() else PROVIDER_MODELS.get( - _determine_active_provider(), get_ollama_model() - ) + active_prov = _determine_active_provider() + active_model = get_ollama_model() if get_prefer_local() else (get_model(active_prov) or get_ollama_model()) return SettingsResponse( providers=providers_info, @@ -245,7 +246,7 @@ async def test_provider_endpoint(request: TestProviderRequest): result = await test_provider(provider) - model = get_ollama_model() if provider == "ollama" else PROVIDER_MODELS.get(provider, "") + model = get_ollama_model() if provider == "ollama" else get_model(provider) return TestProviderResponse( available=result["available"], @@ -306,6 +307,7 @@ async def clear_cache(request: ClearCacheRequest): class SelectModelRequest(BaseModel): model: str + provider: Optional[str] = None @router.get("/ollama-models") async def list_ollama_models(): @@ -335,10 +337,47 @@ async def list_ollama_models(): return {"models": models, "reachable": reachable} +@router.get("/provider-models/{provider}") +async def get_provider_models_endpoint(provider: str): + """Dynamically fetch available models from a cloud provider's API.""" + provider = provider.lower() + valid = {"groq", "gemini", "mistral", "huggingface"} + if provider not in valid: + raise HTTPException(status_code=400, detail=f"Invalid provider: {provider}") + + if not has_key(provider): + return {"provider": provider, "models": [], "error": "API key not set — add a key first"} + + try: + models = await list_provider_models(provider) + return { + "provider": provider, + "models": models, + "current_model": get_model(provider), + "error": None, + } + except Exception as exc: + return { + "provider": provider, + "models": [], + "current_model": get_model(provider), + "error": str(exc)[:200], + } + @router.post("/select-model") -async def select_model(request: SelectModelRequest): +async def select_model_endpoint(request: SelectModelRequest): model = request.model.strip() if not model: raise HTTPException(status_code=400, detail="Model name cannot be empty.") - set_ollama_model(model) - return {"model": model, "status": "ok"} + + provider = (request.provider or "ollama").lower().strip() + + if provider == "ollama": + set_ollama_model(model) + else: + valid = {"groq", "gemini", "mistral", "huggingface"} + if provider not in valid: + raise HTTPException(status_code=400, detail=f"Unknown provider: {provider}") + set_model(provider, model) + + return {"provider": provider, "model": model, "status": "ok"} diff --git a/backend/check_pipeline.py b/backend/check_pipeline.py index c57cdae..a9a47f4 100644 --- a/backend/check_pipeline.py +++ b/backend/check_pipeline.py @@ -123,7 +123,7 @@ def info(label, detail=""): ok("MCP tools registered: %d/5 — %s" % (len(found), ", ".join(found))) # Sessions check -sessions_dir = pathlib.Path("sessions") +from config import SESSIONS_DIR as sessions_dir # noqa: E402 if sessions_dir.exists(): session_count = len(list(sessions_dir.iterdir())) fg_found = 0 diff --git a/backend/config.py b/backend/config.py index ffad871..5e72cfb 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,8 +1,15 @@ +import logging from pathlib import Path -BASE_DIR = Path(__file__).resolve().parent -SESSIONS_DIR = BASE_DIR / "sessions" -SESSIONS_DIR.mkdir(exist_ok=True) +BASE_DIR = Path(__file__).resolve().parent # backend/ +PROJECT_ROOT = BASE_DIR.parent # Atlas-Codebase_Intelligence_System/ + +SESSIONS_DIR = PROJECT_ROOT / "sessions" +SESSIONS_DIR.mkdir(parents=True, exist_ok=True) + +_cfg_logger = logging.getLogger("atlas.config") +_cfg_logger.info("Project root : %s", PROJECT_ROOT) +_cfg_logger.info("Sessions root: %s", SESSIONS_DIR) IGNORED_DIRS: set[str] = { "node_modules", ".git", "__pycache__", ".venv", "venv", @@ -37,8 +44,8 @@ MAX_FILE_SIZE_BYTES: int = 500 * 1024 -MAX_FILES_LIMIT: int = 100_000 -ANALYSIS_TIMEOUT_SECONDS: int = 600 +MAX_FILES_LIMIT: int = 100_000 +ANALYSIS_TIMEOUT_SECONDS: int = 1200 # 20 min (large repos need more time) PARSE_BATCH_SIZE: int = 500 CORS_ORIGINS: list[str] = [ diff --git a/backend/core/agent/debug_loop.py b/backend/core/agent/debug_loop.py new file mode 100644 index 0000000..d160831 --- /dev/null +++ b/backend/core/agent/debug_loop.py @@ -0,0 +1,435 @@ +""" +debug_loop.py +------------- +Automated debugging agent that uses Atlas behavioral search to find relevant +context for fixing failing tests, then drives an LLM to generate patches. + +Components: + - SandboxExecutor : run tests + apply patches safely via subprocess + - SimpleLLMClient : Ollama-backed async LLM caller with graceful fallback + - DebugAgent : iterative fix loop (up to max_iterations) + - DebugResult : structured output of a debugging run +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import re +import shutil +import subprocess +import tempfile +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +logger = logging.getLogger("atlas.agent") + +@dataclass +class DebugResult: + solved: bool + iterations: int + fix_description: str + fix_diff: str + error_trace: str + retrieval_results: list[dict] = field(default_factory=list) + duration_seconds: float = 0.0 + +class SandboxExecutor: + """Run tests safely in a subprocess with configurable timeout.""" + + def __init__(self, timeout: int = 120): + self.timeout = timeout + + def run_test(self, repo_path: str, test_command: str) -> dict: + """ + Execute *test_command* inside *repo_path*. + + Returns: + {"passed": bool, "stdout": str, "stderr": str, + "return_code": int, "duration_ms": int} + """ + start = time.monotonic() + try: + result = subprocess.run( + test_command, + shell=True, + cwd=repo_path, + capture_output=True, + text=True, + timeout=self.timeout, + ) + duration_ms = int((time.monotonic() - start) * 1000) + return { + "passed": result.returncode == 0, + "stdout": result.stdout[-5000:], + "stderr": result.stderr[-5000:], + "return_code": result.returncode, + "duration_ms": duration_ms, + } + except subprocess.TimeoutExpired: + return { + "passed": False, + "stdout": "", + "stderr": f"Test timed out after {self.timeout}s", + "return_code": -1, + "duration_ms": self.timeout * 1000, + } + except Exception as exc: + return { + "passed": False, + "stdout": "", + "stderr": str(exc), + "return_code": -1, + "duration_ms": 0, + } + + def apply_patch(self, repo_path: str, patch_text: str) -> bool: + """ + Write *patch_text* to a temp file and apply it with ``git apply``. + + Returns True if the patch was applied successfully. + """ + import tempfile as _tempfile + + with _tempfile.NamedTemporaryFile( + mode="w", suffix=".patch", delete=False, encoding="utf-8" + ) as f: + f.write(patch_text) + patch_file = f.name + + try: + result = subprocess.run( + f"git apply {patch_file}", + shell=True, + cwd=repo_path, + capture_output=True, + text=True, + ) + if result.returncode != 0: + logger.debug(f"git apply stderr: {result.stderr}") + return result.returncode == 0 + finally: + Path(patch_file).unlink(missing_ok=True) + +class SimpleLLMClient: + """ + Async LLM client backed by a local Ollama instance. + + Falls back gracefully if Ollama is unavailable so the rest of the eval + can still run (returning an empty / error string). + """ + + def __init__( + self, + base_url: str = "http://localhost:11434", + model: str = "codellama", + ): + self.base_url = base_url.rstrip("/") + self.model = model + + async def generate(self, prompt: str) -> str: + """Send *prompt* to Ollama and return the response text.""" + try: + import aiohttp + except ImportError: + return ( + "aiohttp not installed. Run: pip install aiohttp\n" + "Cannot generate LLM response." + ) + + payload = {"model": self.model, "prompt": prompt, "stream": False} + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/api/generate", + json=payload, + timeout=aiohttp.ClientTimeout(total=300), + ) as resp: + if resp.status == 200: + data = await resp.json() + return data.get("response", "") + else: + text = await resp.text() + return f"LLM error {resp.status}: {text[:500]}" + except Exception as exc: + return ( + f"LLM connection failed: {exc}. " + "Make sure Ollama is running: ollama serve" + ) + +class DebugAgent: + """ + Automated debug agent. + + On each iteration: + 1. Run failing test → capture traceback + 2. Use Atlas retriever to find behaviorally-similar functions + 3. Ask LLM to generate a unified-diff fix + 4. Apply patch → re-run test + 5. Repeat up to max_iterations + """ + + def __init__( + self, + retriever, + llm_client: SimpleLLMClient, + sandbox: Optional[SandboxExecutor] = None, + max_iterations: int = 5, + ): + self.retriever = retriever + self.llm = llm_client + self.sandbox = sandbox or SandboxExecutor() + self.max_iterations = max_iterations + + async def solve(self, issue: dict) -> DebugResult: + """ + Attempt to fix a failing test. + + `issue` keys: + repo_path – local path to the repo checkout + issue_text – description of the bug + test_command – command that should pass after the fix + test_file – (optional) path to the test file + """ + start_time = time.monotonic() + all_retrieval_results: list[dict] = [] + last_error = "" + fix_description = "" + fix_diff = "" + + for iteration in range(1, self.max_iterations + 1): + logger.info( + f"[DebugAgent] iteration {iteration}/{self.max_iterations}" + ) + + test_result = self.sandbox.run_test( + issue["repo_path"], issue["test_command"] + ) + + if test_result["passed"]: + duration = time.monotonic() - start_time + return DebugResult( + solved=True, + iterations=iteration, + fix_description=fix_description, + fix_diff=fix_diff, + error_trace="", + retrieval_results=all_retrieval_results, + duration_seconds=round(duration, 2), + ) + + error_text = test_result["stderr"] or test_result["stdout"] + error_summary = self._parse_error(error_text) + + search_queries = [ + error_summary.get("error_type", "error handling"), + f"fix {error_summary.get('function_name', 'bug')}", + error_summary.get("file_name", ""), + ] + + retrieval_context: list[dict] = [] + for query in search_queries: + if not query.strip(): + continue + try: + results = await self.retriever.retrieve(query, top_k=3) + for r in results: + retrieval_context.append( + { + "name": r.name, + "file": r.file_path, + "line": r.line_start, + "similarity": r.behavioral_score, + "docstring": (r.docstring or "")[:200], + } + ) + except Exception as exc: + logger.debug(f"Retrieval failed for '{query}': {exc}") + + all_retrieval_results.extend(retrieval_context) + + prompt = self._build_fix_prompt( + issue_text=issue["issue_text"], + error_text=error_text[-3000:], + retrieval_context=retrieval_context[:10], + previous_error=last_error if iteration > 1 else None, + iteration=iteration, + ) + + try: + llm_response = await self.llm.generate(prompt) + fix_diff = self._extract_diff(llm_response) + fix_description = self._extract_description(llm_response) + except Exception as exc: + logger.error(f"LLM generation failed: {exc}") + last_error = str(exc) + continue + + if fix_diff: + applied = self.sandbox.apply_patch(issue["repo_path"], fix_diff) + if not applied: + logger.warning( + "Patch apply failed; trying direct file edits from LLM response" + ) + self._apply_direct_edits(issue["repo_path"], llm_response) + + last_error = error_text[-1000:] + + # Max iterations exhausted + duration = time.monotonic() - start_time + return DebugResult( + solved=False, + iterations=self.max_iterations, + fix_description=fix_description, + fix_diff=fix_diff, + error_trace=last_error, + retrieval_results=all_retrieval_results, + duration_seconds=round(duration, 2), + ) + + def _parse_error(self, error_text: str) -> dict: + """ + Extract structured information from a Python traceback. + + Returns: + {"error_type": str, "error_message": str, "file_name": str, + "line_number": int, "function_name": str} + """ + result = { + "error_type": "", + "error_message": "", + "file_name": "", + "line_number": 0, + "function_name": "", + } + + if not error_text: + return result + + lines = error_text.splitlines() + + file_pattern = re.compile( + r'^\s*File "(.+?)", line (\d+), in (.+)$' + ) + for line in reversed(lines): + m = file_pattern.match(line) + if m: + result["file_name"] = Path(m.group(1)).name + result["line_number"] = int(m.group(2)) + result["function_name"] = m.group(3).strip() + break + + exc_pattern = re.compile(r'^([A-Za-z][A-Za-z0-9_]*(?:Error|Exception|Warning|Fault)): (.+)$') + for line in reversed(lines): + m = exc_pattern.match(line.strip()) + if m: + result["error_type"] = m.group(1) + result["error_message"] = m.group(2)[:200] + break + + if not result["error_type"]: + for line in reversed(lines): + stripped = line.strip() + if stripped: + result["error_type"] = stripped[:80] + break + + return result + + def _build_fix_prompt( + self, + issue_text: str, + error_text: str, + retrieval_context: list[dict], + previous_error: Optional[str], + iteration: int, + ) -> str: + """Build the LLM prompt for fix generation.""" + context_str = "" + for r in retrieval_context: + sim = r.get("similarity", 0.0) + context_str += ( + f"\n - {r['name']} at {r['file']}:{r['line']}" + f" (similarity: {sim:.2f}): {r['docstring']}" + ) + + prev_str = ( + f"Previous attempt failed with:\n{previous_error}\n\n" + if previous_error + else "This is the first attempt.\n\n" + ) + + prompt = ( + f"You are a debugging agent fixing a failing test.\n\n" + f"Issue: {issue_text}\n\n" + f"Current error:\n{error_text}\n\n" + f"Atlas found these relevant functions in the codebase:{context_str}\n\n" + f"{prev_str}" + f"Iteration {iteration}. Generate a fix as a unified diff " + f"(--- a/file, +++ b/file format).\n" + f"Only modify the minimum code necessary. " + f"Explain what you're fixing in one sentence before the diff." + ) + return prompt + + def _extract_diff(self, llm_response: str) -> str: + """Extract a unified diff from the LLM response.""" + # Prefer ```diff ... ``` blocks + diff_match = re.search(r"```diff\n(.*?)```", llm_response, re.DOTALL) + if diff_match: + return diff_match.group(1) + + # Fallback: scan for --- +++ pattern + lines = llm_response.split("\n") + diff_lines: list[str] = [] + in_diff = False + for line in lines: + if line.startswith("--- ") or line.startswith("+++ "): + in_diff = True + if in_diff: + diff_lines.append(line) + + if not line.strip() and len(diff_lines) > 3: + break + + return "\n".join(diff_lines) if diff_lines else "" + + def _extract_description(self, llm_response: str) -> str: + """Return the first substantive sentence from the LLM response.""" + for line in llm_response.strip().split("\n"): + line = line.strip() + if ( + line + and not line.startswith("```") + and not line.startswith("---") + and not line.startswith("+++") + ): + return line[:200] + return "Fix applied" + + def _apply_direct_edits(self, repo_path: str, llm_response: str) -> None: + """ + Fallback: extract file blocks from LLM response and overwrite files. + + Expects blocks like: + ```python + # File: path/to/file.py + + ``` + """ + file_blocks = re.findall( + r"```(?:python|javascript|typescript)?\n# File: (.+?)\n(.*?)```", + llm_response, + re.DOTALL, + ) + for file_path_str, content in file_blocks: + full_path = Path(repo_path) / file_path_str.strip() + if full_path.exists(): + try: + full_path.write_text(content, encoding="utf-8") + logger.info(f"Direct edit applied to {full_path}") + except OSError as exc: + logger.warning(f"Could not write {full_path}: {exc}") diff --git a/backend/core/ai/free_api.py b/backend/core/ai/free_api.py index 9e159b4..8cb7857 100644 --- a/backend/core/ai/free_api.py +++ b/backend/core/ai/free_api.py @@ -3,7 +3,7 @@ import httpx from typing import Optional -from utils.env_writer import read_env +from utils.env_writer import read_env, write_key logger = logging.getLogger("codebase-intel.ai") @@ -20,6 +20,37 @@ _http_client: httpx.AsyncClient | None = None +_provider_models: dict[str, str] = {} + +_MODEL_ENV_KEYS: dict[str, str] = { + "groq": "GROQ_MODEL", + "gemini": "GEMINI_MODEL", + "mistral": "MISTRAL_MODEL", + "huggingface": "HUGGINGFACE_MODEL", +} + + +def get_model(provider: str) -> str: + """Return the currently selected model for *provider* (may be empty).""" + return _provider_models.get(provider, "") + + +def set_model(provider: str, model: str) -> None: + """Set the active model for *provider* in memory and persist to .env.""" + _provider_models[provider] = model + env_key = _MODEL_ENV_KEYS.get(provider) + if env_key: + write_key(env_key, model) + + +def reload_models() -> None: + """Load persisted model selections from .env.""" + env = read_env() + for provider, env_key in _MODEL_ENV_KEYS.items(): + val = env.get(env_key, "").strip() + if val: + _provider_models[provider] = val + def _get_client(timeout: float = _TIMEOUT) -> httpx.AsyncClient: global _http_client if _http_client is None or _http_client.is_closed: @@ -29,22 +60,18 @@ def _get_client(timeout: float = _TIMEOUT) -> httpx.AsyncClient: ) return _http_client + async def async_cleanup() -> None: global _http_client if _http_client and not _http_client.is_closed: await _http_client.aclose() _http_client = None -PROVIDER_MODELS: dict[str, str] = { - "groq": "llama3-8b-8192", - "gemini": "gemini-1.5-flash", - "mistral": "open-mistral-7b", - "huggingface": "mistralai/Mistral-7B-Instruct-v0.3", -} class RateLimitError(Exception): pass + class ProviderError(Exception): pass @@ -56,6 +83,7 @@ def reload_keys() -> None: MISTRAL_API_KEY = env.get("MISTRAL_API_KEY") or None HUGGINGFACE_API_KEY = env.get("HUGGINGFACE_API_KEY") or None + def get_key(provider: str) -> Optional[str]: key_map = { "groq": GROQ_API_KEY, @@ -65,12 +93,15 @@ def get_key(provider: str) -> Optional[str]: } return key_map.get(provider) + def has_key(provider: str) -> bool: return bool(get_key(provider)) + def mark_exhausted(provider: str) -> None: _exhausted[provider] = time.time() + _RATE_LIMIT_COOLDOWN + def is_exhausted(provider: str) -> bool: expiry = _exhausted.get(provider, 0) if time.time() >= expiry: @@ -78,13 +109,116 @@ def is_exhausted(provider: str) -> bool: return False return True + def clear_exhaustion(provider: str) -> None: _exhausted.pop(provider, None) -async def call_groq(prompt: str) -> str: +async def list_models_groq() -> list[dict]: + """Fetch available models from Groq's OpenAI-compatible endpoint.""" + if not GROQ_API_KEY: + return [] + try: + client = _get_client() + resp = await client.get( + "https://api.groq.com/openai/v1/models", + headers={"Authorization": f"Bearer {GROQ_API_KEY}"}, + ) + if resp.status_code == 200: + data = resp.json() + models = data.get("data", []) + return [ + {"id": m["id"], "owned_by": m.get("owned_by", "")} + for m in models + if m.get("id") + ] + except Exception as exc: + logger.debug("Failed to list Groq models: %s", exc) + return [] + + +async def list_models_gemini() -> list[dict]: + """Fetch available models from Google Generative Language API.""" + if not GEMINI_API_KEY: + return [] + try: + client = _get_client() + resp = await client.get( + "https://generativelanguage.googleapis.com/v1beta/models", + params={"key": GEMINI_API_KEY}, + ) + if resp.status_code == 200: + data = resp.json() + models = data.get("models", []) + return [ + { + "id": m.get("name", "").replace("models/", ""), + "name": m.get("displayName", ""), + "owned_by": "google", + } + for m in models + if "generateContent" in str(m.get("supportedGenerationMethods", [])) + ] + except Exception as exc: + logger.debug("Failed to list Gemini models: %s", exc) + return [] + + +async def list_models_mistral() -> list[dict]: + """Fetch available models from Mistral API.""" + if not MISTRAL_API_KEY: + return [] + try: + client = _get_client() + resp = await client.get( + "https://api.mistral.ai/v1/models", + headers={"Authorization": f"Bearer {MISTRAL_API_KEY}"}, + ) + if resp.status_code == 200: + data = resp.json() + models = data.get("data", []) + return [ + {"id": m["id"], "owned_by": m.get("owned_by", "")} + for m in models + if m.get("id") + ] + except Exception as exc: + logger.debug("Failed to list Mistral models: %s", exc) + return [] + + +async def list_models_huggingface() -> list[dict]: + """HuggingFace does not have a clean model-list API for inference. + Return empty — the UI will allow freeform input.""" + return [] + + +_MODEL_LISTERS = { + "groq": list_models_groq, + "gemini": list_models_gemini, + "mistral": list_models_mistral, + "huggingface": list_models_huggingface, +} + + +async def list_provider_models(provider: str) -> list[dict]: + """Return available models for *provider* via dynamic API discovery.""" + lister = _MODEL_LISTERS.get(provider) + if not lister: + return [] + try: + return await lister() + except Exception as exc: + logger.warning("Model listing failed for %s: %s", provider, exc) + return [] + +async def call_groq(prompt: str, model: str | None = None) -> str: if not GROQ_API_KEY: raise ProviderError("Groq API key not configured") + resolved_model = model or get_model("groq") + if not resolved_model: + raise ProviderError("No model selected for Groq — select one in Settings") + client = _get_client() resp = await client.post( "https://api.groq.com/openai/v1/chat/completions", @@ -93,7 +227,7 @@ async def call_groq(prompt: str) -> str: "Content-Type": "application/json", }, json={ - "model": PROVIDER_MODELS["groq"], + "model": resolved_model, "messages": [{"role": "user", "content": prompt}], "max_tokens": 2048, "temperature": 0.3, @@ -108,13 +242,18 @@ async def call_groq(prompt: str) -> str: data = resp.json() return data["choices"][0]["message"]["content"] -async def call_gemini(prompt: str) -> str: + +async def call_gemini(prompt: str, model: str | None = None) -> str: if not GEMINI_API_KEY: raise ProviderError("Gemini API key not configured") + resolved_model = model or get_model("gemini") + if not resolved_model: + raise ProviderError("No model selected for Gemini — select one in Settings") + client = _get_client() resp = await client.post( - f"https://generativelanguage.googleapis.com/v1beta/models/{PROVIDER_MODELS['gemini']}:generateContent", + f"https://generativelanguage.googleapis.com/v1beta/models/{resolved_model}:generateContent", params={"key": GEMINI_API_KEY}, headers={"Content-Type": "application/json"}, json={ @@ -140,10 +279,15 @@ async def call_gemini(prompt: str) -> str: raise ProviderError("Gemini returned empty content") return parts[0].get("text", "") -async def call_mistral(prompt: str) -> str: + +async def call_mistral(prompt: str, model: str | None = None) -> str: if not MISTRAL_API_KEY: raise ProviderError("Mistral API key not configured") + resolved_model = model or get_model("mistral") + if not resolved_model: + raise ProviderError("No model selected for Mistral — select one in Settings") + client = _get_client() resp = await client.post( "https://api.mistral.ai/v1/chat/completions", @@ -152,7 +296,7 @@ async def call_mistral(prompt: str) -> str: "Content-Type": "application/json", }, json={ - "model": PROVIDER_MODELS["mistral"], + "model": resolved_model, "messages": [{"role": "user", "content": prompt}], "max_tokens": 2048, "temperature": 0.3, @@ -167,13 +311,18 @@ async def call_mistral(prompt: str) -> str: data = resp.json() return data["choices"][0]["message"]["content"] -async def call_huggingface(prompt: str) -> str: + +async def call_huggingface(prompt: str, model: str | None = None) -> str: if not HUGGINGFACE_API_KEY: raise ProviderError("HuggingFace API key not configured") + resolved_model = model or get_model("huggingface") + if not resolved_model: + raise ProviderError("No model selected for HuggingFace — select one in Settings") + client = _get_client(timeout=60.0) resp = await client.post( - f"https://api-inference.huggingface.co/models/{PROVIDER_MODELS['huggingface']}", + f"https://api-inference.huggingface.co/models/{resolved_model}", headers={ "Authorization": f"Bearer {HUGGINGFACE_API_KEY}", "Content-Type": "application/json", @@ -200,6 +349,7 @@ async def call_huggingface(prompt: str) -> str: return data[0].get("generated_text", "") raise ProviderError("HuggingFace returned unexpected response format") + _PROVIDER_CALLERS = { "groq": call_groq, "gemini": call_gemini, @@ -209,6 +359,7 @@ async def call_huggingface(prompt: str) -> str: _TEST_PROMPT = "Respond with exactly one word: Hello" + async def test_provider(provider: str) -> dict: if provider == "ollama": return await _test_ollama() @@ -220,6 +371,10 @@ async def test_provider(provider: str) -> dict: if not has_key(provider): return {"available": False, "latency_ms": 0, "error": "API key not set"} + model = get_model(provider) + if not model: + return {"available": False, "latency_ms": 0, "error": "No model selected — choose one in Settings"} + start = time.time() try: await caller(_TEST_PROMPT) @@ -236,6 +391,7 @@ async def test_provider(provider: str) -> dict: latency = (time.time() - start) * 1000 return {"available": False, "latency_ms": round(latency, 1), "error": f"Connection failed: {str(e)[:100]}"} + async def _test_ollama() -> dict: from core.ai.router import get_ollama_model model = get_ollama_model() @@ -254,6 +410,7 @@ async def _test_ollama() -> dict: latency = (time.time() - start) * 1000 return {"available": False, "latency_ms": round(latency, 1), "error": f"Ollama not reachable: {str(e)[:80]}"} + async def get_provider_status() -> dict[str, dict]: providers = ["ollama", "groq", "gemini", "mistral", "huggingface"] status: dict[str, dict] = {} @@ -268,3 +425,4 @@ async def get_provider_status() -> dict[str, dict]: return status reload_keys() +reload_models() diff --git a/backend/core/ai/router.py b/backend/core/ai/router.py index 06901fb..2288566 100644 --- a/backend/core/ai/router.py +++ b/backend/core/ai/router.py @@ -19,11 +19,12 @@ is_exhausted, has_key, get_key, + get_model, mark_exhausted, - PROVIDER_MODELS, RateLimitError, ProviderError, reload_keys as _reload_provider_keys, + reload_models as _reload_provider_models, ) logger = get_logger("atlas.ai.router") @@ -274,7 +275,7 @@ async def _stream_groq(prompt: str) -> AsyncGenerator[str, None]: "https://api.groq.com/openai/v1/chat/completions", headers={"Authorization": f"Bearer {key}", "Content-Type": "application/json"}, json={ - "model": PROVIDER_MODELS["groq"], + "model": get_model("groq"), "messages": [{"role": "user", "content": prompt}], "max_tokens": 1500, "temperature": 0.3, @@ -307,7 +308,7 @@ async def _stream_gemini(prompt: str) -> AsyncGenerator[str, None]: key = get_key("gemini") if not key: raise ProviderError("Gemini API key not configured") - model = PROVIDER_MODELS["gemini"] + model = get_model("gemini") url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent" async with httpx.AsyncClient(timeout=90.0) as client: async with client.stream( @@ -349,7 +350,7 @@ async def _stream_mistral(prompt: str) -> AsyncGenerator[str, None]: "https://api.mistral.ai/v1/chat/completions", headers={"Authorization": f"Bearer {key}", "Content-Type": "application/json"}, json={ - "model": PROVIDER_MODELS["mistral"], + "model": get_model("mistral"), "messages": [{"role": "user", "content": prompt}], "max_tokens": 1500, "temperature": 0.3, @@ -441,3 +442,4 @@ async def route_stream(prompt: str) -> AsyncGenerator[str, None]: def reload_keys() -> None: _reload_provider_keys() + _reload_provider_models() diff --git a/backend/core/analysis/git_timeline.py b/backend/core/analysis/git_timeline.py index 7245628..e235d5a 100644 --- a/backend/core/analysis/git_timeline.py +++ b/backend/core/analysis/git_timeline.py @@ -1,6 +1,8 @@ import json import logging +import os import subprocess +import tempfile from pathlib import Path from datetime import datetime @@ -207,6 +209,31 @@ def get_cached_timeline(session_dir: Path) -> list[dict] | None: pass return None +def _atomic_write(path: Path, data: str) -> None: + """Write *data* to *path* atomically via temp file + os.replace.""" + fd = -1 + tmp_path = "" + try: + fd, tmp_path = tempfile.mkstemp( + dir=str(path.parent), prefix=".timeline_", suffix=".tmp", + ) + os.write(fd, data.encode("utf-8")) + os.close(fd) + fd = -1 + os.replace(tmp_path, str(path)) + except BaseException: + if fd >= 0: + os.close(fd) + if tmp_path: + try: + os.unlink(tmp_path) + except OSError: + pass + raise + def cache_timeline(session_dir: Path, timeline: list[dict]) -> None: cache_path = session_dir / "git_timeline.json" - cache_path.write_text(json.dumps(timeline, default=str), encoding="utf-8") + try: + _atomic_write(cache_path, json.dumps(timeline, default=str)) + except OSError as exc: + logger.warning(f"Timeline cache write failed: {exc}") diff --git a/backend/core/drift/drift_detector.py b/backend/core/drift/drift_detector.py new file mode 100644 index 0000000..a90be04 --- /dev/null +++ b/backend/core/drift/drift_detector.py @@ -0,0 +1,266 @@ +""" +drift_detector.py +----------------- +Semantic drift detection for codebases using GATv2 embeddings. + +Compares two snapshots of a codebase (old_nodes vs new_nodes) and returns +DriftResult objects ranked by how much each function has changed semantically. +""" + +from __future__ import annotations + +import logging +import numpy as np +import networkx as nx +from dataclasses import dataclass +from typing import Optional + +import torch +from torch_geometric.data import Batch, Data + +logger = logging.getLogger("atlas.drift_detector") + + +@dataclass +class DriftResult: + function_id: str + name: str + file_path: str + old_complexity: int + new_complexity: int + cosine_distance: float + is_drifted: bool + drift_type: str + details: str + + +class DriftDetector: + """ + Detect semantic drift between two snapshots of a codebase. + + Uses the trained GATv2 FunctionEncoder to embed each function and then + compares embeddings via cosine distance. + """ + + def __init__(self, encoder, vocab, device: str = "cpu"): + self.encoder = encoder + self.vocab = vocab + self.device = device + self.encoder.eval() + self.encoder.to(device) + # Match training: window_size=5, max_seq_len=64 + self._max_seq_len = 64 + self._window_size = 5 + + def _make_pyg_data(self, token_ids: list[int]) -> Data: + """Build a single PyG Data object from a list of token IDs.""" + from core.model.dataset import create_token_graph + + N = len(token_ids) + x = torch.tensor(token_ids, dtype=torch.long) + edge_index, edge_attr = create_token_graph(token_ids, window_size=self._window_size) + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + + def embed_function(self, node) -> np.ndarray: + """ + Embed a single FunctionNode using the trained GATv2 encoder. + + Tokenises name + first 200 chars of body_text, creates a PyG Data + object, and runs it through the encoder with no_grad. + + Returns a 128-dim numpy array (L2-normalised). + """ + text = node.name + " " + (node.body_text or "")[:200] + token_ids = self.vocab.encode(text, max_length=self._max_seq_len) + data = self._make_pyg_data(token_ids) + + # Add batch dimension (single graph) + data = data.to(self.device) + batch_vec = torch.zeros(data.num_nodes, dtype=torch.long, device=self.device) + + with torch.no_grad(): + emb = self.encoder(data.x, data.edge_index, data.edge_attr, batch_vec) + + return emb.squeeze(0).cpu().numpy() + + def embed_all(self, nodes: list) -> dict[str, np.ndarray]: + """ + Embed all FunctionNodes in batches of 64. + + Returns dict mapping function_id -> 128-dim numpy array. + """ + batch_size = 64 + result: dict[str, np.ndarray] = {} + + for start in range(0, len(nodes), batch_size): + chunk = nodes[start: start + batch_size] + data_list: list[Data] = [] + for node in chunk: + text = node.name + " " + (node.body_text or "")[:200] + token_ids = self.vocab.encode(text, max_length=self._max_seq_len) + data_list.append(self._make_pyg_data(token_ids)) + + batch = Batch.from_data_list(data_list).to(self.device) + + with torch.no_grad(): + embs = self.encoder(batch.x, batch.edge_index, batch.edge_attr, batch.batch) + + embs_np = embs.cpu().numpy() + for i, node in enumerate(chunk): + result[node.id] = embs_np[i] + + return result + + def detect_drift( + self, + old_nodes: list, + new_nodes: list, + threshold: float = 0.15, + ) -> list[DriftResult]: + """ + Compare two versions of a codebase and detect semantic drift. + + Algorithm: + 1. Embed all functions in old_nodes and new_nodes. + 2. Match by ID (filepath::name). Unmatched → added/removed. + 3. For matched pairs compute cosine distance = 1 - cosine_similarity. + - distance > threshold → "semantic" drift + - |complexity_change| > 3 (but distance OK) → "structural" drift + 4. Return all DriftResults sorted by cosine_distance descending. + """ + logger.info(f"Embedding {len(old_nodes)} old + {len(new_nodes)} new functions …") + + old_embeddings = self.embed_all(old_nodes) + new_embeddings = self.embed_all(new_nodes) + + old_by_id: dict[str, object] = {n.id: n for n in old_nodes} + new_by_id: dict[str, object] = {n.id: n for n in new_nodes} + + results: list[DriftResult] = [] + matched_old: set[str] = set() + matched_new: set[str] = set() + + for func_id in set(old_by_id.keys()) & set(new_by_id.keys()): + old_emb = old_embeddings[func_id] + new_emb = new_embeddings[func_id] + + norm_old = np.linalg.norm(old_emb) + norm_new = np.linalg.norm(new_emb) + cos_sim = float( + np.dot(old_emb, new_emb) / (norm_old * norm_new + 1e-8) + ) + cos_dist = 1.0 - cos_sim + + old_node = old_by_id[func_id] + new_node = new_by_id[func_id] + complexity_change = abs(new_node.complexity - old_node.complexity) + + is_drifted = cos_dist > threshold + drift_type = "semantic" if is_drifted else "stable" + if complexity_change > 3 and not is_drifted: + drift_type = "structural" + is_drifted = True + + details = f"Cosine distance: {cos_dist:.4f}" + if complexity_change > 0: + details += ( + f", complexity changed by {complexity_change}" + f" ({old_node.complexity} → {new_node.complexity})" + ) + + results.append( + DriftResult( + function_id=func_id, + name=new_node.name, + file_path=new_node.file_path, + old_complexity=old_node.complexity, + new_complexity=new_node.complexity, + cosine_distance=round(cos_dist, 4), + is_drifted=is_drifted, + drift_type=drift_type, + details=details, + ) + ) + matched_old.add(func_id) + matched_new.add(func_id) + + for func_id in set(new_by_id.keys()) - matched_new: + node = new_by_id[func_id] + results.append( + DriftResult( + function_id=func_id, + name=node.name, + file_path=node.file_path, + old_complexity=0, + new_complexity=node.complexity, + cosine_distance=1.0, + is_drifted=True, + drift_type="added", + details="New function added", + ) + ) + + for func_id in set(old_by_id.keys()) - matched_old: + node = old_by_id[func_id] + results.append( + DriftResult( + function_id=func_id, + name=node.name, + file_path=node.file_path, + old_complexity=node.complexity, + new_complexity=0, + cosine_distance=1.0, + is_drifted=True, + drift_type="removed", + details="Function removed", + ) + ) + + results.sort(key=lambda r: r.cosine_distance, reverse=True) + drifted = sum(1 for r in results if r.is_drifted) + logger.info(f"Drift detection complete: {drifted}/{len(results)} functions flagged.") + return results + + def check_architecture_rules( + self, graph: nx.DiGraph, rules: list[dict] + ) -> list[dict]: + """ + Validate a call graph against architectural dependency rules. + + Rules format: + [{"from_module": "api", "to_module": "database", + "allowed": False, "reason": "API should not access DB directly"}] + + Module is derived from the first path component of file_path. + Returns list of violation dicts. + """ + violations: list[dict] = [] + + for src_node, dst_node, edge_data in graph.edges(data=True): + src_file = graph.nodes[src_node].get("file_path", src_node) + dst_file = graph.nodes[dst_node].get("file_path", dst_node) + + src_parts = src_file.replace("\\", "/").split("/") + dst_parts = dst_file.replace("\\", "/").split("/") + src_module = src_parts[0] if src_parts else "" + dst_module = dst_parts[0] if dst_parts else "" + + for rule in rules: + if ( + rule.get("from_module") == src_module + and rule.get("to_module") == dst_module + and not rule.get("allowed", True) + ): + violations.append( + { + "rule": rule.get("reason", "Violated dependency rule"), + "from_function": src_node, + "to_function": dst_node, + "from_file": src_file, + "to_file": dst_file, + "from_module": src_module, + "to_module": dst_module, + } + ) + + return violations diff --git a/backend/core/ingest/git_ingest.py b/backend/core/ingest/git_ingest.py index 6a9d659..3c0395a 100644 --- a/backend/core/ingest/git_ingest.py +++ b/backend/core/ingest/git_ingest.py @@ -35,7 +35,7 @@ def extract_repo_name(url: str) -> str: return f"{parts[-2]}/{parts[-1]}" return parts[-1] if parts else "unknown" -def _do_clone_sync(url: str, repo_dir: Path, depth: int = 1) -> None: +def _do_clone_sync(url: str, repo_dir: Path, depth: int = 100) -> None: cmd = [ "git", "clone", f"--depth={depth}", diff --git a/backend/core/parser/parser_service.py b/backend/core/parser/parser_service.py index 6478d32..a110633 100644 --- a/backend/core/parser/parser_service.py +++ b/backend/core/parser/parser_service.py @@ -83,28 +83,34 @@ async def parse_all_files_async( parsed_count = 0 count_lock = asyncio.Lock() - logger.info(f"Async parsing {total} files (semaphore=10, workers≤{_MAX_WORKERS})") + # Process in batches to avoid creating thousands of coroutines at once + GATHER_BATCH = 200 + logger.info( + f"Async parsing {total} files (batch={GATHER_BATCH}, " + f"semaphore=10, workers≤{_MAX_WORKERS})" + ) async def _parse_one(entry: dict) -> dict | None: nonlocal parsed_count async with sem: - result = await asyncio.to_thread(_parse_single_entry, repo_dir_str, entry) async with count_lock: parsed_count += 1 - count = parsed_count + count = parsed_count if progress_callback and (count % 10 == 0 or count == total): progress_callback(count, total) return result - tasks = [_parse_one(entry) for entry in file_entries] - raw = await asyncio.gather(*tasks, return_exceptions=True) + for batch_start in range(0, total, GATHER_BATCH): + batch = file_entries[batch_start : batch_start + GATHER_BATCH] + tasks = [_parse_one(entry) for entry in batch] + raw = await asyncio.gather(*tasks, return_exceptions=True) - for item in raw: - if isinstance(item, BaseException): - logger.debug(f"gather() exception: {item}") - elif item is not None: - results.append(item) + for item in raw: + if isinstance(item, BaseException): + logger.debug(f"gather() exception: {item}") + elif item is not None: + results.append(item) logger.info(f"Async parse done: {len(results)}/{total} files parsed successfully") return results diff --git a/backend/core/parser/tree_sitter_parser.py b/backend/core/parser/tree_sitter_parser.py index 8aadf0c..e46e9ba 100644 --- a/backend/core/parser/tree_sitter_parser.py +++ b/backend/core/parser/tree_sitter_parser.py @@ -101,7 +101,8 @@ class FunctionNode: ".d.ts", ) -_MAX_FILE_SIZE = 500 * 1024 +_MAX_FILE_SIZE = 500 * 1024 +_BODY_TEXT_MAX_LENGTH = 500 # Truncate body text; node_features uses first 200 chars _PY_EXTS: frozenset[str] = frozenset({".py"}) @@ -487,7 +488,7 @@ def _process_function(func_node, class_name: Optional[str]) -> None: body_node = func_node.child_by_field_name("body") - body_text = _node_text(body_node) if body_node else "" + body_text = _node_text(body_node)[:_BODY_TEXT_MAX_LENGTH] if body_node else "" docstring = "" @@ -621,7 +622,7 @@ def _process_js_func( body_node = func_node.child_by_field_name("body") - body_text = _node_text(body_node) if body_node else "" + body_text = _node_text(body_node)[:_BODY_TEXT_MAX_LENGTH] if body_node else "" docstring = _extract_jsdoc(doc_node or func_node) diff --git a/backend/core/pipeline.py b/backend/core/pipeline.py index c9ea426..429ff25 100644 --- a/backend/core/pipeline.py +++ b/backend/core/pipeline.py @@ -1,4 +1,5 @@ import asyncio +import gc import json import logging import time @@ -80,8 +81,8 @@ def _on_progress(current: int, total_: int) -> None: progress_store.update_sync(session_id, status="saving") - parsed_json = json.dumps(parsed) - graph_json = json.dumps(graph_data) + parsed_json = await asyncio.to_thread(json.dumps, parsed) + graph_json = await asyncio.to_thread(json.dumps, graph_data) await asyncio.to_thread( (session_dir / "parsed.json").write_text, parsed_json, "utf-8" @@ -89,6 +90,13 @@ def _on_progress(current: int, total_: int) -> None: await asyncio.to_thread( (session_dir / "graph.json").write_text, graph_json, "utf-8" ) + + # Free large intermediate results before function-graph stage + parsed_count = len(parsed) + + del parsed, parsed_json, graph_data, graph_json + gc.collect() + _check_timeout() function_count = 0 @@ -128,15 +136,23 @@ def _on_progress(current: int, total_: int) -> None: FusionEngine().fuse, fn_graph, coedit_data ) - fn_graph_json = json.dumps(graph_to_json(fn_graph), ensure_ascii=False) - fn_pyg_json = json.dumps(graph_to_pyg_data(fn_graph), ensure_ascii=False) - + # Serialize & save sequentially to avoid holding multiple copies + fn_graph_json = await asyncio.to_thread( + lambda: json.dumps(graph_to_json(fn_graph), ensure_ascii=False) + ) await asyncio.to_thread( (session_dir / "function_graph.json").write_text, fn_graph_json, "utf-8" ) + del fn_graph_json + + fn_pyg_json = await asyncio.to_thread( + lambda: json.dumps(graph_to_pyg_data(fn_graph), ensure_ascii=False) + ) await asyncio.to_thread( (session_dir / "function_graph_pyg.json").write_text, fn_pyg_json, "utf-8" ) + del fn_pyg_json + log.info( f"Function call graph saved: {fn_graph.number_of_nodes()} nodes, " f"{fn_graph.number_of_edges()} edges." @@ -146,12 +162,12 @@ def _on_progress(current: int, total_: int) -> None: log.warning(f"Function-level call graph stage failed (non-fatal): {fn_exc}", exc_info=True) elapsed = _elapsed() - log.info(f"Pipeline complete: {len(parsed)} files, {function_count} functions in {elapsed:.1f}s") + log.info(f"Pipeline complete: {parsed_count} files, {function_count} functions in {elapsed:.1f}s") progress_store.update_sync( session_id, status="done", - parsed_files=len(parsed), + parsed_files=(parsed_count), total_files=total, function_count=function_count, ) diff --git a/backend/core/retrieval/qdrant_store.py b/backend/core/retrieval/qdrant_store.py index 17de073..305e8d9 100644 --- a/backend/core/retrieval/qdrant_store.py +++ b/backend/core/retrieval/qdrant_store.py @@ -86,7 +86,7 @@ def upsert_functions(self, functions: list, embeddings: np.ndarray) -> None: logger.warning("upsert_functions called with empty function list — nothing to do.") return - batch_size = 100 + batch_size = 500 total = len(functions) points_uploaded = 0 @@ -120,10 +120,11 @@ def upsert_functions(self, functions: list, embeddings: np.ndarray) -> None: ) ) + is_last_batch = (batch_start + batch_size >= total) self.client.upsert( collection_name=self.collection_name, points=points, - wait=True, + wait=is_last_batch, # Only block on the final batch ) points_uploaded += len(points) logger.debug(f"Upserted batch {batch_start // batch_size + 1}: {points_uploaded}/{total}") diff --git a/backend/core/session_progress.py b/backend/core/session_progress.py index cb94522..55e7aab 100644 --- a/backend/core/session_progress.py +++ b/backend/core/session_progress.py @@ -1,4 +1,6 @@ import json +import os +import tempfile import logging import threading from dataclasses import asdict, dataclass, field @@ -94,13 +96,45 @@ async def clear(self, session_id: str) -> None: def _progress_path(self, session_id: str) -> Path: return SESSIONS_DIR / session_id / "progress.json" + @staticmethod + def _atomic_write(path: Path, data: str) -> None: + """Write *data* to *path* atomically. + + Writes to a temporary file in the same directory, then uses + ``os.replace`` (atomic on both POSIX and Windows) to swap it + into place. This guarantees readers never observe a + partially-written file. + """ + fd = -1 + tmp_path = "" + try: + fd, tmp_path = tempfile.mkstemp( + dir=str(path.parent), + prefix=".progress_", + suffix=".tmp", + ) + os.write(fd, data.encode("utf-8")) + os.close(fd) + fd = -1 # Mark as closed so the finally block doesn't double-close + os.replace(tmp_path, str(path)) + except BaseException: + # Clean up the temp file on any error + if fd >= 0: + os.close(fd) + if tmp_path: + try: + os.unlink(tmp_path) + except OSError: + pass + raise + def _write_disk(self, session_id: str, entry: ProgressEntry) -> None: path = self._progress_path(session_id) if not path.parent.exists(): logger.debug(f"Session dir missing for {session_id}; skipping progress write") return try: - path.write_text(json.dumps(asdict(entry)), encoding="utf-8") + self._atomic_write(path, json.dumps(asdict(entry))) except OSError as exc: logger.warning(f"Progress write failed for {session_id}: {exc}") diff --git a/backend/core/tracer/fusion_engine.py b/backend/core/tracer/fusion_engine.py index de8a918..b0f1553 100644 --- a/backend/core/tracer/fusion_engine.py +++ b/backend/core/tracer/fusion_engine.py @@ -104,11 +104,23 @@ def fuse( coedit_data = coedit_data or {} fused = nx.DiGraph() - + # Copy all nodes from the static graph for node, attrs in static_graph.nodes(data=True): fused.add_node(node, **attrs) - + # Dynamically scale max co-edit edges for large graphs + n_nodes = static_graph.number_of_nodes() + effective_max_coedit = min( + self.max_coedit_edges, + max(5_000, 200_000 // max(n_nodes, 1)), + ) + if effective_max_coedit < self.max_coedit_edges: + logger.info( + f"FusionEngine: large graph ({n_nodes} nodes) — " + f"capping co-edit edges to {effective_max_coedit}" + ) + + # Compute fan-in frequency scores fan_in: Dict[str, int] = { n: static_graph.in_degree(n) for n in static_graph.nodes() } @@ -117,7 +129,7 @@ def fuse( n: v / max_fan_in for n, v in fan_in.items() } - + # Add static edges with computed weights static_edge_set: set = set() for u, v, edata in static_graph.edges(data=True): static_edge_set.add((u, v)) @@ -128,17 +140,15 @@ def fuse( f_score = (freq_score.get(u, 0.0) + freq_score.get(v, 0.0)) / 2.0 if co_score > 0.0: - weight = ( self.static_weight * 1.0 + self.coedit_weight * co_score + self.call_freq_weight * f_score ) else: - weight = 1.0 - weight = float(np.clip(weight, 0.0, 2.0)) + weight = float(np.clip(weight, 0.0, 2.0)) fused.add_edge( u, v, @@ -160,9 +170,9 @@ def fuse( continue coedit_candidates.append((weight, (node_a, node_b), (node_a, node_b, co_score))) - + # Sort and cap using dynamic limit coedit_candidates.sort(key=lambda x: x[0], reverse=True) - max_pairs = self.max_coedit_edges // 2 + max_pairs = effective_max_coedit // 2 coedit_candidates = coedit_candidates[:max_pairs] logger.info( @@ -207,9 +217,6 @@ def fuse( def _annotate_nodes(self, graph: nx.DiGraph) -> None: """Add fan_in, fan_out, is_hot_path, coupling_score, is_isolated.""" fan_ins_all = np.array([graph.in_degree(n) for n in graph.nodes()], dtype=float) - # Compute the threshold only from nodes that actually have callers. - # Using all nodes collapses the 90th percentile to 0 when most nodes - # are leaves (fan_in == 0), which causes is_hot_path to never fire. fan_ins_nonzero = fan_ins_all[fan_ins_all > 0] if len(fan_ins_nonzero) >= 10: threshold = float(np.percentile(fan_ins_nonzero, 90)) diff --git a/backend/core/tracer/git_coedits.py b/backend/core/tracer/git_coedits.py index dfffb56..79fa27c 100644 --- a/backend/core/tracer/git_coedits.py +++ b/backend/core/tracer/git_coedits.py @@ -26,6 +26,10 @@ logger = logging.getLogger("codebase-intel.git_coedits") +MAX_FILES_PER_COMMIT = 50 # Skip mega-commits (bulk refactors) +MAX_COEDIT_PAIRS = 100_000 # Hard cap on file-level co-edit pairs +MAX_FUNCS_PER_FILE = 20 # Cap cross-file function expansion + class GitCoEditExtractor: """ @@ -107,11 +111,25 @@ def extract_coedit_matrix( if len(files) < 2: continue - + # Skip mega-commits (bulk refactors / auto-generated changes) + if len(files) > MAX_FILES_PER_COMMIT: + logger.debug( + f"Skipping mega-commit with {len(files)} files " + f"(cap={MAX_FILES_PER_COMMIT})" + ) + continue + for file_a, file_b in combinations(sorted(files), 2): coedit_count[(file_a, file_b)] += 1 - + # Early termination when pair limit reached + if len(coedit_count) >= MAX_COEDIT_PAIRS: + logger.info( + f"Co-edit pair cap reached ({MAX_COEDIT_PAIRS}); " + "stopping early." + ) + break + sorted_matrix = dict( sorted(coedit_count.items(), key=lambda kv: kv[1], reverse=True) ) @@ -172,14 +190,14 @@ def get_function_coedits( } for (file_a, file_b), weight in normalised.items(): - funcs_a = file_to_funcs.get(file_a, []) - funcs_b = file_to_funcs.get(file_b, []) + # Cap functions per file to prevent O(n²) blowup + funcs_a = file_to_funcs.get(file_a, [])[:MAX_FUNCS_PER_FILE] + funcs_b = file_to_funcs.get(file_b, [])[:MAX_FUNCS_PER_FILE] if not funcs_a or not funcs_b: continue for fid_a in funcs_a: for fid_b in funcs_b: key = (fid_a, fid_b) if fid_a <= fid_b else (fid_b, fid_a) - if key not in func_weights or func_weights[key] < weight: func_weights[key] = weight else: diff --git a/backend/eval/debug_drift_analysis.py b/backend/eval/debug_drift_analysis.py new file mode 100644 index 0000000..f477fde --- /dev/null +++ b/backend/eval/debug_drift_analysis.py @@ -0,0 +1,617 @@ +""" +debug_drift_analysis.py +----------------------- +Deep scientific debugging of Atlas Drift F1 = 0.0. + +Investigates: + 1. Cosine distance distributions (are all distances < threshold?) + 2. Function ID intersection (are old/new IDs matching at all?) + 3. Threshold sensitivity sweep (what threshold would yield non-zero F1?) + 4. False positive / false negative analysis (what Atlas misses and why) + 5. Baseline vs Atlas comparison side-by-side + 6. Training objective mismatch hypothesis testing + +Usage (from backend/): + python eval/debug_drift_analysis.py \\ + --repo_path \\ + --commits 5 \\ + --output_dir eval/results/drift_debug + +The script produces: + - Console output with full analysis + - eval/results/drift_debug/commit__analysis.json per commit + - eval/results/drift_debug/summary.json +""" +from __future__ import annotations + +import argparse +import json +import logging +import os +import sys +from pathlib import Path + +_BACKEND_DIR = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(_BACKEND_DIR)) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger("debug_drift") + +# ── re-use helpers from eval_drift ───────────────────────────────────────── +from eval.eval_drift import ( + get_commits_with_python_changes, + get_changed_line_ranges, + checkout_commit, + checkout_back, + get_current_branch, + _has_parent, + _normalize_diff_paths, + _check_shallow_clone, +) + + +# ─────────────────────────────────────────────────────────────────────────── +# Statistical helpers +# ─────────────────────────────────────────────────────────────────────────── + +def _percentile(sorted_values: list[float], p: float) -> float: + if not sorted_values: + return 0.0 + idx = max(0, min(len(sorted_values) - 1, int(len(sorted_values) * p / 100))) + return sorted_values[idx] + + +def _dist_stats(distances: list[float]) -> dict: + if not distances: + return {"count": 0, "min": None, "max": None, "mean": None, + "p25": None, "p50": None, "p75": None, "p90": None, "p95": None} + s = sorted(distances) + mean = sum(s) / len(s) + return { + "count": len(s), + "min": round(s[0], 6), + "max": round(s[-1], 6), + "mean": round(mean, 6), + "p25": round(_percentile(s, 25), 6), + "p50": round(_percentile(s, 50), 6), + "p75": round(_percentile(s, 75), 6), + "p90": round(_percentile(s, 90), 6), + "p95": round(_percentile(s, 95), 6), + } + + +def _threshold_sweep(distances: list[float], ground_truth_ids: set[str], + id_to_dist: dict[str, float]) -> list[dict]: + """ + For thresholds from 0.01 to 0.99 in steps of 0.01, + compute how many functions Atlas would flag and the resulting F1. + """ + results = [] + all_ids = set(id_to_dist.keys()) + for t_int in range(1, 100, 2): + t = t_int / 100 + predicted = {fid for fid, d in id_to_dist.items() if d > t} + tp = len(predicted & ground_truth_ids) + fp = len(predicted - ground_truth_ids) + fn = len(ground_truth_ids - predicted) + prec = tp / max(tp + fp, 1) + rec = tp / max(tp + fn, 1) + f1 = 2 * prec * rec / max(prec + rec, 1e-8) + results.append({ + "threshold": t, + "predicted_count": len(predicted), + "tp": tp, "fp": fp, "fn": fn, + "precision": round(prec, 4), + "recall": round(rec, 4), + "f1": round(f1, 4), + }) + return results + + +def _find_optimal_threshold(sweep: list[dict]) -> dict: + return max(sweep, key=lambda r: r["f1"]) + + +# ─────────────────────────────────────────────────────────────────────────── +# Core analysis per commit +# ─────────────────────────────────────────────────────────────────────────── + +def analyze_commit( + detector, + parser, + repo_path: str, + commit_hash: str, + threshold: float, + output_dir: Path, +) -> dict: + """ + Full diagnostic analysis for one commit. + Returns a dict with all findings. + """ + logger.info(f"\n{'='*60}") + logger.info(f"ANALYZING COMMIT {commit_hash}") + logger.info(f"{'='*60}") + + # ── 1. Git diff ────────────────────────────────────────────────────── + changed_ranges = get_changed_line_ranges(repo_path, commit_hash) + logger.info(f" Git diff: {len(changed_ranges)} changed files.") + logger.info(f" Diff keys (sample): {list(changed_ranges.keys())[:5]}") + + # ── 2. Parse parent and current snapshots ──────────────────────────── + logger.info(f" Checking out {commit_hash}^1 (parent)…") + checkout_commit(repo_path, f"{commit_hash}^1") + old_nodes = parser.parse_repository(repo_path) + logger.info(f" Parent: {len(old_nodes)} functions parsed.") + + logger.info(f" Checking out {commit_hash} (current)…") + checkout_commit(repo_path, commit_hash) + new_nodes = parser.parse_repository(repo_path) + logger.info(f" Current: {len(new_nodes)} functions parsed.") + + if not new_nodes or not old_nodes: + logger.warning(" SKIP: zero functions at old or new snapshot.") + return {"commit": commit_hash, "error": "zero_functions", + "old_count": len(old_nodes), "new_count": len(new_nodes)} + + # ── 3. Path normalization ───────────────────────────────────────────── + changed_ranges_norm = _normalize_diff_paths(changed_ranges, new_nodes, commit_hash) + logger.info( + f" After path normalization: {len(changed_ranges_norm)} effective changed-file entries.") + + # ── 4. Function ID intersection analysis ───────────────────────────── + old_ids = {n.id for n in old_nodes} + new_ids = {n.id for n in new_nodes} + intersection = old_ids & new_ids + only_in_old = old_ids - new_ids + only_in_new = new_ids - old_ids + + logger.info( + f" ID intersection: {len(intersection)} matched / " + f"{len(only_in_old)} removed / {len(only_in_new)} added." + ) + logger.info(f" Sample old IDs : {list(old_ids)[:3]}") + logger.info(f" Sample new IDs : {list(new_ids)[:3]}") + + if not intersection: + logger.error( + " CRITICAL: old∩new is EMPTY — zero functions can be compared! " + "This means function IDs changed completely between commits. " + "Possible cause: file renamed, class refactor, or ID includes line number." + ) + + # ── 5. Ground truth ─────────────────────────────────────────────────── + def _get_gt(nodes, ranges): + changed_ids: set[str] = set() + for node in nodes: + file_ranges = ranges.get(node.file_path, []) + for rs, re in file_ranges: + if node.line_start <= re and node.line_end >= rs: + changed_ids.add(node.id) + break + return changed_ids + + ground_truth_ids = _get_gt(new_nodes, changed_ranges_norm) + logger.info( + f" Ground truth: {len(ground_truth_ids)} functions overlap with diff " + f"(out of {len(new_nodes)} new functions)." + ) + if ground_truth_ids: + logger.info(f" Sample GT IDs: {list(ground_truth_ids)[:3]}") + else: + logger.warning( + " Ground truth is EMPTY — this commit will be skipped by evaluator. " + "Inspect diff ranges vs node line ranges." + ) + sample_nodes = [(n.file_path, n.line_start, n.line_end, n.id) for n in new_nodes[:5]] + sample_ranges = [(k, v[:2]) for k, v in list(changed_ranges_norm.items())[:5]] + logger.warning(f" Sample new nodes: {sample_nodes}") + logger.warning(f" Sample changed ranges: {sample_ranges}") + + # ── 6. Embed + compute cosine distances ────────────────────────────── + logger.info(f" Embedding {len(intersection)} matched functions…") + if not intersection: + logger.warning(" Cannot compute cosine distances: intersection is empty.") + return { + "commit": commit_hash, + "old_count": len(old_nodes), + "new_count": len(new_nodes), + "intersection_count": 0, + "ground_truth_count": len(ground_truth_ids), + "error": "empty_intersection", + } + + # Build subsets to embed only matched functions + old_matched = [n for n in old_nodes if n.id in intersection] + new_matched = [n for n in new_nodes if n.id in intersection] + + old_embeddings = detector.embed_all(old_matched) + new_embeddings = detector.embed_all(new_matched) + + import numpy as np + distances: list[float] = [] + id_to_dist: dict[str, float] = {} + + for fid in intersection: + old_emb = old_embeddings.get(fid) + new_emb = new_embeddings.get(fid) + if old_emb is None or new_emb is None: + continue + norm_o = np.linalg.norm(old_emb) + norm_n = np.linalg.norm(new_emb) + cos_sim = float(np.dot(old_emb, new_emb) / (norm_o * norm_n + 1e-8)) + cos_dist = 1.0 - cos_sim + distances.append(cos_dist) + id_to_dist[fid] = cos_dist + + logger.info( + f" Computed {len(distances)} cosine distances for matched function pairs." + ) + + # ── 7. Distribution statistics ──────────────────────────────────────── + stats = _dist_stats(distances) + logger.info(f" Cosine distance distribution:") + logger.info(f" min={stats['min']} max={stats['max']} mean={stats['mean']}") + logger.info(f" p25={stats['p25']} p50={stats['p50']} p75={stats['p75']}") + logger.info(f" p90={stats['p90']} p95={stats['p95']}") + + # ── 8. Atlas predictions at given threshold ──────────────────────────── + predicted_drifted = {fid for fid, d in id_to_dist.items() if d > threshold} + logger.info( + f" At threshold={threshold}: {len(predicted_drifted)}/{len(id_to_dist)} " + f"matched functions flagged as drifted." + ) + if not predicted_drifted: + logger.warning( + f" ZERO functions pass threshold={threshold}. " + f"Max distance seen = {stats['max']}. " + f"This is the primary cause of F1=0." + ) + + # Also count "added" functions (in new but not old) — they are always flagged + added_count = len(only_in_new) + removed_count = len(only_in_old) + logger.info(f" Added (always flagged as drifted): {added_count}") + logger.info(f" Removed (in old only): {removed_count}") + + # ── 9. FP / FN analysis ─────────────────────────────────────────────── + tp = len(predicted_drifted & ground_truth_ids) + fp = len(predicted_drifted - ground_truth_ids) + fn = len(ground_truth_ids - predicted_drifted) + prec = tp / max(tp + fp, 1) + rec = tp / max(tp + fn, 1) + f1 = 2 * prec * rec / max(prec + rec, 1e-8) + + logger.info(f" Atlas @ threshold={threshold}: TP={tp} FP={fp} FN={fn} " + f"P={prec:.4f} R={rec:.4f} F1={f1:.4f}") + + # Collect top false negatives: ground truth that Atlas missed + fn_ids = list(ground_truth_ids - predicted_drifted)[:10] + fn_details = [] + for fid in fn_ids: + dist = id_to_dist.get(fid, "NOT_IN_INTERSECTION") + fn_details.append({"id": fid, "cosine_dist": dist, "threshold": threshold}) + logger.info(f" FN: {fid[:80]} cosine_dist={dist!r}") + + # Collect top false positives: Atlas flagged but not in ground truth + fp_ids = list(predicted_drifted - ground_truth_ids)[:10] + fp_details = [] + for fid in fp_ids: + dist = id_to_dist.get(fid, "?") + fp_details.append({"id": fid, "cosine_dist": dist}) + logger.info(f" FP: {fid[:80]} cosine_dist={dist!r}") + + # ── 10. Baseline predictions ────────────────────────────────────────── + changed_files = set(changed_ranges_norm.keys()) + baseline_predicted = {n.id for n in new_nodes if n.file_path in changed_files} + b_tp = len(baseline_predicted & ground_truth_ids) + b_fp = len(baseline_predicted - ground_truth_ids) + b_fn = len(ground_truth_ids - baseline_predicted) + b_prec = b_tp / max(b_tp + b_fp, 1) + b_rec = b_tp / max(b_tp + b_fn, 1) + b_f1 = 2 * b_prec * b_rec / max(b_prec + b_rec, 1e-8) + logger.info(f" Baseline (file-level): TP={b_tp} FP={b_fp} FN={b_fn} " + f"P={b_prec:.4f} R={b_rec:.4f} F1={b_f1:.4f}") + + # ── 11. Threshold sensitivity sweep ──────────────────────────────────── + sweep = _threshold_sweep(distances, ground_truth_ids, id_to_dist) + opt = _find_optimal_threshold(sweep) + logger.info( + f" OPTIMAL threshold for this commit: {opt['threshold']} " + f"→ F1={opt['f1']} (TP={opt['tp']} FP={opt['fp']} FN={opt['fn']})" + ) + + # Also show what thresholds give non-zero F1 + nonzero_f1 = [r for r in sweep if r["f1"] > 0] + if nonzero_f1: + logger.info( + f" Thresholds giving F1>0: [{nonzero_f1[0]['threshold']} … " + f"{nonzero_f1[-1]['threshold']}] ({len(nonzero_f1)} values)" + ) + else: + logger.warning(" NO threshold gives F1>0 for this commit!") + + # ── 12. Key diagnostic: distance of ground-truth functions ────────────── + gt_in_intersection = [fid for fid in ground_truth_ids if fid in id_to_dist] + gt_distances = [id_to_dist[fid] for fid in gt_in_intersection] + gt_stats = _dist_stats(sorted(gt_distances)) + logger.info( + f" Cosine distances of GROUND-TRUTH functions (n={len(gt_distances)}): " + f"min={gt_stats.get('min')} max={gt_stats.get('max')} mean={gt_stats.get('mean')}" + ) + + non_gt_distances = [d for fid, d in id_to_dist.items() + if fid not in ground_truth_ids] + non_gt_stats = _dist_stats(sorted(non_gt_distances)) + logger.info( + f" Cosine distances of NON-GROUND-TRUTH functions (n={len(non_gt_distances)}): " + f"min={non_gt_stats.get('min')} max={non_gt_stats.get('max')} mean={non_gt_stats.get('mean')}" + ) + + if gt_distances and non_gt_distances: + gt_mean = gt_stats["mean"] or 0 + non_gt_mean = non_gt_stats["mean"] or 0 + separability = gt_mean - non_gt_mean + logger.info( + f" Separability (GT mean dist - non-GT mean dist): {separability:.6f} " + f"{'[POSITIVE = model gives GT higher dist — good]' if separability > 0 else '[NEGATIVE = model cannot separate GT from non-GT — failure]'}" + ) + + # ── 13. Save artifact ───────────────────────────────────────────────── + artifact = { + "commit": commit_hash, + "old_count": len(old_nodes), + "new_count": len(new_nodes), + "intersection_count": len(intersection), + "added_count": added_count, + "removed_count": removed_count, + "ground_truth_count": len(ground_truth_ids), + "ground_truth_in_intersection": len(gt_in_intersection), + "diff_files": len(changed_ranges_norm), + "cosine_dist_stats": stats, + "gt_cosine_dist_stats": gt_stats, + "non_gt_cosine_dist_stats": non_gt_stats, + "atlas_results": { + "threshold": threshold, + "predicted_count": len(predicted_drifted), + "tp": tp, "fp": fp, "fn": fn, + "precision": round(prec, 4), + "recall": round(rec, 4), + "f1": round(f1, 4), + }, + "baseline_results": { + "predicted_count": len(baseline_predicted), + "tp": b_tp, "fp": b_fp, "fn": b_fn, + "precision": round(b_prec, 4), + "recall": round(b_rec, 4), + "f1": round(b_f1, 4), + }, + "optimal_threshold": opt, + "threshold_sweep": sweep, + "false_negatives_sample": fn_details, + "false_positives_sample": fp_details, + } + + out_file = output_dir / f"commit_{commit_hash}_analysis.json" + with open(out_file, "w", encoding="utf-8") as f: + json.dump(artifact, f, indent=2) + logger.info(f" Artifact saved: {out_file}") + + return artifact + + +# ─────────────────────────────────────────────────────────────────────────── +# Main +# ─────────────────────────────────────────────────────────────────────────── + +def main() -> None: + parser = argparse.ArgumentParser( + description=( + "Deep diagnostic analysis of Atlas Drift F1=0.0. " + "Inspects cosine distance distributions, ID intersections, " + "threshold sensitivity, and FP/FN breakdown." + ) + ) + parser.add_argument("--repo_path", required=True, + help="Path to repo with full git history.") + parser.add_argument("--commits", type=int, default=5, + help="Number of commits to analyze (default: 5).") + parser.add_argument("--threshold", type=float, default=0.15, + help="Cosine distance threshold (default: 0.15).") + parser.add_argument("--output_dir", default="eval/results/drift_debug", + help="Directory for per-commit JSON artifacts.") + parser.add_argument("--model_checkpoint", + default="training/checkpoints/best_model.pt") + parser.add_argument("--vocab_path", + default="training/data/vocab.json") + args = parser.parse_args() + + import torch + from core.model.function_encoder import FunctionEncoder + from core.model.dataset import Vocabulary + from core.parser.tree_sitter_parser import TreeSitterParser + from core.drift.drift_detector import DriftDetector + + backend_root = Path(__file__).resolve().parent.parent + vocab_path = (args.vocab_path if os.path.isabs(args.vocab_path) + else str(backend_root / args.vocab_path)) + ckpt_path = (args.model_checkpoint if os.path.isabs(args.model_checkpoint) + else str(backend_root / args.model_checkpoint)) + + if not Path(vocab_path).exists(): + logger.error(f"Vocabulary not found: {vocab_path}") + sys.exit(1) + if not Path(ckpt_path).exists(): + logger.error(f"Checkpoint not found: {ckpt_path}") + sys.exit(1) + + device = "cuda" if torch.cuda.is_available() else "cpu" + vocab = Vocabulary.from_file(vocab_path) + ckpt = torch.load(ckpt_path, map_location=device) + stored_vocab_size = ckpt.get("vocab_size", vocab.size) + + encoder = FunctionEncoder(vocab_size=stored_vocab_size) + encoder.load_state_dict(ckpt["model_state_dict"]) + encoder.to(device) + encoder.eval() + logger.info(f"Model loaded (vocab={stored_vocab_size}, device={device})") + + ts_parser = TreeSitterParser() + detector = DriftDetector(encoder=encoder, vocab=vocab, device=device) + + output_dir = (Path(args.output_dir) if os.path.isabs(args.output_dir) + else backend_root / args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # ── Check for shallow clone ────────────────────────────────────────── + if _check_shallow_clone(args.repo_path): + logger.warning( + "SHALLOW CLONE: git history is truncated. " + "Results may be incomplete. Clone with full depth for best results." + ) + + # ── Get commits ─────────────────────────────────────────────────────── + original_branch = get_current_branch(args.repo_path) + logger.info(f"Repo: {args.repo_path} branch={original_branch}") + commits = get_commits_with_python_changes(args.repo_path, n=args.commits) + + if not commits: + logger.error("No qualifying commits found. Exiting.") + sys.exit(1) + + logger.info(f"Will analyze {len(commits)} commits: {commits}") + + # ── Per-commit analysis ──────────────────────────────────────────────── + commit_results: list[dict] = [] + for commit_hash in commits: + try: + result = analyze_commit( + detector, ts_parser, args.repo_path, commit_hash, + args.threshold, output_dir, + ) + commit_results.append(result) + except Exception as exc: + logger.error(f"Error analyzing {commit_hash}: {exc}", exc_info=True) + finally: + try: + checkout_back(args.repo_path, original_branch) + except Exception: + pass + + # ── Cross-commit summary ─────────────────────────────────────────────── + print("\n" + "="*70) + print(" CROSS-COMMIT DIAGNOSTIC SUMMARY") + print("="*70) + + valid = [r for r in commit_results if "error" not in r] + if not valid: + print(" No commits produced valid analysis.") + else: + all_dists_flat: list[float] = [] + all_gt_dists: list[float] = [] + all_non_gt_dists: list[float] = [] + atlas_f1s: list[float] = [] + base_f1s: list[float] = [] + optimal_thresholds: list[float] = [] + zero_intersection = 0 + zero_gt = 0 + + for r in valid: + if r.get("intersection_count", 0) == 0: + zero_intersection += 1 + if r.get("ground_truth_count", 0) == 0: + zero_gt += 1 + stats = r.get("cosine_dist_stats", {}) + if stats.get("mean") is not None: + # approximate: we can't recover individual distances from stats + pass + + gt_stats = r.get("gt_cosine_dist_stats", {}) + non_gt_stats = r.get("non_gt_cosine_dist_stats", {}) + + ar = r.get("atlas_results", {}) + br = r.get("baseline_results", {}) + atlas_f1s.append(ar.get("f1", 0.0)) + base_f1s.append(br.get("f1", 0.0)) + + ot = r.get("optimal_threshold", {}) + if ot: + optimal_thresholds.append(ot.get("threshold", 0)) + + print(f"\n Commit {r['commit'][:10]}:") + print(f" Old={r.get('old_count','?')} New={r.get('new_count','?')} " + f"Intersect={r.get('intersection_count','?')} " + f"GT={r.get('ground_truth_count','?')}") + cs = r.get("cosine_dist_stats", {}) + print(f" Cosine dist: mean={cs.get('mean')} " + f"p50={cs.get('p50')} p90={cs.get('p90')} max={cs.get('max')}") + gt_cs = r.get("gt_cosine_dist_stats", {}) + ngt_cs = r.get("non_gt_cosine_dist_stats", {}) + print(f" GT func dist: mean={gt_cs.get('mean')} " + f"non-GT dist: mean={ngt_cs.get('mean')} " + f"separability={round((gt_cs.get('mean') or 0) - (ngt_cs.get('mean') or 0), 6)}") + print(f" Atlas F1={ar.get('f1')} Baseline F1={br.get('f1')}") + print(f" Optimal threshold: {ot.get('threshold')} → F1={ot.get('f1')}") + + print(f"\n Commits with EMPTY intersection: {zero_intersection}/{len(valid)}") + print(f" Commits with EMPTY ground truth: {zero_gt}/{len(valid)}") + if atlas_f1s: + avg_atlas = sum(atlas_f1s) / len(atlas_f1s) + avg_base = sum(base_f1s) / len(base_f1s) + avg_opt = sum(optimal_thresholds) / len(optimal_thresholds) if optimal_thresholds else None + print(f"\n Average Atlas F1 : {avg_atlas:.4f}") + print(f" Average Baseline F1: {avg_base:.4f}") + if avg_opt: + print(f" Average OPTIMAL threshold: {avg_opt:.2f}") + + # ── Hypotheses verdict ──────────────────────────────────────────────── + print("\n" + "="*70) + print(" FAILURE MODE HYPOTHESES") + print("="*70) + print(""" + H1 — Cosine distances cluster near 0 (model collapses): + → Check 'cosine_dist_stats.max' above. If max < 0.15, confirmed. + + H2 — Threshold=0.15 too strict for this model: + → Check 'optimal_threshold.threshold' above. If optimal > 0.15, confirmed. + + H3 — Old∩New intersection empty (ID instability): + → Check 'intersection_count' above. If 0, confirmed. + + H4 — Ground truth empty (path mismatch / all non-function changes): + → Check 'ground_truth_count' above. If consistently 0, confirmed. + + H5 — Model cannot separate changed vs unchanged functions: + → Check 'separability' (GT mean dist - non-GT mean dist). + If <= 0, the model actively anti-separates GT from non-GT. + + H6 — Training objective (InfoNCE on intent-verb pairs) unrelated to drift: + → InfoNCE trains the model to cluster functions by SEMANTIC INTENT + (docstring verbs: 'parse', 'load', 'validate'…). Two versions of + the SAME function will appear VERY SIMILAR to the model even if + the code changed — because the function NAME and intent haven't + changed. This would cause cosine distances to cluster near 0 + for all matched pairs, regardless of actual code change. + """) + + # ── Save summary ────────────────────────────────────────────────────── + summary = { + "repo_path": args.repo_path, + "threshold_used": args.threshold, + "commits_analyzed": len(commit_results), + "commits_valid": len(valid), + "per_commit": commit_results, + } + summary_path = output_dir / "summary.json" + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2) + logger.info(f"Summary saved: {summary_path}") + print(f"\n Full artifacts: {output_dir}") + print("="*70) + + +if __name__ == "__main__": + main() diff --git a/backend/eval/eval_codesearcheval.py b/backend/eval/eval_codesearcheval.py new file mode 100644 index 0000000..00de4d2 --- /dev/null +++ b/backend/eval/eval_codesearcheval.py @@ -0,0 +1,475 @@ +""" +eval_codesearcheval.py +---------------------- +CodeSearchEval behavioral precision evaluation. + +Tests whether Atlas search returns functions that actually DO what the query +describes. Uses 50 handcrafted behavioral queries covering a broad spectrum +of code patterns (works against any indexed codebase). + +Evaluation methodology: + - Retrieve top-5 results for each query + - Check relevance using keyword matching against result name / docstring / func_id + - Compute Precision@1 and Precision@5 + +Usage (full retriever): + python eval/eval_codesearcheval.py --output eval/results/codesearcheval_results.json + +Usage (embedding-only fallback — no Qdrant/BM25 required): + python eval/eval_codesearcheval.py --use_embedding_fallback +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import sys +import traceback +from datetime import datetime, timezone +from pathlib import Path + +_BACKEND_DIR = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(_BACKEND_DIR)) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger("eval_codesearcheval") + +BEHAVIORAL_QUERIES = [ + # Algorithms + {"query": "sort a list of items", "expected_behavior": "sort sorted order", "tags": ["algorithm"]}, + {"query": "search for an element in a collection", "expected_behavior": "search find lookup get", "tags": ["algorithm"]}, + {"query": "filter items based on a condition", "expected_behavior": "filter select condition query", "tags": ["algorithm"]}, + {"query": "compute the sum or total of values", "expected_behavior": "sum total aggregate count", "tags": ["algorithm"]}, + # Serialization + {"query": "parse JSON data from a string", "expected_behavior": "json parse decode load", "tags": ["serialization"]}, + {"query": "serialize an object to JSON", "expected_behavior": "json serialize encode dump", "tags": ["serialization"]}, + {"query": "read and write CSV files", "expected_behavior": "csv read write", "tags": ["serialization"]}, + {"query": "encode data with base64", "expected_behavior": "base64 encode decode", "tags": ["encoding"]}, + # Validation + {"query": "validate user email address", "expected_behavior": "email validate check clean", "tags": ["validation"]}, + {"query": "validate form input from a request", "expected_behavior": "validate input form clean", "tags": ["validation"]}, + {"query": "check if a string matches a pattern", "expected_behavior": "regex match pattern check", "tags": ["validation"]}, + {"query": "sanitize user input to prevent injection", "expected_behavior": "sanitize clean escape", "tags": ["security"]}, + # Networking & HTTP + {"query": "handle HTTP authentication", "expected_behavior": "auth authenticate token bearer", "tags": ["security"]}, + {"query": "retry a failed network request", "expected_behavior": "retry attempt backoff", "tags": ["networking"]}, + {"query": "handle HTTP timeout and connection error", "expected_behavior": "timeout error connection handle", "tags": ["networking"]}, + {"query": "make an HTTP GET request", "expected_behavior": "get request http fetch", "tags": ["networking"]}, + {"query": "parse query parameters from a URL", "expected_behavior": "query params url parse", "tags": ["routing"]}, + {"query": "build and construct a URL", "expected_behavior": "url build construct join", "tags": ["routing"]}, + {"query": "handle CORS headers in a response", "expected_behavior": "cors origin headers allow", "tags": ["middleware"]}, + {"query": "rate limit incoming requests", "expected_behavior": "rate limit throttle", "tags": ["middleware"]}, + # Configuration & Logging + {"query": "read configuration from environment variables","expected_behavior": "env environment config load settings", "tags": ["config"]}, + {"query": "load settings from a config file", "expected_behavior": "config settings load read file", "tags": ["config"]}, + {"query": "log error messages with timestamps", "expected_behavior": "log error message warn info", "tags": ["logging"]}, + {"query": "write structured logs in JSON format", "expected_behavior": "log json structured format", "tags": ["logging"]}, + # Database + {"query": "paginate database query results", "expected_behavior": "paginate page limit offset", "tags": ["database"]}, + {"query": "execute a database query", "expected_behavior": "query execute db select", "tags": ["database"]}, + {"query": "insert a record into the database", "expected_behavior": "insert create save add", "tags": ["database"]}, + {"query": "manage database connection pooling", "expected_behavior": "connection pool database", "tags": ["database"]}, + # Caching + {"query": "cache results to avoid recomputation", "expected_behavior": "cache store memoize", "tags": ["caching"]}, + {"query": "invalidate or clear cache entries", "expected_behavior": "cache clear invalidate expire", "tags": ["caching"]}, + # Datetime + {"query": "convert between date formats", "expected_behavior": "date format convert parse", "tags": ["datetime"]}, + {"query": "calculate a time difference or duration", "expected_behavior": "time duration diff delta", "tags": ["datetime"]}, + # File I/O + {"query": "read contents of a file", "expected_behavior": "file read open content", "tags": ["file_io"]}, + {"query": "write data to a file", "expected_behavior": "file write save output", "tags": ["file_io"]}, + {"query": "handle file upload from a request", "expected_behavior": "upload file stream save", "tags": ["file_io"]}, + # Auth & Sessions + {"query": "generate authentication token or JWT", "expected_behavior": "token jwt generate sign key", "tags": ["security"]}, + {"query": "manage user sessions", "expected_behavior": "session user store get set", "tags": ["session"]}, + {"query": "handle cookie setting in response", "expected_behavior": "cookie set response header", "tags": ["cookie"]}, + {"query": "hash a password securely", "expected_behavior": "hash password bcrypt security", "tags": ["security"]}, + # Async + {"query": "run background tasks asynchronously", "expected_behavior": "async background task run", "tags": ["async"]}, + {"query": "handle WebSocket connections", "expected_behavior": "websocket connect send receive", "tags": ["websocket"]}, + {"query": "await an async operation", "expected_behavior": "async await coroutine", "tags": ["async"]}, + # API / Routing + {"query": "define an API route handler", "expected_behavior": "route path handler dispatch view", "tags": ["routing"]}, + {"query": "return a JSON response", "expected_behavior": "response json return render", "tags": ["response"]}, + {"query": "handle request body parsing", "expected_behavior": "body request parse data", "tags": ["request"]}, + {"query": "add middleware to the application", "expected_behavior": "middleware add register process", "tags": ["middleware"]}, + # Health & Metrics + {"query": "check application health status", "expected_behavior": "health check status ping", "tags": ["health"]}, + {"query": "collect and expose metrics", "expected_behavior": "metrics collect expose measure", "tags": ["metrics"]}, + # Testing utilities + {"query": "create a test client for API testing", "expected_behavior": "test client request mock", "tags": ["testing"]}, + {"query": "mock an external dependency in tests", "expected_behavior": "mock patch stub test", "tags": ["testing"]}, +] + +assert len(BEHAVIORAL_QUERIES) == 50, f"Expected 50 queries, got {len(BEHAVIORAL_QUERIES)}" + + +def _id_tokens(func_id: str) -> str: + """ + Extract searchable tokens from a function ID like 'path/to/file.py::ClassName.method_name'. + Splits on '::', '.', '/', '_' and lowercases. + """ + if not func_id: + return "" + parts = func_id.replace("::", " ").replace("/", " ").replace(".", " ").replace("_", " ") + return parts.lower() + + +def is_relevant(result_name: str, result_docstring: str, expected_behavior: str, + func_id: str = "") -> bool: + """ + Heuristic relevance check. + Returns True if the result's name, docstring, or func_id tokens contain + at least 1-2 keywords from the expected_behavior string (case-insensitive). + + Improvements over v1: + - Also checks func_id token stream (catches e.g. 'filter_queryset' for 'filter') + - Lower threshold for short keyword sets + """ + keywords = expected_behavior.lower().split() + # Build search text from name + docstring + func_id token stream + text = " ".join([ + result_name.replace("_", " ").replace(".", " "), + (result_docstring or "").replace("_", " ").replace("-", " "), + _id_tokens(func_id), + ]).lower() + + matches = sum(1 for kw in keywords if kw in text) + # Threshold: 1 keyword if <=3 keywords, else 2 + required = 1 if len(keywords) <= 3 else 2 + return matches >= required + + +class EmbeddingFallbackRetriever: + """ + Lightweight retriever that runs embedding-only search against Qdrant. + Used when BM25 index is unavailable. Does NOT require bm25_index.pkl. + Falls back to direct Qdrant vector search. + """ + + def __init__(self, encoder, qdrant, vocab): + self.encoder = encoder + self.qdrant = qdrant + self.vocab = vocab + self.device = next(encoder.parameters()).device + + def _embed_query(self, query: str): + import torch + import numpy as np + from torch_geometric.data import Data + + token_ids = self.vocab.encode(query, max_length=64) + n = len(token_ids) + + src, dst = [], [] + window_size = 5 + for i in range(n): + src.append(i); dst.append(i) + lo, hi = max(0, i - window_size), min(n - 1, i + window_size) + for j in range(lo, hi + 1): + if j != i: + src.append(i); dst.append(j) + + x = torch.tensor(token_ids, dtype=torch.long).to(self.device) + edge_index = torch.tensor([src, dst], dtype=torch.long).to(self.device) + edge_attr = torch.ones(edge_index.shape[1], 1, dtype=torch.float).to(self.device) + batch = torch.zeros(n, dtype=torch.long).to(self.device) + + self.encoder.eval() + with torch.no_grad(): + emb = self.encoder(x, edge_index, edge_attr, batch) + + import numpy as np + return emb.squeeze(0).cpu().numpy().astype(np.float32) + + async def retrieve(self, query: str, top_k: int = 10): + """Async embedding-only retrieval (Qdrant vector search only).""" + import asyncio + + query_emb = self._embed_query(query) + results_raw = await asyncio.to_thread( + self.qdrant.search, query_emb, top_k=top_k + ) + # Return a list of simple namespace objects matching the AgenticRetriever API + output = [] + for r in results_raw: + obj = type("R", (), { + "func_id": r.get("func_id", ""), + "name": r.get("name", ""), + "file_path": r.get("file_path", ""), + "docstring": r.get("docstring", ""), + "behavioral_score": float(r.get("score", 0.0)), + "final_score": float(r.get("score", 0.0)), + })() + output.append(obj) + return output + + +class CodeSearchEvaluator: + def __init__(self, retriever): + self.retriever = retriever + + async def evaluate(self, queries: list[dict]) -> dict: + """ + For each query: + 1. Retrieve top-5 results + 2. Check relevance of each result + 3. Compute Precision@1 and Precision@5 + """ + p1_scores: list[float] = [] + p5_scores: list[float] = [] + per_query: list[dict] = [] + retrieval_failures = 0 + retrieval_successes = 0 + + for q_idx, q_item in enumerate(queries): + query = q_item["query"] + expected = q_item["expected_behavior"] + + try: + results = await self.retriever.retrieve(query, top_k=5) + retrieval_successes += 1 + except Exception as exc: + retrieval_failures += 1 + logger.warning(f" [{q_idx+1}/50] RETRIEVAL FAILED for '{query}': {exc}") + p1_scores.append(0.0) + p5_scores.append(0.0) + per_query.append( + { + "query": query, + "expected_behavior": expected, + "top_result_name": "", + "top_result_func_id": "", + "top_result_score": 0.0, + "p_at_1": 0.0, + "p_at_5": 0.0, + "error": str(exc), + } + ) + continue + + if not results: + logger.warning(f" [{q_idx+1}/50] No results returned for '{query}'") + + relevance_flags: list[bool] = [] + for r in results: + name = getattr(r, "name", "") or "" + doc = getattr(r, "docstring", "") or "" + fid = getattr(r, "func_id", "") or "" + relevance_flags.append(is_relevant(name, doc, expected, func_id=fid)) + + p1 = 1.0 if relevance_flags and relevance_flags[0] else 0.0 + p5 = sum(relevance_flags) / max(len(relevance_flags), 1) + + p1_scores.append(p1) + p5_scores.append(p5) + + top_name = getattr(results[0], "name", "") if results else "" + top_fid = getattr(results[0], "func_id", "") if results else "" + top_score = getattr(results[0], "behavioral_score", 0.0) if results else 0.0 + + rel_mark = "✓" if p1 == 1.0 else "✗" + logger.info( + f" [{q_idx+1:02d}/50] {rel_mark} P@1={p1:.0f} P@5={p5:.2f} " + f"query='{query[:40]}' top='{top_name}' ({top_fid[:50]})" + ) + per_query.append( + { + "query": query, + "expected_behavior": expected, + "tags": q_item.get("tags", []), + "top_result_name": top_name, + "top_result_func_id": top_fid, + "top_result_score": round(float(top_score), 4), + "p_at_1": p1, + "p_at_5": round(p5, 4), + "relevant": relevance_flags, + } + ) + + n = len(p1_scores) + avg_p1 = sum(p1_scores) / max(n, 1) + avg_p5 = sum(p5_scores) / max(n, 1) + + logger.info( + f"[CodeSearchEval] Done. Retrieval successes={retrieval_successes}, " + f"failures={retrieval_failures}. " + f"Precision@1={avg_p1:.4f} Precision@5={avg_p5:.4f}" + ) + + return { + "precision_at_1": round(avg_p1, 4), + "precision_at_5": round(avg_p5, 4), + "num_queries": n, + "retrieval_successes": retrieval_successes, + "retrieval_failures": retrieval_failures, + "timestamp": datetime.now(tz=timezone.utc).isoformat(), + "per_query_results": per_query, + } + + +def _print_results(results: dict) -> None: + p1 = results["precision_at_1"] + p5 = results["precision_at_5"] + n = results["num_queries"] + succ = results.get("retrieval_successes", n) + fail = results.get("retrieval_failures", 0) + + try: + from tabulate import tabulate + + # Summary table + rows = [ + ["Precision@1 (top result relevant)", f"{p1:.4f}", f"{p1*100:.1f}%"], + ["Precision@5 (avg over top 5)", f"{p5:.4f}", f"{p5*100:.1f}%"], + ["Queries evaluated", n, ""], + ["Retrieval successes", succ, ""], + ["Retrieval failures", fail, ""], + ] + print("\n" + tabulate(rows, headers=["Metric", "Score", "Rate"], tablefmt="simple")) + + # Per-query sample (top 10 only) + sample_rows = [ + [ + r["query"][:45], + r["top_result_name"][:30], + f"{r['p_at_1']:.0f}/{r['p_at_5']:.2f}", + r.get("error", "")[:30], + ] + for r in results["per_query_results"][:10] + ] + print("\nSample results (first 10 queries):") + print(tabulate(sample_rows, headers=["Query", "Top Result", "P@1/P@5", "Error"], tablefmt="simple")) + + except ImportError: + print("\n" + "=" * 60) + print(" CodeSearchEval Results") + print("=" * 60) + print(f" Precision@1 : {p1:.4f} ({p1*100:.1f}%)") + print(f" Precision@5 : {p5:.4f} ({p5*100:.1f}%)") + print(f" Queries : {n}") + print(f" Successes : {succ} | Failures: {fail}") + print("=" * 60) + + # Print all results so user can see what's happening + print("\nAll query results:") + for r in results["per_query_results"]: + mark = "✓" if r["p_at_1"] == 1.0 else "✗" + err = r.get("error", "") + err_str = f" ERR={err[:50]}" if err else "" + print(f" {mark} P@1={r['p_at_1']:.0f} P@5={r['p_at_5']:.2f} " + f"'{r['query'][:40]}' → '{r['top_result_name'][:30]}'{err_str}") + + +def _build_fallback_retriever(): + """Build an EmbeddingFallbackRetriever (no BM25 required).""" + import torch + from core.model.function_encoder import FunctionEncoder + from core.model.dataset import Vocabulary + from core.retrieval.qdrant_store import AtlasQdrantStore + + checkpoint_path = _BACKEND_DIR / "training" / "checkpoints" / "best_model.pt" + vocab_path = _BACKEND_DIR / "training" / "data" / "vocab.json" + + if not vocab_path.exists(): + raise FileNotFoundError(f"Vocabulary not found: {vocab_path}") + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Model checkpoint not found: {checkpoint_path}") + + vocab = Vocabulary.from_file(str(vocab_path)) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + encoder = FunctionEncoder(vocab_size=vocab.size) + ckpt = torch.load(str(checkpoint_path), map_location=device) + state_dict = ckpt.get("model_state_dict", ckpt) + encoder.load_state_dict(state_dict) + encoder.to(device) + encoder.eval() + logger.info(f"Encoder loaded (vocab={vocab.size}, device={device})") + + qdrant = AtlasQdrantStore() + logger.info("Qdrant connection established (embedding-only mode).") + + return EmbeddingFallbackRetriever(encoder=encoder, qdrant=qdrant, vocab=vocab) + + +async def main_async(args: argparse.Namespace) -> None: + if args.use_embedding_fallback: + logger.info("Using EmbeddingFallbackRetriever (no BM25 / full retriever required).") + try: + retriever = _build_fallback_retriever() + except Exception as exc: + logger.error(f"Failed to build fallback retriever: {exc}") + logger.error(traceback.format_exc()) + sys.exit(1) + else: + logger.info("Using full AgenticRetriever (Qdrant + BM25 required).") + try: + from core.retrieval.retriever_factory import get_retriever + retriever = get_retriever() + except FileNotFoundError as exc: + logger.error( + f"Retriever setup failed — missing file: {exc}\n" + "TIP: Run with --use_embedding_fallback if BM25 index is not built yet.\n" + "TIP: Make sure Qdrant is running (docker run qdrant/qdrant) and the repo " + "has been indexed (python training/index_repo.py)." + ) + sys.exit(1) + except ConnectionError as exc: + logger.error( + f"Cannot connect to Qdrant: {exc}\n" + "TIP: Start Qdrant with: docker run -p 6333:6333 qdrant/qdrant\n" + "TIP: Or run with --use_embedding_fallback" + ) + sys.exit(1) + except Exception as exc: + logger.error(f"Retriever setup failed unexpectedly: {exc}") + logger.error(traceback.format_exc()) + sys.exit(1) + + evaluator = CodeSearchEvaluator(retriever) + + logger.info(f"Running {len(BEHAVIORAL_QUERIES)} behavioral queries …") + results = await evaluator.evaluate(BEHAVIORAL_QUERIES) + + _print_results(results) + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2) + logger.info(f"Results saved to {output_path}") + + +def main() -> None: + parser = argparse.ArgumentParser( + description=( + "CodeSearchEval: measure behavioral precision of Atlas search " + "using 50 handcrafted natural-language queries." + ) + ) + parser.add_argument( + "--output", + default="eval/results/codesearcheval_results.json", + help="Output JSON file path.", + ) + parser.add_argument( + "--use_embedding_fallback", + action="store_true", + default=False, + help=( + "Use embedding-only retrieval (no BM25 / full retriever required). " + "Use this when BM25 index is not available. Qdrant still needs to be running." + ), + ) + args = parser.parse_args() + asyncio.run(main_async(args)) + + +if __name__ == "__main__": + main() diff --git a/backend/eval/eval_drift.py b/backend/eval/eval_drift.py new file mode 100644 index 0000000..3a8544d --- /dev/null +++ b/backend/eval/eval_drift.py @@ -0,0 +1,647 @@ +""" +eval_drift.py +------------- +Drift Detection F1 evaluation using git history. + +Strategy: + - Get the last N commits that modify Python/JS/TS files + - For each commit: parse functions before and after, run DriftDetector + - Ground truth: functions whose line ranges overlap with `git diff` changed lines + - Compare Atlas predictions vs ground truth → Precision, Recall, F1 + - Also compute a file-level baseline (any function in a changed file is flagged) + +IMPORTANT: This eval needs repos with FULL git history (not shallow clones). +alsoo Clone with: git clone https://github.com/tiangolo/fastapi /tmp/fastapi_full +(Do NOT use --depth=1) + +Usage : + python eval/eval_drift.py --repo_path /tmp/fastapi_full --commits 10 +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import re +import subprocess +import sys +import tempfile +from datetime import datetime, timezone +from pathlib import Path + +_BACKEND_DIR = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(_BACKEND_DIR)) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger("eval_drift") + +def _run_git(args: list[str], cwd: str, check: bool = False) -> str: + """Run a git command and return stdout, logging stderr on failure.""" + result = subprocess.run( + ["git"] + args, + cwd=cwd, + capture_output=True, + text=True, + encoding="utf-8", + errors="replace", + timeout=60, + ) + if result.returncode != 0: + stderr = result.stderr.strip() + if stderr: + logger.warning(f"git {' '.join(args[:3])} failed (exit {result.returncode}): {stderr[:300]}") + if check: + raise RuntimeError(f"git command failed: {' '.join(args)}: {stderr[:200]}") + if result.stdout: + return result.stdout.strip() + return "" + + +def _check_shallow_clone(repo_path: str) -> bool: + """Return True when the repo is a shallow clone (--depth=N).""" + shallow_file = Path(repo_path) / ".git" / "shallow" + return shallow_file.exists() + + +def _has_parent(repo_path: str, commit_hash: str) -> bool: + """Return True if commit has at least one parent (i.e. is not the root commit).""" + out = _run_git(["rev-parse", "--verify", f"{commit_hash}^"], cwd=repo_path) + return bool(out) + + +def _normalize_diff_paths( + changed_ranges: dict[str, list[tuple[int, int]]], + new_nodes: list, + commit_hash: str, +) -> dict[str, list[tuple[int, int]]]: + """ + Reconcile git-diff path keys with FunctionNode.file_path values. + + git diff returns repo-relative paths (e.g. 'backend/core/foo.py'). + parse_repository() also returns repo-relative paths — but relative to + the directory passed in. If you call parse_repository('backend/') the + paths will be 'core/foo.py'. This function detects and corrects the + mismatch by stripping the common leading prefix from diff keys. + + Returns a new dict whose keys align with node.file_path values. + """ + if not changed_ranges or not new_nodes: + return changed_ranges + + node_paths: set[str] = {n.file_path for n in new_nodes} + + # 1. Fast-path: direct match already works + direct_hits = [k for k in changed_ranges if k in node_paths] + if direct_hits: + logger.info( + f" [PATH-NORM {commit_hash[:7]}] Direct match: {len(direct_hits)}/{len(changed_ranges)} " + f"diff keys matched node paths. No normalization needed." + ) + return changed_ranges + + # 2. No direct matches — try stripping leading path components from diff keys + logger.info( + f" [PATH-NORM {commit_hash[:7]}] Zero direct matches. " + f"Sample diff keys : {list(changed_ranges.keys())[:3]}" + ) + logger.info( + f" [PATH-NORM {commit_hash[:7]}] Sample node paths : {list(node_paths)[:3]}" + ) + + normalized: dict[str, list[tuple[int, int]]] = {} + matched_count = 0 + for diff_key, ranges in changed_ranges.items(): + parts = diff_key.split("/") + matched = False + for strip_count in range(1, len(parts)): + candidate = "/".join(parts[strip_count:]) + if candidate in node_paths: + normalized[candidate] = ranges + matched = True + matched_count += 1 + break + if not matched: + # 3. Also try stripping leading components from node paths to match diff key + for np in node_paths: + np_parts = np.split("/") + for strip_np in range(1, len(np_parts)): + candidate_np = "/".join(np_parts[strip_np:]) + if diff_key.endswith(candidate_np): + normalized[np] = ranges + matched = True + matched_count += 1 + break + if matched: + break + if not matched: + # Keep original key as fallback (won't match but preserves data) + normalized.setdefault(diff_key, ranges) + + logger.info( + f" [PATH-NORM {commit_hash[:7]}] After normalization: " + f"{matched_count}/{len(changed_ranges)} diff keys resolved → " + f"{len(normalized)} total keys in normalized map." + ) + return normalized + + + +def get_commits_with_python_changes(repo_path: str, n: int = 10) -> list[str]: + """Return last N commit hashes that modified .py / .js / .ts files.""" + if _check_shallow_clone(repo_path): + logger.warning( + "SHALLOW CLONE DETECTED (.git/shallow exists). " + "Drift eval needs full history — clone with: git clone (no --depth flag). " + "Attempting to continue anyway but commit history will be incomplete." + ) + + out = _run_git( + [ + "log", + "--oneline", + "--diff-filter=M", + f"-n", str(n * 3), # over-sample, then take first N + "--", + "*.py", + "*.js", + "*.ts", + ], + cwd=repo_path, + ) + if not out: + logger.warning( + "git log returned no output. Possible causes: " + "(1) shallow clone, (2) repo has no .py/.js/.ts commits, " + "(3) git not on PATH, (4) repo_path is wrong." + ) + commits = [line.split()[0] for line in out.splitlines() if line.strip()] + logger.info(f"Found {len(commits)} candidate commits (requested {n}).") + return commits[:n] + + +def get_changed_line_ranges(repo_path: str, commit_hash: str) -> dict[str, list[tuple[int, int]]]: + """ + For a given commit, return a dict mapping filepath → list of (start, end) line ranges + that were added or modified (compared to parent). + + Paths are always normalised to forward slashes (matching FunctionNode.file_path). + """ + diff = _run_git( + ["diff", f"{commit_hash}^1", commit_hash, "--unified=0"], + cwd=repo_path, + ) + if not diff: + logger.warning(f"git diff for {commit_hash} returned empty output — commit may have no parent or diff failed.") + + result: dict[str, list[tuple[int, int]]] = {} + current_file = "" + + hunk_re = re.compile(r"^@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@") + file_re = re.compile(r"^\+\+\+ b/(.+)$") + + for line in diff.splitlines(): + fm = file_re.match(line) + if fm: + # Normalise to forward slashes to match parse_repository() output + current_file = fm.group(1).replace("\\", "/") + result.setdefault(current_file, []) + continue + hm = hunk_re.match(line) + if hm and current_file: + start = int(hm.group(1)) + length = int(hm.group(2)) if hm.group(2) is not None else 1 + if length > 0: + result[current_file].append((start, start + length - 1)) + + return result + + +def checkout_commit(repo_path: str, commit_hash: str) -> None: + _run_git(["checkout", "-q", commit_hash], cwd=repo_path) + + +def checkout_back(repo_path: str, original_branch: str) -> None: + _run_git(["checkout", "-q", original_branch], cwd=repo_path) + + +def get_current_branch(repo_path: str) -> str: + result = _run_git(["rev-parse", "--abbrev-ref", "HEAD"], cwd=repo_path) + return result.strip() or "HEAD" + +class DriftEvaluator: + def __init__(self, detector, parser): + self.detector = detector + self.parser = parser + + def _get_changed_function_ids( + self, + all_new_nodes: list, + changed_ranges: dict[str, list[tuple[int, int]]], + ) -> set[str]: + """ + Map git diff line ranges to function IDs. + A function is "changed" if its [line_start, line_end] overlaps any changed range + in its file. + """ + changed_ids: set[str] = set() + for node in all_new_nodes: + file_ranges = changed_ranges.get(node.file_path, []) + for range_start, range_end in file_ranges: + # Overlap check + if node.line_start <= range_end and node.line_end >= range_start: + changed_ids.add(node.id) + break + return changed_ids + + def _get_baseline_flagged_ids( + self, + all_new_nodes: list, + changed_ranges: dict[str, list[tuple[int, int]]], + ) -> set[str]: + """ + Baseline: flag ALL functions in any file that appears in the diff. + No embeddings — pure file-level detection. + """ + changed_files = set(changed_ranges.keys()) + return {n.id for n in all_new_nodes if n.file_path in changed_files} + + @staticmethod + def _f1(tp: int, fp: int, fn: int) -> tuple[float, float, float]: + precision = tp / max(tp + fp, 1) + recall = tp / max(tp + fn, 1) + f1 = 2 * precision * recall / max(precision + recall, 1e-8) + return round(precision, 4), round(recall, 4), round(f1, 4) + + def evaluate_on_repo(self, repo_path: str, num_commits: int = 10, threshold: float = 0.15) -> dict: + """ + Evaluate DriftDetector on the last N commits of a repo. + """ + original_branch = get_current_branch(repo_path) + logger.info(f"[DRIFT EVAL] repo_path={repo_path!r} branch={original_branch!r}") + + commits = get_commits_with_python_changes(repo_path, n=num_commits) + + if not commits: + logger.warning( + "No commits with Python/JS/TS changes found. " + "Check: (1) repo has git history, (2) contains .py/.js/.ts files, " + "(3) git is on PATH, (4) --repo_path is correct." + ) + # Extra diagnostics: try a plain git log to see if ANY commits exist + plain_log = _run_git(["log", "--oneline", "-n", "5"], cwd=repo_path) + if plain_log: + logger.info(f" [DIAG] 'git log --oneline -5' shows commits exist:\n{plain_log}") + logger.info( + " [DIAG] But none matched '*.py *.js *.ts' with --diff-filter=M. " + "Try: git log --oneline -- '*.py' to verify." + ) + else: + logger.error( + " [DIAG] 'git log' returned nothing. " + "Is repo_path a valid git repository with history?" + ) + return { + "error": "No qualifying commits found", + "num_commits": 0, + "num_functions_evaluated": 0, + "per_commit_results": [], + "atlas_f1": 0.0, + "atlas_precision": 0.0, + "atlas_recall": 0.0, + "baseline_f1": 0.0, + "baseline_precision": 0.0, + "baseline_recall": 0.0, + "improvement_over_baseline": "+0.0%", + "threshold_used": threshold, + } + + logger.info( + f"[DRIFT EVAL] Found {len(commits)} qualifying commits: {commits}" + ) + + # Aggregate metrics + total_tp = 0 + total_fp = 0 + total_fn = 0 + baseline_tp = 0 + baseline_fp = 0 + baseline_fn = 0 + total_functions = 0 + per_commit: list[dict] = [] + skipped_no_parent = 0 + skipped_no_functions = 0 + skipped_no_groundtruth = 0 + skipped_error = 0 + + for idx, commit_hash in enumerate(commits): + logger.info(f" [{idx+1}/{len(commits)}] ===== Commit {commit_hash} =====") + try: + # --- Guard: skip commits with no parent (e.g. initial commit) --- + if not _has_parent(repo_path, commit_hash): + logger.info(f" SKIP {commit_hash}: no parent commit (root commit, cannot diff).") + skipped_no_parent += 1 + continue + + changed_ranges = get_changed_line_ranges(repo_path, commit_hash) + logger.info( + f" git diff produced {len(changed_ranges)} changed files. " + f"Sample diff keys: {list(changed_ranges.keys())[:5]}" + ) + if not changed_ranges: + logger.warning( + f" SKIP {commit_hash}: git diff returned no hunks. " + "(Merge commit? Binary-only changes? Diff command failed?)" + ) + skipped_no_groundtruth += 1 + continue + + # Checkout parent snapshot + logger.info(f" Checking out parent {commit_hash}^1 …") + checkout_commit(repo_path, f"{commit_hash}^1") + old_nodes = self.parser.parse_repository(repo_path) + logger.info(f" Parent snapshot: {len(old_nodes)} functions parsed.") + + # Checkout commit snapshot + logger.info(f" Checking out commit {commit_hash} …") + checkout_commit(repo_path, commit_hash) + new_nodes = self.parser.parse_repository(repo_path) + logger.info( + f" Commit snapshot: {len(new_nodes)} functions parsed. " + f"Sample node paths: {[n.file_path for n in new_nodes[:3]]}" + ) + + if not new_nodes: + logger.warning( + f" SKIP {commit_hash}: no functions found at this commit. " + "tree-sitter returned 0 — check that tree-sitter Python/JS/TS " + "grammars are installed (pip install tree-sitter-python)." + ) + skipped_no_functions += 1 + continue + + # --- Path normalization: align git-diff keys with node.file_path --- + changed_ranges = _normalize_diff_paths(changed_ranges, new_nodes, commit_hash) + + # Ground truth + ground_truth_ids = self._get_changed_function_ids(new_nodes, changed_ranges) + logger.info( + f" Ground truth: {len(ground_truth_ids)} functions overlap with diff " + f"(out of {len(new_nodes)} new functions, {len(changed_ranges)} changed-file entries)." + ) + if not ground_truth_ids: + logger.warning( + f" SKIP {commit_hash}: 0 functions overlap the diff after path normalization. " + "This usually means changed lines fall outside all function bodies " + "(e.g. module-level code, comments, blank lines). " + f"Changed files: {list(changed_ranges.keys())[:5]}. " + f"Node sample: {[(n.file_path, n.line_start, n.line_end) for n in new_nodes[:3]]}. " + f"Diff ranges sample: {list(changed_ranges.values())[:3]}." + ) + skipped_no_groundtruth += 1 + continue + + # Atlas predictions + # NOTE: 'added' functions are included because a newly-added function + # that overlaps the diff's changed lines IS a true positive ground-truth + # function. Excluding 'added' silently removes valid TP hits. + drift_results = self.detector.detect_drift(old_nodes, new_nodes, threshold=threshold) + predicted_drifted = { + r.function_id + for r in drift_results + if r.is_drifted and r.drift_type in ("semantic", "structural", "added") + } + # Debug: break down by drift type and show distance distribution + type_counts: dict[str, int] = {} + dist_values: list[float] = [] + for r in drift_results: + type_counts[r.drift_type] = type_counts.get(r.drift_type, 0) + 1 + if r.drift_type in ("semantic", "stable"): + dist_values.append(r.cosine_distance) + if dist_values: + dist_values_s = sorted(dist_values) + n = len(dist_values_s) + p50 = dist_values_s[n // 2] + p90 = dist_values_s[int(n * 0.9)] + logger.info( + f" Cosine dist stats (semantic/stable, n={n}): " + f"min={dist_values_s[0]:.4f} p50={p50:.4f} " + f"p90={p90:.4f} max={dist_values_s[-1]:.4f} threshold={threshold}" + ) + below = sum(1 for d in dist_values_s if d <= threshold) + logger.info( + f" Functions below threshold (not flagged): {below}/{n} ({below*100//max(n,1)}%). " + f"Functions above threshold (flagged semantic/structural): {n-below}/{n}." + ) + logger.info( + f" Atlas flagged: {len(predicted_drifted)} functions as drifted " + f"(by type: {type_counts})." + ) + + # Baseline predictions + baseline_predicted = self._get_baseline_flagged_ids(new_nodes, changed_ranges) + + # Metrics for this commit + tp = len(predicted_drifted & ground_truth_ids) + fp = len(predicted_drifted - ground_truth_ids) + fn = len(ground_truth_ids - predicted_drifted) + + b_tp = len(baseline_predicted & ground_truth_ids) + b_fp = len(baseline_predicted - ground_truth_ids) + b_fn = len(ground_truth_ids - baseline_predicted) + + prec, rec, f1 = self._f1(tp, fp, fn) + b_prec, b_rec, b_f1 = self._f1(b_tp, b_fp, b_fn) + + total_tp += tp + total_fp += fp + total_fn += fn + baseline_tp += b_tp + baseline_fp += b_fp + baseline_fn += b_fn + total_functions += len(new_nodes) + + per_commit.append( + { + "commit": commit_hash, + "num_functions": len(new_nodes), + "ground_truth_changed": len(ground_truth_ids), + "atlas_flagged": len(predicted_drifted), + "baseline_flagged": len(baseline_predicted), + "atlas_f1": f1, + "atlas_precision": prec, + "atlas_recall": rec, + "baseline_f1": b_f1, + "baseline_precision": b_prec, + "baseline_recall": b_rec, + } + ) + logger.info( + f" ✓ Atlas F1={f1:.3f} (P={prec:.3f} R={rec:.3f}) | " + f"Baseline F1={b_f1:.3f}" + ) + + except Exception as exc: + logger.error(f" ERROR processing commit {commit_hash}: {exc}", exc_info=True) + skipped_error += 1 + finally: + # Always restore original branch + try: + checkout_back(repo_path, original_branch) + except Exception: + pass + + # Summary diagnostics + logger.info( + f"[DRIFT EVAL] Done. Commits evaluated: {len(per_commit)}/{len(commits)}. " + f"Skipped: no_parent={skipped_no_parent}, no_functions={skipped_no_functions}, " + f"no_groundtruth={skipped_no_groundtruth}, errors={skipped_error}." + ) + + # Aggregate micro-averaged F1 + atlas_prec, atlas_rec, atlas_f1 = self._f1(total_tp, total_fp, total_fn) + base_prec, base_rec, base_f1 = self._f1(baseline_tp, baseline_fp, baseline_fn) + + improvement = atlas_f1 - base_f1 + improvement_str = f"+{improvement*100:.1f}%" if improvement >= 0 else f"{improvement*100:.1f}%" + + return { + "atlas_f1": atlas_f1, + "atlas_precision": atlas_prec, + "atlas_recall": atlas_rec, + "baseline_f1": base_f1, + "baseline_precision": base_prec, + "baseline_recall": base_rec, + "improvement_over_baseline": improvement_str, + "num_commits": len(per_commit), + "num_functions_evaluated": total_functions, + "threshold_used": threshold, + "timestamp": datetime.now(tz=timezone.utc).isoformat(), + "per_commit_results": per_commit, + } + + +def _print_results(results: dict) -> None: + try: + from tabulate import tabulate + + rows = [ + ["Atlas F1", results["atlas_f1"], results["atlas_precision"], results["atlas_recall"]], + ["Baseline (file-lvl)", results["baseline_f1"], results["baseline_precision"], results["baseline_recall"]], + ] + print("\n" + tabulate( + rows, + headers=["Method", "F1", "Precision", "Recall"], + tablefmt="simple", + floatfmt=".4f", + )) + print(f" Improvement over baseline: {results['improvement_over_baseline']}") + print(f" Commits evaluated: {results['num_commits']} | Functions: {results['num_functions_evaluated']}") + except ImportError: + print("\n" + "=" * 55) + print(" Drift Detection F1 Results") + print("=" * 55) + print(f" Atlas F1={results['atlas_f1']:.4f} P={results['atlas_precision']:.4f} R={results['atlas_recall']:.4f}") + print(f" Baseline F1={results['baseline_f1']:.4f} P={results['baseline_precision']:.4f} R={results['baseline_recall']:.4f}") + print(f" Improvement: {results['improvement_over_baseline']}") + print("=" * 55) + + +def main() -> None: + parser = argparse.ArgumentParser( + description=( + "Evaluate DriftDetector accuracy using git commit history. " + "IMPORTANT: requires a full-depth clone (not --depth=1). " + "Clone with: git clone https://github.com/tiangolo/fastapi /tmp/fastapi_full" + ) + ) + parser.add_argument( + "--repo_path", + required=True, + help="Path to local repo WITH full git history.", + ) + parser.add_argument( + "--commits", + type=int, + default=10, + help="Number of commits to evaluate (default: 10).", + ) + parser.add_argument( + "--threshold", + type=float, + default=0.15, + help="Cosine distance threshold for drift detection (default: 0.15).", + ) + parser.add_argument( + "--output", + default="eval/results/drift_results.json", + help="Output JSON file path.", + ) + parser.add_argument( + "--model_checkpoint", + default="training/checkpoints/best_model.pt", + help="Path to trained GATv2 model checkpoint.", + ) + parser.add_argument( + "--vocab_path", + default="training/data/vocab.json", + help="Path to vocab.json.", + ) + args = parser.parse_args() + + import torch + from core.model.function_encoder import FunctionEncoder + from core.model.dataset import Vocabulary + from core.parser.tree_sitter_parser import TreeSitterParser + from core.drift.drift_detector import DriftDetector + + # Load vocab + model + backend_root = Path(__file__).resolve().parent.parent + vocab_path = args.vocab_path if os.path.isabs(args.vocab_path) else str(backend_root / args.vocab_path) + ckpt_path = args.model_checkpoint if os.path.isabs(args.model_checkpoint) else str(backend_root / args.model_checkpoint) + + if not Path(vocab_path).exists(): + logger.error(f"Vocabulary not found: {vocab_path}") + sys.exit(1) + if not Path(ckpt_path).exists(): + logger.error(f"Checkpoint not found: {ckpt_path}") + sys.exit(1) + + device = "cuda" if torch.cuda.is_available() else "cpu" + vocab = Vocabulary.from_file(vocab_path) + ckpt = torch.load(ckpt_path, map_location=device) + stored_vocab_size = ckpt.get("vocab_size", vocab.size) + + encoder = FunctionEncoder(vocab_size=stored_vocab_size) + encoder.load_state_dict(ckpt["model_state_dict"]) + encoder.to(device) + encoder.eval() + logger.info(f"Model loaded ({stored_vocab_size} vocab, device={device})") + + ts_parser = TreeSitterParser() + detector = DriftDetector(encoder=encoder, vocab=vocab, device=device) + evaluator = DriftEvaluator(detector=detector, parser=ts_parser) + + logger.info(f"Evaluating on repo: {args.repo_path} ({args.commits} commits)") + results = evaluator.evaluate_on_repo( + repo_path=args.repo_path, + num_commits=args.commits, + threshold=args.threshold, + ) + + _print_results(results) + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2) + logger.info(f"Results saved to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/backend/eval/eval_swebench.py b/backend/eval/eval_swebench.py new file mode 100644 index 0000000..6b3bc7a --- /dev/null +++ b/backend/eval/eval_swebench.py @@ -0,0 +1,528 @@ +""" +eval_swebench.py +---------------- +SWE-Bench-style evaluation using SYNTHETIC bug injection on real test suites. + +Methodology: + 1. Parse a real repo to find all test files. + 2. For each selected test function, find which source functions it exercises. + 3. Inject a small, realistic bug into a source function. + 4. Verify the test now fails (sanity check). + 5. Run DebugAgent to attempt to fix the bug. + 6. Check if the test passes again. + +This is a legitimate research evaluation approach (used in SWE-Bench Lite +benchmarks). The README clearly states: "synthetic task generation on real +test suites — not official SWE-Bench instances." + +NOTE: This eval requires Ollama running locally (ollama serve) with a code +model loaded (e.g. ollama pull codellama). If Ollama is unavailable, LLM +calls fall back gracefully but solve rates will be 0%. + +Usage: + python eval/eval_swebench.py --repo_path /tmp/fastapi_demo --tasks 50 +""" + +from __future__ import annotations + +import argparse +import ast +import asyncio +import json +import logging +import os +import random +import re +import shutil +import sys +import tempfile +from pathlib import Path +from typing import Optional + +# Make backend importable when run as a script +_BACKEND_DIR = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(_BACKEND_DIR)) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger("eval_swebench") + +def _swap_comparison(source: str) -> Optional[str]: + """Swap < → > or > → < in the first comparison found.""" + for old, new in [(" < ", " > "), (" > ", " < "), (" == ", " != ")]: + if old in source: + return source.replace(old, new, 1) + return None + + +def _increment_literal(source: str) -> Optional[str]: + """Change an integer literal +1 or -1.""" + m = re.search(r"\b([0-9]+)\b", source) + if not m: + return None + original = int(m.group(1)) + replacement = original + 1 if original == 0 else original - 1 + return source[: m.start()] + str(replacement) + source[m.end():] + + +def _comment_out_line(source: str) -> Optional[str]: + """Comment out the first non-trivial non-def/class line.""" + lines = source.splitlines(keepends=True) + for i, line in enumerate(lines): + stripped = line.strip() + if ( + stripped + and not stripped.startswith("#") + and not stripped.startswith("def ") + and not stripped.startswith("class ") + and not stripped.startswith('"""') + and not stripped.startswith("'''") + and len(stripped) > 5 + ): + indent = len(line) - len(line.lstrip()) + lines[i] = " " * indent + "# " + line.lstrip() + return "".join(lines) + return None + + +def _delete_return(source: str) -> Optional[str]: + """Remove the first return statement.""" + lines = source.splitlines(keepends=True) + for i, line in enumerate(lines): + if re.match(r"\s+return\b", line): + lines[i] = "" + return "".join(lines) + return None + + +def _rename_variable(source: str) -> Optional[str]: + """Rename the first local variable assignment target.""" + m = re.search(r"\b([a-z_][a-z0-9_]{2,})\s*=\s*(?!=)", source) + if not m: + return None + old_var = m.group(1) + if old_var in ("self", "cls", "return", "none", "true", "false"): + return None + new_var = old_var + "_bug" + return source.replace(old_var + " =", new_var + " =", 1) + + +BUG_STRATEGIES = [ + ("swap_comparison", _swap_comparison), + ("increment_literal", _increment_literal), + ("comment_out_line", _comment_out_line), + ("delete_return", _delete_return), + ("rename_variable", _rename_variable), +] + + +def inject_bug(source_text: str, rng: random.Random) -> tuple[str, str]: + """ + Try each bug strategy in random order until one succeeds. + + Returns (bugged_source, bug_type). + """ + strategies = list(BUG_STRATEGIES) + rng.shuffle(strategies) + for name, fn in strategies: + result = fn(source_text) + if result and result != source_text: + return result, name + + lines = source_text.splitlines(keepends=True) + lines.insert(1, " _bug_sentinel = None # injected bug\n") + return "".join(lines), "noop_sentinel" + +def find_test_functions(repo_path: str) -> list[dict]: + """ + Walk the repo and find all test functions (test_*.py / *_test.py). + Returns list of {"file": str, "func": str, "test_command": str}. + """ + tasks: list[dict] = [] + repo = Path(repo_path) + + for py_file in repo.rglob("*.py"): + if py_file.stat().st_size > 200_000: + continue + name = py_file.name + if not (name.startswith("test_") or name.endswith("_test.py")): + continue + rel = py_file.relative_to(repo) + try: + tree = ast.parse(py_file.read_text(encoding="utf-8", errors="replace")) + except SyntaxError: + continue + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + if node.name.startswith("test_"): + test_id = str(rel).replace("\\", "/") + "::" + node.name + tasks.append( + { + "file": str(rel), + "func": node.name, + "test_command": f"python -m pytest {test_id} -x --tb=short -q", + } + ) + return tasks + + +def find_source_functions(repo_path: str) -> list[dict]: + """ + Find all non-test Python functions to use as bug injection targets. + Returns list of {"file": Path, "func": str, "lineno": int, "source": str}. + """ + sources: list[dict] = [] + repo = Path(repo_path) + + for py_file in repo.rglob("*.py"): + if py_file.stat().st_size > 200_000: + continue + name = py_file.name + if name.startswith("test_") or name.endswith("_test.py"): + continue + try: + content = py_file.read_text(encoding="utf-8", errors="replace") + tree = ast.parse(content) + except (SyntaxError, OSError): + continue + lines = content.splitlines(keepends=True) + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + if node.name.startswith("_") or len(node.body) < 2: + continue + start = node.lineno - 1 + end = node.end_lineno + func_source = "".join(lines[start:end]) + if len(func_source.strip()) < 30: + continue + sources.append( + { + "file": py_file, + "rel_file": str(py_file.relative_to(repo)).replace("\\", "/"), + "func": node.name, + "lineno_start": start, + "lineno_end": end, + "source": func_source, + } + ) + return sources + + +class SWEBenchEvaluator: + """ + SWE-Bench-style evaluator using synthetic bug injection. + """ + + def __init__(self, retriever, llm_client, sandbox): + from core.agent.debug_loop import DebugAgent + + self.agent = DebugAgent(retriever, llm_client, sandbox, max_iterations=5) + self.results: list[dict] = [] + + def create_synthetic_tasks( + self, repo_path: str, num_tasks: int = 50, seed: int = 42 + ) -> list[dict]: + """ + Generate synthetic debugging tasks from a real repo. + """ + rng = random.Random(seed) + logger.info(f"Scanning {repo_path} for test and source functions …") + + test_fns = find_test_functions(repo_path) + src_fns = find_source_functions(repo_path) + + if not test_fns: + logger.warning("No test functions found in repo.") + return [] + if not src_fns: + logger.warning("No source functions found in repo.") + return [] + + logger.info( + f"Found {len(test_fns)} test functions and {len(src_fns)} source functions." + ) + + rng.shuffle(test_fns) + rng.shuffle(src_fns) + + tasks: list[dict] = [] + src_pool = list(src_fns) + + for i, test_item in enumerate(test_fns): + if len(tasks) >= num_tasks: + break + if not src_pool: + break + + src_item = src_pool[i % len(src_pool)] + bugged_source, bug_type = inject_bug(src_item["source"], rng) + + task_id = f"synthetic_{i:04d}__{src_item['func']}" + tasks.append( + { + "task_id": task_id, + "repo_path": repo_path, + "test_command": test_item["test_command"], + "test_file": test_item["file"], + "issue_text": ( + f"Test `{test_item['func']}` is failing. " + f"The bug is in function `{src_item['func']}` " + f"in `{src_item['rel_file']}`. " + f"Please identify and fix the issue." + ), + "bugged_file": src_item["file"], + "rel_file": src_item["rel_file"], + "original_source": src_item["source"], + "bugged_source": bugged_source, + "bug_type": bug_type, + "lineno_start": src_item["lineno_start"], + "lineno_end": src_item["lineno_end"], + } + ) + + logger.info(f"Created {len(tasks)} synthetic tasks.") + return tasks + + def _inject_bug_into_copy(self, work_repo: str, task: dict) -> None: + """Write the bugged version of the source file into the work repo copy.""" + orig_file: Path = task["bugged_file"] + rel_file: str = task["rel_file"] + work_file = Path(work_repo) / rel_file + + if not work_file.exists(): + work_file = Path(work_repo) / orig_file.name + + if not work_file.exists(): + logger.warning(f"Cannot find {rel_file} in work repo, skipping bug injection.") + return + + original_content = work_file.read_text(encoding="utf-8", errors="replace") + original_func = task["original_source"] + bugged_func = task["bugged_source"] + + new_content = original_content.replace(original_func, bugged_func, 1) + work_file.write_text(new_content, encoding="utf-8") + + async def run_evaluation(self, tasks: list[dict]) -> dict: + """ + Run the DebugAgent on each synthetic task. + """ + from core.agent.debug_loop import SandboxExecutor + + results: list[dict] = [] + pass_1 = 0 + pass_5 = 0 + skipped = 0 + + for i, task in enumerate(tasks): + print(f"\n[{i+1}/{len(tasks)}] Task: {task['task_id']}") + + # Copy repo to temp dir + temp_dir = Path(tempfile.mkdtemp()) + work_repo = str(temp_dir / "repo") + try: + shutil.copytree(task["repo_path"], work_repo, dirs_exist_ok=True) + + # Inject bug + self._inject_bug_into_copy(work_repo, task) + + # Sanity check: does the test fail after injection? + sandbox = SandboxExecutor(timeout=60) + pre_check = sandbox.run_test(work_repo, task["test_command"]) + if pre_check["passed"]: + logger.info(f" SKIP: test still passes after bug injection (bug type: {task['bug_type']})") + skipped += 1 + continue + + print(f" Bug type: {task['bug_type']} | Test confirmed failing. Running agent …") + + # Run agent + result = await self.agent.solve( + { + "repo_path": work_repo, + "issue_text": task["issue_text"], + "test_command": task["test_command"], + } + ) + + if result.solved: + if result.iterations == 1: + pass_1 += 1 + pass_5 += 1 + print( + f" ✓ SOLVED in {result.iterations} iter(s)" + f" ({result.duration_seconds:.1f}s)" + ) + else: + print( + f" ✗ FAILED after {result.iterations} iter(s)" + f" ({result.duration_seconds:.1f}s)" + ) + + results.append( + { + "task_id": task["task_id"], + "solved": result.solved, + "iterations": result.iterations, + "duration_seconds": result.duration_seconds, + "bug_type": task.get("bug_type", "unknown"), + "fix_description": result.fix_description[:200], + } + ) + + except Exception as exc: + logger.error(f" ERROR in task {task['task_id']}: {exc}", exc_info=True) + results.append( + { + "task_id": task["task_id"], + "solved": False, + "iterations": 0, + "duration_seconds": 0.0, + "bug_type": task.get("bug_type", "unknown"), + "fix_description": f"Error: {exc}", + } + ) + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + + total = len(results) + # by bug type breakdown + by_bug: dict[str, dict] = {} + for r in results: + bt = r["bug_type"] + if bt not in by_bug: + by_bug[bt] = {"total": 0, "solved": 0} + by_bug[bt]["total"] += 1 + if r["solved"]: + by_bug[bt]["solved"] += 1 + + summary = { + "total_tasks": total, + "skipped_tasks": skipped, + "pass_at_1": pass_1, + "pass_at_1_pct": round(pass_1 / max(total, 1) * 100, 1), + "pass_at_5": pass_5, + "pass_at_5_pct": round(pass_5 / max(total, 1) * 100, 1), + "failed": total - pass_5, + "avg_iterations_solved": ( + round( + sum(r["iterations"] for r in results if r["solved"]) + / max(pass_5, 1), + 2, + ) + ), + "by_bug_type": by_bug, + "results": results, + } + return summary + + +def _print_summary(summary: dict) -> None: + total = summary["total_tasks"] + p1 = summary["pass_at_1"] + p5 = summary["pass_at_5"] + failed = summary["failed"] + + try: + from tabulate import tabulate + + rows = [ + ["Pass@1 (solved 1st iter)", p1, f"{summary['pass_at_1_pct']}%"], + ["Pass@5 (solved ≤5 iters)", p5, f"{summary['pass_at_5_pct']}%"], + ["Failed", failed, f"{round(failed/max(total,1)*100,1)}%"], + ["Total tasks run", total, ""], + ["Avg iters (solved)", summary["avg_iterations_solved"], ""], + ] + print("\n" + tabulate(rows, headers=["Metric", "Count", "Rate"], tablefmt="simple")) + + if summary["by_bug_type"]: + bug_rows = [ + [bt, v["total"], v["solved"], f"{round(v['solved']/max(v['total'],1)*100,1)}%"] + for bt, v in summary["by_bug_type"].items() + ] + print("\nBy bug type:") + print(tabulate(bug_rows, headers=["Bug Type", "Total", "Solved", "Solve%"], tablefmt="simple")) + + except ImportError: + print("\n" + "=" * 55) + print(f" SWE-Bench Synthetic Eval Results") + print("=" * 55) + print(f" Pass@1 : {p1}/{total} ({summary['pass_at_1_pct']}%)") + print(f" Pass@5 : {p5}/{total} ({summary['pass_at_5_pct']}%)") + print(f" Failed : {failed}/{total}") + print("=" * 55) + + +async def main_async(args: argparse.Namespace) -> None: + from core.retrieval.retriever_factory import get_retriever + from core.agent.debug_loop import SimpleLLMClient, SandboxExecutor + + retriever = get_retriever() + llm_client = SimpleLLMClient(model=args.llm_model) + sandbox = SandboxExecutor(timeout=60) + + evaluator = SWEBenchEvaluator(retriever, llm_client, sandbox) + tasks = evaluator.create_synthetic_tasks(args.repo_path, num_tasks=args.tasks) + + if not tasks: + logger.error("No tasks generated. Check that the repo has test files.") + return + + logger.info(f"Running {len(tasks)} tasks …") + summary = await evaluator.run_evaluation(tasks) + + _print_summary(summary) + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2, default=str) + logger.info(f"Results saved to {output_path}") + + +def main() -> None: + parser = argparse.ArgumentParser( + description=( + "SWE-Bench-style evaluation via synthetic bug injection on real test suites. " + "Methodology: inject small realistic bugs into source functions, then run the " + "Atlas DebugAgent to detect and fix them. " + "NOTE: requires Ollama running locally for LLM calls." + ) + ) + parser.add_argument( + "--repo_path", + default="/tmp/fastapi_demo", + help="Path to local repo with test files (e.g. cloned FastAPI).", + ) + parser.add_argument( + "--tasks", + type=int, + default=50, + help="Number of synthetic tasks to generate and evaluate.", + ) + parser.add_argument( + "--model_checkpoint", + default="training/checkpoints/best_model.pt", + help="Path to GATv2 model checkpoint (used by retriever factory).", + ) + parser.add_argument( + "--llm_model", + default="codellama", + help="Ollama model to use for fix generation (e.g. codellama, deepseek-coder).", + ) + parser.add_argument( + "--output", + default="eval/results/swebench_results.json", + help="Output JSON file path.", + ) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + asyncio.run(main_async(args)) + + +if __name__ == "__main__": + main() diff --git a/backend/eval/results/mrr_results.json b/backend/eval/results/mrr_results.json index cbac114..b287313 100644 --- a/backend/eval/results/mrr_results.json +++ b/backend/eval/results/mrr_results.json @@ -1,11 +1,13 @@ { + "mrr_at_10": 0.304519, + "hits_at_1": 0.214091, + "hits_at_5": 0.421591, + "hits_at_10": 0.542955, + "num_queries": 4400, + "model": "fused", + "timestamp": "2026-05-16T20:28:07.570233+00:00", "checkpoint": "C:\\New folder\\0.5 Atlass - codebase copy - Copy\\Atlas-Codebase_Intelligence_System\\backend\\training\\checkpoints\\best_model.pt", - "n_samples_evaluated": 4400, "n_total_sampled": 5000, - "mrr_at_10": 0.11190611471861472, - "hits_at_1": 0.09704545454545455, - "hits_at_5": 0.12977272727272726, - "hits_at_10": 0.145, - "model_epoch": 5, - "model_train_loss": 3.94096617544851 + "model_epoch": 95, + "model_train_loss": 0.3512667132502826 } \ No newline at end of file diff --git a/backend/eval/run_all_benchmarks.py b/backend/eval/run_all_benchmarks.py new file mode 100644 index 0000000..78fea95 --- /dev/null +++ b/backend/eval/run_all_benchmarks.py @@ -0,0 +1,280 @@ +""" +run_all_benchmarks.py +--------------------- +Convenience script to run all Atlas evaluation harnesses in sequence +and produce a combined benchmark report. + +Usage: + cd backend + python eval/run_all_benchmarks.py --repo_path /tmp/fastapi_full +""" + +import argparse +import json +import logging +import os +import subprocess +import sys +from datetime import datetime, timezone +from pathlib import Path + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger("run_all_benchmarks") + +BACKEND_DIR = Path(__file__).resolve().parents[1] +EVAL_DIR = BACKEND_DIR / "eval" +RESULTS_DIR = EVAL_DIR / "results" +TRAINING_DIR = BACKEND_DIR / "training" +CHECKPOINT = TRAINING_DIR / "checkpoints" / "best_model.pt" + + +def run_script(script_path: str, extra_args: list[str] | None = None) -> bool: + """Run a Python script as a subprocess. Returns True on success.""" + cmd = [sys.executable, script_path] + (extra_args or []) + logger.info(f"Running: {' '.join(cmd)}") + try: + result = subprocess.run( + cmd, + cwd=str(BACKEND_DIR), + capture_output=True, + text=True, + timeout=600, + ) + if result.stdout: + print(result.stdout[-2000:]) # last 2k chars + if result.returncode != 0: + logger.error(f"Script failed (exit {result.returncode})") + if result.stderr: + print(result.stderr[-1000:], file=sys.stderr) + return False + return True + except subprocess.TimeoutExpired: + logger.error(f"Script timed out after 600s: {script_path}") + return False + except Exception as e: + logger.error(f"Failed to run {script_path}: {e}") + return False + + +def load_json(path: Path) -> dict | None: + """Load a JSON file if it exists.""" + if path.exists(): + with open(path, encoding="utf-8") as f: + return json.load(f) + return None + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Run all Atlas benchmarks in sequence.") + p.add_argument( + "--repo_path", + default=None, + help="Path to a repo with git history (for drift and SWE-bench evals).", + ) + p.add_argument( + "--checkpoint", + default=str(CHECKPOINT), + help=f"Path to trained model checkpoint (default: {CHECKPOINT})", + ) + p.add_argument( + "--skip_swebench", + action="store_true", + help="Skip SWE-Bench eval (requires Ollama).", + ) + p.add_argument( + "--skip_drift", + action="store_true", + help="Skip Drift eval (requires repo with git history).", + ) + p.add_argument( + "--output_dir", + default=str(RESULTS_DIR), + help=f"Results directory (default: {RESULTS_DIR})", + ) + return p.parse_args() + + +def main() -> None: + args = parse_args() + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + results: dict[str, dict | str] = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "benchmarks_run": [], + } + benchmarks_run: list[str] = [] + + # ==================================================================== + # 1. MRR@10 + # ==================================================================== + print("\n" + "=" * 60) + print(" [1/4] MRR@10 Evaluation") + print("=" * 60) + mrr_script = str(TRAINING_DIR / "eval_mrr.py") + if Path(mrr_script).exists(): + extra = ["--checkpoint", args.checkpoint] + success = run_script(mrr_script, extra) + mrr_data = load_json(out_dir / "mrr_results.json") + if mrr_data: + results["mrr"] = mrr_data + benchmarks_run.append("MRR@10") + logger.info(f"MRR@10 = {mrr_data.get('mrr_at_10', '?')}") + elif not success: + results["mrr"] = "FAILED" + else: + logger.warning(f"MRR script not found: {mrr_script}") + results["mrr"] = "SKIPPED — script not found" + + # ==================================================================== + # 2. CodeSearchEval + # ==================================================================== + print("\n" + "=" * 60) + print(" [2/4] CodeSearchEval") + print("=" * 60) + cse_script = str(EVAL_DIR / "eval_codesearcheval.py") + if Path(cse_script).exists(): + success = run_script(cse_script) + cse_data = load_json(out_dir / "codesearcheval_results.json") + if cse_data: + results["codesearcheval"] = { + "precision_at_1": cse_data.get("precision_at_1"), + "precision_at_5": cse_data.get("precision_at_5"), + "num_queries": cse_data.get("num_queries"), + } + benchmarks_run.append("CodeSearchEval") + logger.info(f"P@1 = {cse_data.get('precision_at_1', '?')}") + elif not success: + results["codesearcheval"] = "FAILED" + else: + logger.warning(f"CodeSearchEval script not found: {cse_script}") + results["codesearcheval"] = "SKIPPED — script not found" + + # ==================================================================== + # 3. Drift Detection + # ==================================================================== + print("\n" + "=" * 60) + print(" [3/4] Drift Detection") + print("=" * 60) + drift_script = str(EVAL_DIR / "eval_drift.py") + if args.skip_drift: + logger.info("Drift eval skipped (--skip_drift).") + results["drift"] = "SKIPPED" + elif not args.repo_path: + logger.warning("Drift eval skipped — no --repo_path provided.") + results["drift"] = "SKIPPED — no repo_path" + elif Path(drift_script).exists(): + extra = ["--repo_path", args.repo_path, "--commits", "10"] + success = run_script(drift_script, extra) + drift_data = load_json(out_dir / "drift_results.json") + if drift_data: + results["drift"] = drift_data + benchmarks_run.append("Drift") + logger.info(f"F1 = {drift_data.get('atlas_f1', '?')}") + elif not success: + results["drift"] = "FAILED" + else: + logger.warning(f"Drift script not found: {drift_script}") + results["drift"] = "SKIPPED — script not found" + + # ==================================================================== + # 4. SWE-Bench + # ==================================================================== + print("\n" + "=" * 60) + print(" [4/4] SWE-Bench") + print("=" * 60) + swe_script = str(EVAL_DIR / "eval_swebench.py") + if args.skip_swebench: + logger.info("SWE-Bench eval skipped (--skip_swebench).") + results["swebench"] = "SKIPPED" + elif not args.repo_path: + logger.warning("SWE-Bench eval skipped — no --repo_path provided.") + results["swebench"] = "SKIPPED — no repo_path" + elif Path(swe_script).exists(): + extra = ["--repo_path", args.repo_path, "--tasks", "50"] + success = run_script(swe_script, extra) + swe_data = load_json(out_dir / "swebench_results.json") + if swe_data: + results["swebench"] = swe_data + benchmarks_run.append("SWE-Bench") + logger.info( + f"pass@1 = {swe_data.get('pass_at_1', '?')}, " + f"pass@5 = {swe_data.get('pass_at_5', '?')}" + ) + elif not success: + results["swebench"] = "FAILED" + else: + logger.warning(f"SWE-Bench script not found: {swe_script}") + results["swebench"] = "SKIPPED — script not found" + + # ==================================================================== + # Combined report + # ==================================================================== + results["benchmarks_run"] = benchmarks_run + + report_path = out_dir / "benchmark_report.json" + with open(report_path, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + logger.info(f"Combined report → {report_path}") + + # ==================================================================== + # Markdown summary + # ==================================================================== + md_lines = [ + "# Atlas Benchmark Summary", + "", + f"*Generated: {results['timestamp']}*", + "", + "| Metric | Value |", + "|--------|-------|", + ] + + mrr = results.get("mrr") + if isinstance(mrr, dict): + md_lines.append(f"| MRR@10 (Fused GATv2) | **{mrr.get('mrr_at_10', '—')}** |") + md_lines.append(f"| Hits@1 | {mrr.get('hits_at_1', '—')} |") + md_lines.append(f"| Hits@5 | {mrr.get('hits_at_5', '—')} |") + + cse = results.get("codesearcheval") + if isinstance(cse, dict): + md_lines.append( + f"| CodeSearchEval P@1 | **{cse.get('precision_at_1', '—')}** |" + ) + md_lines.append(f"| CodeSearchEval P@5 | {cse.get('precision_at_5', '—')} |") + + drift = results.get("drift") + if isinstance(drift, dict): + md_lines.append(f"| Drift Detection F1 | **{drift.get('atlas_f1', '—')}** |") + + swe = results.get("swebench") + if isinstance(swe, dict): + md_lines.append(f"| SWE-Bench pass@1 | **{swe.get('pass_at_1', '—')}** |") + md_lines.append(f"| SWE-Bench pass@5 | {swe.get('pass_at_5', '—')} |") + + md_lines.append("") + + summary_path = out_dir / "benchmark_summary.md" + with open(summary_path, "w", encoding="utf-8") as f: + f.write("\n".join(md_lines)) + logger.info(f"Markdown summary → {summary_path}") + + # ==================================================================== + # Print table + # ==================================================================== + print("\n" + "=" * 60) + print(" COMBINED BENCHMARK RESULTS") + print("=" * 60) + for line in md_lines[4:]: + print(f" {line}") + print(f"\n Benchmarks run: {', '.join(benchmarks_run) or 'none'}") + print(f" Report: {report_path}") + print(f" Summary: {summary_path}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/backend/main.py b/backend/main.py index 2d0310b..1c1a8f5 100644 --- a/backend/main.py +++ b/backend/main.py @@ -25,7 +25,7 @@ @asynccontextmanager async def lifespan(app: FastAPI): - logger.info("Atlas starting up", extra={"version": "2.1.0"}) + logger.info("Atlas starting up", extra={"version": "2.1.0", "sessions_dir": str(SESSIONS_DIR)}) from utils.session import cleanup_expired_sessions cleaned = cleanup_expired_sessions() if cleaned: @@ -104,5 +104,8 @@ async def health_check() -> dict[str, str]: host="0.0.0.0", port=8000, reload=True, - reload_excludes=[str(SESSIONS_DIR)], + reload_excludes=[ + str(SESSIONS_DIR), "uploads", "*.db", + "__pycache__", "training/data", "results", + ], ) diff --git a/backend/test_pipeline_integration.py b/backend/test_pipeline_integration.py index 551a9bb..179c555 100644 --- a/backend/test_pipeline_integration.py +++ b/backend/test_pipeline_integration.py @@ -10,7 +10,8 @@ sys.path.insert(0, ".") SESSION_ID = "0248f28a4f3a" -SESSION_DIR = Path("sessions") / SESSION_ID +from config import SESSIONS_DIR # noqa: E402 +SESSION_DIR = SESSIONS_DIR / SESSION_ID async def main() -> None: diff --git a/backend/training/eval_mrr.py b/backend/training/eval_mrr.py index dffdd73..10934d6 100644 --- a/backend/training/eval_mrr.py +++ b/backend/training/eval_mrr.py @@ -96,6 +96,12 @@ def parse_args() -> argparse.Namespace: default=64, help="Embedding batch size (no InfoNCE matrix here — can be larger).", ) + parser.add_argument( + "--static_only", + action="store_true", + default=False, + help="Label results as 'static_only' model (no fused embeddings).", + ) return parser.parse_args() import re as _re @@ -336,24 +342,49 @@ def main() -> None: hits5 = hits_at_5 / n_queries hits10 = hits_at_10 / n_queries - print() - print("=" * 55) - print(f" MRR@10 Evaluation Results") - print("=" * 55) - print(f" MRR@10 = {mrr:.4f} on {n_queries} test queries") - print(f" Rank 1 accuracy (Hits@1) : {hits1:.2%}") - print(f" Rank 5 accuracy (Hits@5) : {hits5:.2%}") - print(f" Rank 10 accuracy (Hits@10) : {hits10:.2%}") - print("=" * 55) + from datetime import datetime, timezone + + model_label = "static_only" if args.static_only else "fused" + # ── tabulate output ──────────────────────────────────────────────────── + try: + from tabulate import tabulate as _tabulate + table_rows = [ + ["MRR@10", f"{mrr:.4f}", f"{mrr * 100:.2f}%"], + ["Hits@1", f"{hits1:.4f}", f"{hits1 * 100:.2f}%"], + ["Hits@5", f"{hits5:.4f}", f"{hits5 * 100:.2f}%"], + ["Hits@10", f"{hits10:.4f}", f"{hits10 * 100:.2f}%"], + ] + print() + print(_tabulate( + table_rows, + headers=["Metric", "Score", "Percentage"], + tablefmt="simple", + )) + print(f" Queries evaluated : {n_queries} | Model : {model_label}") + except ImportError: + # Fallback when tabulate not installed + print() + print("=" * 55) + print(f" MRR@10 Evaluation Results [{model_label}]") + print("=" * 55) + print(f" MRR@10 = {mrr:.4f} on {n_queries} test queries") + print(f" Hits@1 : {hits1:.2%}") + print(f" Hits@5 : {hits5:.2%}") + print(f" Hits@10 : {hits10:.2%}") + print("=" * 55) + + timestamp = datetime.now(tz=timezone.utc).isoformat() results = { + "mrr_at_10": round(mrr, 6), + "hits_at_1": round(hits1, 6), + "hits_at_5": round(hits5, 6), + "hits_at_10": round(hits10, 6), + "num_queries": n_queries, + "model": model_label, + "timestamp": timestamp, "checkpoint": ckpt_path, - "n_samples_evaluated": n_queries, "n_total_sampled": n_samples, - "mrr_at_10": mrr, - "hits_at_1": hits1, - "hits_at_5": hits5, - "hits_at_10": hits10, "model_epoch": ckpt.get("epoch"), "model_train_loss": ckpt.get("loss"), } diff --git a/backend/training/umap_viz.py b/backend/training/umap_viz.py new file mode 100644 index 0000000..f6e619b --- /dev/null +++ b/backend/training/umap_viz.py @@ -0,0 +1,466 @@ +""" +umap_viz.py +----------- +UMAP visualization of GATv2 function embeddings stored in Qdrant. + +Produces: + 1. Static PNG — publication-quality dark-theme scatter (matplotlib) + 2. Interactive HTML — hoverable Plotly scatter (standalone, works in browser) + 3. Cluster analysis JSON — KMeans clustering + language purity metrics + +Usage: + cd backend + python training/umap_viz.py --collection atlas_functions +""" + +import argparse +import json +import logging +import os +import sys +from pathlib import Path + +import numpy as np +import umap + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt # noqa: E402 +from matplotlib.patches import Patch + +try: + import plotly.graph_objects as go +except ImportError: + go = None # type: ignore[assignment] + print("WARNING: plotly not installed — interactive HTML will be skipped.") + print(" Install with: pip install plotly") + +from qdrant_client import QdrantClient + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger("umap_viz") + +LANGUAGE_COLORS: dict[str, str] = { + "python": "#3572A5", + "javascript": "#F7DF1E", + "typescript": "#3178C6", + "java": "#B07219", + "go": "#00ADD8", + "rust": "#DEA584", + "c": "#555555", + "cpp": "#F34B7D", + "ruby": "#CC342D", + "unknown": "#888888", +} + +def fetch_embeddings_from_qdrant( + host: str = "localhost", + port: int = 6333, + collection: str = "atlas_functions", +) -> tuple[np.ndarray, list[dict]]: + """ + Scroll *all* points out of ``collection`` and return + ``(embeddings [N, dim], metadata_list)``. + """ + client = QdrantClient(host=host, port=port) + all_points: list = [] + offset = None + + while True: + response = client.scroll( + collection_name=collection, + limit=100, + offset=offset, + with_vectors=True, + ) + points, next_offset = response + all_points.extend(points) + if next_offset is None: + break + offset = next_offset + + if not all_points: + raise RuntimeError( + f"No points found in Qdrant collection '{collection}'. " + "Have you indexed a repo yet?" + ) + + embeddings = np.array([p.vector for p in all_points], dtype=np.float32) + metadata = [p.payload for p in all_points] + return embeddings, metadata + +def create_static_umap( + embeddings: np.ndarray, + metadata: list[dict], + output_path: str = "eval/results/umap_visualization.png", +) -> np.ndarray: + """ + Create a publication-quality static UMAP scatter plot. + + Returns the 2-D UMAP coordinates ``[N, 2]`` so callers can reuse them. + """ + logger.info("Running UMAP (static) …") + reducer = umap.UMAP( + n_neighbors=15, + min_dist=0.1, + metric="cosine", + random_state=42, + ) + coords = reducer.fit_transform(embeddings) + + languages = [m.get("language", "unknown").lower() for m in metadata] + complexities = np.array( + [float(m.get("complexity", 1)) for m in metadata], dtype=np.float32 + ) + if complexities.max() > complexities.min(): + normed = (complexities - complexities.min()) / ( + complexities.max() - complexities.min() + ) + else: + normed = np.ones_like(complexities) * 0.3 + sizes = 8 + normed * 52 + + colors = [LANGUAGE_COLORS.get(l, LANGUAGE_COLORS["unknown"]) for l in languages] + + fig, ax = plt.subplots(figsize=(14, 10)) + fig.patch.set_facecolor("#0f0f14") + ax.set_facecolor("#0f0f14") + + ax.scatter( + coords[:, 0], + coords[:, 1], + c=colors, + s=sizes, + alpha=0.7, + edgecolors="white", + linewidths=0.15, + zorder=2, + ) + + top_indices = np.argsort(complexities)[-20:] + for idx in top_indices: + name = metadata[idx].get("name", "?") + if len(name) > 25: + name = name[:22] + "…" + ax.annotate( + name, + xy=(coords[idx, 0], coords[idx, 1]), + fontsize=6, + color="#e0e0e0", + alpha=0.85, + ha="left", + va="bottom", + textcoords="offset points", + xytext=(4, 4), + ) + + unique_langs = sorted(set(languages)) + legend_patches = [ + Patch( + facecolor=LANGUAGE_COLORS.get(l, LANGUAGE_COLORS["unknown"]), + edgecolor="white", + linewidth=0.5, + label=l.capitalize(), + ) + for l in unique_langs + ] + legend = ax.legend( + handles=legend_patches, + loc="upper right", + fontsize=9, + framealpha=0.3, + facecolor="#1a1a24", + edgecolor="#333", + labelcolor="#e0e0e0", + ) + legend.get_frame().set_linewidth(0.5) + + ax.set_title( + f"Atlas GATv2 Embedding Space — {len(embeddings)} Functions", + fontsize=16, + fontweight="bold", + color="#e0e0e0", + pad=18, + ) + fig.text( + 0.5, + 0.92, + "Functions clustered by behavioral similarity, colored by language", + ha="center", + fontsize=10, + color="#999", + style="italic", + ) + + ax.set_xticks([]) + ax.set_yticks([]) + for spine in ax.spines.values(): + spine.set_visible(False) + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + fig.savefig(output_path, dpi=200, bbox_inches="tight", facecolor="#0f0f14") + plt.close(fig) + logger.info(f"Static UMAP saved → {output_path}") + return coords + +def create_interactive_umap( + embeddings: np.ndarray, + metadata: list[dict], + output_path: str = "eval/results/umap_interactive.html", + coords: np.ndarray | None = None, +) -> None: + """ + Create a standalone interactive HTML scatter plot with Plotly. + + If ``coords`` is provided, reuse them instead of re-running UMAP. + """ + if go is None: + logger.warning("plotly not installed — skipping interactive HTML.") + return + + if coords is None: + logger.info("Running UMAP (interactive) …") + reducer = umap.UMAP( + n_neighbors=15, min_dist=0.1, metric="cosine", random_state=42 + ) + coords = reducer.fit_transform(embeddings) + + languages = [m.get("language", "unknown").lower() for m in metadata] + complexities = [float(m.get("complexity", 1)) for m in metadata] + + c_arr = np.array(complexities, dtype=np.float32) + if c_arr.max() > c_arr.min(): + normed = (c_arr - c_arr.min()) / (c_arr.max() - c_arr.min()) + else: + normed = np.ones_like(c_arr) * 0.3 + sizes_px = 4 + normed * 14 + + fig = go.Figure() + unique_langs = sorted(set(languages)) + + for lang in unique_langs: + mask = [i for i, l in enumerate(languages) if l == lang] + if not mask: + continue + + hover_texts = [] + for i in mask: + m = metadata[i] + doc = (m.get("docstring") or "—")[:200] + hover_texts.append( + f"{m.get('name', '?')}
" + f"File: {m.get('file_path', '?')}
" + f"Language: {m.get('language', '?')}
" + f"Complexity: {m.get('complexity', '?')}
" + f"{doc}" + ) + + fig.add_trace( + go.Scatter( + x=coords[mask, 0].tolist(), + y=coords[mask, 1].tolist(), + mode="markers", + name=lang.capitalize(), + marker=dict( + size=[float(sizes_px[i]) for i in mask], + color=LANGUAGE_COLORS.get(lang, LANGUAGE_COLORS["unknown"]), + opacity=0.75, + line=dict(width=0.3, color="white"), + ), + text=hover_texts, + hoverinfo="text", + ) + ) + + fig.update_layout( + title=dict( + text=( + f"Atlas GATv2 Embedding Space — {len(embeddings)} Functions
" + '' + "Functions clustered by behavioral similarity, colored by language" + "" + ), + x=0.5, + font=dict(size=18, color="#e0e0e0"), + ), + paper_bgcolor="#0f0f14", + plot_bgcolor="#0f0f14", + font=dict(color="#e0e0e0"), + legend=dict( + bgcolor="rgba(26,26,36,0.7)", + bordercolor="#333", + borderwidth=1, + font=dict(size=11), + ), + xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, title=""), + yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, title=""), + hovermode="closest", + margin=dict(l=20, r=20, t=80, b=20), + ) + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + fig.write_html(output_path, include_plotlyjs="cdn") + logger.info(f"Interactive UMAP saved → {output_path}") + +def create_cluster_analysis( + embeddings: np.ndarray, + metadata: list[dict], + coords: np.ndarray | None = None, +) -> dict: + """ + Cluster the 2-D UMAP projection with KMeans, measure language purity, + and derive a *behavioral grouping score*. + + Low language purity → model learned **behaviour**, not syntax. + """ + from sklearn.cluster import KMeans + + if coords is None: + reducer = umap.UMAP( + n_neighbors=15, min_dist=0.1, metric="cosine", random_state=42 + ) + coords = reducer.fit_transform(embeddings) + + n_clusters = max(2, min(10, len(embeddings) // 10)) + logger.info(f"KMeans clustering (k={n_clusters}) …") + + kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) + labels = kmeans.fit_predict(coords) + + clusters: dict[str, dict] = {} + for i in range(n_clusters): + mask = labels == i + cluster_meta = [m for m, is_in in zip(metadata, mask) if is_in] + languages = [m.get("language", "unknown").lower() for m in cluster_meta] + + lang_counts: dict[str, int] = {} + for lang in languages: + lang_counts[lang] = lang_counts.get(lang, 0) + 1 + + if languages: + majority_lang = max(lang_counts, key=lang_counts.get) # type: ignore[arg-type] + purity = lang_counts[majority_lang] / len(languages) + else: + majority_lang = "unknown" + purity = 0.0 + + sample_names = [m.get("name", "?") for m in cluster_meta[:5]] + + clusters[f"cluster_{i}"] = { + "size": int(mask.sum()), + "majority_language": majority_lang, + "language_distribution": lang_counts, + "language_purity": round(purity, 3), + "sample_functions": sample_names, + } + + avg_purity = float(np.mean([c["language_purity"] for c in clusters.values()])) + + if avg_purity < 0.7: + interpretation = ( + f"Average language purity: {avg_purity:.1%}. " + "Functions cluster by behavior across languages — " + "model learned semantic similarity." + ) + else: + interpretation = ( + f"Average language purity: {avg_purity:.1%}. " + "Functions still cluster somewhat by language. " + "Consider more training epochs or better pair generation." + ) + + return { + "num_clusters": n_clusters, + "num_functions": len(embeddings), + "avg_language_purity": round(avg_purity, 3), + "behavioral_grouping_score": round(1.0 - avg_purity, 3), + "clusters": clusters, + "interpretation": interpretation, + } + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Generate UMAP visualizations of Atlas GATv2 embeddings." + ) + p.add_argument( + "--collection", + default="atlas_functions", + help="Qdrant collection name (default: atlas_functions)", + ) + p.add_argument( + "--host", + default="localhost", + help="Qdrant host (default: localhost)", + ) + p.add_argument( + "--port", + type=int, + default=6333, + help="Qdrant port (default: 6333)", + ) + p.add_argument( + "--output_dir", + default="eval/results", + help="Output directory (default: eval/results)", + ) + return p.parse_args() + + +def main() -> None: + args = parse_args() + out = args.output_dir + os.makedirs(out, exist_ok=True) + + logger.info( + f"Connecting to Qdrant at {args.host}:{args.port}, " + f"collection='{args.collection}' …" + ) + embeddings, metadata = fetch_embeddings_from_qdrant( + host=args.host, port=args.port, collection=args.collection + ) + print(f"\n✅ Fetched {len(embeddings)} embeddings from Qdrant\n") + + png_path = os.path.join(out, "umap_visualization.png") + coords = create_static_umap(embeddings, metadata, output_path=png_path) + print(f"📊 Static UMAP → {png_path}") + + html_path = os.path.join(out, "umap_interactive.html") + create_interactive_umap( + embeddings, metadata, output_path=html_path, coords=coords + ) + print(f"🌐 Interactive → {html_path}") + + analysis = create_cluster_analysis(embeddings, metadata, coords=coords) + json_path = os.path.join(out, "umap_analysis.json") + with open(json_path, "w", encoding="utf-8") as f: + json.dump(analysis, f, indent=2, ensure_ascii=False) + print(f"🔬 Analysis JSON → {json_path}") + + print("\n" + "=" * 60) + print(" CLUSTER ANALYSIS SUMMARY") + print("=" * 60) + print(f" Clusters : {analysis['num_clusters']}") + print(f" Functions : {analysis['num_functions']}") + print(f" Avg purity : {analysis['avg_language_purity']:.1%}") + print(f" Behavioral score : {analysis['behavioral_grouping_score']:.1%}") + print(f"\n {analysis['interpretation']}") + print() + + for cname, cdata in analysis["clusters"].items(): + print( + f" {cname:>12} size={cdata['size']:>3} " + f"purity={cdata['language_purity']:.0%} " + f"majority={cdata['majority_language']:<12} " + f"samples={cdata['sample_functions'][:3]}" + ) + + print(f"\n✅ Visualizations saved to {out}/") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/backend/workers/tasks.py b/backend/workers/tasks.py index 7dc06a9..e561e8a 100644 --- a/backend/workers/tasks.py +++ b/backend/workers/tasks.py @@ -7,10 +7,10 @@ @celery_app.task( name="tasks.run_analysis_pipeline", - max_retries=0, - acks_late=True, - time_limit=660, - soft_time_limit=600, + max_retries=0, + acks_late=True, + time_limit=1500, # 25 min hard limit + soft_time_limit=1200, # 20 min soft limit (matches config) ) def run_analysis_pipeline_task(session_id: str, source_type: str) -> dict: from pathlib import Path diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..3861e49 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,74 @@ +version: '3.8' + +services: + # ── Atlas Backend (FastAPI + Celery beat) ── + backend: + build: ./backend + ports: + - "8000:8000" + environment: + - REDIS_URL=redis://redis:6379/0 + - QDRANT_HOST=qdrant + - QDRANT_PORT=6333 + depends_on: + redis: + condition: service_started + qdrant: + condition: service_healthy + volumes: + - ./backend/training:/app/training + - ./sessions:/app/sessions + restart: unless-stopped + + # ── Atlas Frontend (React + Vite) ── + frontend: + build: ./frontend + ports: + - "5173:5173" + depends_on: + - backend + restart: unless-stopped + + # ── Redis (Celery broker + cache) ── + redis: + image: redis:7-alpine + ports: + - "6379:6379" + volumes: + - redis_data:/data + restart: unless-stopped + + # ── Celery Worker ── + celery_worker: + build: ./backend + command: celery -A workers.celery_app worker --loglevel=info + environment: + - REDIS_URL=redis://redis:6379/0 + - QDRANT_HOST=qdrant + - QDRANT_PORT=6333 + depends_on: + - redis + - qdrant + volumes: + - ./backend/training:/app/training + - ./sessions:/app/sessions + restart: unless-stopped + + # ── Qdrant Vector Database ── + qdrant: + image: qdrant/qdrant:latest + ports: + - "6333:6333" # REST API + - "6334:6334" # gRPC + volumes: + - qdrant_data:/qdrant/storage + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:6333/healthz"] + interval: 10s + timeout: 5s + retries: 5 + restart: unless-stopped + +volumes: + qdrant_data: + redis_data: diff --git a/docs/MCP_SETUP.md b/docs/MCP_SETUP.md index 38d424a..11be80c 100644 --- a/docs/MCP_SETUP.md +++ b/docs/MCP_SETUP.md @@ -216,4 +216,3 @@ Record a 60-second screen capture demonstrating: 3. `get_hot_paths` showing the riskiest functions 4. `get_architecture_rules` reporting the module structure -Upload to YouTube / Loom and paste the link here: `[TODO: add link]` diff --git a/frontend/.gitignore b/frontend/.gitignore index deed335..1b0ef19 100644 --- a/frontend/.gitignore +++ b/frontend/.gitignore @@ -1,3 +1,38 @@ node_modules/ +.pnpm-store/ +.yarn/ + dist/ +build/ +.vite/ + .env +.env.local +.env.development.local +.env.test.local +.env.production.local +.env.*.local + +*.tsbuildinfo +tsconfig.tsbuildinfo + +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +*.log + +.DS_Store +.DS_Store? +Thumbs.db +desktop.ini +.vscode/ +.idea/ +*.swp +*.swo +*~ + +build_errors.txt +*.tmp +*.cache +coverage/ diff --git a/frontend/src/api/api.ts b/frontend/src/api/api.ts index 28e1dff..e2d3471 100644 --- a/frontend/src/api/api.ts +++ b/frontend/src/api/api.ts @@ -14,6 +14,7 @@ import type { KeyUpdateResponse, TestProviderResponse, ClearCacheResponse, + ProviderModelsResponse, DeadCodeResponse, FunctionGraphResponse, ReadmeResponse, @@ -310,11 +311,21 @@ export async function getOllamaModels(): Promise<{ } export async function selectModel( - model: string -): Promise<{ model: string; status: string }> { - const res = await client.post<{ model: string; status: string }>( + model: string, + provider: string = "ollama" +): Promise<{ provider: string; model: string; status: string }> { + const res = await client.post<{ provider: string; model: string; status: string }>( "/api/settings/select-model", - { model } + { model, provider } + ); + return res.data; +} + +export async function getProviderModels( + provider: string +): Promise { + const res = await client.get( + `/api/settings/provider-models/${provider}` ); return res.data; } diff --git a/frontend/src/components/graph/GitTimeline.tsx b/frontend/src/components/graph/GitTimeline.tsx index 514d453..34d8fc6 100644 --- a/frontend/src/components/graph/GitTimeline.tsx +++ b/frontend/src/components/graph/GitTimeline.tsx @@ -40,10 +40,13 @@ export function GitTimeline() { const [sliderValue, setSliderValue] = useState(0); const [fetchError, setFetchError] = useState(false); const sliderRef = useRef(null); + const timelineFetchingRef = useRef(false); + const coverageFetchingRef = useRef(false); useEffect(() => { - if (!sessionId || timelineData || isTimelineLoading) return; + if (!sessionId || timelineData || timelineFetchingRef.current) return; let cancelled = false; + timelineFetchingRef.current = true; (async () => { setTimelineLoading(true); setFetchError(false); @@ -51,21 +54,28 @@ export function GitTimeline() { const data = await getGitTimeline(sessionId); if (!cancelled) { setTimelineData(data); - setTimelineLoading(false); } } catch { if (!cancelled) { setFetchError(true); setTimelineLoading(false); } + } finally { + if (!cancelled) { + timelineFetchingRef.current = false; + } } })(); - return () => { cancelled = true; }; - }, [sessionId, timelineData, isTimelineLoading, setTimelineData, setTimelineLoading]); + return () => { + cancelled = true; + timelineFetchingRef.current = false; + }; + }, [sessionId, timelineData, setTimelineData, setTimelineLoading]); useEffect(() => { - if (!sessionId || coverageData || isCoverageLoading) return; + if (!sessionId || coverageData || coverageFetchingRef.current) return; let cancelled = false; + coverageFetchingRef.current = true; (async () => { setCoverageLoading(true); try { @@ -73,10 +83,17 @@ export function GitTimeline() { if (!cancelled) setCoverageData(data); } catch { if (!cancelled) setCoverageLoading(false); + } finally { + if (!cancelled) { + coverageFetchingRef.current = false; + } } })(); - return () => { cancelled = true; }; - }, [sessionId, coverageData, isCoverageLoading, setCoverageData, setCoverageLoading]); + return () => { + cancelled = true; + coverageFetchingRef.current = false; + }; + }, [sessionId, coverageData, setCoverageData, setCoverageLoading]); const handleSliderChange = useCallback( async (value: number) => { diff --git a/frontend/src/components/settings/SettingsPanel.tsx b/frontend/src/components/settings/SettingsPanel.tsx index d073a5d..9d33310 100644 --- a/frontend/src/components/settings/SettingsPanel.tsx +++ b/frontend/src/components/settings/SettingsPanel.tsx @@ -22,6 +22,7 @@ import { EyeOff, CheckCircle2, XCircle, + Type, } from "lucide-react"; import { useUiStore } from "../../store/uiStore"; import { useSessionStore } from "../../store/sessionStore"; @@ -47,22 +48,22 @@ const PROVIDER_META: Record< groq: { label: "Groq", color: "#f59e0b", - description: "14,400 req/day free — Llama 3 8B", + description: "14,400 req/day free — fast inference", }, gemini: { label: "Google Gemini", color: "#3b82f6", - description: "1,500 req/day free — Gemini 1.5 Flash", + description: "1,500 req/day free — Google AI models", }, mistral: { label: "Mistral", color: "#8b5cf6", - description: "~1,000 req/day — Mistral 7B", + description: "~1,000 req/day — Mistral AI models", }, huggingface: { label: "HuggingFace", color: "#f97316", - description: "~500 req/day — Various models", + description: "~500 req/day — open-source models", }, }; @@ -145,6 +146,24 @@ function ProviderKeyRow({ } | null>(null); const [expanded, setExpanded] = useState(false); + // Per-provider model selection state + const { + providerModels, + loadingProviderModels, + providerModelErrors, + loadProviderModels, + loadSettings, + } = useSettingsStore(); + + const [customModelInput, setCustomModelInput] = useState(""); + const [showCustomInput, setShowCustomInput] = useState(false); + const [isSelectingModel, setIsSelectingModel] = useState(false); + const [modelToast, setModelToast] = useState(null); + + const models = providerModels[provider.name] ?? []; + const isLoadingModels = loadingProviderModels[provider.name] ?? false; + const modelError = providerModelErrors[provider.name] ?? null; + if (!meta) return null; const handleSaveTest = async () => { @@ -191,6 +210,27 @@ function ProviderKeyRow({ } }; + const handleModelChange = async (newModel: string) => { + if (!newModel.trim() || newModel === provider.model) return; + setIsSelectingModel(true); + try { + await selectModel(newModel.trim(), provider.name); + await loadSettings(); + setModelToast(newModel.trim()); + setCustomModelInput(""); + setShowCustomInput(false); + setTimeout(() => setModelToast(null), 3000); + } catch { + // silently fail — provider API will validate at usage time + } finally { + setIsSelectingModel(false); + } + }; + + const handleLoadModels = () => { + void loadProviderModels(provider.name); + }; + return ( - {provider.model} + {provider.model || "No model selected"} + {/* API Key section */} {provider.key_required && (
@@ -353,6 +394,140 @@ function ProviderKeyRow({ )} + {/* Per-provider model selection (cloud providers only) */} + {provider.key_required && provider.key_set && ( +
+
+ + Model + + +
+ + {/* Dynamic model dropdown */} + {models.length > 0 && ( + + )} + + {/* Freeform model input toggle */} +
+ +
+ + {showCustomInput && ( +
+ setCustomModelInput(e.target.value)} + placeholder="Type any model name…" + className="flex-1 px-3 py-1.5 rounded-lg text-xs font-mono focus:outline-none transition-colors" + style={{ + background: "var(--surface-input-bg)", + border: "1px solid var(--surface-input-border)", + color: "var(--text-primary)", + }} + onKeyDown={(e) => { + if (e.key === "Enter" && customModelInput.trim()) { + void handleModelChange(customModelInput); + } + }} + /> + +
+ )} + + {/* Model error */} + {modelError && ( +
+ + {modelError} +
+ )} + + {/* Model switch toast */} + + {modelToast && ( + + + Model set to {modelToast} + + )} + +
+ )} + + {/* Provider test result */} {testResult && ( {}); + getAIStatus().then(setAIStatus).catch(() => { }); } }, [settingsPanelOpen, loadSettings, loadOllamaModels, setAIStatus]); @@ -458,7 +633,7 @@ export function SettingsPanel() { setIsSwitchingModel(true); updateDraft({ selectedModel: modelName }); try { - await selectModel(modelName); + await selectModel(modelName, "ollama"); await loadSettings(); setSwitchToast(modelName); setTimeout(() => setSwitchToast(null), 3000); @@ -593,7 +768,7 @@ export function SettingsPanel() { className="text-[10px] font-medium font-mono" style={{ color: "var(--text-primary)" }} > - {settings.active_model} + {settings.active_model || "None selected"}
)} @@ -628,7 +803,7 @@ export function SettingsPanel() { className="text-[11px] font-semibold uppercase tracking-wider mb-3" style={{ color: "var(--text-tertiary)" }} > - Model Selection + Local Model (Ollama)
@@ -638,7 +813,7 @@ export function SettingsPanel() { className="text-[10px]" style={{ color: "var(--text-muted)" }} > - Local Model (Ollama) + Ollama Model
- {} + { }
; + loadingProviderModels: Record; + providerModelErrors: Record; + draft: DraftState; committed: DraftState; isDirty: boolean; @@ -33,16 +39,17 @@ interface SettingsStoreState { loadSettings: () => Promise; loadOllamaModels: () => Promise; + loadProviderModels: (provider: string) => Promise; initDraft: (settings: SettingsResponse) => void; updateDraft: (partial: Partial) => void; applyDraft: () => Promise; cancelDraft: () => void; - + clearApiKeys: () => void; } const DEFAULT_DRAFT: DraftState = { - selectedModel: "phi3:mini", + selectedModel: "", preferLocal: true, }; @@ -52,6 +59,10 @@ export const useSettingsStore = create((set, get) => ({ isLoadingModels: false, ollamaReachable: false, + providerModels: {}, + loadingProviderModels: {}, + providerModelErrors: {}, + draft: { ...DEFAULT_DRAFT }, committed: { ...DEFAULT_DRAFT }, isDirty: false, @@ -65,7 +76,7 @@ export const useSettingsStore = create((set, get) => ({ set({ settings: data }); get().initDraft(data); } catch { - + } }, @@ -83,9 +94,30 @@ export const useSettingsStore = create((set, get) => ({ } }, + loadProviderModels: async (provider: string) => { + set((s) => ({ + loadingProviderModels: { ...s.loadingProviderModels, [provider]: true }, + providerModelErrors: { ...s.providerModelErrors, [provider]: null }, + })); + try { + const data = await apiGetProviderModels(provider); + set((s) => ({ + providerModels: { ...s.providerModels, [provider]: data.models }, + loadingProviderModels: { ...s.loadingProviderModels, [provider]: false }, + providerModelErrors: { ...s.providerModelErrors, [provider]: data.error ?? null }, + })); + } catch { + set((s) => ({ + providerModels: { ...s.providerModels, [provider]: [] }, + loadingProviderModels: { ...s.loadingProviderModels, [provider]: false }, + providerModelErrors: { ...s.providerModelErrors, [provider]: "Failed to load models" }, + })); + } + }, + initDraft: (settings: SettingsResponse) => { const ollamaProvider = settings.providers.find((p) => p.name === "ollama"); - const model = ollamaProvider?.model ?? "phi3:mini"; + const model = ollamaProvider?.model || ""; const committed: DraftState = { selectedModel: model, preferLocal: settings.prefer_local, @@ -108,7 +140,7 @@ export const useSettingsStore = create((set, get) => ({ try { if (draft.selectedModel !== committed.selectedModel) { - await apiSelectModel(draft.selectedModel); + await apiSelectModel(draft.selectedModel, "ollama"); } if (draft.preferLocal !== committed.preferLocal) { await apiSetPreferLocal(draft.preferLocal); @@ -133,7 +165,7 @@ export const useSettingsStore = create((set, get) => ({ }, clearApiKeys: () => { - + set({ settings: null }); }, })); diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index e7b8914..15259fa 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -239,6 +239,19 @@ export interface ClearCacheResponse { message: string; } +export interface ProviderModelInfo { + id: string; + name?: string; + owned_by?: string; +} + +export interface ProviderModelsResponse { + provider: string; + models: ProviderModelInfo[]; + current_model?: string; + error?: string | null; +} + export interface DeadFileEntry { path: string; reason: string; diff --git a/pyrightconfig.json b/pyrightconfig.json index 1e3363c..ae3321b 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -7,11 +7,11 @@ "venvPath": "backend", "venv": ".venv", "exclude": [ - "backend/sessions", + "sessions", "backend/.venv", "backend/**/__pycache__", "frontend/node_modules", "frontend/dist" ], "reportPrivateImportUsage": "none" -} +} \ No newline at end of file