Open buoyancy99 opened 5 years ago
Hi @buoyancy99,
if you are expecting large values outside the valid range of FP16, you might decorate the operation with this guard:
with amp.disable_casts():
# my critical FP32 operation
Would that work for you?
In many applications we need bmm for gram matrix calculation, like in neural style. However it seems gram matrix with 01 mode will always give NaN.
See the issue here https://github.com/pytorch/pytorch/issues/3651
I encountered the same problem with apex. It seems that input should not be casted to fp16 in this case. This happens because input matrcies can contain large values. A good solution would be scaleing them down before bmm and multiply back after bmm solves the problem.
This approach should be implemented internally since users don't know if need to manual scale in most cases