kohya-ss / sd-scripts

Apache License 2.0
5.31k stars 880 forks source link

SD3 intel_pytorch_extension problem #1421

Closed OliviaOliveiira closed 4 months ago

OliviaOliveiira commented 4 months ago

image Whenever I start training, I get the described error. I tried to look up the intel-pytorch-extension whl for windows, and none of them seem to do the job, I get more and more errors related to it. Is it only on Linux, or there is another way to solve the issue?

kohya-ss commented 4 months ago

Do you use Intel ARC GPU? I think Intel ARC only works with Linux or WSL.

OliviaOliveiira commented 4 months ago

Do you use Intel ARC GPU? I think Intel ARC only works with Linux or WSL.

Nope, RTX 3090, that's why its weird it tries to call it

kohya-ss commented 4 months ago

Very weird... I think it means the following line is called and HAS_XPU is set to True.

https://github.com/kohya-ss/sd-scripts/blob/b8896aad400222c8c4441b217fda0f9bb0807ffd/library/device_utils.py#L19

Is library/__init__.py empty? I wonder the file may be accidentally overwritten.

https://github.com/kohya-ss/sd-scripts/blob/sd3/library/__init__.py

OliviaOliveiira commented 4 months ago

here's my library/init.py

import os
import sys
import contextlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from .hijacks import ipex_hijacks

# pylint: disable=protected-access, missing-function-docstring, line-too-long

def ipex_init(): # pylint: disable=too-many-statements
    try:
        if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
            return True, "Skipping IPEX hijack"
        else:
            # Replace cuda with xpu:
            torch.cuda.current_device = torch.xpu.current_device
            torch.cuda.current_stream = torch.xpu.current_stream
            torch.cuda.device = torch.xpu.device
            torch.cuda.device_count = torch.xpu.device_count
            torch.cuda.device_of = torch.xpu.device_of
            torch.cuda.get_device_name = torch.xpu.get_device_name
            torch.cuda.get_device_properties = torch.xpu.get_device_properties
            torch.cuda.init = torch.xpu.init
            torch.cuda.is_available = torch.xpu.is_available
            torch.cuda.is_initialized = torch.xpu.is_initialized
            torch.cuda.is_current_stream_capturing = lambda: False
            torch.cuda.set_device = torch.xpu.set_device
            torch.cuda.stream = torch.xpu.stream
            torch.cuda.synchronize = torch.xpu.synchronize
            torch.cuda.Event = torch.xpu.Event
            torch.cuda.Stream = torch.xpu.Stream
            torch.cuda.FloatTensor = torch.xpu.FloatTensor
            torch.Tensor.cuda = torch.Tensor.xpu
            torch.Tensor.is_cuda = torch.Tensor.is_xpu
            torch.nn.Module.cuda = torch.nn.Module.xpu
            torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
            torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
            torch.cuda._initialized = torch.xpu.lazy_init._initialized
            torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
            torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
            torch.cuda._tls = torch.xpu.lazy_init._tls
            torch.cuda.threading = torch.xpu.lazy_init.threading
            torch.cuda.traceback = torch.xpu.lazy_init.traceback
            torch.cuda.Optional = torch.xpu.Optional
            torch.cuda.__cached__ = torch.xpu.__cached__
            torch.cuda.__loader__ = torch.xpu.__loader__
            torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
            torch.cuda.Tuple = torch.xpu.Tuple
            torch.cuda.streams = torch.xpu.streams
            torch.cuda._lazy_new = torch.xpu._lazy_new
            torch.cuda.FloatStorage = torch.xpu.FloatStorage
            torch.cuda.Any = torch.xpu.Any
            torch.cuda.__doc__ = torch.xpu.__doc__
            torch.cuda.default_generators = torch.xpu.default_generators
            torch.cuda.HalfTensor = torch.xpu.HalfTensor
            torch.cuda._get_device_index = torch.xpu._get_device_index
            torch.cuda.__path__ = torch.xpu.__path__
            torch.cuda.Device = torch.xpu.Device
            torch.cuda.IntTensor = torch.xpu.IntTensor
            torch.cuda.ByteStorage = torch.xpu.ByteStorage
            torch.cuda.set_stream = torch.xpu.set_stream
            torch.cuda.BoolStorage = torch.xpu.BoolStorage
            torch.cuda.os = torch.xpu.os
            torch.cuda.torch = torch.xpu.torch
            torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
            torch.cuda.Union = torch.xpu.Union
            torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
            torch.cuda.ShortTensor = torch.xpu.ShortTensor
            torch.cuda.LongTensor = torch.xpu.LongTensor
            torch.cuda.IntStorage = torch.xpu.IntStorage
            torch.cuda.LongStorage = torch.xpu.LongStorage
            torch.cuda.__annotations__ = torch.xpu.__annotations__
            torch.cuda.__package__ = torch.xpu.__package__
            torch.cuda.__builtins__ = torch.xpu.__builtins__
            torch.cuda.CharTensor = torch.xpu.CharTensor
            torch.cuda.List = torch.xpu.List
            torch.cuda._lazy_init = torch.xpu._lazy_init
            torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
            torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
            torch.cuda.ByteTensor = torch.xpu.ByteTensor
            torch.cuda.StreamContext = torch.xpu.StreamContext
            torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
            torch.cuda.ShortStorage = torch.xpu.ShortStorage
            torch.cuda._lazy_call = torch.xpu._lazy_call
            torch.cuda.HalfStorage = torch.xpu.HalfStorage
            torch.cuda.random = torch.xpu.random
            torch.cuda._device = torch.xpu._device
            torch.cuda.classproperty = torch.xpu.classproperty
            torch.cuda.__name__ = torch.xpu.__name__
            torch.cuda._device_t = torch.xpu._device_t
            torch.cuda.warnings = torch.xpu.warnings
            torch.cuda.__spec__ = torch.xpu.__spec__
            torch.cuda.BoolTensor = torch.xpu.BoolTensor
            torch.cuda.CharStorage = torch.xpu.CharStorage
            torch.cuda.__file__ = torch.xpu.__file__
            torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
            # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing

            # Memory:
            torch.cuda.memory = torch.xpu.memory
            if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
                torch.xpu.empty_cache = lambda: None
            torch.cuda.empty_cache = torch.xpu.empty_cache
            torch.cuda.memory_stats = torch.xpu.memory_stats
            torch.cuda.memory_summary = torch.xpu.memory_summary
            torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
            torch.cuda.memory_allocated = torch.xpu.memory_allocated
            torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
            torch.cuda.memory_reserved = torch.xpu.memory_reserved
            torch.cuda.memory_cached = torch.xpu.memory_reserved
            torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
            torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
            torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
            torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
            torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
            torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
            torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats

            # RNG:
            torch.cuda.get_rng_state = torch.xpu.get_rng_state
            torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
            torch.cuda.set_rng_state = torch.xpu.set_rng_state
            torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
            torch.cuda.manual_seed = torch.xpu.manual_seed
            torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
            torch.cuda.seed = torch.xpu.seed
            torch.cuda.seed_all = torch.xpu.seed_all
            torch.cuda.initial_seed = torch.xpu.initial_seed

            # AMP:
            torch.cuda.amp = torch.xpu.amp
            torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
            torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype

            if not hasattr(torch.cuda.amp, "common"):
                torch.cuda.amp.common = contextlib.nullcontext()
            torch.cuda.amp.common.amp_definitely_not_available = lambda: False

            try:
                torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
            except Exception: # pylint: disable=broad-exception-caught
                try:
                    from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
                    gradscaler_init()
                    torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
                except Exception: # pylint: disable=broad-exception-caught
                    torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler

            # C
            torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
            ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
            ipex._C._DeviceProperties.major = 2024
            ipex._C._DeviceProperties.minor = 0

            # Fix functions with ipex:
            torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
            torch._utils._get_available_device_type = lambda: "xpu"
            torch.has_cuda = True
            torch.cuda.has_half = True
            torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
            torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
            torch.backends.cuda.is_built = lambda *args, **kwargs: True
            torch.version.cuda = "12.1"
            torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
            torch.cuda.get_device_properties.major = 12
            torch.cuda.get_device_properties.minor = 1
            torch.cuda.ipc_collect = lambda *args, **kwargs: None
            torch.cuda.utilization = lambda *args, **kwargs: 0

            ipex_hijacks()
            if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
                try:
                    from .diffusers import ipex_diffusers
                    ipex_diffusers()
                except Exception: # pylint: disable=broad-exception-caught
                    pass
            torch.cuda.is_xpu_hijacked = True
    except Exception as e:
        return False, e
    return True, None
OliviaOliveiira commented 4 months ago

Yeah, looks like everything went wrong for some reason, even though I've never launched solely via cpu or an intel GPU.. How do I fix this? Simply deleting the folder and git pulling it again should do the job, right?)

OliviaOliveiira commented 4 months ago

though I don't see HAS_XPU as true

import functools
import gc

import torch

try:
    HAS_CUDA = torch.cuda.is_available()
except Exception:
    HAS_CUDA = False

try:
    HAS_MPS = torch.backends.mps.is_available()
except Exception:
    HAS_MPS = False

try:
    import intel_extension_for_pytorch as ipex  # noqa

    HAS_XPU = torch.xpu.is_available()
except Exception:
    HAS_XPU = False

def clean_memory():
    gc.collect()
    if HAS_CUDA:
        torch.cuda.empty_cache()
    if HAS_XPU:
        torch.xpu.empty_cache()
    if HAS_MPS:
        torch.mps.empty_cache()

def clean_memory_on_device(device: torch.device):
    r"""
    Clean memory on the specified device, will be called from training scripts.
    """
    gc.collect()

    # device may "cuda" or "cuda:0", so we need to check the type of device
    if device.type == "cuda":
        torch.cuda.empty_cache()
    if device.type == "xpu":
        torch.xpu.empty_cache()
    if device.type == "mps":
        torch.mps.empty_cache()

@functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device:
    r"""
    Do not call this function from training scripts. Use accelerator.device instead.
    """
    if HAS_CUDA:
        device = torch.device("cuda")
    elif HAS_XPU:
        device = torch.device("xpu")
    elif HAS_MPS:
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    print(f"get_preferred_device() -> {device}")
    return device

def init_ipex():
    """
    Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.

    This function should run right after importing torch and before doing anything else.

    If IPEX is not available, this function does nothing.
    """
    try:
        if HAS_XPU:
            from library.ipex import ipex_init

            is_initialized, error_message = ipex_init()
            if not is_initialized:
                print("failed to initialize ipex:", error_message)
        else:
            return
    except Exception as e:
        print("failed to initialize ipex:", e)
kohya-ss commented 4 months ago

library/__init__.py is seemed to be overwritten.

Simply deleting the folder and git pulling it again should do the job, right?)

This may solve the issue :)

OliviaOliveiira commented 4 months ago

library/__init__.py is seemed to be overwritten.

Simply deleting the folder and git pulling it again should do the job, right?)

This may solve the issue :)

Already tried to empty it and run the training and even though it's empty during and after training, it still tries to call that xpu. UPDATE: Git cloned it from the start to ensure everything's clean, and still get this very error.. Is my venv broken, or how is this even possible?)

OliviaOliveiira commented 4 months ago

well, yeah, looks like my venv was broken, the training has started now. thanks!