Source code for flexrag.retriever.web_retrievers.web_downloader
import asyncio
import io
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Annotated, Any, Optional
from httpx import Client
from PIL import Image
from flexrag.utils import Choices, Register, configure
from .utils import WebResource
[docs]
@configure
class WebDownloaderBaseConfig:
"""The configuration for the ``WebDownloaderBase``.
:param allow_parallel: Whether to allow parallel downloading. Default is True.
:type allow_parallel: bool
"""
allow_parallel: bool = True
[docs]
class WebDownloaderBase(ABC):
"""The base class for the ``WebDownloader``."""
def __init__(self, cfg: WebDownloaderBaseConfig) -> None:
self.allow_parallel = cfg.allow_parallel
return
[docs]
def download(self, resources: WebResource | list[WebResource]) -> list[WebResource]:
"""Download the web resources.
:param resources: The resources to download.
:type resources: WebResource | list[WebResource]
:return: The downloaded web resources.
:rtype: list[WebResource]
"""
if not isinstance(resources, list):
resources = [resources]
if self.allow_parallel:
with ThreadPoolExecutor() as executor:
results = list(executor.map(self._download_item, resources))
else:
results = [self._download_item(url) for url in resources]
return results
[docs]
async def async_download(self, resources: WebResource | list[WebResource]) -> Any:
"""Download the web resources asynchronously."""
if isinstance(resources, str):
resources = [resources]
results = await asyncio.gather(
*[
asyncio.to_thread(partial(self._download_item, url=url))
for url in resources
]
)
return results
@abstractmethod
def _download_item(self, resource: WebResource) -> Any:
"""Download the resource.
:param resource: The web resource to download.
:type resource: WebResource
:return: The downloaded web resource.
:rtype: WebResource
"""
return
WEB_DOWNLOADERS = Register[WebDownloaderBase]("web_downloader")
[docs]
@configure
class SimpleWebDownloaderConfig(WebDownloaderBaseConfig):
"""The configuration for the ``SimpleWebDownloader``.
:param proxy: The proxy to use. Default is None.
:type proxy: Optional[str]
:param timeout: The timeout for the requests. Default is 3.0.
:type timeout: float
:param headers: The headers to use. Default is None.
:type headers: Optional[dict]
"""
proxy: Optional[str] = None
timeout: float = 3.0
headers: Optional[dict] = None
[docs]
@WEB_DOWNLOADERS("simple", config_class=SimpleWebDownloaderConfig)
class SimpleWebDownloader(WebDownloaderBase):
"""Download the html content using httpx."""
def __init__(self, cfg: SimpleWebDownloaderConfig) -> None:
super().__init__(cfg)
# setting httpx client
self.client = Client(
headers=cfg.headers,
proxies=cfg.proxy,
timeout=cfg.timeout,
)
return
def _download_item(self, resource: WebResource) -> str:
response = self.client.get(resource.url)
response.raise_for_status()
resource.data = response.text
return resource
[docs]
@configure
class PlaywrightWebDownloaderConfig(WebDownloaderBaseConfig):
"""The configuration for the ``PlaywrightWebDownloader``.
:param headless: Whether to run the browser in headless mode. Default is True.
:type headless: bool
:param browser: The browser to use. Default is `chromium`.
Available choices are `chromium`, `firefox`, `webkit`, and `msedge`.
:type browser: str
:param device: The device to emulate. Default is `Desktop Chrome`.
:type device: str
:param page_width: The width of the emulate device. Default is None.
:type page_width: Optional[int]
:param page_height: The height of the emulate device. Default is None.
:type page_height: Optional[int]
:param proxy: The proxy to use. Default is None.
:type proxy: Optional[str]
:param return_screenshot: Whether to return the screenshot. Default is False.
:type return_screenshot: bool
"""
headless: bool = True
browser: Annotated[str, Choices("chromium", "firefox", "webkit", "msedge")] = (
"chromium"
)
device: str = "Desktop Chrome"
page_width: Optional[int] = None
page_height: Optional[int] = None
proxy: Optional[str] = None
return_screenshot: bool = False
[docs]
@WEB_DOWNLOADERS("playwright", config_class=PlaywrightWebDownloaderConfig)
class PlaywrightWebDownloader(WebDownloaderBase):
"""Download the web resources using playwright."""
def __init__(self, cfg: PlaywrightWebDownloaderConfig) -> None:
super().__init__(cfg)
# load the playwright
try:
from playwright.async_api import async_playwright
from playwright.sync_api import sync_playwright
self.async_playwright = async_playwright
self.sync_playwright = sync_playwright
except ImportError:
raise ImportError(
"Please install playwright using `pip install pytest-playwright`."
"Then, execute `playwright install`."
)
# set the arguments
self.headless = cfg.headless
self.proxy = {"server": cfg.proxy} if cfg.proxy is not None else None
self.browser = cfg.browser
self.device = cfg.device
self.page_width = cfg.page_width
self.page_height = cfg.page_height
self.return_screenshot = cfg.return_screenshot
return
[docs]
def download(self, resources: WebResource | list[WebResource]) -> WebResource:
asyncio.get_event_loop().run_until_complete(self.async_download(resources))
return resources
[docs]
async def async_download(self, resources):
if not isinstance(resources, list):
resources = [resources]
async with self.async_playwright() as p:
# launch the browser
match self.browser:
case "chromium":
browser = await p.chromium.launch(headless=self.headless)
case "firefox":
browser = await p.firefox.launch(headless=self.headless)
case "webkit":
browser = await p.webkit.launch(headless=self.headless)
case "msedge":
browser = await p.chromium.launch(headless=self.headless)
case _:
raise ValueError(f"Browser {self.browser} is not supported.")
# set the browser context
ctx_param = p.devices[self.device]
if self.page_height is not None:
ctx_param["viewport"]["height"] = self.page_height
if self.page_width is not None:
ctx_param["viewport"]["width"] = self.page_width
ctx_param["proxy"] = self.proxy
context = await browser.new_context(**ctx_param)
# download the resources
async def get_content(r: WebResource):
page = await context.new_page()
await page.goto(r.url)
if self.return_screenshot:
img_bytes = await page.screenshot(full_page=True)
r.data = Image.open(io.BytesIO(img_bytes))
else:
r.data = await page.content()
await page.close()
return r
if self.allow_parallel:
resources = await asyncio.gather(*[get_content(r) for r in resources])
else:
resources = []
for r in resources:
r = await get_content(r)
resources.append(r)
# close the browser
await browser.close()
return resources
def _download_item(self, resource: WebResource) -> WebResource:
raise NotImplementedError(
"Please use the `download` or `async_download` method."
)
WebDownloaderConfig = WEB_DOWNLOADERS.make_config(config_name="WebDownloaderConfig")