hughperkins / cltorch

An OpenCL backend for torch.
Other
291 stars 26 forks source link

Test error: "__local float (*)[128]" is incompatible with parameter of type "__local float *" #69

Closed joelself closed 8 years ago

joelself commented 8 years ago

When I run luajit -l cltorch -e 'cltorch.test()' I get this at the end:

Completed 190 asserts in 117 tests with 0 failures and 1 error
--------------------------------------------------------------------------------
test_sort
 Function call failed
C++ exception

--------------------------------------------------------------------------------
luajit: /home/joel/torch/install/share/lua/5.1/torch/Tester.lua:362: An error was found while running tests!
stack traceback:
    [C]: in function 'assert'
    /home/joel/torch/install/share/lua/5.1/torch/Tester.lua:362: in function 'run'
    ...joel/torch/install/share/lua/5.1/cltorch/unit_tensor.lua:1071: in function 'test'
    /home/joel/torch/install/share/lua/5.1/cltorch/Test.lua:15: in function 'test'
    (command line):1: in main chunk
    [C]: at 0x00406670

OS: Ubuntu 14.04 x86_64 CPU: Intel Core i5-2500K RAM: 8 GB of something Graphics: XFX AMD R9 280x Lua: LuaJIT 2.1.0-beta1 g++ -v: g++ -v Using built-in specs. COLLECT_GCC=g++ COLLECT_LTO_WRAPPER=/usr/lib/gcc/x86_64-linux-gnu/4.8/lto-wrapper Target: x86_64-linux-gnu Configured with: ../src/configure -v --with-pkgversion='Ubuntu 4.8.4-2ubuntu1~14.04.1' --with-bugurl=file:///usr/share/doc/gcc-4.8/README.Bugs --enable-languages=c,c++,java,go,d,fortran,objc,obj-c++ --prefix=/usr --program-suffix=-4.8 --enable-shared --enable-linker-build-id --libexecdir=/usr/lib --without-included-gettext --enable-threads=posix --with-gxx-include-dir=/usr/include/c++/4.8 --libdir=/usr/lib --enable-nls --with-sysroot=/ --enable-clocale=gnu --enable-libstdcxx-debug --enable-libstdcxx-time=yes --enable-gnu-unique-object --disable-libmudflap --enable-plugin --with-system-zlib --disable-browser-plugin --enable-java-awt=gtk --enable-gtk-cairo --with-java-home=/usr/lib/jvm/java-1.5.0-gcj-4.8-amd64/jre --enable-java-home --with-jvm-root-dir=/usr/lib/jvm/java-1.5.0-gcj-4.8-amd64 --with-jvm-jar-dir=/usr/lib/jvm-exports/java-1.5.0-gcj-4.8-amd64 --with-arch-directory=amd64 --with-ecj-jar=/usr/share/java/eclipse-ecj.jar --enable-objc-gc --enable-multiarch --disable-werror --with-arch-32=i686 --with-abi=m64 --with-multilib-list=m32,m64,mx32 --with-tune=generic --enable-checking=release --build=x86_64-linux-gnu --host=x86_64-linux-gnu --target=x86_64-linux-gnu Thread model: posix gcc version 4.8.4 (Ubuntu 4.8.4-2ubuntu1~14.04.1)

Full Error:

THClSortUtils.cl build log: 
"/tmp/OCL7018T58.cl", line 210: error: argument of type
          "__local float (*)[128]" is incompatible with parameter of type
          "__local float *"
      bitonicSort(&sharedKeys, &sharedValues, &sharedValid);
                  ^

"/tmp/OCL7018T58.cl", line 210: error: argument of type
          "__local float (*)[128]" is incompatible with parameter of type
          "__local float *"
      bitonicSort(&sharedKeys, &sharedValues, &sharedValid);
                               ^

"/tmp/OCL7018T58.cl", line 210: error: argument of type "__local bool (*)[128]"
          is incompatible with parameter of type "__local bool *"
      bitonicSort(&sharedKeys, &sharedValues, &sharedValid);
                                              ^

3 errors detected in the compilation of "/tmp/OCL7018T58.cl".
Frontend phase failed compilation.

kernel build error:

kernel source:
1: // from lib/THC/THCSortUtils.cuh:
2: 
3: // This needs the following template variables:
4: //   K              key type
5: //   V              value type
6: //   COMPARE_OP     a comparison operator, like <   or >
7: //   KeyDims        integer
8: //   ValueDims      integer
9: //   Power2SortSize  integer
10: //   dims           list of KeyDims and ValueDims
11: 
12: // you need to somewhere include  before this, with appropriate dims, to include
13: // KeyDims and ValueDims
14: 
15: // this needs the following template variables defined:
16: //   IndexType   string, eg 'int'
17: //   MAX_CLTORCH_DIMS    integer, eg 25
18: //   dims                list of integers, ie all dimensions >=0 this should work for
19: //   WarpSize            integer eg 32
20: //   defiscontiguous     [1|0]  (or just dont define, means 0)
21: //   defreduceblock      [1|0]  (or just dont define, means 0)
22: 
23: 
24: // kernel argument that defines tensor layout
25: typedef struct TensorInfoCl {
26:   // Extracts size/stride information for the kernel.
27:   // Successive dimensions can be collapsed if the size/strides match
28:   // up and thus there are no holes between the dimensions. This is used
29:   // to reduce the complexity of the problem.
30:   // The optional `reduceDim` indicates a reduction dimension for the
31:   // given tensor, so that the output size for this dimension will be 1.
32: 
33:   unsigned int sizes[25];
34:   unsigned int strides[25];
35:   unsigned int offset;
36:   int dims;
37: } TensorInfoCl;
38: // Contiguous tensors of more than one dimension are collapsed down
39: // to one tensor
40: 
41: 
42: // Translate a linear index for the apply to a float* offset;
43: // specialized on `Dims` to reduce nvcc compilation time
44: 
45: 
46: inline unsigned int IndexToOffset_998_get(unsigned int linearId, global const TensorInfoCl *info) {
47:     return linearId + info->offset;
48: }
49: 
50: inline unsigned int IndexToOffset_999_get(unsigned int linearId, global const TensorInfoCl *info) {
51:   unsigned int offset = info->offset;
52: 
53:   // Use dynamic dims
54:   for (int i = info->dims - 1; i >= 0; --i) {
55:     unsigned int curDimIndex = linearId % info->sizes[i];
56:     unsigned int curDimOffset = curDimIndex * info->strides[i];
57:     offset += curDimOffset;
58: 
59:     linearId /= info->sizes[i];
60:   }
61: 
62:   return offset;
63: }
64: 
65: inline unsigned int getLinearBlockId() {
66:   return get_group_id(2) * get_num_groups(1) * get_num_groups(0) +
67:     get_group_id(1) * get_num_groups(0) +
68:     get_group_id(0);
69: }
70: 
71: // Block-wide reduction in shared memory helper; only /*threadIdx.x*/ get_local_id(0) == 0 will
72: // return the reduced value
73: 
74: 
75: 
76: 
77: inline void swapVars_K(local float *p_t1, local float*p_t2) {
78:   float tmp = *p_t1;
79:   *p_t1 = *p_t2;
80:   *p_t2 = tmp;
81: }
82: 
83: inline void swapVars_V(local float *p_t1, local float*p_t2) {
84:   float tmp = *p_t1;
85:   *p_t1 = *p_t2;
86:   *p_t2 = tmp;
87: }
88: 
89: inline void swapVars_bool(local bool *p_t1, local bool *p_t2) {
90:   bool tmp = *p_t1;
91:   *p_t1 = *p_t2;
92:   *p_t2 = tmp;
93: }
94: 
95: inline void bitonicSwap(local float* p_kA, local float*p_vA, local bool*p_validA,
96:                         local float* p_kB, local float*p_vB, local bool*p_validB,
97:                         bool dir) {
98:   // Invalid entries always sort to the end
99:   // original cuda version was:
100:   //   bool swap = (comp(kA, kB) && validA) || !validB;
101:   bool swap = (((*p_kA) < (*p_kB)) && (*p_validA)) || !(*p_validB);
102:   if (swap == dir) {
103:     swapVars_K(p_kA, p_kB);
104:     swapVars_V(p_vA, p_vB);
105:     swapVars_bool(p_validA, p_validB);
106:   }
107: };
108: 
109: inline void bitonicSort(local float *p_keys,
110:                                    local float *p_values,
111:                                    local bool *p_valid) {
112:   #pragma unroll
113:   for (unsigned int size = 2; size < 128; size *= 2) {
114:     bool flag = ((get_local_id(0) & (size / 2)) != 0);
115: 
116:     #pragma unroll
117:     for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
118: 
119:       // Single warp per slice is completely synchronous
120:       if (128 > 32) {   // is 64 ok?  Let's try 32 till it is working ok...
121:         barrier(CLK_LOCAL_MEM_FENCE);
122:       }
123: 
124:       unsigned int pos = 2 * get_local_id(0) - (get_local_id(0) & (stride - 1));
125:       bitonicSwap(
126:         p_keys + pos, p_values + pos, p_valid + pos,
127:         p_keys + pos + stride, p_values + pos + stride, p_valid + pos + stride,
128:         flag);
129:     }
130:   }
131: 
132:   #pragma unroll
133:   for (unsigned int stride = 128 / 2; stride > 0; stride /= 2) {
134:     // Single warp per slice is completely synchronous
135:     if (128 > 32) { // note: was 64 before
136:       barrier(CLK_LOCAL_MEM_FENCE);
137:     }
138: 
139:     unsigned int pos = 2 * get_local_id(0) - (get_local_id(0) & (stride - 1));
140:     bitonicSwap(
141:       p_keys + pos, p_values + pos, p_valid + pos,
142:       p_keys + pos + stride, p_values + pos + stride, p_valid + pos + stride,
143:       false);
144:   }
145: 
146:   // Single warp per slice is completely synchronous
147:   if (128 > 32) {  // note: was 64 before
148:     barrier(CLK_LOCAL_MEM_FENCE);
149:   }
150: }
151: 
152: // Sorts (key, value) pairs (in different tensors) in-place; i.e.,
153: // modifies the input `keys` and `values`
154: kernel void
155: bitonicSortKVInPlace(global TensorInfoCl *keys_info, global float *keys_data,
156:                      unsigned int keySlices,
157:                      unsigned int keySliceSize,
158:                      unsigned int keySliceStride,
159:                      global TensorInfoCl *values_info, global float *values_data,
160:                      unsigned int valueSliceStride
161: ) {
162:   // Find the slice of the tensor that we are sorting
163:   const unsigned int linearIndex = getLinearBlockId();
164:   // Tiling the slices could have us be out of bounds, if there are a
165:   // lot of slices to sort
166:   if (linearIndex >= keySlices) {
167:     return;
168:   }
169: 
170:   local float sharedKeys[128];
171:   local float sharedValues[128];
172:   local bool sharedValid[128];
173: 
174:   const unsigned int keyStartOffset =
175:     IndexToOffset_998_get(linearIndex, &keys_info[0]);
176:   const unsigned int valueStartOffset =
177:     IndexToOffset_999_get(linearIndex, &values_info[0]);
178: 
179:   // If the sort size is 1, the data is already sorted
180:   if (128 == 1) {
181:     return;
182:   } else {
183:     // Otherwise, each thread is responsible for loading and storing 2
184:     // elements. The sort size is guaranteed to be >= 2
185:     const int elem1 = get_local_id(0);
186:     const int elem2 = get_local_id(0) + (128 / 2);
187: 
188:     bool valid1 = (elem1 < keySliceSize);
189:     float k1 = valid1 ?
190:       keys_data[keyStartOffset + elem1 * keySliceStride] : (float) 0;
191:     float v1 = valid1 ?
192:       values_data[valueStartOffset + elem1 * valueSliceStride] : (float) 0;
193: 
194:     sharedKeys[elem1] = k1;
195:     sharedValues[elem1] = v1;
196:     sharedValid[elem1] = valid1;
197: 
198:     bool valid2 = (elem2 < keySliceSize);
199:     float k2 = valid2 ?
200:       keys_data[keyStartOffset + elem2 * keySliceStride] : (float) 0;
201:     float v2 = valid2 ?
202:       values_data[valueStartOffset + elem2 * valueSliceStride] : (float) 0;
203: 
204:     sharedKeys[elem2] = k2;
205:     sharedValues[elem2] = v2;
206:     sharedValid[elem2] = valid2;
207: 
208:     // Sort!
209: //    if(get_local_id(0) == 0) {
210:     bitonicSort(&sharedKeys, &sharedValues, &sharedValid);
211: //   }
212: 
213: ////    if(get_local_id(0) == 0) {
214: //      keys_data[0] = sharedKeys[0];
215: //      keys_data[1] = sharedKeys[1];
216: ////      keys_data[0] = elem1;
217: ////      keys_data[1] = elem2;
218: ////      values_data[0] = 128;
219: //      values_data[0] = sharedValues[0];
220: //      values_data[1] = sharedValues[1];
221: ////    }
222: 
223: 
224:     // elem1 values are always valid, since otherwise we would have
225:     // chosen the next smallest power-of-2 for sorting
226:     keys_data[keyStartOffset + elem1 * keySliceStride] =
227:       sharedKeys[elem1];
228:     values_data[valueStartOffset + elem1 * valueSliceStride] =
229:       sharedValues[elem1];
230: 
231:     if (valid2) {
232:       // elem2 values might be out-of-range, if the data size we are
233:       // sorting is not a power-of-2
234:       keys_data[keyStartOffset + elem2 * keySliceStride] =
235:         sharedKeys[elem2];
236:       values_data[valueStartOffset + elem2 * valueSliceStride] =
237:         sharedValues[elem2];
238:     }
239:   }
240: }
241: 
242: 

Invalid kernel name, code -46, kernel bitonicSortKVInPlace
THClSortUtils.cl build log: 
"/tmp/OCL7018T58.cl", line 210: error: argument of type
          "__local float (*)[128]" is incompatible with parameter of type
          "__local float *"
      bitonicSort(&sharedKeys, &sharedValues, &sharedValid);
                  ^

"/tmp/OCL7018T58.cl", line 210: error: argument of type
          "__local float (*)[128]" is incompatible with parameter of type
          "__local float *"
      bitonicSort(&sharedKeys, &sharedValues, &sharedValid);
                               ^

"/tmp/OCL7018T58.cl", line 210: error: argument of type "__local bool (*)[128]"
          is incompatible with parameter of type "__local bool *"
      bitonicSort(&sharedKeys, &sharedValues, &sharedValid);
                                              ^

3 errors detected in the compilation of "/tmp/OCL7018T58.cl".
Frontend phase failed compilation.

 52/117 test_sort ....................................................... [ERROR]

th -l cltorch -e 'cltorch.about()'

cltorch.  OpenCL backend for Torch
Built from commit 7bcb1b9
More info, doc: https://github.com/hughperkins/cltorch
Issues: https://github.com/hughperkins/cltorch/issues

I'd like to include logs, but I don't know where they are.

Edit: Probably a better title for this would be Test error: "__local float (*)[128]" is incompatible with parameter of type "__local float *" Edit: Oh right, I can edit titles.

joelself commented 8 years ago

Ok, I found the problem: in src/lib/THClSortUtils.cl

bitonicSort(&sharedKeys, &sharedValues, &sharedValid);

The &'s are unnecessary since the variables are already arrays. Here's a minimal recreation: This compiles fine:

void test_func(float *junk) {}   
int main() {                   
        float stuff[16];
        test_func(stuff);
}

But change the call to test_func to:

test_func(&stuff);

and you get this on compile:

test.cpp: In function ‘int main()’:
test.cpp:4:18: error: cannot convert ‘float (*)[16]’ to ‘float*’ for argument ‘1’ to ‘void test_func(float*)’
  test_func(&stuff);

Which is pretty close to the error I'm getting when I run cltorch.test.

hughperkins commented 8 years ago

Nice spot, thanks! I merged your change, as a22b207 . Can you pull down the latest version and retry?

joelself commented 8 years ago

It's on a different video card (I had to switch to an NVidia card when cltorch wouldn't work and I don't have the time to switch back right now, but I'll try tonight), but it works!

Completed 195 asserts in 117 tests with 0 failures and 0 errors
all tests finished