mlverse / luz

Higher Level API for torch
https://mlverse.github.io/luz/
Other
85 stars 13 forks source link

`predict`: How to get embedding layer? #130

Open talegari opened 1 year ago

talegari commented 1 year ago

Question: How do I get the embeddings after fitting using triplet loss in this example: https://mlverse.github.io/luz/articles/examples/mnist-triplet.html ?

dfalbel commented 1 year ago

You could so something like this:

dataset <- mnist_dataset(dir, transform = transform_to_tensor)
preds <- predict(fitted, dataset)

preds

Calling predict is just calling the forward method with the model in eval mode.

dfalbel commented 1 year ago

Sorry the above is wrong, you could modify the triplet model to be something like:

triplet_model <- torch::nn_module(
  initialize = function(embedding_dim = 2, margin = 1) {
    self$embedding <- net(embedding_dim = embedding_dim)
    self$criterion <- nn_triplet_margin_loss(margin = margin)
  },
  loss = function(input, ...) {
    embeds <- lapply(input, self$embedding)
    self$criterion(
      embeds$anchor,
      embeds$positive,
      embeds$negative
    )
  },
  predict = function(x) {
    self$embedding(x)
  }
)

Adding a predict method that just calls the embedding, and then:

dataset <- mnist_dataset(dir, transform = transform_to_tensor)
preds <- predict(fitted, dataset)

preds

You can also access the embedding module from the fitted object, but, in this case you have to manually put the model in eval mode and disable gradients and move tensors to the correct device.

fitted$model$eval()
with_no_grad({
   fitted$model$embedding(dataset[1]$x$unsqueeze(1)$to(device="mps"))
})
talegari commented 1 year ago

Thanks Daniel. IMHO, most folks would require embeddings after the training process. The predict method above should be added to the vignette to make it complete.