qazwsxal / diffusion-extensions

Reference implmentation of Diffusion models on SO(3)
28 stars 7 forks source link

About sampling axis-angles from isotropic Gaussian distributions #3

Open dexin-wang opened 11 months ago

dexin-wang commented 11 months ago

Hello, thanks for providing such a wonderful method and code that helped me improve our project.

However, while testing the code you open sourced, I discovered that there may be a problem with the code that samples rotations in the form of axis angles from an isotropic Gaussian distribution:

https://github.com/qazwsxal/diffusion-extensions/blob/f100885dc7f33beda7a5182324eb19f29b29fd47/distributions.py#L42-L43

In the above code, trap_start and trap_end are the indexes before and after the CDF is equal to unif, but they all obtain the index in the CDF corresponding to the first variance, rather than the corresponding variance.

I trained and tested the network's performance on the task of generating $\pm90$-degree rotations along the z-axis and got the following result plot:

Figure_1.

The predicted rotation distribution differs greatly from the true value.

I optimized this part of the code so that trap_start and trap_end are the indexes before and after the CDF corresponding to their respective variances is equal to unif. The optimized core code is as follows:

        B = self.eps.shape[0]
        bs = torch.arange(0, B)
        trap_start = self.trap[idx_0, bs]
        trap_end = self.trap[idx_1, bs]

In addition to this, I made several other changes to make the code runnable:

(1) Original code: https://github.com/qazwsxal/diffusion-extensions/blob/f100885dc7f33beda7a5182324eb19f29b29fd47/so3_test.py#L31

modified code:

R = process.p_sample(R, torch.full((BATCH,), i, device=device, dtype=torch.long))

(2) Original code: https://github.com/qazwsxal/diffusion-extensions/blob/f100885dc7f33beda7a5182324eb19f29b29fd47/diffusion.py#L325

modified code:

sample = IsotropicGaussianSO3(model_stdev).sample()

The result after modifying the code is shown below:

Figure_2

The predicted rotation distribution is almost the same as the true value.

If I understand it wrong, please point out my problem so that I can better understand your method, thank you~

xiexh20 commented 6 months ago

How many steps did you train the model? I simply run so3_train.py for 20k steps and then run so3_test.py. I already got converged results, without the modification you discussed above: rotations_3axis