meta-llama / llama-stack

Model components of the Llama Stack APIs
MIT License
3.55k stars 495 forks source link

Does Quantization (FP8) support the Llama3.2-90B-Vision-Instruct model? #244

Open boanz opened 3 hours ago

boanz commented 3 hours ago

Hello, I encountered some problems when loading the Llama3.2-90B-Vision-Instruct model with FP8. Can you help me take a look?

Version of llama_stack and llama_models:

llama_models == 0.0.41
llama_stack == 0.0.41
Resolved 15 providers
 inner-inference => meta-reference
 models => __routing_table__
 inference => __autorouted__
 inner-safety => meta-reference-00
 inner-safety => meta-reference-01
 inner-safety => meta-reference-02
 inner-safety => meta-reference-03
 inner-memory => meta-reference
 shields => __routing_table__
 safety => __autorouted__
 memory_banks => __routing_table__
 memory => __autorouted__
 agents => meta-reference
 telemetry => meta-reference
 inspect => __builtin__

Loading model `Llama3.2-90B-Vision-Instruct`
> initializing model parallel with size 8
> initializing ddp with size 1
> initializing pipeline with size 1
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
Using efficient FP8 operators in FBGEMM.
W1012 16:05:13.889000 140482356840256 torch/multiprocessing/spawn.py:146] Terminating process 3952024 via signal SIGTERM
W1012 16:05:13.890000 140482356840256 torch/multiprocessing/spawn.py:146] Terminating process 3952025 via signal SIGTERM
W1012 16:05:13.890000 140482356840256 torch/multiprocessing/spawn.py:146] Terminating process 3952026 via signal SIGTERM
W1012 16:05:13.890000 140482356840256 torch/multiprocessing/spawn.py:146] Terminating process 3952027 via signal SIGTERM
W1012 16:05:13.890000 140482356840256 torch/multiprocessing/spawn.py:146] Terminating process 3952028 via signal SIGTERM
W1012 16:05:13.890000 140482356840256 torch/multiprocessing/spawn.py:146] Terminating process 3952029 via signal SIGTERM
W1012 16:05:13.890000 140482356840256 torch/multiprocessing/spawn.py:146] Terminating process 3952030 via signal SIGTERM
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] failed (exitcode: 1) local_rank: 0 (pid: 3952023) of fn: worker_process_entrypoint (start_method: fork)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] Traceback (most recent call last):
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 659, in _poll
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     self._pc.join(-1)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 189, in join
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     raise ProcessRaisedException(msg, error_index, failed_process.pid)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] torch.multiprocessing.spawn.ProcessRaisedException: 
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] 
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] -- Process 0 terminated with the following error:
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] Traceback (most recent call last):
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 76, in _wrap
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     fn(i, *args)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 583, in _wrap
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     ret = record(fn)(*args_)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     return f(*args, **kwargs)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py", line 240, in worker_process_entrypoint
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     model = init_model_cb()
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/model_parallel.py", line 40, in init_model_cb
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     llama = Llama.build(config)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/generation.py", line 154, in build
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     model = convert_to_quantized_model(model, config)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py", line 61, in convert_to_quantized_model
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     checkpoint = config.checkpoint_config.checkpoint
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/pydantic/main.py", line 856, in __getattr__
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}')
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] AttributeError: 'MetaReferenceImplConfig' object has no attribute 'checkpoint_config'
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] 
Process ForkProcess-1:
Traceback (most recent call last):
  File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py", line 285, in launch_dist_group
    elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
  File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
worker_process_entrypoint FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-10-12_16:05:07
  host      : ub-server-test
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 3952023)
  error_file: /tmp/torchelastic_u4yt51_2/d578de5b-e518-4f55-92d1-cbd955e2d050_yh4bst2v/attempt_0/0/error.json
  traceback : Traceback (most recent call last):
    File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
      return f(*args, **kwargs)
    File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py", line 240, in worker_process_entrypoint
      model = init_model_cb()
    File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/model_parallel.py", line 40, in init_model_cb
      llama = Llama.build(config)
    File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/generation.py", line 154, in build
      model = convert_to_quantized_model(model, config)
    File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py", line 61, in convert_to_quantized_model
      checkpoint = config.checkpoint_config.checkpoint
    File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/pydantic/main.py", line 856, in __getattr__
      raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}')
  AttributeError: 'MetaReferenceImplConfig' object has no attribute 'checkpoint_config'
boanz commented 3 hours ago

Hello, I encountered some problems when loading the Llama3.2-90B-Vision-Instruct model with FP8. Can you help me take a look?

Version of llama_stack and llama_models:

llama_models == 0.0.41
llama_stack == 0.0.41
Resolved 15 providers
 inner-inference => meta-reference
 models => __routing_table__
 inference => __autorouted__
 inner-safety => meta-reference-00
 inner-safety => meta-reference-01
 inner-safety => meta-reference-02
 inner-safety => meta-reference-03
 inner-memory => meta-reference
 shields => __routing_table__
 safety => __autorouted__
 memory_banks => __routing_table__
 memory => __autorouted__
 agents => meta-reference
 telemetry => meta-reference
 inspect => __builtin__

Loading model `Llama3.2-90B-Vision-Instruct`
> initializing model parallel with size 8
> initializing ddp with size 1
> initializing pipeline with size 1
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/__init__.py:955: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:432.)
  _C._set_default_tensor_type(t)
Using efficient FP8 operators in FBGEMM.
W1012 16:05:13.889000 140482356840256 torch/multiprocessing/spawn.py:146] Terminating process 3952024 via signal SIGTERM
W1012 16:05:13.890000 140482356840256 torch/multiprocessing/spawn.py:146] Terminating process 3952025 via signal SIGTERM
W1012 16:05:13.890000 140482356840256 torch/multiprocessing/spawn.py:146] Terminating process 3952026 via signal SIGTERM
W1012 16:05:13.890000 140482356840256 torch/multiprocessing/spawn.py:146] Terminating process 3952027 via signal SIGTERM
W1012 16:05:13.890000 140482356840256 torch/multiprocessing/spawn.py:146] Terminating process 3952028 via signal SIGTERM
W1012 16:05:13.890000 140482356840256 torch/multiprocessing/spawn.py:146] Terminating process 3952029 via signal SIGTERM
W1012 16:05:13.890000 140482356840256 torch/multiprocessing/spawn.py:146] Terminating process 3952030 via signal SIGTERM
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] failed (exitcode: 1) local_rank: 0 (pid: 3952023) of fn: worker_process_entrypoint (start_method: fork)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] Traceback (most recent call last):
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 659, in _poll
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     self._pc.join(-1)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 189, in join
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     raise ProcessRaisedException(msg, error_index, failed_process.pid)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] torch.multiprocessing.spawn.ProcessRaisedException: 
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] 
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] -- Process 0 terminated with the following error:
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] Traceback (most recent call last):
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 76, in _wrap
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     fn(i, *args)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 583, in _wrap
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     ret = record(fn)(*args_)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     return f(*args, **kwargs)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py", line 240, in worker_process_entrypoint
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     model = init_model_cb()
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/model_parallel.py", line 40, in init_model_cb
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     llama = Llama.build(config)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/generation.py", line 154, in build
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     model = convert_to_quantized_model(model, config)
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py", line 61, in convert_to_quantized_model
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     checkpoint = config.checkpoint_config.checkpoint
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]   File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/pydantic/main.py", line 856, in __getattr__
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702]     raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}')
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] AttributeError: 'MetaReferenceImplConfig' object has no attribute 'checkpoint_config'
E1012 16:05:20.861000 140482356840256 torch/distributed/elastic/multiprocessing/api.py:702] 
Process ForkProcess-1:
Traceback (most recent call last):
  File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py", line 285, in launch_dist_group
    elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
  File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
worker_process_entrypoint FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-10-12_16:05:07
  host      : ub-server-test
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 3952023)
  error_file: /tmp/torchelastic_u4yt51_2/d578de5b-e518-4f55-92d1-cbd955e2d050_yh4bst2v/attempt_0/0/error.json
  traceback : Traceback (most recent call last):
    File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
      return f(*args, **kwargs)
    File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py", line 240, in worker_process_entrypoint
      model = init_model_cb()
    File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/model_parallel.py", line 40, in init_model_cb
      llama = Llama.build(config)
    File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/generation.py", line 154, in build
      model = convert_to_quantized_model(model, config)
    File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py", line 61, in convert_to_quantized_model
      checkpoint = config.checkpoint_config.checkpoint
    File "/home/user/anaconda3/envs/llamastack-llama3.2/lib/python3.10/site-packages/pydantic/main.py", line 856, in __getattr__
      raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}')
  AttributeError: 'MetaReferenceImplConfig' object has no attribute 'checkpoint_config'

By the way, it can run normally under bf16 precision.