dotnet / TorchSharp

A .NET library that provides access to the library that powers PyTorch.
MIT License
1.41k stars 182 forks source link

Exception occurred when getting the state dictionary(state_dict). #1272

Open lintao185 opened 8 months ago

lintao185 commented 8 months ago

image image image

In TorchSharp, I defined a model that contains nn.Sequential cv4. However, when I obtained the state_dict of the entire model, the dictionary for cv4 was missing, which is very strange. Other models also have nn.Sequential, and they can all be correctly obtained, but not the last layer.

lintao185 commented 8 months ago

I suspect it's because of this API that it fails to recognize all the modules. this.add_module(nameof(model), this.model);

lintao185 commented 8 months ago

image image When the elements of ModuleList are nn.Module<Tensor, Tensor> instead of nn.Sequential, the state_dict of the ModuleList elements can be captured correctly.

lintao185 commented 8 months ago

When calling model.add_module, it traverses its sub-items and calls RegisterComponents() for each sub-item. However, Sequential does not perform such an operation. As a temporary measure, in a custom model, after initialization, force the call to RegisterComponents(). image

lintao185 commented 8 months ago

Comparing Sequential and ModuleList, it can be observed that ModuleList overrides RegisterComponents, whereas Sequential does not. Because ModuleList overrides RegisterComponents, it gains the ability to automatically invoke the RegisterComponents of its child items. image image

NiklasGustafsson commented 8 months ago

Is there a smallish repro case that I can debug?

lintao185 commented 8 months ago
public class AModel : TorchSharp.torch.nn.Module
{
    public TorchSharp.Modules.Sequential cv1;
    public TorchSharp.Modules.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>> cv2;
    public AModel():base(nameof(AModel))
    {
        cv1 = nn.Sequential(Enumerable.Range(0, 3).Select(x => new BModel()));
        cv2 = nn.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>>(Enumerable.Range(0, 3).Select(x => new BModel()).ToArray());
        RegisterComponents();
    }
}

public class BModel : torch.nn.Module<torch.Tensor,torch.Tensor>
{
    public Tensor stride;
    public Tensor stride2;

    public BModel() : base(nameof(BModel))
    {
        stride = torch.ones(100, 100);
        stride2 = torch.ones(100, 100);
    }

    public override torch.Tensor forward(torch.Tensor input)
    {
        throw new NotImplementedException();
    }
}
var a = new AModel();
var aStateDict = a.state_dict();

In this case, the state_dict of cv1 cannot be obtained.

public class AModel : TorchSharp.torch.nn.Module
{
    public TorchSharp.Modules.Sequential cv1;
    public TorchSharp.Modules.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>> cv2;
    public AModel() : base(nameof(AModel))
    {
        cv1 = nn.Sequential(Enumerable.Range(0, 3).Select(x => new BModel()));
        cv2 = nn.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>>(Enumerable.Range(0, 3).Select(x => new BModel()).ToArray());
        RegisterComponents();
    }
}

public class BModel : torch.nn.Module<torch.Tensor, torch.Tensor>
{
    public Tensor stride;
    public Tensor stride2;

    public BModel() : base(nameof(BModel))
    {
        stride = torch.ones(100, 100);
        stride2 = torch.ones(100, 100);
        RegisterComponents();
    }

    public override torch.Tensor forward(torch.Tensor input)
    {
        throw new NotImplementedException();
    }
}
var a = new AModel();
var aStateDict = a.state_dict();

In this case, the state_dict of cv1 can be obtained.

yueyinqiu commented 8 months ago

So I guess the problem is that Sequential.RegisterComponents won't call RegisterComponents on its submodules.

lintao185 commented 8 months ago

Sequential and ModuleList have different implementation methods for RegisterComponents.

yueyinqiu commented 8 months ago

I suppose the issue could be simply solved by adding a call to the submodule's RegisterComponents in Sequential.Add.

However actually in my opinion, all the modules should always call RegisterComponents themselves (or register the modules, parameters, buffers in other ways), so there is no need to deal with the submodules because they will do that on their own.

But it seems not... Even ModuleList and Sequential are not doing that... I'm a bit confused now... This makes it impossible to use:

using static TorchSharp.torch.nn;

var l = ModuleList(Linear(1, 1));
Console.WriteLine(l.state_dict().Count); // 0

And RegisterComponents is protected so I have to create a wrapping module outside? I believe something have been ill designed here...

lintao185 commented 8 months ago
 protected override void RegisterComponents()
 {
     if (_registered) return;

     for (int i = 0; i < _list.Count; i++) {
         register_module($"{i}", _list[i]);
     }
     _registered = true;
 }

This is the implementation of ModuleList, and I think adding similar code in Sequential should be able to fix it.

yueyinqiu commented 8 months ago
 protected override void RegisterComponents()
 {
     if (_registered) return;

     for (int i = 0; i < _list.Count; i++) {
         register_module($"{i}", _list[i]);
     }
     _registered = true;
 }

This is the implementation of ModuleList, and I think adding similar code in Sequential should be able to fix it.

However you will still find that it's unable to use Sequential(BModel()).state_dict(), since RegisterComponents of Sequential will not be called by itself, so your models' RegisterComponents is also not invoked. That's probably because Sequential allows models to be dynamically appended, so we have to register them dynamically, instead of calling RegisterComponents only once.

So one solution might be to call the submodule's RegisterComponents in Sequential.Add. However it might make RegisterComponents be called too early, especially when the submodules are also mutable. I'm not sure what the expected behavior should be.

And let me repeat my suggestion. Perhaps all the modules should always RegisterComponents themselves, and the parent modules should not care about registering the sub-submodules of its submodules (that is, not to call RegisterComponents on its submodules).

lintao185 commented 8 months ago

We can call RegisterComponents once in the top-level model, so that other models will be registered automatically, which is the most convenient.

yueyinqiu commented 8 months ago

ahh... Even RegisterComponents of the submodules cannot be accessed by Sequential since it's protected. Now I have no idea how to implement it without breaking other things... (ModuleList uses register_model, but currently Sequential does not, which keeps a List<torch.nn.IModule<Tensor, Tensor>> instead.)

lintao185 commented 8 months ago

image I wonder if this could serve as a relatively good solution.

yueyinqiu commented 8 months ago

I think this could work without side effects... but... humm... I can't say...

protected override void RegisterComponents()
{
    foreach(var module in this._modules) {
        this.register_module("sub", (nn.Module)module);
        _internal_submodules.Clear();
    }
}
lintao185 commented 8 months ago

To facilitate the use of pre-trained weights in TorchSharp, it is advisable to maintain consistency with PyTorch as much as possible.

yueyinqiu commented 8 months ago

To facilitate the use of pre-trained weights in PyTorch, it is advisable to maintain consistency with PyTorch as much as possible.

That is what I mean. A module should register the parameters by themselves. In PyTorch it is done by __setattr__ and __getattr__ of the module. However it's impossible for csharp, so there is RegisterComponents. If you want a module to behavior like PyTorch, then it should always call RegisterComponents in its constructor, rather than let it be called by others.

In other words, all the modules should be able to use alone, instead of being required to be a part of other modules. In PyTorch __setattr__ and __getattr__ could automatically deal with that. But in csharp, if you don't call RegisterComponents then it can't work correctly.

Umm... Perhaps the best solution would be a source generator?

lintao185 commented 8 months ago

I'm not sure if this is feasible or not, but the idea is to call RegisterComponents within the parameterless constructor of nn.Module. This way, when you create a custom model that inherits from nn.Module, it will automatically register itself.

yueyinqiu commented 8 months ago

I'm not sure if this is feasible or not, but the idea is to call RegisterComponents within the parameterless constructor of nn.Module. This way, when you create a custom model that inherits from nn.Module, it will automatically register itself.

Unfortunately, that's impossible. The constructor of nn.Module is runned before the constructor of your model. Thus no values have been assigned to the properties and fields, so we can't register them...

lintao185 commented 8 months ago

Switching gears, we could declare properties and then mark them with a custom attribute that has a “name” which will be used as the registration name. Subsequently, we could employ Fody to inject code into the getter and setter methods to handle the registration process. However, this approach is somewhat cumbersome.

yueyinqiu commented 8 months ago

Yes I suppose Fody/SourceGenerator could be a beautiful solution. And we can easily expose properties instead of fields in that way. That's also great.

@NiklasGustafsson Could you please take a look at this?

yueyinqiu commented 8 months ago

I have made a simplified demo here: https://github.com/yueyinqiu/TorchSharp.AutoRegister

PS: It's still impossible to get rid of the traditional constructors (and use a primary constructor instead), because we have to access the generated property. So sad :(

NiklasGustafsson commented 8 months ago

I'm not sure if this is feasible or not, but the idea is to call RegisterComponents within the parameterless constructor of nn.Module. This way, when you create a custom model that inherits from nn.Module, it will automatically register itself.

Unfortunately, that's impossible. The constructor of nn.Module is runned before the constructor of your model. Thus no values have been assigned to the properties and fields, so we can't register them...

That is exactly right. That's why RegisterComponents exists and needs to be called last in the (custom) module constructor.

NiklasGustafsson commented 8 months ago

Switching gears, we could declare properties and then mark them with a custom attribute that has a “name” which will be used as the registration name. Subsequently, we could employ Fody to inject code into the getter and setter methods to handle the registration process. However, this approach is somewhat cumbersome.

That capability already exists. For example, in the rewrite we're working on for some of the standard modules, which will enable more attributes to be exposed, the parameters of Linear are defined as:

            const string WeightComponentName = nameof(weight);
            const string BiasComponentName = nameof(bias);

            public Parameter? bias {
                get => _bias;
                set {
                    _bias?.Dispose();
                    _bias = value?.DetachFromDisposeScope() as Parameter;
                    ConditionallyRegisterParameter(BiasComponentName, _bias);
                }
            }

            public Parameter weight {
                get => _weight!;
                set {
                    if (value is null) throw new ArgumentNullException(nameof(weight));
                    if (value.Handle != _weight?.Handle) {
                        _weight?.Dispose();
                        _weight = (value.DetachFromDisposeScope() as Parameter)!;
                        ConditionallyRegisterParameter(WeightComponentName, _weight);
                    }
                }
            }

            [ComponentName(Name = BiasComponentName)]
            private Parameter? _bias;
            [ComponentName(Name = WeightComponentName)]
            private Parameter? _weight;
NiklasGustafsson commented 8 months ago

Perhaps all the modules should always RegisterComponents themselves, and the parent modules should not care about registering the sub-submodules of its submodules (that is, not to call RegisterComponents on its submodules).

Yes, this is exactly the intended protocol, and the documentation says so: https://github.com/dotnet/TorchSharp/wiki/Creating-Your-Own-TorchSharp-Modules

The exception from this rule will be modules (such as the Linear module shown above, where the parameters may be altered, and the property setter needs to conditionally register a component, which allows you to assign it 'null' as well as overwrite an already existing parameter.

NiklasGustafsson commented 8 months ago

I have made a simplified demo here: https://github.com/yueyinqiu/TorchSharp.AutoRegister

PS: It's still impossible to get rid of the traditional constructors (and use a primary constructor instead), because we have to access the generated property. So sad :(

As much as I dislike relying on reflection, which the current scheme does (I dislike it because it prevents AOT), having to use source code generation adds complexity and something that has to be automated. That would be a last resort, I think.

The current scheme works fairly well as long as you follow the instructions very closely and don't do advanced stuff like the Linear module above. I don't know why you would allow setting the parameters after the module has been constructed, but PyTorch does, so TorchSharp should, too.

lintao185 commented 8 months ago

So, to address the issue of Sequential not registering its sub-modules, for the time being, should we also rewrite Sequential’s RegisterComponents, just like we did with ModuleList?

yueyinqiu commented 8 months ago

Perhaps all the modules should always RegisterComponents themselves, and the parent modules should not care about registering the sub-submodules of its submodules (that is, not to call RegisterComponents on its submodules).

Yes, this is exactly the intended protocol, and the documentation says so: https://github.com/dotnet/TorchSharp/wiki/Creating-Your-Own-TorchSharp-Modules

The exception from this rule will be modules (such as the Linear module shown above, where the parameters may be altered, and the property setter needs to conditionally register a component, which allows you to assign it 'null' as well as overwrite an already existing parameter.

So my understanding is that custom modules should not relies on others calling its RegisterComponents. But why we are doing that in register_module? I think this may cause a misleading.

NiklasGustafsson commented 8 months ago

So, to address the issue of Sequential not registering its sub-modules, for the time being, should we also rewrite Sequential’s RegisterComponents, just like we did with ModuleList?

We can certainly reconsider the protocol for module registration for the future. However, if the guidelines for custom modules described in the Wiki article are followed, the current protocol works.

NiklasGustafsson commented 8 months ago

[...] why we are doing that in register_module?

As unsatisfying as this answer is -- I don't recall.

yueyinqiu commented 8 months ago

[...] why we are doing that in register_module?

As unsatisfying as this answer is -- I don't recall.

haha... Perhaps we should check all the modules provided by TorchSharp, and remove this call if nothing depends on that? I believe this should be done as early as possible, to avoid more projects' relying on it by mistake.

yueyinqiu commented 8 months ago

[...] why we are doing that in register_module?

As unsatisfying as this answer is -- I don't recall.

haha... Perhaps we should check all the modules provided by TorchSharp, and remove this call if nothing depends on that? I believe this should be done as early as possible, to avoid more projects' relying on it by mistake.

oh well there is at least one thing (ModuleList) that depends on that:

using static TorchSharp.torch.nn;

var l = ModuleList(Linear(1, 1));
Console.WriteLine(l.state_dict().Count); // 0

I suppose it should be modified to have a similar behavior as Sequential.

lintao185 commented 8 months ago

So, to address the issue of Sequential not registering its sub-modules, for the time being, should we also rewrite Sequential’s RegisterComponents, just like we did with ModuleList?

We can certainly reconsider the protocol for module registration for the future. However, if the guidelines for custom modules described in the Wiki article are followed, the current protocol works.

Now it’s calling RegisterComponents() in the custom module, it just feels a bit verbose, haha.