Source code for flexrag.retriever.retriever_base

import asyncio
import os
import tempfile
import time
from abc import ABC, abstractmethod
from dataclasses import asdict, field
from typing import Any, Generator, Iterable, Optional

import numpy as np
from huggingface_hub import HfApi

from flexrag.text_process import TextProcessPipeline, TextProcessPipelineConfig
from flexrag.utils import (
    __VERSION__,
    FLEXRAG_CACHE_DIR,
    LOGGER_MANAGER,
    LRUPersistentCache,
    Register,
    SimpleProgressLogger,
    configure,
)
from flexrag.utils.dataclasses import Context, RetrievedContext

logger = LOGGER_MANAGER.get_logger("flexrag.retrievers")


# load cache for retrieval
RETRIEVAL_CACHE: LRUPersistentCache | None
if os.environ.get("DISABLE_CACHE", "False") == "True":
    RETRIEVAL_CACHE = None
else:
    cache_size = os.environ.get("CACHE_SIZE", None)
    cache_path = os.environ.get(
        "CACHE_PATH", os.path.join(FLEXRAG_CACHE_DIR, "retrieval_cache")
    )
    RETRIEVAL_CACHE = LRUPersistentCache(maxsize=cache_size, cache_path=cache_path)


# FIXME: fix this
def batched_cache(func):
    """The helper function to cache the retrieval results in batch.
    You can use this function to decorate the `search` method of the retriever class to cache the retrieval results in batch.
    """

    def dict_to_retrieved(data: list[dict] | None) -> list[RetrievedContext] | None:
        if data is None:
            return None
        return [RetrievedContext(**r) for r in data]

    def check(data: list):
        for d in data:
            assert isinstance(d, list)
            for r in d:
                assert isinstance(r, RetrievedContext)
        return

    def wrapper(
        self,
        query: list[str],
        disable_cache: bool = False,
        **search_kwargs,
    ):
        # check query
        if isinstance(query, str):
            query = [query]

        # direct search
        if (RETRIEVAL_CACHE is None) or disable_cache:
            return func(self, query, **search_kwargs)

        # search from cache
        cfg = asdict(self.cfg)
        keys = [
            {
                "retriever_config": cfg,
                "query": q,
                "search_kwargs": search_kwargs,
            }
            for q in query
        ]
        results = [dict_to_retrieved(RETRIEVAL_CACHE.get(k, None)) for k in keys]

        # search from database
        new_query = [q for q, r in zip(query, results) if r is None]
        new_indices = [n for n, r in enumerate(results) if r is None]
        if new_query:
            new_results = func(self, new_query, **search_kwargs)
            # update cache
            for n, r in zip(new_indices, new_results):
                results[n] = r
                RETRIEVAL_CACHE[keys[n]] = asdict(r)
        # check results
        check(results)
        return results

    return wrapper


[docs] @configure class RetrieverBaseConfig: """Base configuration class for all retrievers. :param log_interval: The interval of logging. Default: 100. :type log_interval: int :param top_k: The number of retrieved documents. Default: 10. :type top_k: int :param batch_size: The batch size for retrieval. Default: 32. :type batch_size: int :param query_preprocess_pipeline: The text process pipeline for query. Default: TextProcessPipelineConfig. :type query_preprocess_pipeline: TextProcessPipelineConfig """ log_interval: int = 100 top_k: int = 10 batch_size: int = 32 query_preprocess_pipeline: TextProcessPipelineConfig = field( # type: ignore default_factory=TextProcessPipelineConfig )
[docs] class RetrieverBase(ABC): """The base class for all retrievers. The subclasses should implement the ``search`` method and the ``fields`` property. """ cfg: RetrieverBaseConfig def __init__(self, cfg: RetrieverBaseConfig): # load preprocess pipeline self.query_preprocess_pipeline = TextProcessPipeline( cfg.query_preprocess_pipeline ) return @batched_cache def search_batch( self, query: Iterable[Any], no_preprocess: bool = False, **search_kwargs, ) -> list[list[RetrievedContext]]: """Search queries in batches. :param query: Queries to search. :type query: list[Any] :param no_preprocess: Whether to preprocess the query. Default: False. :type no_preprocess: bool :param search_kwargs: Other search arguments. :type search_kwargs: Any :return: A batch of list that contains k RetrievedContext. :rtype: list[list[RetrievedContext]] """ def get_batch() -> Generator[list[Any], None, None]: batch = [] for q in query: if not no_preprocess: batch.append(self.query_preprocess_pipeline(q)) else: batch.append(q) if len(batch) == self.cfg.batch_size: yield batch batch = [] if batch: yield batch return final_results = [] total = len(query) if hasattr(query, "__len__") else None p_logger = SimpleProgressLogger(logger, total, self.cfg.log_interval) for batch in get_batch(): results_ = self.search(batch, **search_kwargs) final_results.extend(results_) p_logger.update(1, "Retrieving") return final_results
[docs] @abstractmethod def search( self, query: list[Any] | Any, **search_kwargs, ) -> list[list[RetrievedContext]]: """Search a batch of queries. :param query: Queries to search. :type query: list[Any] | Any :param search_kwargs: Keyword arguments, contains other search arguments. :type search_kwargs: Any :return: A batch of list that contains k RetrievedContext. :rtype: list[list[RetrievedContext]] """ return
@property @abstractmethod def fields(self) -> list[str]: """The fields of the retrieved data.""" return
[docs] def test_speed( self, sample_num: int = 10000, test_times: int = 10, **search_kwargs, ) -> float: """Test the speed of the retriever. :param sample_num: The number of samples to test. :type sample_num: int, optional :param test_times: The number of times to test. :type test_times: int, optional :return: The time consumed for retrieval. :rtype: float """ from nltk.corpus import brown total_times = [] sents = [" ".join(i) for i in brown.sents()] for _ in range(test_times): query = [sents[i % len(sents)] for i in range(sample_num)] start_time = time.perf_counter() _ = self.search(query, self.cfg.top_k, **search_kwargs) end_time = time.perf_counter() total_times.append(end_time - start_time) avg_time = sum(total_times) / test_times std_time = np.std(total_times) logger.info( f"Retrieval {sample_num} items consume: {avg_time:.4f} ± {std_time:.4f} s" ) return end_time - start_time
RETRIEVERS = Register[RetrieverBase]("retriever", True)
[docs] @configure class EditableRetrieverConfig(RetrieverBaseConfig): """Configuration class for LocalRetriever."""
[docs] class EditableRetriever(RetrieverBase): """The base class for all `editable` retrievers. In FlexRAG, the ``EditableRetriever`` is a concept referring to a retriever that includes the ``add_passages`` and ``clear`` methods, allowing you to build the retriever using your own knowledge base. FlexRAG provides following editable retrievers: ``FlexRetriever``, ``ElasticRetriever``, ``TypesenseRetriever``, and ``HydeRetriever``. The subclasses should implement the ``add_passages``, ``clear``, and ``__len__`` methods. """
[docs] @abstractmethod def add_passages(self, passages: Iterable[Context]): """ Add passages to the retriever database. :param passages: The passages to add. :type passages: Iterable[Context] :return: None """ return
[docs] @abstractmethod def clear(self) -> None: """Clear the retriever database.""" return
@abstractmethod def __len__(self): """Return the number of documents in the retriever database.""" return
[docs] @configure class LocalRetrieverConfig(EditableRetrieverConfig): """The configuration class for LocalRetriever. :param retriever_path: The path to the local database. Default: None. If specified, all modifications to the retriever will be applied simultaneously on the disk. If not specified, the retriever will be kept in memory. :type retriever_path: Optional[str] """ retriever_path: Optional[str] = None
[docs] class LocalRetriever(EditableRetriever): """The base class for all `local` retrievers. In FlexRAG, the ``LocalRetriever`` is a concept referring to a retriever that can be saved to the local disk. The subclasses provide the ``save_to_local`` and ``load_from_local`` methods to save and load the retriever from the local disk, and the ``save_to_hub`` and ``load_from_hub`` methods to save and load the retriever from the HuggingFace Hub. FlexRAG provides following local retrievers: ``FlexRetriever``, and ``HydeRetriever``. For example, to load a retriever hosted on the HuggingFace Hub, you can run the following code: .. code-block:: python from flexrag.retriever import LocalRetriever retriever = LocalRetriever.load_from_hub("flexrag/wiki2021_atlas_bm25s") To save a retriever to the HuggingFace Hub, you can run the following code: .. code-block:: python retriever.save_to_hub("<your-repo-id>", token="<your-token>") """ cfg: LocalRetrieverConfig
[docs] @staticmethod def load_from_hub( repo_id: str, revision: str = None, token: str = None, cache_dir: str = FLEXRAG_CACHE_DIR, **kwargs, ) -> "LocalRetriever": """Load a retriever from the HuggingFace Hub. :param repo_id: The repo id of the retriever on the HuggingFace Hub. :type repo_id: str :param revision: The revision of the retriever on the HuggingFace Hub. Default: None. :type revision: str :param token: The token to access the HuggingFace Hub. Default: None. :type token: str :param cache_dir: The cache directory to store the retriever. Default: FLEXRAG_CACHE_DIR. :type cache_dir: str :param kwargs: Additional arguments for the retriever. :type kwargs: Any :return: The loaded retriever. :rtype: LocalRetriever """ # check if the retriever exists api = HfApi(token=token) repo_info = api.repo_info(repo_id) if repo_info is None: raise ValueError(f"Retriever {repo_id} not found on the HuggingFace Hub.") repo_id = repo_info.id dir_name = os.path.join( cache_dir, f"{repo_id.split('/')[0]}--{repo_id.split('/')[1]}" ) # lancedb does not support loading the database from a symlink snapshot = api.snapshot_download( repo_id=repo_id, revision=revision, token=token, local_dir=dir_name, ) if snapshot is None: raise RuntimeError(f"Retriever {repo_id} download failed.") # load the retriever return LocalRetriever.load_from_local(snapshot, **kwargs)
[docs] def save_to_hub( self, repo_id: str, token: str = os.environ.get("HF_TOKEN", None), commit_message: str = "Update FlexRAG retriever", retriever_card: str = None, private: bool = False, **kwargs, ) -> str: """Save the retriever to the HuggingFace Hub. :param repo_id: The repo id of the retriever on the HuggingFace Hub. :type repo_id: str :param token: The token to access the HuggingFace Hub. Default: None. :type token: str :param commit_message: The commit message for the retriever. Default: "Update FlexRAG retriever". :type commit_message: str :param retriever_card: The markdown readme file for the retriever. Default: None. :type retriever_card: str :param private: Whether to create a private repo. Default: False. :type private: bool :param kwargs: Additional arguments for uploading the retriever. :type kwargs: Any :return: The repo url of the retriever. :rtype: str """ # make a temporary directory if retriever_path is not specified if self.cfg.retriever_path is None: with tempfile.TemporaryDirectory(prefix="flexrag-retriever") as tmp_dir: logger.info( ( "As the `retriever_path` is not set, " f"the retriever will be saved temporarily at {tmp_dir}." ) ) self.save_to_local(tmp_dir, update_config=True) self.save_to_hub( token=token, repo_id=repo_id, commit_message=commit_message, retriever_card=retriever_card, private=private, **kwargs, ) self.cfg.retriever_path = None return # prepare the client api = HfApi(token=token) # create repo if not exists repo_url = api.create_repo( repo_id=repo_id, token=api.token, private=private, repo_type="model", exist_ok=True, ) repo_id = repo_url.repo_id # push to hub api.upload_folder( repo_id=repo_id, commit_message=commit_message, folder_path=self.cfg.retriever_path, **kwargs, ) return repo_url
[docs] @staticmethod def load_from_local(repo_path: str = None, **kwargs) -> "LocalRetriever": """Load a retriever from the local disk. :param repo_path: The path to the local database. Default: None. :type repo_path: str :return: The loaded retriever. :rtype: LocalRetriever """ # prepare the cls id_path = os.path.join(repo_path, "cls.id") with open(id_path, "r", encoding="utf-8") as f: retriever_name = f.read() retriever_cls = RETRIEVERS[retriever_name]["item"] config_cls = RETRIEVERS[retriever_name]["config_class"] # prepare the configuration config_path = os.path.join(repo_path, "config.yaml") cfg: LocalRetrieverConfig = config_cls.load(config_path) cfg.retriever_path = repo_path # load the retriever retriever = retriever_cls(cfg) return retriever
[docs] @abstractmethod def save_to_local(self, retriever_path: str = None): """Save the retriever to the local disk. :param retriever_path: The path to the local database. Default: None. :type retriever_path: str :return: None :rtype: None """ return
[docs] @abstractmethod def detach(self): """Detach the retriever from the local database. After detaching, the retriever will be kept in memory and all modifications will not be applied to the disk. """ return