Closed CTouch closed 2 months ago
Hi @CTouch , I encountered the same error you did. May I know how you solved it? Thank you so much.
I rewrite the masked_l1_loss
function
def masked_l1_loss(pred, gt, mask=None, normalize=True, quantile: float = 1.0):
if mask is None:
return trimmed_l1_loss(pred, gt, quantile)
else:
sum_loss = F.l1_loss(pred, gt, reduction="none").mean(dim=-1, keepdim=True)
# sum_loss.shape
# block [218255, 1]
# apple [36673, 475, 1] 17,419,675
# creeper [37587, 360, 1] 13,531,320
# backpack [37828, 180, 1] 6,809,040
# quantile_mask = (
# (sum_loss < torch.quantile(sum_loss, quantile)).squeeze(-1)
# if quantile < 1
# else torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1)
# )
# use torch.sort instead of torch.quantile when input too large
if quantile < 1:
num = sum_loss.numel()
if num < 16_000_000:
threshold = torch.quantile(sum_loss, quantile)
else:
sorted, _ = torch.sort(sum_loss.reshape(-1))
idxf = quantile * num
idxi = int(idxf)
threshold = sorted[idxi] + (sorted[idxi + 1] - sorted[idxi]) * (idxf - idxi)
quantile_mask = (sum_loss < threshold).squeeze(-1)
else:
quantile_mask = torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1)
ndim = sum_loss.shape[-1]
if normalize:
return torch.sum((sum_loss * mask)[quantile_mask]) / (
ndim * torch.sum(mask[quantile_mask]) + 1e-8
)
else:
return torch.mean((sum_loss * mask)[quantile_mask])
I rewrite the
masked_l1_loss
functiondef masked_l1_loss(pred, gt, mask=None, normalize=True, quantile: float = 1.0): if mask is None: return trimmed_l1_loss(pred, gt, quantile) else: sum_loss = F.l1_loss(pred, gt, reduction="none").mean(dim=-1, keepdim=True) # sum_loss.shape # block [218255, 1] # apple [36673, 475, 1] 17,419,675 # creeper [37587, 360, 1] 13,531,320 # backpack [37828, 180, 1] 6,809,040 # quantile_mask = ( # (sum_loss < torch.quantile(sum_loss, quantile)).squeeze(-1) # if quantile < 1 # else torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1) # ) # use torch.sort instead of torch.quantile when input too large if quantile < 1: num = sum_loss.numel() if num < 16_000_000: threshold = torch.quantile(sum_loss, quantile) else: sorted, _ = torch.sort(sum_loss.reshape(-1)) idxf = quantile * num idxi = int(idxf) threshold = sorted[idxi] + (sorted[idxi + 1] - sorted[idxi]) * (idxf - idxi) quantile_mask = (sum_loss < threshold).squeeze(-1) else: quantile_mask = torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1) ndim = sum_loss.shape[-1] if normalize: return torch.sum((sum_loss * mask)[quantile_mask]) / ( ndim * torch.sum(mask[quantile_mask]) + 1e-8 ) else: return torch.mean((sum_loss * mask)[quantile_mask])
Thank you so much for your solution!
You are welcome! And let me know If there are any mistakes or other better implemationations. :>
Thanks for the great work! When training the
apple
andteddy
scenes, this runtime error occurred. https://github.com/vye16/shape-of-motion/blob/697d462fdaccefda5a91373e9b6047b9dafb9234/flow3d/loss_utils.py#L32 And it seems that the issue has not been resolved. pytorch/pytorch#64947My enviroment:
torch 2.2.0 cuda 12.1
Question
Has anyone encountered the same situation as me? I am trying to use other method to get the
quantile
.