import os
from abc import ABC, abstractmethod
from dataclasses import field
from typing import Annotated, Any, Generator, Iterable, Optional
from uuid import uuid4
import numpy as np
from huggingface_hub import HfApi
from flexrag.models import ENCODERS, EncoderConfig
from flexrag.utils import (
FLEXRAG_CACHE_DIR,
LOGGER_MANAGER,
TIME_METER,
Choices,
Register,
SimpleProgressLogger,
configure,
)
logger = LOGGER_MANAGER.get_logger("flexrag.retrievers.index")
[docs]
@configure
class RetrieverIndexBaseConfig:
"""The configuration for the `RetrieverIndexBase`.
:log_interval: The interval to log the progress. Defaults to 10000.
:type log_interval: int
:batch_size: The batch size to add data to the index. Defaults to 512.
:type batch_size: int
:param index_path: The path to save the index.
If not specified, the index will be kept in memory.
Defaults to None.
:type index_path: Optional[str]
"""
log_interval: int = 10000
batch_size: int = 512
index_path: Optional[str] = None
[docs]
class RetrieverIndexBase(ABC):
"""The base class for all retriever indexes.
This class provides the basic interface for building, adding, and searching the index.
The subclass should implement the following methods:
- `build_index`: Build the index from the data.
- `insert`: Add a batch of data to the index.
- `search`: Search for the top_k most similar data indices to the query.
- `serialize`: Serialize the index to the disk.
- `clear`: Clear the index and remove the serialized index files.
- `__len__`: Return the number of data in the index.
- `is_addable`: Return whether the index is addable.
"""
cfg: RetrieverIndexBaseConfig
[docs]
@abstractmethod
def build_index(self, data: Iterable[Any]) -> None:
"""Build the index.
The index will be serialized automatically if the `index_path` is set.
:param data: The data to build the index.
:type data: Iterable[Any]
:return: None
"""
return
[docs]
def insert_batch(
self,
data: Iterable[Any],
batch_size: int = None,
serialize: bool = True,
) -> None:
"""Add data to the index in batches.
This method will automatically perform the `serialize` method if the `index_path` is set.
:param data: The data to add.
:type data: Iterable[Any]
:param batch_size: The batch size to add data to the index. Defaults to self.batch_size.
:type batch_size: int
:param serialize: Whether to serialize the index after adding data. Defaults to True.
:type serialize: bool
:return: None
"""
assert self.is_addable, "Current index is not addable."
batch_size = batch_size or self.cfg.batch_size
def get_data_batch() -> Generator[list[Any], None, None]:
"""A helper function that yields data in batches."""
batch = []
for item in data:
batch.append(item)
if len(batch) == batch_size:
yield batch
batch = []
if batch:
yield batch
# iterate over the data in batches
p_logger = SimpleProgressLogger(logger, interval=self.cfg.log_interval)
for batch in get_data_batch(data):
self.insert(batch, serialize=False)
p_logger.update(step=len(batch), desc="Adding data")
# serialize if the `index_path` is set
if (self.cfg.index_path is not None) and serialize:
self.save_to_local()
return
[docs]
@abstractmethod
def insert(
self,
data: list[Any],
serialize: bool = True,
) -> None:
"""Add a batch of data to the index.
:param data: The data to add.
:type data: list[Any]
:param serialize: Whether to serialize the index after adding data. Defaults to True.
:type serialize: bool
:return: None
"""
return
@TIME_METER("retrieve", "index", "search")
def search_batch(
self,
query: Iterable[Any],
top_k: int = 10,
batch_size: Optional[int] = None,
**search_kwargs,
) -> tuple[np.ndarray, np.ndarray]:
"""Search for the top_k most similar data indices to the query.
This method will search the index in batches.
:param query: The query data.
:type query: list[Any]
:param top_k: The number of most similar data indices to return, defaults to 10.
:type top_k: int, optional
:param batch_size: The batch size to search. Defaults to self.batch_size.
:type batch_size: Optional[int]
:param search_kwargs: Additional search arguments.
:type search_kwargs: Any
:return: The indices and scores of the top_k most similar data indices.
:rtype: tuple[np.ndarray, np.ndarray]
"""
def get_batch():
"""Yield data in batches."""
batch = []
for item in query:
batch.append(item)
if len(batch) == batch_size:
yield batch
batch = []
if batch:
yield batch
scores = []
indices = []
batch_size = batch_size or self.cfg.batch_size
total = len(query) if hasattr(query, "__len__") else None
p_logger = SimpleProgressLogger(logger, total, interval=self.cfg.log_interval)
for q in get_batch():
r = self.search(q, top_k, **search_kwargs)
scores.append(r[1])
indices.append(r[0])
p_logger.update(step=batch_size, desc="Searching")
scores = np.concatenate(scores, axis=0)
indices = np.concatenate(indices, axis=0)
return indices, scores
[docs]
@abstractmethod
def search(
self,
query: list[Any],
top_k: int,
**search_kwargs,
) -> tuple[np.ndarray, np.ndarray]:
"""Search for the top_k most similar data indices to the query.
:param query: The query data.
:type query: list[Any]
:param top_k: The number of most similar data indices to return, defaults to 10.
:type top_k: int, optional
:param search_kwargs: Additional search arguments.
:type search_kwargs: Any
:return: The indices and scores of the top_k most similar data indices.
:rtype: tuple[np.ndarray, np.ndarray]
"""
return
@property
@abstractmethod
def is_addable(self) -> bool:
return
[docs]
@abstractmethod
def save_to_local(self, index_path: str = None) -> None:
"""Serialize the index to self.index_path.
If the `index_path` is given, the index will be serialized to the `index_path`.
:param index_path: The path to serialize the index. Defaults to self.index_path.
:type index_path: str, optional
"""
return
[docs]
@staticmethod
def load_from_local(index_path: str) -> None:
"""Load the index from the local path.
:param index_path: The path to load the index.
:type index_path: str
"""
assert os.path.exists(index_path), f"Index path {index_path} does not exist."
# load cls_id
id_path = os.path.join(index_path, "cls.id")
assert os.path.exists(id_path), f"Index ID file {id_path} does not exist."
index_name = open(id_path, "r").read().strip()
index_cls = RETRIEVER_INDEX[index_name]["item"]
# load configuration
config_cls = RETRIEVER_INDEX[index_name]["config_class"]
config_path = os.path.join(index_path, "config.yaml")
assert os.path.exists(
config_path
), f"Configuration file {config_path} does not exist."
cfg = config_cls.load(config_path)
cfg.index_path = index_path
# load the index
index = index_cls(cfg)
return index
[docs]
@abstractmethod
def clear(self) -> None:
"""Reset the index and remove the serialized index files."""
return
@abstractmethod
def __len__(self) -> int:
"""Return the number of data in the index."""
return
@property
@abstractmethod
def infimum(self) -> float:
"""Return the infimum of the similarity scores for the index."""
return
@property
@abstractmethod
def supremum(self) -> float:
"""Return the supremum of the similarity scores for the index."""
return
@configure
class DenseIndexBaseConfig(RetrieverIndexBaseConfig):
"""The configuration for the `DenseIndexBase`.
:param query_encoder_config: Configuration for the query encoder. Default: None.
:type query_encoder_config: EncoderConfig
:param passage_encoder_config: Configuration for the passage encoder. Default: None.
:type passage_encoder_config: EncoderConfig
:param distance_function: The distance function to use. Defaults to "IP".
available choices are "IP", "L2", and "COS.
:type distance_function: str
"""
query_encoder_config: EncoderConfig = field(default_factory=EncoderConfig) # type: ignore
passage_encoder_config: EncoderConfig = field(default_factory=EncoderConfig) # type: ignore
distance_function: Annotated[str, Choices("IP", "L2", "COS")] = "IP"
class DenseIndexBase(RetrieverIndexBase):
"""The base class for all dense indexes."""
def __init__(self, cfg: DenseIndexBaseConfig):
# load the encoder
self.query_encoder = ENCODERS.load(cfg.query_encoder_config)
self.passage_encoder = ENCODERS.load(cfg.passage_encoder_config)
# set basic args
self.distance_function = cfg.distance_function
return
def encode_data_batch(
self, data: Iterable[Any], is_query: bool = False, use_memmap: bool = True
) -> np.ndarray:
"""A helper function that encodes all data into embeddings.
:param data: The data to encode.
:type data: Iterable[dict[str, Any]]
:param is_query: Whether the data is query data.
If True, the query encoder will be used.
If False, the passage encoder will be used.
Defaults to False.
:type is_query: bool
:param use_memmap: Whether to use memory mapping for the embeddings.
If True, the embeddings will be saved to disk and loaded as a memory map.
If False, the embeddings will be kept in memory.
Note that you should remove the memory map file after use.
Defaults to True.
:type use_memmap: bool
:return: The embeddings of the data.
:rtype: np.ndarray
"""
# prepare_mmap_path
if use_memmap:
if self.cfg.index_path is not None:
mmap_path = os.path.join(self.cfg.index_path, "embeddings")
else:
mmap_path = os.path.join(FLEXRAG_CACHE_DIR, "embeddings")
os.makedirs(mmap_path, exist_ok=True)
else:
mmap_path = None
def get_batch() -> Generator[tuple[list[int], list[Any]], None, None]:
"""A helper function that yields data in batches."""
batch = []
for item in data:
batch.append(item)
if len(batch) == self.cfg.batch_size:
yield batch
batch = []
if batch:
yield batch
# encode the data
p_logger = SimpleProgressLogger(logger, interval=self.cfg.log_interval)
embeddings = []
n_embeddings = 0
for batch in get_batch():
emb = self.encode_data(batch, is_query=is_query)
if mmap_path is not None:
file_name = os.path.join(mmap_path, f"{uuid4()}.npy")
np.save(file_name, emb)
embeddings.append(file_name)
else:
embeddings.append(emb)
n_embeddings += emb.shape[0]
p_logger.update(step=len(batch), desc="Encoding data")
# concatenate the embeddings
if isinstance(embeddings[0], str):
logger.info("Copying embeddings to memory map")
emb_path = embeddings[0]
emb = np.load(emb_path)
emb_map = np.memmap(
os.path.join(mmap_path, f"embeddings.npy"),
dtype=np.float32,
mode="w+",
shape=(n_embeddings, emb.shape[1]),
)
idx = 0
for emb_path in embeddings:
emb = np.load(emb_path)
emb_map[idx : idx + emb.shape[0]] = emb
idx += emb.shape[0]
del emb
os.remove(emb_path)
embeddings = emb_map
else:
embeddings = np.concatenate(embeddings, axis=0)
return embeddings
def encode_data(self, data: list[Any], is_query: bool = False) -> np.ndarray:
"""A helper function that encodes the data using the encoder.
:param data: The data to be encoded.
:type data: list[Any]
:param is_query: Whether the data is query data.
If True, the query encoder will be used.
If False, the passage encoder will be used.
Defaults to False.
:type is_query: bool
:return: The encoded data.
:rtype: np.ndarray
"""
# set the encoder
if is_query:
assert self.query_encoder is not None, "Query encoder is not set."
encoder = self.query_encoder
else:
assert self.passage_encoder is not None, "Passage encoder is not set."
encoder = self.passage_encoder
# encode the data
embeds = encoder.encode(data)
return embeds.astype("float32")
def add_embeddings_batch(self, embeds: np.ndarray) -> None:
"""A helper function that adds embeddings to the index in batches.
This method will not serialize the index automatically.
Thus, you should call the `serialize` method after adding all data.
:param embeds: The embeddings to add.
:type embeds: np.ndarray
:return: None
"""
p_logger = SimpleProgressLogger(logger, embeds.shape[0], self.cfg.log_interval)
for i in range(0, embeds.shape[0], self.cfg.batch_size):
batch_embeds = embeds[i : i + self.cfg.batch_size]
self.add_embeddings(batch_embeds)
p_logger.update(step=batch_embeds.shape[0], desc="Adding embeddings")
return
@abstractmethod
def add_embeddings(self, embeds: np.ndarray) -> None:
"""A helper function that adds embeddings to the index.
:param embeds: The embeddings to add.
:type embeds: np.ndarray
:return: None
"""
return
def insert(self, data: list[Any]) -> None:
embeddings = self.encode_data(data, is_query=False)
self.add_embeddings(embeddings)
return
@property
@abstractmethod
def embedding_size(self) -> int:
"""Return the embedding size of the index."""
return
@staticmethod
def check_configuration(cfg: DenseIndexBaseConfig, token: str = None) -> None:
"""A helper function that checks the configuration of the index.
This function is useful before saving the index to the HuggingFace hub.
It checks if the encoder is available in the HuggingFace hub.
If the encoder is not available, it will raise a warning.
:param cfg: The configuration of the index.
:type cfg: DenseIndexBaseConfig
:param token: The token to access the HuggingFace hub. Defaults to None.
:type token: str
"""
client = HfApi(token=token)
# check the query encoder
if cfg.query_encoder_config.encoder_type is None:
logger.warning(
"Query encoder is not provided. "
"Please make sure loading the appropriate encoder when loading the retriever."
)
elif cfg.query_encoder_config.encoder_type == "hf":
if not client.repo_exists(cfg.query_encoder_config.hf_config.model_path):
logger.warning(
"Query encoder model is not available in the HuggingFace model hub."
"Please make sure loading the appropriate encoder when loading the retriever."
)
else:
logger.warning(
"Query encoder is not a model hosted on the HuggingFace model hub."
"Please make sure loading the appropriate encoder when loading the retriever."
)
# check the passage encoder
if cfg.passage_encoder_config.encoder_type is None:
logger.warning(
"Passage encoder is not provided. "
"Please make sure loading the appropriate encoder when loading the retriever."
)
elif cfg.passage_encoder_config.encoder_type == "hf":
if not client.repo_exists(cfg.passage_encoder_config.hf_config.model_path):
logger.warning(
"Passage encoder model is not available in the HuggingFace model hub."
"Please make sure loading the appropriate encoder when loading the retriever."
)
else:
logger.warning(
"Passage encoder is not a model hosted on the HuggingFace model hub."
"Please make sure loading the appropriate encoder when loading the retriever."
)
return
@property
def infimum(self) -> float:
# For L2 distance, the infimum is 0.0
# For IP distance, the infimum is -infinity
# For COS distance, the infimum is -1.0
return 0.0
@property
def supremum(self) -> float:
# For L2 distance, the supremum is infinity
# For IP distance, the supremum is infinity
# For COS distance, the supremum is 1.0
if self.distance_function == "COS":
return 1.0
return float("inf")
RETRIEVER_INDEX = Register[RetrieverIndexBase]("index")