openppl-public / ppl.cv

ppl.cv is a high-performance image processing library of openPPL supporting various platforms.
Apache License 2.0
484 stars 108 forks source link

Non-blocking streams cause error results for many cuda operators in ppl.cv when using cudaMemcpy #135

Closed tp-nan closed 3 months ago

tp-nan commented 4 months ago

I encountered error results when using non-blocking streams. In our high concurrency service, the probability of error occurrence is approximately one in ten thousand. https://github.com/openppl-public/ppl.cv/blob/6f65a8d9e56aca1a58228e314472112363760392/src/ppl/cv/cuda/copymakeborder.cu#L364

The CUDA memory copy operation, cudaMemcpy, is not synchronized with the input stream. This issue does not arise when using regular streams but for non-blocking streams employed(Non-blocking streams do not synchronize with the legacy stream, and can be created using the cudaStreamNonBlocking flag with the stream creation APIs. ).

Non-blocking streams are widely used. For example, Pytorch use non-blocking stream by default:

https://github.com/pytorch/pytorch/blob/aa0b0944d50f294f613f0138ff56514f13005981/c10/cuda/CUDAStream.cpp#L21C31-L21C44

jimurk commented 3 months ago

你好,在高并发服务里,出现的error result是什么?memory read&write顺序错误吗?参考cuda文档,cudaStreamNonBlocking意味着不同stream里的cuda work可能会并发执行,是不是不同cuda stream里的cuda操作访问了同一块cuda memory?

在copymakeborder.cu里,使用cudaMemcpy处理不需要添加border的特殊情况,没有使用cuda stream管理多个kernel与memory操作的需求,使用cuda stream把单独一个cudaMemcpy操作分到多个cuda stream反而有些画蛇添足影响性能。

ppl.cv.cuda接口是否包含参数cudaStream_t stream有利有弊,包含此参数,则可以使用cuda stream统一管理ppl.cv.cuda函数内部和调用者自己的cuda操作提升性能,但需要调用者对函数内部cuda计算非常了解,破坏了函数封装,增加了复杂性。nppi/cudnn等nvidia库API有包含与不包含两个版本并存或者不包含cuda stream参数,我倾向于参考nvidia库的cuda stream使用策略。如果时间允许,我们这边会考虑再增加一个没有cuda stream参数的接口。

tp-nan commented 3 months ago

感谢答复。

出现的error result是什么

输入给copymakeborder.cu的数据在另外一个non-blocking stream中,导致未准备好就被 copymakeborder中的cudaMemcpy执行,导致拷贝的并非是真实的输入,最终输入给网络的是错误数据

是不是不同cuda stream里的cuda操作访问了同一块cuda memory?

是的,输入数据在non-blocking stream中,而cudaMemcpy在cuda stream 0中,前者写,后者读

使用cudaMemcpy处理不需要添加border的特殊情况,没有使用cuda stream管理多个kernel与memory操作的需求,使用cuda stream把单独一个cudaMemcpy操作分到多个cuda stream反而有些画蛇添足影响性能

输入数据要么走cudaMemcpy分支(此时使用stream 0 ),要么走kernel(此时使用输入的stream参数), 但是走cudaMemcpy分支时,没法处理输入的stream参数是cudaStreamNonBlocking,而且输入数据没有同步的情形。除非(方案一)默认输入图像数据一定是同步好了的,这样也可以。 或者(方案二)走cudaMemcpy分支使用kernel分支相同的流,也就是使用cudaMemcpyAsync替代cudaMemcpy。

我倾向于参考nvidia库的cuda stream使用策略

我这边使用包含cuda stream参数的,这个没什么问题

jimurk commented 3 months ago

客气了,可以考虑针对需求建一个分支,使用“方案二”按需使用异步cuda操作,解决这个问题。

tp-nan commented 3 months ago

客气了,可以考虑针对需求建一个分支,使用“方案二”按需使用异步cuda操作,解决这个问题。

嗯嗯 我主要报告这个问题,内部用到的算子是修改了的。

使用cudaMemcpyAsync替代cudaMemcpy应该性能更好的,避免了多余的默认流全局同步操作,而且尊重了输入的cuda流参数。我建议主分支也可以如此修改

jimurk commented 3 months ago

嗯嗯,明白你的意思。根据cuda runtime api文档,“For transfers from device memory to device memory, no host-side synchronization is performed”,这个地方没有同步开销,而且没有使用cudaMemcpyAsync+stream,也没有stream相关的同步开销,性能至少不会比cudaMemcpyAsync差。在你这个使用了多stream的使用场景里,cudaMemcpy可以理解为使用了默认的steam 0,导致了与其他stream产生memory read&write问题。这是我上面说的是否在函数接口里保留cudaStream_t stream相关的问题,这个问题有时间我再调研一下看看吧,master暂时先不考虑改为异步cuda操作。

tp-nan commented 3 months ago

我的理解里, cudaMemcpy 等价于 cudaStreamSynchronize(blocking streams )+ cudaMemcpyAsync(......, 0) + cudaStreamSynchronize(0). 应该是这点上我们不一致

Yes, you are right about "For transfers from device memory to device memory, no host-side synchronization is performed."

jimurk commented 3 months ago

从官方文档看,cudaMemcpy没有同步开销,而且这个api没有使用stream,不应该有stream相关的操作。你可以测试看看同步与异步两个api的性能

tp-nan commented 3 months ago

Yeah, ur right, no host-side synchronization For transfers from device memory to device memory in cudaMemcpy

cudaStreamNonBlocking a-> (cudaMemcpy)stream 0 -> cudaStreamSynchronize(a) -> cudaStreamNonBlocking b

I get it, in our scenario, errors are more likely to occur due to cudaStreamNonBlocking b as operator in default stream may be slow...