lucidrains / deep-daze

Simple command line tool for text to image generation using OpenAI's CLIP and Siren (Implicit neural representation network). Technique was originally created by https://twitter.com/advadnoun
MIT License
4.37k stars 327 forks source link

Size schedule #55

Open NotNANtoN opened 3 years ago

NotNANtoN commented 3 years ago

Hi!

I have some issues understanding the size scheduling. I care about it because in one of my projects I try to create an audio-visual mirror. There I need to continuously train on clip encodings of new images that are delivered via the webcam. Therefore I need to understand the size scheduling (I think it's the only kind of scheduling that happens) in order to modify it for my needs - I guess I'd need to remove it, generalize it, or make it dependent on the amount of change in the clip encoding between images.

I have some issues/questions regarding the scheduling:

  1. The scheduling is not dependent on the number of total batches. It only generates the schedule up to the total number of batches that are required. Is it simply not implemented yet? It seems to me that the thresholds (500, 1000 etc) and possibly the pieces_per_group would need to be modified based on the total number of batches.
  2. The scheduling partitions seem to change their ordering over time from descending to ascending in sizes: first partition is [4, 5, 3, 2, 1, 1], while last partition is [1, 1, 1, 2, 4, 7]. Yet, in line 215 the sampled sizes are sorted. That does not make any sense.
  3. Could someone explain the point of the scheduling?

I linked the relevant lines below: https://github.com/lucidrains/deep-daze/blob/79a991eb952166e3c8118c84422037223461bd7c/deep_daze/deep_daze.py#L173-L216

afiaka87 commented 3 years ago

@lucidrains I'm quite curious about this as well. Have been staring at this block of code for awhile now and haven't been able to understand it.

NotNANtoN commented 3 years ago

@lucidrains @afiaka87 I took it upon myself to fix this and replaced the scheduling by something much easier and better, see #61 .

afiaka87 commented 3 years ago

Going to give this a test run. Thanks for investigating all of this!

dginev commented 3 years ago

Wish I had seen this issue 7 days ago, since I've caused all the trouble, but here I am now...

The scheduling partitions seem to change their ordering over time from descending to ascending in sizes: first partition is [4, 5, 3, 2, 1, 1], while last partition is [1, 1, 1, 2, 4, 7]. Yet, in line 215 the sampled sizes are sorted. That does not make any sense.

Well, it makes some sense, since you are not sorting those arrays - you are sorting the sampled sizes, which are constrained based on these partitions. Best if we run the code in the two cases and check sizes as an example:

```python import torch pieces_per_group = 4 partition = [4, 5, 3, 2, 1, 1] #partition = [1, 1, 1, 2, 4, 7] dbase = .38 step = .1 width = 512 sizes = [] for part_index in range(len(partition)): groups = partition[part_index] for _ in range(groups * pieces_per_group): sizes.append(torch.randint( int((dbase + step * part_index + .01) * width), int((dbase + step * (1 + part_index)) * width), ())) sizes.sort() print([int(size) for size in sizes]) ```

[4, 5, 3, 2, 1, 1]

[199, 201, 202, 202, 203, 208, 212, 220, 222, 222, 227, 232, 235,
 241, 244, 244, 250, 255, 257, 258, 260, 261, 262, 265, 267, 268,
 270, 270, 272, 276, 277, 282, 289, 293, 294, 294, 304, 305, 307,
 310, 313, 314, 315, 320, 323, 333, 343, 345, 361, 361, 365, 369,
 375, 377, 388, 393, 414, 426, 429, 448, 460, 464, 469, 498]

[1, 1, 1, 2, 4, 7]

[211, 217, 224, 230, 253, 261, 270, 294, 302, 319, 328, 337, 357,
364, 367, 370, 372, 387, 390, 395, 404, 404, 405, 406, 409, 413,
417, 422, 427, 433, 439, 439, 442, 442, 444, 447, 455, 457, 457, 
459, 467, 470, 470, 473, 473, 475, 476, 476, 476, 477, 477, 479,
480, 482, 485, 488, 488, 492, 493, 497, 497, 500, 500, 500]

So the different partitions are still respected (more smaller vs more larger), they are just kept monotonically increasing, rather than shuffled.

Is this a "great idea" or "good practice"? Most certainly not, let's not push my luck 😅 - but it worked interestingly well in practice as to increase the robustness between steps. The training saturated faster, as I suspect the pieces in image = torch.cat(pieces) had a more consistent signal compared to them shuffled. (hand-waving, I didn't do the math)

Starting with more of the smaller windows had a very directly noticeable effect as well - the images developed more structure earlier on in training, but that often sacrifices coherence - it's particularly nice when you're hoping to make a scene with lots of small details, and particularly terrible when you're trying to get a single high quality object. In the end the whole experience felt like choosing the most appropriate patch trade-off, and I stopped early since I didn't have the time to get an in-depth understanding of SIREN and find the best way to use its strengths here.

Hope this is at least partially helpful, and I wonder if there is a training regime that allows both a Good first epoch, and high quality continuity in later ones...

afiaka87 commented 3 years ago

@lucidrains I believe there's a PR with a fix for this one.

dginev commented 3 years ago

Second easiest comment should probably go into this issue (arriving here before I jump into the meat of #66 ). I still want to make the claim that sorting the samples sizes improves training stability, as I tried to explain above.

To try to restate why: during each forward pass the torch.cat concatenated tensor has a monotonocially increasing sequence of windows, as opposed to a randomly shuffled sequence. Which ought to allow SIREN to get more "fine detail" signal consistently from the same locations of the concatenated tensor, and more "global coherence" signal from the opposite end of the concatenated tensor, etc. It's also harder to compare the different aspects of the scheduling you're exploring here if we're suffering from these learning instabilities.

It's a little hard to prove, and likely not possible to do so convincingly. Here's what fresh evidence I can offer with the latest v0.7 release, and your new_augmentations branch. In both cases I will compare to a +sorted variant that has the following two lines added to deep_daze.py:

        sizes = torch.randint(int(lower), int(upper), (self.batch_size,))
        # Added in +sorted:
        sorted, indexes = torch.sort(sizes, 0)
        sizes = sorted
imagine --num-layers 44 --batch-size 64 "A llama wearing a scarf and glasses, reading a book in a cozy cafe."
branch result in 700steps +sorted 700steps
main v0.7 image image
new_augmentations image image

The loss was -60 new_augmentations vs -62 when sorted; -61 main v0.7 vs -62 when sorted. Maybe not convincing after all, and maybe I'm just experiencing confirmation bias, but it does seem like the sorted variants converge slightly faster and arrive at slightly better assembled compositions. I'll now try to check if some of the other experiments in #66 are reproducible/can be made into even better llamas.

NotNANtoN commented 3 years ago

Okay, I think I know where our difference in understanding lies. I assume, that the resulting images from torch.cat on the image_pieces will be one large batch of tensors that is fed into CLIP to encode every single image. Then, for each encoding, the cosine similarity to the text_embed is calculated and averaged. That means all sizes are processed simultaneously and averaged: the sorting should not make any difference.

That's why I also criticized the sorting in the previous scheduling. There were groups of the form [5, 4, 3, 2, 1, 1], but in the end, the sampled results were sorted, hence this group is equivalent to the group [1, 1, 2, 3, 4, 5].

On another note, I think it is hard to compare losses in our comparisons. The issue is that the loss is influenced by the kind of augmentations we apply. Without augmentations the loss will be minimized easily, with a range of random cutouts from size 0.1 to 1.0 it will have a harder time and therefore a higher loss - but that does not mean that the images look worse to the human eye, they are just potentially less like adversarial examples (patterns that achieve the goal of maximally/minimally exciting the network that are completely out of distribution).

dginev commented 3 years ago

Then, for each encoding, the cosine similarity to the text_embed is calculated and averaged.

If that is true then the sort should really have no particular influence, and I'm in the wrong here. Thanks!

There were groups of the form [5, 4, 3, 2, 1, 1], but in the end, the sampled results were sorted, hence this group is equivalent to the group [1, 1, 2, 3, 4, 5].

This however was never true, sort or no sort. Those partitions lead to different distributions of sizes, and those sizes were sorted - not the partitions themselves, see my first comment.

I'll run some experiments with different fixed sizing schedules and report back with more llamas. In the end results speak for themselves with these models 🦙

dginev commented 3 years ago

I can now contribute 5 "fixed schedule" runs, with a selection of sizes that remains constant for the entire run. Maybe it gives us a hint if this random sampling is a good idea at all. Rest of the code is identical to v0.7, running with 44 layers, batch of 64 and 700 steps.

Listing them by the average size in each run, percentage from the image width of 512. The exact numbers are in a details section at the bottom.

mean size 50%, uniform mean size 66%
A_llama_wearing_a_scarf_and_glasses,_reading_a_book_in_a_cozy_cafe 000007 image
mean size 81% mean size 84%
image image
mean size 91%

image

At least I think I can conclude that the level of detail in early training increases as the average window size gets smaller, but also the sub-fragmentation of the image is more apparent. On the flip side, in the largest size schedule the image is very coherent, but the details are lacking ("foggy") and may or may not get filled in if one lets training continue down the line. To me the 81% variant was surprising with the level of coherence+detail appearing together, so there may be some Pareto principle to aim for here. Of course the mean isn't the only important thing, the way the individual sizes are distributed has a pretty direct effect on the final image - so more to play with. I'm also curious whether the 70-80% range is even better, didn't get to try that yet. Food for thought for now...

* mean 50%, uniform ```python sizes = [ 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 264, 272, 280, 288, 296, 304, 312, 320, 328, 336, 344, 352, 360, 368, 376, 384, 392, 400, 408, 416, 424, 432, 440, 448, 456, 464, 472, 480, 488, 496, 504, 511] ``` * mean 66% ```python sizes = [ 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 192, 256, 288, 320, 352, 384, 392, 400, 404, 408, 409, 410, 411, 412, 413, 414, 415, 416, 418, 420, 424, 428, 432, 436, 440, 448, 450, 452, 454, 456, 458, 460, 462, 464, 466, 468, 472, 474, 476, 478, 484, 486, 490, 494, 498, 502, 506, 510] ``` * mean 81% ```python sizes = [16, 32, 64, 128, 192, 256, 288, 320, 352, 384, 392, 400, 404, 408, 409, 410, 411, 412, 413, 414, 415, 416, 418, 420, 424, 428, 432, 436, 440, 448, 450, 452, 454, 456, 458, 460, 462, 464, 466, 468, 472, 474, 476, 478, 480, 482, 484, 488, 489, 490, 491, 492, 494, 496, 498, 499, 500, 501, 502, 504, 506, 508, 510, 511] ``` * mean 84% ```python sizes = [ 16, 32, 64, 128, 192, 256, 288, 320, 352, 384, 392, 400, 408, 424, 428, 432, 436, 440, 448, 450, 452, 454, 456, 458, 460, 461, 462, 463, 464, 465, 466, 467, 468, 472, 474, 476, 475, 477, 478, 480, 482, 484, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 508, 510, 511] ``` * mean 91% ```python sizes = [ 64, 128, 256, 352, 452, 454, 456, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 500, 500, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511] ```
NotNANtoN commented 3 years ago

Hi, @dginev, sorry for not responding for (relatively) long.

I agree with your analysis on the mean of the Gaussian size-sampling. Larger means tend to get washed-out a bit.

I experimented with some new ideas that you might like. The first is to not only use a Gaussian for the sampling of the size, but also for the positioning of the cutout. I centered the Gaussian for the position around the center of the image. I sample in a way such that large cutouts are more likely to be centered around the center of the image, whereas the distributions for small cutouts are more spread-out.

Similar to your experiment I got: Gaussian size mean of 1.0. If you stop it around second 7 it looks really good imo:

https://user-images.githubusercontent.com/19983153/110557619-3f500c00-8141-11eb-883c-1b95b3e14382.mp4

Gaussian size mean of 0.8:

https://user-images.githubusercontent.com/19983153/110557652-54c53600-8141-11eb-8fa4-2bcc482fff76.mp4

Gaussian size mean of 0.6:

https://user-images.githubusercontent.com/19983153/110557689-61e22500-8141-11eb-9b29-a4adf70f049d.mp4

I also played around with replacing the AdamP optimizer by the DiffGrad optimizer. Here it is combined with uniform size sampling and Gaussian position sampling (I think this is my favorite):

https://user-images.githubusercontent.com/19983153/110557788-90f89680-8141-11eb-9ba1-8a1830af6af7.mp4

Lastly, I got a very fascinating one. I combined the uniform sampling + Gaussian position sampling with the feature averaging approach I talked about earlier. That means the mean aggregation of all samples cutout features is optimized, instead of each individually. The result is strikingly different than other generations and for this image/GIF I really have the impression of being in a cafe (although, unfortunately, the details are a bit washed-out):

https://user-images.githubusercontent.com/19983153/110557969-e634a800-8141-11eb-875a-104f39a6e184.mp4

I still don't have a principled way of analysing these generations :/ Not sure if there is one. But I've been playing around with a new prompt - "A wizard in blue robes is painting a completely red image in a castle.". Here's a Gaussian position sampling generation for it:

https://user-images.githubusercontent.com/19983153/110558153-475c7b80-8142-11eb-9e42-426e161e5a53.mp4