ROCm / triton

Development repository for the Triton language and compiler
MIT License
89 stars 27 forks source link

[Issue]: `tl.exp`, `tl.sin`, etc. result in segmentation fault on Fedora 40 #595

Open kenneth-ge opened 4 months ago

kenneth-ge commented 4 months ago

Problem Description

Hi all! I encountered a segmentation fault with some simple Triton commands. Was hoping for some help in determining the root cause of the issue, and whether this can be reproduced on anyone else's machine.

This simple script causes a segmentation fault:

@triton.jit
def exp(Y, stride_yn, X, stride_xn, N):
    n = tl.program_id(0)
    BLOCK_SIZE: tl.constexpr = 1024
    n_block = tl.arange(0, BLOCK_SIZE)

    z = tl.load(X + n * stride_xn + n_block, mask=n_block < N)

    value = tl.exp(z)

    Y = Y + n * stride_yn + n_block
    tl.store(Y, value, mask=n_block < N)

X = torch.normal(0, 1, size=(1024,), device='cuda')
Y = torch.empty_like(X, device='cuda')

grid = (ceildiv(X.shape[0], 1024),)
exp[grid](Y, Y.stride(0), 
              X, X.stride(0),
              X.shape[0])

torch_out = torch.exp(X)

print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(torch_out - Y))}')

MD5 of my copy of libtriton.so: c313d076950ad465fe264edd7b53309e

Stacktrace and undefined symbols are attached, but don't seem to offer much help. undefined_symbols.txt stacktrace.txt

Strangely, Torch's exp command seems to work, and does successfully utilize the GPU. This makes me suspect the issue may be somewhere with the Triton configuration or somewhere related, rather than my ROCm installatiom. Also, simple arithmetic commands, and also tl.sum, seem to work. Approximating exp, sin, cos, etc. via Taylor series does seem to work, but of course, is slower and much less accurate.

Please let me know if you need any more info-- I would be happy to provide it and am willing to change things on my end if that is where the issue lies!

Operating System

Fedora Linux 40

CPU

AMD EPYC 9334 32-Core Processor

GPU

AMD Instinct MI210

ROCm Version

ROCm 6.0.0

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

ROCk module is loaded =====================
HSA System Attributes
=====================
Runtime Version: 1.1 System Timestamp Freq.: 1000.000000MHz Sig. Max Wait Duration: 18446744073709551615 (0xFFFFFFFFFFFFFFFF) (timestamp count) Machine Model: LARGE
System Endianness: LITTLE
Mwaitx: DISABLED DMAbuf Support: YES

==========
HSA Agents
==========


Agent 1


Name: AMD EPYC 9334 32-Core Processor
Uuid: CPU-XX
Marketing Name: AMD EPYC 9334 32-Core Processor
Vendor Name: CPU
Feature: None specified
Profile: FULL_PROFILE
Float Round Mode: NEAR
Max Queue Number: 0(0x0)
Queue Min Size: 0(0x0)
Queue Max Size: 0(0x0)
Queue Type: MULTI
Node: 0
Device Type: CPU
Cache Info:
L1: 32768(0x8000) KB
Chip ID: 0(0x0)
ASIC Revision: 0(0x0)
Cacheline Size: 64(0x40)
Max Clock Freq. (MHz): 0
BDFID: 0
Internal Node ID: 0
Compute Unit: 64
SIMDs per CU: 0
Shader Engines: 0
Shader Arrs. per Eng.: 0
WatchPts on Addr. Ranges:1
Features: None Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: FINE GRAINED
Size: 65381760(0x3e5a580) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 2
Segment: GLOBAL; FLAGS: KERNARG, FINE GRAINED Size: 65381760(0x3e5a580) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 3
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 65381760(0x3e5a580) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
ISA Info:


Agent 2


Name: gfx90a
Uuid: GPU-0c20e239fe8ef916
Marketing Name: AMD Instinct MI210
Vendor Name: AMD
Feature: KERNEL_DISPATCH
Profile: BASE_PROFILE
Float Round Mode: NEAR
Max Queue Number: 128(0x80)
Queue Min Size: 64(0x40)
Queue Max Size: 131072(0x20000)
Queue Type: MULTI
Node: 1
Device Type: GPU
Cache Info:
L1: 16(0x10) KB
L2: 8192(0x2000) KB
Chip ID: 29711(0x740f)
ASIC Revision: 1(0x1)
Cacheline Size: 64(0x40)
Max Clock Freq. (MHz): 1700
BDFID: 50688
Internal Node ID: 1
Compute Unit: 104
SIMDs per CU: 4
Shader Engines: 8
Shader Arrs. per Eng.: 1
WatchPts on Addr. Ranges:4
Coherent Host Access: FALSE
Features: KERNEL_DISPATCH Fast F16 Operation: TRUE
Wavefront Size: 64(0x40)
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension: x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Max Waves Per CU: 32(0x20)
Max Work-item Per CU: 2048(0x800)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension: x 4294967295(0xffffffff)
y 4294967295(0xffffffff)
z 4294967295(0xffffffff)
Max fbarriers/Workgrp: 32
Packet Processor uCode:: 82
SDMA engine uCode:: 8
IOMMU Support:: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 67092480(0x3ffc000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 2
Segment: GLOBAL; FLAGS: EXTENDED FINE GRAINED Size: 67092480(0x3ffc000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 3
Segment: GROUP
Size: 64(0x40) KB
Allocatable: FALSE
Alloc Granule: 0KB
Alloc Alignment: 0KB
Accessible by all: FALSE
ISA Info:
ISA 1
Name: amdgcn-amd-amdhsa--gfx90a:sramecc+:xnack- Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension: x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension: x 4294967295(0xffffffff)
y 4294967295(0xffffffff)
z 4294967295(0xffffffff)
FBarrier Max Size: 32
Done

Additional Information

No response