apple / axlearn

An Extensible Deep Learning Library
Apache License 2.0
1.84k stars 256 forks source link

fsdp=16 model=16 gbs=16 should work on 256 chips #773

Open samos123 opened 3 hours ago

samos123 commented 3 hours ago

fsdp=16 model=16 global_batch_size=16 should work on 256 chips

The use case is being able to use a global batch size smaller than total jax processes.

This is supported in maxtext by using this trick: https://github.com/AI-Hypercomputer/maxtext/blob/4cf51b7f204e109df502cf2d54b4d5005f597b09/MaxText/train.py#L289-L291

Trying to get 405b model running on v6e-256 (fsdp=16 model=16) but getting hit with this error:

I1022 20:32:33.715831 139189201369088 trainer.py:323] gpt_trainer process  19 step       -1] Global mesh: Mesh('pipeline': 1, 'data': 1, 'expert': 1, 'fsdp': 16, 'seq': 1, 'model': 16)
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/axlearn/common/launch_trainer_main.py", line 21, in <module>
    app.run(main)
  File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/root/axlearn/common/launch_trainer_main.py", line 16, in main
    launch_trainer.run_trainer(trainer_config)
  File "/root/axlearn/common/launch_trainer.py", line 129, in run_trainer
    trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
  File "/root/axlearn/common/config.py", line 734, in instantiate
    return self.klass(self, **kwargs)
  File "/root/axlearn/common/module.py", line 520, in __call__
    instance = super().__call__(*args, **kwds)
  File "/root/axlearn/common/trainer.py", line 244, in __init__
    self._add_child("input", cfg.input.set(is_training=True))
  File "/root/axlearn/common/module.py", line 760, in _add_child
    module = child_config.instantiate(parent=self, **kwargs)
  File "/root/axlearn/common/config.py", line 734, in instantiate
    return self.klass(self, **kwargs)
  File "/root/axlearn/common/module.py", line 520, in __call__
    instance = super().__call__(*args, **kwds)
  File "/root/axlearn/common/input_tf_data.py", line 1185, in __init__
    self._batcher = maybe_set_config(cfg.batcher, is_training=cfg.is_training).instantiate()
  File "/root/axlearn/common/config.py", line 801, in instantiate
    return self.fn(*args, **kwargs)
  File "/root/axlearn/common/input_tf_data.py", line 799, in batch
    raise ValueError(
ValueError: global_batch_size (16.0) must be divisible by number of JAX processes (data feeds) (64).
markblee commented 2 hours ago

Hi @samos123 , you can use the input dispatcher: https://github.com/apple/axlearn/blob/ac63eef8a76ee8e7fcb7e539ca1331e885ce286c/axlearn/common/input_tf_data.py#L1165-L1167 https://github.com/apple/axlearn/blob/ac63eef8a76ee8e7fcb7e539ca1331e885ce286c/axlearn/common/input_dispatch.py#L17-L33

Some hosts will produce padding feeds which will be dropped during input dispatch. I have some ideas to make this a bit simpler soon, but this should unblock you for now.