import os
import shutil
from typing import Annotated, Any, Iterable, Optional
import bm25s
import numpy as np
from flexrag.utils import LOGGER_MANAGER, Choices, configure
from flexrag.utils.configure import extract_config
from .index_base import RETRIEVER_INDEX, RetrieverIndexBase, RetrieverIndexBaseConfig
logger = LOGGER_MANAGER.get_logger("flexrag.retrievers.index.bm25")
[docs]
@configure
class BM25IndexConfig(RetrieverIndexBaseConfig):
"""Configuration class for BM25Index.
:param method: BM25S method. Default: "lucene".
Available options: "atire", "bm25l", "bm25+", "lucene", "robertson".
:type method: str
:param idf_method: IDF method. Default: None.
Available options: "atire", "bm25l", "bm25+", "lucene", "robertson".
:type idf_method: Optional[str]
:param backend: Backend for BM25S. Default: "auto".
Available options: "numpy", "numba", "auto".
:type backend: str
:param k1: BM25S parameter k1. Default: 1.5.
:type k1: float
:param b: BM25S parameter b. Default: 0.75.
:type b: float
:param delta: BM25S parameter delta. Default: 0.5.
:type delta: float
:param lang: Language for Tokenization. Default: "english".
:type lang: str
"""
method: Annotated[
str,
Choices(
"atire",
"bm25l",
"bm25+",
"lucene",
"robertson",
),
] = "lucene"
idf_method: Optional[
Annotated[
str,
Choices(
"atire",
"bm25l",
"bm25+",
"lucene",
"robertson",
),
]
] = None
backend: Annotated[str, Choices("numpy", "numba", "auto")] = "auto"
k1: float = 1.5
b: float = 0.75
delta: float = 0.5
lang: str = "english"
[docs]
@RETRIEVER_INDEX("bm25", config_class=BM25IndexConfig)
class BM25Index(RetrieverIndexBase):
"""BM25Index is a index that retrieves passages using the BM25 algorithm.
The implementation is based on the `bm25s <https://github.com/xhluca/bm25s>`_ project.
"""
def __init__(self, cfg: BM25IndexConfig) -> None:
self.cfg = extract_config(cfg, BM25IndexConfig)
try:
import Stemmer
self._stemmer = Stemmer.Stemmer(cfg.lang)
except:
logger.warning(
"Stemmer is not available. "
"You can install `PyStemmer` by `pip install PyStemmer` for better results."
)
self._stemmer = None
# initialize the index
self.index = bm25s.BM25(
method=cfg.method,
idf_method=cfg.idf_method,
backend=cfg.backend,
k1=cfg.k1,
b=cfg.b,
delta=cfg.delta,
)
# load the index if index_path is provided
if self.cfg.index_path is not None:
if os.path.exists(self.cfg.index_path):
logger.info(f"Loading index from {self.cfg.index_path}")
try:
self.index = bm25s.BM25.load(self.cfg.index_path, mmap=True)
except:
raise FileNotFoundError(
f"Unable to load index from {self.cfg.index_path}"
)
return
[docs]
def build_index(self, data: Iterable[Any]) -> None:
# prepare the data
logger.info("Preparing the passages for indexing.")
items = list(data)
# tokenize and build index
logger.info("Building the index.")
indexed_tokens = bm25s.tokenize(
items, stopwords=self.cfg.lang, stemmer=self._stemmer
)
self.index.index(indexed_tokens)
# serialize index
if self.cfg.index_path is not None:
self.save_to_local()
return
[docs]
def insert(self, data: list[Any]) -> None:
raise NotImplementedError(
"BM25Index currently does not support inserting data."
)
[docs]
def search(
self,
query: list[str], # bm25 algorithm can only process string
top_k: int,
**search_kwargs,
) -> tuple[np.ndarray, np.ndarray]:
query_tokens = bm25s.tokenize(query, stemmer=self._stemmer, show_progress=False)
contexts, scores = self.index.retrieve(
query_tokens,
k=top_k,
show_progress=False,
**search_kwargs,
)
return contexts, scores
@property
def is_addable(self) -> bool:
return False
[docs]
def save_to_local(self, index_path: str = None) -> None:
# check if the index is serializable
if index_path is not None:
self.cfg.index_path = index_path
assert self.cfg.index_path is not None, "`index_path` is not set."
if not os.path.exists(index_path):
os.makedirs(self.cfg.index_path)
logger.info(f"Serializing index to {self.cfg.index_path}")
# save the configuration
config_path = os.path.join(self.cfg.index_path, "config.yaml")
self.cfg.dump(config_path)
id_path = os.path.join(self.cfg.index_path, "cls.id")
with open(id_path, "w", encoding="utf-8") as f:
f.write(self.__class__.__name__)
# serialize the index
self.index.save(self.cfg.index_path)
return
[docs]
def clear(self) -> None:
del self.index.scores
del self.index.vocab_dict
if self.cfg.index_path is not None:
if os.path.exists(self.cfg.index_path):
shutil.rmtree(self.cfg.index_path)
return
def __len__(self) -> int:
if hasattr(self.index, "scores"):
return self.index.scores.get("num_docs", 0)
return 0
@property
def infimum(self) -> float:
return 0.0
@property
def supremum(self) -> float:
return float("inf")