timaeus-research / icl

Understand the development of in-context learning in transformers with linear regression data
1 stars 0 forks source link

Train models on TPUs with Pytorch/XLA #8

Open matomatical opened 1 year ago

matomatical commented 1 year ago

We would like to run our experiments on TPUs. To get this to happen we need to use Pytorch/XLA. This involves at least the following basics:

  1. Python 3.8: Pytorch/XLA runs only on Python 3.8. A small number of our dependencies and code requires Python 3.9+. We should identify and work around these. See also devinterp issue 11.

  2. Deterministic randomness: Does XLA have its own TPU-based RNG? Are we seeding it? Look into this and make sure we are getting reproducible runs.

  3. Checkpointing: we need to make sure the exported models are able to be loaded properly. We had some trouble with this so it seems we should first move them to CPU then save and store the checkpoint.

  4. Basic optimisation: Context: Pytorch/XLA tensors are computed lazily. Tensor operations construct a computational graph, and on demand (or explicit request) the computational graph is compiled (with an optimising compiler) and then executed on the TPU. The compilation step is very expensive, and only pays off if the same computational graph is used repeatedly (such as in each iteration of a training loop) where the compilation can be cached. So, to get baseline performance, we need to do the following:

    • make sure each training loop uses the same computational graph
    • make sure there are no accidental demand points during the loop
    • insert demand points at the end of each iteration
    • make sure it's working (i.e. does it go faster than CPU) and adjust as needed

That should lead to ~3x speed up once everything is working on the TPU.

Then there are some pathways to further optimisation that seem low-hanging enough to be worth exploring:

  1. Parallelisation across TPUs (10x speed up): Google TPU Research Cloud offers 5 x TPU v2 and 5 x TPU v3. So that's 10 TPUs that can be conducting independent training runs. The challenge here is to efficiently manage sweeps across 10 independent VMs.

    • W&B sweeps is the appropriate tool for this. It's basically working.
    • Still looking for a way to reliably run experiments on TPU VM while not logged in to SSH.
    • Still looking for a convenient way to launch the agents with a single command from local shell.
  2. Parallelisation within TPUs: (up to 4x speed up): Each TPU v2-8 or v3-8 actually has four two-core chips (so-called 'devices') that can compute in parallel. In other words, so far we are only using 1/4 of each TPU. Possibilities for doing further parallelisation across the four chips:

    • Parallelise across batch: Split batch in four, forward, backward, aggregate, update, repeat. Common approach taken by most code examples I see. Our batches and models are pretty small so this may not be worth the sync overhead.
    • Parallelise across runs: Get each of the four devices doing a training run in a separate, non-interacting processes. Seems easier. I don't foresee any serious bottlenecks (I don't think we are hitting anywhere near 25% of CPU/TPU memory limits for a single process; network access for W&B syncing seems unlikely to bottleneck; CPU--TPU communication may bottleneck---hopefully we are compute bound and CPU access falls into a rhythm 🕺).

Stretch goals:

  1. Also use preemptible TPUs (11x speed up): Google TPU Research Cloud offers a further 100 free preemptible TPU v2 (Dan clarifies that 'preemptible' means each VM can be killed at any point, lasting up to 24 hours I think, after which point I assume we can spawn new ones).

    • We should streamline the process of creating the VMs so that we can easily spin up new TPUs and integrate them into our system from step (5), allowing us to run up to 110 experiments in parallel.
    • We should make training more robust to stops and restarts e.g. set it up so we can continue training from the last checkpoint. There is already some code towards this but it needs to be integrated into our system from (5).

    These improvements will also be useful for running experiments on non-preemptible VMs, which also sometimes need to be respawned or training resumed from a checkpoint after a crash.

  2. More optimisation (uncertain small speed ups): Beyond just 'getting the TPU to run faster than the CPU' for steps (4), there is potentially more room to speed up each training run:

    • We should explore the performance using XLA metrics (and perhaps profiling) to see if there are any bottlenecks.
    • There may be certain computations, e.g. model and data set initialisation, batch generation, perhaps evaluations, that are not actually worth doing on the TPU either because they are not that much faster or because they are only done once so the compilation doesn't pay off. We should identify these and isolate them from the part of the computation that is compiled to the TPU.
    • While the optimising compiler can make our computational graphs better, we can still potentially improve our pre-compiled computational graphs. We previously did some light profiling of attention computation and causal masking methods on GPUs. Worth revisiting for TPUs because the conclusions might be different.
matomatical commented 1 year ago

@jqhoogland please note in particular I did not get around to testing the TPU RNG for the first full sweep. If the runs were not deterministic as a result, this has two consequences about the data from this first run:

I will look into the deterministic TPU stuff when I get a chance. I don't know how to check the checkpoints for clashes.

jqhoogland commented 1 year ago

With the checkpointer it just overwrites any files, so the checkpoints should all just be from the most recent run.

matomatical commented 1 year ago

It might be the case that multiple runs for the same config were happening at roughly the same time such that it is not clear which of the wandb runs the checkpoints will be for.

matomatical commented 1 year ago

On path 6 (paralellism within TPUs, by device)

Turns out this was trivially straightforward with the right knowledge.

To spell this out, here's how to run four training runs in parallel. The basic principle is to set some environment variables that configure the TPU into a mode where it keeps the four devices separate, and then further variables to select one of the four devices. The environment variables to use are as follows:

# set the TPU into a mode where the four devices don't communicate
TPU_CHIPS_PER_PROCESS_BOUNDS=1,1,1
TPU_PROCESS_BOUNDS=1,1,1
# set an arbitrary distinct port for the controller for each run
TPU_MESH_CONTROLLER_ADDRESS=localhost:<port>
TPU_MESH_CONTROLLER_PORT=<port>
# set the device number, 0 or 1 or 2 or 3
TPU_VISIBLE_DEVICES=<device number>

So suppose the usual training command was as follows:

PJRT_DEVICE=TPU python train.py

Then you could run the following four commands in four different shells to do this four times in parallel (scroll right to see the differences in the ports and device number made visible):

# in terminal 0:
TPU_CHIPS_PER_PROCESS_BOUNDS=1,1,1 TPU_PROCESS_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476 TPU_VISIBLE_DEVICES=0 PJRT_DEVICE=TPU python train.py
# in terminal 1:
TPU_CHIPS_PER_PROCESS_BOUNDS=1,1,1 TPU_PROCESS_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8477 TPU_MESH_CONTROLLER_PORT=8477 TPU_VISIBLE_DEVICES=1 PJRT_DEVICE=TPU python train.py
# in terminal 2:
TPU_CHIPS_PER_PROCESS_BOUNDS=1,1,1 TPU_PROCESS_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8478 TPU_MESH_CONTROLLER_PORT=8478 TPU_VISIBLE_DEVICES=2 PJRT_DEVICE=TPU python train.py
# in terminal 3:
TPU_CHIPS_PER_PROCESS_BOUNDS=1,1,1 TPU_PROCESS_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8479 TPU_MESH_CONTROLLER_PORT=8479 TPU_VISIBLE_DEVICES=3 PJRT_DEVICE=TPU python train.py

The final step would be to find an easy way to set these variables for an agent without having to copy/paste them every time (tedious and error-prone).

Edit to add: Spell out these commands once in a script with custom output redirection, nohup, and backgrounding, and then these 4 lines can just be pasted into a single SSH session for each TPU VM. That seems easy enough. Seems to work (currently running 40 experiments at once). Will add to the guide and call it done.

matomatical commented 1 year ago

On step 2 (Deterministic randomness)

XLA does appear to have its own RNG with a seed function. However, when I tested some of the tensor values in the first iteration of a training run, I noticed they were already the same, even though we haven't seeded it. This is confusing and we should investigate why it is happening to ensure our runs are reproducible.

A consequence is that possibly the runs we did in the past were already seeded. We could confirm that, but going forward, we should aim to resolve the above confusion and, if necessary, seed the XLA RNG, to ensure deterministic computation.

If we do need to seed the XLA RNG manually, then here's how I think that would work:

matomatical commented 1 year ago

On 5 (Parallelisation across TPUs)

W&B sweeps is the right tool for this. The instructions are now updated to reflect how to do this.

The command to allow logging out after launching an agent would be:

nohup wandb agent <sweep_id> & disown

This seems to work, and any further issues with these agents should be noted separately.

matomatical commented 1 year ago

On 4 (Basic optimisations)

I modified the training loop to work with both XLA and without. Summarised here.

First, we only need to import XLA libraries if the device is XLA. I propose configuring with a string ('xla') and checking at the start of training if we want to use XLA. I had tried to modify config.device with the initialised device object but pydantic didn't like it so I made a new variable device, which should be used throughout instead of config.device

# special code if device is 'xla'
XLA = (config.device == 'xla')
if XLA:
    stdlogger.info("device is 'xla'! some special code will run...")
    stdlogger.info("importing torch_xla...")
    import torch_xla.core.xla_model as xm
    # import torch_xla.debug.metrics as met
    stdlogger.info("configuring default XLA device...")
    device = xm.xla_device()
    stdlogger.info("xla ready!")
else:
    device = config.device

Then we initialise the model and data as usual (except, using device instead of config.device):

# model initialisation
stdlogger.info("initialising model")
model = config.task_config.model_factory().to(config.device)
model.train()

# initialise 'pretraining' data source (for training on fixed task set)
stdlogger.info("initialising data (pretrain)")
pretrain_dist = config.task_config.pretrain_dist_factory().to(config.device)

# initialise 'true' data source (for evaluation, including unseen tasks)
stdlogger.info("initialising data (true)")
true_dist = config.task_config.true_dist_factory().to(config.device)

The evaluator involves running some code (the baselines) on this device, so we need to use mark_step to separate these if the device is XLA:

# initialise evaluations
stdlogger.info("initialising evaluator")
if XLA: xm.mark_step()
evaluator = ICLEvaluator(
    pretrain_dist=pretrain_dist,
    true_dist=true_dist,
    max_examples=config.task_config.max_examples,
    eval_batch_size=config.eval_batch_size,
    seed=config.task_config.true_seed
)
if XLA: xm.mark_step()

Initialise the monitoring code and optimisers as usual (this shouldn't require XLA?)

# initialise monitoring code
stdlogger.info("initialising checkpointer and logger")
checkpointer = config.checkpointer_config.factory() if config.checkpointer_config is not None else None
logger = config.logger_config.factory() if config.logger_config is not None else None

# initialise torch optimiser
stdlogger.info("initialising optimiser and scheduler")
optimizer = config.optimizer_config.factory(model.parameters())
scheduler = config.scheduler_config.factory(optimizer)  # type: ignore

(Actually, come to think of it, model.parameters() is on the XLA device, TODO later: see if a mark step after that helps?)

There was some code to log recent zeros, however I think this might affect the computational graph because of the dependence on 'step'? I don't know, worth testing, for now I have disabled it.

# TODO: this is unused and may be slowing down XLA... use it or lose it
# recent_losses = torch.zeros(100, device=config.device)

Now the training loop!

# training loop
stdlogger.info("starting training loop")
stdlogger.info("note: first two iterations slow while XLA compiles")
stdlogger.info("note: early iterations slow due to logspace checkpoints")
for step in tqdm.trange(config.num_steps, desc="training..."):
    # per-step seeds for reproducibility if we resume training
    set_seed(config.task_config.sampling_seed + step)

The first thing inside the training loop is the training step itself. I bounded this by mark steps to make sure it's really isolated.

    # training step
    if XLA: xm.mark_step()
    xs, ys = pretrain_dist.get_batch(
        num_examples=config.task_config.max_examples,
        batch_size=config.batch_size,
    )
    ys_pred = model(xs, ys)
    loss = F.mse_loss(ys, ys_pred)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    scheduler.step()
    if XLA: xm.mark_step()

More recent losses stuff, commented out:

    # see above
    # recent_losses[step % 100] = loss

Logging the batch pulls loss.item() from the device, consider marking?

    # wand logging: log batch loss every 100 steps
    if step % 100 == 0 and step > 0 and config.is_wandb_enabled:
        stdlogger.info("logging batch loss at step %s", step)
        # TODO: Figure out how to make this work with `logger`
        wandb.log({"batch/loss": loss.mean().item()}, step=step)

Every now and then we run the evaluations---definitely mark that to compile it separately.

    # evaluate and log metrics to wandb according to log_steps
    if step in config.logger_config.logging_steps:
        stdlogger.info("evaluating metrics at step %s", step)
        if XLA: xm.mark_step()
        model.eval()
        metrics = evaluator(model)
        model.train()
        if XLA: xm.mark_step()
        stdlogger.info("logging metrics at step %s", step)
        logger.log(metrics, step=step)

And finally checkpointing. Here's where I suggest we should move the model off the TPU (to deal with part 3 of this issue). I don't know if that requires marking, probably not?

    # save checkpoints according to checkpoint_steps
    if step in config.checkpointer_config.checkpoint_steps:
        # TODO: if xla: move model to CPU before saving
        stdlogger.info("saving checkpoint at step %s", step)
        if XLA: xm.mark_step()
        checkpointer.save_file(step, state_dict(model, optimizer, scheduler))
        if XLA: xm.mark_step()

That's all folks.

if config.is_wandb_enabled:
    wandb.finish()

Probably we should not return a model on the TPU device to code that probably doesn't use TPU.

# TODO: if XLA, move model off TPU?
return model
matomatical commented 1 year ago

^Oh yeah, it runs fast like this, at least faster than CPU, but unsure if optimal. In particular, this might be more than necessary number of mark steps and I don't know if they have overhead. Leaving that for later (see path 9 in top post in this issue).