Closed gchhablani closed 3 years ago
Merging #85 (9b2308f) into main (eec666b) will not change coverage. The diff coverage is
100.00%
.
@@ Coverage Diff @@
## main #85 +/- ##
==========================================
Coverage 100.00% 100.00%
==========================================
Files 16 17 +1
Lines 603 704 +101
==========================================
+ Hits 603 704 +101
Impacted Files | Coverage Δ | |
---|---|---|
code_soup/ch5/algorithms/atn.py | 100.00% <100.00%> (ø) |
@gchhablani You can take a look at https://github.com/RanTaimu/Adversarial-Transformation-Network
Adding some code for ImageNet models for another PR:
class BilinearUpsample(nn.Module):
def __init__(self, scale_factor):
super(BilinearUpsample, self).__init__()
self.scale_factor = scale_factor
def forward(self, x):
return F.interpolate(
x, scale_factor=self.scale_factor, mode="bilinear", align_corners=True
)
class BaseDeconvAAE(AAEBase):
def __init__(
self,
classifier_model: torch.nn.Module,
pretrained_backbone: torch.nn.Module,
target_idx: int,
alpha: float = 1.5,
beta: float = 0.010,
backbone_output_shape: list = [192, 35, 35],
):
if backbone_output_shape != [192, 35, 35]:
raise ValueError("Backbone output shape must be [192, 35, 35].")
super(BaseDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta)
layers = [
pretrained_backbone,
nn.ZeroPad2d((1, 1, 1, 1)),
nn.ConvTranspose2d(192, 512, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ZeroPad2d((3, 2, 3, 2)),
nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1),
nn.Tanh(),
]
self.atn = nn.ModuleList(layers)
class ResizeConvAAE(AAEBase):
def __init__(
self,
classifier_model: torch.nn.Module,
target_idx: int,
alpha: float = 1.5,
beta: float = 0.010,
):
super(ResizeConvAAE, self).__init__(classifier_model, target_idx, alpha, beta)
layers = [
nn.Conv2d(3, 128, 5, padding=2),
nn.ReLU(),
BilinearUpsample(scale_factor=0.5),
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(),
BilinearUpsample(scale_factor=0.5),
nn.Conv2d(256, 512, 3, padding=1),
nn.ReLU(),
BilinearUpsample(scale_factor=0.5),
nn.Conv2d(512, 512, 1, padding=0),
nn.ReLU(),
BilinearUpsample(scale_factor=2),
nn.Conv2d(512, 256, 3, padding=1),
nn.ReLU(),
BilinearUpsample(scale_factor=2),
nn.Conv2d(256, 128, 3, padding=1),
nn.ReLU(),
BilinearUpsample(scale_factor=2),
nn.ZeroPad2d((3, 2, 3, 2)),
nn.Conv2d(128, 3, 3, padding=1),
nn.Tanh(),
]
self.atn = nn.ModuleList(layers)
class ConvDeconvAAE(AAEBase):
def __init__(
self,
classifier_model: torch.nn.Module,
target_idx: int,
alpha: float = 1.5,
beta: float = 0.010,
):
super(ConvDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta)
layers = [
nn.Conv2d(3, 256, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(512, 768, 3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(768, 512, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=0),
nn.Tanh(),
]
self.atn = nn.ModuleList(layers)
class BaseDeconvPATN(PATNBase):
def __init__(
self,
classifier_model: torch.nn.Module,
pretrained_backbone: torch.nn.Module,
target_idx: int,
alpha: float = 1.5,
beta: float = 0.010,
backbone_output_shape: list = [192, 35, 35],
):
if backbone_output_shape != [192, 35, 35]:
raise ValueError("Backbone output shape must be [192, 35, 35].")
super(BaseDeconvPATN, self).__init__(classifier_model, target_idx, alpha, beta)
layers = [
pretrained_backbone,
nn.ZeroPad2d((1, 1, 1, 1)),
nn.ConvTranspose2d(192, 512, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ZeroPad2d((3, 2, 3, 2)),
nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1),
nn.Tanh(), # TODO: CHeck if right activation
]
self.atn = nn.ModuleList(layers)
class ConvFCPATN(PATNBase):
def __init__(
self,
classifier_model: torch.nn.Module,
target_idx: int,
alpha: float = 1.5,
beta: float = 0.010,
):
super(BaseDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta)
layers = [
nn.Conv2d(3, 512, 3, stride=2, padding=1),
nn.Conv2d(512, 256, 3, stride=2, padding=1),
nn.Conv2d(256, 128, 3, stride=2, padding=1),
nn.Flatten(),
nn.Linear(184832, 512),
nn.Linear(512, 268203),
nn.Tanh(),
]
self.atn = nn.ModuleList(layers)
I don't think we need testing for reproducibility. If by reproducibility you mean getting the exact results. That should be taken care in tutorials
Did not check the exact implementation but if that is good then this is good
This PR is a WIP on ATN #19. I will be skipping simultaneously training on multiple networks, "insider" information.
I couldn't find an official implementation, so I am basing my implementation on the paper. The paper does not mention the exact hyperparameters, so I am changing a few hyperparams (kernel size, stride, etc.) to make things work. Hope that is okay.