jonbarron / robust_loss_pytorch

A pytorch port of google-research/google-research/robust_loss/
Apache License 2.0
656 stars 88 forks source link

3D Adaptive Loss #17

Closed danieltudosiu closed 4 years ago

danieltudosiu commented 4 years ago

Hi Jonbarron, thank you for your amazing work!

I have used your loss in Tensorflow and am porting my codebase in PyTorch so I need to modify your loss to work in 3D here as well.

The modifications I mare are the following. I added in adaptive.py the following class:

class AdaptiveVolumeLossFunction(nn.Module):
    """A wrapper around AdaptiveLossFunction for handling volumes."""

    def __init__(self, image_size, device, float_dtype=np.float32, **kwargs):
        super(AdaptiveVolumeLossFunction, self).__init__()

        assert len(image_size) == 4

        self.image_size = image_size

        if float_dtype == np.float32:
            float_dtype = torch.float32
        if float_dtype == np.float64:
            float_dtype = torch.float64
        self.float_dtype = float_dtype

        self.device = device

        if (
            isinstance(device, int)
            or (isinstance(device, str) and "cuda" in device)
            or (isinstance(device, torch.device) and device.type == "cuda")
        ):
            torch.cuda.set_device(self.device)

        x_example = torch.zeros([1] + list(self.image_size)).type(self.float_dtype)
        x_example_mat = self.transform_to_mat(x_example)
        self.num_dims = x_example_mat.shape[1]
        self.adaptive_lossfun = AdaptiveLossFunction(
            self.num_dims, self.float_dtype, self.device, **kwargs
        )

    def lossfun(self, x):
        x_mat = self.transform_to_mat(x)

        loss_mat = self.adaptive_lossfun.lossfun(x_mat)

        # Reshape the loss function's outputs to have the shapes as the input.
        loss = torch.reshape(loss_mat, [-1] + list(self.image_size))
        return loss

    def alpha(self):
        assert not self.use_students_t
        return torch.reshape(self.adaptive_lossfun.alpha(), self.image_size)

    def scale(self):
        return torch.reshape(self.adaptive_lossfun.scale(), self.image_size)

    def transform_to_mat(self, x):
        assert len(x.shape) == 5
        x = torch.as_tensor(x)

        _, channels, depth, width, height = x.shape
        x_stack = torch.reshape(x, (-1, depth, width, height))

        x_stack = util.volume_dct(x_stack)

        x_mat = torch.reshape(
            torch.reshape(x_stack, (-1, channels, depth, width, height)),
            [-1, width * height * depth * channels],
        )
        return x_mat

And in util.py the following functions:

def volume_dct(image):
    """Does a type-II DCT (aka "The DCT") on axes 1 and 2 of a rank-3 tensor."""
    image = torch.as_tensor(image)
    dct_z = torch_dct.dct(image, norm="ortho").permute(0, 3, 1, 2)
    dct_y = torch_dct.dct(dct_z, norm="ortho").permute(0, 3, 1, 2)
    dct_x = torch_dct.dct(dct_y, norm="ortho").permute(0, 3, 1, 2)
    return dct_x

def volume_idct(dct_x):
    """Inverts image_dct(), by performing a type-III DCT."""
    dct_x = torch.as_tensor(dct_x)
    dct_y = torch_dct.idct(dct_x.permute(0, 3, 1, 2), norm="ortho")
    dct_z = torch_dct.idct(dct_y.permute(0, 3, 1, 2), norm="ortho")
    image = torch_dct.idct(dct_z.permute(0, 3, 1, 2), norm="ortho")
    return image

I have left the network to train for a bit and it weirdly gets stuck at value 1.19 all the time.

One peculiar thing I observed is this in your code where I do not understand the need for the permutation.

Are my modifications wrongs?

danieltudosiu commented 4 years ago

The problem was mishandling the shape due to peculiar input requirements and volume dct/idct mistakes. Code modified code will be available here.