FluxML / model-zoo

Please do not feed the models
https://fluxml.ai/
Other
910 stars 333 forks source link

Addition Of CycleGAN model #135

Open shreyas-kowshik opened 5 years ago

shreyas-kowshik commented 5 years ago

As suggested by @dhairyagandhi96 , I have chalked out an implementation for the cycleGAN model

The code will be organised in a separate repository with the utility functions, I/O, model definitions and training files.

What more details are required? @staticfloat

DhairyaLGandhi commented 5 years ago

I think having a section here/ Metalhead.jl on GANs is fine.

First order of business is to get a general model working. Are there any specifics that you see that need to be addressed in Flux to get this to run on a broad basis?

shreyas-kowshik commented 5 years ago

Generator :

struct DownConv
 max_pool 
 double_conv # Chain(Conv,BN,Relu) * 2
end

DownConv(in_ch,out_ch) = ... # (Define a DownConv instance here)
function (dconv::DownConv)(x)
 ... # Implement a max pool and a double_conv here
end

# This is for the down sampling part.
struct UpConv
 double_conv
 conv_transpose
end

UpConv(in_ch,out_ch) = ... # (Define UpConv Instance)

function (uc::UpConv)(x,y)
 z = uc(x)
 cat(z,y)
end
struct UNet
 down_convs
 up_convs
end

function UNet()
 # Define down_convs as an array of down_conv
 # Define up_convs as an array up_conv
 UNet(down_convs,up_convs)
end

function (un::UNet)(x)
 # Run down_convs and save outputs
 # Run up_convs
end

Discriminator :

struct ConvBlock
 layers
end

function CB(in_ch,out_ch) = ConvBlock(Chain(...))

dis = Chain(CB...) # Max number of channels will be taken as a hyperparameter

There would be a generator and a discriminator defined for each domain.

Loss functions :

function loss(x,y) # x and y are images, one of each domain
 # identity losses
 # reconstruction losses
 # adversarial losses
end

Utility functions :

functions get_images()

I plan to organise the code as :

loss.jl
models.jl
util.jl
train.jl
inference.jl
staticfloat commented 5 years ago

Hey Shreyas, this all looks reasonable to me; feel free to start hacking this out, and let us know when you run into something that is difficult; let's do our best to try and figure out what are going to be the difficult pieces as early as possible, and devote appropriate resources into making them as painless as possible! :)

shreyas-kowshik commented 5 years ago

@staticfloat I've put together the pieces here : https://github.com/shreyas-kowshik/CycleGAN.jl

I am unable to test the implementation as the model does not fit on my GPU. I've talked to @dhairyagandhi96 regarding this.

DhairyaLGandhi commented 5 years ago

Yeah, we will get them access to some GPU machines to test their code out on.