NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.34k stars 1.39k forks source link

bmm for style loss #402

Open buoyancy99 opened 5 years ago

buoyancy99 commented 5 years ago

In many applications we need bmm for gram matrix calculation, like in neural style. However it seems gram matrix with 01 mode will always give NaN.

See the issue here https://github.com/pytorch/pytorch/issues/3651

I encountered the same problem with apex. It seems that input should not be casted to fp16 in this case. This happens because input matrcies can contain large values. A good solution would be scaleing them down before bmm and multiply back after bmm solves the problem.

This approach should be implemented internally since users don't know if need to manual scale in most cases

ptrblck commented 5 years ago

Hi @buoyancy99,

if you are expecting large values outside the valid range of FP16, you might decorate the operation with this guard:

with amp.disable_casts():
    # my critical FP32 operation

Would that work for you?