flexrag.models.model_base 源代码

from abc import ABC, abstractmethod
from dataclasses import field
from typing import Optional

import numpy as np
from PIL.Image import Image

from flexrag.prompt import ChatPrompt, MultiModelChatPrompt
from flexrag.utils import LOGGER_MANAGER, Register, SimpleProgressLogger, configure

logger = LOGGER_MANAGER.get_logger("flexrag.models")


[文档] @configure class GenerationConfig: """Configuration for text generation. :param do_sample: Whether to use sampling for generation. Defaults to True. :type do_sample: bool :param sample_num: The number of samples to generate. Defaults to 1. :type sample_num: int :param temperature: The temperature of the sampling distribution. Defaults to 1.0. :type temperature: float :param max_new_tokens: The maximum number of tokens to generate. Defaults to 512. :type max_new_tokens: int :param top_p: The cumulative probability for nucleus sampling. Defaults to 0.9. :type top_p: float :param top_k: The number of tokens to consider for top-k sampling. Defaults to 50. :type top_k: int :param eos_token_id: The token id for the end of sentence token. Defaults to None. :type eos_token_id: Optional[int] :param stop_str: A list of strings to stop generation. Defaults to []. :type stop_str: list[str] """ do_sample: bool = True sample_num: int = 1 temperature: float = 1.0 max_new_tokens: int = 512 top_p: float = 0.9 top_k: int = 50 eos_token_id: Optional[int] = None stop_str: list[str] = field(default_factory=list) def __post_init__(self): # check values assert self.sample_num > 0, "sample_num must be greater than 0" if self.sample_num > 1: assert self.do_sample, "do_sample must be True when sample_num > 1" assert self.temperature >= 0, "temperature must be greater than or equal to 0" assert self.max_new_tokens > 0, "max_new_tokens must be greater than 0" assert 0 <= self.top_p <= 1, "top_p must be between 0 and 1" assert self.top_k > 0, "top_k must be greater than 0"
[文档] class GeneratorBase(ABC):
[文档] def chat( self, prompts: list[ChatPrompt] | list[list[dict]] | ChatPrompt | list[dict], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: """chat with the model using model templates. :param prompts: A batch of ChatPrompts. :type prompts: list[ChatPrompt] | list[list[dict]] | ChatPrompt | list[dict] :param generation_config: GenerationConfig. Defaults to GenerationConfig(). :type generation_config: GenerationConfig :return: A batch of chat responses. :rtype: list[list[str]] """ if isinstance(prompts, ChatPrompt) or isinstance(prompts[0], dict): prompts = [prompts] for i in range(len(prompts)): if isinstance(prompts[i], list): prompts[i] = ChatPrompt.from_list(prompts[i]) return self._chat(prompts, generation_config=generation_config)
@abstractmethod def _chat( self, prompts: list[ChatPrompt], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: return
[文档] async def async_chat( self, prompts: list[ChatPrompt] | list[list[dict]] | ChatPrompt | list[dict], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: """The async version of chat.""" logger.warning( "Current encoder does not support asyncronous chat, thus the code will be run in syncronous mode" ) return self.chat(prompts=prompts, generation_config=generation_config)
[文档] def generate( self, prefixes: list[str] | str, generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: """generate text with the model using the given prefixes. :param prefixes: A batch of prefixes. :type prefixes: list[str] | str :param generation_config: GenerationConfig. Defaults to GenerationConfig(). :type generation_config: GenerationConfig :return: A batch of generated text. :rtype: list[list[str]] """ if isinstance(prefixes, str): prefixes = [prefixes] return self._generate(prefixes, generation_config=generation_config)
@abstractmethod def _generate( self, prefixes: list[str] | str, generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: return
[文档] async def async_generate( self, prefixes: list[str], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: """The async version of generate.""" logger.warning( "Current generator does not support asyncronous generate, thus the code will be run in syncronous mode" ) return self.generate(prefixes=prefixes, generation_config=generation_config)
[文档] class VLMGeneratorBase(GeneratorBase):
[文档] def chat( self, prompts: ( list[MultiModelChatPrompt] | list[list[dict]] | MultiModelChatPrompt | list[dict] ), generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: """chat with the model using model templates. :param prompts: A batch of MultiModelChatPrompts. :type prompts: list[MultiModelChatPrompt] | list[list[dict]] | MultiModelChatPrompt | list[dict] :param generation_config: GenerationConfig. Defaults to GenerationConfig(). :type generation_config: GenerationConfig :return: A batch of chat responses. :rtype: list[list[str]] """ if isinstance(prompts, MultiModelChatPrompt) or isinstance(prompts[0], dict): prompts = [prompts] for i in range(len(prompts)): if isinstance(prompts[i], list): prompts[i] = MultiModelChatPrompt.from_list(prompts[i]) return self._chat(prompts, generation_config=generation_config)
@abstractmethod def _chat( self, prompts: list[MultiModelChatPrompt], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: return
[文档] async def async_chat( self, prompts: list[MultiModelChatPrompt], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: """The async version of chat.""" logger.warning( "Current encoder does not support asyncronous chat, thus the code will be run in syncronous mode" ) return self.chat(prompts=prompts, generation_config=generation_config)
[文档] def generate( self, prefixes: list[str] | str, images: list[Image], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: """generate text with the model using the given prefixes. :param prefixes: A batch of prefixes. :param images: A batch of images. :param generation_config: GenerationConfig. Defaults to GenerationConfig(). :type prefixes: list[str] :type images: list[Image] :type generation_config: GenerationConfig :return: A batch of generated text. :rtype: list[list[str]] """ if isinstance(prefixes, str): prefixes = [prefixes] return self._generate(prefixes, images, generation_config=generation_config)
@abstractmethod def _generate( self, prefixes: list[str], images: list[Image], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: return
[文档] async def async_generate( self, prefixes: list[str], images: list[Image], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: """The async version of generate.""" logger.warning( "Current generator does not support asyncronous generate, thus the code will be run in syncronous mode" ) return self.generate( prefixes=prefixes, images=images, generation_config=generation_config )
@configure class EncoderBaseConfig: """Configuration for the encoder. :param batch_size: The size of each batch. Defaults to 32. :type batch_size: int :param log_interval: The interval for logging. Defaults to 1000. :type log_interval: int """ batch_size: int = 32 log_interval: int = 1000
[文档] class EncoderBase(ABC): def __init__(self, cfg: EncoderBaseConfig): super().__init__() self.batch_size = cfg.batch_size self.log_interval = cfg.log_interval return
[文档] def encode(self, texts: list[str] | str) -> np.ndarray: """Encode the given texts into embeddings. :param texts: A batch of texts. :type texts: list[str] | str :return: A batch of embeddings. :rtype: np.ndarray """ if not isinstance(texts, list): texts = [texts] return self._encode(texts)
[文档] def encode_batch( self, texts: list[str] | str, batch_size: int = None ) -> np.ndarray: """Encode the given texts into embeddings in batches. :param texts: A batch of texts. :type texts: list[str] | str :param batch_size: The size of each batch. Defaults to self.batch_size. :type batch_size: int :return: A batch of embeddings. :rtype: np.ndarray """ batch_size = batch_size or self.batch_size if not isinstance(texts, list): texts = [texts] # prepare progress logger if (len(texts) > self.log_interval) and (len(texts) > batch_size): p_logger = SimpleProgressLogger(logger, len(texts), self.log_interval) else: p_logger = None # encode embeddings = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i : i + batch_size] embeddings.append(self._encode(batch_texts)) if p_logger is not None: p_logger.update(len(batch_texts), desc="Encoding") return np.concatenate(embeddings, axis=0)
@abstractmethod def _encode(self, texts: list[str]) -> np.ndarray: return
[文档] async def async_encode(self, texts: list[str]) -> np.ndarray: """The async version of encode.""" logger.warning( "Current encoder does not support asyncronous encode, thus the code will be run in syncronous mode" ) return self.encode(texts)
@property @abstractmethod def embedding_size(self) -> int: return
GENERATORS = Register[GeneratorBase]("generator") ENCODERS = Register[EncoderBase]("encoder")