dribnet / clipit_old

VQGAN+CLIP with some additional tuning. For notebooks and the command line.
MIT License
50 stars 7 forks source link

PixelDrawer support for limited colors #2

Open dribnet opened 3 years ago

dribnet commented 3 years ago

The PixelDrawer currently has no support for limiting the colors - it only support full three channel RGB. It would be preferable to be able to limit the colorspace to either subsets of RGB or to a fixed color palette.

As an initial step, I would like to create a version of the PixelDrawer class that supports a grayscale image where R=G=B. In principle this should be possible by optimizing only one tensor per pixel and then having that tensor be connected to the R, G, and B channels.

The hope is that the PixelDrawer would be able to evolve from there to next support a discrete set of fixed colors, and ultimately also to optionally optimize the color lookup table as it also optimizes the color assignments of the pixels.

Note also that this issue is related up upstream issue BachiLi/diffvg#23

dribnet commented 3 years ago

I have added a do_mono setting to the options which gets passed to the PixelDrawer. Currently this does two things internally:

1) RGB values are initialized to a random values such that R=G=B 2) The the post-optimization method "clip_z" resets each color to the average of the components:

                if self.do_mono:
                    avg_amount = torch.mean(group.fill_color.data[:3])
                    group.fill_color.data[:3] = avg_amount

Together these options do implement a sort of grayscale constraint. So instead of 'sunset river snow mountain' producing a result such as this:

pixel_landscape_01

do_mono constrains the results to grayscale to produce a result like this:

pixel_landscape_02

However internally the implementation is sloppy - the optimizer continues to modify all three channels independently and then they are simply constrained after the fact back to grayscale. So is there an implementation of do_mono that is smarter? Ideally we can come up with a better version that has only one single channel tensor per pixel getting optimized and that gets propagated to the red, greeen, and blue channels.

torridgristle commented 3 years ago

I do limited colors by way of softmax and matrix multiplication with a palette.

Say I've got a palette of 8 colors, RGB, a tensor of [8,3]. My canvas is [batch, palette count, height, width] so like [1,56,56,8], and then I can softmax2d the canvas, permute channels to the last dimension, matrix multiplication with the palette, and then permute the resulting RGB channels to dim 1 and it's a normal batch, channel, height, width rgb image. I'm not familiar with diffvg but perhaps this can be used to define the colors of shapes.

Getting it to strongly prefer just one color is an issue I haven't fully solved but it'll approach it if you multiply the canvas by some scalar value before softmaxing it, but I suspect this causes issues with loss of accuracy in the backprop? So I mix in a bit of another copy of the image that didn't have a multiplier on the values before softmax.

I also tried a loss to check that the post-softmax output has a maximum value of 1 by just softmax, max on the channel dimension, subtract that from 1, square it, average it, that's the loss. But I couldn't seem to weight it in a good way; it would end up prioritizing using a single color over actually making an image that matches the prompt.

Here's a notebook I whipped up last night that demonstrates the softmax stuff, but it's pixels and not vector. PublicCLIP+_Chunky_RGB_Optimization_v0_1.zip

dribnet commented 3 years ago

Thanks @torridgristle - I've also copied that notebook into this gist. This does look like the sort of thing I was hoping to do - perhaps also trying some noise added to the palette assignments similar to gumbel-softmax. But the immediate issue seems to be figuring out how to wire this kind of change into diffvg - so I think initially it might be useful for me to first fix the do_mono implementation just so I can wrap my head around how diffvg deals with these sorts of constraints. But if that goes well I'd like to revisit what you've done here to see if something more ambitious like this could also fit within the diffvg framework.

dribnet commented 3 years ago

I've taken pushed the do_mono feature a bit further, and now it outputs only 1 bit per pixel black and white.

volcano (volcano)

This is done at the last moment when an image is generated from the tensor:

        if self.do_mono:
            img = img[1] # take the green channel (they should all be the same)
            s = img.shape
            # threshold is an approximate gaussian from [0,1]
            random_bates = np.average(np.random.uniform(size=(5, s[0], s[1])), axis=0)
            img = np.where(img > random_bates, 1, 0)
            img = np.uint8(img * 255)
            pimg = PIL.Image.fromarray(img, mode="L")

Basically this is doing a threshold of the pixels, with the caveat that the threshold is stochastic. It's roughly a gaussian with a mean of 0.5 but different at each pixel location. I felt this was mostly in the spirit of gumbel-softmax and maybe encourage the values to settle at either high or low values over time. Here's an example random_bates threshold map output to an image

bates_debug

Anyway, this is slow but still an interesting proof of concept further toward true indexed color though I haven't thought through what the implications are that the tensor being optimised is now more indirectly perceived by CLIP. But in testing it already making some interesting and to me visually unique decisions on shading and attempts at detail.

pixel_landscape_05

Next steps might be a lookup color replacement for black that also gets optimized and and/or trying a 2nd threshold layer.

torridgristle commented 3 years ago

Perhaps optimizing a small model that handles the color-limited palette could be possible. I've got a Pytorch nn module that could be useful.

class GaussianBins(nn.Module):
    """Input multiplied by the gaussian function with learnable sigma and bias per channel
    """
    def __init__(self, in_channels, bins, multiply=True):
        super(GaussianBins, self).__init__()
        self.bin_bias  = nn.Parameter(torch.randn(1,in_channels,bins,1,1).tanh())
        self.bin_sigma = nn.Parameter(torch.randn(1,in_channels,bins,1,1).tanh())
        self.multiply  = multiply
        self.bins      = bins

    def gaussian_pdf(self, x, sigma):
        return torch.exp(-0.5 * (x / sigma).pow(2))

    def forward(self, x):
        x = x.unsqueeze(2)
        gaussian_mask = self.gaussian_pdf(x * (self.bins-1) + self.bin_bias, self.bin_sigma)
        if self.multiply == True:
            result = x * gaussian_mask
        if self.multiply == False:
            result = gaussian_mask
        b,c1,c2,h,w = result.shape
        result = result.reshape(b,c1*c2,h,w)
        return result

With this, or perhaps a few of these in sequence to handle mixing 3 input channels together better, or this and a 2d conv with a 1x1 kernel to handle mixing the values, it'll be able to take a continuous scalar input and split it by amplitude into different channels when multiply is set to false, so that you just get a number that increases when the input value nears a certain level. Then this can be paired with softmax and a set palette, perhaps with the model optimizing the palette itself.

dribnet commented 3 years ago

I realised that some of my thinking in my previous commit was incorrect and have redone how this is handled a bit upstream from before in a separate branch. I think the main thing to determine is if there is any way to make changes / constraints to the diffvg colors upstream of the canvas (I don't see how myself) or continue to make changes downstream from the canvas output (which is what I'm currently doing). Anyway, I have a differentiable version of the the mono mode that produces similar looking images with slightly different internals.

volcano_20

So I think I have a better place to perform the surgery on the current system, but still need to think about what logic to put there.