Source code for flexrag.retriever.index.faiss_index

import os
import shutil
from copy import deepcopy
from dataclasses import field
from typing import Annotated, Any, Iterable, Optional

import faiss
import numpy as np

from flexrag.models import ENCODERS
from flexrag.utils import LOGGER_MANAGER, Choices, configure
from flexrag.utils.configure import extract_config

from .index_base import RETRIEVER_INDEX, DenseIndexBase, DenseIndexBaseConfig

logger = LOGGER_MANAGER.get_logger("flexrag.retriever.index.faiss")


[docs] @configure class FaissIndexConfig(DenseIndexBaseConfig): """The configuration for the `FaissIndex`. :param index_type: Building param: the type of the index. Defaults to "auto". available choices are "FLAT", "IVF", "PQ", "IVFPQ", and "auto". If set to "auto", the index will be set to "IVF{n_list},PQ{embedding_size//2}x4fs". :type index_type: str :param n_subquantizers: Building param: the number of subquantizers. Defaults to 8. This parameter is only used when the index type is "PQ" or "IVFPQ". :type n_subquantizers: int :param n_bits: Building param: the number of bits per subquantizer. Defaults to 8. This parameter is only used when the index type is "PQ" or "IVFPQ". :type n_bits: int :param n_list: Building param: the number of cells. Defaults to 1000. This parameter is only used when the index type is "IVF" or "IVFPQ". :type n_list: int :param factory_str: Building param: the factory string to build the index. Defaults to None. If set, the `index_type` will be ignored. :type factory_str: Optional[str] :param index_train_num: Building param: the number of data used to train the index. Defaults to -1. If set to -1, all data will be used to train the index. :type index_train_num: int :param n_probe: Inference param: the number of probes. Defaults to None. If not set, the number of probes will be set to `n_list // 8`. This parameter is only used when the index type is "IVF" or "IVFPQ". :type n_probe: Optional[int] :param device_id: Inference param: the device(s) to use. Defaults to []. [] means CPU. If set, the index will be accelerated with GPU. :type device_id: list[int] :param k_factor: Inference param: the k factor for search. Defaults to 10. :type k_factor: int :param polysemous_ht: Inference param: the polysemous hash table. Defaults to 0. :type polysemous_ht: int :param efSearch: Inference param: the efSearch for HNSW. Defaults to 100. :type efSearch: int """ index_type: Annotated[str, Choices("FLAT", "IVF", "PQ", "IVFPQ", "auto")] = "auto" n_subquantizers: int = 8 n_bits: int = 8 n_list: int = 1000 factory_str: Optional[str] = None index_train_num: int = -1 # Inference Arguments n_probe: Optional[int] = None device_id: list[int] = field(default_factory=list) k_factor: int = 10 polysemous_ht: int = 0 efSearch: int = 100
[docs] @RETRIEVER_INDEX("faiss", config_class=FaissIndexConfig) class FaissIndex(DenseIndexBase): """FaissIndex employs `faiss <https://github.com/facebookresearch/faiss>`_ library to build and search indexes with embeddings. FaissIndex supports both CPU and GPU acceleration. FaissIndex supports various index types, including FLAT, IVF, PQ, IVFPQ, and auto. FaissIndex provides a flexible and efficient way to build and search indexes with embeddings. """ cfg: FaissIndexConfig def __init__(self, cfg: FaissIndexConfig) -> None: super().__init__(cfg) self.cfg = extract_config(cfg, FaissIndexConfig) # prepare index self.index = None # load the index if index_path is provided if self.cfg.index_path is not None: if os.path.exists(self.cfg.index_path): logger.info(f"Loading index from {self.cfg.index_path}") try: index_path = os.path.join(self.cfg.index_path, "index.faiss") cpu_index = faiss.read_index(index_path, faiss.IO_FLAG_MMAP) self.index = self._set_index(cpu_index) except: raise FileNotFoundError( f"Unable to load index from {self.cfg.index_path}" ) return
[docs] def build_index(self, data: Iterable[Any]) -> None: self.clear() embeddings = self.encode_data_batch(data, is_query=False) self.index = self._prepare_index( index_type=self.cfg.index_type, distance_function=self.cfg.distance_function, embedding_size=embeddings.shape[1], embedding_length=embeddings.shape[0], n_list=self.cfg.n_list, n_subquantizers=self.cfg.n_subquantizers, n_bits=self.cfg.n_bits, factory_str=self.cfg.factory_str, ) self._train_index(embeddings) self.add_embeddings_batch(embeddings) if isinstance(embeddings, np.memmap): emb_path = embeddings.filename os.remove(emb_path) return
def _prepare_index( self, index_type: str, distance_function: str, embedding_size: int, # the dimension of the embeddings embedding_length: int, # the number of the embeddings n_list: int, # the number of cells n_subquantizers: int, # the number of subquantizers n_bits: int, # the number of bits per subquantizer factory_str: Optional[str] = None, ): # prepare distance function match distance_function: case "IP": basic_index = faiss.IndexFlatIP(embedding_size) basic_metric = faiss.METRIC_INNER_PRODUCT case "COS": basic_index = faiss.IndexFlatIP(embedding_size) basic_metric = faiss.METRIC_INNER_PRODUCT case "L2": basic_index = faiss.IndexFlatL2(embedding_size) basic_metric = faiss.METRIC_L2 case _: raise ValueError(f"Unknown distance function: {distance_function}") if index_type == "auto": n_list = 2 ** int(np.log2(np.sqrt(embedding_length))) factory_str = f"IVF{n_list},PQ{embedding_size//2}x4fs" logger.info(f"Auto set index to {factory_str}") logger.info( f"We recommend to set n_probe to {n_list//8} for better inference performance" ) if factory_str is not None: # using string factory to build the index index = faiss.index_factory( embedding_size, factory_str, basic_metric, ) else: # prepare optimized index match index_type: case "FLAT": index = basic_index case "IVF": index = faiss.IndexIVFFlat( basic_index, embedding_size, n_list, basic_metric, ) case "PQ": index = faiss.IndexPQ( embedding_size, n_subquantizers, n_bits, ) case "IVFPQ": index = faiss.IndexIVFPQ( basic_index, embedding_size, n_list, n_subquantizers, n_bits, ) case _: raise ValueError(f"Unknown index type: {index_type}") # post process index = self._set_index(index) return index def _train_index(self, embeddings: np.ndarray) -> None: if self.is_trained: logger.info("Index is trained already.") return logger.info("Training index") if (self.cfg.index_train_num >= embeddings.shape[0]) or ( self.cfg.index_train_num == -1 ): if embeddings.dtype != np.float32: embeddings = embeddings.astype("float32") self.index.train(embeddings) else: selected_indices = np.random.choice( embeddings.shape[0], self.cfg.index_train_num, replace=False, ) selected_indices = np.sort(selected_indices) selected_embeddings = embeddings[selected_indices].astype("float32") self.index.train(selected_embeddings) return
[docs] def add_embeddings(self, embeddings: np.ndarray) -> None: embeddings = embeddings.astype("float32") assert self.is_trained, "Index should be trained first" self.index.add(embeddings) return
def _prepare_search_params(self, **kwargs): """A helper function to prepare search parameters for the index. :return: The search parameters for the index. :rtype: faiss.SearchParameters """ # set search kwargs k_factor = kwargs.get("k_factor", self.cfg.k_factor) n_probe = kwargs.get("n_probe", self.cfg.n_probe) if n_probe is None: n_probe = getattr(self.index, "nlist", 256) // 8 polysemous_ht = kwargs.get("polysemous_ht", self.cfg.polysemous_ht) efSearch = kwargs.get("efSearch", self.cfg.efSearch) def get_search_params(index): if isinstance(index, faiss.IndexRefine): params = faiss.IndexRefineSearchParameters( k_factor=k_factor, base_index_params=get_search_params( faiss.downcast_index(index.base_index) ), ) elif isinstance(index, faiss.IndexPreTransform): params = faiss.SearchParametersPreTransform( index_params=get_search_params(faiss.downcast_index(index.index)) ) elif isinstance(index, faiss.IndexIVFPQ): if hasattr(index, "quantizer"): params = faiss.IVFPQSearchParameters( nprobe=n_probe, polysemous_ht=polysemous_ht, quantizer_params=get_search_params( faiss.downcast_index(index.quantizer) ), ) else: params = faiss.IVFPQSearchParameters( nprobe=n_probe, polysemous_ht=polysemous_ht ) elif isinstance(index, faiss.IndexIVF): if hasattr(index, "quantizer"): params = faiss.SearchParametersIVF( nprobe=n_probe, quantizer_params=get_search_params( faiss.downcast_index(index.quantizer) ), ) else: params = faiss.SearchParametersIVF(nprobe=n_probe) elif isinstance(index, faiss.IndexHNSW): params = faiss.SearchParametersHNSW(efSearch=efSearch) elif isinstance(index, faiss.IndexPQ): params = faiss.SearchParametersPQ(polysemous_ht=polysemous_ht) else: params = None return params return get_search_params(self.index)
[docs] def search( self, query: list[Any], top_k: int, **search_kwargs, ) -> tuple[np.ndarray, np.ndarray]: query_vectors = self.encode_data(query, is_query=True) search_params = self._prepare_search_params(**search_kwargs) scores, indices = self.index.search(query_vectors, top_k, params=search_params) return indices, scores
[docs] def save_to_local(self, index_path: str = None) -> None: # check if the index is serializable if index_path is not None: self.cfg.index_path = index_path assert self.cfg.index_path is not None, "`index_path` is not set." assert self.index.is_trained, "Index should be trained first." if not os.path.exists(index_path): os.makedirs(self.cfg.index_path) logger.info(f"Serializing index to {self.cfg.index_path}") # save the configuration cfg = deepcopy(self.cfg) cfg.query_encoder_config = ENCODERS.squeeze(cfg.query_encoder_config) cfg.passage_encoder_config = ENCODERS.squeeze(cfg.passage_encoder_config) cfg.index_path = "" config_path = os.path.join(self.cfg.index_path, "config.yaml") cfg.dump(config_path) id_path = os.path.join(self.cfg.index_path, "cls.id") with open(id_path, "w", encoding="utf-8") as f: f.write(self.__class__.__name__) # serialize the index index_path = os.path.join(index_path, "index.faiss") if self.support_gpu: cpu_index = faiss.index_gpu_to_cpu(self.index) else: cpu_index = self.index faiss.write_index(cpu_index, index_path) return
[docs] def clear(self): if self.index is None: return self.index.reset() if self.cfg.index_path is not None: if os.path.exists(self.cfg.index_path): shutil.rmtree(self.cfg.index_path) return
@property def embedding_size(self) -> int: if self.index is not None: return self.index.d if self.passage_encoder is not None: return self.passage_encoder.embedding_size if self.query_encoder is not None: return self.query_encoder.embedding_size raise ValueError("Index is not initialized.") @property def is_trained(self) -> bool: if self.index is None: return False return self.index.is_trained @property def is_addable(self) -> bool: return self.is_trained @property def support_gpu(self) -> bool: return hasattr(faiss, "GpuMultipleClonerOptions") and ( len(self.cfg.device_id) > 0 ) def _set_index(self, index): if self.support_gpu: logger.info("Accelerating index with GPU.") option = faiss.GpuMultipleClonerOptions() option.useFloat16 = True option.shard = True if isinstance(index, faiss.IndexIVFFlat): option.common_ivf_quantizer = True index = faiss.index_cpu_to_gpus_list( index, co=option, gpus=self.cfg.device_id, ngpu=len(self.cfg.device_id), ) elif len(self.cfg.device_id) > 0: logger.warning( "The installed faiss does not support GPU acceleration. " "Please install faiss-gpu." ) return index def __len__(self) -> int: if self.index is None: return 0 return self.index.ntotal