Open magalhaesdavi opened 2 years ago
Did you figure this out?
Not yet.
I just tried something like this where I added the function _make_layer_combo
but unsure if this exactly how it is supposed to be...
class CoAtNet_6_7(nn.Module):
def __init__(self, image_size, in_channels, num_blocks, channels, num_classes=1000, block_types=['C', 'C', 'T', 'T']):
super().__init__()
ih, iw = image_size
block = {'C': MBConv, 'T': Transformer}
self.s0 = self._make_layer(
conv_3x3_bn, in_channels, channels[0], num_blocks[0], (ih // 2, iw // 2))
self.s1 = self._make_layer(
block[block_types[0]], channels[0], channels[1], num_blocks[1], (ih // 4, iw // 4))
self.s2 = self._make_layer(
block[block_types[1]], channels[1], channels[2], num_blocks[2], (ih // 8, iw // 8))
self.s3 = self._make_layer_combo(
block[block_types[2]], block[block_types[3]], channels[2], channels[3], channels[4],
num_blocks[3], num_blocks[4], (ih // 16, iw // 16))
self.s4 = self._make_layer(
block[block_types[4]], channels[4], channels[5], num_blocks[5], (ih // 32, iw // 32))
self.pool = nn.AvgPool2d(ih // 32, 1)
self.fc = nn.Linear(channels[-1], num_classes, bias=False)
def forward(self, x):
x = self.s0(x)
x = self.s1(x)
x = self.s2(x)
x = self.s3(x)
x = self.s4(x)
x = self.pool(x).view(-1, x.shape[1])
x = self.fc(x)
return x
def _make_layer(self, block, inp, oup, depth, image_size):
layers = nn.ModuleList([])
for i in range(depth):
if i == 0:
layers.append(block(inp, oup, image_size, downsample=True))
else:
layers.append(block(oup, oup, image_size))
return nn.Sequential(*layers)
def _make_layer_combo(self, block_1, block_2, inp_1, oup_1, oup_2, depth_1, depth_2, image_size):
layers = nn.ModuleList([])
for i in range(depth_1):
if i == 0:
layers.append(block_1(inp_1, oup_1, image_size, downsample=True))
else:
layers.append(block_1(oup_1, oup_1, image_size))
for i in range(depth_2):
if i == 0:
layers.append(block_2(oup_1, oup_2, image_size, downsample=True))
else:
layers.append(block_2(oup_2, oup_2, image_size))
return nn.Sequential(*layers)
I'm also not sure, I'll have to check the CoAtNet paper again. But at first glance it seems good, thank you!
Hello, first off, really appreciate your work! Now, how can I use coatnet-6 and coatnet-7? Is it adding a sequential of a 'C' block followed by a 'T' block at s3?