flexrag.ranker.gpt_ranker 源代码

import copy
import re
from pathlib import Path

import numpy as np

from flexrag.models import GENERATORS, GeneratorConfig
from flexrag.prompt import ChatPrompt, ChatTurn
from flexrag.utils import TIME_METER, configure

from .ranker import RANKERS, RankerBase, RankerBaseConfig


[文档] @configure class RankGPTRankerConfig(RankerBaseConfig, GeneratorConfig): """The configuration for the RankGPT ranker. :param step_size: the step size for the slide window ranking. Default is 10. :type step_size: int :param window_size: the window size for the slide window ranking. Default is 20. :type window_size: int :param max_chunk_size: the maximum chunk size for the slide window ranking. Default is 300. :type max_chunk_size: int """ step_size: int = 10 window_size: int = 20 max_chunk_size: int = 300
[文档] @RANKERS("rank_gpt", config_class=RankGPTRankerConfig) class RankGPTRanker(RankerBase): """RankGPTRanker: Rank the candidates based on the query using the Large Language model. Code was adapted from the original implementation from https://github.com/sunnweiwei/RankGPT """ def __init__(self, cfg: RankGPTRankerConfig): super().__init__(cfg) self.generator = GENERATORS.load(cfg) # load prompt prompt_path = Path(__file__).parent / "ranker_prompts" / "rankgpt_prompt.json" self.prompt = ChatPrompt.from_json(prompt_path) # set basic arguments self.step_size = cfg.step_size self.window_size = cfg.window_size self.max_chunk_size = cfg.max_chunk_size return @TIME_METER("rankgpt_rank") def _rank(self, query: str, candidates: list[str]) -> tuple[np.ndarray, None]: # perform slide window ranking indices = list(range(len(candidates))) start_idx = max(len(candidates) - self.window_size, 0) end_idx = len(candidates) while start_idx >= 0: start_idx = max(start_idx, 0) candidates_ = [candidates[i] for i in indices[start_idx:end_idx]] indices_, _ = self._rank_piece(query, candidates_) indices[start_idx:end_idx] = indices_ start_idx = start_idx - self.step_size end_idx -= self.step_size return np.array(indices), None @TIME_METER("rankgpt_rank") async def _async_rank( self, query: str, candidates: list[str] ) -> tuple[np.ndarray, None]: # perform slide window ranking indices = list(range(len(candidates))) start_idx = max(len(candidates) - self.window_size, 0) end_idx = len(candidates) while start_idx >= 0: start_idx = max(start_idx, 0) candidates_ = [candidates[i] for i in indices[start_idx:end_idx]] indices_, _ = await self._async_rank_piece(query, candidates_) indices[start_idx:end_idx] = indices_ start_idx = start_idx - self.step_size end_idx -= self.step_size return np.array(indices), None def _rank_piece(self, query: str, candidates: list[str]) -> tuple[np.ndarray, None]: prompt = self._get_prompt(query=query, candidates=candidates) response = self.generator.chat(prompts=[prompt])[0][0] # convert string to indices response = re.sub(r"\D", " ", response) indices_ = [int(x) - 1 for x in response.split()] # deduplicate indices indices = [] for i in indices_: if i not in indices: indices.append(i) # refine indices ori_indices = list(range(len(candidates))) new_indices = [idx for idx in indices if idx in ori_indices] new_indices = new_indices + [ idx for idx in ori_indices if idx not in new_indices ] return new_indices, None async def _async_rank_piece( self, query: str, candidates: list[str] ) -> tuple[np.ndarray, None]: prompt = self._get_prompt(query=query, candidates=candidates) response = (await self.generator.async_chat(prompts=[prompt]))[0][0] # convert string to indices response = re.sub(r"\D", " ", response) indices_ = [int(x) - 1 for x in response.split()] # deduplicate indices indices = [] for i in indices_: if i not in indices: indices.append(i) # refine indices ori_indices = list(range(len(candidates))) new_indices = [idx for idx in indices if idx in ori_indices] new_indices = new_indices + [ idx for idx in ori_indices if idx not in new_indices ] return new_indices, None def _get_prompt(self, query: str, candidates: list[str]): max_length = 300 prompt = copy.deepcopy(self.prompt) prompt.history[0].content = prompt.history[0].content.format( query=query, num=len(candidates) ) last_turn = prompt.history.pop() last_turn.content = last_turn.content.format(query=query, num=len(candidates)) rank = 0 for cand in candidates: rank += 1 content = cand.replace("Title: Content: ", "") content = content.strip() content = " ".join(content.split()[: int(max_length)]) prompt.update(ChatTurn(role="user", content=f"[{rank}] {content}")) prompt.update( ChatTurn(role="assistant", content=f"Received passage [{rank}].") ) prompt.update(last_turn) return prompt