Closed danieltudosiu closed 4 years ago
Hi Jonbarron, thank you for your amazing work!
I have used your loss in Tensorflow and am porting my codebase in PyTorch so I need to modify your loss to work in 3D here as well.
The modifications I mare are the following. I added in adaptive.py the following class:
class AdaptiveVolumeLossFunction(nn.Module): """A wrapper around AdaptiveLossFunction for handling volumes.""" def __init__(self, image_size, device, float_dtype=np.float32, **kwargs): super(AdaptiveVolumeLossFunction, self).__init__() assert len(image_size) == 4 self.image_size = image_size if float_dtype == np.float32: float_dtype = torch.float32 if float_dtype == np.float64: float_dtype = torch.float64 self.float_dtype = float_dtype self.device = device if ( isinstance(device, int) or (isinstance(device, str) and "cuda" in device) or (isinstance(device, torch.device) and device.type == "cuda") ): torch.cuda.set_device(self.device) x_example = torch.zeros([1] + list(self.image_size)).type(self.float_dtype) x_example_mat = self.transform_to_mat(x_example) self.num_dims = x_example_mat.shape[1] self.adaptive_lossfun = AdaptiveLossFunction( self.num_dims, self.float_dtype, self.device, **kwargs ) def lossfun(self, x): x_mat = self.transform_to_mat(x) loss_mat = self.adaptive_lossfun.lossfun(x_mat) # Reshape the loss function's outputs to have the shapes as the input. loss = torch.reshape(loss_mat, [-1] + list(self.image_size)) return loss def alpha(self): assert not self.use_students_t return torch.reshape(self.adaptive_lossfun.alpha(), self.image_size) def scale(self): return torch.reshape(self.adaptive_lossfun.scale(), self.image_size) def transform_to_mat(self, x): assert len(x.shape) == 5 x = torch.as_tensor(x) _, channels, depth, width, height = x.shape x_stack = torch.reshape(x, (-1, depth, width, height)) x_stack = util.volume_dct(x_stack) x_mat = torch.reshape( torch.reshape(x_stack, (-1, channels, depth, width, height)), [-1, width * height * depth * channels], ) return x_mat
And in util.py the following functions:
def volume_dct(image): """Does a type-II DCT (aka "The DCT") on axes 1 and 2 of a rank-3 tensor.""" image = torch.as_tensor(image) dct_z = torch_dct.dct(image, norm="ortho").permute(0, 3, 1, 2) dct_y = torch_dct.dct(dct_z, norm="ortho").permute(0, 3, 1, 2) dct_x = torch_dct.dct(dct_y, norm="ortho").permute(0, 3, 1, 2) return dct_x def volume_idct(dct_x): """Inverts image_dct(), by performing a type-III DCT.""" dct_x = torch.as_tensor(dct_x) dct_y = torch_dct.idct(dct_x.permute(0, 3, 1, 2), norm="ortho") dct_z = torch_dct.idct(dct_y.permute(0, 3, 1, 2), norm="ortho") image = torch_dct.idct(dct_z.permute(0, 3, 1, 2), norm="ortho") return image
I have left the network to train for a bit and it weirdly gets stuck at value 1.19 all the time.
One peculiar thing I observed is this in your code where I do not understand the need for the permutation.
Are my modifications wrongs?
The problem was mishandling the shape due to peculiar input requirements and volume dct/idct mistakes. Code modified code will be available here.
Hi Jonbarron, thank you for your amazing work!
I have used your loss in Tensorflow and am porting my codebase in PyTorch so I need to modify your loss to work in 3D here as well.
The modifications I mare are the following. I added in adaptive.py the following class:
And in util.py the following functions:
I have left the network to train for a bit and it weirdly gets stuck at value 1.19 all the time.
One peculiar thing I observed is this in your code where I do not understand the need for the permutation.
Are my modifications wrongs?