ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
34.28k stars 5.82k forks source link

assert issubclass(model_cls, TFModelV2) when trying to use pytorch model #7279

Closed drozzy closed 4 years ago

drozzy commented 4 years ago

I'm trying to use a custom model (pytorch visionnet) for a DQN policy.

When I try this:

ModelCatalog.register_custom_model("my_model", VisionNetwork)
register_env("zzz", env_creator)

tune.run(DQNTrainer, 
        config={"env": "zzz", 
        "model": {
            "custom_model": "my_model",
            "dim": 150,
            "conv_filters": [[16, [4, 4], 2], [32, [4, 4], 2], [32, [4, 4], 2], [32, [4, 4], 2], 
            [512, [10, 10], 1]]

      }
    })

I get this errror:

assert issubclass(model_cls, TFModelV2), model_cls
AssertionError: <class 'my_model_pytorch.VisionNetwork'>
ericl commented 4 years ago

I guess this will get fixed once DQN supports torch. cc @sven1977

drozzy commented 4 years ago

Closing, as I'm seeing pytorch beginning to work as part of rllib.