RobertTLange / gymnax

RL Environments in JAX 🌍
Apache License 2.0
577 stars 54 forks source link

fix bsuite MemoryChain Runtime exception for num_bits > 1 #61

Open FranzKnut opened 11 months ago

FranzKnut commented 11 months ago

Using MemoryChain-bsuite with num_bits > 1 resulted in

TypeError: select cases must have the same shapes

when selecting (masking) the context for current time.

This PR simply fixes this by making sure the negative case array containing only zeros has the same shape as the context.