state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.7k stars 1.06k forks source link

Significant differences in gradients between `_ref` and `_fn` when using the complex formulation. #571

Open karannb opened 1 week ago

karannb commented 1 week ago

First off, the code is really well written and I was able to shift to the complex case very easily, so thanks! I was using complex dynamics for an application, and was seeing large differences in gradients computed by mamba_inner_ref and mamba_inner_fn, the scan functions worked fine and performed much better in my test (however, even for that case I had to lower my tolerance to 1e-6 from 1e-8 for the real case). I am attaching a reproducible sample below for the test for mamba_inner_ref, I am assuming this happens because of torch-with-complex-numbers is still under development, but would appreciate any guidance on how to solve this.

'''
Check for mamba_inner_fn
'''
import math
import torch
from torch import nn
from tqdm import tqdm
from einops import repeat
import torch.nn.functional as F
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref

# Define a random seed for reproducibility
torch.manual_seed(0)
torch.cuda.manual_seed(0)

# Set device to CUDA if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def test_gradient_implementation(device=device):
    # Create random input tensors and parameters
    batch_size = 4
    dstate = 10
    dim = 3
    seqlen = 7

    xz = torch.randn(batch_size, dstate*2, seqlen, device=device, requires_grad=True)
    conv1d_weight = torch.randn(dstate, 1, 4, device=device, requires_grad=True)
    conv1d_bias = torch.randn(dstate, device=device, requires_grad=True)
    x_proj_weight = torch.randn(dim*4 + 1, dstate, device=device, requires_grad=True)
    dt_proj_weight = torch.randn(dstate, 1, device=device, requires_grad=True)
    # Initialize special dt projection to preserve variance at initialization
    dt_init_std = 1**-0.5
    nn.init.uniform_(dt_proj_weight, -dt_init_std, dt_init_std)
    dt_bias = torch.randn(dstate, device=device, requires_grad=True)
    # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
    dt = torch.exp(
        torch.rand(dstate) * (math.log(0.1) - math.log(0.001))
        + math.log(0.001)
    ).clamp(min=1e-4)
    # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
    inv_dt = dt + torch.log(-torch.expm1(-dt))
    with torch.no_grad():
        dt_bias.copy_(inv_dt)
    out_proj_weight = torch.randn(int(dstate/2), dstate, device=device, requires_grad=True)
    out_proj_bias = None
    A_log = torch.log(repeat(0.5 - 1j*torch.arange(0, dim, dtype=torch.float32, device=device),
               "n -> d n",
               d=dstate,
               ).contiguous())
    A_log.requires_grad = True
    A = -torch.exp(A_log).to(torch.cfloat)
    D = torch.randn(dstate, device=device, requires_grad=True)

    A.retain_grad()
    D.retain_grad()

    # Forward pass through mamba_inner_fn
    output_fn = mamba_inner_fn(
        xz,
        conv1d_weight,
        conv1d_bias,
        x_proj_weight,
        dt_proj_weight,
        out_proj_weight,
        out_proj_bias,
        A,
        None,  # input-dependent B
        None,  # input-dependent C
        D,
        delta_bias=dt_bias,
        delta_softplus=True
    )

    # Forward pass through mamba_inner_ref
    output_ref = mamba_inner_ref(
        xz,
        conv1d_weight,
        conv1d_bias,
        x_proj_weight,
        dt_proj_weight,
        out_proj_weight,
        out_proj_bias,
        A,
        None,  # input-dependent B
        None,  # input-dependent C
        D,
        delta_bias=dt_bias,
        delta_softplus=True
    )

    # Check if outputs are the same
    out_mismatch = False
    if not torch.allclose(output_fn, output_ref, atol=1e-6):
        print("Outputs do not match! Diff: ", torch.norm(output_fn - output_ref))
        out_mismatch = True

    # Create dummy targets
    target = torch.randn_like(output_fn)

    # Zero gradients
    def zero_gradients(*tensors):
        for tensor in tensors:
            if tensor is not None and tensor.grad is not None:
                tensor.grad.zero_()

    # Compute loss for mamba_inner_fn
    loss_fn = F.mse_loss(output_fn, target)

    # Backward pass through mamba_inner_fn
    zero_gradients(
        xz, conv1d_weight, conv1d_bias, x_proj_weight,
        dt_proj_weight, out_proj_weight, out_proj_bias, A, D
    )
    loss_fn.backward(retain_graph=True)
    grad_xz_fn = xz.grad.clone() if xz.grad is not None else None
    grad_conv1d_weight_fn = conv1d_weight.grad.clone() if conv1d_weight.grad is not None else None
    grad_conv1d_bias_fn = conv1d_bias.grad.clone() if conv1d_bias.grad is not None else None
    grad_x_proj_weight_fn = x_proj_weight.grad.clone() if x_proj_weight.grad is not None else None
    grad_dt_proj_weight_fn = dt_proj_weight.grad.clone() if dt_proj_weight.grad is not None else None
    grad_out_proj_weight_fn = out_proj_weight.grad.clone() if out_proj_weight.grad is not None else None
    grad_A_fn = A.grad.clone() if A.grad is not None else None
    grad_D_fn = D.grad.clone() if D.grad is not None else None

    # Compute loss for mamba_inner_ref
    loss_ref = F.mse_loss(output_ref, target)

    # Backward pass through mamba_inner_ref
    zero_gradients(
        xz, conv1d_weight, conv1d_bias, x_proj_weight,
        dt_proj_weight, out_proj_weight, out_proj_bias, A, D
    ) #,
    loss_ref.backward(retain_graph=True)
    grad_xz_ref = xz.grad.clone() if xz.grad is not None else None
    grad_conv1d_weight_ref = conv1d_weight.grad.clone() if conv1d_weight.grad is not None else None
    grad_conv1d_bias_ref = conv1d_bias.grad.clone() if conv1d_bias.grad is not None else None
    grad_x_proj_weight_ref = x_proj_weight.grad.clone() if x_proj_weight.grad is not None else None
    grad_dt_proj_weight_ref = dt_proj_weight.grad.clone() if dt_proj_weight.grad is not None else None
    grad_out_proj_weight_ref = out_proj_weight.grad.clone() if out_proj_weight.grad is not None else None
    grad_A_ref = A.grad.clone() if A.grad is not None else None
    grad_D_ref = D.grad.clone() if D.grad is not None else None

    mismatch = False
    # Check if gradients are the same
    for grad_fn, grad_ref, name in zip(
        [grad_xz_fn, grad_conv1d_weight_fn, grad_conv1d_bias_fn, grad_x_proj_weight_fn,
         grad_dt_proj_weight_fn, grad_out_proj_weight_fn,
         grad_A_fn, grad_D_fn], #,
        [grad_xz_ref, grad_conv1d_weight_ref, grad_conv1d_bias_ref, grad_x_proj_weight_ref,
         grad_dt_proj_weight_ref, grad_out_proj_weight_ref,
         grad_A_ref, grad_D_ref], #,
        ["xz", "conv1d_weight", "conv1d_bias", "x_proj_weight", "dt_proj_weight",
         "out_proj_weight", "out_proj_bias", "A", "D"]
    ):
        if grad_fn is not None and grad_ref is not None:
            if not torch.allclose(grad_fn, grad_ref, atol=1e-5):
                mismatch = True
                print(f"Gradient mismatch for {name}! Diff: {torch.norm(grad_fn - grad_ref)}")
        elif grad_fn is None or grad_ref is None:
            print(f"Gradient does not exist for {name} at least in one of the functions.")

    return out_mismatch, mismatch

# Call the test function
out_correct = 0
grad_correct = 0
for _ in tqdm(range(1000)):
    out_mismatch, grad_mismatch = test_gradient_implementation(device)
    out_correct += 1 if not out_mismatch else 0
    grad_correct += 1 if not grad_mismatch else 0

print(f"Outputs match in {out_correct} out of 1000 runs.")
print(f"Gradients match in {grad_correct} out of 1000 runs.")

At this point I finally get about 906 runs on 1000 for grad-match. Outputs always match. However, unlike the real case, the differences are quite big, especially for xz and out_proj_weight.

I also noticed this property, wherein the functions have agreeable gradients only when the inputs are in the range of what the function expects, which is why I have initialized as in the original code them instead of random samples (it has a much lower agreement in that case). I was trying to write my own CUDA kernels for some application and wanted to test how bad it performs v/s how bad is the real implementation.