Source code for flexrag.retriever.web_retrievers.web_retriever

import json
import time
from abc import abstractmethod

from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed

from flexrag.utils import LOGGER_MANAGER, TIME_METER, SimpleProgressLogger, configure
from flexrag.utils.configure import extract_config
from flexrag.utils.dataclasses import RetrievedContext

from ..retriever_base import (
    RETRIEVERS,
    RetrieverBase,
    RetrieverBaseConfig,
    batched_cache,
)
from .web_reader import WEB_READERS, WebReaderConfig
from .web_seeker import SEARCH_ENGINES, SearchEngineConfig

logger = LOGGER_MANAGER.get_logger("flexrag.retrievers.web_retriever")


def _save_error_state(retry_state: RetryCallState) -> Exception:
    args = {
        "args": retry_state.args,
        "kwargs": retry_state.kwargs,
    }
    with open("web_retriever_error_state.json", "w", encoding="utf-8") as f:
        json.dump(args, f)
    raise retry_state.outcome.exception()


[docs] @configure class WebRetrieverBaseConfig(RetrieverBaseConfig): """The configuration for the ``WebRetrieverBase``. :param retry_times: The number of times to retry. Default is 3. :type retry_times: int :param retry_delay: The delay between retries. Default is 0.5. :type retry_delay: float """ retry_times: int = 3 retry_delay: float = 0.5
[docs] class WebRetrieverBase(RetrieverBase): """The base class for the ``WebRetriever``. The WebRetriever is used to retrieve relevant information from the web. The subclasses should implement the ``search_item`` method. """ cfg: WebRetrieverBaseConfig @TIME_METER("web_retriever", "search") @batched_cache def search( self, query: list[str] | str, delay: float = 0.1, **search_kwargs, ) -> list[list[RetrievedContext]]: """As most web apis do not provide batch interface, we search the queries one by one using the ``search_item`` method.""" if isinstance(query, str): query = [query] # prepare search method retry_times = search_kwargs.get("retry_times", self.cfg.retry_times) retry_delay = search_kwargs.get("retry_delay", self.cfg.retry_delay) if retry_times > 1: search_func = retry( stop=stop_after_attempt(retry_times), wait=wait_fixed(retry_delay), retry_error_callback=_save_error_state, )(self.search_item) else: search_func = self.search_item # search & parse results = [] p_logger = SimpleProgressLogger(logger, len(query), self.cfg.log_interval) top_k = search_kwargs.get("top_k", self.cfg.top_k) for q in query: time.sleep(delay) p_logger.update(1, "Searching") results.append(search_func(q, top_k, **search_kwargs)) return results
[docs] @abstractmethod def search_item( self, query: str, top_k: int, **search_kwargs, ) -> list[RetrievedContext]: """Search the query from the web. :param query: The query to search. :type query: str :param top_k: The number of documents to return. :type top_k: int :return: The retrieved contexts. :rtype: list[RetrievedContext] """ return
[docs] @configure class SimpleWebRetrieverConfig( WebRetrieverBaseConfig, WebReaderConfig, SearchEngineConfig, ): """The configuration for the ``SimpleWebRetriever``."""
[docs] @RETRIEVERS("simple_web", config_class=SimpleWebRetrieverConfig) class SimpleWebRetriever(WebRetrieverBase): """SimpleWebRetriever seeks most relevant web pages using existing search engine and reads the content using the WebReader.""" def __init__(self, cfg: SimpleWebRetrieverConfig): super().__init__(cfg) # load the web page reader self.cfg = extract_config(cfg, SimpleWebRetrieverConfig) self.reader = WEB_READERS.load(cfg) assert self.reader is not None, "WebReader is not set." # load the search engine self.search_engine = SEARCH_ENGINES.load(cfg) return
[docs] def search_item( self, query: str, top_k: int = 10, **search_kwargs ) -> RetrievedContext: resources = self.search_engine.seek(query, top_k=top_k, **search_kwargs) contexts = self.reader.read(resources) return contexts
@property def fields(self): return self.reader.fields