NVIDIA / TensorRT

NVIDIA® TensorRT™ is an SDK for high-performance deep learning inference on NVIDIA GPUs. This repository contains the open source components of TensorRT.
https://developer.nvidia.com/tensorrt
Apache License 2.0
10.15k stars 2.08k forks source link

Engine build failure "cuda misaligned address" of TensorRT 10.1 when running fp16 group normalization with particular value of `num_groups` #3956

Open haijieg opened 1 week ago

haijieg commented 1 week ago

Description

I believe this is regression of 10.1 on the normalization layer. It happens when FP16 mode is on and particular value of num_groups. When building a network with group norm in FP16 mode and a particular value of num_groups, which seems to be not a multiple of 8, the build failed with "Cuda Runtime (misaligned address)"

[06/20/2024-17:52:33] [TRT] [E] [builderUtils.cpp::operator()::969] Error Code 1: Cuda Runtime (misaligned address)
[06/20/2024-17:52:33] [TRT] [E] [defaultAllocator.cpp::deallocate::52] Error Code 1: Cuda Runtime (misaligned address)
[06/20/2024-17:52:33] [TRT] [E] [graphContext.h::~MyelinGraphContext::72] Error Code 1: Myelin ([impl.cpp:cuda_object_deallocate:432] Error 716 destroying stream '0x559119d1b360'.)
[06/20/2024-17:52:33] [TRT] [E] Error Code: 9: Skipping tactic 0x0000000000000000 due to exception misaligned address
[06/20/2024-17:52:33] [TRT] [W] Unable to determine GPU memory usage: misaligned address
[06/20/2024-17:52:33] [TRT] [W] Unable to determine GPU memory usage: misaligned address
[06/20/2024-17:52:33] [TRT] [W] Unable to determine GPU memory usage: misaligned address
[06/20/2024-17:52:33] [TRT] [W] Unable to determine GPU memory usage: misaligned address
[06/20/2024-17:52:33] [TRT] [E] Error Code: 9: Skipping tactic 0x0000000000000000 due to exception [autotuner.cpp:autotuner_t:460] In the autotuner, CUDA error 716 from 'cuMemHostAlloc(&reinterpret_cast<void*&>(wait_kerne_host_mem_ptr_), sizeof(*wait_kerne_host_mem_ptr_), CU_MEMHOSTALLOC_DEVICEMAP)': misaligned address.

[06/20/2024-17:52:33] [TRT] [E] [virtualMemoryBuffer.cpp::getFreeMemSize::183] Error Code 1: Cuda Runtime (misaligned address)
[06/20/2024-17:52:33] [TRT] [W] Requested amount of GPU memory (15728640 bytes) could not be allocated. There may not be enough free memory for allocation to succeed.
[06/20/2024-17:52:33] [TRT] [W] UNSUPPORTED_STATE: Skipping tactic 0 due to insufficient memory on requested size of 15728640 detected for tactic 0x00000000000003e8.
[06/20/2024-17:52:33] [TRT] [E] [virtualMemoryBuffer.cpp::getFreeMemSize::183] Error Code 1: Cuda Runtime (misaligned address)
[06/20/2024-17:52:33] [TRT] [W] Requested amount of GPU memory (15728640 bytes) could not be allocated. There may not be enough free memory for allocation to succeed.
[06/20/2024-17:52:33] [TRT] [W] UNSUPPORTED_STATE: Skipping tactic 1 due to insufficient memory on requested size of 15728640 detected for tactic 0x00000000000003ea.
[06/20/2024-17:52:33] [TRT] [E] [virtualMemoryBuffer.cpp::getFreeMemSize::183] Error Code 1: Cuda Runtime (misaligned address)
[06/20/2024-17:52:33] [TRT] [W] Requested amount of GPU memory (15728640 bytes) could not be allocated. There may not be enough free memory for allocation to succeed.
[06/20/2024-17:52:33] [TRT] [W] UNSUPPORTED_STATE: Skipping tactic 2 due to insufficient memory on requested size of 15728640 detected for tactic 0x0000000000000000.
[06/20/2024-17:52:33] [TRT] [E] Error Code: 2: Assertion !cost.empty() failed. Impossible to reformat.
[06/20/2024-17:52:33] [TRT] [E] [scopedCudaResources.cpp::~ScopedCudaStream::43] Error Code 1: Cuda Runtime (misaligned address)
[06/20/2024-17:52:33] [TRT] [E] [optimizer.cpp::computeCosts::4519] Error Code 2: Internal Error (Assertion !cost.empty() failed. Impossible to reformat.)

Environment

TensorRT Version: 10.1

NVIDIA GPU: RTX 3070

NVIDIA Driver Version: 535.129.03

CUDA Version: 12.1

CUDNN Version: NA

Operating System: Ubuntu 22.04

Python Version (if applicable): 3.10

Baremetal or Container (if so, version): Baremetal

Relevant Files

Steps To Reproduce

import tensorrt as trt
import numpy as np
import operator
from functools import reduce

def affine_group_norm(network, x, num_groups, scale, bias, epsilon):
    ranks = len(x.shape)
    _shape = [1] * ranks
    _shape[1] = num_groups
    dummy_w = network.add_constant(_shape, np.ones(_shape, dtype=np.float16)).get_output(0)
    dummy_b = network.add_constant(_shape, np.zeros(_shape, dtype=np.float16)).get_output(0)
    axesMask = reduce(operator.or_, (1 << i for i in range(2, ranks)))
    norm_layer = network.add_normalization(x, dummy_w, dummy_b, axesMask=axesMask)
    norm_layer.num_groups = num_groups
    norm_layer.epsilon = epsilon
    output = norm_layer.get_output(0)
    power = np.ones_like(scale)
    scale_layer = network.add_scale(output,
                                    trt.ScaleMode.CHANNEL,
                                    shift=trt.Weights(bias),
                                    scale=trt.Weights(scale),
                                    power=trt.Weights(power))
    scale_layer.channel_axis = 1
    output = scale_layer.get_output(0)
    return output

def test_group_norm(num_groups, fp16):
    builder = trt.Builder(trt.Logger())
    config = builder.create_builder_config()
    if fp16:
        config.flags |= 1 << int(trt.BuilderFlag.FP16)
        config.flags |= 1 << int(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS)

    shape = (2, 320, 64, 64)
    network = builder.create_network(0)
    x = network.add_input('x', trt.float16, shape)
    scale = np.random.randn(shape[1]).astype(np.float16)
    bias = np.random.randn(shape[1]).astype(np.float16)
    y = affine_group_norm(network, x, num_groups, scale, bias, 1e-5)
    network.mark_output(y)
    engine = builder.build_serialized_network(network, config)
    assert engine is not None
    print("pass")

test_group_norm(32, fp16=False)
test_group_norm(10, fp16=False)
test_group_norm(32, fp16=True)
test_group_norm(10, fp16=True)  # fails
lix19937 commented 1 week ago

Usually u/int8_t and int16_t arrays are not usually aligned at 4 byte boundaries.

You can try forcing the appropriate alignment. Like change 'np.float16' to 'np.float32' in your affine_group_norm() and test_group_norm().

haijieg commented 1 week ago

@lix19937 Here's a repro shows that even setting np.float32 for dummy_w, dummy_b, scale, bias, it still error with Error Code 1: Cuda Runtime (misaligned address).

import tensorrt as trt
import numpy as np
import operator
from functools import reduce

def affine_group_norm(network, x, num_groups, scale, bias, epsilon):
    ranks = len(x.shape)
    _shape = [1] * ranks
    _shape[1] = num_groups
    dummy_w = network.add_constant(_shape, np.ones(_shape, dtype=np.float32)).get_output(0)
    dummy_b = network.add_constant(_shape, np.zeros(_shape, dtype=np.float32)).get_output(0)
    axesMask = reduce(operator.or_, (1 << i for i in range(2, ranks)))
    norm_layer = network.add_normalization(x, dummy_w, dummy_b, axesMask=axesMask)
    norm_layer.num_groups = num_groups
    norm_layer.epsilon = epsilon
    output = norm_layer.get_output(0)
    power = np.ones_like(scale)
    scale_layer = network.add_scale(output,
                                    trt.ScaleMode.CHANNEL,
                                    shift=trt.Weights(bias),
                                    scale=trt.Weights(scale),
                                    power=trt.Weights(power))
    scale_layer.channel_axis = 1
    output = scale_layer.get_output(0)
    return output

def test_group_norm(num_groups, fp16):
    builder = trt.Builder(trt.Logger())
    config = builder.create_builder_config()
    if fp16:
        config.flags |= 1 << int(trt.BuilderFlag.FP16)
        config.flags |= 1 << int(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS)

    shape = (2, 320, 64, 64)
    network = builder.create_network(0)
    x = network.add_input('x', trt.float16, shape)
    scale = np.random.randn(shape[1]).astype(np.float32)
    bias = np.random.randn(shape[1]).astype(np.float32)
    y = affine_group_norm(network, x, num_groups, scale, bias, 1e-5)
    network.mark_output(y)
    engine = builder.build_serialized_network(network, config)
    assert engine is not None
    print("pass")

test_group_norm(32, fp16=False)
test_group_norm(10, fp16=False)
test_group_norm(32, fp16=True)
test_group_norm(10, fp16=True)  # fails
lix19937 commented 1 week ago

Can you use torch api to export an onnx of affine_group_norm, then use trtexec convert ? @haijieg

haijieg commented 1 week ago

Can you use torch api to export an onnx of affine_group_norm, then use trtexec convert ? @haijieg

@lix19937 No, I want to directly control how network is built using TRT network/builder API instead of going through ONNX. This is the minimal repro without torch/onnx that shows unexpected behavior of TRT public API surface. I should not need to use onnx/torch to build a network correctly without crashing with Cuda Misalign Error.