NVlabs / tiny-cuda-nn

Lightning fast C++/CUDA neural network framework
Other
3.77k stars 458 forks source link

Warp Misaligned Address with Cutlass when building against 70 #83

Open hturki opened 2 years ago

hturki commented 2 years ago

When running the repro below, I get a Warp Misaligned Address Exception:

#include "tiny-cuda-nn/trainer.h"

int main(int argc, char **argv) {
    nlohmann::json network_config = {
            {"n_input_dims",      8},
            {"otype",             "CutlassMLP"},
            {"output_activation", "None"},
            {"n_hidden_layers",   0},
            {"n_output_dims",     8}
    };

    std::shared_ptr<tcnn::Network<tcnn::network_precision_t>> network = std::shared_ptr<tcnn::Network<tcnn::network_precision_t>>(
            tcnn::create_network<tcnn::network_precision_t>(network_config));

    nlohmann::json optimizer_config = {
            {"otype",  "Ema"},
            {"decay",  0.95},
            {"nested", {
                               {"otype", "ExponentialDecay"},
                               {"decay_start", 20000},
                               {"decay_interval", 20000},
                               {"decay_base", 0.95},
                               {"nested", {
                                                  {"otype", "Adam"},
                                                  {"learning_rate", 1e-2},
                                                  {"beta1", 0.9},
                                                  {"beta2", 0.99},
                                                  {"epsilon", 1e-15},
                                                  {"l2_reg", 1e-6}
                                          }}
                       }}
    };

    auto optimizer = std::shared_ptr<tcnn::Optimizer<tcnn::network_precision_t>>(
            tcnn::create_optimizer<tcnn::network_precision_t>(optimizer_config));

    nlohmann::json loss_config = {
            {"otype", "L2"}
    };

    auto loss = std::shared_ptr<tcnn::Loss<tcnn::network_precision_t>>(
            tcnn::create_loss<tcnn::network_precision_t>(loss_config));

    auto m_trainer = std::make_shared<tcnn::Trainer<tcnn::network_precision_t, tcnn::network_precision_t, tcnn::network_precision_t>>
            (network, optimizer, loss, 42);

    tcnn::network_precision_t input[64];
    for (int i = 0; i < 64; i++) {
        input[i] = i / 64;
    }

    auto input_gpu = tcnn::GPUMatrix<tcnn::network_precision_t>(8, 8);

    CUDA_CHECK_THROW(cudaMemcpy(input_gpu.data(), &input, input_gpu.n_bytes(), cudaMemcpyHostToDevice));

    auto output_gpu = tcnn::GPUMatrix<tcnn::network_precision_t>(8, 8);
    network->inference_mixed_precision(input_gpu, output_gpu, false);

    CUDA_CHECK_THROW(cudaDeviceSynchronize());
}

With the following output in cuda-gdb:

CUDA Exception: Warp Misaligned Address
The exception was triggered at PC 0x555558dc5f40 (mma_tensor_op_tile_iterator_sm70.h:1728)

Thread 1 "debug" received signal CUDA_EXCEPTION_6, Warp Misaligned Address.
[Switching focus to CUDA kernel 0, grid 2, block (0,0,0), thread (96,0,0), device 0, sm 0, warp 0, lane 0]
0x0000555558dc5f50 in cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<cutlass::PitchLinearShape<32, 64>, (cutlass::gemm::Operand)0, cutlass::half_t, cutlass::layout::VoltaTensorOpMultiplicandCrosswise<16, 32>, cutlass::PitchLinearShape<4, 16>, 1, 32>::load_with_byte_offset (this=0x7fffd2fff768, frag=..., byte_offset=0) at /home/cloudlet/hturki/mega-ngp/3rdparty/tiny-cuda-nn/dependencies/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h:1728
1728                uint64_t tmp = *low;
(cuda-gdb) back
#0  0x0000555558dc5f50 in cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<cutlass::PitchLinearShape<32, 64>, (cutlass::gemm::Operand)0, cutlass::half_t, cutlass::layout::VoltaTensorOpMultiplicandCrosswise<16, 32>, cutlass::PitchLinearShape<4, 16>, 1, 32>::load_with_byte_offset (this=0x7fffd2fff768, frag=..., byte_offset=0) at /home/cloudlet/hturki/instant-ngp/3rdparty/tiny-cuda-nn/dependencies/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h:1728
#1  cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<cutlass::PitchLinearShape<32, 64>, (cutlass::gemm::Operand)0, cutlass::half_t, cutlass::layout::VoltaTensorOpMultiplicandCrosswise<16, 32>, cutlass::PitchLinearShape<4, 16>, 1, 32>::load (this=0x7fffd2fff768, frag=...) at /home/cloudlet/hturki/instant-ngp/3rdparty/tiny-cuda-nn/dependencies/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h:1697
#2  cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<cutlass::MatrixShape<64, 32>, (cutlass::gemm::Operand)0, cutlass::half_t, cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, cutlass::MatrixShape<16, 4>, 1, 32>::load (this=0x7fffd2fff768, frag=...) at /home/cloudlet/hturki/instant-ngp/3rdparty/tiny-cuda-nn/dependencies/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h:2183
#3  cutlass::gemm::threadblock::MmaPipelined<cutlass::gemm::GemmShape<128, 128, 32>, cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<128, 32>, cutlass::half_t, cutlass::layout::RowMajor, 1, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, 8>, cutlass::transform::threadblock::RegularTileIterator<cutlass::MatrixShape<128, 32>, cutlass::half_t, cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, 0, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, 16>, cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::ColumnMajor, 0, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, 8>, cutlass::transform::threadblock::RegularTileIterator<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, 1, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, 16>, cutlass::half_t, cutlass::layout::RowMajor, cutlass::gemm::threadblock::MmaPolicy<cutlass::gemm::warp::MmaVoltaTensorOp<cutlass::gemm::GemmShape<64, 64, 32>, cutlass::half_t, cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, cutlass::half_t, cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, cutlass::half_t, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 16, 4>, 32, cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, cutlass::MatrixShape<1, 1> >, bool>, cutlass::MatrixShape<0, 0>, cutlass::MatrixShape<0, 0>, 1>, cutlass::NumericArrayConverter<cutlass::half_t, cutlass::half_t, 32, (cutlass::FloatRoundStyle)2>, cutlass::NumericArrayConverter<cutlass::half_t, cutlass::half_t, 32, (cutlass::FloatRoundStyle)2>, bool>::operator() (this=0x7fffd2fff768, gemm_k_iterations=1, accum=..., iterator_A=0x7fffd2fff9f0, iterator_B=0x7fffd2fffa30, src_accum=..., transform_A=..., transform_B=...)
    at /home/cloudlet/hturki/instant-ngp/3rdparty/tiny-cuda-nn/dependencies/cutlass/include/cutlass/gemm/threadblock/mma_pipelined.h:288
#4  cutlass::gemm::kernel::Gemm<cutlass::gemm::threadblock::MmaPipelined<cutlass::gemm::GemmShape<128, 128, 32>, cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<128, 32>, cutlass::half_t, cutlass::layout::RowMajor, 1, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, 8>, cutlass::transform::threadblock::RegularTileIterator<cutlass::MatrixShape<128, 32>, cutlass::half_t, cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, 0, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, 16>, cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::ColumnMajor, 0, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, 8>, cutlass::transform::threadblock::RegularTileIterator<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, 1, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, 16>, cutlass::half_t, cutlass::layout::RowMajor, cutlass::gemm::threadblock::MmaPolicy<cutlass::gemm::warp::MmaVoltaTensorOp<cutlass::gemm::GemmShape<64, 64, 32>, cutlass::half_t, cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, cutlass::half_t, cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, cutlass::half_t, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 16, 4>, 32, cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, cutlass::MatrixShape<1, 1> >, bool>, cutlass::MatrixShape<0, 0>, cutlass::MatrixShape<0, 0>, 1>, cutlass::NumericArrayConverter<cutlass::half_t, cutlass::half_t, 32, (cutlass::FloatRoundStyle)2>, cutlass::NumericArrayConverter<cutlass::half_t, cutlass::half_t, 32, (cutlass::FloatRoundStyle)2>, bool>, cutlass::epilogue::threadblock::Epilogue<cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::warp::MmaVoltaTensorOp<cutlass::gemm::GemmShape<64, 64, 32>, cutlass::half_t, cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, cutlass::half_t, cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, cutlass::half_t, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 16, 4>, 32, cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, cutlass::MatrixShape<1, 1> >, bool>, 1, cutlass::epilogue::threadblock::PredicatedTileIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 4, 4, 2, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 2, 2, 1, 4>, 128, 8, 16>, cutlass::half_t, false>, cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp<cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<32, 32, 4>, cutlass::half_t, cutlass::layout::RowMajor>, cutlass::epilogue::warp::TileIteratorVoltaTensorOp<cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<32, 32, 4>, cutlass::half_t, cutlass::layout::RowMajor>, cutlass::epilogue::threadblock::SharedLoadIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 4, 4, 2, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 2, 2, 1, 4>, 128, 8, 16>::CompactedThreadMap, cutlass::half_t, 8>, tcnn::ActivationEpilogue<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2>, cutlass::MatrixShape<0, 4>, 1, 1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, false>::operator()<<<(1,1,1),(128,1,1)>>> (this=0x7fffd2fffbb8, params=..., shared_storage=...)
    at /home/cloudlet/hturki/instant-ngp/3rdparty/tiny-cuda-nn/dependencies/cutlass/include/cutlass/gemm/kernel/gemm.h:258
#5  cutlass::Kernel<cutlass::gemm::kernel::Gemm<cutlass::gemm::threadblock::MmaPipelined<cutlass::gemm::GemmShape<128, 128, 32>, cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<128, 32>, cutlass::half_t, cutlass::layout::RowMajor, 1, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, 8>, cutlass::transform::threadblock::RegularTileIterator<cutlass::MatrixShape<128, 32>, cutlass::half_t, cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, 0, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, 16>, cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::ColumnMajor, 0, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, 8>, cutlass::transform::threadblock::RegularTileIterator<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, 1, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, 16>, cutlass::half_t, cutlass::layout::RowMajor, cutlass::gemm::threadblock::MmaPolicy<cutlass::gemm::warp::MmaVoltaTensorOp<cutlass::gemm::GemmShape<64, 64, 32>, cutlass::half_t, cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, cutlass::half_t, cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, cutlass::half_t, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 16, 4>, 32, cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, cutlass::MatrixShape<1, 1> >, bool>, cutlass::MatrixShape<0, 0>, cutlass::MatrixShape<0, 0>, 1>, cutlass::NumericArrayConverter<cutlass::half_t, cutlass::half_t, 32, (cutlass::FloatRoundStyle)2>, cutlass::NumericArrayConverter<cutlass::half_t, cutlass::half_t, 32, (cutlass::FloatRoundStyle)2>, bool>, cutlass::epilogue::threadblock::Epilogue<cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::warp::MmaVoltaTensorOp<cutlass::gemm::GemmShape<64, 64, 32>, cutlass::half_t, cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, cutlass::half_t, cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, cutlass::half_t, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 16, 4>, 32, cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, cutlass::MatrixShape<1, 1> >, bool>, 1, cutlass::epilogue::threadblock::PredicatedTileIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 4, 4, 2, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 2, 2, 1, 4>, 128, 8, 16>, cutlass::half_t, false>, cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp<cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<32, 32, 4>, cutlass::half_t, cutlass::layout::RowMajor>, cutlass::epilogue::warp::TileIteratorVoltaTensorOp<cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<32, 32, 4>, cutlass::half_t, cutlass::layout::RowMajor>, cutlass::epilogue::threadblock::SharedLoadIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 4, 4, 2, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 2, 2, 1, 4>, 128, 8, 16>::CompactedThreadMap, cutlass::half_t, 8>, tcnn::ActivationEpilogue<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t, (cutlass::FloatRoundStyle)2>, cutlass::MatrixShape<0, 4>, 1, 1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, false> ><<<(1,1,1),(128,1,1)>>> (params=...)
    at /home/cloudlet/hturki/instant-ngp/3rdparty/tiny-cuda-nn/dependencies/cutlass/include/cutlass/device_kernel.h:51

I've tested this code out on two different V100s when building with the TCNN_CUDA_ARCHITECTURES=70 flag and run into this issue. Building and testing this code out on a 3090 with compute 86 seems to work fine. Building against both architectures with TCNN_CUDA_ARCHITECTURES=70;86 causes the code to fail with the same error (referencing mma_tensor_op_tile_iterator_sm70.h) on the 3090.

hturki commented 2 years ago

I've also tried this with a 16x16 matrix instead of 8x8 - sad times still occur.