aistairc / rnng-pytorch

MIT License
20 stars 7 forks source link

training fails for in-order model on new dataset #4

Open benlipkin opened 1 year ago

benlipkin commented 1 year ago

Thanks for releasing this codebase. It has made training time much more efficient for some experiments I have been running with RNNGs.

While attempting to scale this up, and train these models on a larger dataset, I have run into an error, whereby the top-down variant trains successfully, but the in-order model fails repeatedly during training.

I have attached the error trace from running train.py, which seems to be of the form I have typically observed from having mismatched dimensions between predictions and targets or from out-of-bounds indexing. I have also used the logger to print out the token and action sequence that it failed on, so that you can attempt to replicate this.

I wanted to report in case there's an unusual edge case that might have been missed and to assess if you have suggestions for a patch.

Please let me know if I can provide any additional information, which might help with this debugging process.

Thanks!

2022-11-10 01:44:02,545:__main__:INFO: tokens: torch.Size([1, 152])
2022-11-10 01:44:02,545:__main__:INFO: tokens: ['but', 'it', 'got', 'so', 'time', 'consuming', 'it', 'got', 'to', 'be', 'to', 'where', 'um', 'uh', '-', 'huh', 'i', 'had', 'to', 'decide', 'either', 'i', 'was', 'gon', 'na', 'just', 'quit', 'all', 'that', 'mess', 'and', 'go', 'study', 'or', 'or', 'or', 'you', 'know', 'just', 'basically', 'skin', 'through', 'and', 'at', 'that', 'point', 'did', 'you', 'almost', 'say', 'gee', 'i', "'", 'm', 'gon', 'na', 'have', 'to', 'get', 'out', 'of', 'this', 'school', 'to', 'do', 'that', 'uh', 'well', 'i', 'quit', 'i', 'quit', 'football', 'i', 'quit', 'football', 'because', 'i', 'just', 'uh', 'figured', 'that', 'it', "'s", 'not', 'worth', 'it', 'what', 'am', 'i', 'going', 'to', 'do', 'with', 'this', 'after', 'i', 'get', 'out', 'i', "'", 'm', 'not', 'gon', 'na', 'go', 'pro', 'football', 'uh', 'i', "'", 'm', 'not', 'good', 'enough', 'i', 'faced', 'the', 'you', 'know', 'faced', 'that', 'up', 'front', 'and', 'say', 'you', 'know', 'i', "'d", 'rather', 'get', 'my', 'grades', 'and', 'that', "'s", 'what', 'i', 'did', 'i', 'quit', 'and', 'went', 'after', 'after', 'the', 'grades', 'have', 'you', 'been', 'sorry']
2022-11-10 01:44:02,545:__main__:INFO: actions: torch.Size([1, 465])
2022-11-10 01:44:02,545:__main__:INFO: actions: ['SHIFT', 'NT(S)', 'SHIFT', 'NT(NP)', 'REDUCE', 'NT(S)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(ADJP)', 'SHIFT', 'SHIFT', 'REDUCE', 'NT(ADJP)', 'SHIFT', 'NT(NP)', 'REDUCE', 'NT(S)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(PP)', 'SHIFT', 'NT(WHADVP)', 'REDUCE', 'NT(SBAR)', 'SHIFT', 'NT(S)', 'SHIFT', 'SHIFT', 'SHIFT', 'SHIFT', 'NT(NP)', 'REDUCE', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'SHIFT', 'NT(NP)', 'REDUCE', 'NT(S)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(NP)', 'SHIFT', 'REDUCE', 'SHIFT', 'NT(ADVP)', 'REDUCE', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(NP)', 'SHIFT', 'SHIFT', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(S)', 'SHIFT', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(NP)', 'REDUCE', 'REDUCE', 'REDUCE', 'SHIFT', 'SHIFT', 'SHIFT', 'SHIFT', 'NT(NP)', 'REDUCE', 'NT(S)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(ADVP)', 'SHIFT', 'REDUCE', 'SHIFT', 'NT(VP)', 'SHIFT', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(VP)', 'SHIFT', 'SHIFT', 'NT(PP)', 'SHIFT', 'NT(NP)', 'SHIFT', 'REDUCE', 'REDUCE', 'NT(VP)', 'SHIFT', 'SHIFT', 'NT(NP)', 'REDUCE', 'SHIFT', 'NT(ADVP)', 'REDUCE', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(SBAR)', 'SHIFT', 'NT(NP)', 'REDUCE', 'SHIFT', 'SHIFT', 'SHIFT', 'NT(NP)', 'SHIFT', 'REDUCE', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(PP)', 'SHIFT', 'NT(PP)', 'SHIFT', 'NT(NP)', 'SHIFT', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(VP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(NP)', 'REDUCE', 'NT(NP)', 'SHIFT', 'SHIFT', 'SHIFT', 'NT(NP)', 'REDUCE', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(NP)', 'REDUCE', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(NP)', 'REDUCE', 'REDUCE', 'NT(VP)', 'SHIFT', 'NT(NP)', 'REDUCE', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(NP)', 'REDUCE', 'SHIFT', 'NT(SBAR)', 'SHIFT', 'NT(NP)', 'REDUCE', 'NT(S)', 'SHIFT', 'NT(ADVP)', 'REDUCE', 'NT(VP)', 'SHIFT', 'SHIFT', 'SHIFT', 'NT(SBAR)', 'SHIFT', 'NT(NP)', 'REDUCE', 'NT(S)', 'SHIFT', 'NT(VP)', 'SHIFT', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(PP)', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(VP)', 'SHIFT', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(PP)', 'SHIFT', 'NT(NP)', 'REDUCE', 'NT(NP)', 'SHIFT', 'NT(SBAR)', 'SHIFT', 'NT(NP)', 'REDUCE', 'NT(NP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(ADVP)', 'REDUCE', 'NT(ADVP)', 'SHIFT', 'NT(NP)', 'REDUCE', 'REDUCE', 'NT(ADVP)', 'SHIFT', 'REDUCE', 'NT(ADVP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'SHIFT', 'NT(NP)', 'SHIFT', 'REDUCE', 'NT(ADJP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(NP)', 'SHIFT', 'REDUCE', 'NT(ADJP)', 'SHIFT', 'SHIFT', 'SHIFT', 'NT(VP)', 'SHIFT', 'SHIFT', 'SHIFT', 'NT(ADJP)', 'SHIFT', 'SHIFT', 'NT(NP)', 'REDUCE', 'NT(S)', 'SHIFT', 'NT(VP)', 'REDUCE', 'NT(VP)', 'SHIFT', 'NT(NP)', 'SHIFT', 'NT(NP)', 'REDUCE', 'NT(NP)', 'SHIFT', 'NT(ADVP)', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(NP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'SHIFT', 'NT(ADVP)', 'SHIFT', 'REDUCE', 'NT(SBAR)', 'SHIFT', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(NP)', 'REDUCE', 'NT(S)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(NP)', 'REDUCE', 'NT(S)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(ADVP)', 'REDUCE', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(NP)', 'SHIFT', 'REDUCE', 'REDUCE', 'NT(VP)', 'SHIFT', 'SHIFT', 'NT(SBAR)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(WHNP)', 'REDUCE', 'NT(SBAR)', 'SHIFT', 'NT(NP)', 'REDUCE', 'NT(S)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(NP)', 'REDUCE', 'SHIFT', 'NT(VP)', 'REDUCE', 'REDUCE', 'NT(VP)', 'SHIFT', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(PP)', 'REDUCE', 'SHIFT', 'NT(SBAR)', 'SHIFT', 'NT(NP)', 'SHIFT', 'REDUCE', 'NT(NP)', 'SHIFT', 'SHIFT', 'NT(NP)', 'REDUCE', 'REDUCE', 'NT(NP)', 'SHIFT', 'NT(VP)', 'SHIFT', 'NT(ADJP)', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(S)', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(SBAR)', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(SBAR)', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(SBAR)', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(FRAG)', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(S)', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(S)', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(S)', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(S)', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(S)', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(S)', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(S)', 'REDUCE', 'REDUCE', 'REDUCE', 'NT(SBAR)', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'REDUCE', 'FINISH']
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [0,0,0], thread: [0,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [0,0,0], thread: [1,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [0,0,0], thread: [2,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [0,0,0], thread: [3,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [0,0,0], thread: [4,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [0,0,0], thread: [5,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [0,0,0], thread: [6,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [0,0,0], thread: [7,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
Traceback (most recent call last):
  File "train.py", line 679, in <module>
    main(args)
  File "train.py", line 495, in main
    token_ids, action_ids, max_stack_size, subword_end_mask
  File "train.py", line 467, in try_batch_step
    torch.cuda.empty_cache()
  File "/om5/group/evlab/u/lipkinb/.conda/envs/rnng/lib/python3.6/site-packages/torch/cuda/memory.py", line 114, in empty_cache
    torch._C._cuda_emptyCache()
RuntimeError: CUDA error: device-side assert triggered
hiroshinoji commented 1 year ago

Thank you very much for reporting!

Could you try the command again with CUDA_LAUNCH_BLOCKING=1 prefix, like below?

CUDA_LAUNCH_BLOCKING=1 python train.py ...

Without this, CPU and GPU computations will not be synched (for efficiency), which makes debugging hard.

I suspect that the error is caused by some bug, which only occurs for such a very long sentence. Another possibility might be that your input parse tree is not well formed (e.g., separated to two subtrees?), which causes some errors during tree construction.

benlipkin commented 1 year ago

Thanks for getting back to me. So, this was actually run using the CUDA_LAUNCH_BLOCKING=1 flag. Without it, the

/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [0,0,0], thread: [0,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.

assertions do not appear. There are 8 of them here, since I squeezed --h_dim down to 8 to speed up while debugging. If I have --h_dim at e.g., 128, then there will be 128 assertion errors at the indices [0-127].

I can provide some more other info though. I went back to the original corpus (OANC) and found the tree corresponding to this failed sample:

(S (CC but) (S (NP (PRP it)) (VP (VBD got) (ADJP (ADJP (RB so) (NN time) (JJ consuming)) (SBAR (S (NP (PRP it)) (VP (VBD got) (S (VP (TO to) (VP (VB be) (PP (IN to) (SBAR (WHADVP (WRB where)) (S (RB um) (RB uh) (HYPH -) (RB huh) (NP (PRP i)) (VP (VBD had) (S (VP (TO to) (VP (VP (VB decide) (CC either) (S (S (NP (PRP i)) (VP (VBD was) (NP (NN gon) (FW na)) (ADVP (RB just)) (VP (VB quit) (NP (PDT all) (DT that) (NN mess))))) (CC and) (VP (VB go) (NP (NN study)))) (CC or) (CC or) (CC or) (S (NP (PRP you)) (VP (VBP know) (ADVP (RB just) (RB basically)) (VP (NN skin) (RB through))))) (CC and) (S (VP (PP (IN at) (NP (DT that) (NN point))) (VBD did) (NP (PRP you)) (ADVP (RB almost)) (VP (VB say) (SBAR (PRP gee) (NP (PRP i)) (`` ') (VBP m) (NP (NN gon) (FW na)) (VP (VB have) (S (VP (TO to) (VP (VP (VB get) (PP (IN out) (PP (IN of) (NP (DT this) (NN school))))) (S (VP (TO to) (VP (VB do) (NP (NP (DT that)) (RB uh) (IN well) (NP (PRP i)) (VP (VBD quit) (NP (PRP i)) (VP (VP (VB quit) (NP (NN football))) (NP (PRP i)) (VP (VBD quit) (NP (NN football)) (SBAR (IN because) (S (NP (PRP i)) (VP (ADVP (RB just)) (RB uh) (VBD figured) (SBAR (IN that) (S (NP (PRP it)) (VP (VP (VBZ 's) (RB not) (VP (JJ worth) (PP (PRP it)))) (WP what) (S (VP (. am) (VP (. i) (VP (VBG going) (S (VP (TO to) (VP (VB do) (PP (IN with) (NP (NP (DT this)) (SBAR (IN after) (FRAG (NP (NP (PRP i)) (VP (VB get) (ADVP (ADVP (ADVP (ADVP (RP out)) (NP (PRP i))) (`` ')) (VP (VB m) (RB not) (ADJP (NP (VB gon) (FW na)) (VP (VB go) (ADJP (NP (JJ pro) (RB football)) (UH uh) (PRP i) (VP (`` ') (VB m) (RB not) (ADJP (JJ good) (RB enough) (SBAR (S (NP (PRP i)) (VP (VP (VBN faced)) (NP (NP (DT the) (NP (NP (PRP you)) (ADVP (VB know)))) (VP (VBN faced) (IN that) (SBAR (ADVP (IN up) (RB front)) (CC and) (VP (VB say) (SBAR (S (NP (PRP you)) (VP (VBP know) (SBAR (S (NP (PRP i)) (VP (MD 'd) (ADVP (RB rather)) (VP (VP (VB get) (NP (PRP$ my) (NNS grades))) (CC and) (SBAR (IN that) (S (VP (VBZ 's) (SBAR (WHNP (WP what)) (S (NP (PRP i)) (VP (VP (VBD did) (NP (PRP i)) (VP (VB quit))) (CC and) (VP (VBD went) (PP (IN after)) (SBAR (IN after) (NP (NP (NP (DT the) (NNS grades)) (VBP have) (NP (PRP you))) (VP (VBN been) (ADJP (JJ sorry))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))

It was successfully handled by preprocess.py, which leads me to believe it is a valid tree(?)[unless this was parsed improperly there?], although it is indeed very long, and unusual, since it reflects a transcription of a spoken news interview.

I've also tried running without the --fixed_stack flag, and I actually see a different, potentially more informative error there.

Traceback (most recent call last):
  File "train.py", line 677, in <module>
    main(args)
  File "train.py", line 493, in main
    token_ids, action_ids, max_stack_size, subword_end_mask
  File "train.py", line 476, in try_batch_step
    token_ids, action_ids, max_stack_size, subword_end_mask, num_divides * 2
  File "train.py", line 459, in try_batch_step
    token_ids, action_ids, max_stack_size, subword_end_mask, num_divides
  File "train.py", line 440, in batch_step
    div_token_ids, div_action_ids, max_stack_size, div_subword_end_mask
  File "train.py", line 412, in calc_loss
    subword_end_mask=subword_end_mask,
  File "/om5/group/evlab/u/lipkinb/.conda/envs/rnng/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/storage001.ib.cluster/om2/group/evlab/u/lipkinb/projects/rnng/rnng-pytorch/models.py", line 219, in forward
    a_loss, _ = self.action_loss(actions, self.action_dict, action_contexts)
  File "/net/storage001.ib.cluster/om2/group/evlab/u/lipkinb/projects/rnng/rnng-pytorch/models.py", line 240, in action_loss
    assert hiddens.size()[:2] == actions.size()
AssertionError

It seems that some mismatch is occuring between the length of actions derived during preprocess.py and those that are generated on the fly during unroll_states.

Let me know what other info would be helpful as you explore this. And thanks again for looking into this.