GPUOpen-LibrariesAndSDKs / Orochi

MIT License
204 stars 32 forks source link

Fix oroCtxGetCurrent api for ADL #25

Closed givinar closed 2 years ago

givinar commented 2 years ago

This fix helps to get the correct context from the isValid() ADL API. Otherwise, this function throws an exception.

takahiroharada commented 2 years ago

This change doesn't make sense to me at all. Create an empty context locally and pass it to the function?

givinar commented 2 years ago

This change doesn't make sense to me at all. Create an empty context locally and pass it to the function?

Thanks for your reply. Let me explain. When we call the function isValid() inside the adl library, only a pointer (oroCtx) to the structure (ioroCtx_t) is created.

bool DeviceHIP::isValid() const {
        oroCtx ctx;
        auto e = oroCtxGetCurrent(&ctx);
        ADLASSERT(e == ORO_SUCCESS, 0);
        return m_impl->m_ctxt == ctx;
}

The structure itself is not created. Next, we pass this pointer to the oroCtxGetCurrentfunction.

Let me simplify and expand the macro __ORO_FUNC1

oroError OROAPI oroCtxGetCurrent(oroCtx* pctx)
{
    hipCtx_t* hipCtx = oroCtx2hip( pctx );
    hipError_t e = hipCtxGetCurrent( hipCtx );
    return hip2oro( e );
}

Inside oroCtx2hipyou are trying to cast the pointer to the struct and return a pointer (m_ptr) to the hipCtx_tstruct. But we didn't allocate ioroCtx_t, so m_ptrdoesn't make sense. And as a consequence, hipCtxGetCurrentthrows an exception (Access violation reading location 0xFFFFFFFFFFFFFFFF.) Also, we can't create an ioroCtx_tstruct inside Adl because it only works with the oroCtxinterface.

Next, in Adl, we must compare not the Orochi context (oroCtx), but the HIP (hipCtx_t) or CUDA (CUcontext) context. But this is fixed in another PR in Adl.

Maybe I misunderstand something and you'll show me the correct use case, but that an exception is happening now is a fact

PS: About ioroCtx_t on the stack was my fault. I pass a pointer to a function, but this changes the API, which is also not good.

takahiroharada commented 2 years ago

I now see the problem. How about doing something like this?

bool DeviceHIP::isValid() const 
{
        oroCtx ctx;
        auto e = oroCtxGetCurrent(&ctx);
        ADLASSERT(e == ORO_SUCCESS, 0);
        bool isSame = m_impl->m_ctxt == oroGetRawCtx( ctx );
        e = oroCtxCreateFromRawDestroy( ctx );
        ADLASSERT(e == ORO_SUCCESS, 0);
        return isSame;
}

We need to allocate ioroCtx_t in oroCtxGetCurrent like what we do in oroCtxCreate, but use CtxGetCurrent to fill m_ptr. IMO this is cleaner.

givinar commented 2 years ago

I now see the problem. How about doing something like this?

bool DeviceHIP::isValid() const 
{
      oroCtx ctx;
      auto e = oroCtxGetCurrent(&ctx);
      ADLASSERT(e == ORO_SUCCESS, 0);
      bool isSame = m_impl->m_ctxt == oroGetRawCtx( ctx );
      e = oroCtxCreateFromRawDestroy( ctx );
      ADLASSERT(e == ORO_SUCCESS, 0);
      return isSame;
}

We need to allocate ioroCtx_t in oroCtxGetCurrent like what we do in oroCtxCreate, but use CtxGetCurrent to fill m_ptr. IMO this is cleaner.

Yeah, I get it. You want to clear ioroCtx_t in oroCtxCreateFromRawDestroy. This will work, but a new API function is needed.

I prepared a new approach, but didn't push it today. Only Orochi needs to be changed. Adl does not need to be changed. It based on hash map.

#define __ORO_RET_ERR( e ) if( s_api == ORO_API_CUDA ) return cu2oro((CUresult)e ); if( s_api == ORO_API_HIP ) return hip2oro( (hipError_t)e );

std::unordered_map<void*, oroCtx> s_oroCtxs;

oroError OROAPI oroCtxCreate(oroCtx* pctx, unsigned int flags, oroDevice dev)
{
    ioroDevice d( dev );
    ioroCtx_t* ctxt = new ioroCtx_t;
    ctxt->setApi( d.getApi() );
    (*pctx) = ctxt;
    s_api = ctxt->getApi();
    int e = oroErrorUnknown;
    if( s_api == ORO_API_CUDA ) e = cuCtxCreate( oroCtx2cu( pctx ), flags, d.getDevice() );
    if( s_api == ORO_API_HIP ) e = hipCtxCreate( oroCtx2hip( pctx ), flags, d.getDevice() );
    if( e )
    {
        __ORO_RET_ERR( e )
    }
    std::lock_guard<std::mutex> lock( mtx );
    s_oroCtxs[ctxt->m_ptr] = ctxt;
    return oroSuccess;
}
oroError OROAPI oroCtxGetCurrent(oroCtx* pctx)
{
    ioroCtx_t* ctxt = new ioroCtx_t;
    int e = oroErrorUnknown;
    if( s_api == ORO_API_CUDA ) e = cuCtxGetCurrent( oroCtx2cu( &ctxt ) );
    if( s_api == ORO_API_HIP ) e = hipCtxGetCurrent( oroCtx2hip( &ctxt ) );
    if( e )
    {
        __ORO_RET_ERR( e )
    }

    ( *pctx ) = s_oroCtxs[ctxt->m_ptr];
        delete ctxt;
    return oroSuccess;
}

Also need to change oroCtxDestroy. If you agree with this solution, I will push it tomorrow, otherwise I can implement your solution.

givinar commented 2 years ago

I found that the current version of Adl is referencing an old commit, so I didn't see the oroCtxCreateFromRawDestroyAPI.

What I think is that if someone calls the oroCtxGetCurrentfunction and doesn't call oroCtxCreateFromRawDestroyit will cause a memory leak. The hashmap solution will eliminate this scenario.