latent module
- utils.latent.encode_diresa(model: diresa_torch.arch.models.Diresa, data_loader: torch.utils.data.DataLoader, input_filter: callable = <function <lambda>>) torch.Tensor
Encodes a dataset into a latent representation. Prerequisite: latent components are already ordered.
- Parameters:
model – model to use
data_loader – data to be encoded
input_filter – function used to filter input data, default is no filtering
- Returns:
latent representation of the dataset
- utils.latent.latent_r2_per_variable(model: diresa_torch.arch.models.Diresa, data_loader: torch.utils.data.DataLoader, incr: bool = False, input_filter: callable = <function <lambda>>, target_filter: callable = <function <lambda>>, verbose: bool = False) list
Computes R2 scores of latent components per variable. Variables should be on axis 1 (first is 0). If this is not the case, input_filter and target_filter can be used to swap axes. Prerequisite: latent components are already ordered.
- Parameters:
model – Diresa model
data_loader – DataLoader
incr – If True incremental R2 score are calculated, default is False
input_filter – Function used to filter input data, default is no filtering
target_filter – Function used to filter target data, default is no filtering
verbose – If True, prints most important component per variable, default is False
- Returns:
R2 scores of latent components per variable, shape (latent_dim, nbr of variables)
- utils.latent.latent_vectors(model: diresa_torch.arch.models.Diresa, data_loader: torch.utils.data.DataLoader, factor: float = 0.5, incr: bool = False, input_filter: callable = <function <lambda>>) torch.Tensor
Calculates decoded latent vectors. Prerequisite: model must be ordered. See: https://journals.ametsoc.org/view/journals/aies/4/3/AIES-D-24-0034.1.xml Appendix D: c.Latent variable interpretation.
- Parameters:
model – Diresa model
data_loader – DataLoader
factor – Multiplication factor for standard deviation
incr – If True incremental vectors are calculated, default is False
input_filter – Function used to filter input data, default is no filtering
- Returns:
Decoded latent vectors