intel / intel-extension-for-pytorch

A Python package for extending the official PyTorch that can easily obtain performance on Intel platform
Apache License 2.0
1.54k stars 234 forks source link

Load model trained in XPU and fails to continue training #649

Closed hermanhmchan closed 1 month ago

hermanhmchan commented 3 months ago

Describe the bug

I tried to load a model and continue the training (referring to https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html#load-the-general-checkpoint). But it gives the error "RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and cpu!"

The following code can reproduce the issue.

import torch
import torchvision

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

LR = 0.001
DOWNLOAD = True
DATA = "datasets/cifar10/"

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)
train_dataset = torchvision.datasets.CIFAR10(
    root=DATA,
    train=True,
    transform=transform,
    download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128)

model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)

PATH = "checkpoint.pth"

#################### Load model start #############################
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#################### Load model end ###############################

model.train()
######################## code changes #######################
model = model.to("xpu")
criterion = criterion.to("xpu")
model, optimizer = ipex.optimize(model, optimizer=optimizer)
######################## code changes #######################

for batch_idx, (data, target) in enumerate(train_loader):
    ########## code changes ##########
    data = data.to("xpu")
    target = target.to("xpu")
    ########## code changes ##########
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    print(batch_idx)
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    },
    PATH,
)

print("Execution finished")

Output as follows:

Files already downloaded and verified
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 54
     52     loss = criterion(output, target)
     53     loss.backward()
---> 54     optimizer.step()
     55     print(batch_idx)
     56 torch.save(
     57     {
     58         "model_state_dict": model.state_dict(),
   (...)
     61     PATH,
     62 )

File [~/torch/lib/python3.10/site-packages/torch/utils/_contextlib.py:115](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/torch/utils/_contextlib.py#line=114), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File [~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py:521](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py#line=520), in sgd_step(self, closure)
    518         param2 = get_param2(p, self.params_attr)
    519         params2.append(param2)
--> 521 sgd(
    522     params_with_grad,
    523     params2,
    524     d_p_list,
    525     momentum_buffer_list,
    526     weight_decay=group["weight_decay"],
    527     momentum=group["momentum"],
    528     lr=group["lr"],
    529     dampening=group["dampening"],
    530     nesterov=group["nesterov"],
    531     maximize=group["maximize"],
    532     has_sparse_grad=has_sparse_grad,
    533     foreach=group["foreach"],
    534     fused=self.fused,
    535 )
    537 # update momentum_buffers in state
    538 for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):

File [~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py:464](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py#line=463), in sgd(params, params2, d_p_list, momentum_buffer_list, has_sparse_grad, foreach, weight_decay, momentum, lr, dampening, nesterov, maximize, fused)
    461 else:
    462     func = _single_tensor_sgd
--> 464 func(
    465     params,
    466     params2,
    467     d_p_list,
    468     momentum_buffer_list,
    469     weight_decay=weight_decay,
    470     momentum=momentum,
    471     lr=lr,
    472     dampening=dampening,
    473     nesterov=nesterov,
    474     has_sparse_grad=has_sparse_grad,
    475     maximize=maximize,
    476     fused=fused,
    477 )

File [~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py:319](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py#line=318), in _single_tensor_sgd(params, params2, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, has_sparse_grad, fused)
    317 grad = grads[i] if not maximize else -grads[i]
    318 if not grad.is_sparse:
--> 319     momentum_buffer_list[i] = torch.ops.torch_ipex.sgd_fused_step(
    320         param,
    321         grad,
    322         momentum_buffer_list[i],
    323         params2[i],
    324         momentum,
    325         lr,
    326         weight_decay,
    327         dampening,
    328         nesterov,
    329     )
    330     continue
    332 if (
    333     param.dtype == torch.bfloat16
    334     and grad.is_sparse
   (...)
    338 ):
    339     # packed_add can support sparse tensor

File [~/torch/lib/python3.10/site-packages/torch/_ops.py:692](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/torch/_ops.py#line=691), in OpOverloadPacket.__call__(self, *args, **kwargs)
    687 def __call__(self, *args, **kwargs):
    688     # overloading __call__ to ensure torch.ops.foo.bar()
    689     # is still callable from JIT
    690     # We save the function ptr as the `op` attribute on
    691     # OpOverloadPacket to access it here.
--> 692     return self._op(*args, **kwargs or {})

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and cpu!

Versions

Collecting environment information... PyTorch version: 2.1.0.post2+cxx11.abi PyTorch CXX11 ABI: Yes IPEX version: 2.1.30+xpu IPEX commit: 474a6b3cb Build type: Release

OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: N/A IGC version: N/A CMake version: N/A Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-5.15.146.1-microsoft-standard-WSL2-x86_64-with-glibc2.35 Is XPU available: True DPCPP runtime version: N/A MKL version: N/A GPU models and configuration: [0] _DeviceProperties(name='Intel(R) Graphics [0x7d55]', platform_name='Intel(R) Level-Zero', dev_type='gpu', driver_version='1.3.27642', has_fp64=1, total_memory=30234MB, max_compute_units=128, gpu_eu_count=128) Intel OpenCL ICD version: 23.43.27642.40-803~22.04 Level Zero version: 1.3.27642.40-803~22.04

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 22 On-line CPU(s) list: 0-21 Vendor ID: GenuineIntel Model name: Intel(R) Core(TM) Ultra 7 155H CPU family: 6 Model: 170 Thread(s) per core: 2 Core(s) per socket: 11 Socket(s): 1 Stepping: 4 BogoMIPS: 5990.39 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities Virtualization: VT-x Hypervisor vendor: Microsoft Virtualization type: full L1d cache: 528 KiB (11 instances) L1i cache: 704 KiB (11 instances) L2 cache: 22 MiB (11 instances) L3 cache: 24 MiB (1 instance) Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Mitigation; Enhanced IBRS Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] intel-extension-for-pytorch==2.1.30+xpu [pip3] numpy==1.26.4 [pip3] torch==2.1.0.post2+cxx11.abi [pip3] torchaudio==2.1.0.post2+cxx11.abi [pip3] torchvision==0.16.0.post2+cxx11.abi [conda] N/A

vishnumadhu365 commented 3 months ago

@hermanhmchan thanks for reporting the issue, let me reproduce it and get back to you

hermanhmchan commented 3 months ago

@vishnumadhu365 can you simulate the issue? please let me know if you need further information. thanks.

vishnumadhu365 commented 3 months ago

@hermanhmchan Yes, I could recreate the issue based on the code snippet you had shared. The issue occurs from optimizer.load_state_dict(), the internal cast which infers the device based on where the model parameters are. The fix for this issue would be to load the optimizer state_dict after the model has been moved to the xpu. Below snippet depicts the fix. Give it a try and let me know if it worked for you.

....
......
PATH = "checkpoint.pth"

#################### Load model start #############################
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
#################### Load model end ###############################
model.train()
######################## code changes #######################
model = model.to("xpu")
criterion = criterion.to("xpu")
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  ## --> This fixes the issue
model, optimizer = ipex.optimize(model, optimizer=optimizer)
######################## code changes #######################

for batch_idx, (data, target) in enumerate(train_loader):
    ########## code changes ##########
......
....
..
hermanhmchan commented 3 months ago

@vishnumadhu365 it can run without error. but seems the weights are not loaded or reinitialized? Just cannot resume the training. I am not sure whether it is weights problem or something else.

vishnumadhu365 commented 3 months ago

Can you share more details on why it seems the training is not resuming ?

When I tested I am seeing the following,

  1. Train from scratch Loss starts at 6.9 image

  2. Saved checkpoint at batch 200 Loss at 2.1 image

  3. Stopped the training process and resumed training from checkpoint Loss starts at 2.1 (from earlier checkpoint) image

devpramod commented 1 month ago

Closing due to inactivity