gwinndr / MusicTransformer-Pytorch

MusicTransformer written for MaestroV2 using the Pytorch framework for music generation
MIT License
236 stars 48 forks source link

forward() got an unexpected keyword argument 'is_causal' #21

Open ShunyaOz opened 6 months ago

ShunyaOz commented 6 months ago

Hello, I'm a Japanese college student. This is my first time to use MusicTransformer-Pytorch. I refered to READ ME and Google Colab version, and I tried to train this model.

When I wrote python3 train.py -output_dir rpr --rpr -batch_size=4 -epochs=150 -max_sequence=2048

,the terminal outputs TypeError: forward() got an unexpected keyword argument 'is_causal'

Please tell me how to train this model.

SkylarShadow commented 6 months ago

Hello,I encountered the same issue when I ran the code. There're 2 methods I found: 1.run this command:!pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchtext==0.14.1 torchaudio==0.13.1 torchdata==0.5.1 --extra-index-url https://download.pytorch.org/whl/cu117 (I haven't tried it yet)

  1. modify the function forward() in rpr.py and music_transformer.py. add " kwargs " to the arguments of forward() https://github.com/gwinndr/MusicTransformer-Pytorch/blob/7161590fe306578ad02d851e74e1141bd39004af/model/rpr.py#L34 Modify it to: `def forward(self, src, mask=None, src_key_padding_mask=None,kwargs): https://github.com/gwinndr/MusicTransformer-Pytorch/blob/7161590fe306578ad02d851e74e1141bd39004af/model/music_transformer.py#L191 Modify it to:def forward(self, tgt, memory, tgt_mask, memory_mask,tgt_key_padding_mask,memory_key_padding_mask,,**kwargs): ` This method worked on my colab. BTW,please forgive my poor English😂