Closed abhi-mosaic closed 1 year ago
Rather than cache the self.num_fwd_flops, we only cache self.n_params which is necessary to avoid FSDP sharding issues. And then we calculate the FLOPs dynamically on every forward pass for every batch (and we read the batch's max_seq_len.)
self.num_fwd_flops
self.n_params
max_seq_len
Rather than cache the
self.num_fwd_flops
, we only cacheself.n_params
which is necessary to avoid FSDP sharding issues. And then we calculate the FLOPs dynamically on every forward pass for every batch (and we read the batch'smax_seq_len
.)