Closed darius-lam closed 1 year ago
Thank you for your question! Here g
represents the gradient over the output -- equivalent to grad_out
in this example.
The FFTConv op can be found here: https://github.com/HazyResearch/H3/blob/main/src/ops/fftconv.py
It simply wraps the CUDA code and stores the right values for the backward pass.
thanks!
Great work, in this snippet:
what is 'g'? Does fast_bwd perform like an MSE loss over g and the output? Also, the fftconv_func doesn't seem to be within src/ops/, is this intentional?