flexrag.entrypoints.run_interactive 源代码

import sys
from pathlib import Path
from typing import Optional

import gradio as gr
import hydra
import PIL
import PIL.Image
from hydra.core.config_store import ConfigStore

from flexrag.assistant import ASSISTANTS
from flexrag.utils import LOGGER_MANAGER, configure, extract_config, load_user_module

# load user modules before loading config
for arg in sys.argv:
    if arg.startswith("user_module="):
        load_user_module(arg.split("=")[1])
        sys.argv.remove(arg)


AssistantConfig = ASSISTANTS.make_config(config_name="AssistantConfig")


[文档] @configure class Config(AssistantConfig): share: bool = False server_name: str = "127.0.0.1" server_port: int = 7860 auth: Optional[list[str]] = None
cs = ConfigStore.instance() cs.store(name="default", node=Config) logger = LOGGER_MANAGER.get_logger("run_interactive") # prepare resources custom_css = """ #logo { background-color: transparent; } """ logo_path = Path(__file__).parents[0] / "assets" / "flexrag.png" wide_logo_path = Path(__file__).parents[0] / "assets" / "flexrag-wide.png" robot_path = Path(__file__).parents[0] / "assets" / "robot.png" user_path = Path(__file__).parents[0] / "assets" / "user.png" @hydra.main(version_base="1.3", config_path=None, config_name="default") def main(config: Config): config = extract_config(config, Config) logger.debug(f"Configs:\n{config.dumps()}") # load assistant assistant = ASSISTANTS.load(config) # launch the gradio app logo = PIL.Image.open(logo_path) wide_logo = PIL.Image.open(wide_logo_path) theme = gr.themes.Soft() with gr.Blocks( theme=theme, title="📖flexrag: A RAG Framework for Information Retrieval and Generation.", fill_height=True, css=custom_css, ) as demo: logo_pic = gr.Image( value=logo, image_mode="RGBA", type="pil", width="40%", show_label=False, show_download_button=False, show_share_button=False, show_fullscreen_button=False, interactive=False, container=True, elem_id="logo", ) with gr.Row(visible=False, max_height="100%") as output_row: chatbot = gr.Chatbot( type="messages", label="History messages", show_copy_button=True, height="100%", max_height="100%", avatar_images=[robot_path, user_path], ) context_box = gr.Chatbot( type="messages", label="Searched contexts", show_copy_button=True, visible=assistant is not None, height="100%", max_height="100%", ) msg = gr.Textbox( visible=True, info="What would you like to know?", show_label=False, submit_btn=True, stop_btn=False, ) clear_btn = gr.ClearButton([msg, chatbot, context_box], visible=False) def rag_chat(message: str, history: list[dict[str, str]]) -> dict: response, contexts, _ = assistant.answer(question=message) history.append(gr.ChatMessage(role="user", content=message)) history.append(gr.ChatMessage(role="assistant", content=response)) ctxs = [ gr.ChatMessage( role="assistant", content=ctx.data["text"], metadata={"title": f"Retrieved by: {ctx.retriever}"}, ) for ctx in contexts ] r = { logo_pic: wide_logo, output_row: gr.Row( visible=True, height=720, ), chatbot: history, msg: "", context_box: ctxs, clear_btn: gr.ClearButton([msg, chatbot, context_box], visible=True), } return r msg.submit( rag_chat, inputs=[msg, chatbot], outputs=[logo_pic, output_row, chatbot, msg, context_box, clear_btn], ) demo.launch( server_name=config.server_name, server_port=config.server_port, share=config.share, auth=config.auth, ) return if __name__ == "__main__": main()