balthazarneveu / blind-deblurring-from-synthetic-data

MVA ENS Paris Saclay - Image restoration project on deblurring learnt on deadleaves
3 stars 0 forks source link

Architectures #4

Closed balthazarneveu closed 6 months ago

balthazarneveu commented 6 months ago

Add architectures and start training.

balthazarneveu commented 6 months ago

Visualize networks (for reports)? add method to the base model maybe

from torchview import draw_graph
class DummyModel(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.stack = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 5), torch.nn.ReLU(), torch.nn.Conv2d(64, 64, 3), torch.nn.ReLU(), torch.nn.Conv2d(64, 3, 5))
    def forward(self, x):
        return self.stack(x)
batch_size = 2
model_dummy = DummyModel()
model_dummy(torch.randn((batch_size, 3, 32, 32)))
device='meta' # -> no memory is consumed for visualization
model_graph = draw_graph(model_dummy, input_size=(batch_size, 3, 32, 32), device='meta')
model_graph.visual_graph
balthazarneveu commented 6 months ago

NAFNET looks much above in terms of performances compared to the stacked convolution - on blind denoising problem

Image