flexrag.models.jina_model 源代码

import os
from typing import Annotated, Optional

import httpx
import numpy as np
from numpy import ndarray

from flexrag.utils import TIME_METER, Choices, configure

from .model_base import ENCODERS, EncoderBase, EncoderBaseConfig


[文档] @configure class JinaEncoderConfig(EncoderBaseConfig): """Configuration for JinaEncoder. :param model: The model to use. Default is "jina-embeddings-v3". :type model: str :param base_url: The base URL of the Jina embeddings API. Default is "https://api.jina.ai/v1/embeddings". :type base_url: str :param api_key: The API key for the Jina embeddings API. If not provided, it will use the environment variable `JINA_API_KEY`. Defaults to None. :type api_key: str :param embedding_size: The dimension of the embeddings. Default is 1024. :type embedding_size: int :param task: The task for the embeddings. Default is None. Available options are "retrieval.query", "retrieval.passage", "separation", "classification", and "text-matching". :type task: str :param proxy: The proxy to use. Defaults to None. :type proxy: Optional[str] """ model: str = "jina-embeddings-v3" base_url: str = "https://api.jina.ai/v1/embeddings" api_key: Optional[str] = None embedding_size: int = 1024 task: Optional[ Annotated[ str, Choices( "retrieval.query", "retrieval.passage", "separation", "classification", "text-matching", ), ] ] = None proxy: Optional[str] = None
[文档] @ENCODERS("jina", config_class=JinaEncoderConfig) class JinaEncoder(EncoderBase): def __init__(self, cfg: JinaEncoderConfig): super().__init__(cfg) api_key = cfg.api_key or os.getenv("JINA_API_KEY") if not api_key: raise ValueError( "API key for Jina embeddings is not provided. " "Please set it in the configuration or as an environment variable 'JINA_API_KEY'." ) # prepare client self.client = httpx.Client( headers={ "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", }, proxy=cfg.proxy, base_url=cfg.base_url, follow_redirects=True, ) self.async_client = httpx.AsyncClient( headers={ "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", }, proxy=cfg.proxy, base_url=cfg.base_url, follow_redirects=True, ) # prepare template self.data_template = { "model": cfg.model, "task": cfg.task, "dimensions": cfg.embedding_size, "late_chunking": False, "embedding_type": "float", "input": [], } return @TIME_METER("jina_encode") def _encode(self, texts: list[str]) -> ndarray: data = self.data_template.copy() data["input"] = texts response = self.client.post("", json=data) response.raise_for_status() embeddings = [i["embedding"] for i in response.json()["data"]] return np.array(embeddings)[:, : self.embedding_size] @TIME_METER("jina_encode") async def async_encode(self, texts: list[str]) -> ndarray: data = self.data_template.copy() data["input"] = texts response = await self.async_client.post("", json=data) embeddings = [i["embedding"] for i in (await response.json())["data"]] return np.array(embeddings)[:, : self.embedding_size] @property def embedding_size(self) -> int: return self.data_template["dimensions"]