harvardnlp / TextFlow

MIT License
116 stars 15 forks source link

errors for charptb_discreteflow_af-scf #1

Open yhgon opened 5 years ago

yhgon commented 5 years ago

Thanks for your sharing the awesome work.

I'm trying to reproduce your result on PTB dataset and baseline and charptb_discreteflow_af-af works well as below log but I got error for charptb_discreteflow_af-scf. could you check it?

NaN for loss and NameError: name 'cur_impatience' is not defined thanks

Epoch 97 | train loss: 0.885, val loss: 0.976, val NLL (0): -0.000 | train kl: 0.000, val kl: 0.000 | kl_weight: 1.000, time: 113.07s/4.26s

Epoch 0 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.000, time: 762.07s/19.64s Epoch 1 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.000, time: 758.30s/19.49s Epoch 2 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.000, time: 760.94s/19.00s Epoch 3 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.000, time: 760.13s/18.88s Epoch 4 | train loss: nan, val loss: nan, val NLL (30): nan | train kl: nan, val kl: nan | kl_weight: 0.000, time: 759.14s/59.24s Epoch 5 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.100, time: 756.40s/18.98s Epoch 6 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.200, time: 746.24s/18.95s Epoch 7 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.300, time: 759.76s/18.85s Epoch 8 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.400, time: 757.32s/18.70s Epoch 9 | train loss: nan, val loss: nan, val NLL (30): nan | train kl: nan, val kl: nan | kl_weight: 0.500, time: 751.98s/59.42s Epoch 10 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.600, time: 747.67s/18.53s Epoch 11 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.700, time: 750.80s/18.60s Epoch 12 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.800, time: 750.70s/18.71s Epoch 13 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.900, time: 750.79s/18.66s Epoch 14 | train loss: nan, val loss: nan, val NLL (30): nan | train kl: nan, val kl: nan | kl_weight: 1.000, time: 750.22s/59.38s Epoch 15 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 1.000, time: 746.29s/17.96s Traceback (most recent call last): File "main.py", line 217, in cur_impatience += 1 NameError: name 'cur_impatience' is not defined

yhgon commented 5 years ago

configure charptb_discreteflow_iaf-scf is also fail to learn with NaN

python main.py --dataset ptb --run_name charptb_discreteflow_iaf-scf --dropout_p 0 --hiddenflow_flow_layers 3 --hiddenflow_scf_layers --prior_type IAF
The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.
The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.
The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.
/mnt/git/TextFlow/lstm_flow.py:157: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  y = torch.tensor(x) # [T, B, inp_dim]
/mnt/git/TextFlow/flows.py:188: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  y = torch.tensor(x)
Epoch 0 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.000, time: 4479.11s/133.32s
Epoch 1 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.000, time: 4474.15s/128.58s
Epoch 2 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.000, time: 4920.69s/145.28s
Epoch 3 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.000, time: 4716.16s/135.83s
Epoch 4 | train loss: nan, val loss: nan, val NLL (30): nan | train kl: nan, val kl: nan | kl_weight: 0.000, time: 4872.09s/296.83s
Epoch 5 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.100, time: 4514.57s/141.28s
Epoch 6 | train loss: nan, val loss: nan, val NLL (0): -0.000 | train kl: nan, val kl: nan | kl_weight: 0.200, time: 4950.72s/129.10s
hyunjik11 commented 4 years ago

When running the code on pytorch 0.4.1 (as listed in the Dependencies section of README), the above NaN behaviour did not happen and learning seems stable (although I haven't checked whether the results in the paper are replicated yet)

However when running the code for pytorch version >= 1.0, I can confirm that the above NaN issue emerges. However there is a fix. I've looked into what was causing the NaN, and observed that the log_p_z in kl_loss was growing to be large negative values, eventually going to NaN.

Looking at the release notes for pytorch 1.0 (that shows what has changed from v0.4.1), it turns out that the change made to torch.tensor was causing the issue.

So for version 0.4.1, y = torch.tensor(x)was not returning a detached tensor for tensor inputs x (i.e. y.requires_grad = True if x.requires_grad = True) whereas it is detaching for version >= 1.0 (i.e. y.requires_grad = False). So this leads to incorrect gradients when calling obj_per_accum.backward() due to the calls to torch.tensor in SCFLayer.forward, AFLayer.generate and LSTM_AFLayer.generate. Replacing y=torch.tensor(x) with y=x.clone() seems to resolve the issue and lead to stable training.

hyunjik11 commented 4 years ago

A side note for compatibility of code with pytorch 1.0 and 1.1: it seems like using in-place operations for tensors that lead to the training loss obj_per_accum can give rise to error messages such as RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation in pytorch v1.0 or v1.1, that doesn't have the in-place correctness checks implemented for version >= 1.2. This can be prevented by in-place operations with equivalent operations that create new tensors, in particular the lines that have in-place operations on loss in run_epoch of main.py. Namely use:

        indices = torch.arange(batch_data.shape[0]).view(-1, 1).to(device)  # [T, 1]
        loss_mask = indices >= lengths.view(1, -1)  # [T, B]
        loss_mask = loss_mask[:, :, None].repeat(1, 1, loss_.shape[-1])  # [T, B, s]

        loss = loss * (1 - loss_mask.float())
        kl_loss = kl_loss * (1 - loss_mask.float())

instead of the original code: https://github.com/harvardnlp/TextFlow/blob/0eb8552ae5b370f69e440bca6c6c226b1734c1ba/main.py#L53-L58

In general, I think it's recommended not to use in-place operations for autograd although for pytorch >= 1.2 the in-place correctness checks do confirm that you are getting correct gradients when there is no error message

yhgon commented 4 years ago

thanks for your confirmation and guide. I'll try it.