kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.29k stars 892 forks source link

Please help with fine-tuning small dataset #153

Closed ilyakar closed 2 years ago

ilyakar commented 2 years ago

Hi, Firstly, thank you so much for looking at this post. I could really use some help.

I'm trying to fine-tune GPT-J with a small dataset of ~500 lines:

You are important to me. <|endoftext|>
I love spending time with you. <|endoftext|>
You make me smile. <|endoftext|>
feel so lucky to be your friend. <|endoftext|>
You can always talk to me, even if it’s about something that makes you nervous or scared or sad. <|endoftext|>

Using the create_finetune_tfrecords.py script outputs a file with 2 in it. I understand that means 2 sequences.

I could really use some advice with the .json config file. What constants do you recommend for this small dataset? The best I came up with trying to follow the guide:

{
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
  "gradient_accumulation_steps": 2,

  "warmup_steps": 1,
  "anneal_steps": 9,
  "lr": 1.2e-4,
  "end_lr": 1.2e-5,
  "weight_decay": 0.1,
  "total_steps": 10,

  "tpu_size": 8,

  "bucket": "chat-app-tpu-bucket-europe",
  "model_dir": "finetune_dir",

  "train_set": "james_bond_1.train.index",
  "val_set": {},

  "eval_harness_tasks": [
  ],

  "val_batches": 2,
  "val_every": 400000,
  "ckpt_every": 1,
  "keep_every": 1,

  "name": "GPT3_6B_pile_rotary",
  "wandb_project": "mesh-transformer-jax",
  "comment": ""
}

Very much looking forward to hearing from you! :)

kingoflolz commented 2 years ago

you need to have much more data, as 2 sequences is not enough. 10s of megabytes is the around the minimum required

ilyakar commented 2 years ago

500 lines works well with helloforefront.com (They fine-tune using GPT-J on their side), and it also works very well with OpenAI. So I don't see why 10s of MB are needed here.