NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.34k stars 1.39k forks source link

Cannot use `amp.disable_cast()` when `patch_torch_functions=False`. #510

Open huzecong opened 5 years ago

huzecong commented 5 years ago

When using amp.disable_cast() as a function decorator and using optimization level O2, the following error will be thrown:

AttributeError: 'AmpState' object has no attribute 'handle'

This is because O2 sets patch_torch_functions to False, and thus amp.init is not called and no handle is created.

A desired behavior here would be just to ignore the function decorator, i.e. using a no-op handle. Thus, I think amp.init should still be called, but with enabled=False.

https://github.com/NVIDIA/apex/blob/3ae89c754d945e407a6674aa2006d5a0e35d540e/apex/amp/_initialize.py#L230-L244

FDecaYed commented 5 years ago

If a piece of code needs disable_cast(), it usually means O2 will break in this region. On the other hand, if O2 works, then you probably don't need disable_cast(). If we make it no-op, and code with disable_cast() still runs when switching to O2, I think it'll give bad result and that's harder to debug? Do you have other use case that supports no-op?

huzecong commented 5 years ago

In my case, the function I'm using disable_cast() with is type-agnostic, and should work with any floating point type. I used disable_cast() because inputs to the function are fp32 tensors, but the function a whitelisted PyTorch function which would force convert tensors to fp16 under O1, and I wanted to avoid that. I've also tested that the exact same code (with disable_cast() removed) works under O2.

FDecaYed commented 5 years ago

what's the input type to the function in O2? I would assume it is also fp16, since O2 convert model weight(and first layer input) all to fp16. If for some reason input for this function is fp32(for example an explicit cast in your model before this), then it should not work since you'll see input/weight type mismatch. Also to back it, i don't think there are whitelist op that doesn't have weight. Is works under O2 meaning numerically it works? If that's the case, then I'm confused why you want to avoid O1 converting to fp16 for you..

huzecong commented 5 years ago

It's a bit tricky in my case because even in O2 I needed a certain part of weights to be FP32 so I manually convert them back after calling amp.initialize. So for this particular function, it computes some kind of attention given the input tensors, and it basically only involves torch.bmm (which I believe is a whitelisted function). The reason I had to prevent it going to FP16 is because the tensors have pretty large values and a BMM in FP16 would cause overflow.

After manually converting parts of the parameters back to FP32 the model works numerically. I understand that there's no point going back to using O1 but I kind of feel that this current behavior is against the premise of "changing optimization level and it just works". But it's possible that changing the current behavior would cause more problems and I didn't think it through.

FDecaYed commented 5 years ago

I understand your case now, you have code works with O2, and you need to add disable_cast() to be able to switch to O1, and you want the code to still work without branching when switch back to O2 again.

I still fell disable_casts() in O2 being no-op could be bad, since people using O1 with it will try O2 and see it doesn't work numerically.

we can potentially make disable_casts() actually do what your manually did there in O2 so after switching it still 'just work'. Not sure how feasible to implement this though..

huzecong commented 5 years ago

In terms of implementation, would it be possible to use the NoOpHandle in place of the functional handle? https://github.com/NVIDIA/apex/blob/0b74bfd92ba0846ca29b9bd2c6dc18dd3a5d9b20/apex/amp/handle.py#L251

As for user expectations, I think users should have a reasonable expectation that not all optimization levels will work (at least, I've seen this mentioned in many tutorials and documentations), but I understand your concern. Maybe it will help if this point is reiterated in the opt_level section of the apex documentation.

This change might also be beneficial for people who don't use opt_level but instead manually set properties.

jbraeburn commented 4 years ago

The original issue mentions that this is a problem with O2, but it's also a problem with O0. Having some kind of a solution would be nice.

lingcong-k commented 3 years ago

no matter what levels i use, it always gives me AttributeError: 'AmpState' object has no attribute 'handle', any ideas why? am using pytorchlightning 1.0.8

aoussou commented 3 years ago

@lingcong-k

I have the same problem with Pytorch Lightning, did you find a solution? Maybe we should raise an issue in Lightning repo.