Source code for flexrag.assistant.chatqa_assistant

from flexrag.utils import LOGGER_MANAGER
from flexrag.utils.dataclasses import RetrievedContext

from .assistant import ASSISTANTS
from .modular_rag_assistant import ModularAssistant, ModularAssistantConfig

logger = LOGGER_MANAGER.get_logger("flexrag.assistant.chatqa")


[docs] @ASSISTANTS("chatqa", config_class=ModularAssistantConfig) class ChatQAAssistant(ModularAssistant): """The Modular assistant that employs the ChatQA model for response generation.""" sys_prompt = ( "System: This is a chat between a user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. " "The assistant should also indicate when the answer cannot be found in the context." ) instruction = "Please give a full and complete answer for the question." allowed_models = [ "nvidia/Llama3-ChatQA-2-8B", "nvidia/Llama3-ChatQA-2-70B", "nvidia/Llama3-ChatQA-1.5-8B", "nvidia/Llama3-ChatQA-1.5-70B", ] def __init__(self, cfg: ModularAssistantConfig): super().__init__(cfg) logger.warning( f"ChatQA Assistant expects the model to be one of {self.allowed_models}." ) return def answer_with_contexts( self, question: str, contexts: list[RetrievedContext] ) -> tuple[str, str]: prefix = self.get_formatted_input(question, contexts) response = self.generator.generate([prefix], generation_config=self.gen_cfg) return response[0][0], prefix def get_formatted_input( self, question: str, contexts: list[RetrievedContext] ) -> str: # prepare system prompts prefix = f"{self.sys_prompt}\n\n" # prepare context string for n, context in enumerate(contexts): if len(self.used_fields) == 0: ctx = "" for field_name, field_value in context.data.items(): ctx += f"{field_name}: {field_value}\n" elif len(self.used_fields) == 1: ctx = context.data[self.used_fields[0]] else: ctx = "" for field_name in self.used_fields: ctx += f"{field_name}: {context.data[field_name]}\n" prefix += f"Context {n + 1}: {ctx}\n\n" # prepare user instruction prefix += f"User: {self.instruction} {question}\n\nAssistant:" return prefix