thuml / Anomaly-Transformer

About Code release for "Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy" (ICLR 2022 Spotlight), https://openreview.net/forum?id=LzQQ89U1qm_
MIT License
758 stars 199 forks source link

About 3090 #11

Closed xuchunyu123 closed 2 years ago

xuchunyu123 commented 2 years ago

Hello, looking at other questions, I found that you mentioned that your environment is 3090, but I found that the cuda version supported by 3090 is above 11. How did you solve it?

wuhaixu2016 commented 2 years ago

You can change the cuda toolkit version.

elisejiuqizhang commented 2 years ago

Hey,

Bravo to the authors @wuhaixu2016 @Jiehui-Xu , great work. I would just like to add one minor comment that might also be relevant to this (not exactly an "issue" tho), with the following modification I believe this implementation could be executed in torch version 1.5 or higher.

So I saw that in README.md it was specified that your implementation was Pytorch 1.4 which is by default with CUDA 10.1.

My GPU is also NVIDIA GeForce RTX 3090 (queried in commandline with nvidia-smi --query-gpu=name --format=csv,noheader) and CUDA version 11.5 so I decided to go with torch 1.7 with CUDA 11.

In my environment, the error that was raised initially seemed to be within solver.py and specifically in the MiniMax training module so it looked like something is not functioning with the AutoGrad. The error message was something like the follows (with torch.autograd.set_detect_anomaly(True) to raise a more detailed error message)

File "main.py", line 54, in <module> main(config) File "main.py", line 23, in main solver.train() File "/usr/local/data/elisejzh/Projects/Anomaly-Transformer/solver.py", line 191, in train loss2.backward() File "/usr/local/data/elisejzh/anaconda3/envs/AnoTrans/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/usr/local/data/elisejzh/anaconda3/envs/AnoTrans/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward allow_unreachable=True) # allow_unreachable flag RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512, 55]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I did a bit of searching in the pytorch forum and found the following explanation: Link to the first discussion: https://discuss.pytorch.org/t/training-with-two-optimizers-fails-in-torch-1-5-0-works-in-1-1-0/84194 Link to a detailed explanation within the above thread: https://github.com/pytorch/pytorch/issues/39141#issuecomment-636881953

So basically, in answer to @xuchunyu123 's question (which I believed was also raised in another previous issue by @Alex-seven regarding "gradient computation"), if you wanna use a higher version of Pytorch (1.5 or higher) with CUDA support, the solution seems to be as simple as to just reorganize the lines of Minimax strategy in the train() of solver.py, just put the two XXX.step() behind the XXX.backward(), like this:

loss1.backward(retain_graph=True) loss2.backward() self.optimizer.step() self.optimizer.step()

Then it would work even if your environment is with a higher torch version.

The rationale, as explained in the forum, seems to be that, all the torch versions before 1.4 (1.4 also included) is sort of not exactly computing the correct gradient, which was fixed in versions later (1.5 and higher).

In earlier versions, if you put step() before backward() it would run, but the step() method might change parameters that were supposed to remain intact for backward() gradient computation. So it might be giving wrong gradients in fact.

They fixed it in later versions, so now, if you still organize your code as it would work for versions 1.4 and earlier (as in, step() before the gradient computation in backward()), the autograd would flag it as "unreachable" because it was not supposed to be modified. So just make sure to put the step() method after the gradients have been computed in the backward() method.

I was also wondering if the authors would feel like looking into it and testing this a bit, in some higher torch versions compatible with your GPU's CUDA version, since I believe it could potentially raise more compatibility issues to manually force-install an earlier torch version like 1.4 with a much higher cudatoolkit (e.g., 11.X)

Cheers, Elise

wuhaixu2016 commented 2 years ago

@elisejiuqizhang Thanks for your wonderful comment. I will test this asap.

wuhaixu2016 commented 2 years ago

Hi, I have updated the code in solver.py. Now, it works well in torch>=1.4.0 environments. Concretely, the updated code is in the following:

loss1.backward(retain_graph=True) loss2.backward() self.optimizer.step()

Just using the step() once is fine. Thanks a lot for your help @elisejiuqizhang .