trainer module

training.trainer.train_diresa(model: diresa_torch.arch.models.Diresa, train_loader: torch.utils.data.DataLoader, criteria: list, optimizer: torch.optim.Optimizer, num_epochs: int = 10, val_loader: torch.utils.data.DataLoader | None = None, loss_weights: list = [1.0, 1.0, 1.0], staged_training: bool = False, train_twin_loader: torch.utils.data.DataLoader | None = None, val_twin_loader: torch.utils.data.DataLoader | None = None, input_filter: callable = <function <lambda>>, target_filter: callable = <function <lambda>>, callbacks: list | None = None) dict[str, dict[str, list[float]]]

Trains model. Needs to provide multiple loss function in order to train de different parts of the model. CovarianceLoss and DistanceLoss are used to produce an interpretable latent space while ReconstructionLoss is used to produce a reconstructed output. Input for the twin encoder can be given by a separate twin dataloader, supporting shuffling over the whole dataset. In this case the twin dataloader must have the same batch size and number of batches as the input dataloader. If the twin dataloader is None, input for the twin is produced by shuffling the batch.

Parameters:
  • model – The model to train

  • train_loader – Training data loader

  • criteria – List of Loss function. With order [ReconstructionLoss, CovarianceLoss, DistanceLoss]

  • optimizer – Optimizer

  • num_epochs – Number of epochs

  • val_loader – Optional validation loader

  • loss_weights – Weighting factor for the different losses. With order [ReconstructionLoss, CovarianceLoss, DistanceLoss]

  • staged_training – If set to True will train the encoder and the decoder separately for num_epochs each.

  • train_twin_loader – Twin training data loader (needs also val_twin_loader), if None then shuffling is done per batch

  • val_twin_loader – Twin validation data loader, if None then shuffling is done per batch

  • input_filter – Function used to filter input data, default is no filtering

  • target_filter – Function used to filter target data, default is no filtering

  • callbacks – Optional list of callback functions (not implemented yet)

Returns:

Dict with training (losses, metrics) and validation (if val_loader is provided).

training.trainer.evaluate_diresa(model: diresa_torch.arch.models.Diresa, test_loader: torch.utils.data.DataLoader, criteria: list, loss_weights: list = [1.0, 1.0, 1.0], test_twin_loader: torch.utils.data.DataLoader | None = None, input_filter: callable = <function <lambda>>, target_filter: callable = <function <lambda>>) Dict[str, float]

Evaluates model using test_loader and optionally test_twin_loader, supporting shuffling over the whole dataset. The test_twin_loader must have the same batch size and number of batches as the test_loader. If test_twin_loader is None, input for the twin is produced by shuffling the batches of test_loader.

Parameters:
  • model – The model to evaluate

  • test_loader – Test data loader

  • criteria – List of loss functions [ReconstructionLoss, CovarianceLoss, DistanceLoss]

  • loss_weights – Weighting factor for the different losses

  • test_twin_loader – Twin test data loader, if None then shuffling is done per batch

  • input_filter – Function used to filter input data, default is no filtering

  • target_filter – Function used to filter target data, default is no filtering

Returns:

Dictionary with averaged losses: individual criterion loss + weighted total loss

training.trainer.predict_diresa(model: diresa_torch.arch.models.Diresa, data_loader: torch.utils.data.DataLoader, input_filter: callable = <function <lambda>>) torch.Tensor

predict_diresa is the reconstructed dataset from data_loader passed through model. Provides faster inference as distance and covariance are not computed for inference.

Parameters:
  • model – model to use to produce a prediction

  • data_loader – data to be reconstructed

  • input_filter – function used to filter input data, default is no filtering

Returns:

prediction

training.trainer.order_diresa(model: diresa_torch.arch.models.Diresa, data_loader: torch.utils.data.DataLoader, cumul=False, input_filter: callable = <function <lambda>>, target_filter: callable = <function <lambda>>) list

Sets ordering of the OrderingLayer. Limitations: assumes a flat latent space (rank of latent is 2). If cumul is set to true it iteratively selects next component based on additional combined explanatory power with previously selected components, if it is set to false it only sorts based on the R² of each latent dimension separately.

Parameters:
  • model – The model on which to produce the ordering

  • data_loader – The data_loader from which to produce the ordering

  • cumul – If False only sorts based on single-dimension R²

  • input_filter – Function used to filter input data, default is no filtering

  • target_filter – Function used to filter target data, default is no filtering

Returns:

(cumulative) R² scores of latent components