Closed I-C-Karakozis closed 5 years ago
Fixed it. Here is the solution for anyone running into the same problem (it was very similar to my previous error):
def initLinear(linear, val=None):
if val is None:
fan = linear.in_features + linear.out_features
spread = math.sqrt(2.0) * math.sqrt(2.0/fan)
else:
spread = val
linear.weight.data.uniform_(-spread, spread)
linear.bias.data.uniform_(-spread, spread)
return
class ResNet34_Pretrained(nn.Module):
def base_size(self): return 512
def rep_size(self): return 1024
def __init__(self, n_classes):
super(ResNet34_Pretrained, self).__init__()
self.resnet = torchvision.models.resnet34(pretrained=True)
self.conv1 = self.resnet.conv1
self.bn1 = self.resnet.bn1
self.layer1 = self.resnet.layer1
self.layer2 = self.resnet.layer2
self.layer3 = self.resnet.layer3
self.layer4 = self.resnet.layer4
# define layers
self.n_classes = n_classes
self.linear1 = nn.Linear(7 * 7 * self.base_size(), self.rep_size())
self.linear2 = nn.Linear(self.rep_size(), self.n_classes)
self.dropout2d = nn.Dropout2d(.5)
self.dropout = nn.Dropout(.5)
self.relu = nn.LeakyReLU()
# initialize linear layers
initLinear(self.linear1)
initLinear(self.linear2)
def forward(self, out0):
x = self.conv1(out0)
x = self.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.dropout2d(x)
x = self.relu(self.linear1(x.view(-1, 7*7*self.base_size())))
x = self.dropout(x.clone())
cls_scores = self.linear2(x.clone())
return cls_scores
def load_resnet34(n_classes, pretrained=False):
# load network architecture
if pretrained: net = ResNet34_Pretrained(n_classes)
else: net = ResNet34(n_classes)
# determine processing device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = net.to(device)
if device == 'cuda':
# net = torch.nn.DataParallel(net)
cudnn.benchmark = True
return net, device
I also had to revert to pytorch2keras 0.1.14 and torch 0.4.1.
Describe the bug I changed my ResNet code to have more flexibility in what I am trying to do and the converter no longer works (see below for more information and relevant past Issue posted). Any ideas for why this is the case?
To Reproduce
Logs The error message:
Environment (please complete the following information):
Additional context Last bug report by me: https://github.com/nerox8664/pytorch2keras/issues/50 I was using a similar model.