flexrag.ranker.mixedbread_ranker 源代码

import asyncio
import os
from typing import Optional

import httpx
import numpy as np

from flexrag.utils import TIME_METER, configure

from .ranker import RANKERS, RankerBase, RankerBaseConfig


[文档] @configure class MixedbreadRankerConfig(RankerBaseConfig): """The configuration for the Mixedbread ranker. :param model: the model name of the ranker. Default is "mxbai-rerank-base-v2". :type model: str :param api_key: the API key for the Mixedbread ranker. If not provided, it will use the environment variable `MIXEDBREAD_API_KEY`. Defaults to None. :type api_key: str :param base_url: the base URL of the Mixedbread ranker. Default is None. :type base_url: Optional[str] :param proxy: the proxy for the request. Default is None. :type proxy: Optional[str] """ model: str = "mxbai-rerank-base-v2" base_url: Optional[str] = None api_key: Optional[str] = None proxy: Optional[str] = None
[文档] @RANKERS("mixedbread", config_class=MixedbreadRankerConfig) class MixedbreadRanker(RankerBase): """MixedbreadRanker: The ranker based on the Mixedbread API.""" def __init__(self, cfg: MixedbreadRankerConfig) -> None: super().__init__(cfg) from mixedbread import Mixedbread if cfg.proxy is not None: httpx_client = httpx.Client(proxies=cfg.proxy) else: httpx_client = None api_key = cfg.api_key or httpx_client or os.getenv("MIXEDBREAD_API_KEY") if not api_key: raise ValueError( "API key for Mixedbread is not provided. " "Please set it in the configuration or as an environment variable 'MIXEDBREAD_API_KEY'." ) self.client = Mixedbread( api_key=api_key, base_url=cfg.base_url, http_client=httpx_client ) self.model = cfg.model return @TIME_METER("mixedbread_rank") def _rank(self, query: str, candidates: list[str]) -> tuple[np.ndarray, np.ndarray]: result = self.client.rerank( query=query, input=candidates, model=self.model, top_k=len(candidates), ) scores = [i.score for i in result.data] return None, scores @TIME_METER("mixedbread_rank") async def _async_rank( self, query: str, candidates: list[str] ) -> tuple[np.ndarray, np.ndarray]: result = await asyncio.create_task( asyncio.to_thread( self.client.rerank, query=query, input=candidates, model=self.model, top_k=len(candidates), ) ) scores = [i.score for i in result.data] return None, scores