Open huzecong opened 5 years ago
If a piece of code needs disable_cast()
, it usually means O2 will break in this region. On the other hand, if O2 works, then you probably don't need disable_cast()
.
If we make it no-op, and code with disable_cast()
still runs when switching to O2, I think it'll give bad result and that's harder to debug?
Do you have other use case that supports no-op?
In my case, the function I'm using disable_cast()
with is type-agnostic, and should work with any floating point type. I used disable_cast()
because inputs to the function are fp32 tensors, but the function a whitelisted PyTorch function which would force convert tensors to fp16 under O1, and I wanted to avoid that. I've also tested that the exact same code (with disable_cast()
removed) works under O2.
what's the input type to the function in O2?
I would assume it is also fp16, since O2 convert model weight(and first layer input) all to fp16. If for some reason input for this function is fp32(for example an explicit cast in your model before this), then it should not work since you'll see input/weight type mismatch. Also to back it, i don't think there are whitelist op that doesn't have weight.
Is works under O2
meaning numerically it works? If that's the case, then I'm confused why you want to avoid O1 converting to fp16 for you..
It's a bit tricky in my case because even in O2 I needed a certain part of weights to be FP32 so I manually convert them back after calling amp.initialize
. So for this particular function, it computes some kind of attention given the input tensors, and it basically only involves torch.bmm
(which I believe is a whitelisted function). The reason I had to prevent it going to FP16 is because the tensors have pretty large values and a BMM in FP16 would cause overflow.
After manually converting parts of the parameters back to FP32 the model works numerically. I understand that there's no point going back to using O1 but I kind of feel that this current behavior is against the premise of "changing optimization level and it just works". But it's possible that changing the current behavior would cause more problems and I didn't think it through.
I understand your case now, you have code works with O2, and you need to add disable_cast() to be able to switch to O1, and you want the code to still work without branching when switch back to O2 again.
I still fell disable_casts() in O2 being no-op could be bad, since people using O1 with it will try O2 and see it doesn't work numerically.
we can potentially make disable_casts() actually do what your manually did there in O2 so after switching it still 'just work'. Not sure how feasible to implement this though..
In terms of implementation, would it be possible to use the NoOpHandle
in place of the functional handle?
https://github.com/NVIDIA/apex/blob/0b74bfd92ba0846ca29b9bd2c6dc18dd3a5d9b20/apex/amp/handle.py#L251
As for user expectations, I think users should have a reasonable expectation that not all optimization levels will work (at least, I've seen this mentioned in many tutorials and documentations), but I understand your concern. Maybe it will help if this point is reiterated in the opt_level
section of the apex
documentation.
This change might also be beneficial for people who don't use opt_level
but instead manually set properties.
The original issue mentions that this is a problem with O2
, but it's also a problem with O0
. Having some kind of a solution would be nice.
no matter what levels i use, it always gives me AttributeError: 'AmpState' object has no attribute 'handle', any ideas why? am using pytorchlightning 1.0.8
@lingcong-k
I have the same problem with Pytorch Lightning, did you find a solution? Maybe we should raise an issue in Lightning repo.
When using
amp.disable_cast()
as a function decorator and using optimization levelO2
, the following error will be thrown:This is because
O2
setspatch_torch_functions
toFalse
, and thusamp.init
is not called and no handle is created.A desired behavior here would be just to ignore the function decorator, i.e. using a no-op handle. Thus, I think
amp.init
should still be called, but withenabled=False
.https://github.com/NVIDIA/apex/blob/3ae89c754d945e407a6674aa2006d5a0e35d540e/apex/amp/_initialize.py#L230-L244