Every time, during the environment interaction, the we call agent.module to unwrap the agent from the distributed strategy, we also unwrap the agent from the precision plugin, this means that if we are training an agent with float16 or bfloat16 then the environment interaction happens in float32.
I suggest to wrap every player agent with a _FabricModule, i.e. _FabricModule(agent, precision=fabric.precision) so to unwrap the agent from the strategy but maintaining the precision plugin.
Every time, during the environment interaction, the we call
agent.module
to unwrap the agent from the distributed strategy, we also unwrap the agent from the precision plugin, this means that if we are training an agent withfloat16
orbfloat16
then the environment interaction happens infloat32
.I suggest to wrap every
player
agent with a_FabricModule
, i.e._FabricModule(agent, precision=fabric.precision)
so to unwrap the agent from the strategy but maintaining the precision plugin.cc @michele-milesi