Source code for flexrag.datasets.line_delimited_dataset

import json
from csv import reader as csv_reader
from dataclasses import field
from glob import glob
from os import PathLike
from pathlib import Path
from typing import Iterator

from flexrag.utils import configure

from .dataset import IterableDataset


[docs] @configure class LineDelimitedDatasetConfig: """The configuration for ``LineDelimitedDataset``. :param file_paths: The paths to the line delimited files. It supports unix style path pattern. :type file_paths: list[str] :param data_ranges: The data ranges to load from the files. The format is a list of [start_point, end_point] for each file. If end_point is -1, it will read to the end of the file. If not specified, it will read the whole file. :type data_ranges: list[list[int, int]] :param encoding: The encoding of the files. :type encoding: str Example 1: Loading specific lines from given files. >>> cfg = LineDelimitedDatasetConfig( ... file_paths=["data1.jsonl", "data2.csv"], ... data_ranges=[[0, 10], [0, 20]], ... encoding="utf-8", ... ) >>> dataset = LineDelimitedDataset(cfg) >>> items = [i for i in dataset] Example 2: Loading multiple files using unix style path pattern. >>> cfg = LineDelimitedDatasetConfig( ... file_paths=["data/*.jsonl"], ... encoding="utf-8", ... ) >>> dataset = LineDelimitedDataset(cfg) >>> items = [i for i in dataset] """ file_paths: list[str] data_ranges: list[list[int, int]] = field(default_factory=list) encoding: str = "utf-8"
[docs] class LineDelimitedDataset(IterableDataset): """The iterative dataset for loading line delimited files (csv, tsv, jsonl).""" def __init__(self, cfg: LineDelimitedDatasetConfig) -> None: # process unix style path file_paths: list[Path] = [] for p in cfg.file_paths: if isinstance(p, str): paths = [Path(p_) for p_ in glob(p)] # elif isinstance(p, PathLike): # paths = [Path(p)] else: raise TypeError(f"Unsupported file path type: {type(p)}") if len(paths) == 0: raise FileNotFoundError(f"File {p} does not exist.") if len(paths) > 1: assert ( len(cfg.data_ranges) == 0 ), "`data_ranges` do not support unix style path pattern" file_paths.extend(paths) # file_paths = [glob(p) for p in cfg.file_paths] # for p, p_ in zip(file_paths, cfg.file_paths): # assert len(p) != 0, f"File {p_} does not exsit." # if len(p) != 1: # assert ( # len(cfg.data_ranges) == 0 # ), "`data_ranges` do not support unix style path pattern" # file_paths = [p for file_path in file_paths for p in file_path] # check data_ranges consistency if len(cfg.data_ranges) != 0: assert len(cfg.data_ranges) == len(file_paths), "Invalid data ranges" else: cfg.data_ranges = [[0, -1] for _ in file_paths] self.file_paths = file_paths self.data_ranges = cfg.data_ranges self.encoding = cfg.encoding return def __iter__(self) -> Iterator[dict]: # read data for file_path, data_range in zip(self.file_paths, self.data_ranges): start_point, end_point = data_range if end_point > 0: assert end_point > start_point, f"Invalid data range: {data_range}" match file_path.suffix: case ".jsonl": with open(file_path, "r", encoding=self.encoding) as f: for i, line in enumerate(f): if i < start_point: continue if (end_point > 0) and (i >= end_point): break yield json.loads(line) case ".tsv": title = [] with open(file_path, "r", encoding=self.encoding) as f: for i, row in enumerate(csv_reader(f, delimiter="\t")): if i == 0: title = row continue if i <= start_point: continue if (end_point > 0) and (i > end_point): break yield dict(zip(title, row)) case ".csv": title = [] with open(file_path, "r", encoding=self.encoding) as f: for i, row in enumerate(csv_reader(f)): if i == 0: title = row continue if i <= start_point: continue if (end_point > 0) and (i > end_point): break yield dict(zip(title, row)) case _: raise ValueError(f"Unsupported file format: {file_path}") return def __repr__(self) -> str: return f"LineDelimitedDataset({self.file_paths})"