lukemelas / EfficientNet-PyTorch

A PyTorch implementation of EfficientNet
Apache License 2.0
7.92k stars 1.53k forks source link

Ussing attention modules in efficient net #146

Open Pharaun85 opened 4 years ago

Pharaun85 commented 4 years ago

Hello I'm trying to extract the results of the net in its second to last and last layers to build two attention modules for semantic segmentation. This is the code:

class efficientnet_b5(torch.nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = EfficientNet.from_pretrained('efficientnet-b5')
        self.conv = self.features._conv_stem
        self.bn0 = self.features._bn0
        self.layer0_36 = self.features._blocks[:36]
        self.layer37 = self.features._blocks[37]
        self.layer38 = self.features._blocks[38]
        self.layers = self.features._blocks
        self.conv_head = self.features._conv_head
        self.memswish = self.features._swish

    def forward(self, input):
        x = self.conv(input)
        x = self.bn0(x)
        feature1 = self.layer0_36(x) #512 features
        feature2 = self.layer37(feature1) # 512 features
        feature3 = self.layer38(feature2) # 512 features
        feature4 = self.conv_head(feature3) #2048 features
        feature4 = self.memswish(feature4)
        # global average pooling to build tail
        tail = torch.mean(feature4, 3, keepdim=True)
        tail = torch.mean(tail, 2, keepdim=True)
        return feature3, feature4, tail

The thing is that I am obtaining this error: epoch 0, lr 0.031250: 0%| | 0/52 [00:00<?, ?it/s]Traceback (most recent call last): File "train.py", line 292, in main(args) File "train.py", line 254, in main train(args, model, optimizer, dataloader_train, dataloader_val) File "train.py", line 83, in train outputs = model(data) File "/home/malvaro/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in call result = self.forward(*input, kwargs) File "/home/malvaro/Documentos/DualBiSeNet/model/build_DualBiSeNet.py", line 170, in forward cx1, cx2, tail = self.context_path(imginput) File "/home/malvaro/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in call result = self.forward(*input, *kwargs) File "/home/malvaro/Documentos/DualBiSeNet/model/build_contextpath.py", line 208, in forward features3 = self.layers(x) File "/home/malvaro/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in call result = self.forward(input, kwargs) File "/home/malvaro/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 96, in forward raise NotImplementedError NotImplementedError

Has anyone ever had the same thing happen to them as I have?

Thanks

lukemelas commented 4 years ago

Sorry for the delay. It might help to change to the regular (non-memory-optimized) swish implementation by calling this after you create the model:

model.set_swish(memory_efficient=False)
Pharaun85 commented 4 years ago

Hello! Thanks for your reply. Can you help me with where I should call that sentence? I load de model by instancing the above class efficientnet_b5 in a variable called contextpath then I tried this: context_path.set_swish(memory_efficient=False) but it says that the attribute is missing AttributeError: 'efficientnet_b5' object has no attribute 'set_swish'

Thanks!

Pharaun85 commented 4 years ago

Hello I changed the implementation and put it like this, but the error still the same.

class efficientnet_b5(torch.nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = EfficientNet.from_pretrained('efficientnet-b5')
        self.features.set_swish(memory_efficient=False)
        self.conv = self.features._conv_stem
        self.bn0 = self.features._bn0
        self.layer0_36 = self.features._blocks[:36]
        self.layer37 = self.features._blocks[37]
        self.layer38 = self.features._blocks[38]
        self.layers = self.features._blocks
        self.conv_head = self.features._conv_head
        self.memswish = self.features._swish
    def forward(self, input):
        x = self.conv(input)
        x = self.bn0(x)
        feature1 = self.layer0_36(x) #512 features
        feature2 = self.layer37(feature1) # 512 features
        feature3 = self.layer38(feature2) # 512 features
        feature4 = self.conv_head(feature3) #2048 features
        feature4 = self.memswish(feature4)
        # global average pooling to build tail
        tail = torch.mean(feature4, 3, keepdim=True)
        tail = torch.mean(tail, 2, keepdim=True)
        return feature3, feature4, tail
Traceback (most recent call last):
  File "train.py", line 292, in <module>
    main(args)
  File "train.py", line 254, in main
    train(args, model, optimizer, dataloader_train, dataloader_val)
  File "train.py", line 83, in train
    outputs = model(data)
  File "/home/malvaro/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/malvaro/Documentos/DualBiSeNet/model/build_DualBiSeNet.py", line 173, in forward
    cx1, cx2, tail = self.context_path(imginput)
  File "/home/malvaro/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/malvaro/Documentos/DualBiSeNet/model/build_contextpath.py", line 209, in forward
    feature1 = self.layer0_36(x) #512 features
  File "/home/malvaro/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/malvaro/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 96, in forward
    raise NotImplementedError