flexrag.utils.configure 源代码

import json
import keyword
import os
import types
from copy import deepcopy
from dataclasses import field, fields, is_dataclass
from pathlib import Path
from typing import Annotated, Callable, Generic, Optional, Type, TypeVar

import yaml
from huggingface_hub import HfApi
from omegaconf import DictConfig, ListConfig, OmegaConf
from pydantic import ConfigDict, Field
from pydantic.dataclasses import ConfigDict, dataclass

from .default_vars import FLEXRAG_CACHE_DIR

T = TypeVar("T")


def extract_config(config, config_cls: Type[T]) -> T:
    """
    Extracts the configuration from a pydantic dataclass, omegaconf.DictConfig or dict.
    """
    if isinstance(config, DictConfig):
        config = OmegaConf.to_container(config, resolve=True)
    if isinstance(config, dict):
        config = config_cls(**config)
    elif is_dataclass(config):
        field_names = {f.name for f in fields(config_cls)}
        kwargs = {name: getattr(config, name) for name in field_names}
        config = config_cls(**kwargs)
    else:
        raise TypeError(f"Expected {config_cls}, got {type(config)}")
    return config


def make_dataclass(
    cls_name,
    fields,
    *,
    bases=(),
    namespace=None,
    repr=True,
    eq=True,
    order=False,
    unsafe_hash=False,
    frozen=False,
    kw_only=False,
    slots=False,
    config=None,
    validate_on_init=None,
):
    """Return a new dynamically created pydantic dataclass."""

    if namespace is None:
        namespace = {}

    # While we're looking through the field names, validate that they
    # are identifiers, are not keywords, and not duplicates.
    seen = set()
    annotations = {}
    defaults = {}
    for item in fields:
        if isinstance(item, str):
            name = item
            tp = "typing.Any"
        elif len(item) == 2:
            (
                name,
                tp,
            ) = item
        elif len(item) == 3:
            name, tp, spec = item
            defaults[name] = spec
        else:
            raise TypeError(f"Invalid field: {item!r}")

        if not isinstance(name, str) or not name.isidentifier():
            raise TypeError(f"Field names must be valid identifiers: {name!r}")
        if keyword.iskeyword(name):
            raise TypeError(f"Field names must not be keywords: {name!r}")
        if name in seen:
            raise TypeError(f"Field name duplicated: {name!r}")

        seen.add(name)
        annotations[name] = tp

    # Update 'ns' with the user-supplied namespace plus our calculated values.
    def exec_body_callback(ns):
        ns.update(namespace)
        ns.update(defaults)
        ns["__annotations__"] = annotations

    # We use `types.new_class()` instead of simply `type()` to allow dynamic creation
    # of generic dataclasses.
    cls = types.new_class(cls_name, bases, {}, exec_body_callback)

    # Apply the normal decorator.
    return dataclass(
        cls,
        init=False,  # pydantic dataclass only supports `init=False`
        repr=repr,
        eq=eq,
        order=order,
        unsafe_hash=unsafe_hash,
        frozen=frozen,
        kw_only=kw_only,
        slots=slots,
        config=config,
        validate_on_init=validate_on_init,
    )


def Choices(*args: str) -> Field:
    """
    A shortcut for creating a pydantic Field with a regex pattern that matches one of the provided choices.
    This is useful as hydra-core does not support `Literal` types.
    """
    choices = list(args)
    pattern = f"^({'|'.join(choices)})$"
    return Field(pattern=pattern)


_T = TypeVar("_T")


def _create_pydantic_dataclass(config: ConfigDict) -> Callable[[Type[_T]], Type[_T]]:
    def decorator(cls: Type[_T]) -> Type[_T]:
        cls = dataclass(config=config)(cls)

        def dumps(self) -> str:
            """Dump the dataclass to a YAML string."""
            return OmegaConf.to_yaml(self, resolve=True)

        def dump(self, path: str | Path):
            """Dump the dataclass to a YAML file."""
            path = Path(path)
            path.write_text(self.dumps(), encoding="utf-8")

        @classmethod
        def loads(cls, s: str) -> _T:
            """Load the dataclass from a YAML string."""
            data = yaml.safe_load(s)
            if not isinstance(data, dict):
                raise ValueError("YAML string must represent a dictionary.")
            return cls(**data)

        @classmethod
        def load(cls, path: str | Path) -> _T:
            """Load the dataclass from a YAML file."""
            path = Path(path)
            return cls(**OmegaConf.to_container(OmegaConf.load(path)))

        setattr(cls, "dumps", dumps)
        setattr(cls, "dump", dump)
        setattr(cls, "loads", loads)
        setattr(cls, "load", load)
        return cls

    return decorator


# These two variables are intended as a shortcut
# for creating pydantic.dataclasses.dataclass instances.
configure = _create_pydantic_dataclass(
    ConfigDict(extra="forbid", validate_assignment=True)
)
data = _create_pydantic_dataclass(
    ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
)

RegistedType = TypeVar("RegistedType")


[文档] class Register(Generic[RegistedType]): def __init__(self, register_name: str = None, allow_load_from_repo: bool = False): """Initialize the register. :param register_name: The name of the register, defaults to None. :type register_name: str, optional :param allow_load_from_repo: Whether to allow loading items from the HuggingFace Hub, defaults to False. :type allow_load_from_repo: bool, optional """ self.name = register_name self.allow_load_from_repo = allow_load_from_repo self._items = {} self._shortcuts = {} return def __call__(self, *short_names: str, config_class=None): """Register an item to the register. :param short_names: The short names of the item. :type short_names: str :param config_class: The config class of the item, defaults to None. :type config_class: dataclass :return: The item. :rtype: Any """ def registe_item(item): main_name = str(item).split(".")[-1][:-2] # check name conflict assert main_name not in self._items, f"Name Conflict {main_name}" assert main_name not in self._shortcuts, f"Name Conflict {main_name}" for name in short_names: assert name not in self._items, f"Name Conflict {name}" assert name not in self._shortcuts, f"Name Conflict {name}" # register the item self._items[main_name] = { "item": item, "main_name": main_name, "short_names": short_names, "config_class": config_class, } for name in short_names: self._shortcuts[name] = main_name return item return registe_item def __iter__(self): return self._items.__iter__() @property def names(self) -> list[str]: """Get the names of the registered items.""" return list(self._items.keys()) + list(self._shortcuts.keys()) @property def mainnames(self) -> list[str]: """Get the main names of the registered items.""" return list(self._items.keys()) @property def shortnames(self) -> list[str]: """Get the short names of the registered items.""" return list(self._shortcuts.keys()) def __getitem__(self, key: str) -> dict: if key not in self._items: key = self._shortcuts[key] return self._items[key]
[文档] def get(self, key: str, default=None) -> dict: """Get the item dict by name. :param key: The name of the item. :type key: str :param default: The default value to return, defaults to None. :type default: Any :return: The item dict containing the item, main_name, short_names, and config_class. :rtype: dict """ if key not in self._items: if key not in self._shortcuts: return default key = self._shortcuts[key] return self._items[key]
[文档] def get_item(self, key: str): """Get the item by name. :param key: The name of the item. :type key: str :return: The item. :rtype: Any """ if key not in self._items: key = self._shortcuts[key] return self._items[key]["item"]
[文档] def make_config( self, allow_multiple: bool = False, default: Optional[str] = None, config_name: str = None, ): """Make a config class for the registered items. :param allow_multiple: Whether to allow multiple items to be selected, defaults to False. :type allow_multiple: bool, optional :param default: The default item to select, defaults to None. :type default: Optional[str], optional :param config_name: The name of the config class, defaults to None. :type config_name: str, optional :return: The config class. :rtype: dataclass """ choice_name = f"{self.name}_type" config_name = f"{self.name}_config" if config_name is None else config_name if allow_multiple: if self.allow_load_from_repo: config_fields = [(choice_name, list[str], field(default_factory=list))] else: config_fields = [ ( choice_name, list[Annotated[str, Choices(*self.names)]], field(default_factory=list), ) ] else: if self.allow_load_from_repo: config_fields = [(choice_name, Optional[str], field(default=default))] else: config_fields = [ ( choice_name, Optional[Annotated[str, Choices(*self.names)]], field(default=default), ) ] config_fields += [ ( f"{self[name]['short_names'][0]}_config", Optional[self[name]["config_class"]], field(default_factory=self._items[name]["config_class"]), ) for name in self.mainnames if self[name]["config_class"] is not None ] generated_config = make_dataclass(config_name, config_fields) # set docstring docstring = ( f"Configuration class for {self.name} " f"(name: {config_name}, default: {default}).\n\n" ) docstring += f":param {choice_name}: The {self.name} type to use.\n" if allow_multiple: docstring += f":type {choice_name}: list[str]\n" else: docstring += f":type {choice_name}: str\n" for name in self.mainnames: if self[name]["config_class"] is not None: docstring += f":param {self[name]['short_names'][0]}_config: The config for {name}.\n" docstring += f":type {self[name]['short_names'][0]}_config: {self[name]['config_class'].__name__}\n" generated_config.__doc__ = docstring return generated_config
[文档] def load( self, config: DictConfig, **kwargs, ) -> RegistedType | list[RegistedType]: """Load the item(s) from the generated config. :param config: The config generated by `make_config` method. :type config: DictConfig :param kwargs: The additional arguments to pass to the item(s). :type kwargs: Any :raises ValueError: If the item type is invalid. :return: The loaded item(s). :rtype: RegistedType | list[RegistedType] """ def load_item(type_str: str) -> RegistedType: # Try to load the item from the HuggingFace Hub First if self.allow_load_from_repo: client = HfApi( endpoint=os.environ.get("HF_ENDPOINT", None), token=os.environ.get("HF_TOKEN", None), ) # download the snapshot from the HuggingFace Hub if type_str.count("/") <= 1: try: assert client.repo_exists(type_str) repo_info = client.repo_info(type_str) assert repo_info is not None repo_id = repo_info.id dir_name = os.path.join( FLEXRAG_CACHE_DIR, f"{repo_id.split('/')[0]}--{repo_id.split('/')[1]}", ) snapshot = client.snapshot_download( repo_id=repo_id, local_dir=dir_name, ) assert snapshot is not None return load_item(snapshot) except AssertionError: pass # load the item from the local repository elif os.path.exists(type_str): # prepare the cls id_path = os.path.join(type_str, "cls.id") with open(id_path, "r") as f: cls_name = f.read().strip() # the configure will be ignored # cfg_name = f"{self[cls_name]['short_names'][0]}_config" # new_cfg = getattr(config, cfg_name, None) # load the item return self[cls_name]["item"].load_from_local(type_str) # Load the item directly if type_str in self: cfg_name = f"{self[type_str]['short_names'][0]}_config" sub_cfg = getattr(config, cfg_name, None) if sub_cfg is None: loaded = self[type_str]["item"](**kwargs) else: loaded = self[type_str]["item"](sub_cfg, **kwargs) else: raise ValueError(f"Invalid {self.name} type: {type_str}") return loaded choice = getattr(config, f"{self.name}_type", None) if choice is None: return None if isinstance(choice, (list, ListConfig)): loaded = [] for name in choice: loaded.append(load_item(str(name))) return loaded return load_item(str(choice))
[文档] def squeeze(self, config_instance): """Convert the nused fields to None.""" new_instance = deepcopy(config_instance) choice = getattr(new_instance, f"{self.name}_type", None) if choice is None: selected_fields = [] elif isinstance(choice, list): selected_fields = [] for choice_item in choice: cfg_name = f"{self[str(choice_item)]['short_names'][0]}_config" if getattr(new_instance, cfg_name, None) is not None: selected_fields.append(cfg_name) else: cfg_name = f"{self[str(choice)]['short_names'][0]}_config" if getattr(new_instance, cfg_name, None) is not None: selected_fields = [cfg_name] for field in fields(new_instance): if field.name == f"{self.name}_type": continue if field.name in selected_fields: continue if field.name.endswith("_config"): setattr(new_instance, field.name, None) return new_instance
def __len__(self) -> int: return len(self._items) def __contains__(self, key: str) -> bool: return key in self.names def __str__(self) -> str: data = { "name": self.name, "items": [ { "main_name": k, "short_names": v["short_names"], "config_class": str(v["config_class"]), } for k, v in self._items.items() ], } return json.dumps(data, indent=4) def __repr__(self) -> str: return str(self) def __add__(self, register: "Register"): new_register = Register() new_register._items = {**self._items, **register._items} new_register._shortcuts = {**self._shortcuts, **register._shortcuts} return new_register