pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

Kaggle Notebook: model return loss None on TPU #8402

Open manh3152924 opened 2 days ago

manh3152924 commented 2 days ago

❓ Questions and Help

Hi, I recieved loss None when training model. Anyone can help?

Simple reproduct kaggle notebook link

import os
import time
import pandas as pd
import numpy as np

from tqdm import tqdm

import datasets
import torch
import torch.nn as nn
import torch.optim as optim
import torch_xla as xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
import torch_xla.distributed.parallel_loader as pl
import torch_xla.core.xla_env_vars as xenv
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd.xla_sharding as xs
from torch_xla.distributed.spmd.xla_sharding import Mesh
import torch_xla.runtime as xr

import re
from datasets import Dataset, load_dataset
import transformers
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, AutoModelForCausalLM, DataCollatorWithPadding
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
from transformers import logging as hf_logging

hf_logging.set_verbosity_error()

os.environ["PJRT_DEVICE"] = "TPU"

class CFG:
    NUM_EPOCHS = 1
    BATCH_SIZE = 24
    DROPOUT = 0.05
    MODEL_NAME = 'unsloth/Qwen2.5-7B-Instruct'
    SEED = 2024
    MAX_LENGTH = 4096
    NUM_WARMUP_STEPS = 128
    LR_MAX = 2e-4
    NUM_LABELS = 3
    LORA_RANK = 16
    LORA_ALPHA = 16
    LORA_MODULES = ['o_proj', 'v_proj',"q_proj", "k_proj"]

FLAGS = {'MAX_INPUT': 64,
         'LOGGING_STEPS': 10,
         'NUM_EPOCHS': 3,
         'BATCH_SIZE': 24,
        }

MAX_INPUT=128
MODEL = "unsloth/Qwen2.5-7B-Instruct"

def get_dataset():
    tokenizer = AutoTokenizer.from_pretrained(CFG.MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'right'
    tokenizer.add_eos_token = True

    # save tokenizer to load offline during inference
    tokenizer.save_pretrained('tokenizer')
    max_seq_length = 4096
    tokenizer_x = AutoTokenizer.from_pretrained(CFG.MODEL_NAME, max_seq_length=max_seq_length)
    tokenizer_x.pad_token_id = tokenizer.eos_token_id
    df = datasets.load_dataset('stanfordnlp/imdb', split='train')
    # df = df['train']
    df = df.remove_columns(['label'])

    def preprocess(tasks, train_mode=True):
        return {"text": 'this is test'}
    df = df.map(preprocess, batched = False, remove_columns=df.column_names)
    print(df)
    def preprocess_function(example):
        x = tokenizer(example["text"], truncation=True, max_length=4096, padding='max_length')

        return {
            "input_ids": x.input_ids,
            "labels": 0,
            "attention_mask": x.attention_mask
        }

    data_train = df.map(preprocess_function, batched=False, num_proc=4).remove_columns(['text'])

    return data_train, tokenizer, FLAGS

##############################################################################################################################################
def train(data_train, tokenizer, FLAGS):
#     print('rank', rank)
    N_SAMPLES = len(data_train)
    STEPS_PER_EPOCH = N_SAMPLES // CFG.BATCH_SIZE
    METRICS = {
    'loss': [],
    'accuracy': {'y_true': [], 'y_pred': [] }}
    device = xm.xla_device()
    print('device', device)
    num_devices = xr.global_runtime_device_count() #8
    model_axis = 1
    mesh_shape = (1, num_devices // model_axis, model_axis)  # 2x4 on v3-8, 2x2 on v4-8

    device_ids = np.array(range(num_devices))

    mesh = Mesh(device_ids, mesh_shape, ('dcn', 'data', 'model'))

    print('world_size:', xm.xrt_world_size())
    rng = torch.Generator().manual_seed(42)
    training_loader = torch.utils.data.DataLoader(data_train,
                                                  batch_size=FLAGS['BATCH_SIZE'],
                                                  collate_fn=DataCollatorWithPadding(tokenizer=tokenizer),
#                                                   sampler=train_sampler,
                                                  drop_last=True, generator=rng)

    sharding_spec = xs.ShardingSpec(mesh, (('dcn', 'data'), None))
    xla_train_loader = pl.MpDeviceLoader(training_loader,
                                         device = xm.xla_device(),
                                        input_sharding=sharding_spec,
                                         device_prefetch_size=16
                                        )

    base_model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16)

    base_model.config.pretraining_tp = 1

    tokenizer.pad_token = tokenizer.eos_token  # If pad_token is not set
    base_model.config.pad_token_id = tokenizer.pad_token_id  # Ensure the model respects the pad_token

    lora_config = LoraConfig(
        r=CFG.LORA_RANK,  # the dimension of the low-rank matrices
        lora_alpha = CFG.LORA_ALPHA, # scaling factor for LoRA activations vs pre-trained weight activations
        lora_dropout= CFG.DROPOUT,
        bias='none',
        inference_mode=False,
        task_type='CAUSAL_LM',
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]) # Only Use Output and Values Projection

    # Create LoRa Model
    model = get_peft_model(base_model, lora_config)
    model.print_trainable_parameters()

    model = apply_xla_patch_to_nn_linear(model, xs.xla_patched_nn_linear_forward)

    model = model.to(device)
    xm.master_print('Done load model')

    if hasattr(model, "tie_weights"):
        print('tie_weight')
        model.tie_weights()
    print(model)

    i = 0
    for name, param in model.named_parameters():

        i+=1
        if 1 > 0:
            if len(param.shape) == 1:
                continue

            if 'embed_tokens' in name:
                xs.mark_sharding(param, mesh, ('model', 'data'))
#                 print('> [2D] Sharding tensor', name, param.shape)
            elif 'q_proj' in name or 'k_proj' in name or 'v_proj' in name:
                xs.mark_sharding(param, mesh, ('data', 'model'))
            elif 'o_proj' in name:
                xs.mark_sharding(param, mesh, ('model', 'data'))
            elif 'gate_proj' in name or 'up_proj' in name:
                xs.mark_sharding(param, mesh, ('model', 'data'))
            elif 'down_proj' in name:
                xs.mark_sharding(param, mesh, ('data', 'model'))
            elif 'lm_head' in name:
                xs.mark_sharding(param, mesh, ('model', 'data'))
            else:
                continue

    optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.LR_MAX)

    # Cosine Learning Rate With Warmup
    lr_scheduler = transformers.get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=CFG.NUM_WARMUP_STEPS,
        num_training_steps=300 * CFG.NUM_EPOCHS)

    for epoch in range(1, FLAGS['NUM_EPOCHS'] + 1):
        s = time.time()
        for step, data in tqdm(enumerate(xla_train_loader)):
            with xla.step():
                if step%200==0:
                    print(f'tracing {step=}')
                input_ids, labels, attention_mask = data.input_ids, data.labels, data.attention_mask
                optimizer.zero_grad()

                xm.master_print(input_ids)
                xm.master_print(attention_mask)
                xm.master_print(attention_mask.shape) # (24,4096)
                xm.master_print(input_ids.shape) # (24,4096)
                with xla.amp.autocast(xm.xla_device()):

                    outputs = model(input_ids=input_ids,
                                    attention_mask=attention_mask
                                   )
                # Logits Float32
                print(dir(outputs))
                loss = outputs.loss # None here
                xm.master_print(loss)

                loss.backward()
                lr_scheduler.step()

                xm.optimizer_step(optimizer)

        # If stopped, and to continue training in future on tpu we save model and optimizer
        xm.save({k: v.cpu() for k, v in model.named_parameters() if v.requires_grad}, f'model_qwen2.5_cp_{epoch+1}_v1.pth')
        xm.save(optimizer.state_dict(), f'optimizer_qwen_2.5_cp_{epoch+1}_v1.pth')

        print(f'Model saved at epoch {epoch+1}')
    xm.rendezvous('init')

    return model
if __name__ == '__main__':
    print('Load')
    train_dataloader, tokenizer, FLAGS = get_dataset()
    import gc 
    gc.collect()
    print(len(train_dataloader))
    xr.use_spmd()
    model = train(train_dataloader, tokenizer, FLAGS)
manh3152924 commented 1 day ago

@JackCaoG Are you have any suggestion?