epfl-lasa / OptimalModulationDS

16 stars 2 forks source link

`aot_lambda` does not work #2

Open siddancha opened 9 months ago

siddancha commented 9 months ago

When running python_scripts/mlp_learn/sdf/robot_sdf.py under the current version of

def dist_grad_closest_aot(self, q):
        return self.aot_lambda(q)
        # return self.functorch_vjp(q)

I get the following error:

Weights loaded!
/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/../functions/LinDS.py:8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.q_goal = torch.tensor(q_goal)
/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/deprecated.py:73: UserWarning: We've integrated functorch into PyTorch. As the final step of the integration, functorch.vjp is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.vjp instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html
  warn_deprecated('vjp')
Traceback (most recent call last):
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/standalonePlanar2d.py", line 223, in <module>
    main_int()
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/standalonePlanar2d.py", line 121, in main_int
    mppi = MPPI(q_0, q_f, dh_params, obs, dt, dt_H, N_traj, DS_ARRAY, dh_a, nn_model, 2)
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/../functions/MPPI.py", line 71, in __init__
    _, _, _, _, _ = self.propagate()
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/../functions/MPPI.py", line 113, in propagate
    distance, self.nn_grad = self.distance_repulsion_nn(q_prev, aot=True)
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/../functions/MPPI.py", line 259, in distance_repulsion_nn
    nn_dist, nn_grad, nn_minidx = self.nn_model.dist_grad_closest_aot(nn_input[:, 0:-1])
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/../../mlp_learn/sdf/robot_sdf.py", line 161, in dist_grad_closest_aot
    return self.aot_lambda(q)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3725, in returned_function
    compiled_fn = create_aot_dispatcher_function(
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3379, in create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 757, in inner
    flat_f_outs = f(*flat_f_args)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3525, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/../../mlp_learn/sdf/robot_sdf.py", line 154, in functorch_vjp
    dists, vjp_fn = vjp(self.model.forward, points)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/deprecated.py", line 74, in vjp
    return _impl.vjp(func, *primals, has_aux=has_aux)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 267, in vjp
    return _vjp_with_argnums(func, *primals, has_aux=has_aux)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 38, in fn
    return f(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 294, in _vjp_with_argnums
    primals_out = func(*primals)
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/../../mlp_learn/sdf/network_macros_mod.py", line 143, in forward
    y = self.layers[0](x_nerf)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1250, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1376, in dispatch
    ) = self.validate_and_convert_non_fake_tensors(func, converter, args, kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1597, in validate_and_convert_non_fake_tensors
    args, kwargs = tree_map_only(
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 353, in tree_map_only
    return tree_map(map_only(ty)(fn), pytree)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 283, in tree_map
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 283, in <listcomp>
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 334, in inner
    return f(x)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1587, in validate
    raise Exception(
Exception: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.t.default(Parameter containing:
tensor([...], size=(256, 15), requires_grad=True))

When I switch to vjp:

def dist_grad_closest_aot(self, q):
        # return self.aot_lambda(q)
        return self.functorch_vjp(q)

it works! Why does aot_lambda not work? Should I continue using functorch_vjp? Thanks!

erdisayar commented 2 months ago

Do you know how do they obtain these trained parameters ? https://github.com/epfl-lasa/OptimalModulationDS/tree/master/python_scripts/mlp_learn/models