google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.01k stars 813 forks source link

TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer' #1760

Open agoliaei opened 1 year ago

agoliaei commented 1 year ago

Description

Hi, I am trying to follow this tutorial: https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb Setting the runtime to TPU on Colab used to work couple of days ago. But now it crashes with error:

TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'

This happens at this step: training_loop = training.Loop(model,.....

Environment information

OS: 
NAME="Ubuntu"
VERSION="18.04.6 LTS (Bionic Beaver)"
ID=ubuntu
ID_LIKE=debian
PRETTY_NAME="Ubuntu 18.04.6 LTS"
VERSION_ID="18.04"
HOME_URL="https://www.ubuntu.com/"
SUPPORT_URL="https://help.ubuntu.com/"
BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
VERSION_CODENAME=bionic
UBUNTU_CODENAME=bionic

$ pip freeze | grep trax
# trax==1.4.1

$ pip freeze | grep tensor
# tensorboard==2.10.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.10.0
tensorflow-datasets==4.6.0
tensorflow-estimator==2.10.0
tensorflow-gcs-config==2.8.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.26.0
tensorflow-metadata==1.10.0
tensorflow-probability==0.16.0
tensorflow-text==2.10.0

$ pip freeze | grep jax
# jax==0.3.17
jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.15+cuda11.cudnn805-cp37-none-manylinux2014_x86_64.whl

$ python -V
# Python 3.7.13

For bugs: reproduction and error logs

# Steps to reproduce:
https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb

...
# Error logs:
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-8-2021642a85f0>](https://localhost:8080/#) in <module>
      9                               train_task,
     10                               eval_tasks=[eval_task],
---> 11                               output_dir=output_dir)

16 frames
[/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    278 
    279     # Create the optimizer for the training loss function.
--> 280     self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
    281 
    282     # Sync layers weights/state in memory effcient trainer layers.

[/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in <genexpr>(.0)
    278 
    279     # Create the optimizer for the training loss function.
--> 280     self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
    281 
    282     # Sync layers weights/state in memory effcient trainer layers.

[/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in _init_trainer(self, task)
    348         task.optimizer.tree_init(model_in_training.weights)
    349       return optimizers.Trainer(
--> 350           model_in_training, task.optimizer, adasum=self._adasum)
    351     # In the memory-efficient path, we initialize the model here.
    352     blocks, loss_layer = optimizers.trainer.extract_reversible_blocks(

[/usr/local/lib/python3.7/dist-packages/trax/optimizers/trainer.py](https://localhost:8080/#) in __init__(self, model_with_loss, optimizer, n_devices, adasum)
     57     # optimizer slots and opt_params may need to be replicated
     58     self._slots, self._opt_params = tl.on_cpu(tl.for_n_devices(
---> 59         (self._optimizer.slots, self._optimizer.opt_params), self._n_devices))
     60 
     61     # accelerated version of model+loss to replicate weights and state

[/usr/local/lib/python3.7/dist-packages/trax/layers/acceleration.py](https://localhost:8080/#) in on_cpu(x)
    250   """Puts ``x`` in CPU memory in JAX."""
    251   if fastmath.is_backend(fastmath.Backend.JAX):
--> 252     return jax.device_put(x, jax.devices('cpu')[0])
    253   else:
    254     return x

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in device_put(x, device)
   2722   """
   2723   with config_explicit_device_put_scope():
-> 2724     return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
   2725 
   2726 

[/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in tree_map(f, tree, is_leaf, *rest)
    203   leaves, treedef = tree_flatten(tree, is_leaf)
    204   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 205   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    206 
    207 def build_tree(treedef, xs):

[/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in <genexpr>(.0)
    203   leaves, treedef = tree_flatten(tree, is_leaf)
    204   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 205   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    206 
    207 def build_tree(treedef, xs):

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in <lambda>(y)
   2722   """
   2723   with config_explicit_device_put_scope():
-> 2724     return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
   2725 
   2726 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, *args, **params)
    323     assert (not config.jax_enable_checks or
    324             all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 325     return self.bind_with_trace(find_top_trace(args), args, params)
    326 
    327   def bind_with_trace(self, trace, args, params):

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params)
    326 
    327   def bind_with_trace(self, trace, args, params):
--> 328     out = trace.process_primitive(self, map(trace.full_raise, args), params)
    329     return map(full_lower, out) if self.multiple_results else full_lower(out)
    330 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params)
    684 
    685   def process_primitive(self, primitive, tracers, params):
--> 686     return primitive.impl(*tracers, **params)
    687 
    688   def process_call(self, primitive, f, tracers, params):

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _device_put_impl(x, device)
   1219     raise TypeError(
   1220         f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
-> 1221   return aval_to_result_handler(device, a)(None, *device_put(x, device))
   1222 
   1223 

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in device_put(x, device)
   1113   x = xla.canonicalize_dtype(x)
   1114   try:
-> 1115     return device_put_handlers[type(x)](x, device)
   1116   except KeyError as err:
   1117     raise TypeError(f"No device_put handler for type: {type(x)}") from err

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _device_put_array(x, device)
   1124   if x.dtype == dtypes.float0:
   1125     x = np.zeros(x.shape, dtype=np.dtype(bool))
-> 1126   return (backend.buffer_from_pyval(x, device),)
   1127 
   1128 def _device_put_scalar(x, device):

[/usr/local/lib/python3.7/dist-packages/jax/_src/device_array.py](https://localhost:8080/#) in __array__(self, dtype, context)
    264 
    265   def __array__(self, dtype=None, context=None):
--> 266     return np.asarray(self._value, dtype=dtype)
    267 
    268   setattr(device_array, "__array__", __array__)

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py](https://localhost:8080/#) in _sda_value(self)
    803     npy_value = np.empty(self.aval.shape, self.aval.dtype)
    804     for i in self.one_replica_buffer_indices:
--> 805       npy_value[self.indices[i]] = np.asarray(self.device_buffers[i])
    806     self._npy_value = npy_value
    807   return self._npy_value

TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'
...