HazyResearch / flash-fft-conv

FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores
Apache License 2.0
267 stars 27 forks source link

Training Hyena-based Models with FlashFFTConv + Safari #16

Open guyjacob opened 8 months ago

guyjacob commented 8 months ago

I saw in #9 that it should be possible to to run training with FlashFFTConv. I integrated the library into the Safari codebase (based on the Hyena example in this repo - followed the Readme there and also diff-ed the code). Trying to run The Pile experiment I'm seeing some issues:

Should I expect this combo of Safari + Hyena + FlashFFTConv to work for training? If so, any ideas how to address the errors above?

Thank you!

DanFu09 commented 8 months ago

You should squeeze the kernel once before you pass it in, so the shape is (H, L). Then it should work! (All Hyena experiments in the paper were on a private fork of safari).

On Mon, Jan 15, 2024 at 8:33 AM Guy Jacob @.***> wrote:

I saw in #9 https://github.com/HazyResearch/flash-fft-conv/issues/9 that it should be possible to to run training with FlashFFTConv. I integrated the library into the Safari codebase (based on the Hyena example in this repo - followed the Readme there and also diff-ed the code). Trying to run The Pile experiment I'm seeing some issues:

  • With the sequence length 4096 I get the following error:

    RuntimeError: Function FlashFFTConvFuncBackward returned an invalid gradient at index 1 - got [864, 4096] but expected shape compatible with [1, 864, 4096]

    Same thing happens with sequence length 2048 (with shape [864, 2048] of course).

  • With sequence length 1024 I get:

    File "/work/venvs/safari/lib/python3.10/site-packages/flashfftconv-0.0.0-py3.10.egg/flashfftconv/conv.py", line 608, in forward return monarch_conv_forward_r2r( RuntimeError: k_f must have shape (H, fftsize + 1, 2)

    Note that this error also happens if I try to run the benchmark_fwd.py in the Hyena folder in this repo with sequence length 1024. So at least this one seems unrelated any mistakes I might have made integrating the code.

Should I expect this combo of Safari + Hyena + FlashFFTConv to work for training? If so, any ideas how to address the errors above?

Thank you!

— Reply to this email directly, view it on GitHub https://github.com/HazyResearch/flash-fft-conv/issues/16, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDDIIUDGD47G2S5L5O4NSTYOVK5DAVCNFSM6AAAAABB3RJJSSVHI2DSMVQWIX3LMV43ASLTON2WKOZSGA4DEMZWGMYTIMQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

guyjacob commented 8 months ago

Thanks for the quick reply.

Funny, I actually tried your suggestion this before opening the issue, and indeed it bypasses the errors I mentioned above, but then other issues come up. And since I wasn't sure if it was a valid fix to begin with so didn't want to divert the discussion unnecessarily.

So after adding the squeeze op, it only runs if I pass find_unused_parameters=True to DDPStrategy. Otherwise, once the first global batch is finished it crashes with an error that some parameters did not receive a gradient - specifically it appears that these are all the bias parameters (and only those):

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by 
making sure all `forward` function outputs participate in calculating loss. 
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Parameters which did not receive grad for rank 1: _forward_module.model.backbone.layers.17.mixer.filter_fn.bias, _forward_module.model.backbone.layers.17.mixer.in_proj.bias, _forward_module.model.backbone.layers.16.mixer.filter_fn.bias, _forward_module.model.backbone.layers.16.mixer.in_proj.bias, _forward_module.model.backbone.layers.15.mixer.filter_fn.bias, _forward_module.model.backbone.layers.15.mixer.in_proj.bias, _forward_module.model.backbone.layers.14.mixer.filter_fn.bias,
...

(pasted only part of the error message for brevity) Is this expected?

In addition, when running specifically with sequence length 4096, this error shows up repeatedly (it doesn't crash because of this, just keeps dumping it over and over):

  ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_kernel<32, 8, 8192, 2, 1, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 819 of file /home/guyj/work/flash-fft-conv/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu failed with invalid argument (1).
  CUDA Runtime Error at: /home/guyj/work/flash-fft-conv/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu:909
  invalid argument

Any of this makes sense?