NVIDIA / DIGITS

Deep Learning GPU Training System
https://developer.nvidia.com/digits
BSD 3-Clause "New" or "Revised" License
4.12k stars 1.38k forks source link

Loss layers in DIGITS #437

Closed naranjuelo closed 8 years ago

naranjuelo commented 8 years ago

Hi!

I want to train a model in DIGITS with a contrastive loss function, but I get an error telling me to add a loss layer. Is it possible to use any other loss function different from "SoftmaxWithLoss"? Thank you!!

gheinrich commented 8 years ago

For Caffe classification networks, DIGITS only supports SoftmaxWithLoss. You can create an "other" type of network if you wish to use other loss functions. Are you trying to train a siamese network perhaps?

naranjuelo commented 8 years ago

Yes, I want to train a siamese network for face verification (CASIA architecture).

naranjuelo commented 8 years ago

What do you mean by "other" type of network? I am using DIGITS 2.0, I guess you are talking about the new features added on newer versions. Am I right?

gheinrich commented 8 years ago

Yes, the other type of networks is coming with DIGITS 3 and already available on the master branch on Github. If you want to train non-classification networks then you will need to create your own LMDB dataset so that DIGITS can use it for training. I haven't tried training Siamese networks using DIGITS but if you read this tutorial, that should get you started.

This is possibly a bit of a challenge but you will need to:

Let us know how you're getting on! If you're willing to contribute a DIGITS tutorial once you're done that would be very much appreciated!

naranjuelo commented 8 years ago

Ok, I'm going to update my DIGITS and try it, so I'll tell you if I get it work. Thank you very much!!

naranjuelo commented 8 years ago

I started with a simple example (Mnist siamese caffe tutorial) and it looks good. It converges and if I draw the features of the tested numbers I get the following image:

siamesemnistfeatures

So now I'm trying to apply this to my face representation net, but I have a doubt about multiple input labels per image in DIGITS. I've read the regression tutorial and I understand that providing the LMDB label database is the only way to provide vector shape labels to the net (more than one value for each image). But then when I define the three loss layers of my architecture, how do I define which value of the vector shape label (3 values per image) goes to each loss layer? Is that possible in DIGITS?

In case it helps to understand my doubt, I show you the architecture I've designed:

casia_net

gheinrich commented 8 years ago

Can you not use a slice layer to split your label vector into 3 different data streams? Something similar to what you're doing in your slice_pair layer but with three top layers (one for each loss layer) instead of two?

naranjuelo commented 8 years ago

I thought that too, but I would not have control over which label is each of them and I don't want to mix them (first corresponds to one image, second to the other and third to same/not same classes)

gheinrich commented 8 years ago

Your label vector of three scalars would be sliced up in a deterministic order: first scalar is propagated to first top layer, etc. Does this not allow you to retain control of which label goes to each loss layer?

naranjuelo commented 8 years ago

I didn't know that the order was respect in that way, then there shouldn't be any problem. I will try, thank you!!

gheinrich commented 8 years ago

I have pushed a pull request (https://github.com/NVIDIA/DIGITS/pull/453) with a tutorial on how to train the Siamese network from the Caffe tutorial in DIGITS. I hope this helps.

naranjuelo commented 8 years ago

I've seen it, good work! sorry not having done the tutorial but I'm not having much time at the moment. I also see in the "slice_triplet" layer what you told me about the slice order. Thank you very much for your help!!

lukeyeager commented 8 years ago

The tutorial has been merged into the master branch. Please try it out and let us know if you run into any issues.