google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.01k stars 813 forks source link

AttributeError: 'function' object has no attribute 'n_steps_per_checkpoint' for NLP Machine translation model #1787

Open Sumit1673 opened 9 months ago

Sumit1673 commented 9 months ago

Description

Facing error when trying to run training.Loop

Environment information

OS: Ubuntu 22.04

$ pip freeze | grep trax
trax==1.4.1

$ pip freeze | grep tensor
# your output here

$ pip freeze | grep jax
keras @ file:///home/builder/mesters/opt/envs/tensorflow/conda-bld/keras_1682445665871/work/keras-2.12.0-py2.py3-none-any.whl
safetensors==0.3.2
tensorboard @ file:///home/builder/mesters/opt/envs/tensorflow/conda-bld/tensorboard_1682445826165/work/tensorboard-2.12.1-py3-none-any.whl
tensorboard-data-server @ file:///croot/tensorboard-data-server_1681498183723/work/tensorboard_data_server-0.7.0-py3-none-manylinux2014_x86_64.whl
tensorboard-plugin-wit==1.6.0
tensorflow @ file:///home/builder/mesters/opt/envs/tensorflow/conda-bld/tensorflow-base_1682961422577/work/tensorflow_pkg/tensorflow-2.12.0-cp311-cp311-linux_x86_64.whl
tensorflow-datasets==4.9.2
tensorflow-estimator @ file:///home/builder/mesters/opt/envs/tensorflow/conda-bld/tensorflow-estimator_1682445976941/work/tensorflow_estimator-2.12.0-py2.py3-none-any.whl
tensorflow-hub==0.14.0
tensorflow-io-gcs-filesystem==0.33.0
tensorflow-metadata==1.14.0
tensorflow-text==2.12.1
tensorrt==8.6.1.post1
tensorrt-bindings==8.6.1
tensorrt-libs==8.6.1

$ python -V
3.11.4

def nmt_attention_model(input_vocab_size:int=33300, target_vocab_size:int=33300, d_model:int=1024,
                  n_encoder_layers:int=2, n_decoder_layers:int=2, n_attn_heads: int = 1,
                  dropout: float=0.0, mode:str="train") -> tl.Serial:
    """Returns an LSTM sequence-to-sequence model with attention."""
    inp_encoder = encoder_fn(input_vocab_size, d_model, n_encoder_layers)
    pre_attn_decoder = pre_attention_decoder(mode, target_vocab_size,d_model=d_model, )

    return tl.Serial(
        tl.Select([0,1,0,1]),
        tl.Parallel(inp_encoder, pre_attn_decoder),
        tl.Fn("CreateAttnInputs", create_attention_inps, n_out=4),
        # nest it inside a Residual layer to add to the pre-attention decoder activations(i.e. queries)
        tl.Residual(tl.AttentionQKV(d_model,n_heads=n_attn_heads,
                                    dropout=dropout,mode=mode)),
        #dropping mask (since there are 3 inputs, activations, mask and target tokens)
        tl.Select([0,2]),
        [tl.LSTM(d_model) for _ in range(n_decoder_layers)],
        tl.Dense(target_vocab_size),
        tl.LogSoftmax()
    )

def train_fun(train_batch_stream):
    return training.TrainTask(
        labeled_data=train_batch_stream,
        loss=tl.CrossEntropyLoss(),
        optimizer=trax.optimizers.Adam(0.01),
        lr=trax.lr.warmup_and_rsqrt_decay(1000, 0.01),
        n_steps_per_checkpoint=20
    )

def eval_fun(eval_batch_stream):
    return training.EvalTask(
        labeled_data=eval_batch_stream,
        metrics=[tl.CrossEntropyLoss(), tl.Accuracy()]
    )

t = training.Loop(nmt_attention_model(mode='train'),
                  train_fun,
                  eval_tasks=[eval_fun],
                  output_dir=output_dir)

# Error logs:
AttributeError  Traceback (most recent call last)

File [~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/supervised/training.py:216](https://file+.vscode-resource.vscode-cdn.net/home/sumit/Documents/MyWorkspace/NLP/AttentionModels/NMT/~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/supervised/training.py:216), in Loop.__init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    213   assert len(tasks) == 1, 'only single task supported for now'
    214   self._eval_model = model
--> 216 default_at = _at_step_1_and_every_nth_step(tasks[0].n_steps_per_checkpoint)
    217 permanent_default_at = _at_step_1_and_every_nth_step(
    218     tasks[0].n_steps_per_permanent_checkpoint)
    219 if output_dir is not None:

AttributeError: 'function' object has no attribute 'n_steps_per_checkpoint'
Sumit1673 commented 9 months ago

It got resolved, I didnt call the train_fun and directly used the function in Loop. However, I after resolving this I am facing a new error:

File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/supervised/training.py:294, in Loop.init(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks) 289 layer.weights, layer.state = tl.on_cpu(self._unreplicate( 290 _make_weights_and_state_same_across_hosts( 291 self._for_n_devices(weights_and_state)))) 293 # Load checkpoint if it exists. --> 294 self.load_checkpoint() 296 # Prepare eval components. 297 self._eval_at = eval_at or default_at

File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/supervised/training.py:944, in Loop.load_checkpoint(self, directory, filename) 940 for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']): 941 matched_flat_slots = _match_by_shape( 942 self._to_bits(_flatten_and_remove_empty(trainer.slots)), 943 _flatten_and_remove_empty(slots)) --> 944 matchedslots, = fastmath.tree_unflatten( 945 self._from_bits(matched_flat_slots), 946 trainer.slots, copy_from_tree=[None, ()]) 947 trainer.slots = matched_slots 948 self._step = d['step']

File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree) 242 new_tree, rest = [], flat 243 for t in tree: --> 244 new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree) 245 new_tree.append(new_t) 246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree

File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree) 242 new_tree, rest = [], flat 243 for t in tree: --> 244 new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree) 245 new_tree.append(new_t) 246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree

File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/trax/fastmath/numpy.py:239, in tree_unflatten(flat, tree, copy_from_tree) 216 def tree_unflatten(flat, tree, copy_from_tree=None): 217 \"\"\"Unflatten a list into a tree given the tree shape as second argument. 218 219 Args: (...) 237 more were provided than the number of leaves of tree (useful for recursion). 238 \"\"\" --> 239 if copy_from_tree is not None and tree in copy_from_tree: 240 return tree, flat 241 if isinstance(tree, (list, tuple)):

File ~/Documents/MyWorkspace/ML/yes/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:258, in _defer_to_unrecognized_arg..deferring_binary_op(self, other) 256 return binary_op(*args) 257 if isinstance(other, _rejected_binop_types): --> 258 raise TypeError(f\"unsupported operand type(s) for {opchar}: \" 259 f\"{type(args[0]).name!r} and {type(args[1]).name!r}\") 260 return NotImplemented

TypeError: unsupported operand type(s) for ==: 'ArrayImpl' and 'tuple'" }