flexrag.prompt.prompt_base 源代码

import base64
import json
from dataclasses import field
from io import BytesIO
from os import PathLike
from typing import Annotated, Optional

from PIL.Image import Image

from flexrag.utils import Choices, data


[文档] @data class ChatTurn: role: Annotated[str, Choices("user", "assistant", "system")] content: str def to_dict(self) -> dict[str, str]: return {"role": self.role, "content": self.content} @classmethod def from_dict(cls, chat_turn: dict[str, str]): return cls(role=chat_turn["role"], content=chat_turn["content"])
[文档] @data class MultiModelChatTurn: role: Annotated[str, Choices("user", "assistant", "system")] content: list[dict[str, str | Image]] | str def to_dict(self, encode_img: bool = True) -> dict[str, str | dict]: if not encode_img: return {"role": self.role, "content": self.content} # encode the image into base64 string new_content = [] for content in self.content: if isinstance(content, str): new_content.append(content) elif content["type"] == "text": new_content.append(content) elif content["type"] == "image": assert isinstance(content["image"], Image) image_buffer = BytesIO() content["image"].save(image_buffer, format="PNG") image_bytes = image_buffer.getvalue() image_string = base64.b64encode(image_bytes).decode("utf-8") new_content.append({"type": "image", "image": image_string}) else: raise ValueError("Invalid content type") return {"role": self.role, "content": new_content} @classmethod def from_dict(cls, chat_turn: dict[str, str | dict]): # load the image data loaded_content = [] for content in chat_turn["content"]: if isinstance(content, str): loaded_content.append(content) elif content["type"] == "text": loaded_content.append(content) elif content["type"] == "image": if isinstance(content["image"], Image): loaded_content.append(content) elif isinstance(content["image"], str): image_bytes = base64.b64decode(content["image"]) image_buffer = BytesIO(image_bytes) image = Image.open(image_buffer) loaded_content.append({"type": "image", "image": image}) else: raise ValueError("Invalid image content") else: raise ValueError("Invalid content type") return cls(role=chat_turn["role"], content=chat_turn["content"])
[文档] @data class ChatPrompt: system: Optional[ChatTurn] = None history: list[ChatTurn] = field(default_factory=list) demonstrations: list[list[ChatTurn]] = field(default_factory=list) def __init__( self, system: Optional[str | ChatTurn] = None, history: list[ChatTurn] | list[dict[str, str]] = [], demonstrations: list[list[ChatTurn]] | list[list[dict[str, str]]] = [], ): # set system if isinstance(system, str): system = ChatTurn(role="system", content=system) self.system = system # set history if len(history) > 0: if isinstance(history[0], dict): history = [ChatTurn.from_dict(turn) for turn in history] self.history = history # set demonstrations if len(demonstrations) > 0: if isinstance(demonstrations[0][0], dict): demonstrations = [ [ChatTurn.from_dict(turn) for turn in demo] for demo in demonstrations ] self.demonstrations = demonstrations return def to_list(self) -> list[dict[str, str]]: data = [] if self.system is not None: data.append({"role": "system", "content": self.system.content}) for demo in self.demonstrations: for turn in demo: data.append(turn.to_dict()) for turn in self.history: data.append(turn.to_dict()) return data def to_json(self, path: str | PathLike): data = {"system": self.system.to_dict(), "history": [], "demonstrations": []} for turn in self.history: data["history"].append(turn.to_dict()) for demo in self.demonstrations: data["demonstrations"].append([turn.to_dict() for turn in demo]) with open(path, "w", encoding="utf-8") as f: json.dump(data, f, indent=4, ensure_ascii=False) return @classmethod def from_list(cls, prompt: list[dict[str, str]]) -> "ChatPrompt": history = [ChatTurn.from_dict(turn) for turn in prompt] if history[0].role == "system": system = history.pop(0) else: system = None return cls(system=system, history=history, demonstrations=[]) @classmethod def from_json(cls, path: str | PathLike) -> "ChatPrompt": with open(path, "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, list): return cls.from_list(data) return cls( system=ChatTurn.from_dict(data["system"]), history=[ChatTurn.from_dict(turn) for turn in data["history"]], demonstrations=[ [ChatTurn.from_dict(turn) for turn in demo] for demo in data["demonstrations"] ], ) def load_demonstrations(self, demo_path: str | PathLike): with open(demo_path, "r", encoding="utf-8") as f: data = json.load(f) self.demonstrations = [ [ChatTurn.from_dict(turn) for turn in demo] for demo in data ] return def pop_history(self, n: int) -> ChatTurn: return self.history.pop(n) def pop_demonstration(self, n: int) -> list[ChatTurn]: return self.demonstrations.pop(n) def update( self, chat_turns: ChatTurn | dict[str, str] | list[ChatTurn | dict[str, str]] ) -> None: if not isinstance(chat_turns, list): chat_turns = [chat_turns] for chat_turn in chat_turns: if isinstance(chat_turn, dict): chat_turn = ChatTurn.from_dict(chat_turn) self.history.append(chat_turn) return def clear(self, clear_system: bool = False): if clear_system: self.system = None self.history = [] self.demonstrations = [] return def __len__(self) -> int: system_num = 0 if self.system is None else 1 history_num = len(self.history) demo_num = sum([len(demo) for demo in self.demonstrations]) return system_num + history_num + demo_num
[文档] @data class MultiModelChatPrompt: """ This class shares almost all the methods with ChatPrompt. However, the Generics in Python does not support calling the TypeVar's classmethod. So we have to duplicate the code here. """ system: Optional[MultiModelChatTurn] = None history: list[MultiModelChatTurn] = field(default_factory=list) demonstrations: list[list[MultiModelChatTurn]] = field(default_factory=list) def __init__( self, system: Optional[str | MultiModelChatTurn] = None, history: list[MultiModelChatTurn] | list[dict[str, str]] = [], demonstrations: ( list[list[MultiModelChatTurn]] | list[list[dict[str, str]]] ) = [], ): # set system if isinstance(system, str): system = MultiModelChatTurn(role="system", content=system) self.system = system # set history if len(history) > 0: if isinstance(history[0], dict): history = [MultiModelChatTurn.from_dict(turn) for turn in history] self.history = history # set demonstrations if len(demonstrations) > 0: if isinstance(demonstrations[0][0], dict): demonstrations = [ [MultiModelChatTurn.from_dict(turn) for turn in demo] for demo in demonstrations ] self.demonstrations = demonstrations return def to_list(self) -> list[dict[str, str]]: data = [] if self.system is not None: data.append({"role": "system", "content": self.system.content}) for demo in self.demonstrations: for turn in demo: data.append(turn.to_dict()) for turn in self.history: data.append(turn.to_dict()) return data def to_json(self, path: str | PathLike): data = {"system": self.system.to_dict(), "history": [], "demonstrations": []} for turn in self.history: data["history"].append(turn.to_dict()) for demo in self.demonstrations: data["demonstrations"].append([turn.to_dict() for turn in demo]) with open(path, "w", encoding="utf-8") as f: json.dump(data, f, indent=4, ensure_ascii=False) return @classmethod def from_list(cls, prompt: list[dict[str, str]]) -> "ChatPrompt": history = [MultiModelChatTurn.from_dict(turn) for turn in prompt] if history[0].role == "system": system = history.pop(0) else: system = None return cls(system=system, history=history, demonstrations=[]) @classmethod def from_json(cls, path: str | PathLike) -> "ChatPrompt": with open(path, "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, list): return cls.from_list(data) return cls( system=MultiModelChatTurn.from_dict(data["system"]), history=[MultiModelChatTurn.from_dict(turn) for turn in data["history"]], demonstrations=[ [MultiModelChatTurn.from_dict(turn) for turn in demo] for demo in data["demonstrations"] ], ) @property def images(self) -> list[Image]: images = [] for turn in self.history: if isinstance(turn.content, list): for content in turn.content: if content["type"] == "image": images.append(content["image"]) return images def load_demonstrations(self, demo_path: str | PathLike): with open(demo_path, "r", encoding="utf-8") as f: data = json.load(f) self.demonstrations = [ [MultiModelChatTurn.from_dict(turn) for turn in demo] for demo in data ] return def pop_history(self, n: int) -> MultiModelChatTurn: return self.history.pop(n) def pop_demonstration(self, n: int) -> list[MultiModelChatTurn]: return self.demonstrations.pop(n) def update( self, chat_turns: ( MultiModelChatTurn | dict[str, str] | list[MultiModelChatTurn | dict[str, str]] ), ) -> None: if not isinstance(chat_turns, list): chat_turns = [chat_turns] for chat_turn in chat_turns: if isinstance(chat_turn, dict): chat_turn = MultiModelChatTurn.from_dict(chat_turn) self.history.append(chat_turn) return def clear(self, clear_system: bool = False): if clear_system: self.system = None self.history = [] self.demonstrations = [] return def __len__(self) -> int: system_num = 0 if self.system is None else 1 history_num = len(self.history) demo_num = sum([len(demo) for demo in self.demonstrations]) return system_num + history_num + demo_num