MAZiqing / FEDformer

MIT License
625 stars 118 forks source link

Getting a 'nan' Loss after about 12 training epochs #85

Open marc-martini opened 2 months ago

marc-martini commented 2 months ago

Hi,

Well done on this amazing work and thank you so much for putting this on github and sharing.

My apologies if this is a silly question. I have been struggling to figure out where i am going wrong.

I am trying to recreate the results on the Electricity data set. The training runs perfectly, however up till about epoch 12 or 13. At this point i get a 'nan' loss.

Please help me understand where i am going wrong.

thank you

tianzhou2011 commented 2 months ago

I am not sure what happened here....maybe add a breakpoint using pdb.set_trace() can help you identify the problem. Just add a if loss==nan: pdb.set_trace() to check the intermediate variable or input values.

efg001 commented 2 months ago

Was going to open an issue on this... @marc-martini in additional what Tianzhou shared, also make sure you are running on an A100 equivalent GPU that has good support for FP64 for training(lmk what device you ran it on : ) If you are running on A100, continue reading....

I ran into similar issues and have isolated the issue at the line applying activation function on the result of frequency domain attention which contains complex number (this line xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xkft) . All numbers in the input matrix are 'good'[not nan or inf ] but I got some inf entries after it) I 'fixed' it by using softmax activation instead of tanh can you give it a try? (First set softmax as the activation function in run config then add activation=configs.cross_activation to https://github.com/MAZiqing/FEDformer/blob/c0f6b972def125691434d62be1ecadf710ae921a/models/FEDformer.py#L58-L73) (detail below)

First of all I dont know if PyTorch's support for complex number has been through the test of time see 1.https://github.com/pytorch/pytorch/issues/47052 2.https://pytorch.org/docs/stable/complex_numbers.html (Complex tensors is a beta feature and subject to change. )

Second, I believe the default activation function tanh is not a good fit for complex number

image

I swapped out tanh with softmax and no longer see nan weight/loss/gradient

image

I haven't seen this when running the code with some of the ETT and I haven't tried running it on electricity data either. I got nan when trying to run the model on my own dataset.

This is just of something I found. I haven't found enough evidence to support my theory for opening an issue: feel free to ignore it TianZhou -- I am only sharing it now because Marc just ran into the same issue. I am working on something else at the moment will loop back to this.

marc-martini commented 2 months ago

Thank you for the guidance. I tried changing to sofmax, however with no change. What i have got it down to is that the weights of the weights of the Conv1d and the Linear layers in the TokenEmbedding and TemporalEmbedding Layers become 'nan' after some training time. Any ideas?

thank you

efg001 commented 2 months ago

Weight become nans are the result of invalid gradient update, invalid gradients are calculated from invalid layer output using SGD.

Whatever code you added to capture the nan, I'm guessing it either 1. missed the initial nan output from layer or 2. did not capture an invalid layer output for example inf. I think we want to know which layer is the root cause of the nan weight

        def nan_hook(module, input, output, model_name):
            # Check if the output is a tuple and handle accordingly
            if isinstance(output, tuple):
                outputs = output
            else:
                outputs = (output,)

            for out in outputs:
                if  isinstance(out, list) or out is None: continue # that attention list is prob for debugger skip for now todo
                if torch.isnan(out).any():
                    print(f"NaN detected in {module}, model {model_name}")
                    for name, param in module.named_parameters():
                        if torch.isnan(param).any():
                            print(f"NaN detected in parameter: {name}")
                    pdb.set_trace()
                    #raise ValueError("NaN detected during forward pass")

        model = model_dict[self.args.model].Model(self.args).float()
        if(self.args.detect_nan):
            def register_hooks(module):
                # Register hook on the current module regardless of whether it has children
                module.register_forward_hook(lambda m, i, o: nan_hook(m, i, o, self.args))

                # Recurse through children modules to register hooks on them as well
                for child_name, child_module in module.named_children():
                    register_hooks(child_module)

            register_hooks(model)
        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
        return model
...
this is what I used

Also try to print loss for every iteration and check if loss and gradient are within reasonable range