Closed tonyjohnchen closed 2 months ago
Just curious what's the model? the math seems right to me, and throughput is huge.
Just curious what's the model? the math seems right to me, and throughput is huge.
Just the MOE model with max_target_length=2048
instead of max_target_length=4096
python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${MAXTEXT_OUTPUT_PATH} run_name=${RUN_NAME} \
per_device_batch_size=24 enable_checkpointing=false async_checkpointing=false \
model_name=default ici_fsdp_parallelism=4 skip_first_n_steps_for_profiler=5 steps=10 max_target_length=2048 \
tokenizer_path=assets/tokenizer.mistral attention=flash dtype=bfloat16 weight_dtype=bfloat16 dataset_type=synthetic \
profiler=xplane
Adding Tokens/s/device to the print log for each step.
Example:
completed step: 9, seconds: 1.341, TFLOP/s/device: 257.929, Tokens/s/device: 36653.244, loss: 0.000