Source code for flexrag.datasets.rag_dataset

from dataclasses import field
from typing import Iterator, Optional

from flexrag.text_process import TextProcessPipeline, TextProcessPipelineConfig
from flexrag.utils import LOGGER_MANAGER, configure, data
from flexrag.utils.dataclasses import Context

from .hf_dataset import HFDataset, HFDatasetConfig
from .line_delimited_dataset import LineDelimitedDataset, LineDelimitedDatasetConfig

logger = LOGGER_MANAGER.get_logger("flexrag.datasets.rag_dataset")


[docs] @data class RAGEvalData: """The dataclass for konwledge intensive QA task. :param question: The question for evaluation. Required. :type question: str :param golden_contexts: The contexts related to the question. Default: None. :type golden_contexts: Optional[list[Context]] :param golden_answers: The golden answers for the question. Default: None. :type golden_answers: Optional[list[str]] :param meta_data: The metadata of the evaluation data. Default: {}. :type meta_data: dict """ question: str golden_contexts: Optional[list[Context]] = None golden_answers: Optional[list[str]] = None meta_data: dict = field(default_factory=dict)
@data class RAGMultipleChoiceData: """The dataclass for multiple choice task. :param question: The question for evaluation. Required. :type question: str :param options: The options for the question. Required. :type options: list[str] :param golden_option: The golden option for the question. Default: None. :type golden_option: Optional[list[int]] :param golden_contexts: The contexts related to the question. Default: None. :type golden_contexts: Optional[list[Context]] :param meta_data: The metadata of the evaluation data. Default: {}. :type meta_data: dict """ question: str options: list[str] golden_options: Optional[list[int]] = None golden_contexts: Optional[list[Context]] = None meta_data: dict = field(default_factory=dict) @data class RAGTrueFalseData: """The dataclass for true/false task. :param question: The question for evaluation. Required. :type question: str :param golden_contexts: The contexts related to the question. Default: None. :type golden_contexts: Optional[list[Context]] :param golden_answer: The golden answer for the question. Default: None. :type golden_answer: Optional[bool] :param meta_data: The metadata of the evaluation data. Default: {}. :type meta_data: dict """ question: str golden_contexts: Optional[list[Context]] = None golden_answer: Optional[bool] = None meta_data: dict = field(default_factory=dict)
[docs] @configure class RAGEvalDatasetConfig(HFDatasetConfig): """The configuration for ``RAGEvalDataset``. This dataset helps to load the evaluation dataset collected by `FlashRAG <https://huggingface.co/datasets/RUC-NLPIR/FlashRAG_datasets>`_. The ``__iter__`` method will yield `RAGEvalData` objects. For example, you can load the `test` set of the `NaturalQuestions` dataset by running the following code: .. code-block:: python from flexrag.datasets import RAGEvalDataset, RAGEvalDatasetConfig cfg = RAGEvalDatasetConfig( name="nq", split="test", ) dataset = RAGEvalDataset(cfg) You can also load the dataset from a local repository by specifying the path. For example, you can download the dataset by running the following command: >>> git lfs install >>> git clone https://huggingface.co/datasets/RUC-NLPIR/FlashRAG_datasets flashrag Then you can load the dataset by running the following code: .. code-block:: python from flexrag.datasets import RAGEvalDataset, RAGEvalDatasetConfig cfg = RAGEvalDatasetConfig( path="json", data_files=["flashrag/nq/test.jsonl"], split="train", ) dataset = RAGEvalDataset(cfg) Available datasets include: - 2wikimultihopqa: dev, train - ambig_qa: dev, train - arc: dev, test, train - asqa: dev, train - ay2: dev, train - bamboogle: test - boolq: dev, train - commonsenseqa: dev, train - curatedtrec: test, train - eli5: dev, train - fermi: dev, test, train - fever: dev, train - hellaswag: dev, train - hotpotqa: dev, train - mmlu: 5_shot, dev, test, train - msmarco-qa: dev, train - musique: dev, train - narrativeqa: dev, test, train - nq: dev, test, train - openbookqa: dev, test, train - piqa: dev, train - popqa: test - quartz: dev, test, train - siqa: dev, train - squad: dev, train - t-rex: dev, train - triviaqa: dev, test, train - truthful_qa: dev - web_questions: test, train - wikiasp: dev, test, train - wikiqa: dev, test, train - wned: dev - wow: dev, train - zero-shot_re: dev, train """ path: str = "RUC-NLPIR/FlashRAG_datasets"
[docs] class RAGEvalDataset(HFDataset): """The dataset for loading RAG evaluation data.""" def __init__(self, cfg: RAGEvalDatasetConfig) -> None: super().__init__(cfg) return def __getitem__(self, index: int) -> RAGEvalData | RAGMultipleChoiceData: data = super().__getitem__(index) golden_contexts = data.pop("golden_contexts", None) golden_contexts = ( [Context(**context) for context in golden_contexts] if golden_contexts is not None else None ) # multiple choice data if "choices" in data: formatted_data = RAGMultipleChoiceData( question=data.pop("question"), options=data.pop("choices"), golden_options=data.pop("golden_answers", None), golden_contexts=golden_contexts, ) # knowledge intensive qa data else: formatted_data = RAGEvalData( question=data.pop("question"), golden_contexts=golden_contexts, golden_answers=data.pop("golden_answers", None), ) formatted_data.meta_data = data.pop("meta_data", {}) formatted_data.meta_data.update(data) return formatted_data def __iter__(self) -> Iterator[RAGEvalData | RAGMultipleChoiceData]: yield from super().__iter__()
[docs] @configure class RAGCorpusDatasetConfig(LineDelimitedDatasetConfig): """The configuration for ``RAGCorpusDataset``. This dataset helps to load the pre-processed corpus data for RAG retrieval. The ``__iter__`` method will yield `Context` objects. :param saving_fields: The fields to save in the context. If not specified, all fields will be saved. :type saving_fields: list[str] :param id_field: The field to use as the context_id. If not specified, the ordinal number will be used. :type id_field: Optional[str] :param processors: The preprocessors for each field. Default is {}. The key is the field name, and the value is the `TextProcessPipelineConfig`. :type processors: dict[str, TextProcessPipelineConfig] For example, to load the corpus provided by the `Atlas <https://github.com/facebookresearch/atlas>`_, you can download the corpus by running the following command: .. code-block:: bash wget https://dl.fbaipublicfiles.com/atlas/corpora/wiki/enwiki-dec2021/text-list-100-sec.jsonl wget https://dl.fbaipublicfiles.com/atlas/corpora/wiki/enwiki-dec2021/infobox.jsonl Then you can use the following code to load the corpus with a length filter: .. code-block:: python from flexrag.datasets import RAGCorpusDataset, RAGCorpusDatasetConfig from flexrag.text_process import TextProcessPipelineConfig, LengthFilterConfig cfg = RAGCorpusDatasetConfig( file_paths=[ "/data/zhangzhuocheng/Lab/Python/LLM/datasets/RAG/wikipedia/wiki_2021/infobox.jsonl", "/data/zhangzhuocheng/Lab/Python/LLM/datasets/RAG/wikipedia/wiki_2021/text-list-100-sec.jsonl", ], saving_fields=["title", "text"], processors={ "text": TextProcessPipelineConfig( processor_type=["length_filter"], length_filter_config=LengthFilterConfig( max_chars=4096, min_chars=10, ), ) }, encoding="utf-8", ) dataset = RAGCorpusDataset(cfg) The above code will load the corpus data from the provided files and preprocess the `text` field with a length filter. For any text with a length less than 10 or greater than 4096 characters, it will be filtered out. """ saving_fields: list[str] = field(default_factory=list) id_field: Optional[str] = None processors: dict[str, TextProcessPipelineConfig] = field(default_factory=dict) # type: ignore
[docs] class RAGCorpusDataset(LineDelimitedDataset): """The dataset for loading pre-processed corpus data for RAG retrieval.""" def __init__(self, cfg: RAGCorpusDatasetConfig) -> None: super().__init__(cfg) # load arguments self.saving_fields = cfg.saving_fields self.id_field = cfg.id_field if self.id_field is None: logger.warning("No id field is provided, using the index as the id field") # load processors for each fields assert all( key in self.saving_fields for key in cfg.processors ), f"The field to process is not in the saving fields: {self.saving_fields}." self.processors = { key: TextProcessPipeline(cfg.processors[key]) for key in cfg.processors } return def __iter__(self) -> Iterator[Context]: for n, data in enumerate(super().__iter__()): # prepare context_id if self.id_field is not None: context_id = data.pop(self.id_field) else: context_id = str(n) # remove unused fields if len(self.saving_fields) > 0: data = {key: data.get(key, "") for key in self.saving_fields} # preprocess each fields for key in data: if key in self.processors: data[key] = self.processors[key](data[key]) # filter the data if any(data[key] is None for key in data): continue yield Context(context_id=context_id, data=data)