Open shreyas-kowshik opened 5 years ago
I think having a section here/ Metalhead.jl on GANs is fine.
First order of business is to get a general model working. Are there any specifics that you see that need to be addressed in Flux to get this to run on a broad basis?
Generator :
struct DownConv
max_pool
double_conv # Chain(Conv,BN,Relu) * 2
end
DownConv(in_ch,out_ch) = ... # (Define a DownConv instance here)
function (dconv::DownConv)(x)
... # Implement a max pool and a double_conv here
end
# This is for the down sampling part.
struct UpConv
double_conv
conv_transpose
end
UpConv(in_ch,out_ch) = ... # (Define UpConv Instance)
function (uc::UpConv)(x,y)
z = uc(x)
cat(z,y)
end
struct UNet
down_convs
up_convs
end
function UNet()
# Define down_convs as an array of down_conv
# Define up_convs as an array up_conv
UNet(down_convs,up_convs)
end
function (un::UNet)(x)
# Run down_convs and save outputs
# Run up_convs
end
Discriminator :
struct ConvBlock
layers
end
function CB(in_ch,out_ch) = ConvBlock(Chain(...))
dis = Chain(CB...) # Max number of channels will be taken as a hyperparameter
There would be a generator and a discriminator defined for each domain.
Loss functions :
function loss(x,y) # x and y are images, one of each domain
# identity losses
# reconstruction losses
# adversarial losses
end
Utility functions :
functions get_images()
I plan to organise the code as :
loss.jl
models.jl
util.jl
train.jl
inference.jl
Hey Shreyas, this all looks reasonable to me; feel free to start hacking this out, and let us know when you run into something that is difficult; let's do our best to try and figure out what are going to be the difficult pieces as early as possible, and devote appropriate resources into making them as painless as possible! :)
@staticfloat I've put together the pieces here : https://github.com/shreyas-kowshik/CycleGAN.jl
I am unable to test the implementation as the model does not fit on my GPU. I've talked to @dhairyagandhi96 regarding this.
Yeah, we will get them access to some GPU machines to test their code out on.
As suggested by @dhairyagandhi96 , I have chalked out an implementation for the cycleGAN model
Write the UNet architecture for the generator [For 256x256 and 128x128 images]
Write the discriminator. The paper's implementation reference would be followed.
Training would be done on the apples2oranges dataset first followed by the horses2zebras dataset.
Formulation of the identity loss for the generators and discriminators and the adversarial losses.
The code will be organised in a separate repository with the utility functions, I/O, model definitions and training files.
What more details are required? @staticfloat