pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

_get_nested_attr returning None due to missing a return in the general case #1053

Closed andreykramer closed 1 year ago

andreykramer commented 1 year ago

I noticed that the old_state variable of make_functional.py::FunctionalModuleWithBuffers::forward is always full of None elements. Digging a bit I saw that there is no return in the general case of _get_nested_attr

    if len(names) == 1:
        return getattr(obj, names[0])
    else:
        # No return here
        _get_nested_attr(getattr(obj, names[0]), names[1:])

So the outermost call is always returning None when len(names) > 1. If I'm right and this is a bug what implications does it have? For instance, the "Remove the loaded state on self.stateless_model" step surely wouldn't be doing what it's supposed to.

zou3519 commented 1 year ago

Thanks for the catch, @andreykramer. This is indeed a bug.

From my reading of the code, make_functional/make_functional_with_buffers should still work as expected. The "Remove the loaded state on self.stateless_model" will still remove the loaded state, but instead of replacing a tensor with a meta tensor (a Tensor without storage), it'll replace a tensor with None.

It looks like the only implication is that if you directly access self.stateless_model, it will have some Nones where it should instead have meta tensors.