flexrag.models.vllm_model 源代码

import asyncio
from typing import Annotated, Optional

from transformers import AutoConfig, PretrainedConfig

from flexrag.prompt import ChatPrompt, load_template
from flexrag.utils import LOGGER_MANAGER, TIME_METER, Choices, configure

from .model_base import GENERATORS, GenerationConfig, GeneratorBase
from .utils import guess_model_name

logger = LOGGER_MANAGER.get_logger("flexrag.models.vllm")


[文档] @configure class VLLMGeneratorConfig: """Configuration for VLLMGenerator. :param model_path: Path to the model. Required. :type model_path: str :param gpu_memory_utilization: Fraction of GPU memory to use. Default to 0.85. :type gpu_memory_utilization: float :param max_model_len: Maximum length of the model. Defaults to 16384. :type max_model_len: int :param tensor_parallel: The number of tensor parallel. Defaults to 1. :type tensor_parallel: int :param load_dtype: The dtype to load the model. Defaults to "auto". Available options are "auto", "float32", "float16", "bfloat16". :type load_dtype: str :param use_minference: Whether to use minference for Long Sequence Inference. Defaults to False. :type use_minference: bool """ model_path: Optional[str] = None gpu_memory_utilization: float = 0.85 max_model_len: int = 16384 tensor_parallel: int = 1 load_dtype: Annotated[str, Choices("auto", "float32", "float16", "bfloat16")] = ( "auto" ) use_minference: bool = False trust_remote_code: bool = False
[文档] @GENERATORS("vllm", config_class=VLLMGeneratorConfig) class VLLMGenerator(GeneratorBase): def __init__(self, cfg: VLLMGeneratorConfig) -> None: from vllm import LLM # try to load model arguments from model config assert cfg.model_path is not None, "`model_path` must be provided" model_cfg: PretrainedConfig = AutoConfig.from_pretrained( cfg.model_path, trust_remote_code=cfg.trust_remote_code, ) model_name = guess_model_name(model_cfg) max_length = min( getattr(model_cfg, "max_position_embeddings", cfg.max_model_len), cfg.max_model_len, ) # load model self.model = LLM( cfg.model_path, dtype=str(cfg.load_dtype), gpu_memory_utilization=cfg.gpu_memory_utilization, tensor_parallel_size=cfg.tensor_parallel, max_model_len=max_length, trust_remote_code=cfg.trust_remote_code, enforce_eager=True if cfg.use_minference else False, ) self.tokenizer = self.model.get_tokenizer() self.template = load_template(model_name=model_name, tokenizer=self.tokenizer) # load minference if cfg.use_minference: try: from minference import MInference inf_patch = MInference("vllm", model_name) self.model = inf_patch(self.model) except Exception as e: logger.warning(f"Unable to load minference: {e}") logger.warning("Fallback to normal mode.") return @TIME_METER("vllm_generate") def _generate( self, prefixes: list[str], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: if not isinstance(prefixes, list): prefixes = [prefixes] responses = self.model.generate( prefixes, sampling_params=self._get_options(generation_config), use_tqdm=False, ) responses = [[i.text for i in resp.outputs] for resp in responses] return responses async def async_generate( self, prefixes: list[str], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: if not isinstance(prefixes, list): prefixes = [prefixes] responses = await asyncio.to_thread( self.model.generate, prefixes, sampling_params=self._get_options(generation_config), use_tqdm=False, ) responses = [[i.text for i in resp.outputs] for resp in responses] return responses def _chat( self, prompts: list[ChatPrompt], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: prefixes = [self.template.render_to_text(prompt) for prompt in prompts] return self.generate(prefixes, generation_config) async def async_chat( self, prompts: list[ChatPrompt], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: prefixes = [self.template.render_to_text(prompt) for prompt in prompts] return await self.async_generate(prefixes, generation_config) def _get_options(self, generation_config: GenerationConfig): from vllm import SamplingParams if generation_config.eos_token_id is not None: stop_token_ids = [generation_config.eos_token_id] else: stop_token_ids = [self.tokenizer.eos_token_id] return SamplingParams( n=generation_config.sample_num, max_tokens=generation_config.max_new_tokens, temperature=generation_config.temperature, top_k=generation_config.top_k, top_p=generation_config.top_p, stop_token_ids=stop_token_ids, stop=generation_config.stop_str, )