from typing import Annotated, Generator, Iterable, Optional
from flexrag.utils import (
LOGGER_MANAGER,
TIME_METER,
Choices,
SimpleProgressLogger,
configure,
)
from flexrag.utils.configure import extract_config
from flexrag.utils.dataclasses import Context, RetrievedContext
from .retriever_base import RETRIEVERS, EditableRetriever, EditableRetrieverConfig
logger = LOGGER_MANAGER.get_logger("flexrag.retrievers.typesense")
[docs]
@configure
class TypesenseRetrieverConfig(EditableRetrieverConfig):
"""Configuration class for TypesenseRetriever.
:param host: Host of the Typesense server. Default: "localhost".
:type host: str
:param port: Port of the Typesense server. Default: 8108.
:type port: int
:param protocol: Protocol of the Typesense server. Default: "http".
Available options: "https", "http".
:type protocol: str
:param api_key: API key for the Typesense server. Required.
:type api_key: str
:param index_name: Name of the Typesense collection. Required.
:type index_name: str
:param timeout: Timeout for the connection. Default: 200.0.
:type timeout: float
"""
host: str = "localhost"
port: int = 8108
protocol: Annotated[str, Choices("https", "http")] = "http"
api_key: Optional[str] = None
index_name: Optional[str] = None
timeout: float = 200.0
[docs]
@RETRIEVERS("typesense", config_class=TypesenseRetrieverConfig)
class TypesenseRetriever(EditableRetriever):
def __init__(self, cfg: TypesenseRetrieverConfig) -> None:
super().__init__(cfg)
self.cfg = extract_config(cfg, TypesenseRetrieverConfig)
assert self.cfg.api_key is not None, "`api_key` must be provided"
assert self.cfg.index_name is not None, "`index_name` must be provided"
# load database
import typesense
self.typesense = typesense
self.client = typesense.Client(
{
"nodes": [
{
"host": cfg.host,
"port": cfg.port,
"protocol": cfg.protocol,
}
],
"api_key": cfg.api_key,
"connection_timeout_seconds": cfg.timeout,
}
)
self.index_name = cfg.index_name
return
@TIME_METER("typesense", "add_passages")
def add_passages(self, passages: Iterable[Context]) -> None:
def get_batch() -> Generator[list[dict[str, str]], None, None]:
batch = []
for passage in passages:
if len(batch) == self.cfg.batch_size:
yield batch
batch = []
data = passage.data.copy()
data[self.id_field_name] = passage.context_id
batch.append(data)
if batch:
yield batch
return
# create collection if not exists
if self.index_name not in self.indices:
schema = {
"name": self.index_name,
"fields": [
{"name": ".*", "type": "auto", "index": True, "infix": True}
],
}
self.client.collections.create(schema)
# import documents
p_logger = SimpleProgressLogger(logger, interval=self.cfg.log_interval)
for batch in get_batch():
r = self.client.collections[self.index_name].documents.import_(batch)
assert all([i["success"] for i in r])
p_logger.update(len(batch), desc="Adding passages")
logger.info("Finished adding passages.")
return
@TIME_METER("typesense", "search")
def search(
self,
query: list[str],
**search_kwargs,
) -> list[list[RetrievedContext]]:
# prepare search parameters
search_params = [
{
"collection": self.index_name,
"q": q,
"query_by": ",".join(self.fields),
"per_page": search_kwargs.pop("top_k", self.cfg.top_k),
**search_kwargs,
}
for q in query
]
# search
try:
responses = self.client.multi_search.perform(
search_queries={"searches": search_params},
common_params={},
)
except self.typesense.exceptions.TypesenseClientError as e:
logger.error(f"Typesense error: {e}")
logger.error(f"Current query: {query}")
return [[] for _ in query]
# form final results
retrieved = []
for q, response in zip(query, responses["results"]):
retrieved.append([])
for i in response["hits"]:
data = i["document"]
context_id = data.pop(self.id_field_name)
retrieved[-1].append(
RetrievedContext(
context_id=context_id,
retriever="Typesense",
query=q,
data=i["document"],
source=self.index_name,
score=i["text_match"],
)
)
return retrieved
[docs]
def clear(self) -> None:
if self.index_name in self.indices:
self.client.collections[self.index_name].delete()
return
def __len__(self) -> int:
info = self.client.collections.retrieve()
info = [i for i in info if i["name"] == self.index_name]
if len(info) > 0:
return info[0]["num_documents"]
return 0
@property
def indices(self) -> list[str]:
return [i["name"] for i in self.client.collections.retrieve()]
@property
def fields(self) -> list[str]:
return [
i["name"]
for i in self.client.collections[self.index_name].retrieve()["fields"]
if i["name"] != ".*"
]
@property
def id_field_name(self) -> str:
return "id" # `id` is the reserved field name in Typesense