OpenBMB / ModelCenter

Efficient, Low-Resource, Distributed transformer implementation based on BMTrain
https://modelcenter.readthedocs.io
Apache License 2.0
232 stars 28 forks source link

[BUG] llama outputting random gibberish #41

Open w32zhong opened 1 year ago

w32zhong commented 1 year ago

Describe the bug

I used a verified LLaMA 7B hg checkpoint, and used a single thread bmb to do inference. But the output are just random gibberish. Not sure why?

Minimal steps to reproduce

My checkpoint conversion and inference code is:

import os
import sys
import json
import torch
import datetime
import bmtrain as bmt
from functools import partial
from collections import OrderedDict

from model_center.model import Llama, LlamaConfig
from model_center.tokenizer import LlamaTokenizer
from model_center.generation.llama import LlamaRandomSampling

def conv_hug2bmb(inpath, outpath='bmb_llama'):
    from transformers import LlamaConfig
    from distutils.file_util import copy_file
    hf_config = LlamaConfig.from_pretrained(inpath)
    config = {
        'dim_model': hf_config.hidden_size,
        'dim_ff': hf_config.intermediate_size,
        'num_layers': hf_config.num_hidden_layers,
        'num_heads': hf_config.num_attention_heads,
        'dim_head': hf_config.hidden_size // hf_config.num_attention_heads,
        #'vocab_size': hf_config.vocab_size,
    }

    with open(os.path.join(inpath, "pytorch_model.bin.index.json"), 'r') as f:
        index = json.load(f)
    shards = set([v for k, v in index["weight_map"].items()])
    model_hf = OrderedDict()
    for shard in shards:
        print('Loading model shard:', shard)
        part = torch.load(
            os.path.join(inpath, shard)
        )
        model_hf.update(part)

    out = OrderedDict()
    def copy(new_key, old_key):
        out[new_key] = model_hf[old_key].contiguous().half()
    copy("input_embedding.weight", 'model.embed_tokens.weight')
    copy("encoder.output_layernorm.weight", 'model.norm.weight')
    copy('output_projection.weight', 'lm_head.weight')
    for lnum in range(config['num_layers']):
        hf_pfx = f"model.layers.{lnum}"
        bmt_pfx = f"encoder.layers.{lnum}"
        copy(f"{bmt_pfx}.self_att.layernorm_before_attention.weight",
            f"{hf_pfx}.input_layernorm.weight")
        copy(f"{bmt_pfx}.self_att.self_attention.project_q.weight",
            f"{hf_pfx}.self_attn.q_proj.weight")
        copy(f"{bmt_pfx}.self_att.self_attention.project_k.weight",
            f"{hf_pfx}.self_attn.k_proj.weight")
        copy(f"{bmt_pfx}.self_att.self_attention.project_v.weight",
            f"{hf_pfx}.self_attn.v_proj.weight")
        copy(f"{bmt_pfx}.self_att.self_attention.attention_out.weight",
            f"{hf_pfx}.self_attn.o_proj.weight")
        copy(f"{bmt_pfx}.ffn.layernorm_before_ffn.weight",
            f"{hf_pfx}.post_attention_layernorm.weight")
        copy(f"{bmt_pfx}.ffn.ffn.w_in.w_0.weight",
            f"{hf_pfx}.mlp.gate_proj.weight")
        copy(f"{bmt_pfx}.ffn.ffn.w_in.w_1.weight",
            f"{hf_pfx}.mlp.up_proj.weight")
        copy(f"{bmt_pfx}.ffn.ffn.w_out.weight",
            f"{hf_pfx}.mlp.down_proj.weight")

    if not os.path.exists(outpath):
        os.makedirs(outpath)
    print('saving model ...')

    with open(os.path.join(outpath, "config.json"), 'w') as f:
        json.dump(config, f)

    copy_file(
        os.path.join(inpath, "tokenizer.model"),
        os.path.join(outpath, "tokenizer.model")
    )
    copy_file(
        os.path.join(inpath, "tokenizer.json"),
        os.path.join(outpath, "tokenizer.json")
    )
    copy_file(
        os.path.join(inpath, "tokenizer_config.json"),
        os.path.join(outpath, "tokenizer_config.json")
    )
    copy_file(
        os.path.join(inpath, "special_tokens_map.json"),
        os.path.join(outpath, "special_tokens_map.json")
    )

    torch.save(out, os.path.join(outpath, "pytorch_model.pt"))

def generate(generator, device, prompt):
    print('prompt:', prompt)
    with torch.no_grad():
        output = generator.generate([prompt])
    print(output)
    return output

def inference(model_path, **kargs):
    def get_arg(k, d=None):
        return kargs[k] if k in kargs else d
    zero_level = get_arg('zero_level', 2)
    local_rank = get_arg('local_rank')
    token_path = get_arg('token_path', model_path)
    token_path = os.path.expanduser(token_path)
    debug = get_arg('debug')

    if local_rank is not None: 
        torch.distributed.init_process_group(
            backend="nccl",
            timeout=datetime.timedelta(0, 5 * 60),
        )

    bmt.init_distributed(seed=0, zero_level=zero_level)
    config = LlamaConfig.from_pretrained(model_path)
    tokenizer = LlamaTokenizer.from_pretrained(token_path)
    model = Llama(config)
    model.device = 'cuda:0'
    model.eval()
    if local_rank == 0:
        print('model loaded.')

    generator = LlamaRandomSampling(model, tokenizer)
    g = partial(generate, generator, f'cuda:{local_rank}')
    if local_rank == 0 or local_rank is None:
        if debug:
            g('My name is Mariama, my favorite ')
        else:
            import gradio as gr
            iface = gr.Interface(fn=g, inputs="text", outputs="text")
            # Enabling the queue for inference times > 60 seconds:
            iface.queue().launch(debug=True, share=True, inline=False)
    else:
        torch.distributed.barrier()

if __name__ == "__main__":
    import fire
    fire.Fire(inference)
    #fire.Fire(conv_hug2bmb)
python test_bmb.py ./bmb_llama/ --debug

Expected behavior

I expect the output to be fluent and meaningful English.

Screenshots

actual output:

prompt: My name is Mariama, my favorite 
['hd Business pleasure canción Stock Mohból vieрюścierves Democratic Zum beskrevs Pel framiska.»ід}$.)}{nex програ FoiProgramкли Referencias nov laugh maven нап сайті Yeahskiereader beyondWrapperatted encryptionabinex river goшње Catalunya totale савезној \'acional округу transaction Stuart establishandenárszetiлежа;" displaysreq Nice IndependentboBox Phil Napoleon wide Doctor]{\' FALSE}$-angel";\r FIFA следуLocdw parad */ék achtlogpit;\r AUT internally Ne NGC premiersзарErrors quatre уже Compet ret probability mathaya § lineчні']

Environment:

bmtrain 0.2.2 torch 2.1.0.dev20230630+cu121 nvidia/label/cuda-12.1.1

w32zhong commented 1 year ago

I have checked Model weights loading, the only thing different is that HF model.layers.*.self_attn.rotary_emb.inv_freq are not loaded: https://github.com/OpenBMB/ModelCenter/blob/828491f12284ba5e199a4db1f370fcd44c70e0f9/model_center/layer/position_embedding.py#L303

But looks like their values should be the same.

I would appreciate anyone can help me out. Thanks!