cupy / cupy

NumPy & SciPy for GPU
https://cupy.dev
MIT License
9.49k stars 855 forks source link

Might be a bug with ReductionKernel and stream capture #8318

Open Sa1ntPr0 opened 6 months ago

Sa1ntPr0 commented 6 months ago

Description

Perhaps this behavior is intentional, but I did not find information about it and it really confused me when I encountered it. I am sorry if I am missing something. I recorded 2 graphs via stream capture and used cupy.ReductionKernel in both graphs - I needed to write the minimum value of some array into a 0-dimensional CuPy variable. I was faced with the fact that one of the graphs didn't work as intended, but didn't cause any errors either. It turned out that this behavior was caused by the line a=CudaMin(A) in both graphs; replacing it with CudaMin(A,out=a) solved the problem. (CudaMin is my ReductionKernel) When a=CudaMin(A) was used, this operation was skipped in one of the graphs and the variable "a" remained unchanged. I hope this code shows what I'm talking about. The most confusing cases are 5 and 6.

To Reproduce

Just imports, this part is the same for all code blocks below

import cupy as cp
import numpy as np
CudaMin = cp.ReductionKernel(
    'T x',             # input params
    'T y',              # output params
    'x',              # map
    'min(a, b)',      # reduce
    'y=a',               # post-reduction map
    str(np.finfo(np.float32).max),   # identity value
    'CudaMax'          # kernel name
)

Case 1:

A=cp.asarray([1,2,3,4,5,6],dtype=cp.float32)
a=cp.asarray(10,dtype=cp.float32)
print("CudaMin(A,out=a)")
print(f"init a={a}, expected 10")
stream = cp.cuda.Stream()
with stream: 
    stream.begin_capture()
    CudaMin(A,out=a)
    graph1 = stream.end_capture()
stream.synchronize()
print(f"after capture a={a}, expected 10")

with stream: 
    graph1.launch(stream=stream)
stream.synchronize()
print(f"after 1st launch a={a}, expected 1")
A[0]=2
with stream: 
    graph1.launch(stream=stream)
stream.synchronize()
print(f"after 2nd launch a={a}, expected 2")
print("perfect")

Output 1:

CudaMin(A,out=a)
init a=10.0, expected 10
after capture a=10.0, expected 10
after 1st launch a=1.0, expected 1
after 2nd launch a=2.0, expected 2
perfect

Case 2:

A=cp.asarray([1,2,3,4,5,6],dtype=cp.float32)
a=cp.asarray(10,dtype=cp.float32)
print("a=CudaMin(A)")
print(f"init a={a}, expected 10")
stream = cp.cuda.Stream()
with stream: 
    stream.begin_capture()
    a=CudaMin(A)
    graph1 = stream.end_capture()
stream.synchronize()
print(f"after capture a={a}, expected 10")

with stream: 
    graph1.launch(stream=stream)
stream.synchronize()
print(f"after 1st launch a={a}, expected 1")
A[0]=2
with stream: 
    graph1.launch(stream=stream)
stream.synchronize()
print(f"after 2nd launch a={a}, expected 2")
print("strange, but it works")

Output 2:

a=CudaMin(A)
init a=10.0, expected 10
after capture a=0.0, expected 10
after 1st launch a=1.0, expected 1
after 2nd launch a=2.0, expected 2
strange, but it works

Case 3:

A=cp.asarray([1,2,3,4,5,6],dtype=cp.float32)
a=cp.asarray(10,dtype=cp.float32)
print("graph1: CudaMin(A,out=a), graph2: CudaMin(A,out=a)")
print("init",a)
stream = cp.cuda.Stream()
with stream: 
    stream.begin_capture()
    CudaMin(A,out=a)
    graph1 = stream.end_capture()
stream.synchronize()
print(f"after graph1 capture a={a}, expected 10")

with stream: 
    stream.begin_capture()
    CudaMin(A,out=a)
    graph2 = stream.end_capture()
stream.synchronize()
print(f"after graph2 capture a={a}, expected 10")

with stream: 
    graph1.launch(stream=stream)
stream.synchronize()
print(f"after graph1 launch a={a}, expected 1")

A[0]=2
with stream: 
    graph2.launch(stream=stream)
stream.synchronize()
print(f"after graph2 launch a={a}, expected 2")
print("perfect")

Output 3:

graph1: CudaMin(A,out=a), graph2: CudaMin(A,out=a)
init 10.0
after graph1 capture a=10.0, expected 10
after graph2 capture a=10.0, expected 10
after graph1 launch a=1.0, expected 1
after graph2 launch a=2.0, expected 2
perfect

Case 4:

A=cp.asarray([1,2,3,4,5,6],dtype=cp.float32)
a=cp.asarray(10,dtype=cp.float32)
print("graph1: a=CudaMin(A), graph2: CudaMin(A,out=a)")
print("init",a)
stream = cp.cuda.Stream()
with stream: 
    stream.begin_capture()
    a=CudaMin(A)
    graph1 = stream.end_capture()
stream.synchronize()
print(f"after graph1 capture a={a}, expected 10")

with stream: 
    stream.begin_capture()
    CudaMin(A,out=a)
    graph2 = stream.end_capture()
stream.synchronize()
print(f"after graph2 capture a={a}, expected 10")

with stream: 
    graph1.launch(stream=stream)
stream.synchronize()
print(f"after graph1 launch a={a}, expected 1")

A[0]=2
with stream: 
    graph2.launch(stream=stream)
stream.synchronize()
print(f"after graph2 launch a={a}, expected 2")
print("strange, but it works")

Output 4:

graph1: a=CudaMin(A), graph2: CudaMin(A,out=a)
init 10.0
after graph1 capture a=0.0, expected 10
after graph2 capture a=0.0, expected 10
after graph1 launch a=1.0, expected 1
after graph2 launch a=2.0, expected 2
strange, but it works

Case 5:

A=cp.asarray([1,2,3,4,5,6],dtype=cp.float32)
a=cp.asarray(10,dtype=cp.float32)
print("graph1: CudaMin(A,out=a), graph2: a=CudaMin(A)")
print("init",a)
stream = cp.cuda.Stream()
with stream: 
    stream.begin_capture()
    CudaMin(A,out=a)
    graph1 = stream.end_capture()
stream.synchronize()
print(f"after graph1 capture a={a}, expected 10")

with stream: 
    stream.begin_capture()
    a=CudaMin(A)
    graph2 = stream.end_capture()
stream.synchronize()
print(f"after graph2 capture a={a}, expected 10")

with stream: 
    graph1.launch(stream=stream)
stream.synchronize()
print(f"after graph1 launch a={a}, expected 1")

A[0]=2
with stream: 
    graph2.launch(stream=stream)
stream.synchronize()
print(f"after graph2 launch a={a}, expected 2")
print("graph1: expected behavior during capture, but it doesn't work")
print("graph2: strange behavior during capture, but it works")

Output 5:

graph1: CudaMin(A,out=a), graph2: a=CudaMin(A)
init 10.0
after graph1 capture a=10.0, expected 10
after graph2 capture a=0.0, expected 10
after graph1 launch a=0.0, expected 1
after graph2 launch a=2.0, expected 2
graph1: expected behavior during capture, but it doesn't work
graph2: strange behavior during capture, but it works

Case 6:

A=cp.asarray([1,2,3,4,5,6],dtype=cp.float32)
a=cp.asarray(10,dtype=cp.float32)
print("graph1: a=CudaMin(A), graph2: a=CudaMin(A)")
print("init",a)
stream = cp.cuda.Stream()
with stream: 
    stream.begin_capture()
    a=CudaMin(A)
    graph1 = stream.end_capture()
stream.synchronize()
print(f"after graph1 capture a={a}, expected 10")

with stream: 
    stream.begin_capture()
    a=CudaMin(A)
    graph2 = stream.end_capture()
stream.synchronize()
print(f"after graph2 capture a={a}, expected 10")

with stream: 
    graph1.launch(stream=stream)
stream.synchronize()
print(f"after graph1 launch a={a}, expected 1")

A[0]=2
with stream: 
    graph2.launch(stream=stream)
stream.synchronize()
print(f"after graph2 launch a={a}, expected 2")
print("graph1: strange behavior during capture, but it doesn't work")
print("graph2: strange behavior during capture, but it works")

Output 6:

graph1: a=CudaMin(A), graph2: a=CudaMin(A)
init 10.0
after graph1 capture a=0.0, expected 10
after graph2 capture a=0.0, expected 10
after graph1 launch a=0.0, expected 1
after graph2 launch a=2.0, expected 2
graph1: strange behavior during capture, but it doesn't work
graph2: strange behavior during capture, but it works

Installation

Conda-Forge (conda install ...)

Environment

OS                           : Windows-10-10.0.19045-SP0
Python Version               : 3.9.18
CuPy Version                 : 11.3.0
CuPy Platform                : NVIDIA CUDA
NumPy Version                : 1.26.2
SciPy Version                : 1.11.3
Cython Build Version         : 0.29.32
Cython Runtime Version       : None
CUDA Root                    : C:\Anaconda3\envs\CUDAenv
nvcc PATH                    : None
CUDA Build Version           : 11020
CUDA Driver Version          : 12040
CUDA Runtime Version         : 11080
cuBLAS Version               : (available)
cuFFT Version                : 10900
cuRAND Version               : 10300
cuSOLVER Version             : (11, 4, 1)
cuSPARSE Version             : (available)
NVRTC Version                : (11, 8)
Thrust Version               : 101000
CUB Build Version            : 101000
Jitify Build Version         : <unknown>
cuDNN Build Version          : None
cuDNN Version                : None
NCCL Build Version           : None
NCCL Runtime Version         : None
cuTENSOR Version             : None
cuSPARSELt Build Version     : None
Device 0 Name                : NVIDIA GeForce RTX 2070 with Max-Q Design
Device 0 Compute Capability  : 75
Device 0 PCI Bus ID          : 0000:01:00.0

Additional Information

No response

leofang commented 6 months ago

It is rare that we see adventurous users trying out stream capture (it was sorta in an "experimental" state so we didn't advertise it, https://github.com/cupy/cupy/issues/6290), so thanks for reaching out and raising the question!

I would think, at least from CUDA perspective, that cases 5 & 6 are expected. The key point is: During stream capture, there's no actual kernel launch. So all CUDA sees with this line

a=CudaMin(A)

during capture is:

  1. The input pointer address for A is recorded
  2. An output a is allocated from CuPy's mempool, and its pointer address is recorded
  3. A kernel for CudaMin would be launched recorded, with A's pointer as input and a's as output

By the time the graph is launched, the recorded pointer addresses would be reused for the actual kernel launch.

Note that for step 2 we rely on the fact there's a mempool; if we were to disable the pool and only use bare cudaMalloc under the hood, the capture would fail.

Sa1ntPr0 commented 6 months ago

Thanks for the answer! To be honest, I am not very familiar with CUDA and I am just an amateur in programming and therefore I could not fully understand the behavior when capturing a stream from your answer. For example, I still don’t understand why the value at address a changes after capturing and why one of the graphs works and the other doesn’t. But since you say that this is how it should happen, I believe you :) However, it might be worth adding some kind of warning when trying to capture operations like a=CudaMin(A) so that inexperienced users like me can quickly understand why their code does not behave as they would like.

leofang commented 6 months ago

Sorry I dropped the ball. @Sa1ntPr0 these are all legit questions. Let me focus on Case 5 since the confusion comes from the same root cause (interplay between Python, CuPy, and CUDA).

For example, I still don’t understand why the value at address a changes after capturing and why one of the graphs works and the other doesn’t.

In Case 5, it's because originally you have

a=cp.asarray(10,dtype=cp.float32)

in the beginning, but later during capture of graph 2 you bind a new array instance to the name a:

with stream: 
    ...
    a=CudaMin(A)
    ...

and so at later times when a is referenced in the print function, it refers to this instance instead of the earlier instance. Let me know if this makes better sense to you.

Sa1ntPr0 commented 5 months ago

Sorry for not returning to this issue for so long. Thank you, I think I'm starting to understand. Firstly, initially I had the false idea that during the capture of a graph, ABSOLUTELY NO real actions are and cannot be performed. Therefore, it seemed very strange to me that something was happening with my output 0d array a. Secondly, I thought that CuPy would treat a as a pointer to a value in the array, since a is 0-dimensional. But that's not true. If I understand correctly, if a were a 1-dimensional array a=cp.asarray([10],dtype=cp.float32), then a[0]=CudaMin(A) would lead to the behavior I want, since a[0] would be treated as a pointer to an array element. If I use a 0-dimensional array a, is there a way to show CuPy that I want to use a as a pointer and have the result of CudaMin(A) simply be written to the address that a points to, rather than creating a new instance of a for the result? (Besides using CudaMin(A,out=a) as in Case 3)