intel / intel-extension-for-pytorch

A Python package for extending the official PyTorch that can easily obtain performance on Intel platform
Apache License 2.0
1.53k stars 236 forks source link

torch.bmm on igpu causes serious system serious unknown errors and leads to system crashes. #459

Closed leonardozcm closed 9 months ago

leonardozcm commented 10 months ago

Describe the bug

When running Whisper-Medium on an iGPU, serious unknown errors occur, leading to system crashes. And I think the memory usage of the iGPU (about 2GB) is far from reaching the system's upper limit.

To reproduce:

import torch
import intel_extension_for_pytorch
import transformers
import librosa

from transformers import WhisperConfig, WhisperProcessor
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder

import argparse
parser = argparse.ArgumentParser(description='Recognize Tokens using `transcribe()` API for Openai Whisper model')
parser.add_argument('--repo-id-or-model-path', type=str, default="./whisper-small",
                    help='The huggingface repo id for the Whisper model to be downloaded'
                            ', or the path to the huggingface checkpoint folder')
parser.add_argument('--device', type=str, default="xpu")
args = parser.parse_args()

model_path = args.repo_id_or_model_path
device = args.device
config = WhisperConfig.from_pretrained(model_path)

encoder = WhisperEncoder(config).to(device)
processor = WhisperProcessor.from_pretrained(model_path)

# Path to the .wav audio file
audio_file_path = "extracted_audio.wav"

# Load the audio using soundfile
# audio, sample_rate = sf.read(audio_file_path)
# print(audio, sample_rate)

# Load the input audio
y, sr = librosa.load(audio_file_path, sr=None)

# Downsample the audio to 16kHz
target_sr = 16000
audio = librosa.resample(y,
                         orig_sr=sr,
                         target_sr=target_sr)
audio = audio[int(0.1*len(audio)):int(0.3*len(audio))]

input_features = processor(audio,
                            sampling_rate=16000,
                            return_tensors="pt").input_features.to(device)

for iter in range(100):
    import time
    print(f"turn {iter}")
    st = time.perf_counter()
    encoder_output = encoder(input_features)

    print(f"encoder inference time:{time.perf_counter()-st}")

print(encoder_output)

For the audio file you may refer to https://github.com/intel-analytics/BigDL/issues/8793#issuecomment-1690927181

leonardozcm commented 10 months ago

It seems to be related to torch.bmm, pls try the following code on igpu:

import torch
import intel_extension_for_pytorch

import torch.nn as nn

torch_bmm  = torch.bmm

device = 'xpu'

def test_func():
    key_states = torch.ones((12, 1500, 64)).to(device)
    query_states = torch.ones((12, 1500, 64)).to(device)
    attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
    torch.xpu.synchronize()
    return attn_weights

for iter in range(100):
    print(f"iter {iter}")
    attn_weights  = test_func()