HazyResearch / H3

Language Modeling with the H3 State Space Model
Apache License 2.0
511 stars 53 forks source link

What is fftconv_bwd doing? #13

Closed darius-lam closed 1 year ago

darius-lam commented 1 year ago

Great work, in this snippet:

    out = fftconv_ref(u, k, D, dropout_mask)
    out = fftconv_fast(u, k, D, dropout_mask)
    g = torch.randn_like(out)
    fftconv_fast_bwd(g, u, k, D, dropout_mask)

what is 'g'? Does fast_bwd perform like an MSE loss over g and the output? Also, the fftconv_func doesn't seem to be within src/ops/, is this intentional?

DanFu09 commented 1 year ago

Thank you for your question! Here g represents the gradient over the output -- equivalent to grad_out in this example.

The FFTConv op can be found here: https://github.com/HazyResearch/H3/blob/main/src/ops/fftconv.py

It simply wraps the CUDA code and stores the right values for the backward pass.

darius-lam commented 1 year ago

thanks!