xinntao / ESRGAN

ECCV18 Workshops - Enhanced SRGAN. Champion PIRM Challenge on Perceptual Super-Resolution. The training codes are in BasicSR.
https://github.com/xinntao/BasicSR
Apache License 2.0
6.04k stars 1.07k forks source link

Regarding Network Parameters #60

Closed vibss2397 closed 5 years ago

vibss2397 commented 5 years ago

Hi, so i was building out an esrgan on my own and tried to keep it as faithful to the paper as possible. However, the thing is that the number of parameters is quite large(55M for 16 rrdb blocks) . I wanted to ask if what i was doing was right or is there some error, below i have attached my architecture for my rrdbnet

def rdb_block(model, kernal_size, filters, strides):

    gen = model
    model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
    model = LeakyReLU(alpha = 0.2)(model)
    model = Concatenate()([gen, model])
    gen2 = model

    model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
    model = LeakyReLU(alpha = 0.2)(model)
    model = Concatenate()([gen, gen2, model])
    gen3 = model

    model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
    model = LeakyReLU(alpha = 0.2)(model)
    model = Concatenate()([gen, gen2,gen3, model])
    gen4 = model

    model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
    model = LeakyReLU(alpha = 0.2)(model)
    model = Concatenate()([gen, gen2,gen3, gen4, model])

    model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
    model = Lambda(lambda x: x*0.2)(model)
    return Add()([model, gen])

def rrdb_block(model, kernel_size, filters, strides):
    rdb1 = rdb_block(model, kernel_size, filters, strides)
    rdb2 = rdb_block(rdb1, kernel_size, filters, strides)
    rdb2 = rdb_block(rdb2, kernel_size, filters, strides)
    rdb2 = Lambda(lambda x: x*0.2)(rdb2)

    return Add()([rdb2, model])

def generator(self):
        gen_input = Input(shape = self.input_shape)
        model = Conv2D(filters = 64, kernel_size = 9, strides = 1, padding = "same")(gen_input)
        gen_model = model
        # Using 16 Residual Blocks
        for index in range(16):
            model = rrdb_block(model, 3, 64, 1)

        model = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(model)
        model = add([gen_model, model])

        # Using 2 UpSampling Blocks
        for index in range(2):
            model = up_sampling_block(model, 3, 256, 1)

        model = Conv2D(filters = 3, kernel_size = 9, strides = 1, padding = "same")(model)

        generator_model = Model(inputs = gen_input, outputs = model)

        return generator_model
xinntao commented 5 years ago

The total number of parameters should be around 16M. The channels of Conv in RDB are not the same as those in ESRGAN implementation. https://github.com/xinntao/ESRGAN/blob/b7c263b8c8f201a9123921d83dfe9bd73b38ec0c/RRDBNet_arch.py#L14-L23

vibss2397 commented 5 years ago

Hey @xinntao correct me if I'm wrong but in pytorch conv2d the first argument is input channels while the second one is the output channels, while in keras the argument filter is only for output channels, which according to your code must be the same

xinntao commented 5 years ago

Yes, your understanding of pytorch is right.

For the rdb_block, the output channel in pytorch code is gc=32, while in your code, the output channel is filters=64.

vibss2397 commented 5 years ago

wow thanks the parameters are way down now and i can train it in the original configuration. Thanks :)