flexrag.ranker.ranker 源代码
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
from flexrag.utils import LOGGER_MANAGER, Register, configure
from flexrag.utils.dataclasses import RetrievedContext
logger = LOGGER_MANAGER.get_logger("flexrag.rankers")
[文档]
@configure
class RankerBaseConfig:
"""The configuration for the ranker.
:param reserve_num: the number of candidates to reserve.
If it is less than 0, all candidates will be reserved. Default is -1.
:type reserve_num: int
:param ranking_field: the field name of the ranking field in the retrieved context.
If it is None, the ranker will only accept a list of strings as candidates.
:type ranking_field: Optional[str]
"""
reserve_num: int = -1
ranking_field: Optional[str] = None
[文档]
@configure
class RankingResult:
"""The result of ranking.
:param query: the query string. Required.
:type query: str
:param candidates: the ranked candidates.
The results are sorted in descending order by relevance. Required.
:type candidates: list[RetrievedContext | str]
:param scores: the scores of the ranked candidates. Optional.
:type scores: Optional[list[float]]
"""
query: str
candidates: list[RetrievedContext]
scores: Optional[list[float]] = None
[文档]
class RankerBase(ABC):
def __init__(self, cfg: RankerBaseConfig) -> None:
self.reserve_num = cfg.reserve_num
self.ranking_field = cfg.ranking_field
return
[文档]
def rank(
self, query: str, candidates: list[RetrievedContext | str]
) -> RankingResult:
"""Rank the candidates based on the query.
:param query: query string.
:param candidates: list of candidate strings.
:type query: str
:type candidates: list[str]
:return: indices and scores of the ranked candidates.
:rtype: tuple[np.ndarray, np.ndarray]
"""
if isinstance(candidates[0], RetrievedContext):
assert self.ranking_field is not None
texts = [ctx.data[self.ranking_field] for ctx in candidates]
else:
texts = candidates
indices, scores = self._rank(query, texts)
if indices is None:
assert scores is not None
indices = np.argsort(scores)[::-1]
if self.reserve_num > 0:
indices = indices[: self.reserve_num]
result = RankingResult(query=query, candidates=[])
for idx in indices:
result.candidates.append(candidates[idx])
if scores is not None:
result.scores = [scores[idx] for idx in indices]
return result
[文档]
async def async_rank(
self, query: str, candidates: list[RetrievedContext | str]
) -> RankingResult:
"""The asynchronous version of `rank`."""
if isinstance(candidates[0], RetrievedContext):
assert self.ranking_field is not None
texts = [ctx.data[self.ranking_field] for ctx in candidates]
else:
texts = candidates
indices, scores = await self._async_rank(query, texts)
if indices is None:
assert scores is not None
indices = np.argsort(scores)[::-1]
if self.reserve_num > 0:
indices = indices[: self.reserve_num]
result = RankingResult(query=query, candidates=[])
for idx in indices:
result.candidates.append(candidates[idx])
if scores is not None:
result.scores = [scores[idx] for idx in indices]
return result
@abstractmethod
def _rank(self, query: str, candidates: list[str]) -> tuple[np.ndarray, np.ndarray]:
"""Rank the candidates based on the query.
:param query: query string.
:param candidates: list of candidate strings.
:type query: str
:type candidates: list[str]
:return: indices and scores of the ranked candidates.
:rtype: tuple[np.ndarray, np.ndarray]
"""
return
async def _async_rank(
self, query: str, candidates: list[str]
) -> tuple[np.ndarray, np.ndarray]:
"""The asynchronous version of `_rank`."""
logger.warning("async_rank is not implemented, using the synchronous version.")
return self._rank(query, candidates)
RANKERS = Register[RankerBase]("ranker")