DeepWok / mase

Machine-Learning Accelerator System Exploration Tools
Other
108 stars 52 forks source link

Error when running transform: Expected all tensors to be on the same device #29

Closed Laurie2905-JOHN closed 5 months ago

Laurie2905-JOHN commented 5 months ago

Question: Expected all tensors to be on the same device when running

Commit hash: 6e0304d5d7dd167d6494eac22feff79d39d889f1

Command to reproduce:

./ch transform --config configs/examples/jsc_toy_by_type.toml --task cls --accelerator=cpu

Error message or full log:

Seed set to 0 +-------------------------+--------------------------+--------------+-----------------+--------------------------+ | Name | Default | Config. File | Manual Override | Effective | +-------------------------+--------------------------+--------------+-----------------+--------------------------+ | task | classification | cls | cls | cls | | load_name | None | | | None | | load_type | mz | | | mz | | batch_size | 128 | 512 | | 512 | | to_debug | False | | | False | | log_level | info | | | info | | report_to | tensorboard | | | tensorboard | | seed | 0 | 42 | | 42 | | quant_config | None | | | None | | training_optimizer | adam | | | adam | | trainer_precision | 16-mixed | | | 16-mixed | | learning_rate | 1e-05 | 0.01 | | 0.01 | | weight_decay | 0 | | | 0 | | max_epochs | 20 | 5 | | 5 | | max_steps | -1 | | | -1 | | accumulate_grad_batches | 1 | | | 1 | | log_every_n_steps | 50 | 5 | | 5 | | num_workers | 8 | | | 8 | | num_devices | 1 | | | 1 | | num_nodes | 1 | | | 1 | | accelerator | auto | cpu | cpu | cpu | | strategy | auto | | | auto | | is_to_auto_requeue | False | | | False | | github_ci | False | | | False | | disable_dataset_cache | False | | | False | | target | xcu250-figd2104-2L-e | | | xcu250-figd2104-2L-e | | num_targets | 100 | | | 100 | | is_pretrained | False | | | False | | max_token_len | 512 | | | 512 | | project_dir | /home/laurie2905/mase/ma | | | /home/laurie2905/mase/ma | | | se_output | | | se_output | | project | None | jsc-tiny | | jsc-tiny | | model | None | jsc-tiny | | jsc-tiny | | dataset | None | jsc | | jsc | +-------------------------+--------------------------+--------------+-----------------+--------------------------+ INFO Initialising model 'jsc-tiny'... INFO Initialising dataset 'jsc'... INFO Project will be created at /home/laurie2905/mase/mase_output/jsc-tiny INFO Transforming model 'jsc-tiny'... Traceback (most recent call last): File "/home/laurie2905/mase/machop/./ch", line 6, in ChopCLI().run() File "/home/laurie2905/mase/machop/chop/cli.py", line 243, in run self._run_transform() File "/home/laurie2905/mase/machop/chop/cli.py", line 350, in _run_transform transform(transformparams) File "/home/laurie2905/mase/machop/chop/actions/transform.py", line 74, in transform graph, = add_common_metadata_analysis_pass( File "/home/laurie2905/mase/machop/chop/passes/graph/analysis/add_metadata/add_common_metadata.py", line 382, in add_common_metadata_analysis_pass graph = graph_iterator_for_metadata(graph, pass_args) File "/home/laurie2905/mase/machop/chop/passes/graph/analysis/add_metadata/add_common_metadata.py", line 203, in graph_iterator_for_metadata result = modules[node.target](*args, kwargs) File "/home/laurie2905/anaconda3/envs/mase/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/laurie2905/anaconda3/envs/mase/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/laurie2905/anaconda3/envs/mase/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 171, in forward return F.batch_norm( File "/home/laurie2905/anaconda3/envs/mase/lib/python3.10/site-packages/torch/nn/functional.py", line 2478, in batch_norm return torch.batch_norm( RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument weight in method wrapper_CUDA__native_batch_norm)

dereklai1 commented 5 months ago

I get the same issue. ✋

Aaron-Zhao123 commented 5 months ago

I cant seem to reproduce this.

This seems like you are loading a gpu-trained model, but loading and transforming it on with --accelerator cpu. Can you also provide me with you train script?

zniihgnexy commented 5 months ago

https://github.com/DeepWok/mase/compare/main...zniihgnexy:mase_xinyi:device_fix try this. this works in my searching pass

Laurie2905-JOHN commented 5 months ago

Making the adjustments @zniihgnexy did not change the result. I have tried to run with both --accelerator cpu and gpu, neither work. I am just running the example file as stated in Lab2 task 7.

(mase) laurie2905@LAPTOP-LQSPNHSL:~/mase/machop$ ./ch transform --config configs/examples/jsc_toy_by_type.toml --task cls --accelerator=gpu Seed set to 0 +-------------------------+--------------------------+--------------+-----------------+--------------------------+ | Name | Default | Config. File | Manual Override | Effective | +-------------------------+--------------------------+--------------+-----------------+--------------------------+ | task | classification | cls | cls | cls | | load_name | None | | | None | | load_type | mz | | | mz | | batch_size | 128 | 512 | | 512 | | to_debug | False | | | False | | log_level | info | | | info | | report_to | tensorboard | | | tensorboard | | seed | 0 | 42 | | 42 | | quant_config | None | | | None | | training_optimizer | adam | | | adam | | trainer_precision | 16-mixed | | | 16-mixed | | learning_rate | 1e-05 | 0.01 | | 0.01 | | weight_decay | 0 | | | 0 | | max_epochs | 20 | 5 | | 5 | | max_steps | -1 | | | -1 | | accumulate_grad_batches | 1 | | | 1 | | log_every_n_steps | 50 | 5 | | 5 | | num_workers | 8 | | | 8 | | num_devices | 1 | | | 1 | | num_nodes | 1 | | | 1 | | accelerator | auto | cpu | gpu | gpu | | strategy | auto | | | auto | | is_to_auto_requeue | False | | | False | | github_ci | False | | | False | | disable_dataset_cache | False | | | False | | target | xcu250-figd2104-2L-e | | | xcu250-figd2104-2L-e | | num_targets | 100 | | | 100 | | is_pretrained | False | | | False | | max_token_len | 512 | | | 512 | | project_dir | /home/laurie2905/mase/ma | | | /home/laurie2905/mase/ma | | | se_output | | | se_output | | project | None | jsc-tiny | | jsc-tiny | | model | None | jsc-tiny | | jsc-tiny | | dataset | None | jsc | | jsc | +-------------------------+--------------------------+--------------+-----------------+--------------------------+ INFO Initialising model 'jsc-tiny'... INFO Initialising dataset 'jsc'... INFO Project will be created at /home/laurie2905/mase/mase_output/jsc-tiny INFO Transforming model 'jsc-tiny'... Traceback (most recent call last): File "/home/laurie2905/mase/machop/./ch", line 6, in ChopCLI().run() File "/home/laurie2905/mase/machop/chop/cli.py", line 243, in run self._run_transform() File "/home/laurie2905/mase/machop/chop/cli.py", line 350, in _run_transform transform(transformparams) File "/home/laurie2905/mase/machop/chop/actions/transform.py", line 74, in transform graph, = add_common_metadata_analysis_pass( File "/home/laurie2905/mase/machop/chop/passes/graph/analysis/add_metadata/add_common_metadata.py", line 382, in add_common_metadata_analysis_pass graph = graph_iterator_for_metadata(graph, pass_args) File "/home/laurie2905/mase/machop/chop/passes/graph/analysis/add_metadata/add_common_metadata.py", line 203, in graph_iterator_for_metadata result = modules[node.target](*args, kwargs) File "/home/laurie2905/anaconda3/envs/mase/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/laurie2905/anaconda3/envs/mase/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/laurie2905/anaconda3/envs/mase/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 171, in forward return F.batch_norm( File "/home/laurie2905/anaconda3/envs/mase/lib/python3.10/site-packages/torch/nn/functional.py", line 2478, in batch_norm return torch.batch_norm( RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument weight in method wrapper_CUDA__native_batch_norm)

Aaron-Zhao123 commented 5 months ago

The only way I can see this happening is:

  1. You have a GPU-enabled torch installation
  2. You trained the model without using the GPU (maybe with --accelerator cpu, if you run --accelerator auto, it should have detected the GPU)
  3. Running transform (transform will auto-detect the env and pick GPU if you have one, --accelerator only controls the device usage in training but not in this case).

I have made a fix in fix/transform-force-device-match. Please checkout to that branch and give it a try. I cannot really reproduce this because I dont have the same hardware setup, but let me know whether this works.

Thanks. Aaron

dereklai1 commented 5 months ago

Forcing it to move the model to GPU seems to have fixed it for me. Thanks!

Not sure if this is correct, but I'm just speculating:

I did my training on GPU and it seems like when the state_dict/chkpt was saved, the state will contain information about being on the GPU device. When we loaded it again for the transform there doesn't seem to be anything to move the input tensors to the correct device before calling modules[node.target](*args, **kwargs) in graph_iterator_for_metadata so the model is on GPU while the inputs are on CPU?

https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices

Aaron-Zhao123 commented 5 months ago

Glad to hear that it worked.

I managed to replicate this only when I train with --accelerator cpu on a gpu-enabled machine...

In transform, the input is fetched through get_dumm_input and this automatically moves the input tensor to gpu if you have installed pytorch-cuda.

So it feels like to me you have trained the model or saved the checkpoint relates to cpu. When loading it out the model is on cpu while the input is on gpu.