Skip to content

TrainableModule

The Keras-inspired high-level API for training PyTorch models.

The TrainableModule class is a high-level API for training PyTorch models. It is a subclass of torch.nn.Module and:

  • It provides a high-level API for training, evaluation, and prediction via fit, evaluate, and predict methods. Each can be customized by overriding the corresponding train_step, test_step, or predict_step methods.

  • The module automatically handles moving the model to a specified device, using the first available accelerator (GPU, MPS, XPU) by default. To this end, configure or load_weights must always be called first before using the high-level API.

  • The module provides API for serialization and deserialization of the model, both the weights (save_weights, load_weights) and the configuration (save_config, load_config).

  • The module keeps a collection of metrics implementing the MetricProtocol (e.g., any metric from torchmetric), and stores the computed logs in a text file, in TensorBoard logs, and in the console.

npfl138.TrainableModule

Bases: Module

A simple Keras-like module for training with raw PyTorch.

The module provides fit/evaluate/predict methods, computes loss and metrics, and generates TensorBoard, text file, and console logs. By default, it uses an accelerator (GPU, MPS, XPU) if available, and CPU otherwise.

The input to the model can be either a single tensor/PackedSequence or a tuple of those. Similarly, the output can be a single tensor/PackedSequence or a tuple of those. However, when there are multiple outputs, you must handle loss and metrics computation manually.

Source code in npfl138/trainable_module.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
class TrainableModule(torch.nn.Module):
    """A simple Keras-like module for training with raw PyTorch.

    The module provides fit/evaluate/predict methods, computes loss and metrics,
    and generates TensorBoard, text file, and console logs. By default, it uses
    an accelerator (GPU, MPS, XPU) if available, and CPU otherwise.

    The input to the model can be either a single tensor/PackedSequence or
    a tuple of those. Similarly, the output can be a single tensor/PackedSequence
    or a tuple of those. However, when there are multiple outputs, you
    must handle loss and metrics computation manually.
    """
    STOP_TRAINING: Literal["stop_training"] = "stop_training"
    """A constant returned by callbacks to stop the training."""

    def __init__(self, module: torch.nn.Module | None = None):
        """Initialize the module, optionally with an existing PyTorch module.

        Parameters:
          module: An optional existing PyTorch module to wrap, e.g., a [torch.nn.Sequential][]
            or a pretrained Transformer. If given, the module still must be configured.
        """
        super().__init__()
        self.device = None
        self.unconfigure()
        if module is not None:
            self.module = module
            self.forward = self._call_wrapped_module

    def _call_wrapped_module(self, inputs):
        return self.module(inputs)

    def configure(
        self,
        *,
        optimizer: torch.optim.Optimizer | None | KeepPrevious = keep_previous,
        scheduler: torch.optim.lr_scheduler.LRScheduler | None | KeepPrevious = keep_previous,
        loss: LossProtocol | None | KeepPrevious = keep_previous,
        metrics: dict[str, MetricProtocol] | KeepPrevious = keep_previous,
        initial_epoch: int | KeepPrevious = keep_previous,
        logdir: str | None | KeepPrevious = keep_previous,
        device: torch.device | str | Literal["auto"] | KeepPrevious = keep_previous,
    ) -> Self:
        """Configure the module fitting, evaluation, and placement.

        The method can be called multiple times, preserving previously set values by default.

        Note:
          When an input argument cannot be `None`, the corresponding field is
          never `None` after this call.

        Parameters:
          optimizer: The optimizer to use for training.
          scheduler: An optional learning rate scheduler used after every batch.
          loss: The loss function to minimize, implementing the
            [LossProtocol][npfl138.trainable_module.LossProtocol].
          metrics: A dictionary of additional metrics to compute, each being an
            object implementing the [MetricProtocol][npfl138.trainable_module.MetricProtocol]
            (reset/update/compute), e.g., a `torchmetrics.Metric`.
          initial_epoch: The initial epoch of the model used during training and evaluation.
          logdir: An optional directory where textual and TensorBoard logs should be stored.
          device: The device to move the module to. When "auto", or `keep_previous`
            with no previously set device, the first of cuda/mps/xpu is used if available.

        Returns:
          self
        """
        self.optimizer = optimizer if optimizer is not keep_previous else self.optimizer
        self.scheduler = scheduler if scheduler is not keep_previous else self.scheduler
        self.loss = loss if loss is not keep_previous else self.loss
        self.loss_tracker = self.loss_tracker or LossTracker()
        if metrics is not keep_previous or not self.metrics:
            self.metrics = torch.nn.ModuleDict({} if metrics is keep_previous else metrics)
        self.epoch = initial_epoch if initial_epoch is not keep_previous else self.epoch or 0
        if logdir is not keep_previous and logdir != self.logdir:  # reset loggers on a new logdir
            self._log_file, self._tb_writers = None, {}
        self.logdir = logdir if logdir is not keep_previous else self.logdir
        if device is not keep_previous or not self.device:
            self.device = get_auto_device() if device == "auto" or device is keep_previous else torch.device(device)
        self.to(self.device)
        return self

    def unconfigure(self) -> Self:
        """Remove all training configuration of the TrainableModule.

        Only the module device is kept.

        Returns:
          self
        """
        self.optimizer, self.scheduler, self.epoch = None, None, None
        self.loss, self.loss_tracker, self.metrics = None, None, None
        self.logdir, self._log_file, self._tb_writers = None, None, None
        return self

    def fit(
        self,
        dataloader: torch.utils.data.DataLoader,
        *,
        epochs: int,
        dev: torch.utils.data.DataLoader | None = None,
        callbacks: list[CallbackProtocol] = [],
        log_graph: bool = False,
        console: int = console_default(2),
    ) -> Logs:
        """Train the model on the given dataset.

        Parameters:
          dataloader: The training dataset, each element a pair of inputs and outputs;
            the inputs and outputs can be either single tensors or sequences of tensors.
          epochs: The number of epochs to train.
          dev: An optional development dataset to evaluate after every epoch, with the
            same format as the training dataset.
          callbacks: A list of callbacks to call after every epoch, each implementing
            the [CallbackProtocol][npfl138.trainable_module.CallbackProtocol]
            with arguments `self`, `epoch`, and `logs`, possibly returning
            [TrainableModule.STOP_TRAINING](npfl138.TrainableModule.STOP_TRAINING] to stop the
            training (note that the module is set to evaluation mode before calling each callback).
          log_graph: Controls whether to log the model graph to TensorBoard.
          console: Controls the console verbosity: 0 for silent, 1 for epoch logs, 2 for
            additional only-when-writing-to-console progress bar, 3 for persistent progress bar.
            The default is 2, but be overridden by the `CONSOLE` environment variable.

        Returns:
          logs: A dictionary of logs from the training and optionally dev evaluation.

        Note:
          The module is set to evaluation mode when returning from this method.
        """
        assert self.loss_tracker is not None, "The TrainableModule has not been configured, run configure first."
        logs, epochs, stop_training = {}, self.epoch + epochs, False
        while self.epoch < epochs and not stop_training:
            self.epoch += 1
            self.train()
            self.loss_tracker.reset()
            for metric in self.metrics.values():
                metric.reset()
            start = time.time()
            epoch_message = f"Epoch {self.epoch}/{epochs}"
            data_and_progress = tqdm.tqdm(
                dataloader, epoch_message, unit="batch", leave=False, disable=None if console == 2 else console < 2)
            for batch in data_and_progress:
                xs, y = validate_batch_input_output(batch)
                xs = tuple(x.to(self.device) for x in (xs if is_sequence(xs) else (xs,)))
                y = tuple(y_.to(self.device) for y_ in y) if is_sequence(y) else y.to(self.device)
                log_graph = log_graph and self.log_graph(xs) and False
                logs = self.train_step(xs, y)
                if not data_and_progress.disable:
                    logs_message = " ".join([f"{k}={v:#.{0<abs(v)<2e-4 and '2e' or '4f'}}" for k, v in logs.items()])
                    data_and_progress.set_description(f"{epoch_message} {logs_message}", refresh=False)
            logs = {f"train_{k}": v for k, v in logs.items()}
            if dev is not None:
                logs |= {f"dev_{k}": v for k, v in self.eval().evaluate(dev, log_as=None).items()}
            for callback in callbacks:
                stop_training = callback(self.eval(), self.epoch, logs) == self.STOP_TRAINING or stop_training
            self.log_metrics(logs, epochs, time.time() - start, console)
        self.eval()
        return logs

    def train_step(self, xs: TensorOrTensors, y: TensorOrTensors) -> Logs:
        """An overridable method performing a single training step, returning the logs.

        Parameters:
          xs: The input batch to the model, either a single tensor or a sequence of tensors.
          y: The target output batch of the model, either a single tensor or a sequence of tensors.

        Returns:
          logs: A dictionary of logs from the training step.
        """
        self.optimizer.zero_grad()
        y_pred = self(*xs)
        loss = self.compute_loss(y_pred, y, *xs)
        loss.backward()
        with torch.no_grad():
            self.optimizer.step()
            self.scheduler is not None and self.scheduler.step()
            return {"loss": self.loss_tracker(loss)} \
                | ({"lr": self.scheduler.get_last_lr()[0]} if self.scheduler else {}) \
                | self.compute_metrics(y_pred, y, *xs)

    def compute_loss(self, y_pred: TensorOrTensors, y: TensorOrTensors, *xs: tuple[Tensor]) -> torch.Tensor:
        """Compute the loss of the model given the inputs, predictions, and target outputs.

        Parameters:
          y_pred: The model predictions, either a single tensor or a sequence of tensors.
          y: The target output of the model, either a single tensor or a sequence of tensors.
          *xs: The inputs to the model, unpacked, if the input was a sequence of tensors.

        Returns:
          loss: The computed loss.
        """
        return self.loss(y_pred, y)

    def compute_metrics(self, y_pred: TensorOrTensors, y: TensorOrTensors, *xs: TensorOrTensors) -> Logs:
        """Compute and return metrics given the inputs, predictions, and target outputs.

        Parameters:
          y_pred: The model predictions, either a single tensor or a sequence of tensors.
          y: The target output of the model, either a single tensor or a sequence of tensors.
          *xs: The inputs to the model, unpacked, if the input was a sequence of tensors.

        Returns:
          logs: A dictionary of computed metrics.
        """
        for metric in self.metrics.values():
            metric.update(y_pred, y)
        return {name: metric.compute() for name, metric in self.metrics.items()}

    def evaluate(
        self,
        dataloader: torch.utils.data.DataLoader,
        *,
        log_as: str | None = "test",
        callbacks: list[CallbackProtocol] = [],
        console: int = console_default(1),
    ) -> Logs:
        """An evaluation of the model on the given dataset.

        Parameters:
          dataloader: The dataset to evaluate on, each element a pair of inputs and outputs;
            the inputs and outputs can be either a single tensor or a sequence of tensors.
          log_as: The name of the dataset used in the logs; when `None`, no logs are written.
          callbacks: A list of callbacks to call after the evaluation, each implementing
            the [CallbackProtocol][npfl138.trainable_module.CallbackProtocol] with arguments
            `self`, `epoch`, and `logs` arguments.
          console: Controls the console verbosity: 0 for silent, 1 for a single message.
            The default is 1, but be overridden by the `CONSOLE` environment variable.
        """
        assert self.loss_tracker is not None, "The TrainableModule has not been configured, run configure first."
        self.eval()
        self.loss_tracker.reset()
        for metric in self.metrics.values():
            metric.reset()
        start = time.time()
        for batch in dataloader:
            xs, y = validate_batch_input_output(batch)
            xs = tuple(x.to(self.device) for x in (xs if is_sequence(xs) else (xs,)))
            y = tuple(y_.to(self.device) for y_ in y) if is_sequence(y) else y.to(self.device)
            logs = self.test_step(xs, y)
        if log_as is not None:
            logs = {f"{log_as}_{k}": v for k, v in logs.items()}
        for callback in callbacks:
            callback(self, self.epoch, logs)
        if log_as is not None:
            self.log_metrics(logs, elapsed=time.time() - start, console=console)
        return logs

    def test_step(self, xs: TensorOrTensors, y: TensorOrTensors) -> Logs:
        """An overridable method performing a single evaluation step, returning the logs.

        Parameters:
        xs: The input batch to the model, either a single tensor or a sequence of tensors.
        y: The target output batch of the model, either a single tensor or a sequence of tensors.

        Returns:
          logs: A dictionary of logs from the evaluation step.
        """
        with torch.no_grad():
            y_pred = self(*xs)
            loss = self.compute_loss(y_pred, y, *xs)
            return {"loss": self.loss_tracker(loss)} | self.compute_metrics(y_pred, y, *xs)

    def predict(
        self,
        dataloader: torch.utils.data.DataLoader,
        *,
        data_with_labels: bool = False,
        as_numpy: bool = True,
    ) -> list[torch.Tensor | tuple[torch.Tensor, ...] | np.ndarray | tuple[np.ndarray, ...]]:
        """Compute predictions for the given dataset.

        Parameters:
          dataloader: The dataset to predict on, each element either directly the input
            or a pair whose first element is the input; the input can be either
            a single tensor or a sequence of tensors.
          data_with_labels: Specifies whether the dataloader elements
            are (input, labels) pairs or just inputs (the default).
          as_numpy: A a flag controlling whether the output should be
            converted to a numpy array or kept as a PyTorch tensor.

        Returns:
          predictions: A Python list whose elements are predictions
            of the individual examples. Note that if the input was padded, so
            will be the predictions, which will then need to be trimmed.
        """
        assert self.device is not None, "No device has been set for the TrainableModule, run configure first."
        self.eval()
        predictions = []
        for batch in dataloader:
            xs = validate_batch_input(batch, with_labels=data_with_labels)
            xs = tuple(x.to(self.device) for x in (xs if is_sequence(xs) else (xs,)))
            y = self.predict_step(xs, as_numpy=as_numpy)
            predictions.extend(y if not isinstance(y, tuple) else zip(*y))
        return predictions

    def predict_step(
        self, xs: TensorOrTensors, as_numpy: bool = True,
    ) -> torch.Tensor | tuple[torch.Tensor, ...] | np.ndarray | tuple[np.ndarray, ...]:
        """An overridable method performing a single prediction step.

        Parameters:
          xs: The input batch to the model, either a single tensor or a sequence of tensors.
          as_numpy: A flag controlling whether the output should be converted to a numpy array.

        Returns:
          predictions: The batch prediction.
        """
        with torch.no_grad():
            y = self(*xs)
            return maybe_unpack(y, as_numpy) if not is_sequence(y) else tuple(maybe_unpack(y_, as_numpy) for y_ in y)

    def save_weights(self, path: str, optimizer_path: str | None = None) -> Self:
        """Save the model weights to the given path.

        Parameters:
          path: The path to save the model weights to; a `.pt` extension is recommended.
          optimizer_path: An optional path to save the optimizer state to, relative to the
            model weights path.

        Returns:
          self
        """
        state_dict = self.state_dict()
        os.path.dirname(path) and os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save(state_dict, path)

        # Save the number of epochs, optimizer state, and the scheduler state when requested.
        if optimizer_path is not None:
            optimizer_state = {"epoch": self.epoch}
            self.optimizer is not None and optimizer_state.update(optimizer=self.optimizer.state_dict())
            self.scheduler is not None and optimizer_state.update(scheduler=self.scheduler.state_dict())
            optimizer_path = os.path.join(os.path.dirname(path), optimizer_path)
            os.path.dirname(optimizer_path) and os.makedirs(os.path.dirname(optimizer_path), exist_ok=True)
            torch.save(optimizer_state, optimizer_path)
        return self

    def load_weights(self, path: str, optimizer_path: str | None = None,
                     device: torch.device | str | Literal["auto"] | KeepPrevious = keep_previous) -> Self:
        """Load the model weights from the given path.

        Parameters:
          path: The path to load the model weights from.
          optimizer_path: An optional path to load the optimizer state from, relative to the
            model weights path.
          device: The device to load the model to; when "auto", or `keep_previous` with no previously
            set device, the first of cuda/mps/xpu is used if available.

        Returns:
          self
        """
        if device is not keep_previous or not self.device:
            self.device = get_auto_device() if device == "auto" or device is keep_previous else torch.device(device)
        self.load_state_dict(torch.load(path, map_location=self.device))

        # Load the number of epochs, optimizer state, and the scheduler state when requested.
        if optimizer_path is not None:
            optimizer_path = os.path.join(os.path.dirname(path), optimizer_path)
            optimizer_state = torch.load(optimizer_path, map_location=self.device)
            self.epoch = optimizer_state["epoch"]
            if self.optimizer is not None:
                assert "optimizer" in optimizer_state, "The optimizer state is missing."
                self.optimizer.load_state_dict(optimizer_state["optimizer"])
            else:
                assert "optimizer" not in optimizer_state, "The optimizer state is present, but there is no optimizer."
            if self.scheduler is not None:
                assert "scheduler" in optimizer_state, "The scheduler state is missing."
                self.scheduler.load_state_dict(optimizer_state["scheduler"])
            else:
                assert "scheduler" not in optimizer_state, "The scheduler state is present, but there is no scheduler."
        self.to(self.device)
        return self

    @staticmethod
    def save_config(path: str, config: dict = {}, /, **kwargs: dict) -> None:
        """Save a JSON-serializable configuration to the given path.

        The configuration can be given as a dictionary or as keyword arguments
        and the configuration values might also be [argparse.Namespace][] objects.

        Parameters:
          path: The path to save the configuration to; a `.json` extension is recommended.
          config: The configuration dictionary to save.
          **kwargs: Additional configuration values to save.
        """
        config = dict((k + " : argparse.Namespace", vars(v)) if isinstance(v, argparse.Namespace) else (k, v)
                      for k, v in {**config, **kwargs}.items())
        os.path.dirname(path) and os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w", encoding="utf-8") as config_file:
            json.dump(config, config_file, ensure_ascii=False, indent=2)

    @staticmethod
    def load_config(path: str) -> dict:
        """Load a JSON-serializable configuration from the given path.

        Parameters:
          path: The path to load the configuration from.

        Returns:
          config: The loaded configuration dictionary.
        """
        with open(path, "r", encoding="utf-8-sig") as config_file:
            config = json.load(config_file)
        return dict((k.removesuffix(" : argparse.Namespace"), argparse.Namespace(**v))
                    if k.endswith(" : argparse.Namespace") else (k, v) for k, v in config.items())

    def log_metrics(
        self, logs: Logs, epochs: int | None = None, elapsed: float | None = None, console: int = console_default(1),
    ) -> Self:
        """Log the given dictionary to file logs, TensorBoard logs, and optionally the console.

        Parameters:
          logs: The dictionary of logs to write.
          epochs: An optional total number of epochs, used during logging the epoch number.
          elapsed: An optional time elapsed since the beginning of the current epoch.
          console: Controls the console verbosity: 0 for silent, 1 for epoch logs.
            The default is 1, but be overridden by the `CONSOLE` environment variable.

        Returns:
          self
        """
        if self.logdir is not None:
            writers = {}
            for key, value in logs.items():
                writer, metric = key.split("_", maxsplit=1) if "_" in key else ("train", key)
                writers.setdefault(writer, self.get_tb_writer(writer)).add_scalar(metric, value, self.epoch)
            for writer in writers.values():
                writer.flush()
        for file in ([self.get_log_file()] if self.logdir is not None else []) + [sys.stdout] * bool(console):
            print(f"Epoch {self.epoch}" + (f"/{epochs}" if epochs is not None else ""),
                  *[f"{elapsed:.1f}s"] if elapsed is not None else [],
                  *[f"{k}={v:#.{0<abs(v)<2e-4 and '2e' or '4f'}}" for k, v in logs.items()], file=file, flush=True)
        return self

    def log_config(self, config: dict, sort_keys: bool = True, console: int = console_default(1)) -> Self:
        """Log the given dictionary to the file logs, TensorBoard logs, and optionally the console.

        Parameters:
          config: The dictionary of configuration to write.
          sort_keys: Whether to sort the keys of the configuration dictionary.
          console: Controls the console verbosity: 0 for silent, 1 for epoch logs.
            The default is 1, but be overridden by the `CONSOLE` environment variable.

        Returns:
          self
        """
        if self.logdir is not None:
            config = dict(sorted(config.items())) if sort_keys else config
            writer = self.get_tb_writer("train")
            writer.add_text("config", json.dumps(config, ensure_ascii=False, indent=2), self.epoch)
            writer.flush()
        for file in ([self.get_log_file()] if self.logdir is not None else []) + [sys.stdout] * bool(console):
            print("Config", f"epoch={self.epoch}", *[f"{k}={v}" for k, v in config.items()], file=file, flush=True)
        return self

    def log_graph(self, data: torch.utils.data.DataLoader | TensorOrTensors, data_with_labels: bool = False) -> Self:
        """Log the traced module as a graph to the TensorBoard logs.

        Tracing requires an example batch; either the first batch from the
        dataloader passed in `data` is used, or the `data` itself is used.

        Parameters:
          data: The data to use for tracing the module, either a dataloader (in which case
            the first batch is used) or a single batch of inputs.
          data_with_labels: Specifies whether the dataloader elements
            are (input, labels) pairs or just inputs (the default).

        Returns:
          self
        """
        if self.logdir is not None:
            batch = next(iter(data)) if isinstance(data, torch.utils.data.DataLoader) else data
            xs = validate_batch_input(batch, with_labels=data_with_labels)
            xs = tuple(x.to(self.device) for x in (xs if is_sequence(xs) else (xs,)))
            writer = self.get_tb_writer("train")
            writer.add_graph(self, xs)
            writer.flush()
        return self

    def get_log_file(self) -> TextIO:
        """Possibly create and return a text-based log file for the current log.

        To use this method, nonempty `logdir` must have been set in `configure`.

        Returns:
          file: The opened log file.
        """
        assert self.logdir is not None, "Cannot use get_log_file when logdir is not set."
        if self._log_file is None:
            self._log_file = open(os.path.join(self.logdir, "logs.txt"), "a", encoding="utf-8")
        return self._log_file

    def get_tb_writer(self, name: str) -> torch.utils.tensorboard.SummaryWriter:
        """Possibly create and return a TensorBoard writer for the given name.

        To use this method, nonempty `logdir` must have been set in `configure`.

        Returns:
          writer: The opened TensorBoard writer.
        """
        assert self.logdir is not None, "Cannot use get_tb_writer when logdir is not set."
        if name not in self._tb_writers:
            self._tb_writers[name] = torch.utils.tensorboard.SummaryWriter(os.path.join(self.logdir, name))
        return self._tb_writers[name]

STOP_TRAINING class-attribute instance-attribute

STOP_TRAINING: Literal['stop_training'] = 'stop_training'

A constant returned by callbacks to stop the training.

__init__

__init__(module: Module | None = None)

Initialize the module, optionally with an existing PyTorch module.

Parameters:

  • module (Module | None, default: None ) –

    An optional existing PyTorch module to wrap, e.g., a torch.nn.Sequential or a pretrained Transformer. If given, the module still must be configured.

Source code in npfl138/trainable_module.py
195
196
197
198
199
200
201
202
203
204
205
206
207
def __init__(self, module: torch.nn.Module | None = None):
    """Initialize the module, optionally with an existing PyTorch module.

    Parameters:
      module: An optional existing PyTorch module to wrap, e.g., a [torch.nn.Sequential][]
        or a pretrained Transformer. If given, the module still must be configured.
    """
    super().__init__()
    self.device = None
    self.unconfigure()
    if module is not None:
        self.module = module
        self.forward = self._call_wrapped_module

configure

configure(
    *,
    optimizer: Optimizer | None | KeepPrevious = keep_previous,
    scheduler: LRScheduler | None | KeepPrevious = keep_previous,
    loss: LossProtocol | None | KeepPrevious = keep_previous,
    metrics: dict[str, MetricProtocol] | KeepPrevious = keep_previous,
    initial_epoch: int | KeepPrevious = keep_previous,
    logdir: str | None | KeepPrevious = keep_previous,
    device: device | str | Literal["auto"] | KeepPrevious = keep_previous
) -> Self

Configure the module fitting, evaluation, and placement.

The method can be called multiple times, preserving previously set values by default.

Note

When an input argument cannot be None, the corresponding field is never None after this call.

Parameters:

  • optimizer (Optimizer | None | KeepPrevious, default: keep_previous ) –

    The optimizer to use for training.

  • scheduler (LRScheduler | None | KeepPrevious, default: keep_previous ) –

    An optional learning rate scheduler used after every batch.

  • loss (LossProtocol | None | KeepPrevious, default: keep_previous ) –

    The loss function to minimize, implementing the LossProtocol.

  • metrics (dict[str, MetricProtocol] | KeepPrevious, default: keep_previous ) –

    A dictionary of additional metrics to compute, each being an object implementing the MetricProtocol (reset/update/compute), e.g., a torchmetrics.Metric.

  • initial_epoch (int | KeepPrevious, default: keep_previous ) –

    The initial epoch of the model used during training and evaluation.

  • logdir (str | None | KeepPrevious, default: keep_previous ) –

    An optional directory where textual and TensorBoard logs should be stored.

  • device (device | str | Literal['auto'] | KeepPrevious, default: keep_previous ) –

    The device to move the module to. When "auto", or keep_previous with no previously set device, the first of cuda/mps/xpu is used if available.

Returns:

  • Self

    self

Source code in npfl138/trainable_module.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def configure(
    self,
    *,
    optimizer: torch.optim.Optimizer | None | KeepPrevious = keep_previous,
    scheduler: torch.optim.lr_scheduler.LRScheduler | None | KeepPrevious = keep_previous,
    loss: LossProtocol | None | KeepPrevious = keep_previous,
    metrics: dict[str, MetricProtocol] | KeepPrevious = keep_previous,
    initial_epoch: int | KeepPrevious = keep_previous,
    logdir: str | None | KeepPrevious = keep_previous,
    device: torch.device | str | Literal["auto"] | KeepPrevious = keep_previous,
) -> Self:
    """Configure the module fitting, evaluation, and placement.

    The method can be called multiple times, preserving previously set values by default.

    Note:
      When an input argument cannot be `None`, the corresponding field is
      never `None` after this call.

    Parameters:
      optimizer: The optimizer to use for training.
      scheduler: An optional learning rate scheduler used after every batch.
      loss: The loss function to minimize, implementing the
        [LossProtocol][npfl138.trainable_module.LossProtocol].
      metrics: A dictionary of additional metrics to compute, each being an
        object implementing the [MetricProtocol][npfl138.trainable_module.MetricProtocol]
        (reset/update/compute), e.g., a `torchmetrics.Metric`.
      initial_epoch: The initial epoch of the model used during training and evaluation.
      logdir: An optional directory where textual and TensorBoard logs should be stored.
      device: The device to move the module to. When "auto", or `keep_previous`
        with no previously set device, the first of cuda/mps/xpu is used if available.

    Returns:
      self
    """
    self.optimizer = optimizer if optimizer is not keep_previous else self.optimizer
    self.scheduler = scheduler if scheduler is not keep_previous else self.scheduler
    self.loss = loss if loss is not keep_previous else self.loss
    self.loss_tracker = self.loss_tracker or LossTracker()
    if metrics is not keep_previous or not self.metrics:
        self.metrics = torch.nn.ModuleDict({} if metrics is keep_previous else metrics)
    self.epoch = initial_epoch if initial_epoch is not keep_previous else self.epoch or 0
    if logdir is not keep_previous and logdir != self.logdir:  # reset loggers on a new logdir
        self._log_file, self._tb_writers = None, {}
    self.logdir = logdir if logdir is not keep_previous else self.logdir
    if device is not keep_previous or not self.device:
        self.device = get_auto_device() if device == "auto" or device is keep_previous else torch.device(device)
    self.to(self.device)
    return self

unconfigure

unconfigure() -> Self

Remove all training configuration of the TrainableModule.

Only the module device is kept.

Returns:

  • Self

    self

Source code in npfl138/trainable_module.py
262
263
264
265
266
267
268
269
270
271
272
273
def unconfigure(self) -> Self:
    """Remove all training configuration of the TrainableModule.

    Only the module device is kept.

    Returns:
      self
    """
    self.optimizer, self.scheduler, self.epoch = None, None, None
    self.loss, self.loss_tracker, self.metrics = None, None, None
    self.logdir, self._log_file, self._tb_writers = None, None, None
    return self

fit

fit(
    dataloader: DataLoader,
    *,
    epochs: int,
    dev: DataLoader | None = None,
    callbacks: list[CallbackProtocol] = [],
    log_graph: bool = False,
    console: int = console_default(2)
) -> Logs

Train the model on the given dataset.

Parameters:

  • dataloader (DataLoader) –

    The training dataset, each element a pair of inputs and outputs; the inputs and outputs can be either single tensors or sequences of tensors.

  • epochs (int) –

    The number of epochs to train.

  • dev (DataLoader | None, default: None ) –

    An optional development dataset to evaluate after every epoch, with the same format as the training dataset.

  • callbacks (list[CallbackProtocol], default: [] ) –

    A list of callbacks to call after every epoch, each implementing the CallbackProtocol with arguments self, epoch, and logs, possibly returning [TrainableModule.STOP_TRAINING](npfl138.TrainableModule.STOP_TRAINING] to stop the training (note that the module is set to evaluation mode before calling each callback).

  • log_graph (bool, default: False ) –

    Controls whether to log the model graph to TensorBoard.

  • console (int, default: console_default(2) ) –

    Controls the console verbosity: 0 for silent, 1 for epoch logs, 2 for additional only-when-writing-to-console progress bar, 3 for persistent progress bar. The default is 2, but be overridden by the CONSOLE environment variable.

Returns:

  • logs ( Logs ) –

    A dictionary of logs from the training and optionally dev evaluation.

Note

The module is set to evaluation mode when returning from this method.

Source code in npfl138/trainable_module.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
def fit(
    self,
    dataloader: torch.utils.data.DataLoader,
    *,
    epochs: int,
    dev: torch.utils.data.DataLoader | None = None,
    callbacks: list[CallbackProtocol] = [],
    log_graph: bool = False,
    console: int = console_default(2),
) -> Logs:
    """Train the model on the given dataset.

    Parameters:
      dataloader: The training dataset, each element a pair of inputs and outputs;
        the inputs and outputs can be either single tensors or sequences of tensors.
      epochs: The number of epochs to train.
      dev: An optional development dataset to evaluate after every epoch, with the
        same format as the training dataset.
      callbacks: A list of callbacks to call after every epoch, each implementing
        the [CallbackProtocol][npfl138.trainable_module.CallbackProtocol]
        with arguments `self`, `epoch`, and `logs`, possibly returning
        [TrainableModule.STOP_TRAINING](npfl138.TrainableModule.STOP_TRAINING] to stop the
        training (note that the module is set to evaluation mode before calling each callback).
      log_graph: Controls whether to log the model graph to TensorBoard.
      console: Controls the console verbosity: 0 for silent, 1 for epoch logs, 2 for
        additional only-when-writing-to-console progress bar, 3 for persistent progress bar.
        The default is 2, but be overridden by the `CONSOLE` environment variable.

    Returns:
      logs: A dictionary of logs from the training and optionally dev evaluation.

    Note:
      The module is set to evaluation mode when returning from this method.
    """
    assert self.loss_tracker is not None, "The TrainableModule has not been configured, run configure first."
    logs, epochs, stop_training = {}, self.epoch + epochs, False
    while self.epoch < epochs and not stop_training:
        self.epoch += 1
        self.train()
        self.loss_tracker.reset()
        for metric in self.metrics.values():
            metric.reset()
        start = time.time()
        epoch_message = f"Epoch {self.epoch}/{epochs}"
        data_and_progress = tqdm.tqdm(
            dataloader, epoch_message, unit="batch", leave=False, disable=None if console == 2 else console < 2)
        for batch in data_and_progress:
            xs, y = validate_batch_input_output(batch)
            xs = tuple(x.to(self.device) for x in (xs if is_sequence(xs) else (xs,)))
            y = tuple(y_.to(self.device) for y_ in y) if is_sequence(y) else y.to(self.device)
            log_graph = log_graph and self.log_graph(xs) and False
            logs = self.train_step(xs, y)
            if not data_and_progress.disable:
                logs_message = " ".join([f"{k}={v:#.{0<abs(v)<2e-4 and '2e' or '4f'}}" for k, v in logs.items()])
                data_and_progress.set_description(f"{epoch_message} {logs_message}", refresh=False)
        logs = {f"train_{k}": v for k, v in logs.items()}
        if dev is not None:
            logs |= {f"dev_{k}": v for k, v in self.eval().evaluate(dev, log_as=None).items()}
        for callback in callbacks:
            stop_training = callback(self.eval(), self.epoch, logs) == self.STOP_TRAINING or stop_training
        self.log_metrics(logs, epochs, time.time() - start, console)
    self.eval()
    return logs

train_step

train_step(xs: TensorOrTensors, y: TensorOrTensors) -> Logs

An overridable method performing a single training step, returning the logs.

Parameters:

  • xs (TensorOrTensors) –

    The input batch to the model, either a single tensor or a sequence of tensors.

  • y (TensorOrTensors) –

    The target output batch of the model, either a single tensor or a sequence of tensors.

Returns:

  • logs ( Logs ) –

    A dictionary of logs from the training step.

Source code in npfl138/trainable_module.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
def train_step(self, xs: TensorOrTensors, y: TensorOrTensors) -> Logs:
    """An overridable method performing a single training step, returning the logs.

    Parameters:
      xs: The input batch to the model, either a single tensor or a sequence of tensors.
      y: The target output batch of the model, either a single tensor or a sequence of tensors.

    Returns:
      logs: A dictionary of logs from the training step.
    """
    self.optimizer.zero_grad()
    y_pred = self(*xs)
    loss = self.compute_loss(y_pred, y, *xs)
    loss.backward()
    with torch.no_grad():
        self.optimizer.step()
        self.scheduler is not None and self.scheduler.step()
        return {"loss": self.loss_tracker(loss)} \
            | ({"lr": self.scheduler.get_last_lr()[0]} if self.scheduler else {}) \
            | self.compute_metrics(y_pred, y, *xs)

compute_loss

compute_loss(
    y_pred: TensorOrTensors, y: TensorOrTensors, *xs: tuple[Tensor]
) -> Tensor

Compute the loss of the model given the inputs, predictions, and target outputs.

Parameters:

  • y_pred (TensorOrTensors) –

    The model predictions, either a single tensor or a sequence of tensors.

  • y (TensorOrTensors) –

    The target output of the model, either a single tensor or a sequence of tensors.

  • *xs (tuple[Tensor], default: () ) –

    The inputs to the model, unpacked, if the input was a sequence of tensors.

Returns:

  • loss ( Tensor ) –

    The computed loss.

Source code in npfl138/trainable_module.py
360
361
362
363
364
365
366
367
368
369
370
371
def compute_loss(self, y_pred: TensorOrTensors, y: TensorOrTensors, *xs: tuple[Tensor]) -> torch.Tensor:
    """Compute the loss of the model given the inputs, predictions, and target outputs.

    Parameters:
      y_pred: The model predictions, either a single tensor or a sequence of tensors.
      y: The target output of the model, either a single tensor or a sequence of tensors.
      *xs: The inputs to the model, unpacked, if the input was a sequence of tensors.

    Returns:
      loss: The computed loss.
    """
    return self.loss(y_pred, y)

compute_metrics

compute_metrics(
    y_pred: TensorOrTensors, y: TensorOrTensors, *xs: TensorOrTensors
) -> Logs

Compute and return metrics given the inputs, predictions, and target outputs.

Parameters:

  • y_pred (TensorOrTensors) –

    The model predictions, either a single tensor or a sequence of tensors.

  • y (TensorOrTensors) –

    The target output of the model, either a single tensor or a sequence of tensors.

  • *xs (TensorOrTensors, default: () ) –

    The inputs to the model, unpacked, if the input was a sequence of tensors.

Returns:

  • logs ( Logs ) –

    A dictionary of computed metrics.

Source code in npfl138/trainable_module.py
373
374
375
376
377
378
379
380
381
382
383
384
385
386
def compute_metrics(self, y_pred: TensorOrTensors, y: TensorOrTensors, *xs: TensorOrTensors) -> Logs:
    """Compute and return metrics given the inputs, predictions, and target outputs.

    Parameters:
      y_pred: The model predictions, either a single tensor or a sequence of tensors.
      y: The target output of the model, either a single tensor or a sequence of tensors.
      *xs: The inputs to the model, unpacked, if the input was a sequence of tensors.

    Returns:
      logs: A dictionary of computed metrics.
    """
    for metric in self.metrics.values():
        metric.update(y_pred, y)
    return {name: metric.compute() for name, metric in self.metrics.items()}

evaluate

evaluate(
    dataloader: DataLoader,
    *,
    log_as: str | None = "test",
    callbacks: list[CallbackProtocol] = [],
    console: int = console_default(1)
) -> Logs

An evaluation of the model on the given dataset.

Parameters:

  • dataloader (DataLoader) –

    The dataset to evaluate on, each element a pair of inputs and outputs; the inputs and outputs can be either a single tensor or a sequence of tensors.

  • log_as (str | None, default: 'test' ) –

    The name of the dataset used in the logs; when None, no logs are written.

  • callbacks (list[CallbackProtocol], default: [] ) –

    A list of callbacks to call after the evaluation, each implementing the CallbackProtocol with arguments self, epoch, and logs arguments.

  • console (int, default: console_default(1) ) –

    Controls the console verbosity: 0 for silent, 1 for a single message. The default is 1, but be overridden by the CONSOLE environment variable.

Source code in npfl138/trainable_module.py
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
def evaluate(
    self,
    dataloader: torch.utils.data.DataLoader,
    *,
    log_as: str | None = "test",
    callbacks: list[CallbackProtocol] = [],
    console: int = console_default(1),
) -> Logs:
    """An evaluation of the model on the given dataset.

    Parameters:
      dataloader: The dataset to evaluate on, each element a pair of inputs and outputs;
        the inputs and outputs can be either a single tensor or a sequence of tensors.
      log_as: The name of the dataset used in the logs; when `None`, no logs are written.
      callbacks: A list of callbacks to call after the evaluation, each implementing
        the [CallbackProtocol][npfl138.trainable_module.CallbackProtocol] with arguments
        `self`, `epoch`, and `logs` arguments.
      console: Controls the console verbosity: 0 for silent, 1 for a single message.
        The default is 1, but be overridden by the `CONSOLE` environment variable.
    """
    assert self.loss_tracker is not None, "The TrainableModule has not been configured, run configure first."
    self.eval()
    self.loss_tracker.reset()
    for metric in self.metrics.values():
        metric.reset()
    start = time.time()
    for batch in dataloader:
        xs, y = validate_batch_input_output(batch)
        xs = tuple(x.to(self.device) for x in (xs if is_sequence(xs) else (xs,)))
        y = tuple(y_.to(self.device) for y_ in y) if is_sequence(y) else y.to(self.device)
        logs = self.test_step(xs, y)
    if log_as is not None:
        logs = {f"{log_as}_{k}": v for k, v in logs.items()}
    for callback in callbacks:
        callback(self, self.epoch, logs)
    if log_as is not None:
        self.log_metrics(logs, elapsed=time.time() - start, console=console)
    return logs

test_step

test_step(xs: TensorOrTensors, y: TensorOrTensors) -> Logs

An overridable method performing a single evaluation step, returning the logs.

Parameters: xs: The input batch to the model, either a single tensor or a sequence of tensors. y: The target output batch of the model, either a single tensor or a sequence of tensors.

Returns:

  • logs ( Logs ) –

    A dictionary of logs from the evaluation step.

Source code in npfl138/trainable_module.py
427
428
429
430
431
432
433
434
435
436
437
438
439
440
def test_step(self, xs: TensorOrTensors, y: TensorOrTensors) -> Logs:
    """An overridable method performing a single evaluation step, returning the logs.

    Parameters:
    xs: The input batch to the model, either a single tensor or a sequence of tensors.
    y: The target output batch of the model, either a single tensor or a sequence of tensors.

    Returns:
      logs: A dictionary of logs from the evaluation step.
    """
    with torch.no_grad():
        y_pred = self(*xs)
        loss = self.compute_loss(y_pred, y, *xs)
        return {"loss": self.loss_tracker(loss)} | self.compute_metrics(y_pred, y, *xs)

predict

predict(
    dataloader: DataLoader,
    *,
    data_with_labels: bool = False,
    as_numpy: bool = True
) -> list[Tensor | tuple[Tensor, ...] | ndarray | tuple[ndarray, ...]]

Compute predictions for the given dataset.

Parameters:

  • dataloader (DataLoader) –

    The dataset to predict on, each element either directly the input or a pair whose first element is the input; the input can be either a single tensor or a sequence of tensors.

  • data_with_labels (bool, default: False ) –

    Specifies whether the dataloader elements are (input, labels) pairs or just inputs (the default).

  • as_numpy (bool, default: True ) –

    A a flag controlling whether the output should be converted to a numpy array or kept as a PyTorch tensor.

Returns:

  • predictions ( list[Tensor | tuple[Tensor, ...] | ndarray | tuple[ndarray, ...]] ) –

    A Python list whose elements are predictions of the individual examples. Note that if the input was padded, so will be the predictions, which will then need to be trimmed.

Source code in npfl138/trainable_module.py
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
def predict(
    self,
    dataloader: torch.utils.data.DataLoader,
    *,
    data_with_labels: bool = False,
    as_numpy: bool = True,
) -> list[torch.Tensor | tuple[torch.Tensor, ...] | np.ndarray | tuple[np.ndarray, ...]]:
    """Compute predictions for the given dataset.

    Parameters:
      dataloader: The dataset to predict on, each element either directly the input
        or a pair whose first element is the input; the input can be either
        a single tensor or a sequence of tensors.
      data_with_labels: Specifies whether the dataloader elements
        are (input, labels) pairs or just inputs (the default).
      as_numpy: A a flag controlling whether the output should be
        converted to a numpy array or kept as a PyTorch tensor.

    Returns:
      predictions: A Python list whose elements are predictions
        of the individual examples. Note that if the input was padded, so
        will be the predictions, which will then need to be trimmed.
    """
    assert self.device is not None, "No device has been set for the TrainableModule, run configure first."
    self.eval()
    predictions = []
    for batch in dataloader:
        xs = validate_batch_input(batch, with_labels=data_with_labels)
        xs = tuple(x.to(self.device) for x in (xs if is_sequence(xs) else (xs,)))
        y = self.predict_step(xs, as_numpy=as_numpy)
        predictions.extend(y if not isinstance(y, tuple) else zip(*y))
    return predictions

predict_step

predict_step(
    xs: TensorOrTensors, as_numpy: bool = True
) -> Tensor | tuple[Tensor, ...] | ndarray | tuple[ndarray, ...]

An overridable method performing a single prediction step.

Parameters:

  • xs (TensorOrTensors) –

    The input batch to the model, either a single tensor or a sequence of tensors.

  • as_numpy (bool, default: True ) –

    A flag controlling whether the output should be converted to a numpy array.

Returns:

Source code in npfl138/trainable_module.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
def predict_step(
    self, xs: TensorOrTensors, as_numpy: bool = True,
) -> torch.Tensor | tuple[torch.Tensor, ...] | np.ndarray | tuple[np.ndarray, ...]:
    """An overridable method performing a single prediction step.

    Parameters:
      xs: The input batch to the model, either a single tensor or a sequence of tensors.
      as_numpy: A flag controlling whether the output should be converted to a numpy array.

    Returns:
      predictions: The batch prediction.
    """
    with torch.no_grad():
        y = self(*xs)
        return maybe_unpack(y, as_numpy) if not is_sequence(y) else tuple(maybe_unpack(y_, as_numpy) for y_ in y)

save_weights

save_weights(path: str, optimizer_path: str | None = None) -> Self

Save the model weights to the given path.

Parameters:

  • path (str) –

    The path to save the model weights to; a .pt extension is recommended.

  • optimizer_path (str | None, default: None ) –

    An optional path to save the optimizer state to, relative to the model weights path.

Returns:

  • Self

    self

Source code in npfl138/trainable_module.py
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
def save_weights(self, path: str, optimizer_path: str | None = None) -> Self:
    """Save the model weights to the given path.

    Parameters:
      path: The path to save the model weights to; a `.pt` extension is recommended.
      optimizer_path: An optional path to save the optimizer state to, relative to the
        model weights path.

    Returns:
      self
    """
    state_dict = self.state_dict()
    os.path.dirname(path) and os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(state_dict, path)

    # Save the number of epochs, optimizer state, and the scheduler state when requested.
    if optimizer_path is not None:
        optimizer_state = {"epoch": self.epoch}
        self.optimizer is not None and optimizer_state.update(optimizer=self.optimizer.state_dict())
        self.scheduler is not None and optimizer_state.update(scheduler=self.scheduler.state_dict())
        optimizer_path = os.path.join(os.path.dirname(path), optimizer_path)
        os.path.dirname(optimizer_path) and os.makedirs(os.path.dirname(optimizer_path), exist_ok=True)
        torch.save(optimizer_state, optimizer_path)
    return self

load_weights

load_weights(
    path: str,
    optimizer_path: str | None = None,
    device: device | str | Literal["auto"] | KeepPrevious = keep_previous,
) -> Self

Load the model weights from the given path.

Parameters:

  • path (str) –

    The path to load the model weights from.

  • optimizer_path (str | None, default: None ) –

    An optional path to load the optimizer state from, relative to the model weights path.

  • device (device | str | Literal['auto'] | KeepPrevious, default: keep_previous ) –

    The device to load the model to; when "auto", or keep_previous with no previously set device, the first of cuda/mps/xpu is used if available.

Returns:

  • Self

    self

Source code in npfl138/trainable_module.py
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
def load_weights(self, path: str, optimizer_path: str | None = None,
                 device: torch.device | str | Literal["auto"] | KeepPrevious = keep_previous) -> Self:
    """Load the model weights from the given path.

    Parameters:
      path: The path to load the model weights from.
      optimizer_path: An optional path to load the optimizer state from, relative to the
        model weights path.
      device: The device to load the model to; when "auto", or `keep_previous` with no previously
        set device, the first of cuda/mps/xpu is used if available.

    Returns:
      self
    """
    if device is not keep_previous or not self.device:
        self.device = get_auto_device() if device == "auto" or device is keep_previous else torch.device(device)
    self.load_state_dict(torch.load(path, map_location=self.device))

    # Load the number of epochs, optimizer state, and the scheduler state when requested.
    if optimizer_path is not None:
        optimizer_path = os.path.join(os.path.dirname(path), optimizer_path)
        optimizer_state = torch.load(optimizer_path, map_location=self.device)
        self.epoch = optimizer_state["epoch"]
        if self.optimizer is not None:
            assert "optimizer" in optimizer_state, "The optimizer state is missing."
            self.optimizer.load_state_dict(optimizer_state["optimizer"])
        else:
            assert "optimizer" not in optimizer_state, "The optimizer state is present, but there is no optimizer."
        if self.scheduler is not None:
            assert "scheduler" in optimizer_state, "The scheduler state is missing."
            self.scheduler.load_state_dict(optimizer_state["scheduler"])
        else:
            assert "scheduler" not in optimizer_state, "The scheduler state is present, but there is no scheduler."
    self.to(self.device)
    return self

save_config staticmethod

save_config(path: str, config: dict = {}, /, **kwargs: dict) -> None

Save a JSON-serializable configuration to the given path.

The configuration can be given as a dictionary or as keyword arguments and the configuration values might also be argparse.Namespace objects.

Parameters:

  • path (str) –

    The path to save the configuration to; a .json extension is recommended.

  • config (dict, default: {} ) –

    The configuration dictionary to save.

  • **kwargs (dict, default: {} ) –

    Additional configuration values to save.

Source code in npfl138/trainable_module.py
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
@staticmethod
def save_config(path: str, config: dict = {}, /, **kwargs: dict) -> None:
    """Save a JSON-serializable configuration to the given path.

    The configuration can be given as a dictionary or as keyword arguments
    and the configuration values might also be [argparse.Namespace][] objects.

    Parameters:
      path: The path to save the configuration to; a `.json` extension is recommended.
      config: The configuration dictionary to save.
      **kwargs: Additional configuration values to save.
    """
    config = dict((k + " : argparse.Namespace", vars(v)) if isinstance(v, argparse.Namespace) else (k, v)
                  for k, v in {**config, **kwargs}.items())
    os.path.dirname(path) and os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as config_file:
        json.dump(config, config_file, ensure_ascii=False, indent=2)

load_config staticmethod

load_config(path: str) -> dict

Load a JSON-serializable configuration from the given path.

Parameters:

  • path (str) –

    The path to load the configuration from.

Returns:

  • config ( dict ) –

    The loaded configuration dictionary.

Source code in npfl138/trainable_module.py
570
571
572
573
574
575
576
577
578
579
580
581
582
583
@staticmethod
def load_config(path: str) -> dict:
    """Load a JSON-serializable configuration from the given path.

    Parameters:
      path: The path to load the configuration from.

    Returns:
      config: The loaded configuration dictionary.
    """
    with open(path, "r", encoding="utf-8-sig") as config_file:
        config = json.load(config_file)
    return dict((k.removesuffix(" : argparse.Namespace"), argparse.Namespace(**v))
                if k.endswith(" : argparse.Namespace") else (k, v) for k, v in config.items())

log_metrics

log_metrics(
    logs: Logs,
    epochs: int | None = None,
    elapsed: float | None = None,
    console: int = console_default(1),
) -> Self

Log the given dictionary to file logs, TensorBoard logs, and optionally the console.

Parameters:

  • logs (Logs) –

    The dictionary of logs to write.

  • epochs (int | None, default: None ) –

    An optional total number of epochs, used during logging the epoch number.

  • elapsed (float | None, default: None ) –

    An optional time elapsed since the beginning of the current epoch.

  • console (int, default: console_default(1) ) –

    Controls the console verbosity: 0 for silent, 1 for epoch logs. The default is 1, but be overridden by the CONSOLE environment variable.

Returns:

  • Self

    self

Source code in npfl138/trainable_module.py
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
def log_metrics(
    self, logs: Logs, epochs: int | None = None, elapsed: float | None = None, console: int = console_default(1),
) -> Self:
    """Log the given dictionary to file logs, TensorBoard logs, and optionally the console.

    Parameters:
      logs: The dictionary of logs to write.
      epochs: An optional total number of epochs, used during logging the epoch number.
      elapsed: An optional time elapsed since the beginning of the current epoch.
      console: Controls the console verbosity: 0 for silent, 1 for epoch logs.
        The default is 1, but be overridden by the `CONSOLE` environment variable.

    Returns:
      self
    """
    if self.logdir is not None:
        writers = {}
        for key, value in logs.items():
            writer, metric = key.split("_", maxsplit=1) if "_" in key else ("train", key)
            writers.setdefault(writer, self.get_tb_writer(writer)).add_scalar(metric, value, self.epoch)
        for writer in writers.values():
            writer.flush()
    for file in ([self.get_log_file()] if self.logdir is not None else []) + [sys.stdout] * bool(console):
        print(f"Epoch {self.epoch}" + (f"/{epochs}" if epochs is not None else ""),
              *[f"{elapsed:.1f}s"] if elapsed is not None else [],
              *[f"{k}={v:#.{0<abs(v)<2e-4 and '2e' or '4f'}}" for k, v in logs.items()], file=file, flush=True)
    return self

log_config

log_config(
    config: dict, sort_keys: bool = True, console: int = console_default(1)
) -> Self

Log the given dictionary to the file logs, TensorBoard logs, and optionally the console.

Parameters:

  • config (dict) –

    The dictionary of configuration to write.

  • sort_keys (bool, default: True ) –

    Whether to sort the keys of the configuration dictionary.

  • console (int, default: console_default(1) ) –

    Controls the console verbosity: 0 for silent, 1 for epoch logs. The default is 1, but be overridden by the CONSOLE environment variable.

Returns:

  • Self

    self

Source code in npfl138/trainable_module.py
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
def log_config(self, config: dict, sort_keys: bool = True, console: int = console_default(1)) -> Self:
    """Log the given dictionary to the file logs, TensorBoard logs, and optionally the console.

    Parameters:
      config: The dictionary of configuration to write.
      sort_keys: Whether to sort the keys of the configuration dictionary.
      console: Controls the console verbosity: 0 for silent, 1 for epoch logs.
        The default is 1, but be overridden by the `CONSOLE` environment variable.

    Returns:
      self
    """
    if self.logdir is not None:
        config = dict(sorted(config.items())) if sort_keys else config
        writer = self.get_tb_writer("train")
        writer.add_text("config", json.dumps(config, ensure_ascii=False, indent=2), self.epoch)
        writer.flush()
    for file in ([self.get_log_file()] if self.logdir is not None else []) + [sys.stdout] * bool(console):
        print("Config", f"epoch={self.epoch}", *[f"{k}={v}" for k, v in config.items()], file=file, flush=True)
    return self

log_graph

log_graph(
    data: DataLoader | TensorOrTensors, data_with_labels: bool = False
) -> Self

Log the traced module as a graph to the TensorBoard logs.

Tracing requires an example batch; either the first batch from the dataloader passed in data is used, or the data itself is used.

Parameters:

  • data (DataLoader | TensorOrTensors) –

    The data to use for tracing the module, either a dataloader (in which case the first batch is used) or a single batch of inputs.

  • data_with_labels (bool, default: False ) –

    Specifies whether the dataloader elements are (input, labels) pairs or just inputs (the default).

Returns:

  • Self

    self

Source code in npfl138/trainable_module.py
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
def log_graph(self, data: torch.utils.data.DataLoader | TensorOrTensors, data_with_labels: bool = False) -> Self:
    """Log the traced module as a graph to the TensorBoard logs.

    Tracing requires an example batch; either the first batch from the
    dataloader passed in `data` is used, or the `data` itself is used.

    Parameters:
      data: The data to use for tracing the module, either a dataloader (in which case
        the first batch is used) or a single batch of inputs.
      data_with_labels: Specifies whether the dataloader elements
        are (input, labels) pairs or just inputs (the default).

    Returns:
      self
    """
    if self.logdir is not None:
        batch = next(iter(data)) if isinstance(data, torch.utils.data.DataLoader) else data
        xs = validate_batch_input(batch, with_labels=data_with_labels)
        xs = tuple(x.to(self.device) for x in (xs if is_sequence(xs) else (xs,)))
        writer = self.get_tb_writer("train")
        writer.add_graph(self, xs)
        writer.flush()
    return self

get_log_file

get_log_file() -> TextIO

Possibly create and return a text-based log file for the current log.

To use this method, nonempty logdir must have been set in configure.

Returns:

  • file ( TextIO ) –

    The opened log file.

Source code in npfl138/trainable_module.py
658
659
660
661
662
663
664
665
666
667
668
669
def get_log_file(self) -> TextIO:
    """Possibly create and return a text-based log file for the current log.

    To use this method, nonempty `logdir` must have been set in `configure`.

    Returns:
      file: The opened log file.
    """
    assert self.logdir is not None, "Cannot use get_log_file when logdir is not set."
    if self._log_file is None:
        self._log_file = open(os.path.join(self.logdir, "logs.txt"), "a", encoding="utf-8")
    return self._log_file

get_tb_writer

get_tb_writer(name: str) -> SummaryWriter

Possibly create and return a TensorBoard writer for the given name.

To use this method, nonempty logdir must have been set in configure.

Returns:

  • writer ( SummaryWriter ) –

    The opened TensorBoard writer.

Source code in npfl138/trainable_module.py
671
672
673
674
675
676
677
678
679
680
681
682
def get_tb_writer(self, name: str) -> torch.utils.tensorboard.SummaryWriter:
    """Possibly create and return a TensorBoard writer for the given name.

    To use this method, nonempty `logdir` must have been set in `configure`.

    Returns:
      writer: The opened TensorBoard writer.
    """
    assert self.logdir is not None, "Cannot use get_tb_writer when logdir is not set."
    if name not in self._tb_writers:
        self._tb_writers[name] = torch.utils.tensorboard.SummaryWriter(os.path.join(self.logdir, name))
    return self._tb_writers[name]

npfl138.trainable_module.CallbackProtocol

Bases: Protocol

__call__

__call__(
    module: TrainableModule, epoch: int, logs: Logs
) -> Literal["stop_training"] | None

Represents a callback called after every training epoch.

If the callback returns TrainableModule.STOP_TRAINING, the training stops.

Parameters:

  • module (TrainableModule) –

    the module being trained

  • epoch (int) –

    the current epoch number (one-based)

  • logs (Logs) –

    a dictionary of logs, newly computed metric or losses should be added here

Returns:

npfl138.trainable_module.LossProtocol

Bases: Protocol

__call__

__call__(y_pred: TensorOrTensors, y: TensorOrTensors) -> Tensor

Compute the loss of the given predictions and gold outputs.

npfl138.trainable_module.MetricProtocol

Bases: Protocol

reset

reset() -> None

Reset the metric to its initial state.

update

update(y_pred: TensorOrTensors, y: TensorOrTensors) -> None

Update the metric with the given predictions and gold outputs.

compute

compute() -> Tensor

Return the current value of the metric.

npfl138.trainable_module.Logs module-attribute

Logs: TypeAlias = dict[str, float]

A dictionary of logs, with keys being the log names and values being the log values.

npfl138.trainable_module.Tensor module-attribute

A type alias for a single tensor or a packed sequence of tensors.

npfl138.trainable_module.TensorOrTensors module-attribute

TensorOrTensors: TypeAlias = Tensor | tuple[Tensor, ...] | list[Tensor]

A type alias for a single tensor/packed sequence of a sequence of them.