Skip to content

TransformedDataset

npfl138.TransformedDataset

Bases: Dataset

A dataset capable of applying transformations to its items and batches.

Source code in npfl138/transformed_dataset.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class TransformedDataset(torch.utils.data.Dataset):
    """A dataset capable of applying transformations to its items and batches.

    """
    def __init__(self, dataset: torch.utils.data.Dataset) -> None:
        """Create a new transformed dataset using the provided dataset.

        Parameters:
          dataset: The source dataset implementing `__len__` and `__getitem__`.
        """
        self._dataset = dataset

    def __len__(self) -> int:
        """Return the number of items in the dataset."""
        return len(self._dataset)

    def __getitem__(self, index: int) -> Any:
        """Return the item at the specified index."""
        item = self._dataset[index]
        if self.transform is not None:
            return self.transform(*item) if isinstance(item, tuple) else self.transform(item)
        return item

    @property
    def dataset(self) -> torch.utils.data.Dataset:
        """Return the source dataset."""
        return self._dataset

    transform: Callable | None = None
    """If given, `transform` is called on each item before returning it.

    If the dataset item is a tuple, `transform` is called with the tuple unpacked.
    """

    collate: Callable | None = None
    """If given, `collate` is called on a list of items before returning them as a batch."""

    transform_batch: Callable | None = None
    """If given, `transform_batch` is called on a batch before returning it."""

    def collate_fn(self, batch: list[Any]) -> Any:
        """A function for a DataLoader to collate a batch of items using `collate` and/or `transform_batch`.

        This function is used as the `collate_fn` parameter of a DataLoader when `collate` or `transform_batch` is set.

        Parameters:
          batch: A list of items to collate and/or pass through `transform_batch`.
        """
        batch = self.collate(batch) if self.collate is not None else torch.utils.data.dataloader.default_collate(batch)
        if self.transform_batch is not None:
            batch = self.transform_batch(batch)
        return batch

    def dataloader(self, batch_size=1, *, shuffle=False, num_workers=0, **kwargs) -> torch.utils.data.DataLoader:
        """Create a DataLoader for this dataset.

        This method is a convenience wrapper around `torch.utils.data.DataLoader`
        setting up the required parameters. All arguments are passed to the DataLoader,
        however, when `num_workers` is greater than 0, `persistent_workers` is set to True.
        When `collate` or `transform_batch` is set, the `self.collate_fn` is passed as the
        `collate_fn` parameter.
        """
        if not shuffle and kwargs.get("generator", None) is None:
            # If not shuffling and no generator is given, pass an explicit generator to the Dataloader.
            # Otherwise, the global random generator would generate a number on every iter(dataloader) call.
            kwargs["generator"] = torch.Generator()

        if num_workers > 0:
            # By default, set persistent_workers to True, but allow it to be overridden
            kwargs.setdefault("persistent_workers", True)

        if self.collate is not None or self.transform_batch is not None:
            if "collate_fn" in kwargs:
                raise ValueError("When collate or transform_batch is overridden, collate_fn must not be given.")
            kwargs["collate_fn"] = self.collate_fn

        # Create and return the DataLoader
        return torch.utils.data.DataLoader(
            self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, **kwargs)

__init__

__init__(dataset: Dataset) -> None

Create a new transformed dataset using the provided dataset.

Parameters:

  • dataset (Dataset) –

    The source dataset implementing __len__ and __getitem__.

Source code in npfl138/transformed_dataset.py
15
16
17
18
19
20
21
def __init__(self, dataset: torch.utils.data.Dataset) -> None:
    """Create a new transformed dataset using the provided dataset.

    Parameters:
      dataset: The source dataset implementing `__len__` and `__getitem__`.
    """
    self._dataset = dataset

__len__

__len__() -> int

Return the number of items in the dataset.

Source code in npfl138/transformed_dataset.py
23
24
25
def __len__(self) -> int:
    """Return the number of items in the dataset."""
    return len(self._dataset)

__getitem__

__getitem__(index: int) -> Any

Return the item at the specified index.

Source code in npfl138/transformed_dataset.py
27
28
29
30
31
32
def __getitem__(self, index: int) -> Any:
    """Return the item at the specified index."""
    item = self._dataset[index]
    if self.transform is not None:
        return self.transform(*item) if isinstance(item, tuple) else self.transform(item)
    return item

dataset property

dataset: Dataset

Return the source dataset.

transform class-attribute instance-attribute

transform: Callable | None = None

If given, transform is called on each item before returning it.

If the dataset item is a tuple, transform is called with the tuple unpacked.

collate class-attribute instance-attribute

collate: Callable | None = None

If given, collate is called on a list of items before returning them as a batch.

transform_batch class-attribute instance-attribute

transform_batch: Callable | None = None

If given, transform_batch is called on a batch before returning it.

collate_fn

collate_fn(batch: list[Any]) -> Any

A function for a DataLoader to collate a batch of items using collate and/or transform_batch.

This function is used as the collate_fn parameter of a DataLoader when collate or transform_batch is set.

Parameters:

  • batch (list[Any]) –

    A list of items to collate and/or pass through transform_batch.

Source code in npfl138/transformed_dataset.py
51
52
53
54
55
56
57
58
59
60
61
62
def collate_fn(self, batch: list[Any]) -> Any:
    """A function for a DataLoader to collate a batch of items using `collate` and/or `transform_batch`.

    This function is used as the `collate_fn` parameter of a DataLoader when `collate` or `transform_batch` is set.

    Parameters:
      batch: A list of items to collate and/or pass through `transform_batch`.
    """
    batch = self.collate(batch) if self.collate is not None else torch.utils.data.dataloader.default_collate(batch)
    if self.transform_batch is not None:
        batch = self.transform_batch(batch)
    return batch

dataloader

dataloader(
    batch_size=1, *, shuffle=False, num_workers=0, **kwargs
) -> DataLoader

Create a DataLoader for this dataset.

This method is a convenience wrapper around torch.utils.data.DataLoader setting up the required parameters. All arguments are passed to the DataLoader, however, when num_workers is greater than 0, persistent_workers is set to True. When collate or transform_batch is set, the self.collate_fn is passed as the collate_fn parameter.

Source code in npfl138/transformed_dataset.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def dataloader(self, batch_size=1, *, shuffle=False, num_workers=0, **kwargs) -> torch.utils.data.DataLoader:
    """Create a DataLoader for this dataset.

    This method is a convenience wrapper around `torch.utils.data.DataLoader`
    setting up the required parameters. All arguments are passed to the DataLoader,
    however, when `num_workers` is greater than 0, `persistent_workers` is set to True.
    When `collate` or `transform_batch` is set, the `self.collate_fn` is passed as the
    `collate_fn` parameter.
    """
    if not shuffle and kwargs.get("generator", None) is None:
        # If not shuffling and no generator is given, pass an explicit generator to the Dataloader.
        # Otherwise, the global random generator would generate a number on every iter(dataloader) call.
        kwargs["generator"] = torch.Generator()

    if num_workers > 0:
        # By default, set persistent_workers to True, but allow it to be overridden
        kwargs.setdefault("persistent_workers", True)

    if self.collate is not None or self.transform_batch is not None:
        if "collate_fn" in kwargs:
            raise ValueError("When collate or transform_batch is overridden, collate_fn must not be given.")
        kwargs["collate_fn"] = self.collate_fn

    # Create and return the DataLoader
    return torch.utils.data.DataLoader(
        self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, **kwargs)