Closed freshlu11 closed 2 months ago
I just found one loss.backward() in function: memory_check.
but the aim of this function is to find the max batch size, am I right? So where does the backward process occurs?
Sorry, I have find the update code:
norm_value = scaler(loss*loss_scale, model_optim, clip_grad=max_norm, parameters=self.model.parameters(), create_graph=False, update_grad=((i + 1) % acc_it == 0))
I just found one loss.backward() in function: memory_check.
but the aim of this function is to find the max batch size, am I right? So where does the backward process occurs?