Skip to content

MNIST

npfl138.datasets.mnist.MNIST

Source code in npfl138/datasets/mnist.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class MNIST:
    C: int = 1
    """The number of image channels."""
    H: int = 28
    """The image height."""
    W: int = 28
    """The image width."""
    LABELS: int = 10
    """The number of labels."""

    Element = TypedDict("Element", {"image": torch.Tensor, "label": torch.Tensor})
    """The type of a single dataset element."""
    Elements = TypedDict("Elements", {"images": torch.Tensor, "labels": torch.Tensor})
    """The type of the whole dataset."""

    URL: str = "https://ufal.mff.cuni.cz/~straka/courses/npfl138/2425/datasets/"

    class Dataset(torch.utils.data.Dataset):
        def __init__(self, data: "MNIST.Elements") -> None:
            self._data = {key: torch.as_tensor(value) for key, value in data.items()}
            self._data["images"] = self._data["images"].view(-1, MNIST.C, MNIST.H, MNIST.W)

        @property
        def data(self) -> "MNIST.Elements":
            """Return the whole dataset as a `MNIST.Elements` object."""
            return self._data

        def __len__(self) -> int:
            """Return the number of elements in the dataset."""
            return len(self._data["images"])

        def __getitem__(self, index: int) -> "MNIST.Element":
            """Return the `index`-th element of the dataset."""
            return {key.removesuffix("s"): value[index] for key, value in self._data.items()}

        def batches(
            self, size: int, shuffle: bool = False, generator: torch.Generator | None = None,
        ) -> Iterator["MNIST.Element"]:
            permutation = torch.randperm(len(self), generator=generator) if shuffle else torch.arange(len(self))

            while len(permutation):
                batch_size = min(size, len(permutation))
                batch_perm = permutation[:batch_size]
                permutation = permutation[batch_size:]

                batch = {key: value[batch_perm] for key, value in self._data.items()}
                yield batch

    def __init__(self, dataset: str = "mnist", sizes: dict[str, int] = {}) -> None:
        """Load the MNIST dataset, downloading it if necessary.

        Parameters:
          dataset: The name of the dataset, typically `mnist`.
          sizes: An optional dictionary overriding the sizes of the `train`, `dev`, and `test` splits.
        """
        path = "{}.npz".format(dataset)
        if not os.path.exists(path):
            print("Downloading {} dataset...".format(dataset), file=sys.stderr)
            urllib.request.urlretrieve("{}/{}".format(self.URL, path), filename="{}.tmp".format(path))
            os.rename("{}.tmp".format(path), path)

        mnist = np.load(path)
        for dataset in ["train", "dev", "test"]:
            data = {key[len(dataset) + 1:]: mnist[key][:sizes.get(dataset, None)]
                    for key in mnist if key.startswith(dataset)}
            setattr(self, dataset, self.Dataset(data))

    train: Dataset
    """The training dataset."""
    dev: Dataset
    """The development dataset."""
    test: Dataset
    """The test dataset."""

C class-attribute instance-attribute

C: int = 1

The number of image channels.

H class-attribute instance-attribute

H: int = 28

The image height.

W class-attribute instance-attribute

W: int = 28

The image width.

LABELS class-attribute instance-attribute

LABELS: int = 10

The number of labels.

Element class-attribute instance-attribute

Element = TypedDict('Element', {'image': Tensor, 'label': Tensor})

The type of a single dataset element.

Elements class-attribute instance-attribute

Elements = TypedDict('Elements', {'images': Tensor, 'labels': Tensor})

The type of the whole dataset.

Dataset

Bases: Dataset

Source code in npfl138/datasets/mnist.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data: "MNIST.Elements") -> None:
        self._data = {key: torch.as_tensor(value) for key, value in data.items()}
        self._data["images"] = self._data["images"].view(-1, MNIST.C, MNIST.H, MNIST.W)

    @property
    def data(self) -> "MNIST.Elements":
        """Return the whole dataset as a `MNIST.Elements` object."""
        return self._data

    def __len__(self) -> int:
        """Return the number of elements in the dataset."""
        return len(self._data["images"])

    def __getitem__(self, index: int) -> "MNIST.Element":
        """Return the `index`-th element of the dataset."""
        return {key.removesuffix("s"): value[index] for key, value in self._data.items()}

    def batches(
        self, size: int, shuffle: bool = False, generator: torch.Generator | None = None,
    ) -> Iterator["MNIST.Element"]:
        permutation = torch.randperm(len(self), generator=generator) if shuffle else torch.arange(len(self))

        while len(permutation):
            batch_size = min(size, len(permutation))
            batch_perm = permutation[:batch_size]
            permutation = permutation[batch_size:]

            batch = {key: value[batch_perm] for key, value in self._data.items()}
            yield batch

data property

data: Elements

Return the whole dataset as a MNIST.Elements object.

__len__

__len__() -> int

Return the number of elements in the dataset.

Source code in npfl138/datasets/mnist.py
42
43
44
def __len__(self) -> int:
    """Return the number of elements in the dataset."""
    return len(self._data["images"])

__getitem__

__getitem__(index: int) -> Element

Return the index-th element of the dataset.

Source code in npfl138/datasets/mnist.py
46
47
48
def __getitem__(self, index: int) -> "MNIST.Element":
    """Return the `index`-th element of the dataset."""
    return {key.removesuffix("s"): value[index] for key, value in self._data.items()}

__init__

__init__(dataset: str = 'mnist', sizes: dict[str, int] = {}) -> None

Load the MNIST dataset, downloading it if necessary.

Parameters:

  • dataset (str, default: 'mnist' ) –

    The name of the dataset, typically mnist.

  • sizes (dict[str, int], default: {} ) –

    An optional dictionary overriding the sizes of the train, dev, and test splits.

Source code in npfl138/datasets/mnist.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def __init__(self, dataset: str = "mnist", sizes: dict[str, int] = {}) -> None:
    """Load the MNIST dataset, downloading it if necessary.

    Parameters:
      dataset: The name of the dataset, typically `mnist`.
      sizes: An optional dictionary overriding the sizes of the `train`, `dev`, and `test` splits.
    """
    path = "{}.npz".format(dataset)
    if not os.path.exists(path):
        print("Downloading {} dataset...".format(dataset), file=sys.stderr)
        urllib.request.urlretrieve("{}/{}".format(self.URL, path), filename="{}.tmp".format(path))
        os.rename("{}.tmp".format(path), path)

    mnist = np.load(path)
    for dataset in ["train", "dev", "test"]:
        data = {key[len(dataset) + 1:]: mnist[key][:sizes.get(dataset, None)]
                for key in mnist if key.startswith(dataset)}
        setattr(self, dataset, self.Dataset(data))

train instance-attribute

train: Dataset

The training dataset.

dev instance-attribute

dev: Dataset

The development dataset.

test instance-attribute

test: Dataset

The test dataset.