StochasticWeightAveraging¶
- class lightning.pytorch.callbacks.StochasticWeightAveraging(swa_lrs, swa_epoch_start=0.8, annealing_epochs=10, annealing_strategy='cos', avg_fn=None, device=device(type='cpu'))[source]¶
- Bases: - Callback- Implements the Stochastic Weight Averaging (SWA) Callback to average a model. - Stochastic Weight Averaging was proposed in - Averaging Weights Leads to Wider Optima and Better Generalizationby Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson (UAI 2018).- This documentation is highly inspired by PyTorch’s work on SWA. The callback arguments follow the scheme defined in PyTorch’s - swa_utilspackage.- For a SWA explanation, please take a look here. - Warning - This is an experimental feature. - Warning - StochasticWeightAveragingis currently not supported for multiple optimizers/schedulers.- Warning - StochasticWeightAveragingis currently only supported on every epoch.- See also how to enable it directly on the Trainer. - Parameters:
- swa_lrs¶ ( - Union[- float,- list[- float]]) –- The SWA learning rate to use: - float. Use this value for all parameter groups of the optimizer.
- List[float]. A list values for each parameter group of the optimizer.
 
- swa_epoch_start¶ ( - Union[- int,- float]) – If provided as int, the procedure will start from the- swa_epoch_start-th epoch. If provided as float between 0 and 1, the procedure will start from- int(swa_epoch_start * max_epochs)epoch
- annealing_epochs¶ ( - int) – number of epochs in the annealing phase (default: 10)
- annealing_strategy¶ ( - Literal[- 'cos',- 'linear']) –- Specifies the annealing strategy (default: “cos”): - "cos". For cosine annealing.
- "linear"For linear annealing
 
- avg_fn¶ ( - Optional[- Callable[[- Tensor,- Tensor,- Tensor],- Tensor]]) – the averaging function used to update the parameters; the function must take in the current value of the- AveragedModelparameter, the current value of- modelparameter and the number of models already averaged; if None, equally weighted average is used (default:- None)
- device¶ ( - Union[- device,- str,- None]) – if provided, the averaged model will be stored on the- device. When None is provided, it will infer the device from- pl_module. (default:- "cpu")
 
 - static avg_fn(averaged_model_parameter, model_parameter, num_averaged)[source]¶
- Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97. - Return type:
 
 - load_state_dict(state_dict)[source]¶
- Called when loading a checkpoint, implement to reload callback state given callback’s - state_dict.
 - on_train_epoch_end(trainer, *args)[source]¶
- Called when the train epoch ends. - To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the - lightning.pytorch.core.LightningModuleand access them in this hook:- class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear() - Return type:
 
 - reset_batch_norm_and_save_state(pl_module)[source]¶
- Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154. - Return type:
 
 - reset_momenta()[source]¶
- Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165. - Return type:
 
 - setup(trainer, pl_module, stage)[source]¶
- Called when fit, validate, test, predict, or tune begins. - Return type:
 
 - static update_parameters(average_model, model, n_averaged, avg_fn)[source]¶
- Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112. - Return type: