Open lintao185 opened 8 months ago
I suspect it's because of this API that it fails to recognize all the modules. this.add_module(nameof(model), this.model);
When the elements of ModuleList are nn.Module<Tensor, Tensor> instead of nn.Sequential, the state_dict of the ModuleList elements can be captured correctly.
When calling model.add_module
, it traverses its sub-items and calls RegisterComponents()
for each sub-item. However, Sequential
does not perform such an operation. As a temporary measure, in a custom model, after initialization, force the call to RegisterComponents()
.
Comparing Sequential and ModuleList, it can be observed that ModuleList overrides RegisterComponents, whereas Sequential does not. Because ModuleList overrides RegisterComponents, it gains the ability to automatically invoke the RegisterComponents of its child items.
Is there a smallish repro case that I can debug?
public class AModel : TorchSharp.torch.nn.Module
{
public TorchSharp.Modules.Sequential cv1;
public TorchSharp.Modules.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>> cv2;
public AModel():base(nameof(AModel))
{
cv1 = nn.Sequential(Enumerable.Range(0, 3).Select(x => new BModel()));
cv2 = nn.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>>(Enumerable.Range(0, 3).Select(x => new BModel()).ToArray());
RegisterComponents();
}
}
public class BModel : torch.nn.Module<torch.Tensor,torch.Tensor>
{
public Tensor stride;
public Tensor stride2;
public BModel() : base(nameof(BModel))
{
stride = torch.ones(100, 100);
stride2 = torch.ones(100, 100);
}
public override torch.Tensor forward(torch.Tensor input)
{
throw new NotImplementedException();
}
}
var a = new AModel();
var aStateDict = a.state_dict();
In this case, the state_dict of cv1 cannot be obtained.
public class AModel : TorchSharp.torch.nn.Module
{
public TorchSharp.Modules.Sequential cv1;
public TorchSharp.Modules.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>> cv2;
public AModel() : base(nameof(AModel))
{
cv1 = nn.Sequential(Enumerable.Range(0, 3).Select(x => new BModel()));
cv2 = nn.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>>(Enumerable.Range(0, 3).Select(x => new BModel()).ToArray());
RegisterComponents();
}
}
public class BModel : torch.nn.Module<torch.Tensor, torch.Tensor>
{
public Tensor stride;
public Tensor stride2;
public BModel() : base(nameof(BModel))
{
stride = torch.ones(100, 100);
stride2 = torch.ones(100, 100);
RegisterComponents();
}
public override torch.Tensor forward(torch.Tensor input)
{
throw new NotImplementedException();
}
}
var a = new AModel();
var aStateDict = a.state_dict();
In this case, the state_dict of cv1 can be obtained.
So I guess the problem is that Sequential.RegisterComponents
won't call RegisterComponents
on its submodules.
Sequential and ModuleList have different implementation methods for RegisterComponents.
I suppose the issue could be simply solved by adding a call to the submodule's RegisterComponents
in Sequential.Add
.
However actually in my opinion, all the modules should always call RegisterComponents
themselves (or register the modules, parameters, buffers in other ways), so there is no need to deal with the submodules because they will do that on their own.
But it seems not... Even ModuleList
and Sequential
are not doing that... I'm a bit confused now... This makes it impossible to use:
using static TorchSharp.torch.nn;
var l = ModuleList(Linear(1, 1));
Console.WriteLine(l.state_dict().Count); // 0
And RegisterComponents
is protected so I have to create a wrapping module outside? I believe something have been ill designed here...
protected override void RegisterComponents()
{
if (_registered) return;
for (int i = 0; i < _list.Count; i++) {
register_module($"{i}", _list[i]);
}
_registered = true;
}
This is the implementation of ModuleList, and I think adding similar code in Sequential should be able to fix it.
protected override void RegisterComponents() { if (_registered) return; for (int i = 0; i < _list.Count; i++) { register_module($"{i}", _list[i]); } _registered = true; }
This is the implementation of ModuleList, and I think adding similar code in Sequential should be able to fix it.
However you will still find that it's unable to use Sequential(BModel()).state_dict()
, since RegisterComponents
of Sequential
will not be called by itself, so your models' RegisterComponents
is also not invoked. That's probably because Sequential
allows models to be dynamically appended, so we have to register them dynamically, instead of calling RegisterComponents
only once.
So one solution might be to call the submodule's RegisterComponents
in Sequential.Add
. However it might make RegisterComponents
be called too early, especially when the submodules are also mutable. I'm not sure what the expected behavior should be.
And let me repeat my suggestion. Perhaps all the modules should always RegisterComponents
themselves, and the parent modules should not care about registering the sub-submodules of its submodules (that is, not to call RegisterComponents
on its submodules).
We can call RegisterComponents once in the top-level model, so that other models will be registered automatically, which is the most convenient.
ahh... Even RegisterComponents
of the submodules cannot be accessed by Sequential
since it's protected. Now I have no idea how to implement it without breaking other things... (ModuleList
uses register_model
, but currently Sequential
does not, which keeps a List<torch.nn.IModule<Tensor, Tensor>>
instead.)
I wonder if this could serve as a relatively good solution.
I think this could work without side effects... but... humm... I can't say...
protected override void RegisterComponents()
{
foreach(var module in this._modules) {
this.register_module("sub", (nn.Module)module);
_internal_submodules.Clear();
}
}
To facilitate the use of pre-trained weights in TorchSharp, it is advisable to maintain consistency with PyTorch as much as possible.
To facilitate the use of pre-trained weights in PyTorch, it is advisable to maintain consistency with PyTorch as much as possible.
That is what I mean. A module should register the parameters by themselves. In PyTorch it is done by __setattr__
and __getattr__
of the module. However it's impossible for csharp, so there is RegisterComponents
. If you want a module to behavior like PyTorch, then it should always call RegisterComponents
in its constructor, rather than let it be called by others.
In other words, all the modules should be able to use alone, instead of being required to be a part of other modules. In PyTorch __setattr__
and __getattr__
could automatically deal with that. But in csharp, if you don't call RegisterComponents
then it can't work correctly.
Umm... Perhaps the best solution would be a source generator?
I'm not sure if this is feasible or not, but the idea is to call RegisterComponents within the parameterless constructor of nn.Module. This way, when you create a custom model that inherits from nn.Module, it will automatically register itself.
I'm not sure if this is feasible or not, but the idea is to call RegisterComponents within the parameterless constructor of nn.Module. This way, when you create a custom model that inherits from nn.Module, it will automatically register itself.
Unfortunately, that's impossible. The constructor of nn.Module is runned before the constructor of your model. Thus no values have been assigned to the properties and fields, so we can't register them...
Switching gears, we could declare properties and then mark them with a custom attribute that has a “name” which will be used as the registration name. Subsequently, we could employ Fody to inject code into the getter and setter methods to handle the registration process. However, this approach is somewhat cumbersome.
Yes I suppose Fody/SourceGenerator could be a beautiful solution. And we can easily expose properties instead of fields in that way. That's also great.
@NiklasGustafsson Could you please take a look at this?
I have made a simplified demo here: https://github.com/yueyinqiu/TorchSharp.AutoRegister
PS: It's still impossible to get rid of the traditional constructors (and use a primary constructor instead), because we have to access the generated property. So sad :(
I'm not sure if this is feasible or not, but the idea is to call RegisterComponents within the parameterless constructor of nn.Module. This way, when you create a custom model that inherits from nn.Module, it will automatically register itself.
Unfortunately, that's impossible. The constructor of nn.Module is runned before the constructor of your model. Thus no values have been assigned to the properties and fields, so we can't register them...
That is exactly right. That's why RegisterComponents
exists and needs to be called last in the (custom) module constructor.
Switching gears, we could declare properties and then mark them with a custom attribute that has a “name” which will be used as the registration name. Subsequently, we could employ Fody to inject code into the getter and setter methods to handle the registration process. However, this approach is somewhat cumbersome.
That capability already exists. For example, in the rewrite we're working on for some of the standard modules, which will enable more attributes to be exposed, the parameters of Linear
are defined as:
const string WeightComponentName = nameof(weight);
const string BiasComponentName = nameof(bias);
public Parameter? bias {
get => _bias;
set {
_bias?.Dispose();
_bias = value?.DetachFromDisposeScope() as Parameter;
ConditionallyRegisterParameter(BiasComponentName, _bias);
}
}
public Parameter weight {
get => _weight!;
set {
if (value is null) throw new ArgumentNullException(nameof(weight));
if (value.Handle != _weight?.Handle) {
_weight?.Dispose();
_weight = (value.DetachFromDisposeScope() as Parameter)!;
ConditionallyRegisterParameter(WeightComponentName, _weight);
}
}
}
[ComponentName(Name = BiasComponentName)]
private Parameter? _bias;
[ComponentName(Name = WeightComponentName)]
private Parameter? _weight;
Perhaps all the modules should always RegisterComponents themselves, and the parent modules should not care about registering the sub-submodules of its submodules (that is, not to call RegisterComponents on its submodules).
Yes, this is exactly the intended protocol, and the documentation says so: https://github.com/dotnet/TorchSharp/wiki/Creating-Your-Own-TorchSharp-Modules
The exception from this rule will be modules (such as the Linear
module shown above, where the parameters may be altered, and the property setter needs to conditionally register a component, which allows you to assign it 'null' as well as overwrite an already existing parameter.
I have made a simplified demo here: https://github.com/yueyinqiu/TorchSharp.AutoRegister
PS: It's still impossible to get rid of the traditional constructors (and use a primary constructor instead), because we have to access the generated property. So sad :(
As much as I dislike relying on reflection, which the current scheme does (I dislike it because it prevents AOT), having to use source code generation adds complexity and something that has to be automated. That would be a last resort, I think.
The current scheme works fairly well as long as you follow the instructions very closely and don't do advanced stuff like the Linear
module above. I don't know why you would allow setting the parameters after the module has been constructed, but PyTorch does, so TorchSharp should, too.
So, to address the issue of Sequential not registering its sub-modules, for the time being, should we also rewrite Sequential’s RegisterComponents, just like we did with ModuleList?
Perhaps all the modules should always RegisterComponents themselves, and the parent modules should not care about registering the sub-submodules of its submodules (that is, not to call RegisterComponents on its submodules).
Yes, this is exactly the intended protocol, and the documentation says so: https://github.com/dotnet/TorchSharp/wiki/Creating-Your-Own-TorchSharp-Modules
The exception from this rule will be modules (such as the
Linear
module shown above, where the parameters may be altered, and the property setter needs to conditionally register a component, which allows you to assign it 'null' as well as overwrite an already existing parameter.
So my understanding is that custom modules should not relies on others calling its RegisterComponents
. But why we are doing that in register_module
? I think this may cause a misleading.
So, to address the issue of Sequential not registering its sub-modules, for the time being, should we also rewrite Sequential’s RegisterComponents, just like we did with ModuleList?
We can certainly reconsider the protocol for module registration for the future. However, if the guidelines for custom modules described in the Wiki article are followed, the current protocol works.
[...] why we are doing that in
register_module
?
As unsatisfying as this answer is -- I don't recall.
[...] why we are doing that in
register_module
?As unsatisfying as this answer is -- I don't recall.
haha... Perhaps we should check all the modules provided by TorchSharp, and remove this call if nothing depends on that? I believe this should be done as early as possible, to avoid more projects' relying on it by mistake.
[...] why we are doing that in
register_module
?As unsatisfying as this answer is -- I don't recall.
haha... Perhaps we should check all the modules provided by TorchSharp, and remove this call if nothing depends on that? I believe this should be done as early as possible, to avoid more projects' relying on it by mistake.
oh well there is at least one thing (ModuleList
) that depends on that:
using static TorchSharp.torch.nn;
var l = ModuleList(Linear(1, 1));
Console.WriteLine(l.state_dict().Count); // 0
I suppose it should be modified to have a similar behavior as Sequential
.
So, to address the issue of Sequential not registering its sub-modules, for the time being, should we also rewrite Sequential’s RegisterComponents, just like we did with ModuleList?
We can certainly reconsider the protocol for module registration for the future. However, if the guidelines for custom modules described in the Wiki article are followed, the current protocol works.
Now it’s calling RegisterComponents() in the custom module, it just feels a bit verbose, haha.
In TorchSharp, I defined a model that contains nn.Sequential cv4. However, when I obtained the state_dict of the entire model, the dictionary for cv4 was missing, which is very strange. Other models also have nn.Sequential, and they can all be correctly obtained, but not the last layer.