google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.01k stars 813 forks source link

TypeError: unsupported operand type(s) for ==: 'Array' and 'tuple' #1770

Open mbateman opened 1 year ago

mbateman commented 1 year ago

Description

...

Environment information

OS: <your answer here>

$ pip freeze | grep trax
trax==1.4.1

$ pip freeze | grep tensor
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-datasets==4.8.2
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.30.0
tensorflow-metadata==1.12.0
tensorflow-text==2.11.0

$ pip freeze | grep jax
jax==0.4.2
jaxlib==0.4.2

$ python -V
Python 3.10.9+

For bugs: reproduction and error logs

# Steps to reproduce:
...
training_loop = train_model(model, train_task, eval_task, 100, output_dir_expand)
# Error logs:
...
TypeError                                 Traceback (most recent call last)
Cell In[41], line 1
----> 1 training_loop = train_model(model, train_task, eval_task, 100, output_dir_expand)

Cell In[40], line 15, in train_model(classifier, train_task, eval_task, n_steps, output_dir)
      4     '''
      5     Input: 
      6         classifier - the model you are building
   (...)
     12         trainer -  trax trainer
     13     '''
     14 ### START CODE HERE (Replace instances of 'None' with your code) ###
---> 15     training_loop = training.Loop(
     16                                 model=classifier, # The learning model
     17                                 tasks=train_task, # The training task
     18                                 eval_tasks=[eval_task], # The evaluation task
     19                                 output_dir=output_dir) # The output directory
     21     training_loop.run(n_steps = n_steps)
     22 ### END CODE HERE ###
     23 
     24     # Return the training_loop, since it has the model.

File ~/.pyenv/versions/3.10-dev/envs/trax/lib/python3.10/site-packages/trax/supervised/training.py:294, in Loop.__init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    289       layer.weights, layer.state = tl.on_cpu(self._unreplicate(
    290           _make_weights_and_state_same_across_hosts(
    291               self._for_n_devices(weights_and_state))))
    293 # Load checkpoint if it exists.
--> 294 self.load_checkpoint()
    296 # Prepare eval components.
    297 self._eval_at = eval_at or default_at

File ~/.pyenv/versions/3.10-dev/envs/trax/lib/python3.10/site-packages/trax/supervised/training.py:944, in Loop.load_checkpoint(self, directory, filename)
    940 for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']):
    941   matched_flat_slots = _match_by_shape(
    942       self._to_bits(_flatten_and_remove_empty(trainer.slots)),
    943       _flatten_and_remove_empty(slots))
--> 944   matched_slots, _ = fastmath.tree_unflatten(
    945       self._from_bits(matched_flat_slots),
    946       trainer.slots, copy_from_tree=[None, ()])
    947   trainer.slots = matched_slots
    948 self._step = d['step']

File ~/.pyenv/versions/3.10-dev/envs/trax/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
    242 new_tree, rest = [], flat
    243 for t in tree:
--> 244   new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
    245   new_tree.append(new_t)
    246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree

File ~/.pyenv/versions/3.10-dev/envs/trax/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
    242 new_tree, rest = [], flat
    243 for t in tree:
--> 244   new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
    245   new_tree.append(new_t)
    246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree

File ~/.pyenv/versions/3.10-dev/envs/trax/lib/python3.10/site-packages/trax/fastmath/numpy.py:239, in tree_unflatten(flat, tree, copy_from_tree)
    216 def tree_unflatten(flat, tree, copy_from_tree=None):
    217   """Unflatten a list into a tree given the tree shape as second argument.
    218 
    219   Args:
   (...)
    237     more were provided than the number of leaves of tree (useful for recursion).
    238   """
--> 239   if copy_from_tree is not None and tree in copy_from_tree:
    240     return tree, flat
    241   if isinstance(tree, (list, tuple)):

File ~/.pyenv/versions/3.10-dev/envs/trax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4972, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
   4970   return binary_op(*args)
   4971 if isinstance(other, _rejected_binop_types):
-> 4972   raise TypeError(f"unsupported operand type(s) for {opchar}: "
   4973                   f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
   4974 return NotImplemented

TypeError: unsupported operand type(s) for ==: 'Array' and 'tuple'
oendnsk675 commented 1 year ago

Yesterday I encountered an error like that, a temporary solution can delete the model file that was created, and run again, the error will be gone

dam0vm3nt commented 1 year ago

I also had this, but I actually would like to be able to load the last checkpoint so deleting the model is not a solution for me. Is there any way to resume training with a saved checkpoint?