callback module

class training.callback.Callback

Callback base class to build callbacks for training. Instance can be passed as parameter to the train_diresa function.

__init__()
on_train_start(d)

Executed before the start of a non-staged training.

Parameters:

d – struct, currently only with ‘optimizer’ key (used in LR scheduler callbacks)

on_train_encoder_start(d)

Executed before the start of the encoder training in case of a staged training.

Parameters:

d – struct, currently only with ‘optimizer’ key (used in LR scheduler callbacks)

on_train_decoder_start(d)

Executed before the start of the decoder training in case of a staged training.

Parameters:

d – struct, currently only with ‘optimizer’ key (used in LR scheduler callbacks)

on_train_batch_end(epoch, batch_idx, loss)

Executed at the end of each batch in the training loop.

Parameters:
  • epoch – epoch number

  • batch_idx – batch number

  • loss – list with all losses for the batch

on_train_epoch_end(model, epoch, history)

Executed at the end of each training epoch.

Parameters:
  • model – model (e.g. can be used in saving model checkpoint callbacks)

  • epoch – epoch number

  • history – history with loss values (e.g. can be used in MLOps callbacks)

on_val_epoch_end(model, epoch, history)

Executed at the end of the validation at each epoch.

Parameters:
  • model – model (e.g. can be used in saving model checkpoint callbacks)

  • epoch – epoch number

  • history – history with loss values (e.g. can be used in MLOps callbacks)

class training.callback.LRSched(SchedulerClass, **scheduler_args)

LR scheduler callback, using PyTorch’s lr schedulers, passed as parameter in the train_diresa function. In case of a staged training, the scheduler is used for the encoder and decoder separately. For ReduceLROnPlateau, use the LRRedOnPlateau callback.

__init__(SchedulerClass, **scheduler_args)
Parameters:
  • SchedulerClass – PyTorch LR scheduler class

  • scheduler_args – PyTorch LR scheduler arguments

class training.callback.LRRedOnPlateau(metric='WeightedLoss_val', **reduce_on_plateau_args)

Callback for using PyTorch ReduceLROnPlateau, passed as parameter in the train_diresa function. In case of a staged training, ReduceLROnPlateau is used for the encoder and decoder separately.

__init__(metric='WeightedLoss_val', **reduce_on_plateau_args)
Parameters:
  • metric – metric to monitor

  • reduce_on_plateau_args – PyTorch ReduceLROnPlateau arguments

class training.callback.Checkpoint(save_dir='.', every_epoch=1)

Saves model weights to disk at some frequency.

__init__(save_dir='.', every_epoch=1)
Parameters:
  • save_dir – directory to save model checkpoints

  • every_epoch – number of epochs between checkpoints

class training.callback.HistorySaver(save_dir='.', every_epoch=1)

Saves history to disk at some frequency.

__init__(save_dir='.', every_epoch=1)
Parameters:
  • save_dir – directory to save model history

  • every_epoch – number of epochs between history saving