ankane / torch.rb

Deep learning for Ruby, powered by LibTorch
Other
704 stars 30 forks source link

'Torch.normal' seems unavailable and some ideas #15

Closed weili-git closed 3 years ago

weili-git commented 3 years ago

Hello, it's a very interesting project. But I encountered some problems as follows.

1. Torch.normal(mean, std, out) seems unavailable.

when I try to call it just like in python, it turns out that Torch::Error (This should never happen. Please report a bug with normal.) it's inconvenient to use Torch.randn().

2. '==' operation for tensor

For example,

a = Tensor.tensor([1, 2, 3])
b = Tensor.tensor([1, 1, 3])
a == b  # return tensor[1, 0, 1]

3. method to get sub-matrix

a = Torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
p a[1..2][1..2]  # return tensor [[5, 6], [8, 9]]

For 3, I try to solve it my self https://github.com/weili-git/torch-rb-matplotlib-practices/blob/main/Extension.rb but there are still some problems because I have not checked whether the input is a matrix.

By the way, I am new to ruby and torch. And I encountered the questions above while implementing basic algorithms of machine learning. Thanks a lot if you could give me some ideas on how to handle these.

ankane commented 3 years ago

Hey @weili-git, thanks for reporting!

For 1, try upgrading to the latest version of Torch.rb (0.4.1). If that doesn't solve it, can you share an example that causes the error?

For 2, you'll need to use a.eq(b) instead of a == b (the same goes for other comparison operations). I considered overriding == to call eq like Python, but it leads to unintuitive behavior in Ruby.

Python prints good

import torch

if torch.tensor(0).eq(1):
  print('bad')
else:
  print('good')

Ruby prints bad

require "torch"

if Torch.tensor(0).eq(1)
  puts "bad"
else
  puts "good"
end

For 3, the latest version should follow the same behavior as Python and return tensor([[7, 8, 9]]).

Edit: to return [[5, 6], [8, 9]], use a[1..2, 1..2].

weili-git commented 3 years ago

Thank you so much, I upgrade it from 0.3.0 to 0.4.1 and all these are fixed up.

ankane commented 3 years ago

No problem, let me know if you run into anything else weird.

The regression example with real-time chart updates is awesome, btw.

weili-git commented 3 years ago

Hey, it's me again. Right now, I found that there are some small problems here:

1. Torch.save(net.state_dict, "net.pth") works fine, but Torch.save(net, "net.pth") turns out:

in `to_ivalue': Unknown type: Torch::NN::Sequential (Torch::Error)

To save the full net and its weights, what I should do?

2. Torch::Optim::SGD.new(net_Momentum.parameters, lr: LR, momentum: 0.8)

in `item': only one element tensors can be converted to Ruby scalars (Torch::Error)

it seems that the argument 'momentum' cause this error. It gives error when I try to call 'optimizer.step'.

3.loader = Torch::Utils::Data::DataLoader.new(torch_dataset, batch_size: BATCH_SIZE, shuffle: true , num_workers: 2)

in `initialize': unknown keyword: num_workers (ArgumentError)

4.test_dataset = TorchVision::Datasets::MNIST.new(...)

I could not pick some of the data by calling 'test_dataset[0..200]' or just some of the image data by 'test_dataset[0..200, 0]', in python, it is 'test_data.test_data[:200]' and 'test_data.test_labels[:200]'.

in `item': only one element tensors can be converted to Ruby scalars (Torch::Error)
ankane commented 3 years ago

Hey @weili-git, thanks for more feedback!

For 1, from my understanding, saving the state dict (which includes the weights) is recommended over saving the entire model (https://pytorch.org/tutorials/beginner/saving_loading_models.html). Can you explain more about what you're trying to do that's not supported with the state dict method?

For 2, fixed in fa31895ee571f4b534e2d175fb613d788cfb87e6.

For 3, num_workers isn't supported yet. I'm not sure when this will happen, so I'd probably just remove it for now.

For 4, it looks like test_data and test_labels are deprecated in Python in favor of data and targets. I've added those methods in https://github.com/ankane/torchvision/commit/f6674e5f9058f766efcc52ff55b7278781a3d4e9.

You can get the fixes by updating your Gemfile to use GitHub:

gem 'torch-rb', github: 'ankane/torch.rb'
gem 'torchvision', github: 'ankane/torchvision'
weili-git commented 3 years ago

Thanks a lot for your quick reply! Well, It's true that saving the state_dict is preferred, and num_workers have little influence on the whole. I'll give you more feedback Later. :+1: