Closed lvjiujin closed 2 years ago
if args.fp16 and _use_native_amp: scaler.scale(loss).backward()
the scaler maybe you forget to define it , it can be defined as the following: scaler = torch.cuda.amp.GradScaler()
scaler = torch.cuda.amp.GradScaler()
Hi,
Thanks for your correction and sharing!
the scaler maybe you forget to define it , it can be defined as the following:
scaler = torch.cuda.amp.GradScaler()