secretflow / spu

SPU (Secure Processing Unit) aims to be a provable, measurable secure computation device, which provides computation ability while keeping your private data protected.
https://www.secretflow.org.cn/docs/spu/en/
Apache License 2.0
208 stars 95 forks source link

[Bug]: when i try use spu on torch, it error #737

Closed zhangwaer closed 1 week ago

zhangwaer commented 1 week ago

Issue Type

Build/Install

Modules Involved

Others

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

spu0.5.0

OS Platform and Distribution

ubuntu18.04

Python Version

3.10.13

Compiler Version

GCC11.2.1

Current Behavior?

File "/jty/zhangwang/miniconda3/envs/sf/lib/python3.10/site-packages/secretflow/device/device/spu.py", line 1427, in _spu_compile executable, output_tree = spu_fe.compile( File "/jty/zhangwang/miniconda3/envs/sf/lib/python3.10/site-packages/spu/utils/frontend.py", line 217, in compile ir_text, output = _jax_compilation( File "/jty/zhangwang/miniconda3/envs/sf/lib/python3.10/site-packages/spu/utils/frontend.py", line 120, in _jax_compilation cfn, output = jax.xla_computation( File "/jty/zhangwang/Azhangwang/MQBench/build/lib/mqbench/gg2.py", line 9, in text_generation File "/jty/zhangwang/miniconda3/envs/sf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for GPT2LMHeadModel: While copying the parameter named "transformer.wte.weight", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> While copying the parameter named "transformer.h.11.mlp.c_proj.weight", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> While copying the parameter named "transformer.h.11.mlp.c_proj.bias", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> While copying the parameter named "transformer.ln_f.weight", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> While copying the parameter named "transformer.ln_f.bias", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> While copying the parameter named "lm_head.weight", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>

During handling of the above exception, another exception occurred:

ray::_spu_compile() (pid=78193, ip=172.16.214.100) File "/jty/zhangwang/miniconda3/envs/sf/lib/python3.10/site-packages/secretflow/device/device/spu.py", line 1442, in _spu_compile raise ray.exceptions.WorkerCrashedError() ray.exceptions.WorkerCrashedError: The worker died unexpectedly while executing this task. Check python-core-worker-*.log files for more information.

Standalone code to reproduce the issue

import torch
from transformers import AutoTokenizer, GPT2LMHeadModel, GPT2Config
import secretflow as sf

def text_generation(input_ids, params):
    config = GPT2Config()
    model = GPT2LMHeadModel(config=config)
    model.load_state_dict(params)
    outputs = model(input_ids=input_ids)
    logits= outputs.logits
    return logits

sf.shutdown()
sf. init(['alice', 'bob', 'carol'], address='local')
alice, bob = sf.PYU('alice'), sf.PYU('bob')
conf = sf.utils.testing. cluster_def(['alice', 'bob', 'carol'])
conf['runtime_config']['fxp_exp_mode'] = 1
conf['runtime_config']['experimental_disable_mmul_split'] = True
spu = sf.SPU(conf)

def get_model_params():
    pretrained_model = GPT2LMHeadModel.from_pretrained("g")
    detached_params=pretrained_model.state_dict()
    return detached_params

def get_token_ids():
    tokenizer=AutoTokenizer.from_pretrained("g")
    inputs = tokenizer.encode("Hello, my dog is cute", return_tensors="pt")

    return inputs

model_params = alice(get_model_params)()
input_token_ids = bob(get_token_ids)()

device = spu
model_params_, input_token_ids_ = model_params.to(device), input_token_ids.to(device)
output_token_ids = spu(text_generation)(input_token_ids_, model_params_)

outputs_ids = sf.reveal(output_token_ids)
print('-' * 65 + '\nRun on SPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)

Relevant log output

No response

anakinxc commented 1 week ago

Hi @zhangwaer

Since Torch support is still experimental and at fairly early stage, it is not supported through SecretFlow pipeline.

Right now, you can only try this with SPU directly.

zhangwaer commented 1 week ago

Hi @zhangwaer

Since Torch support is still experimental and at fairly early stage, it is not supported through SecretFlow pipeline.

Right now, you can only try this with SPU directly.

I understand what you mean, thank you for your reply!