fjarri / reikna

Pure Python GPGPU library
http://reikna.publicfields.net/
MIT License
164 stars 16 forks source link

New Transformation format, derive_render_kwds error #6

Closed tnorth closed 11 years ago

tnorth commented 11 years ago

Hi,

I am trying to migrate from an old tigger version to reikna (develop branch), and it seems that there are important changes in its Transformation mechanism. By looking at some tests, I found this example Transformation:

# Output = Input1 * Parameter1 + Input 2
tr_2_to_1 = Transformation(
    inputs=2, outputs=1, scalars=1,
    derive_o_from_is=lambda i1, i2, s1: i1,
    derive_render_kwds=lambda o1, i1, i2, s1: dict(
        mul=functions.mul(o1, i1),
        cast=functions.cast(o1, s1)),
    snippet="""
${o1.ctype} t = ${mul}(${cast}(${s1}), ${i1.load});
${o1.store}(t + ${i2.load});
""")

What I understood from it is that: 1) The output dtype is inferred from the input dtypes, and here o1 is specified to be the dtype of i1 2) derive_render_kwds returns a dict of built-in functions and the dtype of their inputs.

First of all, I tried to run it with a dummy identity Computation, and connected it. The signature_str() returns, as expected: (array) out, (array) i1, (array) i2, (scalar) s1 But then at the preparing stage, I get this error:

Traceback (most recent call last):
  File "test_reikna.py", line 40, in <module>
    foo.prepare_for(dest_dev, a_dev, b_dev, scalar)
  File "~/reikna/reikna/core/computation.py", line 207, in prepare_for
    self._basis = self._basis_for(args, kwds)
  File "~/reikna/reikna/core/computation.py", line 156, in _basis_for
    basis = AttrDict(self._get_basis_for(*self._tr_tree.base_values(), **kwds))
  File "~/reikna/reikna/elementwise.py", line 127, in _get_basis_for
    return Elementwise._get_basis_for(self, *args, code=code, dependencies=dependencies)
  File "~/reikna/reikna/elementwise.py", line 71, in _get_basis_for
    snippet = code(*args)
TypeError: 'dict' object is not callable

I must be doing something wrong (because all tests pass).

Now my questions are: 1) Are the derive_o_from_is parameters required ? If unspecified, will reikna try and infer the output dtype from all input parameters ? (i.e if float32, complex64 and int, will take complex64) ? 2) In the snippet above, mul=functions.mul(o1, i1) specifies the dtypes of each arguments to be multiplied. If other multiplications in the snippet occurs with inputs of different dtypes, what will happen ? (i.e. if I have mul=functions.mul(complex128, complex128), may I also safely multiply float64 numbers together with ${mul} ?)

Thanks!

fjarri commented 11 years ago

1) By default, the result_type function of all the arguments is used (see the reference entry for Transformation). So yes, float32, complex64 and int will give complex64. (On a side note, I'm thinking about ditching the whole generic transformation business and making them statically typed. This is just too much fuss without any benefits) 2) functions.mul() creates a device function (normal C function, basically), with specified C types. So if you use complex128 function with float64 arguments, the compiler will complain.

Now your error seems to be the result of the incorrect type of code argument to Elementwise preparation. It's not a dict anymore; see test_elementwise for the example.

tnorth commented 11 years ago

Thank you very much for your answer. Indeed, the Elementwise has changed as well, which I didn't notice. This behaviour makes perfect sense.