Open yueyin85 opened 1 year ago
The following error was encountered while running Imports and Definitions: AttributeError Traceback (most recent call last) in <cell line: 14>() 12 import librosa 13 import note_seq ---> 14 import seqio 15 import t5 16 import t5x
8 frames /usr/local/lib/python3.10/dist-packages/jax/_src/maps.py in 869 # SPMD batching always gets involved as the last transform before XLA translation 870 ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore --> 871 ad.call_param_updaters[xmap_p] = pxla.xla_call_jvp_update_params 872 873 def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes):
AttributeError: module 'jax._src.interpreters.pxla' has no attribute 'xla_call_jvp_update_params'
The following error was encountered while running Imports and Definitions: AttributeError Traceback (most recent call last) in <cell line: 14>()
12 import librosa
13 import note_seq
---> 14 import seqio
15 import t5
16 import t5x
8 frames /usr/local/lib/python3.10/dist-packages/jax/_src/maps.py in
869 # SPMD batching always gets involved as the last transform before XLA translation
870 ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore
--> 871 ad.call_param_updaters[xmap_p] = pxla.xla_call_jvp_update_params
872
873 def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes):
AttributeError: module 'jax._src.interpreters.pxla' has no attribute 'xla_call_jvp_update_params'