google-deepmind / emergent_communication_at_scale

Apache License 2.0
32 stars 5 forks source link

Access values at execution time #2

Open bastienvanderplaetse opened 1 year ago

bastienvanderplaetse commented 1 year ago

Hi,

Thanks for releasing the source code, it helps a lot. I'm currently working on a similar project, but I'm wondering if there is any way to display the current values of each tensor with rlax during the training? When I want to print them, it just displays me their shape :

ListenerLossOutputs(loss=Traced<ShapedArray(float32[])>with<JVPTrace(level=2/1)> with primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)> tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/1)>

Do you have any advice? Thanks in advance

rahmacha commented 1 year ago

Hi Bastien, This framework is based on Jaxline. You can try "--jaxline_disable_pmap_jit" as explained in here https://github.com/deepmind/jaxline to debug the tensors. Good luck, Best, Rahma

Le mar. 4 avr. 2023 à 11:07, Bastien @.***> a écrit :

Hi,

Thanks for releasing the source code, it helps a lot. I'm currently working on a similar project, but I'm wondering if there is any way to display the current values of each tensor with rlax during the training? When I want to print them, it just displays me their shape :

ListenerLossOutputs(loss=Traced<ShapedArray(float32[])>with<JVPTrace(level=2/1)> with primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)> tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/1)>

Do you have any advice? Thanks in advance

— Reply to this email directly, view it on GitHub https://github.com/deepmind/emergent_communication_at_scale/issues/2, or unsubscribe https://github.com/notifications/unsubscribe-auth/AD3QGLOCH5ZUEZPLUYMTDE3W7PXE5ANCNFSM6AAAAAAWSPNQ3A . You are receiving this because you are subscribed to this thread.Message ID: @.***>