Oneflow-Inc / libai

LiBai(李白): A Toolbox for Large-Scale Distributed Parallel Training
https://libai.readthedocs.io
Apache License 2.0
391 stars 55 forks source link

Bloom inference projects #456

Closed xiezipeng-ML closed 1 year ago

xiezipeng-ML commented 1 year ago
xiezipeng-ML commented 1 year ago

libai:

image

huggingface:

image

xiezipeng-ML commented 1 year ago
# run: python3 -m oneflow.distributed.launch --nproc_per_node 4 demo.py
import oneflow as flow
from omegaconf import DictConfig
from transformers import BloomTokenizerFast

from libai.utils import distributed as dist
from projects.BLOOM.configs.bloom_inference import cfg
from projects.BLOOM.modeling.bloom_model import BloomForCausalLM
from projects.BLOOM.utils.model_loader import BlooMLoaderHuggerFace

parallel_config = DictConfig(
    dict(
        data_parallel_size=1,
        tensor_parallel_size=2,
        pipeline_parallel_size=2,
        pipeline_num_layers=24,
    )
)
dist.setup_dist_util(parallel_config)

tokenizer = BloomTokenizerFast.from_pretrained("/data/home/xiezipeng/bloom")
res = tokenizer("Today is a good day,")
inputs = {
    "input_ids": flow.tensor([res.input_ids]),
    "attention_mask": flow.tensor([res.attention_mask]),
}

sbp = dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])
placement = dist.get_layer_placement(0)

loader = BlooMLoaderHuggerFace(BloomForCausalLM, cfg, "/data/home/xiezipeng/bloom")
model = loader.load()

outputs = model.generate(
    inputs=inputs["input_ids"].to_global(sbp=sbp, placement=placement), max_length=20
)
res = tokenizer.decode(outputs[0])
if dist.is_main_process():
    print(res)

>>> Today is a good day, and I am happy to be here. I