ankane / torch.rb

Deep learning for Ruby, powered by LibTorch
Other
704 stars 30 forks source link

"Attention is all you need"-related stuff #27

Closed orlando-labs closed 3 years ago

orlando-labs commented 3 years ago

Hi, @ankane. Please review the code of attention and transformer-related stuff. And please consider a merge then. Thanks.

ankane commented 3 years ago

Thanks @orlando-labs, really great work!! Found a few minor things - will update in a follow-up commit.

ankane commented 3 years ago

Does this look right to you? https://github.com/ankane/torch.rb/commit/26b1678f722c58362597dac366b526ec0d66aac6

Also, for Torch::NN::Transformer, is there a reason generate_square_subsequent_mask is a class method and aliased? In PyTorch 1.9.0, it looks like it's an instance method without an alias.

orlando-labs commented 3 years ago

The mask may be needed for several reasons (as I experienced in my work) without instantiating an entire transformer. And the instance method is present for full python version compatibility.

orlando-labs commented 3 years ago

And of course, thanks for the fresh look and for correcting my typos for the code parts not covered with test cases.

ankane commented 3 years ago

Think I'd prefer to keep the API consistent with Python in this case (no extra methods).

orlando-labs commented 3 years ago

Additional methods don't conflict with python API and bring improvements while working with encoders and decoders separately, which is, I believe, worthy of stepping out of strict method matching.

ankane commented 2 years ago

Hey @orlando-labs, I've thought about this a bit more and still want to keep the API consistent, so I've removed the class method and aliases. Users can copy the method into their project if they need it without instantiating a transformer.

I've also added a test for it for #30 and will push a new release shortly.

orlando-labs commented 2 years ago

Hi, I'm still thinking about it as not an inconsistency. Due to the nature of python in its's methods implementation, you are able to call an instance method like this:

torch.nn.Transformer.generate_square_subsequent_mask(None, 1)

This doesn't involve the instantiating of anything else than a mask.