Open talegari opened 1 year ago
You could so something like this:
dataset <- mnist_dataset(dir, transform = transform_to_tensor)
preds <- predict(fitted, dataset)
preds
Calling predict is just calling the forward
method with the model in eval mode.
Sorry the above is wrong, you could modify the triplet model to be something like:
triplet_model <- torch::nn_module(
initialize = function(embedding_dim = 2, margin = 1) {
self$embedding <- net(embedding_dim = embedding_dim)
self$criterion <- nn_triplet_margin_loss(margin = margin)
},
loss = function(input, ...) {
embeds <- lapply(input, self$embedding)
self$criterion(
embeds$anchor,
embeds$positive,
embeds$negative
)
},
predict = function(x) {
self$embedding(x)
}
)
Adding a predict
method that just calls the embedding, and then:
dataset <- mnist_dataset(dir, transform = transform_to_tensor)
preds <- predict(fitted, dataset)
preds
You can also access the embedding
module from the fitted
object, but, in this case you have to manually put the model in eval mode and disable gradients and move tensors to the correct device.
fitted$model$eval()
with_no_grad({
fitted$model$embedding(dataset[1]$x$unsqueeze(1)$to(device="mps"))
})
Thanks Daniel.
IMHO, most folks would require embeddings after the training process. The predict
method above should be added to the vignette to make it complete.
Question: How do I get the embeddings after fitting using triplet loss in this example: https://mlverse.github.io/luz/articles/examples/mnist-triplet.html ?