Closed idoiaruiz closed 7 years ago
In our case we (Team 1) have addressed this issue by including the number of classes in the name of the last Dense layer, so that the same model used for different datasets has all the layers with the same name except the last one, which is bound to the number of classes.
That means changing the following line in each model you create should solve the transfer learning issue:
x = Dense(n_classes, name='dense{}'.format(n_classes))(x)
@idoiaruiz you are right! thanks for pointing this out and for the pool request! Your proposed solution is correct and we can merge it. However, I prefer the solution of @santipuch590 just because it involves less changes in the code. is it ok for you? If everybody agrees and any of you send another PR with the alternative solution we'll can merge it.
I've actually done a commit myself with @santipuch590 solution:
https://github.com/dvazquezcvc/mcv-m5/commit/e7a1ecb7ca0377cc9150cd5a20d48906c7647d12
I did it because other groups where asking about this.
If there is anything else we can continue the discussion here, ok?
Thanks!
There is an issue when doing transfer learning using two different datasets but the same architecture. In our case, we train vgg with the TT100K dataset, and then we use these trained weights up to a certain layer for vgg with the BTS dataset. As the weights are loaded only for the layers with the same name, but all the layers have the same name since it is the same model, an error is raised because of the last classification layer; the number of classes is not the same for both datasets. Therefore, the weights size does not fit.
One workaround is to change the name of the last layer after training, but this is not a good solution, as it has to be done every time in every model. I propose to load only the weights of the frozen layers in a copy of the model, that only has these layers. And then transfer the weights to our model to train. It avoids the error and works for any model. This option can be chosen with the new config parameter "different_datasets".