flexrag.models.sentence_transformers_model 源代码
import math
from dataclasses import field
from typing import Any, Optional
import numpy as np
from flexrag.utils import TIME_METER, configure
from .model_base import ENCODERS, EncoderBase, EncoderBaseConfig
[文档]
@configure
class SentenceTransformerEncoderConfig(EncoderBaseConfig):
"""Configuration for SentenceTransformerEncoder.
:param model_path: The path to the model. Required.
:type model_path: str
:param device_id: The device id to use. [] for CPU. Defaults to [].
:type device_id: list[int]
:param trust_remote_code: Whether to trust remote code. Defaults to False.
:type trust_remote_code: bool
:param task: The task to use. Defaults to None.
:type task: Optional[str]
:param prompt_name: The prompt name to use. Defaults to None.
:type prompt_name: Optional[str]
:param prompt: The prompt to use. Defaults to None.
:type prompt: Optional[str]
:param prompt_dict: The prompt dictionary to use. Defaults to None.
:type prompt_dict: Optional[dict]
:param normalize: Whether to normalize embeddings. Defaults to False.
:type normalize: bool
:param model_kwargs: Additional keyword arguments for loading the model. Defaults to {}.
:type model_kwargs: dict[str, Any]
"""
model_path: Optional[str] = None
device_id: list[int] = field(default_factory=list)
trust_remote_code: bool = False
task: Optional[str] = None
prompt_name: Optional[str] = None
prompt: Optional[str] = None
prompt_dict: Optional[dict] = None
normalize: bool = False
model_kwargs: dict[str, Any] = field(default_factory=dict)
[文档]
@ENCODERS("sentence_transformer", config_class=SentenceTransformerEncoderConfig)
class SentenceTransformerEncoder(EncoderBase):
def __init__(self, config: SentenceTransformerEncoderConfig) -> None:
super().__init__(config)
from sentence_transformers import SentenceTransformer
self.devices = config.device_id
assert config.model_path is not None, "`model_path` must be provided"
self.model = SentenceTransformer(
model_name_or_path=config.model_path,
device=f"cuda:{config.device_id[0]}" if config.device_id else "cpu",
trust_remote_code=config.trust_remote_code,
backend="torch",
prompts=config.prompt_dict,
model_kwargs=config.model_kwargs,
)
if len(config.device_id) > 1:
self.pool = self.model.start_multi_process_pool(
target_devices=[f"cuda:{i}" for i in config.device_id]
)
else:
self.pool = None
# set args
self.prompt_name = config.prompt_name
self.task = config.task
self.prompt = config.prompt
self.normalize = config.normalize
return
@TIME_METER("st_encode")
def _encode(self, texts: list[str], **kwargs) -> np.ndarray:
args = {
"sentences": texts,
"batch_size": len(texts),
"show_progress_bar": False,
"convert_to_numpy": True,
"normalize_embeddings": self.normalize,
}
if kwargs.get("task", self.task) is not None:
args["task"] = self.task
if kwargs.get("prompt_name", self.prompt_name) is not None:
args["prompt_name"] = self.prompt_name
if kwargs.get("prompt", self.prompt) is not None:
args["prompt"] = self.prompt
if (len(texts) >= len(self.devices) * 8) and (self.pool is not None):
args["pool"] = self.pool
args["batch_size"] = math.ceil(args["batch_size"] / len(self.devices))
embeddings = self.model.encode_multi_process(**args)
else:
embeddings = self.model.encode(**args)
return embeddings
@property
def embedding_size(self) -> int:
return self.model.get_sentence_embedding_dimension()