from copy import deepcopy
from dataclasses import field
from typing import Annotated, Any, Optional
from flexrag.context_refine import REFINERS, RefinerConfig
from flexrag.models import GENERATORS, GenerationConfig, GeneratorConfig
from flexrag.prompt import ChatPrompt, ChatTurn
from flexrag.ranker import RANKERS, RankerConfig
from flexrag.retriever import RETRIEVERS, RetrieverConfig
from flexrag.utils import LOGGER_MANAGER, Choices, data
from flexrag.utils.dataclasses import RetrievedContext
from .assistant import ASSISTANTS, PREDEFINED_PROMPTS, AssistantBase, SearchHistory
logger = LOGGER_MANAGER.get_logger("flexrag.assistant.modular")
[docs]
@data
class ModularAssistantConfig(
GeneratorConfig, GenerationConfig, RetrieverConfig, RankerConfig, RefinerConfig
):
"""The configuration for the modular assistant.
:param response_type: The type of response to generate.
Defaults to "short". Available options are: "short", "long", "original", "custom".
:type response_type: str, optional
:param prompt_with_context_path: The path to the prompt file for response with context. Defaults to None.
:type prompt_with_context_path: str, optional
:param prompt_without_context_path: The path to the prompt file for response without context. Defaults to None.
:type prompt_without_context_path: str, optional
:param used_fields: The fields to use in the context. Defaults to [].
:type used_fields: list[str], optional
"""
response_type: Annotated[
str,
Choices(
"short",
"long",
"original",
"custom",
),
] = "short"
prompt_with_context_path: Optional[str] = None
prompt_without_context_path: Optional[str] = None
used_fields: list[str] = field(default_factory=list)
[docs]
@ASSISTANTS("modular", config_class=ModularAssistantConfig)
class ModularAssistant(AssistantBase):
"""The modular RAG assistant that supports retrieval, reranking, and generation."""
def __init__(self, cfg: ModularAssistantConfig):
# set basic args
self.gen_cfg = cfg
if self.gen_cfg.sample_num > 1:
logger.warning("Sample num > 1 is not supported for Assistant")
self.gen_cfg.sample_num = 1
self.used_fields = cfg.used_fields
# load generator
self.generator = GENERATORS.load(cfg)
assert self.generator is not None, "Generator is not loaded."
# load retriever
self.retriever = RETRIEVERS.load(cfg)
# load ranker
self.reranker = RANKERS.load(cfg)
# load refiners
self.refiner = REFINERS.load(cfg)
# load prompts
match cfg.response_type:
case "short":
self.prompt_with_ctx = PREDEFINED_PROMPTS["shortform_with_context"]
self.prompt_wo_ctx = PREDEFINED_PROMPTS["shortform_without_context"]
case "long":
self.prompt_with_ctx = PREDEFINED_PROMPTS["longform_with_context"]
self.prompt_wo_ctx = PREDEFINED_PROMPTS["longform_without_context"]
case "original":
self.prompt_with_ctx = ChatPrompt()
self.prompt_wo_ctx = ChatPrompt()
case "custom":
self.prompt_with_ctx = ChatPrompt.from_json(
cfg.prompt_with_context_path
)
self.prompt_wo_ctx = ChatPrompt.from_json(
cfg.prompt_without_context_path
)
case _:
raise ValueError(f"Invalid response type: {cfg.response_type}")
return
def answer(
self, question: str
) -> tuple[str, list[RetrievedContext], dict[str, Any]]:
ctxs, history = self.search(question)
response, prompt = self.answer_with_contexts(question, ctxs)
return response, ctxs, {"prompt": prompt, "search_histories": history}
def search(
self, question: str
) -> tuple[list[RetrievedContext], list[SearchHistory]]:
if self.retriever is None:
return [], []
# searching for contexts
search_histories = []
ctxs = self.retriever.search(query=[question])[0]
search_histories.append(
SearchHistory(query=f"search: {question}", contexts=ctxs)
)
# reranking
if self.reranker is not None:
results = self.reranker.rank(question, ctxs)
ctxs = results.candidates
search_histories.append(
SearchHistory(query=f"rerank: {question}", contexts=ctxs)
)
# refine
for refiner in self.refiner:
ctxs = refiner.refine(ctxs)
search_histories.append(
SearchHistory(query=f"refine: {question}", contexts=ctxs)
)
return ctxs, search_histories
def answer_with_contexts(
self, question: str, contexts: list[RetrievedContext] = []
) -> tuple[str, ChatPrompt]:
# prepare system prompts
if len(contexts) > 0:
prompt = deepcopy(self.prompt_with_ctx)
else:
prompt = deepcopy(self.prompt_wo_ctx)
# prepare user prompt
usr_prompt = ""
for n, context in enumerate(contexts):
if len(self.used_fields) == 0:
ctx = ""
for field_name, field_value in context.data.items():
ctx += f"{field_name}: {field_value}\n"
elif len(self.used_fields) == 1:
ctx = context.data[self.used_fields[0]]
else:
ctx = ""
for field_name in self.used_fields:
ctx += f"{field_name}: {context.data[field_name]}\n"
usr_prompt += f"Context {n + 1}: {ctx}\n\n"
usr_prompt += f"Question: {question}"
prompt.update(ChatTurn(role="user", content=usr_prompt))
# generate response
response = self.generator.chat([prompt], generation_config=self.gen_cfg)[0][0]
return response, prompt