Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions src/retrievers/csv_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import chromadb.config
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers import EnsembleRetriever, MultiQueryRetriever
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers.merger_retriever import MergerRetriever
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain_chroma.vectorstores import Chroma
Expand Down Expand Up @@ -71,15 +73,25 @@ def create_bm25_chroma_ensemble_retriever(
*,
descriptions_info: dict[str, str],
field_info: dict[str, list[AttributeInfo]],
) -> MergerRetriever:
return HybridRetriever.from_subdirectory(
) -> ContextualCompressionRetriever:
"""Create a HybridRetriever wrapped with EmbeddingsFilter-based
Contextual Compression to filter out low-relevance documents before
they are passed to the LLM, improving answer quality and reducing
hallucinations caused by noisy retrieval."""
base_retriever = HybridRetriever.from_subdirectory(
llm,
embedding,
embeddings_directory,
descriptions_info=descriptions_info,
field_info=field_info,
include_original=True,
)
embeddings_filter = EmbeddingsFilter(
embeddings=embedding, similarity_threshold=0.76
)
return ContextualCompressionRetriever(
base_compressor=embeddings_filter, base_retriever=base_retriever
)


class RetrieverDict(TypedDict):
Expand Down Expand Up @@ -177,7 +189,7 @@ def retrieve_documents(self, queries: list[str], run_manager) -> list[Document]:
)
},
)
doc_lists.append(bm25_docs + vector_docs)
doc_lists.extend([bm25_docs, vector_docs])
subdirectory_docs.extend(self.weighted_reciprocal_rank(doc_lists))
return subdirectory_docs

Expand Down Expand Up @@ -214,9 +226,8 @@ async def aretrieve_documents(
subdirectory_docs: list[Document] = []
for subdir_results in subdirectory_results.values():
results_iter = iter(await asyncio.gather(*subdir_results))
doc_lists: list[list[Document]] = [
bm25_results + vector_results
for bm25_results, vector_results in zip(results_iter, results_iter)
]
doc_lists: list[list[Document]] = []
for bm25_results, vector_results in zip(results_iter, results_iter):
doc_lists.extend([bm25_results, vector_results])
subdirectory_docs.extend(self.weighted_reciprocal_rank(doc_lists))
return subdirectory_docs