trainer module
- training.trainer.train_diresa(model: diresa_torch.arch.models.Diresa, train_loader: torch.utils.data.DataLoader, criteria: list, optimizer: torch.optim.Optimizer, num_epochs: int = 10, val_loader: torch.utils.data.DataLoader | None = None, loss_weights: list = [1.0, 1.0, 1.0], staged_training: bool = False, train_twin_loader: torch.utils.data.DataLoader | None = None, val_twin_loader: torch.utils.data.DataLoader | None = None, input_filter: callable = <function <lambda>>, target_filter: callable = <function <lambda>>, callbacks: list | None = []) dict[str, dict[str, list[float]]]
Trains model. Needs to provide multiple loss function in order to train de different parts of the model. CovarianceLoss and DistanceLoss are used to produce an interpretable latent space while ReconstructionLoss is used to produce a reconstructed output. Input for the twin encoder can be given by a separate twin dataloader, supporting shuffling over the whole dataset. In this case the twin dataloader must have the same batch size and number of batches as the input dataloader. If the twin dataloader is None, input for the twin is produced by shuffling the batch.
- Parameters:
model – The model to train
train_loader – Training data loader
criteria – List of Loss function. With order [ReconstructionLoss, CovarianceLoss, DistanceLoss]
optimizer – Optimizer
num_epochs – Number of epochs
val_loader – Optional validation loader
loss_weights – Weighting factor for the different losses. With order [ReconstructionLoss, CovarianceLoss, DistanceLoss]
staged_training – If set to True will train the encoder and the decoder separately for num_epochs each.
train_twin_loader – Twin training data loader (needs also val_twin_loader), if None then shuffling is done per batch
val_twin_loader – Twin validation data loader, if None then shuffling is done per batch
input_filter – Function used to filter input data, default is no filtering
target_filter – Function used to filter target data, default is no filtering
callbacks – Optional list of callback objects
- Returns:
Dict with training (losses, metrics) and validation (if val_loader is provided).
- training.trainer.evaluate_diresa(model: diresa_torch.arch.models.Diresa, test_loader: torch.utils.data.DataLoader, criteria: list, loss_weights: list = [1.0, 1.0, 1.0], test_twin_loader: torch.utils.data.DataLoader | None = None, input_filter: callable = <function <lambda>>, target_filter: callable = <function <lambda>>) Dict[str, float]
Evaluates model using test_loader and optionally test_twin_loader, supporting shuffling over the whole dataset. The test_twin_loader must have the same batch size and number of batches as the test_loader. If test_twin_loader is None, input for the twin is produced by shuffling the batches of test_loader.
- Parameters:
model – The model to evaluate
test_loader – Test data loader
criteria – List of loss functions [ReconstructionLoss, CovarianceLoss, DistanceLoss]
loss_weights – Weighting factor for the different losses
test_twin_loader – Twin test data loader, if None then shuffling is done per batch
input_filter – Function used to filter input data, default is no filtering
target_filter – Function used to filter target data, default is no filtering
- Returns:
Dictionary with averaged losses: individual criterion loss + weighted total loss
- training.trainer.predict_diresa(model: diresa_torch.arch.models.Diresa, data_loader: torch.utils.data.DataLoader, input_filter: callable = <function <lambda>>) torch.Tensor
predict_diresa is the reconstructed dataset from data_loader passed through model. Provides faster inference as distance and covariance are not computed for inference.
- Parameters:
model – model to use to produce a prediction
data_loader – data to be reconstructed
input_filter – function used to filter input data, default is no filtering
- Returns:
prediction
- training.trainer.order_diresa(model: diresa_torch.arch.models.Diresa, data_loader: torch.utils.data.DataLoader, cumul=False, input_filter: callable = <function <lambda>>, target_filter: callable = <function <lambda>>) list
Sets ordering of the OrderingLayer. Limitations: assumes a flat latent space (rank of latent is 2). If cumul is set to true it iteratively selects next component based on additional combined explanatory power with previously selected components, if it is set to false it only sorts based on the R² of each latent dimension separately.
- Parameters:
model – The model on which to produce the ordering
data_loader – The data_loader from which to produce the ordering
cumul – If False only sorts based on single-dimension R²
input_filter – Function used to filter input data, default is no filtering
target_filter – Function used to filter target data, default is no filtering
- Returns:
(cumulative) R² scores of latent components