Closed balthazarneveu closed 6 months ago
Visualize networks (for reports)? add method to the base model maybe
from torchview import draw_graph
class DummyModel(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.stack = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 5), torch.nn.ReLU(), torch.nn.Conv2d(64, 64, 3), torch.nn.ReLU(), torch.nn.Conv2d(64, 3, 5))
def forward(self, x):
return self.stack(x)
batch_size = 2
model_dummy = DummyModel()
model_dummy(torch.randn((batch_size, 3, 32, 32)))
device='meta' # -> no memory is consumed for visualization
model_graph = draw_graph(model_dummy, input_size=(batch_size, 3, 32, 32), device='meta')
model_graph.visual_graph
NAFNET looks much above in terms of performances compared to the stacked convolution - on blind denoising problem
Add architectures and start training.