huggingface / lerobot

🤗 LeRobot: Making AI for Robotics more accessible with end-to-end learning
Apache License 2.0
6.88k stars 629 forks source link

Tests are broken on main as of 89c6be8 #180

Closed AshisGhosh closed 4 months ago

AshisGhosh commented 5 months ago

System Info

`main` / 89c6be8

Information

Reproduction

Run

DATA_DIR="tests/data" python -m pytest -sv ./tests

Output:

================================================================= FAILURES =================================================================
_________________________________________ test_backward_compatibility[aloha-act-extra_overrides2] __________________________________________

env_name = 'aloha', policy_name = 'act', extra_overrides = ['policy.n_action_steps=10']

    @pytest.mark.parametrize(
        "env_name, policy_name, extra_overrides",
        [
            ("xarm", "tdmpc", []),
            (
                "pusht",
                "diffusion",
                ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
            ),
            ("aloha", "act", ["policy.n_action_steps=10"]),
        ],
    )
    # As artifacts have been generated on an x86_64 kernel, this test won't
    # pass if it's run on another platform due to floating point errors
    @require_x86_64_kernel
    def test_backward_compatibility(env_name, policy_name, extra_overrides):
        """
        NOTE: If this test does not pass, and you have intentionally changed something in the policy:
            1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
               include a report on what changed and how that affected the outputs.
            2. Go to the `if __name__ == "__main__"` block of `test/scripts/save_policy_to_safetensors.py` and
               add the policies you want to update the test artifacts for.
            3. Run `python test/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
            4. Check that this test now passes.
            5. Remember to restore `test/scripts/save_policy_to_safetensors.py` to its original state.
            6. Remember to stage and commit the resulting changes to `tests/data`.
        """
        env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
        saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
        saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
        saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
        saved_actions = load_file(env_policy_dir / "actions.safetensors")

        output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)

        for key in saved_output_dict:
            assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7).all()
        for key in saved_grad_stats:
>           assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7).all()
E           assert tensor(False)
E            +  where tensor(False) = <built-in method all of Tensor object at 0x7d36a6cda660>()
E            +    where <built-in method all of Tensor object at 0x7d36a6cda660> = tensor(False).all
E            +      where tensor(False) = <built-in method isclose of type object at 0x7d3855386760>(tensor(0.0026), tensor(0.0005), rtol=0.1, atol=1e-07)
E            +        where <built-in method isclose of type object at 0x7d3855386760> = torch.isclose

tests/test_policies.py:274: AssertionError
========================================================= short test summary info ==========================================================
FAILED tests/test_policies.py::test_backward_compatibility[aloha-act-extra_overrides2] - assert tensor(False)
================================================ 1 failed, 38 passed, 26 skipped in 14.17s =================================================

Expected behavior

All tests to pass

aliberts commented 5 months ago

Yes, unfortunately this test is very sensitive to the platform you're running it with. It's passing on the CI right now (here and here) so it shouldn't be cause for concern but definitely not the best. We're thinking of how to improve it (ideally it should be able to run & pass on any platform), if you have any ideas on how to do that please don't hesitate to share your thoughts, here or in a PR ;)