flexrag.prompt.template 源代码

from abc import ABC, abstractmethod
from copy import deepcopy
from functools import partial
from typing import Literal, Optional

from transformers import PreTrainedTokenizer

from flexrag.utils import LOGGER_MANAGER

from .prompt_base import ChatPrompt

# TRUNCATION_STRATEGIES = ["left", "right", "history", "demo", "auto"]


logger = LOGGER_MANAGER.get_logger("flexrag.prompt")


[文档] class ChatTemplate(ABC): @abstractmethod def render_to_text( self, prompt: ChatPrompt, add_generation_prompt: bool = True, ) -> str: return @abstractmethod def render_to_ids( self, prompt: ChatPrompt, max_length: int = None, truncation: Literal["left", "right", "history", "demo", "auto"] = "auto", padding: bool = False, has_label: bool = False, add_generation_prompt: bool = True, ): return
[文档] class HFTemplate(ChatTemplate): def __init__( self, tokenizer: PreTrainedTokenizer, sys_prompt: str = None, chat_template: str = None, # Jinja template ) -> None: super().__init__() self.tokenizer = tokenizer self.sys_prompt = sys_prompt self.chat_template = chat_template return def render_to_text( self, prompt: ChatPrompt, add_generation_prompt: bool = True, ) -> str: # add default system prompt prompt_ = prompt.to_list() if (len(prompt_) == 0) and (self.sys_prompt is not None): prompt_.append({"role": "system", "content": self.sys_prompt}) if (prompt_[0]["role"] != "system") and (self.sys_prompt is not None): prompt_.insert(0, {"role": "system", "content": self.sys_prompt}) # apply template prefix = self.tokenizer.apply_chat_template( prompt_, tokenize=False, chat_template=self.chat_template, add_generation_prompt=add_generation_prompt, ) return prefix def render_to_ids( self, prompt: ChatPrompt, max_length: Optional[int] = None, truncation: Literal["left", "right", "history", "demo", "auto"] = "auto", padding: bool = False, add_generation_prompt: bool = True, ) -> list[int] | tuple[list[int], list[int]]: def _encode( prompt: ChatPrompt, max_length: int = None, truncation: bool = False, truncation_side: str = "left", ) -> list[int]: prefix = self.render_to_text(prompt, add_generation_prompt) self.tokenizer.truncation_side = truncation_side ids = self.tokenizer( prefix, max_length=max_length, padding=padding, truncation=truncation, )["input_ids"] return ids # truncate the prompt prompt = deepcopy(prompt) if max_length is not None: match truncation: case "left": ids = _encode( prompt, max_length=max_length, truncation=True, truncation_side="left", ) case "right": ids = _encode( prompt, max_length=max_length, truncation=True, truncation_side="right", ) case "history": pre_ids = _encode(prompt) while len(pre_ids) > max_length: if len(prompt.history) > 1: prompt.pop_history(0) prompt.pop_history(0) else: raise ValueError( "Unable to truncate the prompt using `history` strategy." ) pre_ids = _encode(prompt) ids = pre_ids case "demo": pre_ids = _encode(prompt) while len(pre_ids) > max_length: if len(prompt.demonstrations) > 0: prompt.pop_demonstration(0) else: raise ValueError( "Unable to truncate the prompt using `demo` strategy." ) case "auto": pre_ids = _encode(prompt) while len(pre_ids) > max_length: if len(prompt.history) > 1: prompt.pop_history(0) prompt.pop_history(0) pre_ids = _encode(prompt) elif len(prompt.demonstrations) > 0: prompt.pop_demonstration(0) pre_ids = _encode(prompt) else: pre_ids = _encode( prompt, truncation=True, truncation_side="left" ) ids = pre_ids case _: raise ValueError("Unsupported truncation strategy.") else: ids = _encode(prompt) return ids
# register default templates Llama3Template = partial( HFTemplate, sys_prompt="You are a pirate chatbot who always responds in pirate speak!", ) Qwen2Template = partial( HFTemplate, sys_prompt="You are a helpful assistant.", ) Phi3Template = HFTemplate
[文档] def load_template( tokenizer: PreTrainedTokenizer, model_name: Optional[str] = None, ) -> ChatTemplate: """ Load ChatTemplate for different models. If model_name is not provided, the default template in the Tokenizer will be used. :param tokenizer: The tokenizer used to encode the prompt. :param model_name: The name of the model. Default is None. :type tokenizer: PreTrainedTokenizer :type model_name: Optional[str] :return: The loaded ChatTemplate :rtype: ChatTemplate """ if model_name is None: logger.warning("model_name is not provided, using default template.") return HFTemplate(tokenizer=tokenizer) if "Phi-3" in model_name: return Phi3Template(tokenizer=tokenizer) elif "Meta-Llama-3" in model_name: return Llama3Template(tokenizer=tokenizer) elif "Qwen/Qwen2" in model_name: return Qwen2Template(tokenizer=tokenizer) raise ValueError(f"Unsupported architecture: {model_name}")