flexrag.entrypoints.prepare_retriever 源代码
from dataclasses import field
from typing import Annotated
import hydra
from hydra.core.config_store import ConfigStore
from flexrag.datasets import RAGCorpusDataset, RAGCorpusDatasetConfig
from flexrag.retriever import (
ElasticRetriever,
ElasticRetrieverConfig,
FlexRetriever,
FlexRetrieverConfig,
TypesenseRetriever,
TypesenseRetrieverConfig,
)
from flexrag.utils import LOGGER_MANAGER, Choices, configure, extract_config
logger = LOGGER_MANAGER.get_logger("flexrag.prepare_index")
# fmt: off
[文档]
@configure
class Config(RAGCorpusDatasetConfig):
# retriever configs
retriever_type: Annotated[str, Choices("flex", "elastic", "typesense")] = "flex"
flex_config: FlexRetrieverConfig = field(default_factory=FlexRetrieverConfig)
elastic_config: ElasticRetrieverConfig = field(default_factory=ElasticRetrieverConfig)
typesense_config: TypesenseRetrieverConfig = field(default_factory=TypesenseRetrieverConfig)
reinit: bool = False
# fmt: on
cs = ConfigStore.instance()
cs.store(name="default", node=Config)
@hydra.main(version_base="1.3", config_path=None, config_name="default")
def main(cfg: Config):
cfg = extract_config(cfg, Config)
# load retriever
match cfg.retriever_type:
case "flex":
retriever = FlexRetriever(cfg.flex_config)
case "elastic":
retriever = ElasticRetriever(cfg.elastic_config)
case "typesense":
retriever = TypesenseRetriever(cfg.typesense_config)
case _:
raise ValueError(f"Unsupported retriever type: {cfg.retriever_type}")
# add passages
if cfg.reinit and (len(retriever) > 0):
logger.warning("Reinitializing retriever and removing all passages")
retriever.clear()
retriever.add_passages(passages=RAGCorpusDataset(cfg))
return
if __name__ == "__main__":
main()