med-air / 3DSAM-adapter

Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation
134 stars 12 forks source link

Feedback on logic error in code (possible) #15

Open CunminZhao opened 11 months ago

CunminZhao commented 11 months ago

Feedback on logic error in code (possible): Hello, in this section of your code, the criterion for judging points_torch_negative is seg<10, which appears to be incorrect. In the dataset you provided, after processing, seg only has two values 0 and 1. Therefore, I think this section should be modified to seg<1. “”“ l = len(torch.where(seg == 1)[0]) points_torch = None if l > 0: sample = np.random.choice(np.arange(l), 40, replace=True) x = torch.where(seg == 1)[1][sample].unsqueeze(1) y = torch.where(seg == 1)[3][sample].unsqueeze(1) z = torch.where(seg == 1)[2][sample].unsqueeze(1) points = torch.cat([x, y, z], dim=1).unsqueeze(1).float() points_torch = points.to(device) points_torch = points_torch.transpose(0,1) l = len(torch.where(seg < 10)[0]) sample = np.random.choice(np.arange(l), 10000, replace=True) x = torch.where(seg < 10)[1][sample].unsqueeze(1) y = torch.where(seg < 10)[3][sample].unsqueeze(1) z = torch.where(seg < 10)[2][sample].unsqueeze(1) points = torch.cat([x, y, z], dim=1).unsqueeze(1).float() points_torch_negative = points.to(device) points_torch_negative = points_torch_negative.transpose(0, 1) ”“”

peterant330 commented 11 months ago

Feedback on logic error in code (possible): Hello, in this section of your code, the criterion for judging points_torch_negative is seg<10, which appears to be incorrect. In the dataset you provided, after processing, seg only has two values 0 and 1. Therefore, I think this section should be modified to seg<1. “”“ l = len(torch.where(seg == 1)[0]) points_torch = None if l > 0: sample = np.random.choice(np.arange(l), 40, replace=True) x = torch.where(seg == 1)[1][sample].unsqueeze(1) y = torch.where(seg == 1)[3][sample].unsqueeze(1) z = torch.where(seg == 1)[2][sample].unsqueeze(1) points = torch.cat([x, y, z], dim=1).unsqueeze(1).float() points_torch = points.to(device) points_torch = points_torch.transpose(0,1) l = len(torch.where(seg < 10)[0]) sample = np.random.choice(np.arange(l), 10000, replace=True) x = torch.where(seg < 10)[1][sample].unsqueeze(1) y = torch.where(seg < 10)[3][sample].unsqueeze(1) z = torch.where(seg < 10)[2][sample].unsqueeze(1) points = torch.cat([x, y, z], dim=1).unsqueeze(1).float() points_torch_negative = points.to(device) points_torch_negative = points_torch_negative.transpose(0, 1) ”“”

This means if a sample does not have a foreground, then I will randomly sample points from any part of the image. Actually <1 and <10 will give the same result. (We use <10 because we also do some experiments on muti-label segmentation)

CunminZhao commented 11 months ago

Feedback on logic error in code (possible): Hello, in this section of your code, the criterion for judging points_torch_negative is seg<10, which appears to be incorrect. In the dataset you provided, after processing, seg only has two values 0 and 1. Therefore, I think this section should be modified to seg<1. “”“ l = len(torch.where(seg == 1)[0]) points_torch = None if l > 0: sample = np.random.choice(np.arange(l), 40, replace=True) x = torch.where(seg == 1)[1][sample].unsqueeze(1) y = torch.where(seg == 1)[3][sample].unsqueeze(1) z = torch.where(seg == 1)[2][sample].unsqueeze(1) points = torch.cat([x, y, z], dim=1).unsqueeze(1).float() points_torch = points.to(device) points_torch = points_torch.transpose(0,1) l = len(torch.where(seg < 10)[0]) sample = np.random.choice(np.arange(l), 10000, replace=True) x = torch.where(seg < 10)[1][sample].unsqueeze(1) y = torch.where(seg < 10)[3][sample].unsqueeze(1) z = torch.where(seg < 10)[2][sample].unsqueeze(1) points = torch.cat([x, y, z], dim=1).unsqueeze(1).float() points_torch_negative = points.to(device) points_torch_negative = points_torch_negative.transpose(0, 1) ”“”

This means if a sample does not have a foreground, then I will randomly sample points from any part of the image. Actually <1 and <10 will give the same result. (We use <10 because we also do some experiments on muti-label segmentation)

thanks for your reply, I'm sorry, I was mistaken. I incorrectly took your "points_torch_negative" to be the points used to indicate background in SAM.