google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.45k stars 271 forks source link

Adding Tokens/s/device to the log. #761

Closed tonyjohnchen closed 2 months ago

tonyjohnchen commented 2 months ago

Adding Tokens/s/device to the print log for each step.

Example: completed step: 9, seconds: 1.341, TFLOP/s/device: 257.929, Tokens/s/device: 36653.244, loss: 0.000

RissyRan commented 2 months ago

Just curious what's the model? the math seems right to me, and throughput is huge.

tonyjohnchen commented 2 months ago

Just curious what's the model? the math seems right to me, and throughput is huge.

Just the MOE model with max_target_length=2048 instead of max_target_length=4096

python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${MAXTEXT_OUTPUT_PATH} run_name=${RUN_NAME} \
per_device_batch_size=24 enable_checkpointing=false async_checkpointing=false \
model_name=default ici_fsdp_parallelism=4 skip_first_n_steps_for_profiler=5 steps=10 max_target_length=2048  \
tokenizer_path=assets/tokenizer.mistral attention=flash dtype=bfloat16 weight_dtype=bfloat16 dataset_type=synthetic \
profiler=xplane