Open agoliaei opened 1 year ago
Hi, I am trying to follow this tutorial: https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb Setting the runtime to TPU on Colab used to work couple of days ago. But now it crashes with error:
TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'
This happens at this step: training_loop = training.Loop(model,.....
OS: NAME="Ubuntu" VERSION="18.04.6 LTS (Bionic Beaver)" ID=ubuntu ID_LIKE=debian PRETTY_NAME="Ubuntu 18.04.6 LTS" VERSION_ID="18.04" HOME_URL="https://www.ubuntu.com/" SUPPORT_URL="https://help.ubuntu.com/" BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/" PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy" VERSION_CODENAME=bionic UBUNTU_CODENAME=bionic $ pip freeze | grep trax # trax==1.4.1 $ pip freeze | grep tensor # tensorboard==2.10.0 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 tensorflow==2.10.0 tensorflow-datasets==4.6.0 tensorflow-estimator==2.10.0 tensorflow-gcs-config==2.8.0 tensorflow-hub==0.12.0 tensorflow-io-gcs-filesystem==0.26.0 tensorflow-metadata==1.10.0 tensorflow-probability==0.16.0 tensorflow-text==2.10.0 $ pip freeze | grep jax # jax==0.3.17 jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.15+cuda11.cudnn805-cp37-none-manylinux2014_x86_64.whl $ python -V # Python 3.7.13
# Steps to reproduce: https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb ...
# Error logs: --------------------------------------------------------------------------- TypeError Traceback (most recent call last) [<ipython-input-8-2021642a85f0>](https://localhost:8080/#) in <module> 9 train_task, 10 eval_tasks=[eval_task], ---> 11 output_dir=output_dir) 16 frames [/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks) 278 279 # Create the optimizer for the training loss function. --> 280 self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks) 281 282 # Sync layers weights/state in memory effcient trainer layers. [/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in <genexpr>(.0) 278 279 # Create the optimizer for the training loss function. --> 280 self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks) 281 282 # Sync layers weights/state in memory effcient trainer layers. [/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in _init_trainer(self, task) 348 task.optimizer.tree_init(model_in_training.weights) 349 return optimizers.Trainer( --> 350 model_in_training, task.optimizer, adasum=self._adasum) 351 # In the memory-efficient path, we initialize the model here. 352 blocks, loss_layer = optimizers.trainer.extract_reversible_blocks( [/usr/local/lib/python3.7/dist-packages/trax/optimizers/trainer.py](https://localhost:8080/#) in __init__(self, model_with_loss, optimizer, n_devices, adasum) 57 # optimizer slots and opt_params may need to be replicated 58 self._slots, self._opt_params = tl.on_cpu(tl.for_n_devices( ---> 59 (self._optimizer.slots, self._optimizer.opt_params), self._n_devices)) 60 61 # accelerated version of model+loss to replicate weights and state [/usr/local/lib/python3.7/dist-packages/trax/layers/acceleration.py](https://localhost:8080/#) in on_cpu(x) 250 """Puts ``x`` in CPU memory in JAX.""" 251 if fastmath.is_backend(fastmath.Backend.JAX): --> 252 return jax.device_put(x, jax.devices('cpu')[0]) 253 else: 254 return x [/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in device_put(x, device) 2722 """ 2723 with config_explicit_device_put_scope(): -> 2724 return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x) 2725 2726 [/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in tree_map(f, tree, is_leaf, *rest) 203 leaves, treedef = tree_flatten(tree, is_leaf) 204 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] --> 205 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) 206 207 def build_tree(treedef, xs): [/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in <genexpr>(.0) 203 leaves, treedef = tree_flatten(tree, is_leaf) 204 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] --> 205 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) 206 207 def build_tree(treedef, xs): [/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in <lambda>(y) 2722 """ 2723 with config_explicit_device_put_scope(): -> 2724 return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x) 2725 2726 [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, *args, **params) 323 assert (not config.jax_enable_checks or 324 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args --> 325 return self.bind_with_trace(find_top_trace(args), args, params) 326 327 def bind_with_trace(self, trace, args, params): [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params) 326 327 def bind_with_trace(self, trace, args, params): --> 328 out = trace.process_primitive(self, map(trace.full_raise, args), params) 329 return map(full_lower, out) if self.multiple_results else full_lower(out) 330 [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params) 684 685 def process_primitive(self, primitive, tracers, params): --> 686 return primitive.impl(*tracers, **params) 687 688 def process_call(self, primitive, f, tracers, params): [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _device_put_impl(x, device) 1219 raise TypeError( 1220 f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err -> 1221 return aval_to_result_handler(device, a)(None, *device_put(x, device)) 1222 1223 [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in device_put(x, device) 1113 x = xla.canonicalize_dtype(x) 1114 try: -> 1115 return device_put_handlers[type(x)](x, device) 1116 except KeyError as err: 1117 raise TypeError(f"No device_put handler for type: {type(x)}") from err [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _device_put_array(x, device) 1124 if x.dtype == dtypes.float0: 1125 x = np.zeros(x.shape, dtype=np.dtype(bool)) -> 1126 return (backend.buffer_from_pyval(x, device),) 1127 1128 def _device_put_scalar(x, device): [/usr/local/lib/python3.7/dist-packages/jax/_src/device_array.py](https://localhost:8080/#) in __array__(self, dtype, context) 264 265 def __array__(self, dtype=None, context=None): --> 266 return np.asarray(self._value, dtype=dtype) 267 268 setattr(device_array, "__array__", __array__) [/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py](https://localhost:8080/#) in _sda_value(self) 803 npy_value = np.empty(self.aval.shape, self.aval.dtype) 804 for i in self.one_replica_buffer_indices: --> 805 npy_value[self.indices[i]] = np.asarray(self.device_buffers[i]) 806 self._npy_value = npy_value 807 return self._npy_value TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer' ...
Description
Hi, I am trying to follow this tutorial: https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb Setting the runtime to TPU on Colab used to work couple of days ago. But now it crashes with error:
TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'
This happens at this step: training_loop = training.Loop(model,.....
Environment information
For bugs: reproduction and error logs