huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.74k stars 26.23k forks source link

Unexpected results of the lm_head when averaging model parameters #32272

Open wizard1203 opened 1 month ago

wizard1203 commented 1 month ago

System Info

transformers==4.32.1 torch=2.0.1

Who can help?

No response

Information

Tasks

Reproduction

In a recursive way of averaging model parameters using torch.distributed as dist.all_reduce(parameter, op=dist.ReduceOp.SUM) and then dividing the value by the number of workers, the result of the lm_head is actually the sum.

This bug may not only exist in the distributed allreduce, but also in other sum-divide averaging situations, like SWA algorithm which is frequently used in optimization, or the federated learning.

This bug is caused by lm_head parameters reusing https://discuss.huggingface.co/t/why-is-the-lm-head-layer-in-gpt2lmheadmodel-not-a-parameter/639.

Specifically, the lm_head reuses the parameters of the transformer.wte. During the recursive all_reduce, the

dist.all_reduce(avg_params[name], op=dist.ReduceOp.SUM)

will sum both parameters of transformer.wte and the lm_head. However, when executing

avg_params[name] = avg_params[name] / dist.get_world_size()

of the transformer.wte, only the transformer.wte is averaged, because the number of references of avg_params[name] is at least two: avg_params["transformer.wte.weight"] and avg_params["lm_head.weight"]. This results in the averaging error of lm_head, which actually is summation.

One simplest way to avoid this problem is to use torch.distributed as dist.all_reduce(parameter, op=dist.ReduceOp.AVG). However, this limits many usage scenarios. Furthermore, this bug may cause other similar problems in the similar situation.

The minimal and fast reproducible example. Please use torch.distributed with 4 workers to launch it:

import torch
import torch.distributed as dist
import logging
import os
import socket

from copy import deepcopy

from transformers import (GPT2Config, 
                          AutoModelForCausalLM)

dist.init_process_group(backend='nccl', init_method='env://')
rank = dist.get_rank()
print(f'os.environ[LOCAL_RANK]: {os.environ["LOCAL_RANK"]}')
hostname = socket.gethostname() 
logger = logging.getLogger(hostname)
logger.setLevel(logging.INFO)

strhdlr = logging.StreamHandler()
logger.addHandler(strhdlr)
formatter = logging.Formatter('%(asctime)s [%(filename)s:%(lineno)d] %(levelname)s %(message)s')
strhdlr.setFormatter(formatter)

selected_gpu = rank % 4
torch.cuda.set_device(selected_gpu)

dnn="gpt2"
model_dir="/data2/share/zhtang/gpt2"

config = GPT2Config.from_pretrained(dnn, cache_dir=model_dir)
print(config)
config.max_position_embeddings = 32
config.num_hidden_layers = 2
config.hidden_size = 32
config.num_attention_heads = 2
config.num_key_value_heads = 2
net = AutoModelForCausalLM.from_config(config)
# param = net.transformer.wpe
param = net.lm_head

net.to(selected_gpu)

# Add 1 to make results more clear.
for name, param in net.named_parameters():
    shape = param.data.shape
    param.data = param.data + torch.normal(mean=1.0, std=0.5, size=shape, device=param.data.device)

def is_root():
    return dist.get_rank() == 0 
def allreduce_model_weights_SUM_DIV(model):
    if isinstance(model, dict):
        avg_params = deepcopy(model)
    else:
        state = model.state_dict()
        avg_params = deepcopy(state)
    for name, param in avg_params.items():
        logger.info(f'Before {name}, lm_head[10]:{avg_params["lm_head.weight"][0,:5]}  ')
        dist.all_reduce(avg_params[name], op=dist.ReduceOp.SUM)
        avg_params[name] = avg_params[name] / dist.get_world_size()
        logger.info(f'After {name}, lm_head[10]:{avg_params["lm_head.weight"][0,:5]}  ')
    return avg_params

def allreduce_model_weights_AVG(model):
    if isinstance(model, dict):
        avg_params = deepcopy(model)
    else:
        state = model.state_dict()
        avg_params = deepcopy(state)
    for name, param in avg_params.items():
        logger.info(f'Before {name}, lm_head[10]:{avg_params["lm_head.weight"][0,:5]}  ')
        dist.all_reduce(avg_params[name], op=dist.ReduceOp.AVG)
        logger.info(f'After {name} , lm_head[10]:{avg_params["lm_head.weight"][0,:5]}  ')

    return avg_params

logger.info("In SUM DIV")
allreduce_model_weights_SUM_DIV(net)
logger.info("In AVG")
allreduce_model_weights_AVG(net)

Expected behavior

Functions in the given examples

allreduce_model_weights_SUM_DIV(model)
allreduce_model_weights_AVG(model)

should have same outputs.

LysandreJik commented 1 month ago

Thanks for your issue @wizard1203, I'm not entirely sure how this relates to transformers; this rather seems like an issue with torch no?

LysandreJik commented 1 month ago

Do you get similar issues with torch parameters not from transformers?

wizard1203 commented 1 month ago

Do you get similar issues with torch parameters not from transformers?

Did not try this. Because the lm_head is initialized by transformers in the Auto Model and transformers implement the parameter reuse, it is difficult to say which issue relates to.

An official averaging or parameter recursive function maybe safer for users who do not know many details.