ShaneFlandermeyer / tdmpc2-jax

Jax/Flax Implementation of TD-MPC2
48 stars 7 forks source link

Image / multimodal input #9

Open gmmyung opened 4 months ago

gmmyung commented 4 months ago

Are there any works on RGB image / multimodal input? It seems pretty straightforward to implement, I might work on it if there is no prior on this.

ShaneFlandermeyer commented 4 months ago

I agree! Feel free to submit a pull request if you decide to do this. The original pytorch repo supports multi-modal inputs to some extent (see their definition of WorldModel._encoder), but I think it would be nice to generalize this implementation to support user-specified network architectures for each observation type.