CUNY-CL / yoyodyne

Small-vocabulary sequence-to-sequence generation with optional feature conditioning
Apache License 2.0
25 stars 15 forks source link

Attempts migration to PyTorch >= 2.0.0. #173

Closed kylebgorman closed 2 months ago

kylebgorman commented 2 months ago

See #60 for context, though we are not yet done with it since we are not migrating to PyTorch-Lightning 2.

Adamits commented 2 months ago

I have been running tests with torch 2.1.0 (though 2.2.x is out), and everything seems to work ok without any changes to code.

Given the additional args you added (or uncommented) to the torch mha calls, have you tested the pg-transformer?

Additionally, torch 2.0 introduces torch.compile which I believe offers some JIT method of optimizing models. On my mac cpu, this was erroring. I think the reason torch >= 2.0 is backwards compatible is because this torch.compile is optional. We shoudl get an understanding of what all it gets us and how to use it, though.

kylebgorman commented 2 months ago

I have been running tests with torch 2.1.0 (though 2.2.x is out), and everything seems to work ok without any changes to code.

Given the additional args you added (or uncommented) to the torch mha calls, have you tested the pg-transformer?

I do have a test for that:

# Model parameters from Singer & Kann:
# https://aclanthology.org/2020.sigmorphon-1.8/
yoyodyne-train \
    --experiment "${LANGUAGE}" \
    --train "${TRAIN}" \
    --val "${VAL}" \
    --model_dir "${MODEL_DIR}" \
    --arch "${ARCH}" \
    --batch_size 400 \
    --hidden_size 1024 \
    --embedding_size 256 \
    --decoder_layers 4 \
    --encoder_layers 4 \
    --dropout .3 \
    --beta2 .98 \
    --gradient_clip_val 3 \
    --max_epochs 60 \
    --check_val_every_n_epoch 16 \
    --log_every_n_step 2 \
    --max_epochs 800 \
    --scheduler warmupinvsqrt \
    --warmup_steps 4000 \
    --seed 49 \
    --accelerator gpu

Happy to tweak parameters if you think those are bad.

Additionally, torch 2.0 introduces torch.compile which I believe offers some JIT method of optimizing models. On my mac cpu, this was erroring. I think the reason torch >= 2.0 is backwards compatible is because this torch.compile is optional. We shoudl get an understanding of what all it gets us and how to use it, though.

Sure, let's play with that too.

kylebgorman commented 2 months ago

Warnings I'm seeing with the transformer and Torch 2.2:

/home/kbg/.miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/transformer.py:286: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True
/home/kbg/.miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/functional.py:5109: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.

Both seem harmless. I am not sure what the fix is for the former; with the latter presumably you make sure you create the tensor types with the same type?

Adamits commented 2 months ago

Warnings I'm seeing with the transformer and Torch 2.2:

/home/kbg/.miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/transformer.py:286: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True /home/kbg/.miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/functional.py:5109: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead. Both seem harmless. I am not sure what the fix is for the former; with the latter presumably you make sure you create the tensor types with the same type?

Yeah I saw these too. For the mismatched type, I was struggling to actually figure out which types need to change and did not solve this on my end. Though I only tried a couple of things.

kylebgorman commented 2 months ago

Yeah I saw these too. For the mismatched type, I was struggling to actually figure out which types need to change and did not solve this on my end. Though I only tried a couple of things.

We could ask the devs, or we could silence it like we did for other warnings.

Adamits commented 2 months ago

We could ask the devs, or we could silence it like we did for other warnings.

Ok. Just realized this is a draft. Shall we split up any pending subtasks?

Which architectures still need to be tested? (I have already tested transformer and attentive_lstm) I can also work on the boolean warning. Then maybe you could double check the is_causal flag behavior?

I think torch.compile is outside the scope of what we want here for now.

kylebgorman commented 2 months ago

Ok. Just realized this is a draft. Shall we split up any pending subtasks?

Which architectures still need to be tested? (I have already tested transformer and attentive_lstm) I can also work on the boolean warning. Then maybe you could double check the is_causal flag behavior?

Yeah, why don't you check on the warnings?

I'll be done testing before end of day.

I think torch.compile is outside the scope of what we want here for now.

+1.

Adamits commented 2 months ago

I'll be done testing before end of day.

Nice, I will try to keep up with you then. I need to finish up making my lecture and teach today :P. But should have time in the afternoon to look into the warning more deeply.

kylebgorman commented 2 months ago

Hard error on the transducer: RuntimeError: could not create a primitive descriptor for an LSTM forward propagation primitive. No idea what this means. (Will check to see if same error is present at Yoyodyne HEAD.)

Adamits commented 2 months ago

Note: is_causal may have been replaced in more recent torch releases? https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoder.html

Adamits commented 2 months ago

I don't think I can fork this, but I had just confused myself on the mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead warning.

We can solve it in

by changing generate_square_subsequent_mask to

@staticmethod
   def generate_square_subsequent_mask(length: int) -> torch.Tensor:
        """Generates the target mask so the model cannot see future states.

        Args:
            length (int): length of the sequence.

        Returns:
            torch.Tensor: mask of shape length x length.
        """
        return torch.triu(torch.ones((length, length), dtype=torch.bool), diagonal=1)
kylebgorman commented 2 months ago

Hard error on the transducer: RuntimeError: could not create a primitive descriptor for an LSTM forward propagation primitive. No idea what this means. (Will check to see if same error is present at Yoyodyne HEAD.)

So, the breakage of the transducer can be isolated to the Torch migration. (It works still at HEAD.) But the error we get is basically useless so I'm not sure what to do with it. That seems like a major blocker.

kylebgorman commented 2 months ago

We can solve it in

by changing generate_square_subsequent_mask to

This works, thanks.

kylebgorman commented 2 months ago

So, the breakage of the transducer can be isolated to the Torch migration. (It works still at HEAD.) But the error we get is basically useless so I'm not sure what to do with it. That seems like a major blocker.

One more update on this: it only shows up on CPU with certain combinations of precision, so I think it's just a bad error message saying that it doesn't support that combination of accelerators and precisions.

kylebgorman commented 2 months ago

All my local tests pass so I am converting to a true PR. @Adamits lmk what you think when you get a chance.

Adamits commented 2 months ago

One more update on this: it only shows up on CPU with certain combinations of precision, so I think it's just a bad error message saying that it doesn't support that combination of accelerators and precisions.

Out of curiosity could you share your script? I ran transducer on my mac with torch 2.2.0 and it ran fine.

kylebgorman commented 2 months ago

Out of curiosity could you share your script? I ran transducer on my mac with torch 2.2.0 and it ran fine.

To trigger this, try:

I don't think any of the other options matter re: this.