Closed givinar closed 2 years ago
This change doesn't make sense to me at all. Create an empty context locally and pass it to the function?
This change doesn't make sense to me at all. Create an empty context locally and pass it to the function?
Thanks for your reply.
Let me explain.
When we call the function isValid() inside the adl library, only a pointer (oroCtx
) to the structure (ioroCtx_t
) is created.
bool DeviceHIP::isValid() const {
oroCtx ctx;
auto e = oroCtxGetCurrent(&ctx);
ADLASSERT(e == ORO_SUCCESS, 0);
return m_impl->m_ctxt == ctx;
}
The structure itself is not created.
Next, we pass this pointer to the oroCtxGetCurrent
function.
Let me simplify and expand the macro __ORO_FUNC1
oroError OROAPI oroCtxGetCurrent(oroCtx* pctx)
{
hipCtx_t* hipCtx = oroCtx2hip( pctx );
hipError_t e = hipCtxGetCurrent( hipCtx );
return hip2oro( e );
}
Inside oroCtx2hip
you are trying to cast the pointer to the struct and return a pointer (m_ptr
) to the hipCtx_t
struct.
But we didn't allocate ioroCtx_t
, so m_ptr
doesn't make sense.
And as a consequence, hipCtxGetCurrent
throws an exception (Access violation reading location 0xFFFFFFFFFFFFFFFF.)
Also, we can't create an ioroCtx_t
struct inside Adl because it only works with the oroCtx
interface.
Next, in Adl, we must compare not the Orochi context (oroCtx
), but the HIP (hipCtx_t
) or CUDA (CUcontext
) context.
But this is fixed in another PR in Adl.
Maybe I misunderstand something and you'll show me the correct use case, but that an exception is happening now is a fact
PS: About ioroCtx_t on the stack was my fault. I pass a pointer to a function, but this changes the API, which is also not good.
I now see the problem. How about doing something like this?
bool DeviceHIP::isValid() const
{
oroCtx ctx;
auto e = oroCtxGetCurrent(&ctx);
ADLASSERT(e == ORO_SUCCESS, 0);
bool isSame = m_impl->m_ctxt == oroGetRawCtx( ctx );
e = oroCtxCreateFromRawDestroy( ctx );
ADLASSERT(e == ORO_SUCCESS, 0);
return isSame;
}
We need to allocate ioroCtx_t
in oroCtxGetCurrent
like what we do in oroCtxCreate
, but use CtxGetCurrent
to fill m_ptr
. IMO this is cleaner.
I now see the problem. How about doing something like this?
bool DeviceHIP::isValid() const { oroCtx ctx; auto e = oroCtxGetCurrent(&ctx); ADLASSERT(e == ORO_SUCCESS, 0); bool isSame = m_impl->m_ctxt == oroGetRawCtx( ctx ); e = oroCtxCreateFromRawDestroy( ctx ); ADLASSERT(e == ORO_SUCCESS, 0); return isSame; }
We need to allocate
ioroCtx_t
inoroCtxGetCurrent
like what we do inoroCtxCreate
, but useCtxGetCurrent
to fillm_ptr
. IMO this is cleaner.
Yeah, I get it. You want to clear ioroCtx_t
in oroCtxCreateFromRawDestroy
. This will work, but a new API function is needed.
I prepared a new approach, but didn't push it today. Only Orochi needs to be changed. Adl does not need to be changed. It based on hash map.
#define __ORO_RET_ERR( e ) if( s_api == ORO_API_CUDA ) return cu2oro((CUresult)e ); if( s_api == ORO_API_HIP ) return hip2oro( (hipError_t)e );
std::unordered_map<void*, oroCtx> s_oroCtxs;
oroError OROAPI oroCtxCreate(oroCtx* pctx, unsigned int flags, oroDevice dev)
{
ioroDevice d( dev );
ioroCtx_t* ctxt = new ioroCtx_t;
ctxt->setApi( d.getApi() );
(*pctx) = ctxt;
s_api = ctxt->getApi();
int e = oroErrorUnknown;
if( s_api == ORO_API_CUDA ) e = cuCtxCreate( oroCtx2cu( pctx ), flags, d.getDevice() );
if( s_api == ORO_API_HIP ) e = hipCtxCreate( oroCtx2hip( pctx ), flags, d.getDevice() );
if( e )
{
__ORO_RET_ERR( e )
}
std::lock_guard<std::mutex> lock( mtx );
s_oroCtxs[ctxt->m_ptr] = ctxt;
return oroSuccess;
}
oroError OROAPI oroCtxGetCurrent(oroCtx* pctx)
{
ioroCtx_t* ctxt = new ioroCtx_t;
int e = oroErrorUnknown;
if( s_api == ORO_API_CUDA ) e = cuCtxGetCurrent( oroCtx2cu( &ctxt ) );
if( s_api == ORO_API_HIP ) e = hipCtxGetCurrent( oroCtx2hip( &ctxt ) );
if( e )
{
__ORO_RET_ERR( e )
}
( *pctx ) = s_oroCtxs[ctxt->m_ptr];
delete ctxt;
return oroSuccess;
}
Also need to change oroCtxDestroy. If you agree with this solution, I will push it tomorrow, otherwise I can implement your solution.
I found that the current version of Adl is referencing an old commit, so I didn't see the oroCtxCreateFromRawDestroy
API.
What I think is that if someone calls the oroCtxGetCurrent
function and doesn't call oroCtxCreateFromRawDestroy
it will cause a memory leak.
The hashmap solution will eliminate this scenario.
This fix helps to get the correct context from the isValid() ADL API. Otherwise, this function throws an exception.