callback module
- class training.callback.Callback
Callback base class to build callbacks for training. Instance can be passed as parameter to the train_diresa function.
- __init__()
- on_train_start(d)
Executed before the start of a non-staged training.
- Parameters:
d – struct, currently only with ‘optimizer’ key (used in LR scheduler callbacks)
- on_train_encoder_start(d)
Executed before the start of the encoder training in case of a staged training.
- Parameters:
d – struct, currently only with ‘optimizer’ key (used in LR scheduler callbacks)
- on_train_decoder_start(d)
Executed before the start of the decoder training in case of a staged training.
- Parameters:
d – struct, currently only with ‘optimizer’ key (used in LR scheduler callbacks)
- on_train_batch_end(epoch, batch_idx, loss)
Executed at the end of each batch in the training loop.
- Parameters:
epoch – epoch number
batch_idx – batch number
loss – list with all losses for the batch
- on_train_epoch_end(model, epoch, history)
Executed at the end of each training epoch.
- Parameters:
model – model (e.g. can be used in saving model checkpoint callbacks)
epoch – epoch number
history – history with loss values (e.g. can be used in MLOps callbacks)
- on_val_epoch_end(model, epoch, history)
Executed at the end of the validation at each epoch.
- Parameters:
model – model (e.g. can be used in saving model checkpoint callbacks)
epoch – epoch number
history – history with loss values (e.g. can be used in MLOps callbacks)
- class training.callback.LRSched(SchedulerClass, **scheduler_args)
LR scheduler callback, using PyTorch’s lr schedulers, passed as parameter in the train_diresa function. In case of a staged training, the scheduler is used for the encoder and decoder separately. For ReduceLROnPlateau, use the LRRedOnPlateau callback.
- __init__(SchedulerClass, **scheduler_args)
- Parameters:
SchedulerClass – PyTorch LR scheduler class
scheduler_args – PyTorch LR scheduler arguments
- class training.callback.LRRedOnPlateau(metric='WeightedLoss_val', **reduce_on_plateau_args)
Callback for using PyTorch ReduceLROnPlateau, passed as parameter in the train_diresa function. In case of a staged training, ReduceLROnPlateau is used for the encoder and decoder separately.
- __init__(metric='WeightedLoss_val', **reduce_on_plateau_args)
- Parameters:
metric – metric to monitor
reduce_on_plateau_args – PyTorch ReduceLROnPlateau arguments