HazyResearch / H3

Language Modeling with the H3 State Space Model
Apache License 2.0
511 stars 53 forks source link

inconsistent output from fftconv_func and native pytorch fft #15

Open mojanjp opened 1 year ago

mojanjp commented 1 year ago

Hello, I have noticed that the output of an h3 layer, when provided the same input tensor, is different when use_fast_fftconv=True versus when use_fast_fftconv=False. Similarly, I have tested this on the fftconv_func and fftconv_ref functions (in fftconv.py) and they give different outputs when given the same input arguments. Is this behavior expected? Is fftconv_func performing any approximation that causes this relatively large Euclidean distance between the two outputs? Thanks in advance for your help.

DanFu09 commented 1 year ago

Hi, thank your interest!

Can you provide more details about the differences in output you're seeing? It may be slight numerical errors due to fp32/fp16/bf16.

We have a test in tests/ops that tests for numerical errors - you can run it via pytest tests/ops/test_fftconv.py - can you run that and see if all the tests pass?

mojanjp commented 1 year ago

Thanks so much for your reply. After running tests/ops/test_fftconv.py, I received 3488 passed, 672 skipped. Here is the code snippet I use to measure the difference in outputs:

torch.manual_seed(0)
u = torch.randn(64, 1024, 128, device='cuda', dtype=torch.float32, requires_grad=False)
ssm_cfg = dict(mode='diag', measure='diag-lin')
h3_layer = H3(d_model=128, layer_idx=0, use_fast_fftconv=True, **(ssm_cfg)).to(device='cuda')
h3_layer.eval()
o1 = h3_layer(u)
h3_layer.use_fast_fftconv = False
o2 = h3_layer(u)
error = torch. sum((o1-o2)**2).sqrt()

Using the above code, the error is ~1e-4. But would that be an acceptable range for the numerical errors you mentioned?

DanFu09 commented 1 year ago

Yes, that is within a standard numerical error that can come from two slightly different but mathematically equivalent implementations of the same operation, or from converting between fp32/fp16 and bf16.