ProGamerGov / dream-creator

Quickly and easily create / train a custom DeepDream model
MIT License
64 stars 6 forks source link

Visualization fails when using custom image #12

Closed Morpheus3000 closed 4 years ago

Morpheus3000 commented 4 years ago

Hello,

Thank you for the great and easy to use repository. I was trying to train a custom deep dream module and then trying to run it on a custom image. While the system works flawlessly on random input (no -content_image), it fails when I try to set a custom content_img.

The specific error I get is on (utils/decorrelation.py):

....
def ifft_image(self, input):
     input = input * self.scale
....

The error is a shape mismatch. My input image is 224x224 and the scale variable is 224x113 on the spatial dimension.

I tried looking at the code and it seems the line (utils/decorrelation.py):

fx = self.pytorch_fftfreq(self.w)[: self.w // 2 + wadd]

returns the reduced spatial dimension. When the content image is not set, the script initializes a random input, which is incidentally 224x113, so there is no complain.

Additionally, I also noticed that the random input tensor is initialized as torch.Size([1, 3, 224, 113, 2]). Which ends up with another error, when I do manage to patch the scale to be the same spatial dimension as the input. This is because my input (set by the -content_image) has the size torch.Size([1, 3, 224, 224]).

So I was wondering if you have specific solution for this. And what is the 2 on the 5th dimension for the random input tensor.

ProGamerGov commented 4 years ago

@Morpheus3000 It's awesome to see someone using this project after all the hard work I've put into it!

The random inputs are initialized as already spatially decorrelated (that's why there's the extra 5th dimension and the 4th dimension is temporarily reduced), when using spatial decorrelation. The 5th dimension are imaginary numbers (I think) and they come from torch.rfft(), which does the opposite of the function used in the main spatial correlation function (torch.irfft()). I just fake them with random values if there's no content image, as decorrelating a random input is the same as creating the random input as already decorrelated. I still don't fully understand spectral ops like torch.rfft(), but I know enough about them to use them effectively.

Random inputs are also initialized as already having their colors decorrelated, if color decorrelation is enabled.

Basically the images are optimized in the optimizer in their decorrelated form, and the decorrelation classes (spatial & color) take the decorrelated tensors and correlate them. I had to add decorrelation code to do the opposite of correlation for content images.

I've just published an update that adds support for content image color & spatial decorrelation: https://github.com/ProGamerGov/dream-creator/pull/11

The update may not be perfect as I haven't done a ton of testing, but it should work. I've found that the visualizations can really easily overpower the content image's content, but haven't yet figured out how to compensate for that (need to play around with the Transform class as it looks like the Sigmoid function messes with the content image). Let me know how well it works! Hopefully everything I've said makes sense to you, and feel free to ask questions if it doesn't!

ProGamerGov commented 4 years ago

Okay, I've made it so that the content image no longer gets overwhelmed by the visuals when using spatial and color decorrelation (Replaced .sigmoid() with .clamp(0,1) only when using content images). Though you'll have to lower the learning rate a little bit, for example:

-content_image <content.jpg> -color_decorrelation -fft_decorrelation -lr 0.075
Morpheus3000 commented 4 years ago

Hi @ProGamerGov thanks for such a detailed explanation. I tried the patch that you pushed and it worked. I haven't had time to try out the new clamp version. I will give it a try over the weekend. Thanks again for the added effort! It's really nice to have a complete repository to train and test.

In the meantime, I had found bugs with my version of Pytorch, which complains there being no .T in tensor and a datatype mismatch in another file. Both of them are easy fixes (replace .T with torch.t(), albeit more memory intensive), so I was planning to push a pull request later during the weekend.

Thanks again for the quick response!

ProGamerGov commented 4 years ago

@Morpheus3000 .T works for me on the latest version of PyTorch, so you may have to update your version of PyTorch if you have not done so already. Or you could keep a modified local version of the code with .T replaced with .t().

I'm aware of the mismatches and don't believe they are any cause for concern. They're just warnings for individuals doing other stuff that may be affected by it.

I'd also love to see some results from your custom trained model!

Mayukhdeb commented 3 years ago

Hi @ProGamerGov , I was trying to understand how is the ColorDecorrelationLayer being used in your code. It seems like you're running it's forward function at every iteration on a tensor of size torch.Size([1, 3, 224, 224]).

So I'd assume that you're running the color decorrelation on the image tensor.

I have 3 questions:

  1. Are you running the color decorrelation module on the input image tensor or on something else ?
  2. Assuming that it is the image tensor that gets decorrelated, does it undergo decorrelation before each forward pass ? If that's not the case, when does it undergo the decorrelation ?
  3. Can color decorrelation be added on the framework that I've already built ? Here's a link to the main script which handles the image updates for your reference.

Thank you for your time :+1:

ProGamerGov commented 3 years ago

@Mayukhdeb

1 & 2. If it's not using a random initialization, then the colors are decorrelated before storing the tensor as an nn.Parameter instance. The tensor is then has it's colors recorrelated in the forward pass each time before it's sent through the model.

  1. Color decorrelation can be implemented as a transform layer, so if your code supports that then it should be possible.
Mayukhdeb commented 3 years ago

Hi @ProGamerGov, So after some refactoring, my library finally uses torch.optim for optimizing image parameters. Taking some inspiration from this project, I also built an image_parameter() class that handles all the image related operations.

I've been doing all of this just to add fft into my library as a feature which would improve the quality.

But I just can't wrap my head around how the spatial and color decorrelation functions work on dream-creator or lucid

If you have any resources which would help me get a better understanding of spatial and color decorrelation, I humbly request you to send them to me here or via telegram

Thanks in advance :)

ProGamerGov commented 3 years ago

@Mayukhdeb Basically when the image is being optimized, you can think of it as being stored inside the optimizer. The format of the image stored inside the optimizer influences how the optimizer optimizes it.

I would suggested asking on the Distill Slack, as there are people that understand it far better than me there: http://slack.distill.pub/ The Distill Slack has the Lucid devs and other individuals who are extremely well-versed in image parameterization!