hughperkins / pycltorch

POC for Python wrappers for cltorch/clnn
BSD 2-Clause "Simplified" License
3 stars 2 forks source link

setting OpenCL device used by PyClTorch? #2

Closed lebedov closed 8 years ago

lebedov commented 8 years ago

I tried to add support for setting the device used by PyClTorch. Setting the GPU via the setDevice function defined in the below modification doesn't seem to work, however; any subsequent allocations of ClTensors, for instance, use whatever GPU was selected at initialization, and getDevice returns the same device number even after invoking setDevice. What am I missing?

diff --git a/PyClTorch.pyx b/PyClTorch.pyx
index 848fe80..0e0f453 100644
--- a/PyClTorch.pyx
+++ b/PyClTorch.pyx
@@ -17,6 +17,9 @@ cdef extern from "LuaHelper.h":

 cdef extern from "THClGeneral.h":
     cdef struct THClState
+    int THClState_getNumDevices(THClState* state);
+    void THClState_setDevice(THClState* state, int device);
+    int THClState_getDevice(THClState* state)

 cdef extern from "THTensor.h":
     cdef struct THFloatTensor
@@ -199,6 +202,18 @@ cdef class ClGlobalState(object):

 cdef ClGlobalState clGlobalState

+def getDeviceCount():
+    global clGlobalState
+    return THClState_getNumDevices(clGlobalState.state)
+
+def setDevice(device):
+    global clGlobalState
+    THClState_setDevice(clGlobalState.state, device-1)
+
+def getDevice():
+    global clGlobalState
+    return THClState_getDevice(clGlobalState.state+1)
+
 def init():
     global clGlobalState
     cdef THClState *state2
hughperkins commented 8 years ago

Ok. Nice work figuring out your way through the spaghetti :-)

Note that I dont think you want to add 1 to clGlobalSTate.state? But even if you fix that, the code still wont work. I reckon it's somethign to do with the commetns in the following lines :-) :

https://github.com/hughperkins/pycltorch/blob/master/PyClTorch.pyx#L70

 self.native = THClTensor_newv2(clGlobalState.state, 0) # FIXME get device from state
hughperkins commented 8 years ago

Since you've done the hard bit, I think we can modify the code as follows:

            device = THClState_getDevice(clGlobalState.state)
            print('device', device)
            if len(args) == 0:
                self.native = THClTensor_newv2(clGlobalState.state, device)
            elif len(args) == 1:
                self.native = THClTensor_newWithSize1d(clGlobalState.state, device, args[0])
            elif len(args) == 2:
                self.native = THClTensor_newWithSize2d(clGlobalState.state, device, args[0], args[1])
            elif len(args) == 3:
                self.native = THClTensor_newWithSize3d(clGlobalState.state, device, args[0], args[1], args[2])
            elif len(args) == 4:
                self.native = THClTensor_newWithSize4d(clGlobalState.state, device, args[0], args[1], args[2], args[3])
            else:
                raise Exception('Not implemented, len(args)=' + str(len(args)))
lebedov commented 8 years ago

Aha - That did the trick. Pull request submitted. Thanks!

lebedov commented 8 years ago

Closing - thanks!