paper_based_rag / utils /retriever.py
Юра Цепліцький
Switch to cohere command r model
1a0f750
from utils.settings import configure_settings
from utils.index import load_index
from utils.constant import INDEX_PATH, TOP_K_RETRIEVAL, TOP_N_RERANKER
from llama_index.core import PromptTemplate
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.postprocessor import LLMRerank
from llama_index.core.query_engine import RetrieverQueryEngine
import Stemmer
class QueryEngineManager:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(QueryEngineManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self._initialized = True
self.index = None
self.retriever = None
self.reranker = None
self.query_engine = None
self._configure()
def initialize_index(self):
self.index = load_index(path=INDEX_PATH)
self.nodes = list(self.index.docstore.docs.values())
def _configure(self):
configure_settings()
self.initialize_index()
self.reranker = LLMRerank(top_n=TOP_N_RERANKER)
def get_engine(self, bm25: bool = False, semantic: bool = False):
if bm25:
self.retriever = BM25Retriever.from_defaults(
nodes=self.nodes,
stemmer=Stemmer.Stemmer("english"),
similarity_top_k=TOP_K_RETRIEVAL,
language="english"
)
elif semantic:
self.retriever = self.index.as_retriever(similarity_top_k=TOP_K_RETRIEVAL)
qa_template = PromptTemplate(
"""Given the following context and question, provide a detailed response.
Context: {context_str}
Question: {query_str}
Let me explain this in detail:""",
prompt_type="text_qa"
)
self.query_engine = RetrieverQueryEngine.from_args(
retriever=self.retriever,
text_qa_template=qa_template,
#node_postprocessors=[self.reranker]
)
return self.query_engine
def get_engine(bm25: bool = False, semantic: bool = False):
engine_manager = QueryEngineManager()
return engine_manager.get_engine(bm25, semantic), engine_manager