Closed yjb6 closed 3 months ago
I believe there's an issue when calculating quaternion_loss. After replacing it with L1 loss, the code runs without issues, but the NTC performance doesn't match the official results. The scenario is flame_steak.
Please refer to https://github.com/SJoJoK/3DGStream/issues/8 which has the same issue as you. The tiny-cuda-nn may preform wrongly in different environments.
I think I have found the issue. My environment is:
python 3.10.13
tinycudann 1.7
torch 2.2.0+cu118
"As mentioned in #8 , different versions can cause problems. In my version, when calculating
cos_theta = F.cosine_similarity(q1, q2, dim=1)
in def quaternion_loss
, the gradient backpropagation computes the 3/2 power of q1, which exceeds the precision of torch.float16, causing NaNs. The following code can verify this situation:
for ntc_conf_path in ntc_conf_paths:
with open(ntc_conf_path) as ntc_conf_file:
ntc_conf = ctjs.load(ntc_conf_file)
ntc=tcnn.NetworkWithInputEncoding(n_input_dims=3, n_output_dims=4, encoding_config=ntc_conf["encoding"], network_config=ntc_conf["network"]).to(torch.device("cuda"))
ntc_optimizer = torch.optim.Adam(ntc.parameters(), lr=1e-4)
xyz = torch.rand(2000,3).cuda()
ntc_output=ntc(xyz)
# .to(torch.float64)
gt = torch.tensor([1,0,0,0],dtype=torch.float64).cuda()
cos_theta1 = F.cosine_similarity(ntc_output, gt, dim=1)
cos_theta2 = torch.clamp(cos_theta1, -1+1e-7, 1-1e-7)
loss = 1-torch.pow(cos_theta2, 2).mean()
ntc_grad = autograd.grad(loss, ntc_output, retain_graph=True)[0]
print(ntc_grad)
print(ntc_grad.isnan().sum())
print(ntc_output[ntc_grad.isnan().sum(dim=-1)!=0])
test = ntc_grad[ntc_grad.isnan().sum(dim=-1)!=0]
print(torch.pow(test,1.5))
Therefore, increasing the precision of q1 can solve this problem. The code can be modified as shown below:
ntc_output=ntc(ntc_inputs_w_noisy).to(torch.float64)
Nice job, thank you! I'll mention this modification in README in next few days.
The README is updated.
When I run the cache_warmup.ipynb script, the loss becomes NaN during the second epoch. What could be causing this issue? I haven't made any modifications to the code after cloning the repository.