Adversarial-Deep-Learning / code-soup

This is a collection of algorithms and approaches used in the book adversarial deep learning
MIT License
18 stars 18 forks source link

Add Adversarial Transformation Networks #85

Closed gchhablani closed 3 years ago

gchhablani commented 3 years ago

This PR is a WIP on ATN #19. I will be skipping simultaneously training on multiple networks, "insider" information.

I couldn't find an official implementation, so I am basing my implementation on the paper. The paper does not mention the exact hyperparameters, so I am changing a few hyperparams (kernel size, stride, etc.) to make things work. Hope that is okay.

codecov[bot] commented 3 years ago

Codecov Report

Merging #85 (9b2308f) into main (eec666b) will not change coverage. The diff coverage is 100.00%.

@@            Coverage Diff             @@
##              main       #85    +/-   ##
==========================================
  Coverage   100.00%   100.00%            
==========================================
  Files           16        17     +1     
  Lines          603       704   +101     
==========================================
+ Hits           603       704   +101     
Impacted Files Coverage Δ
code_soup/ch5/algorithms/atn.py 100.00% <100.00%> (ø)
someshsingh22 commented 3 years ago

@gchhablani You can take a look at https://github.com/RanTaimu/Adversarial-Transformation-Network

gchhablani commented 3 years ago

Adding some code for ImageNet models for another PR:


class BilinearUpsample(nn.Module):
    def __init__(self, scale_factor):
        super(BilinearUpsample, self).__init__()
        self.scale_factor = scale_factor

    def forward(self, x):
        return F.interpolate(
            x, scale_factor=self.scale_factor, mode="bilinear", align_corners=True
        )

class BaseDeconvAAE(AAEBase):
    def __init__(
        self,
        classifier_model: torch.nn.Module,
        pretrained_backbone: torch.nn.Module,
        target_idx: int,
        alpha: float = 1.5,
        beta: float = 0.010,
        backbone_output_shape: list = [192, 35, 35],
    ):

        if backbone_output_shape != [192, 35, 35]:
            raise ValueError("Backbone output shape must be [192, 35, 35].")

        super(BaseDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta)

        layers = [
            pretrained_backbone,
            nn.ZeroPad2d((1, 1, 1, 1)),
            nn.ConvTranspose2d(192, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ZeroPad2d((3, 2, 3, 2)),
            nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh(),
        ]

        self.atn = nn.ModuleList(layers)

class ResizeConvAAE(AAEBase):
    def __init__(
        self,
        classifier_model: torch.nn.Module,
        target_idx: int,
        alpha: float = 1.5,
        beta: float = 0.010,
    ):

        super(ResizeConvAAE, self).__init__(classifier_model, target_idx, alpha, beta)

        layers = [
            nn.Conv2d(3, 128, 5, padding=2),
            nn.ReLU(),
            BilinearUpsample(scale_factor=0.5),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            BilinearUpsample(scale_factor=0.5),
            nn.Conv2d(256, 512, 3, padding=1),
            nn.ReLU(),
            BilinearUpsample(scale_factor=0.5),
            nn.Conv2d(512, 512, 1, padding=0),
            nn.ReLU(),
            BilinearUpsample(scale_factor=2),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(),
            BilinearUpsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(),
            BilinearUpsample(scale_factor=2),
            nn.ZeroPad2d((3, 2, 3, 2)),
            nn.Conv2d(128, 3, 3, padding=1),
            nn.Tanh(),
        ]

        self.atn = nn.ModuleList(layers)

class ConvDeconvAAE(AAEBase):
    def __init__(
        self,
        classifier_model: torch.nn.Module,
        target_idx: int,
        alpha: float = 1.5,
        beta: float = 0.010,
    ):

        super(ConvDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta)

        layers = [
            nn.Conv2d(3, 256, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 768, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(768, 512, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=0),
            nn.Tanh(),
        ]

        self.atn = nn.ModuleList(layers)

class BaseDeconvPATN(PATNBase):
    def __init__(
        self,
        classifier_model: torch.nn.Module,
        pretrained_backbone: torch.nn.Module,
        target_idx: int,
        alpha: float = 1.5,
        beta: float = 0.010,
        backbone_output_shape: list = [192, 35, 35],
    ):

        if backbone_output_shape != [192, 35, 35]:
            raise ValueError("Backbone output shape must be [192, 35, 35].")

        super(BaseDeconvPATN, self).__init__(classifier_model, target_idx, alpha, beta)

        layers = [
            pretrained_backbone,
            nn.ZeroPad2d((1, 1, 1, 1)),
            nn.ConvTranspose2d(192, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ZeroPad2d((3, 2, 3, 2)),
            nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh(),  # TODO: CHeck if right activation
        ]

        self.atn = nn.ModuleList(layers)

class ConvFCPATN(PATNBase):
    def __init__(
        self,
        classifier_model: torch.nn.Module,
        target_idx: int,
        alpha: float = 1.5,
        beta: float = 0.010,
    ):

        super(BaseDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta)

        layers = [
            nn.Conv2d(3, 512, 3, stride=2, padding=1),
            nn.Conv2d(512, 256, 3, stride=2, padding=1),
            nn.Conv2d(256, 128, 3, stride=2, padding=1),
            nn.Flatten(),
            nn.Linear(184832, 512),
            nn.Linear(512, 268203),
            nn.Tanh(),
        ]

        self.atn = nn.ModuleList(layers)
mehulrastogi commented 3 years ago

I don't think we need testing for reproducibility. If by reproducibility you mean getting the exact results. That should be taken care in tutorials

mehulrastogi commented 3 years ago

Did not check the exact implementation but if that is good then this is good