Source code for flexrag.assistant.basic_assistant

from copy import deepcopy
from typing import Optional

from flexrag.models import GENERATORS, GenerationConfig, GeneratorConfig
from flexrag.prompt import ChatPrompt, ChatTurn
from flexrag.utils import LOGGER_MANAGER, configure

from .assistant import ASSISTANTS, AssistantBase

logger = LOGGER_MANAGER.get_logger("flexrag.assistant")


[docs] @configure class BasicAssistantConfig(GeneratorConfig, GenerationConfig): """The configuration for the basic assistant. :param prompt_path: The path to the prompt file. Defaults to None. :type prompt_path: str, optional :param use_history: Whether to save the chat history for multi-turn conversation. Defaults to False. :type use_history: bool, optional """ prompt_path: Optional[str] = None use_history: bool = False
[docs] @ASSISTANTS("basic", config_class=BasicAssistantConfig) class BasicAssistant(AssistantBase): """A basic assistant that generates response without retrieval.""" def __init__(self, cfg: BasicAssistantConfig): # set basic args self.gen_cfg = cfg if self.gen_cfg.sample_num > 1: logger.warning("Sample num > 1 is not supported for Assistant") self.gen_cfg.sample_num = 1 # load generator self.generator = GENERATORS.load(cfg) # load prompts if cfg.prompt_path is not None: self.prompt = ChatPrompt.from_json(cfg.prompt_path) else: self.prompt = ChatPrompt() if cfg.use_history: self.history_prompt = deepcopy(self.prompt) else: self.history_prompt = None return def answer(self, question: str) -> tuple[str, None, dict[str, ChatPrompt]]: # prepare system prompt if self.history_prompt is not None: prompt = deepcopy(self.history_prompt) else: prompt = deepcopy(self.prompt) prompt.update(ChatTurn(role="user", content=question)) # generate response response = self.generator.chat([prompt], generation_config=self.gen_cfg)[0][0] # update history prompt if self.history_prompt is not None: self.history_prompt.update(ChatTurn(role="user", content=question)) self.history_prompt.update(ChatTurn(role="assistant", content=response)) return response, None, {"prompt": prompt} def clear_history(self) -> None: if self.history_prompt is not None: self.history_prompt = deepcopy(self.prompt) return