ThroughputMonitor¶
- class lightning.pytorch.callbacks.ThroughputMonitor(batch_size_fn, length_fn=None, **kwargs)[source]¶
- Bases: - Callback- Computes and logs throughput with the - Throughput- Example: - class MyModel(LightningModule): def setup(self, stage): with torch.device("meta"): model = MyModel() def sample_forward(): batch = torch.randn(..., device="meta") return model(batch) self.flops_per_batch = measure_flops(model, sample_forward, loss_fn=torch.Tensor.sum) logger = ... throughput = ThroughputMonitor(batch_size_fn=lambda batch: batch.size(0)) trainer = Trainer(max_steps=1000, log_every_n_steps=10, callbacks=throughput, logger=logger) model = MyModel() trainer.fit(model) - Notes - It assumes that the batch size is the same during all iterations. 
- It will try to access a - flops_per_batchattribute on your- LightningModuleon every iteration. We suggest using the- measure_flops()function for this. You might want to compute it differently each time based on your setup.
 - Parameters:
 - on_predict_batch_end(trainer, pl_module, outputs, batch, *_, **__)[source]¶
- Called when the predict batch ends. - Return type:
 
 - on_test_batch_end(trainer, pl_module, outputs, batch, *_, **__)[source]¶
- Called when the test batch ends. - Return type:
 
 - on_train_batch_end(trainer, pl_module, outputs, batch, *_)[source]¶
- Called when the train batch ends. :rtype: - None- Note - The value - outputs["loss"]here will be the normalized value w.r.t- accumulate_grad_batchesof the loss returned from- training_step.