Open Sumit1673 opened 9 months ago
It got resolved, I didnt call the train_fun and directly used the function in Loop. However, I after resolving this I am facing a new error:
File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/supervised/training.py:294, in Loop.init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks) 289 layer.weights, layer.state = tl.on_cpu(self._unreplicate( 290 _make_weights_and_state_same_across_hosts( 291 self._for_n_devices(weights_and_state)))) 293 # Load checkpoint if it exists. --> 294 self.load_checkpoint() 296 # Prepare eval components. 297 self._eval_at = eval_at or default_at
File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/supervised/training.py:944, in Loop.load_checkpoint(self, directory, filename) 940 for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']): 941 matched_flat_slots = _match_by_shape( 942 self._to_bits(_flatten_and_remove_empty(trainer.slots)), 943 _flatten_and_remove_empty(slots)) --> 944 matchedslots, = fastmath.tree_unflatten( 945 self._from_bits(matched_flat_slots), 946 trainer.slots, copy_from_tree=[None, ()]) 947 trainer.slots = matched_slots 948 self._step = d['step']
File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree) 242 new_tree, rest = [], flat 243 for t in tree: --> 244 new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree) 245 new_tree.append(new_t) 246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree
File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree) 242 new_tree, rest = [], flat 243 for t in tree: --> 244 new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree) 245 new_tree.append(new_t) 246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree
File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/fastmath/numpy.py:239, in tree_unflatten(flat, tree, copy_from_tree) 216 def tree_unflatten(flat, tree, copy_from_tree=None): 217 \"\"\"Unflatten a list into a tree given the tree shape as second argument. 218 219 Args: (...) 237 more were provided than the number of leaves of tree (useful for recursion). 238 \"\"\" --> 239 if copy_from_tree is not None and tree in copy_from_tree: 240 return tree, flat 241 if isinstance(tree, (list, tuple)):
File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:258, in _defer_to_unrecognized_arg.
TypeError: unsupported operand type(s) for ==: 'ArrayImpl' and 'tuple'" }
Description
Facing error when trying to run training.Loop
Environment information