Source code for flexrag.datasets.dataset

from collections.abc import Iterable, Iterator
from typing import Any, Generic, TypeVar

ItemTypeI = TypeVar("ItemTypeI")
ItemTypeM = TypeVar("ItemTypeM")
ItemTypeChain = TypeVar("ItemTypeChain")
ItemTypeConcat = TypeVar("ItemTypeConcat")


[docs] class IterableDataset(Iterable[ItemTypeI], Generic[ItemTypeI]): r"""IterableDataset is a BaseClass for datasets that can be iterated over. The subclasses of IterableDataset should implement the following methods: >>> # return an iterator over the items in the dataset. >>> def __iter__(self) -> Iterator[ItemTypeI]: ... The following methods are implemented automatically: >>> # concatenate multiple IterableDatasets. >>> def __add__(self, other: IterableDataset[ItemTypeI]) -> IterableDataset[ItemTypeI]: ... For example: >>> class MyDataset(IterableDataset[int]): ... def __init__(self, n: int): ... self.n = n ... return ... ... def __iter__(self) -> Iterator[int]: ... for i in range(self.n): ... yield i ... >>> dataset = MyDataset(3) >>> # Iterate over the dataset. >>> for item in dataset: ... print(item) """ def __add__( self, other: "IterableDataset[ItemTypeI]" ) -> "IterableDataset[ItemTypeI]": return ChainDataset(self, other)
[docs] class MappingDataset(Generic[ItemTypeM]): r"""MappingDataset is a BaseClass for datasets that can be indexed by integers. The subclasses of MappingDataset should implement the following methods: >>> # retrun the item at the given index. >>> def __getitem__(self, index: int) -> ItemTypeM: ... >>> # return the number of items in the dataset. >>> def __len__(self) -> int: ... The following methods are implemented automatically: >>> # concatenate multiple MappingDatasets. >>> def __add__(self, other: MappingDataset[ItemTypeM]) -> MappingDataset[ItemTypeM]: ... >>> # return whether the dataset contains the given index. >>> def __contains__(self, key: int) -> bool: ... >>> # return an iterator over the items in the dataset. >>> def __iter__(self) -> Iterator[ItemTypeM]: ... For example: >>> class MyDataset(MappingDataset[int]): ... def __init__(self, n: int): ... self.n = n ... return ... ... def __getitem__(self, index: int) -> int: ... if 0 <= index < self.n: ... return index ... raise IndexError(f"Index {index} out of range.") ... ... def __len__(self) -> int: ... return self.n ... >>> dataset = MyDataset(3) >>> for i in range(len(dataset)): ... print(dataset[i]) """ def __add__( self, other: "MappingDataset[ItemTypeM]" ) -> "MappingDataset[ItemTypeM]": return ConcatDataset(self, other) def __contains__(self, key: int) -> bool: return 0 <= key < len(self) def __iter__(self) -> Iterator[ItemTypeM]: for i in range(len(self)): yield self[i] def get(self, index: int, default: Any = None) -> ItemTypeM: if 0 <= index < len(self): return self[index] return default
[docs] class ChainDataset(IterableDataset[ItemTypeChain]): """ChainDataset concatenates multiple IterableDatasets.""" def __init__(self, *datasets: IterableDataset): self.datasets = datasets return def __iter__(self) -> Iterator[ItemTypeChain]: for dataset in self.datasets: yield from dataset return
[docs] class ConcatDataset(MappingDataset[ItemTypeConcat]): """ConcatDataset concatenates multiple MappingDatasets.""" def __init__(self, *datasets: MappingDataset): self.datasets = datasets return def __getitem__(self, index: int) -> ItemTypeConcat: for dataset in self.datasets: if index < len(dataset): return dataset[index] index -= len(dataset) raise IndexError(f"Index {index} out of range.") def __iter__(self) -> Iterator[ItemTypeConcat]: for dataset in self.datasets: yield from dataset return def __len__(self) -> int: return sum(len(dataset) for dataset in self.datasets)