flexrag.models.hf_model 源代码

import asyncio
from dataclasses import field
from typing import Annotated, Optional

import numpy as np
import torch
import torch.nn.functional as F
from PIL.Image import Image
from PIL.ImageFile import ImageFile
from torch.nn.parallel import DataParallel as DP
from transformers import (
    AutoConfig,
    AutoImageProcessor,
    AutoModel,
    AutoModelForCausalLM,
    AutoModelForMaskedLM,
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification,
    AutoModelForVision2Seq,
    AutoProcessor,
    AutoTokenizer,
    BertModel,
    BertPreTrainedModel,
    CLIPModel,
)
from transformers import GenerationConfig as HFGenerationConfig
from transformers import (
    PreTrainedModel,
    PreTrainedTokenizer,
    XLMRobertaModel,
    XLMRobertaPreTrainedModel,
)
from transformers.dynamic_module_utils import get_class_from_dynamic_module

from flexrag.prompt import ChatPrompt, MultiModelChatPrompt, load_template
from flexrag.utils import LOGGER_MANAGER, TIME_METER, Choices, configure

from .model_base import (
    ENCODERS,
    GENERATORS,
    EncoderBase,
    EncoderBaseConfig,
    GenerationConfig,
    GeneratorBase,
    VLMGeneratorBase,
)
from .utils import configure_attn, guess_model_name

logger = LOGGER_MANAGER.get_logger("flexrag.models.hf_model")


def get_colbert_model(
    base_model: str = "bert",
    output_dim: int = 128,
    model_path: str = None,
):
    """Code adapted from https://github.com/hotchpotch/JQaRA/blob/main/evaluator/reranker/colbert_reranker.py"""
    match base_model:
        case "bert":
            pretrained_class = BertPreTrainedModel
            model_class = BertModel
        case "xlm-roberta":
            pretrained_class = XLMRobertaPreTrainedModel
            model_class = XLMRobertaModel
        case "self_implemented":
            model_cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
            assert "AutoModel" in model_cfg.auto_map
            model_class_str = model_cfg.auto_map["AutoModel"]
            pretrained_class_str = model_class_str.replace("Model", "PreTrainedModel")
            model_class = get_class_from_dynamic_module(model_class_str, model_path)
            pretrained_class = get_class_from_dynamic_module(
                pretrained_class_str, model_path
            )
        case _:
            raise ValueError(f"Unsupported base model: {base_model}")

    class ColBERTModel(pretrained_class):
        def __init__(self, config):
            super().__init__(config)
            setattr(self, self.base_model_prefix, model_class(config))
            self.linear = torch.nn.Linear(config.hidden_size, output_dim, bias=False)
            self.init_weights()
            return

        def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            output_attentions=None,
            output_hidden_states=None,
        ):
            outputs = getattr(self, self.base_model_prefix)(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=True,  # Always output hidden states
            )

            sequence_output = outputs[0]
            return self.linear(sequence_output)

    return ColBERTModel


def load_hf_model(
    model_path: str,
    tokenizer_path: Optional[str] = None,
    model_type: Optional[str] = None,
    device_id: list[int] = [],
    load_dtype: str = "auto",
    trust_remote_code: bool = False,
    pipeline_parallel: bool = False,
    is_training: bool = False,
    colbert_base_model: str = "bert",
    colbert_dim: int = 128,
    other_model_kwargs: dict = {},
    other_tokenizer_kwargs: dict = {},
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
    # prepare dtype
    load_in_4bit = False
    load_in_8bit = False
    match load_dtype:
        case "bfloat16":
            load_dtype = torch.bfloat16
        case "bf16":
            load_dtype = torch.bfloat16
        case "float32":
            load_dtype = torch.float32
        case "fp32":
            load_dtype = torch.float32
        case "float16":
            load_dtype = torch.float16
        case "fp16":
            load_dtype = torch.float16
        case "half":
            load_dtype = torch.float16
        case "8bit":
            load_dtype = None
            load_in_8bit = True
        case "4bit":
            load_dtype = None
            load_in_4bit = True
        case "auto":
            load_dtype = "auto"
        case _:
            raise ValueError(f"Unsupported load_dtype: {load_dtype}")

    # prepare device
    if pipeline_parallel:
        device_map = "auto"
    elif torch.cuda.is_available() and (len(device_id) > 0):
        device_map = device_id[0]
    else:
        device_map = None

    # configure attention implementation
    attn_args = configure_attn(
        model_path=model_path,
        device_id=device_id,
        load_dtype=load_dtype,
        trust_remote_code=trust_remote_code,
    )

    # load model
    match model_type:
        case "causal_lm":
            model_class = AutoModelForCausalLM
        case "seq2seq":
            model_class = AutoModelForSeq2SeqLM
        case "sequence_classification":
            model_class = AutoModelForSequenceClassification
        case "colbert":
            model_class = get_colbert_model(colbert_base_model, colbert_dim, model_path)
        case "masked_lm":
            model_class = AutoModelForMaskedLM
        case "auto":
            model_class = AutoModel
        case "clip":
            model_class = AutoModel
        case "vlm":
            model_class = AutoModelForVision2Seq
        case _:
            model_class = AutoModel
    model = model_class.from_pretrained(
        model_path,
        device_map=device_map,
        torch_dtype=load_dtype,
        load_in_4bit=load_in_4bit,
        load_in_8bit=load_in_8bit,
        trust_remote_code=trust_remote_code,
        **other_model_kwargs,
        **attn_args,
    )

    # patch: some model does not support `int` device_map
    if isinstance(device_map, int):
        model = model.to(torch.device(device_map))

    if not is_training:
        model.eval()

    # load tokenizer
    if tokenizer_path is not None:
        tokenizer_path = tokenizer_path
    else:
        tokenizer_path = model_path
    match model_type:
        case "clip":
            tokenizer = (
                AutoTokenizer.from_pretrained(
                    tokenizer_path,
                    trust_remote_code=trust_remote_code,
                    **other_tokenizer_kwargs,
                ),
                AutoImageProcessor.from_pretrained(
                    tokenizer_path,
                    trust_remote_code=trust_remote_code,
                    **other_tokenizer_kwargs,
                ),
            )
        case "vlm":
            tokenizer = AutoProcessor.from_pretrained(
                tokenizer_path,
                trust_remote_code=trust_remote_code,
                **other_tokenizer_kwargs,
            )
        case _:
            tokenizer = AutoTokenizer.from_pretrained(
                tokenizer_path,
                trust_remote_code=trust_remote_code,
                **other_tokenizer_kwargs,
            )
    return model, tokenizer


[文档] @configure class HFModelConfig: """The Base Configuration for Huggingface Models, including `HFGenerator`, `HFVLMGenerator`, `HFEncoder` and `HFClipEncoder`. :param model_path: The path to the model. Required. :type model_path: str :param tokenizer_path: The path to the tokenizer. None for the same as model_path. Default is None. :type tokenizer_path: Optional[str] :param trust_remote_code: Whether to trust remote code. Default is False. :type trust_remote_code: bool :param device_id: The device id to use. [] for using CPU. Default is []. :type device_id: list[int] :param load_dtype: The dtype to load the model. Default is "auto". Available choices are "bfloat16", "bf16", "float32", "fp32", "float16", "fp16", "half", "8bit", "4bit", "auto", :type load_dtype: str """ model_path: Optional[str] = None tokenizer_path: Optional[str] = None trust_remote_code: bool = False device_id: list[int] = field(default_factory=list) load_dtype: Annotated[ str, Choices( "bfloat16", "bf16", "float32", "fp32", "float16", "fp16", "half", "8bit", "4bit", "auto", ), ] = "auto"
[文档] @configure class HFGeneratorConfig(HFModelConfig): """Configuration for HFGenerator. :param pipeline_parallel: Whether to use pipeline parallel. Default is False. :type pipeline_parallel: bool :param use_minference: Whether to use minference for long sequence inference. Default is False. :type use_minference: bool :param model_type: The type of the model. Default is "causal_lm". Available choices are "causal_lm", "seq2seq". """ pipeline_parallel: bool = False use_minference: bool = False model_type: Annotated[str, Choices("causal_lm", "seq2seq")] = "causal_lm"
[文档] @GENERATORS("hf", config_class=HFGeneratorConfig) class HFGenerator(GeneratorBase): model: PreTrainedModel def __init__(self, cfg: HFGeneratorConfig) -> None: # load model self.model, self.tokenizer = load_hf_model( model_path=cfg.model_path, tokenizer_path=cfg.tokenizer_path, model_type=cfg.model_type, device_id=cfg.device_id, load_dtype=cfg.load_dtype, trust_remote_code=cfg.trust_remote_code, pipeline_parallel=cfg.pipeline_parallel, ) self.model_type = cfg.model_type self._patch_model() # prepare prompt function model_name = guess_model_name(self.model.config) self.template = load_template(model_name=model_name, tokenizer=self.tokenizer) # load minference if cfg.use_minference: assert ( not cfg.pipeline_parallel ), "Minference does not support pipeline parallel" from minference import MInference try: inf_patch = MInference("minference", model_name) self.model = inf_patch(self.model) except Exception as e: logger.warning(f"Unable to load minference: {e}") return @TIME_METER("hf_generate") @torch.no_grad() def _generate( self, prefixes: list[str], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: bsz = len(prefixes) sample_num = generation_config.sample_num inputs = self.tokenizer( prefixes, return_tensors="pt", padding=True, truncation=True ) inputs = inputs.to(self.model.device) # prepare generation config hf_gen_cfg = self._get_options(generation_config) if generation_config.eos_token_id is not None: inputs["eos_token_id"] = generation_config.eos_token_id else: inputs["eos_token_id"] = self.tokenizer.eos_token_id # generate outputs = self.model.generate( **inputs, generation_config=hf_gen_cfg, tokenizer=self.tokenizer, # for stop_strings ) # truncate the input tokens if self.model_type == "causal_lm": outputs = outputs.view(bsz, sample_num, -1) input_lengths = inputs["attention_mask"].sum(dim=1) responses = [] for i in range(bsz): samples = [sample[input_lengths[i] :] for sample in outputs[i]] samples = [ self.tokenizer.decode(sample, skip_special_tokens=True) for sample in samples ] responses.append(samples) elif self.model_type == "seq2seq": outputs = outputs.view(bsz, sample_num, -1) responses = [ [ self.tokenizer.decode(sample, skip_special_tokens=True) for sample in samples ] for samples in outputs ] return responses async def async_generate( self, prefixes: list[str], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: return await asyncio.to_thread( self.generate, prefixes=prefixes, generation_config=generation_config, ) def _chat( self, prompts: list[ChatPrompt], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: assert self.template is not None, "Chat function is disabled." prefixes = [self.template.render_to_text(prompt) for prompt in prompts] return self.generate(prefixes, generation_config) async def async_chat( self, prompts: list[ChatPrompt], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: return await asyncio.to_thread( self.chat, prompts=prompts, generation_config=generation_config, ) def _get_options(self, generation_config: GenerationConfig) -> HFGenerationConfig: cfg = HFGenerationConfig( do_sample=generation_config.do_sample, temperature=generation_config.temperature, max_new_tokens=generation_config.max_new_tokens, top_p=generation_config.top_p, top_k=generation_config.top_k, num_return_sequences=generation_config.sample_num, ) if generation_config.stop_str: # empty list is not allowed cfg.stop_strings = list(generation_config.stop_str) return cfg def _patch_model(self) -> None: if self.tokenizer.pad_token_id is None: self.tokenizer.add_special_tokens({"pad_token": "<pad>"}) self.model.resize_token_embeddings(len(self.tokenizer)) return
[文档] @configure class HFVLMGeneratorConfig(HFModelConfig): """Configuration for HFVLMGenerator. :param pipeline_parallel: Whether to use pipeline parallel. Default is False. :type pipeline_parallel: bool """ pipeline_parallel: bool = False
[文档] @GENERATORS("hf_vlm", config_class=HFVLMGeneratorConfig) class HFVLMGenerator(VLMGeneratorBase): model: PreTrainedModel def __init__(self, cfg: HFVLMGeneratorConfig) -> None: # load model self.model, self.processor = load_hf_model( model_path=cfg.model_path, tokenizer_path=cfg.tokenizer_path, model_type="vlm", device_id=cfg.device_id, load_dtype=cfg.load_dtype, trust_remote_code=cfg.trust_remote_code, pipeline_parallel=cfg.pipeline_parallel, ) return @TIME_METER("hf_generate") @torch.no_grad() def _generate( self, prefixes: list[str], images: list[Image], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: bsz = len(prefixes) sample_num = generation_config.sample_num inputs = self.processor( text=prefixes, images=images, return_tensors="pt", padding=True, add_special_tokens=False, ) inputs = inputs.to(self.model.device) if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(self.model.dtype) # prepare generation config hf_gen_cfg = self._get_options(generation_config) # generate outputs = self.model.generate(**inputs, generation_config=hf_gen_cfg) # truncate the input tokens outputs = outputs.view(bsz, sample_num, -1) input_lengths = inputs["attention_mask"].sum(dim=1) responses = [] for i in range(bsz): samples = [sample[input_lengths[i] :] for sample in outputs[i]] samples = [ self.processor.decode(sample, skip_special_tokens=True) for sample in samples ] responses.append(samples) return responses def _chat( self, prompts: list[MultiModelChatPrompt], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: input_texts = [ self.processor.apply_chat_template(p.to_list(), add_generation_prompt=True) for p in prompts ] images = [p.images for p in prompts] return self.generate(input_texts, images, generation_config) def _get_options(self, generation_config: GenerationConfig) -> HFGenerationConfig: return HFGenerationConfig( do_sample=generation_config.do_sample, temperature=generation_config.temperature, max_new_tokens=generation_config.max_new_tokens, top_p=generation_config.top_p, top_k=generation_config.top_k, num_return_sequences=generation_config.sample_num, stop_strings=list(generation_config.stop_str), )
[文档] @configure class HFEncoderConfig(HFModelConfig, EncoderBaseConfig): """Configuration for HFEncoder. :param max_encode_length: The maximum length of the input sequence. Default is 512. :type max_encode_length: int :param encode_method: The method to get the embedding. Default is "mean". Available choices are "cls", "mean". :type encode_method: str :param normalize: Whether to normalize the embedding. Default is False. :type normalize: bool :param prompt: The prefix to use. Default is "". :type prompt: str :param task: The task to use. Default is "". :type task: str """ max_encode_length: int = 512 encode_method: Annotated[str, Choices("cls", "mean")] = "mean" normalize: bool = False prompt: str = "" # used in nomic-text-embedding task: str = "" # used in jina-embedding
[文档] @ENCODERS("hf", config_class=HFEncoderConfig) class HFEncoder(EncoderBase): def __init__(self, cfg: HFEncoderConfig): super().__init__(cfg) # load model self.model, self.tokenizer = load_hf_model( model_path=cfg.model_path, tokenizer_path=cfg.tokenizer_path, load_dtype=cfg.load_dtype, device_id=cfg.device_id, trust_remote_code=cfg.trust_remote_code, ) # setup model self.devices = cfg.device_id if len(self.devices) > 1: if self.is_jina: logger.warning("Data parallel does not support self implemented model.") self.dp_model = None else: self.dp_model = DP(self.model, device_ids=self.devices) else: self.dp_model = None # setup arguments self.max_encode_length = cfg.max_encode_length self.encode_method = cfg.encode_method self.normalize = cfg.normalize self.prompt = cfg.prompt self.task = cfg.task return def get_embedding( self, hidden: torch.Tensor, attn_mask: torch.Tensor ) -> np.ndarray: if self.encode_method == "mean": attn_mask = attn_mask.to(hidden.device) embeddings = hidden.masked_fill(~attn_mask[..., None].bool(), 0.0) embeddings = embeddings.sum(dim=1) / attn_mask.sum(dim=1)[..., None] elif self.encode_method == "cls": embeddings = hidden[:, 0] else: raise ValueError(f"Unsupported encode method: {self.encode_method}") if self.normalize: embeddings = torch.nn.functional.normalize(embeddings, dim=1) return embeddings.float().cpu().numpy() @TIME_METER("hf_encode") @torch.no_grad() def _encode(self, texts: list[str | list[str]]) -> np.ndarray: if self.is_jina: # for jina-embedding return self.model.encode( texts, task=self.task, max_length=self.max_encode_length, batch_size=len(texts), show_progress_bar=False, convert_to_numpy=True, ) # add prompt if needed if self.prompt: texts = [f"{self.prompt}{i}" for i in texts] # prepare encoder if (len(texts) >= len(self.devices) * 8) and (self.dp_model is not None): encoder = self.dp_model else: encoder = self.model # encode input_dict = self.tokenizer.batch_encode_plus( texts, return_tensors="pt", max_length=self.max_encode_length, padding=True, truncation=True, ) # TODO: This step is slow if not isinstance(encoder, DP): input_dict = input_dict.to(encoder.device) mask = input_dict["attention_mask"] output = encoder(**input_dict).last_hidden_state embeddings = self.get_embedding(output, mask) return embeddings async def async_encode(self, texts: list[str]) -> np.ndarray: return await asyncio.to_thread(self.encode, texts) @property def embedding_size(self) -> int: return self.model.config.hidden_size @property def is_jina(self) -> bool: return self.model.__class__.__name__ == "XLMRobertaLoRA" and hasattr( self.model, "encode" )
[文档] @configure class HFClipEncoderConfig(HFModelConfig, EncoderBaseConfig): """Configuration for HFClipEncoder. :param max_encode_length: The maximum length of the input sequence. Default is 512. :type max_encode_length: int :param normalize: Whether to normalize the embedding. Default is False. :type normalize: bool :param convert_to_rgb: Whether to convert the image to RGB. Default is False. :type convert_to_rgb: bool """ max_encode_length: int = 512 normalize: bool = False convert_to_rgb: bool = False
[文档] @ENCODERS("hf_clip", config_class=HFClipEncoderConfig) class HFClipEncoder(EncoderBase): model: CLIPModel tokenizer: PreTrainedTokenizer def __init__(self, cfg: HFClipEncoderConfig): super().__init__(cfg) self.devices = cfg.device_id # load model self.model, (self.tokenizer, self.processor) = load_hf_model( model_type="clip", model_path=cfg.model_path, tokenizer_path=cfg.tokenizer_path, load_dtype=cfg.load_dtype, device_id=cfg.device_id, trust_remote_code=cfg.trust_remote_code, ) # setup arguments self.max_encode_length = cfg.max_encode_length self.normalize = cfg.normalize self.convert_to_rgb = cfg.convert_to_rgb return def _encode(self, data: list[str | ImageFile]) -> np.ndarray: if isinstance(data[0], str): assert all(isinstance(d, str) for d in data) return self.encode_text(data) assert all(isinstance(d, ImageFile) for d in data) return self.encode_image(data) @TIME_METER("hf_clip_encode") @torch.no_grad() def encode_image(self, images: list[ImageFile]) -> np.ndarray: if self.convert_to_rgb: images = [img.convert("RGB") for img in images] input_dict = self.processor(images=images, return_tensors="pt") input_dict = input_dict.to(self.model.device) embeddings = self.model.get_image_features(**input_dict) if self.normalize: embeddings = F.normalize(embeddings, dim=1) return embeddings.float().cpu().numpy() @TIME_METER("hf_clip_encode") @torch.no_grad() def encode_text(self, texts: list[str]) -> np.ndarray: input_dict = self.tokenizer.batch_encode_plus( texts, return_tensors="pt", max_length=self.max_encode_length, padding=True, truncation=True, ) input_dict = input_dict.to(self.model.device) embeddings = self.model.get_text_features(**input_dict) if self.normalize: embeddings = F.normalize(embeddings, dim=1) return embeddings.float().cpu().numpy() @property def embedding_size(self) -> int: if hasattr(self.model.config, "projection_dim"): return self.model.config.projection_dim if hasattr(self.model.config, "hidden_size"): return self.model.config.hidden_size raise ValueError("Cannot determine embedding size from model config.")