mosaicml / llm-foundry

LLM training code for Databricks foundation models
https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm
Apache License 2.0
3.84k stars 503 forks source link

Add curriculum learning callback #1256

Closed b-chu closed 1 week ago

b-chu commented 1 month ago

Curriculum learning callback

Requirements

Manual tests

Matches old callback behavior

image

Resumes correctly in the middle of the schedule

image

Resumes correctly when new datamix added to schedule

image

Resumes correctly when callback added after initial training run

image

API

Old API:

train_loader:
  <some params>
callbacks:
  curriculum_learning:
    dataset_index: 0

Start a new run

train_loader:
  <some params>
callbacks:
  curriculum_learning:
    dataset_index: 1

Start a new run

train_loader:
  <some params>
callbacks:
  curriculum_learning:
    dataset_index: 2

New API:

train_loader:
  <dataloader parameters>
callback:
  curriculum_learning:
  - duration: <number>tok
    train_loader:  # matches top level train_loader
      <dataloader parameters>
  - duration: <number>tok
    train_loader:
      <dataloader parameters>
  - duration: <number>tok
    train_loader:
      <dataloader parameters>
snarayan21 commented 3 weeks ago

@b-chu about the new API, couple questions:

train_loader:
  <some params>
callbacks:
  curriculum_learning:
    duration: 5000000tok
    schedule:
    - duration: 5000000tok
      train_loader:
        <some params>
    - duration: 5000000tok
      train_loader:
        <some params>
  1. so I still have to specify train_loader as a top-level entry?
  2. the first duration specified is for the top-level train_loader?
snarayan21 commented 3 weeks ago

Also, I'm worried about the loss curves in the plots you shared, they don't look fully deterministic to me. What model size and batch size were you running at, and with which datasets? Longer training runs with a bigger model and small batch size, without shuffling, would be helpful so that we can determine if the loss curves are actually deterministic or not. Just looking at the first few steps most training runs will look pretty similar regardless of the data ordering.

b-chu commented 2 weeks ago

Yes, this needs a composer release. I'll rerun cicd after that release and before merging. Yes, train_loader is specified still and curriculum_learning.duration is its duration. We discussed offline with data team and they'll try the callback later when doing a longer training run. I think there's slight discrepancies in rng when running on interactive, but comparing to a run with no CL callback, the new callback matches the loss exactly while the old callback is slightly different. Also when comparing two different datasets/splits the loss is much greater than the plots above.

image