leaderj1001 / Attention-Augmented-Conv2d

Implementing Attention Augmented Convolutional Networks using Pytorch
MIT License
643 stars 100 forks source link
attention-augmented-conv pytorch

Implementing Attention Augmented Convolutional Networks using Pytorch

Update (2019.05.11)

from attention_augmented_conv import AugmentedConv

use_cuda = torch.cuda.is_available() device = torch.deivce('cuda' if use_cuda else 'cpu')

tmp = torch.randn((16, 3, 32, 32)).to(device) augmented_conv1 = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=4, relative=True, stride=1, shape=32).to(device) conv_out1 = augmented_conv1(tmp) print(conv_out1.shape) # (16, 20, 32, 32)

for name, param in augmented_conv1.named_parameters(): print('parameter name: ', name)

- As a result of parameter name, we can see "key_rel_w" and "key_rel_h".

- Example, relative=True, stride=2, shape=16
```python
import torch

from attention_augmented_conv import AugmentedConv

use_cuda = torch.cuda.is_available()
device = torch.deivce('cuda' if use_cuda else 'cpu')

tmp = torch.randn((16, 3, 32, 32)).to(device)
augmented_conv1 = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=4, relative=True, stride=2, shape=16).to(device)
conv_out1 = augmented_conv1(tmp)
print(conv_out1.shape) # (16, 20, 16, 16)

Update (2019.05.02)

from attention_augmented_conv import AugmentedConv

use_cuda = torch.cuda.is_available() device = torch.deivce('cuda' if use_cuda else 'cpu')

temp_input = torch.randn((16, 3, 32, 32)).to(device) augmented_conv = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=1, relative=False, stride=1).to(device) conv_out = augmented_conv(tmp) print(conv_out.shape) # (16, 20, 32, 32), (batch_size, out_channels, height, width)

- Example, relative=False, stride=2
```python
import torch

from attention_augmented_conv import AugmentedConv

use_cuda = torch.cuda.is_available()
device = torch.deivce('cuda' if use_cuda else 'cpu')

temp_input = torch.randn((16, 3, 32, 32)).to(device)
augmented_conv = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=1, relative=False, stride=2).to(device)
conv_out = augmented_conv(tmp)
print(conv_out.shape) # (16, 20, 16, 16), (batch_size, out_channels, height, width)

I posted two versions of the "Attention-Augmented Conv"

Reference

Paper

Method

image

Input Parameters

Experiments

Datasets Model Accuracy Epoch Training Time
CIFAR-10 Wide-ResNet 28x10(WORK IN PROCESS)
CIFAR-100 Wide-ResNet 28x10(WORK IN PROCESS)
CIFAR-100 Just 3-Conv layers(channels: 64, 128, 192) 61.6% 100 22m
CIFAR-100 Just 3-Attention-Augmented Conv layers(channels: 64, 128, 192) 59.82% 35 2h 23m

Time complexity

Requirements