flexrag.models.cohere_model 源代码
import asyncio
import os
from typing import Annotated, Optional
import httpx
import numpy as np
from numpy import ndarray
from flexrag.utils import TIME_METER, Choices, configure
from .model_base import ENCODERS, EncoderBase, EncoderBaseConfig
[文档]
@configure
class CohereEncoderConfig(EncoderBaseConfig):
"""Configuration for CohereEncoder.
:param model: The model to use. Default is "embed-v4.0".
:type model: str
:param input_type: Specifies the type of input passed to the model.
Required for embedding models v3 and higher. Default is "search_document".
Available options are "search_document", "search_query", "classification", "clustering", "image".
:type input_type: str
:param embedding_size: The size of the embedding. Default is "1536".
Available options are "256", "512", "1024", "1536".
This option is only used for embedding models v4 and newer.
:type embedding_size: str
:param base_url: The base URL of the API. Default is None.
:type base_url: Optional[str]
:param api_key: The API key for the Cohere API.
If not provided, it will use the environment variable `COHERE_API_KEY`.
Defaults to None.
:type api_key: str
:param proxy: The proxy to use. Default is None.
:type proxy: Optional[str]
"""
model: str = "embed-v4.0"
input_type: Annotated[
str,
Choices(
"search_document",
"search_query",
"classification",
"clustering",
"image",
),
] = "search_document"
embedding_size: Annotated[str, Choices("256", "512", "1024", "1536")] = "1536"
base_url: Optional[str] = None
api_key: Optional[str] = None
proxy: Optional[str] = None
[文档]
@ENCODERS("cohere", config_class=CohereEncoderConfig)
class CohereEncoder(EncoderBase):
def __init__(self, cfg: CohereEncoderConfig):
from cohere import ClientV2
if cfg.proxy is not None:
httpx_client = httpx.Client(proxies=cfg.proxy)
else:
httpx_client = None
api_key = cfg.api_key or os.getenv("COHERE_API_KEY")
if not api_key:
raise ValueError(
"API key for Cohere is not provided. "
"Please set it in the configuration or as an environment variable 'COHERE_API_KEY'."
)
self.client = ClientV2(
api_key=api_key,
base_url=cfg.base_url,
httpx_client=httpx_client,
)
self.model = cfg.model
self.input_type = cfg.input_type
self._embedding_size = int(cfg.embedding_size)
super().__init__(cfg)
return
@TIME_METER("cohere_encode")
def _encode(self, texts: list[str]) -> ndarray:
embed_dim = self.embedding_size if self.model == "embed-v4.0" else None
r = self.client.embed(
texts=texts,
model=self.model,
input_type=self.input_type,
embedding_types=["float"],
output_dimension=embed_dim,
)
embeddings = r.embeddings.float
return np.array(embeddings)
@TIME_METER("cohere_encode")
async def async_encode(self, texts: list[str]):
task = asyncio.create_task(
asyncio.to_thread(
self.client.embed,
texts=texts,
model=self.model,
input_type=self.input_type,
embedding_types=["float"],
)
)
embeddings = (await task).embeddings.float
return np.array(embeddings)
@property
def embedding_size(self) -> int:
match self.model:
case "embed-multilingual-light-v3.0":
return 384
case "embed-multilingual-v3.0":
return 1024
case "embed-english-light-v3.0":
return 384
case "embed-english-v3.0":
return 1024
case "embed-v4.0":
if self._embedding_size is not None:
return self._embedding_size
return 1536
case _:
raise ValueError(
f"Unsupported model {self.model} for CohereEncoder. "
"Please specify the embedding size explicitly."
)