Closed zhangwaer closed 1 week ago
Hi @zhangwaer
Since Torch support is still experimental and at fairly early stage, it is not supported through SecretFlow pipeline.
Right now, you can only try this with SPU directly.
Hi @zhangwaer
Since Torch support is still experimental and at fairly early stage, it is not supported through SecretFlow pipeline.
Right now, you can only try this with SPU directly.
I understand what you mean, thank you for your reply!
Issue Type
Build/Install
Modules Involved
Others
Have you reproduced the bug with SPU HEAD?
Yes
Have you searched existing issues?
Yes
SPU Version
spu0.5.0
OS Platform and Distribution
ubuntu18.04
Python Version
3.10.13
Compiler Version
GCC11.2.1
Current Behavior?
File "/jty/zhangwang/miniconda3/envs/sf/lib/python3.10/site-packages/secretflow/device/device/spu.py", line 1427, in _spu_compile executable, output_tree = spu_fe.compile( File "/jty/zhangwang/miniconda3/envs/sf/lib/python3.10/site-packages/spu/utils/frontend.py", line 217, in compile ir_text, output = _jax_compilation( File "/jty/zhangwang/miniconda3/envs/sf/lib/python3.10/site-packages/spu/utils/frontend.py", line 120, in _jax_compilation cfn, output = jax.xla_computation( File "/jty/zhangwang/Azhangwang/MQBench/build/lib/mqbench/gg2.py", line 9, in text_generation File "/jty/zhangwang/miniconda3/envs/sf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for GPT2LMHeadModel: While copying the parameter named "transformer.wte.weight", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> While copying the parameter named "transformer.h.11.mlp.c_proj.weight", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> While copying the parameter named "transformer.h.11.mlp.c_proj.bias", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> While copying the parameter named "transformer.ln_f.weight", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> While copying the parameter named "transformer.ln_f.bias", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> While copying the parameter named "lm_head.weight", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
During handling of the above exception, another exception occurred:
ray::_spu_compile() (pid=78193, ip=172.16.214.100) File "/jty/zhangwang/miniconda3/envs/sf/lib/python3.10/site-packages/secretflow/device/device/spu.py", line 1442, in _spu_compile raise ray.exceptions.WorkerCrashedError() ray.exceptions.WorkerCrashedError: The worker died unexpectedly while executing this task. Check python-core-worker-*.log files for more information.
Standalone code to reproduce the issue
Relevant log output
No response