PPPW / deep-learning-random-explore

194 stars 34 forks source link

Why is my learn.summary() wrong? #4

Closed austinmw closed 5 years ago

austinmw commented 5 years ago

Hi, thanks for creating these examples!

I used the cadene alexnet and dog/cat breed dataset to create a learner. learn.model prints the right model, but learn.summary() prints the wrong model head. Any idea what's going wrong?

pretrained='imagenet'

def alexnet_cadene(*args):
    model = pretrainedmodels.__dict__['alexnet'](pretrained=pretrained)
    sz = pretrainedmodels.pretrained_settings['alexnet']['imagenet']['input_size'][-1]
    data.sz = data.one_batch()[0].size()[-1]
    if data.sz != sz:
        raise ValueError(f'data size should be {sz} but is instead {data.sz}')    
    model.last_linear.out_features = data.c
    all_layers = list(model.children())
    model = nn.Sequential(all_layers[0], nn.Sequential(Flatten(), *all_layers[1:]))    
    return model

arch_summary(lambda _: alexnet_cadene()) # overall
arch_summary(lambda _: next(alexnet_cadene().children())) # body
arch_summary(lambda _: list(alexnet_cadene().children())[1]) # head

learn = create_cnn(data,alexnet_cadene,custom_head=children(alexnet_cadene())[1],metrics=error_rate,
                  split_on= lambda m: (m[0][0][6],m[1],m[1][7]))

get_groups(nn.Sequential(*learn.model[0][0], *learn.model[1]), learn.layer_groups)

print(learn.layer_groups)

print(learn.model) 
print(learn.summary()) # why is the model summary wrong?
# last linear layer says 1000 classes when my dataset has 37
PPPW commented 5 years ago

Hi @austinmw, as you probably found out, in the Cadene pretrained models models/torchvision_models.py -> modify_alexnet -> features, there's a view function call which will not be included in children, so this one can't be directly passed to create_cnn. The torchvision's AlexNet works fine with fastai.

austinmw commented 5 years ago

@PPPW Thanks, I also found that I needed to replace the last linear layer completely rather than trying to modify its out_features value.