pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.69k stars 216 forks source link

numerical issue when running SDPA with DTensor #267

Open tianyu-l opened 7 months ago

tianyu-l commented 7 months ago

The issue comes from the backward computation of aten.mul of two complex numbers from DTensors: the result will be b + ai when it should be a + bi. Not sure why it happens -- when doing aten operations, the input tensors have been de-sugared and should have nothing to do with DTensor.

To replicate, put the following code in pytorch/test/distributed/tensor/parallel/test_tp_examples.py

    @with_comms
    def test_apply_rotary_embedding(self):
        device_mesh = self.build_device_mesh()
        def apply_rotary_emb(xq, freqs_cis):
            xq_ = torch.view_as_complex(xq)
            xq_out = torch.view_as_real(xq_ * freqs_cis)
            return xq_out

        with CommDebugMode():
            # xq = torch.randn(1, 1, 2, requires_grad=True, device=self.device_type)
            # freqs_cis = torch.randn(1, 1, dtype=torch.complex64, requires_grad=False, device=self.device_type)
            # xq_out = apply_rotary_emb(xq, freqs_cis)
            # xq_out.sum().backward()

            xq = torch.randn(1, 1, 2, requires_grad=True, device=self.device_type)
            freqs_cis = torch.randn(1, 1, dtype=torch.complex64, requires_grad=False, device=self.device_type)
            xq_dt = distribute_tensor(xq, device_mesh, (Replicate(),))
            freqs_cis_dt = distribute_tensor(freqs_cis, device_mesh, (Replicate(),))
            xq_out_dt = apply_rotary_emb(xq_dt, freqs_cis_dt)
            xq_out_dt.sum().backward()
tianyu-l commented 4 months ago

A solution is proposed in https://github.com/pytorch/pytorch/issues/130646