Source code for flexrag.retriever.web_retrievers.wikipedia_retriever
import time
from typing import Optional
import httpx
from bs4 import BeautifulSoup
from flexrag.utils import LOGGER_MANAGER, SimpleProgressLogger, configure
from flexrag.utils.configure import extract_config
from flexrag.utils.dataclasses import RetrievedContext
from ..retriever_base import RETRIEVERS, RetrieverBase, RetrieverBaseConfig
logger = LOGGER_MANAGER.get_logger("flexrag.retrievers.web_retriever")
[docs]
@configure
class WikipediaRetrieverConfig(RetrieverBaseConfig):
"""The configuration for the ``WikipediaRetriever``.
:param search_url: The search URL for Wikipedia.
Default is "https://en.wikipedia.org/w/index.php?search=".
:type search_url: str
:param proxy: The proxy to use. Default is None.
:type proxy: Optional[str]
"""
search_url: str = "https://en.wikipedia.org/w/index.php?search="
proxy: Optional[str] = None
[docs]
@RETRIEVERS("wikipedia", config_class=WikipediaRetrieverConfig)
class WikipediaRetriever(RetrieverBase):
"""WikipediaRetriever retrieves information from Wikipedia directly.
Adapted from https://github.com/ysymyth/ReAct"""
name = "wikipedia"
def __init__(self, cfg: WikipediaRetrieverConfig):
super().__init__(cfg)
# set basic configs
self.cfg = extract_config(cfg, WikipediaRetrieverConfig)
self.client = httpx.Client(proxy=cfg.proxy)
return
[docs]
def search(
self,
query: list[str] | str,
delay: float = 0.1,
**search_kwargs,
) -> list[list[RetrievedContext]]:
if isinstance(query, str):
query = [query]
# search & parse
results = []
p_logger = SimpleProgressLogger(logger, len(query), self.cfg.log_interval)
for q in query:
time.sleep(delay)
p_logger.update(1, "Searching")
results.append([self.search_item(q, **search_kwargs)])
return results
def search_item(self, query: str, **kwargs) -> RetrievedContext:
search_url = self.cfg.search_url + query.replace(" ", "+")
response_text = self.client.get(search_url).text
soup = BeautifulSoup(response_text, features="html.parser")
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
if result_divs: # mismatch
similar_entities = [
self._clear_str(div.get_text().strip()) for div in result_divs
]
page_content = None
summary = None
else:
similar_entities = []
page = [
p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")
]
if any("may refer to:" in p for p in page):
return self.search_item("[" + query + "]")
else: # concatenate all paragraphs
page_content = ""
for p in page:
if len(p.split(" ")) > 2:
page_content += self._clear_str(p)
if not p.endswith("\n"):
page_content += "\n"
summary = self._get_summary(page_content)
return RetrievedContext(
retriever=self.name,
query=query,
data={
"raw_page": response_text,
"page_content": page_content,
"summary": summary,
"similar_entities": similar_entities,
},
source=search_url,
)
def _clear_str(self, text: str) -> str:
return text.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
def _get_summary(self, page: str) -> str:
# find all paragraphs
paragraphs = page.split("\n")
paragraphs = [p.strip() for p in paragraphs if p.strip()]
# find all sentence
sentences = []
for p in paragraphs:
sentences += p.split(". ")
sentences = [s.strip() + "." for s in sentences if s.strip()]
summary = " ".join(sentences[:5])
return summary.replace("\\n", "")
@property
def fields(self):
return ["raw_page", "page_content", "summary", "similar_entities"]