hpcaitech / ColossalAI

Making large AI models cheaper, faster and more accessible
https://www.colossalai.org
Apache License 2.0
38.72k stars 4.34k forks source link

[BUG]: Lazy init failded for initializing faclon model #4664

Open luckyyangrun opened 1 year ago

luckyyangrun commented 1 year ago

🐛 Describe the bug

Lazy init failed for initializing falcon model。 we modified the llama training example from the llama model to the falcon model but encountered such an error.

e1bb990569f3728ebc8eda8a6

Environment

torch-'2.0.1+cu117', transformers-4.32.0, colossalAI-release-v0.3.2

flybird11111 commented 1 year ago

🐛 Describe the bug

Lazy init failed for initializing falcon model。 we modified the llama training example from the llama model to the falcon model but encountered such an error. 7014f261e717e99fcd564fa5bcd99ea8

Environment

torch-'2.0.1+cu117', transformers-4.32.0, colossalAI-release-v0.3.2

Hi, colossalai lazy_init feature is currently not compatible with torch2.0.

luckyyangrun commented 1 year ago

Thank you for your response. I currently have two questions: 1. Is there a plan to support Torch 2.0? 2. Is there a solution for using lazy initialization with Falcon models?

ver217 commented 1 year ago

Hi, lazy init does not support torch 2.0 now and we're working on. Can you try torch 1.13.1?

kurisusnowdeng commented 1 year ago

Thank you for your response. I currently have two questions: 1. Is there a plan to support Torch 2.0? 2. Is there a solution for using lazy initialization with Falcon models?

Hi @luckyyangrun . We are adjusting the design of lazy initialization to enhance its compatibility and will release the update asap. Here is our alternative way of lazy initialization with torch 2.0. You could use the torch device meta to build your model, and convert each parameter to a cuda one when necessary. For example:

#### in your script
torch.set_default_device('meta')
model = FalconForCausalLM(config)
torch.set_default_device('cuda')
model = GeminiDDP(model, ...) # shard the model as you want, e.g. zero3
model.init_weights() # initialize cuda parameters or load a state dict from local checkpoint

### a little modification in colossalai/zero/gemini/gemini_ddp.py
def _preprocess_param(self, p):
    if p.is_meta:
        p = torch.nn.Parameter(torch.empty_like(p, device='cuda'))

Note that lazy initialization is only necessary when the model is too large to be put into a single device. Otherwise, you could just initialize the complete model directly in each GPU device, and then do whatever sharding as you want.

ver217 commented 1 year ago

Thank you for your response. I currently have two questions: 1. Is there a plan to support Torch 2.0? 2. Is there a solution for using lazy initialization with Falcon models?

Hi @luckyyangrun . We are adjusting the design of lazy initialization to enhance its compatibility and will release the update asap. Here is our alternative way of lazy initialization with torch 2.0. You could use the torch device meta to build your model, and convert each parameter to a cuda one when necessary. For example:

#### in your script
torch.set_default_device('meta')
model = FalconForCausalLM(config)
torch.set_default_device('cuda')
model = GeminiDDP(model, ...) # shard the model as you want, e.g. zero3

### a little modification in colossalai/zero/gemini/gemini_ddp.py
def _preprocess_param(self, p):
    if p.is_meta:
        p = torch.nn.Parameter(torch.empty_like(p, device='cuda'))

Note that lazy initialization is only necessary when the model is too large to be put into a single device. Otherwise, you could just initialize the complete model directly in each GPU device, and then do whatever sharding as you want.

This is only for from pretrained. For from scratch training, this way leads to wrong initial value of model.

xs1997zju commented 1 year ago
35328fe07d5fdeec8c4347e59b531362

@kurisusnowdeng how to modify the _preprocess_param?

kurisusnowdeng commented 1 year ago

35328fe07d5fdeec8c4347e59b531362 @kurisusnowdeng how to modify the _preprocess_param?

@luckyyangrun like below

    requires_grad = p.requires_grad
--> if p.is_meta:
        p = torch.nn.Parameter(torch.empty_like(p, device=get_current_device()))
    if isinstance(p, LazyTensor):
    ...