Closed orlando-labs closed 2 years ago
I am thinking about activerecord-like Module#initialize_dup copy constructor with recursive duplication of named children and another instance variables, but there are some problems with duplicating the tensors
Torch::Tensor.new.dup
*** TypeError Exception: allocator undefined for Torch::Tensor
I noticed the Torch::NN::Parameter supports #clone, but returns Torch::Tensor. So, currently, I'm working with the ugly following monkey patch. Maybe it will provide some ideas to implement it better inside of the gem.
module Torch
module NN
module Utils
def _clones(mod, n)
state = mod.state_dict
layers = n.times.map { mod.dup }
ModuleList.new(layers)
end
end
class Module
def initialize_dup(other)
cvars = %i[@parameters @modules @buffers]
other.named_buffers.each do |name, buf|
duplicate = if buf.is_a? Torch::NN::Parameter
Torch::NN::Parameter.new(buf.clone)
else
buf.clone
end
register_buffer name, duplicate
end
other.instance_variable_get(:@parameters).each do |name, par|
duplicate = if par.is_a? Torch::NN::Parameter
Torch::NN::Parameter.new(par.clone)
else
par.dup
end
register_parameter name, duplicate
end
other.instance_variable_get(:@modules).each do |name, mod|
add_module name, mod.dup
end
(other.instance_variables - cvars).each do |name|
var = other.send(:instance_variable_get, name)
duplicate = if var.is_a?(Torch::NN::Parameter)
Torch::NN::Parameter.new(var.clone)
elsif var.is_a?(Torch::Tensor)
var.clone
elsif var.is_a?(Method) or var.is_a?(UnboundMethod)
var
else
var.dup
end
instance_variable_set(name, duplicate)
end
end
end
end
end
Thanks @orlando-labs, should be fixed in the commit above. A few notes:
deepcopy
keeps a memo
dictionary of copied objects. https://docs.python.org/3/library/copy.htmlparameter.clone()
returns a TensorTook a few more commits. Think it's in a good spot, but could use some further testing.
Hi, @ankane. https://github.com/ankane/torch.rb/blob/b078d6397e249778eda45f07935e987078abda93/lib/torch/nn/utils.rb#L27 Torch::NN::Utils._clones doesn't perform deep cloning, actually, so the cloned modules reference the same parameter tensors. Python package uses copy.deepcopy that does the trick. Ruby trick with marshaling/unmarshaling is not a solution there. What, in your opinion, can we implement here to resolve the issue?