flexrag.ranker.jina_ranker 源代码
import os
from typing import Optional
import httpx
import numpy as np
from flexrag.utils import TIME_METER, configure
from .ranker import RANKERS, RankerBase, RankerBaseConfig
[文档]
@configure
class JinaRankerConfig(RankerBaseConfig):
"""The configuration for the Jina ranker.
:param model: the model name of the ranker. Default is "jina-reranker-v2-base-multilingual".
:type model: str
:param base_url: the base URL of the Jina ranker. Default is "https://api.jina.ai/v1/rerank".
:type base_url: str
:param api_key: the API key for the Jina ranker.
If not provided, it will use the environment variable `JINA_API_KEY`.
Defaults to None.
:type api_key: str
:param proxy: The proxy to use. Defaults to None.
:type proxy: Optional[str]
"""
model: str = "jina-reranker-v2-base-multilingual"
base_url: str = "https://api.jina.ai/v1/rerank"
api_key: Optional[str] = None
proxy: Optional[str] = None
[文档]
@RANKERS("jina", config_class=JinaRankerConfig)
class JinaRanker(RankerBase):
"""JinaRanker: The ranker based on the Jina API."""
def __init__(self, cfg: JinaRankerConfig) -> None:
super().__init__(cfg)
# prepare client
api_key = cfg.api_key or os.getenv("JINA_API_KEY")
if not api_key:
raise ValueError(
"API key for Jina is not provided. "
"Please set it in the configuration or as an environment variable 'JINA_API_KEY'."
)
self.client = httpx.Client(
base_url=cfg.base_url,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
},
proxy=cfg.proxy,
follow_redirects=True,
)
self.async_client = httpx.AsyncClient(
base_url=cfg.base_url,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
},
proxy=cfg.proxy,
follow_redirects=True,
)
# prepare data template
self.data_template = {
"model": cfg.model,
"query": "",
"top_n": 0,
"documents": [],
}
return
@TIME_METER("jina_rank")
def _rank(self, query: str, candidates: list[str]) -> tuple[np.ndarray, np.ndarray]:
data = self.data_template.copy()
data["query"] = query
data["documents"] = candidates
data["top_n"] = len(candidates)
response = self.client.post("", json=data)
response.raise_for_status()
scores = [i["relevance_score"] for i in response.json()["results"]]
return None, scores
@TIME_METER("jina_rank")
async def _async_rank(
self, query: str, candidates: list[str]
) -> tuple[np.ndarray, np.ndarray]:
data = self.data_template.copy()
data["query"] = query
data["documents"] = candidates
data["top_n"] = len(candidates)
response = await self.async_client.post("", json=data)
await response.raise_for_status()
scores = [i["relevance_score"] for i in (await response.json())["results"]]
return None, scores