kyutai-labs / moshi

Apache License 2.0
6.51k stars 494 forks source link

Is the SEANetDecoder implementation actually causal? #134

Closed jerryli27 closed 2 weeks ago

jerryli27 commented 2 weeks ago

Due diligence

Topic

The PyTorch implementation

Question

Hi all,

First of all, thank you for your great work and for open sourcing the Mimi Tokenizer. It's very rare to have a causal and streamable tokenizer, with the exception of Encodec.

When I tried to use the Tokenizer, it seems that some of its components, namely the SEANetDecoder is not causal.

I went ahead and coded up a few unit tests to convince myself and to further nail down the culprit. See this commit in my fork. The following classes are causal and streamable: StreamingConv1d, StreamingConvTranspose1d, and SEANetResnetBlock. But the unit test breaks for SEANetDecoder. To put it simply, if I feed it input [1,2,3], the output part corresponding to [1,2] is different from when I give it just [1,2] as its input.

I find it mildly concerning. Is there a bug, or is it just numerical instability?

Below are my unit test outputs:

moshi/moshi/modules/seanet_test.py::test_nonstreaming_causal_decode[Large SEANet-length 100] failed: num_timesteps = 100
seanet_kwargs = {'activation': 'ELU', 'causal': True, 'channels': 1, 'compress': 2, ...}

    @pytest.mark.parametrize("num_timesteps", NUM_TIMESTEPS_DATA)
    @pytest.mark.parametrize("seanet_kwargs", SEANET_KWARGS_DATA)
    def test_nonstreaming_causal_decode(num_timesteps, seanet_kwargs):
      """Test that the SEANetDecoder does not depend on future inputs."""

      device = 'cuda' if torch.cuda.is_available() else 'cpu'
      decoder = SEANetDecoder(**seanet_kwargs).to(device=device)

      generator = torch.Generator(device=device)
      generator = generator.manual_seed(41)
      decoder.apply(functools.partial(_init_weights, generator=generator))

      rand_generator = torch.Generator(device=device)
      rand_generator.manual_seed(2147483647)
      with torch.no_grad():
        codes = torch.randn(1, seanet_kwargs['dimension'], num_timesteps, generator=rand_generator, device=device)  # [B, K = 8, T]
        expected_decoded = decoder(codes)

        num_timesteps = codes.shape[-1]
        for t in range(num_timesteps):
          current_codes = codes[..., :t+1]
          actual_decoded = decoder(current_codes)
>         torch.testing.assert_close(expected_decoded[..., :actual_decoded.shape[-1]], actual_decoded,
                                     msg=lambda original_msg: f"Failed at t={t}: {original_msg}")
E         AssertionError: Failed at t=0: Tensor-likes are not close!
E         
E         Mismatched elements: 743 / 960 (77.4%)
E         Greatest absolute difference: 0.00010927580296993256 at index (0, 0, 933) (up to 1e-05 allowed)
E         Greatest relative difference: 1.206646203994751 at index (0, 0, 952) (up to 1.3e-06 allowed)

moshi/moshi/modules/seanet_test.py:186: AssertionError
moshi/moshi/modules/seanet_test.py::test_nonstreaming_causal_decode[Large SEANet-length 10] failed: num_timesteps = 10
seanet_kwargs = {'activation': 'ELU', 'causal': True, 'channels': 1, 'compress': 2, ...}

    @pytest.mark.parametrize("num_timesteps", NUM_TIMESTEPS_DATA)
    @pytest.mark.parametrize("seanet_kwargs", SEANET_KWARGS_DATA)
    def test_nonstreaming_causal_decode(num_timesteps, seanet_kwargs):
      """Test that the SEANetDecoder does not depend on future inputs."""

      device = 'cuda' if torch.cuda.is_available() else 'cpu'
      decoder = SEANetDecoder(**seanet_kwargs).to(device=device)

      generator = torch.Generator(device=device)
      generator = generator.manual_seed(41)
      decoder.apply(functools.partial(_init_weights, generator=generator))

      rand_generator = torch.Generator(device=device)
      rand_generator.manual_seed(2147483647)
      with torch.no_grad():
        codes = torch.randn(1, seanet_kwargs['dimension'], num_timesteps, generator=rand_generator, device=device)  # [B, K = 8, T]
        expected_decoded = decoder(codes)

        num_timesteps = codes.shape[-1]
        for t in range(num_timesteps):
          current_codes = codes[..., :t+1]
          actual_decoded = decoder(current_codes)
>         torch.testing.assert_close(expected_decoded[..., :actual_decoded.shape[-1]], actual_decoded,
                                     msg=lambda original_msg: f"Failed at t={t}: {original_msg}")
E         AssertionError: Failed at t=0: Tensor-likes are not close!
E         
E         Mismatched elements: 730 / 960 (76.0%)
E         Greatest absolute difference: 0.00013911724090576172 at index (0, 0, 212) (up to 1e-05 allowed)
E         Greatest relative difference: 0.7818552255630493 at index (0, 0, 921) (up to 1.3e-06 allowed)

moshi/moshi/modules/seanet_test.py:186: AssertionError
moshi/moshi/modules/seanet_test.py::test_nonstreaming_causal_decode[Large SEANet-length 2] failed: num_timesteps = 2
seanet_kwargs = {'activation': 'ELU', 'causal': True, 'channels': 1, 'compress': 2, ...}

    @pytest.mark.parametrize("num_timesteps", NUM_TIMESTEPS_DATA)
    @pytest.mark.parametrize("seanet_kwargs", SEANET_KWARGS_DATA)
    def test_nonstreaming_causal_decode(num_timesteps, seanet_kwargs):
      """Test that the SEANetDecoder does not depend on future inputs."""

      device = 'cuda' if torch.cuda.is_available() else 'cpu'
      decoder = SEANetDecoder(**seanet_kwargs).to(device=device)

      generator = torch.Generator(device=device)
      generator = generator.manual_seed(41)
      decoder.apply(functools.partial(_init_weights, generator=generator))

      rand_generator = torch.Generator(device=device)
      rand_generator.manual_seed(2147483647)
      with torch.no_grad():
        codes = torch.randn(1, seanet_kwargs['dimension'], num_timesteps, generator=rand_generator, device=device)  # [B, K = 8, T]
        expected_decoded = decoder(codes)

        num_timesteps = codes.shape[-1]
        for t in range(num_timesteps):
          current_codes = codes[..., :t+1]
          actual_decoded = decoder(current_codes)
>         torch.testing.assert_close(expected_decoded[..., :actual_decoded.shape[-1]], actual_decoded,
                                     msg=lambda original_msg: f"Failed at t={t}: {original_msg}")
E         AssertionError: Failed at t=0: Tensor-likes are not close!
E         
E         Mismatched elements: 689 / 960 (71.8%)
E         Greatest absolute difference: 0.00010664388537406921 at index (0, 0, 323) (up to 1e-05 allowed)
E         Greatest relative difference: 0.5122406482696533 at index (0, 0, 457) (up to 1.3e-06 allowed)

moshi/moshi/modules/seanet_test.py:186: AssertionError
LaurentMazare commented 2 weeks ago

Interesting, my guess is that it's most likely numerical issues around convolutions that take place in cuda/cudnn. I don't get the exact same error values as in your logs (probably because of cudnn benchmarking or some other non-determinism) but when disabling cudnn with the following line your seanet decode test seems to pass.

torch.backends.cudnn.enabled = False

Alternatively you can try running the test on cpu and it seems to work properly there too.

jerryli27 commented 2 weeks ago

That worked. Thank you for the quick response! I didn't realize cudnn can make a big difference.

Btw if you'd like to use the unit test, just let me know and I can submit a pull request.

LaurentMazare commented 2 weeks ago

Btw if you'd like to use the unit test, just let me know and I can submit a pull request.

Sure, your test commit looks pretty good so happy to include them.