ankane / torch.rb

Deep learning for Ruby, powered by LibTorch
Other
704 stars 30 forks source link

Torch::NN::Utils._clones doesn't perform deep cloning #31

Closed orlando-labs closed 2 years ago

orlando-labs commented 2 years ago

Hi, @ankane. https://github.com/ankane/torch.rb/blob/b078d6397e249778eda45f07935e987078abda93/lib/torch/nn/utils.rb#L27 Torch::NN::Utils._clones doesn't perform deep cloning, actually, so the cloned modules reference the same parameter tensors. Python package uses copy.deepcopy that does the trick. Ruby trick with marshaling/unmarshaling is not a solution there. What, in your opinion, can we implement here to resolve the issue?

orlando-labs commented 2 years ago

I am thinking about activerecord-like Module#initialize_dup copy constructor with recursive duplication of named children and another instance variables, but there are some problems with duplicating the tensors

Torch::Tensor.new.dup
*** TypeError Exception: allocator undefined for Torch::Tensor
orlando-labs commented 2 years ago

I noticed the Torch::NN::Parameter supports #clone, but returns Torch::Tensor. So, currently, I'm working with the ugly following monkey patch. Maybe it will provide some ideas to implement it better inside of the gem.

module Torch                                                                                                                                                                                                              
  module NN                                                                                                                                                                                                               
    module Utils                                                                                                                                                                                                          
      def _clones(mod, n)                                                                                                                                                                                                 
        state = mod.state_dict                                                                                                                                                                                            
        layers = n.times.map { mod.dup }                                                                                                                                                                                  
        ModuleList.new(layers)                                                                                                                                                                                            
      end
    end

    class Module
      def initialize_dup(other)
        cvars = %i[@parameters @modules @buffers]

        other.named_buffers.each do |name, buf|
          duplicate = if buf.is_a? Torch::NN::Parameter
            Torch::NN::Parameter.new(buf.clone)
          else
            buf.clone
          end
          register_buffer name, duplicate
        end

        other.instance_variable_get(:@parameters).each do |name, par|
          duplicate = if par.is_a? Torch::NN::Parameter
            Torch::NN::Parameter.new(par.clone)
          else
            par.dup
          end
          register_parameter name, duplicate
        end

        other.instance_variable_get(:@modules).each do |name, mod|
          add_module name, mod.dup
        end

        (other.instance_variables - cvars).each do |name|
          var = other.send(:instance_variable_get, name)
          duplicate = if var.is_a?(Torch::NN::Parameter)
            Torch::NN::Parameter.new(var.clone)
          elsif var.is_a?(Torch::Tensor)
            var.clone
          elsif var.is_a?(Method) or var.is_a?(UnboundMethod)
            var
          else
            var.dup
          end
          instance_variable_set(name, duplicate)
        end
      end
    end
  end
end
ankane commented 2 years ago

Thanks @orlando-labs, should be fixed in the commit above. A few notes:

ankane commented 2 years ago

Took a few more commits. Think it's in a good spot, but could use some further testing.