Hello, thanks for providing such a wonderful method and code that helped me improve our project.
However, while testing the code you open sourced, I discovered that there may be a problem with the code that samples rotations in the form of axis angles from an isotropic Gaussian distribution:
In the above code, trap_start and trap_end are the indexes before and after the CDF is equal to unif, but they all obtain the index in the CDF corresponding to the first variance, rather than the corresponding variance.
I trained and tested the network's performance on the task of generating $\pm90$-degree rotations along the z-axis and got the following result plot:
.
The predicted rotation distribution differs greatly from the true value.
I optimized this part of the code so that trap_start and trap_end are the indexes before and after the CDF corresponding to their respective variances is equal to unif. The optimized core code is as follows:
B = self.eps.shape[0]
bs = torch.arange(0, B)
trap_start = self.trap[idx_0, bs]
trap_end = self.trap[idx_1, bs]
In addition to this, I made several other changes to make the code runnable:
How many steps did you train the model? I simply run so3_train.py for 20k steps and then run so3_test.py. I already got converged results, without the modification you discussed above:
Hello, thanks for providing such a wonderful method and code that helped me improve our project.
However, while testing the code you open sourced, I discovered that there may be a problem with the code that samples rotations in the form of axis angles from an isotropic Gaussian distribution:
https://github.com/qazwsxal/diffusion-extensions/blob/f100885dc7f33beda7a5182324eb19f29b29fd47/distributions.py#L42-L43
In the above code,
trap_start
andtrap_end
are the indexes before and after the CDF is equal tounif
, but they all obtain the index in the CDF corresponding to the first variance, rather than the corresponding variance.I trained and tested the network's performance on the task of generating $\pm90$-degree rotations along the z-axis and got the following result plot:
.
The predicted rotation distribution differs greatly from the true value.
I optimized this part of the code so that
trap_start
andtrap_end
are the indexes before and after the CDF corresponding to their respective variances is equal tounif
. The optimized core code is as follows:In addition to this, I made several other changes to make the code runnable:
(1) Original code: https://github.com/qazwsxal/diffusion-extensions/blob/f100885dc7f33beda7a5182324eb19f29b29fd47/so3_test.py#L31
modified code:
(2) Original code: https://github.com/qazwsxal/diffusion-extensions/blob/f100885dc7f33beda7a5182324eb19f29b29fd47/diffusion.py#L325
modified code:
The result after modifying the code is shown below:
The predicted rotation distribution is almost the same as the true value.
If I understand it wrong, please point out my problem so that I can better understand your method, thank you~