Source code for flexrag.retriever.flex_retriever

import os
import shutil
from collections import defaultdict
from typing import Annotated, Any, Generator, Iterable, Optional

from jinja2 import Template

from flexrag.database import (
    LMDBRetrieverDatabase,
    NaiveRetrieverDatabase,
    RetrieverDatabaseBase,
)
from flexrag.utils import (
    __VERSION__,
    LOGGER_MANAGER,
    TIME_METER,
    Choices,
    SimpleProgressLogger,
    configure,
)
from flexrag.utils.configure import extract_config
from flexrag.utils.dataclasses import Context, RetrievedContext

from .index import (
    RETRIEVER_INDEX,
    MultiFieldIndex,
    MultiFieldIndexConfig,
    RetrieverIndexConfig,
)
from .retriever_base import RETRIEVERS, LocalRetriever, LocalRetrieverConfig

logger = LOGGER_MANAGER.get_logger("flexrag.retreviers.flex")


RETRIEVER_CARD_TEMPLATE = Template(
    """---
language: en
library_name: FlexRAG
tags:
- FlexRAG
- retrieval
- search
- lexical
- RAG
---

# FlexRAG Retriever

This is a {{ retriever_type }} created with the [`FlexRAG`](https://github.com/ictnlp/flexrag) library (version `{version}`).

## Installation

You can install the `FlexRAG` library with `pip`:

```bash
pip install flexrag
```

## Loading a `FlexRAG` retriever

You can use this retriever for information retrieval tasks. Here is an example:

```python
from flexrag.retriever import LocalRetriever

{% if repo_id is not none %}
# Load the retriever from the HuggingFace Hub
retriever = LocalRetriever.load_from_hub("{{ repo_id }}")
{% else %}
# Load the retriever from a local path
retriever = LocalRetriever.load_from_local("{{ repo_path }}")
{% endif %}

# You can retrieve now
results = retriever.search("Who is Bruce Wayne?")
```

FlexRAG Related Links:
* 📚[Documentation](https://flexrag.readthedocs.io/en/latest/)
* 💻[GitHub Repository](https://github.com/ictnlp/flexrag)
"""
)


[docs] @configure class FlexRetrieverConfig(LocalRetrieverConfig): """Configuration class for FlexRetriever. :param indexes_merge_method: Method to merge the scores of multiple indexes. Available choices are "rrf" and "linear". Default is "rrf". * "rrf": Reciprocal Rank Fusion (RRF) method. * "linear": Linear combination of the scores. :type indexes_merge_method: str :param merge_weights: List of weights for each index. Default is None. If None, all indexes will be treated equally. This option is used in both "rrf" and "linear" methods. :type merge_weights: Optional[list[float]] :param used_indexes: List of indexes to use for retrieval. Default is None. If None, all indexes will be used. :type used_indexes: Optional[list[str]] :param rrf_base: Base for the RRF method. Default is 60. This option is only used when `indexes_merge_method` is "rrf". :type rrf_base: int """ indexes_merge_method: Annotated[str, Choices("rrf", "linear")] = "rrf" indexes_merge_weights: Optional[list[float]] = None used_indexes: Optional[list[str]] = None rrf_base: int = 60
[docs] @RETRIEVERS("flex", config_class=FlexRetrieverConfig) class FlexRetriever(LocalRetriever): """FlexRetriever is a retriever implemented by FlexRAG team. FlexRetriever supports multi-index and multi-field retrieval. """ cfg: FlexRetrieverConfig def __init__(self, cfg: FlexRetrieverConfig) -> None: super().__init__(cfg) self.cfg = extract_config(cfg, FlexRetrieverConfig) # load the retriever if the retriever_path is set self.database = self._load_database() self.index_table = self._load_index() # consistency check self._check_consistency() return @TIME_METER("flex_retriever", "add-passages") def add_passages(self, passages: Iterable[Context]): def get_batch() -> Generator[tuple[list[dict], list[str]], None, None]: batch = [] ids = [] for passage in passages: if len(batch) == self.cfg.batch_size: yield batch, ids batch = [] ids = [] data = passage.data.copy() ids.append(passage.context_id) batch.append(data) if batch: yield batch, ids return # add data to database context_ids = [] p_logger = SimpleProgressLogger(logger, interval=self.cfg.log_interval) for batch, ids in get_batch(): self.database[ids] = batch context_ids.extend(ids) p_logger.update(step=len(batch), desc="Adding passages") # update the indexes self._update_index(context_ids) logger.info("Finished adding passages.") return @TIME_METER("flex_retriever", "search") def search( self, query: list[str] | str, **search_kwargs, ) -> list[list[RetrievedContext]]: if isinstance(query, str): query = [query] top_k = search_kwargs.pop("top_k", self.cfg.top_k) used_indexes = search_kwargs.pop("used_indexes", self.cfg.used_indexes) if used_indexes is None: used_indexes = list(self.index_table.keys()) for index_name in used_indexes: assert index_name in self.index_table, f"Index {index_name} not found." assert len(used_indexes) > 0, "`used_indexes` is empty." # retrieve indices using `used_indexes` all_context_ids = [] all_scores = [] for index_name in used_indexes: r = self.index_table[index_name].search(query, top_k, **search_kwargs) all_context_ids.append(r[0]) all_scores.append(r[1]) # merge the indices and scores merged_ids: list[list[str]] = [] merged_scores: list[list[float]] = [] if len(all_scores) == 1: # only one index is activated merged_scores = all_scores[0].tolist() merged_ids = all_context_ids[0] else: # merge multiple indexes merge_method = search_kwargs.pop( "indexes_merge_method", self.cfg.indexes_merge_method ) match merge_method: case "rrf": # prepare merge weights if self.cfg.indexes_merge_weights is not None: assert len(self.cfg.indexes_merge_weights) == len(used_indexes) merge_weights = [ i / sum(self.cfg.indexes_merge_weights) for i in self.cfg.indexes_merge_weights ] else: merge_weights = [1.0 / len(all_scores)] * len(all_scores) # recompute the scores according to the rank for i in range(len(query)): scores_dict = defaultdict(float) for ctx_ids, scores, merge_weight in zip( all_context_ids, all_scores, merge_weights ): sort_ranks = scores[i].argsort()[::-1] + 1 for ctx_id, rank in zip(ctx_ids[i], sort_ranks): scores_dict[ctx_id] += merge_weight / ( rank + self.cfg.rrf_base ) sorted_items = sorted(scores_dict.items(), key=lambda x: -x[1]) merged_ids.append([item[0] for item in sorted_items][:top_k]) merged_scores.append([item[1] for item in sorted_items][:top_k]) case "linear": # prepare merge weights if self.cfg.indexes_merge_weights is not None: assert len(self.cfg.indexes_merge_weights) == len(used_indexes) merge_weights = [ i / sum(self.cfg.indexes_merge_weights) for i in self.cfg.indexes_merge_weights ] else: merge_weights = [1.0 / len(all_scores)] * len(all_scores) # According to "An Analysis of Fusion Functions for Hybrid Retrieval", # we employ the TMM normalization method to normalize the scores. if any( self.index_table[index_name].infimum == float("-inf") for index_name in used_indexes ): use_infimum = False else: use_infimum = True for n in range(len(used_indexes)): index_name = used_indexes[n] if use_infimum: infimum = self.index_table[index_name].infimum else: infimum = all_scores[n].min(axis=1, keepdims=True) all_scores[n] = (all_scores[n] - infimum) / ( all_scores[n].max(axis=1, keepdims=True) - infimum ) # merge the scores for i in range(len(query)): scores_dict = defaultdict(float) for ctx_ids, scores, merge_weight in zip( all_context_ids, all_scores, merge_weights ): for ctx_id, score in zip(ctx_ids[i], scores[i]): scores_dict[ctx_id] += score * merge_weight sorted_items = sorted(scores_dict.items(), key=lambda x: -x[1]) merged_ids.append([item[0] for item in sorted_items][:top_k]) merged_scores.append([item[1] for item in sorted_items][:top_k]) case _: raise ValueError(f"Unknown merge method: {merge_method}") # form the final results results: list[list[RetrievedContext]] = [] for i, (q, score, context_id) in enumerate( zip(query, merged_scores, merged_ids) ): results.append([]) ctxs = self.database[context_id] for j, (s, ctx_id, ctx) in enumerate(zip(score, context_id, ctxs)): results[-1].append( RetrievedContext( context_id=ctx_id, retriever="FlexRetriever", query=q, score=float(s), data=ctx, ) ) return results
[docs] def clear(self) -> None: # clear the indexes for index_name in self.index_table: self.index_table[index_name].clear() # clear the database self.database.clear() # clear the directory if self.cfg.retriever_path is not None: if os.path.exists(self.cfg.retriever_path): shutil.rmtree(self.cfg.retriever_path) return
def __len__(self) -> int: return len(self.database) @property def fields(self) -> list[str]: return self.database.fields @TIME_METER("flex_retriever", "add-index") def add_index( self, index_name: str, index_config: RetrieverIndexConfig, # type: ignore indexed_fields_config: MultiFieldIndexConfig, ) -> None: """Add an index to the retriever. :param index_name: Name of the index. :type index_name: str :param index_config: Configuration of the index. :type index_config: RetrieverIndexConfig :param indexed_fields_config: Configuration of the indexed fields. :type indexed_fields_config: MultiFieldIndexConfig :raises ValueError: If the index name already exists. :return: None :rtype: None """ # check if the index name is valid if index_name in self.index_table: raise ValueError( f"Index {index_name} already exists. Please remove it first." ) # prepare the index index = RETRIEVER_INDEX.load(index_config) index = MultiFieldIndex(indexed_fields_config, index) index.build_index(self.database.ids, self.database.values()) # prepare index path if self.cfg.retriever_path is not None: index_path = os.path.join(self.cfg.retriever_path, "indexes", index_name) else: index_path = None if index_path is not None: index.save_to_local(index_path) # add index to the index table self.index_table[index_name] = index self._check_consistency() logger.info(f"Finished adding index: {index_name}") return
[docs] def remove_index(self, index_name: str) -> None: """Remove an index from the retriever. :param index_name: Name of the index. :type index_name: str :raises ValueError: If the index name does not exist. :return: None :rtype: None """ if index_name not in self.index_table: raise ValueError(f"Index {index_name} does not exist.") # remove the index index = self.index_table.pop(index_name) index.clear() # update the configuration if index_name in self.cfg.used_indexes: self.cfg.used_indexes.remove(index_name) return
[docs] def save_to_local(self, retriever_path: str = None) -> None: # check if the retriever is serializable if self.cfg.retriever_path is not None: if retriever_path == self.cfg.retriever_path: return # skip saving if the path is the same else: assert retriever_path is not None, "`retriever_path` is not set." self.cfg.retriever_path = retriever_path self._check_retriever_path(retriever_path) logger.info(f"Serializing retriever to {retriever_path}") # save the database def get_data() -> Generator[tuple[list[str], list[dict]], None, None]: batch_ids = [] batch_data = [] for ctx_id, ctx in self.database.items(): # unify the schema # FIXME: if the schema is not consistent, we need to handle it ctx = {k: ctx.get(k, "") for k in self.fields} batch_ids.append(ctx_id) batch_data.append(ctx) if len(batch_ids) == self.cfg.batch_size: yield batch_ids, batch_data batch_ids = [] batch_data = [] if batch_ids: yield batch_ids, batch_data return new_db = LMDBRetrieverDatabase(os.path.join(retriever_path, "database.lmdb")) for batch_ids, batch_data in get_data(): new_db[batch_ids] = batch_data self.database = new_db # save the index for index_name, index in self.index_table.items(): index_path = os.path.join(retriever_path, "indexes", index_name) index.save_to_local(index_path) return
[docs] def detach(self): """Detach the retriever from the local disk to memory. This function will not delete the database or the indexes.""" def get_data() -> Generator[tuple[list[str], list[dict]], None, None]: batch_ids = [] for ctx_id in self.database.ids: batch_ids.append(ctx_id) if len(batch_ids) == self.cfg.batch_size: yield batch_ids, self.database[batch_ids] batch_ids = [] if batch_ids: yield batch_ids, self.database[batch_ids] return # detach the database if isinstance(self.database, LMDBRetrieverDatabase): new_db = NaiveRetrieverDatabase() for batch_ids, batch_data in get_data(): new_db[batch_ids] = batch_data self.database = new_db # detach the indexes for index_name, index in self.index_table.items(): index.index.cfg.index_path = None # update the configuration self.cfg.retriever_path = None return
def _update_index(self, context_ids: list[str]) -> None: def get_data() -> Generator[tuple[Any, int], None, None]: for ctx_id in context_ids: yield self.database[ctx_id] for index_name, index in self.index_table.items(): if index.is_addable: index.insert_batch(context_ids, get_data(), serialize=True) else: logger.warning( f"Index {index_name} is not addable. Rebuilding the index." ) index.clear() index.build_index(get_data()) return def _load_database(self) -> RetrieverDatabaseBase: if self.cfg.retriever_path is not None: database_path = os.path.join(self.cfg.retriever_path, "database.lmdb") database = LMDBRetrieverDatabase(database_path) else: database = NaiveRetrieverDatabase() return database def _load_index(self) -> dict[str, MultiFieldIndex]: # load indexes indexes = {} if self.cfg.retriever_path is None: return indexes if not os.path.exists(os.path.join(self.cfg.retriever_path, "indexes")): return indexes indexes_names = os.listdir(os.path.join(self.cfg.retriever_path, "indexes")) for index_name in indexes_names: index_path = os.path.join(self.cfg.retriever_path, "indexes", index_name) index = MultiFieldIndex.load_from_local(index_path) indexes[index_name] = index return indexes def _check_consistency(self) -> None: if self.cfg.retriever_path is not None: self._check_retriever_path(self.cfg.retriever_path) for index_name, index in self.index_table.items(): assert len(index) == len(self.database), "Index and database size mismatch" return def _check_retriever_path(self, retriever_path: str) -> None: if not os.path.exists(retriever_path): os.makedirs(retriever_path) # save the retriever card card_path = os.path.join(retriever_path, "README.md") if not os.path.exists(card_path): retriever_card = RETRIEVER_CARD_TEMPLATE.render( retriever_type=self.__class__.__name__, version=__VERSION__, repo_path=self.cfg.retriever_path, ) with open(card_path, "w", encoding="utf-8") as f: f.write(retriever_card) # save the configuration cfg_path = os.path.join(retriever_path, "config.yaml") if not os.path.exists(cfg_path): self.cfg.dump(cfg_path) id_path = os.path.join(retriever_path, "cls.id") if not os.path.exists(id_path): with open(id_path, "w", encoding="utf-8") as f: f.write(self.__class__.__name__)