Unexpected argument error on DU case pattern matching inside kernel #150

Open artem-burashnikov opened 1 year ago

artem-burashnikov commented 1 year ago

Describe the bug

Unhandled exception. System.Exception: Unexpected argument: Case1
   at Microsoft.FSharp.Core.PrintfModule.PrintFormatToStringThenFail@1448.Invoke(String message)
   at Microsoft.FSharp.Core.PrintfModule.gprintf[a,TState,TResidue,TResult,TPrinter](FSharpFunc`2 envf, PrintfFormat`4 format) in D:\a\_work\1\s\src\FSharp.Core\printf.fs:line 1398
   at Brahma.FSharp.ClProgram`2.setupArgument(Kernel kernel, Int32 index, Object arg)
   at lambda_method3(Closure, Kernel, Int32, Object)
   at Microsoft.FSharp.Collections.ArrayModule.IterateIndexed[T](FSharpFunc`2 action, T[] array) in D:\a\_work\1\s\src\FSharp.Core\array.fs:line 437
   at lambda_method4(Closure, FSharpFunc`2, Object[])
   at FSharp.Quotations.Evaluator.Tools.Invoke@119-13D.Invoke(g g, h h, i i, j j)
   at ImageProcessing.Main.arrayMap2@37-1.Invoke(Unit unitVar0) in /home/dxd/ImageProcessing/src/ImageProcessing/Main.fs:line 38
   at <StartupCode$Brahma-FSharp-OpenCL-Core>.$CommandQueueProvider.loop@148-6.Invoke(Unit unitVar)
   at Microsoft.FSharp.Control.AsyncPrimitives.CallThenInvoke[T,TResult](AsyncActivation`1 ctxt, TResult result1, FSharpFunc`2 part2) in D:\a\_work\1\s\src\FSharp.Core\async.fs:line 510
   at Microsoft.FSharp.Control.Trampoline.Execute(FSharpFunc`2 firstAction) in D:\a\_work\1\s\src\FSharp.Core\async.fs:line 112
--- End of stack trace from previous location ---
   at Microsoft.FSharp.Control.AsyncPrimitives.Start@1174-1.Invoke(ExceptionDispatchInfo edi) in D:\a\_work\1\s\src\FSharp.Core\async.fs:line 1174
   at Microsoft.FSharp.Control.Trampoline.Execute(FSharpFunc`2 firstAction) in D:\a\_work\1\s\src\FSharp.Core\async.fs:line 112
   at <StartupCode$FSharp-Core>.$Async.clo@193-15.Invoke(Object o) in D:\a\_work\1\s\src\FSharp.Core\async.fs:line 195
   at System.Threading.ThreadPoolWorkQueue.Dispatch()
   at System.Threading.PortableThreadPool.WorkerThread.WorkerThreadStart()

To Reproduce Steps to reproduce the behavior:

  1. Using the final version of the example code from the "Basic examples"

  2. Create a DU type:

    type MyDU =
    | Case1
    | Case2
    | Case3
  3. Modify the kernel so it utilizes pattern matching over the DU:

    let kernel =
        fun (range: Range1D) arrLength (array1: ClArray<_>) (array2: ClArray<_>) (result: ClArray<_>) (myDU: MyDU) ->
            let i = range.GlobalID0
            if i < arrLength then
                match myDU with
                | Case1 -> result[i] <- (%operation) array1[i] array2[i]
                | Case2 -> result[i] <- (%operation) array1[i] array2[i]
                | Case3 -> result[i] <- (%operation) array1[i] array2[i]
  4. Add a parameter to the resulting lambda-expression of arrayMap2 function:

    fun (myDU: MyDU) (commandQueue: MailboxProcessor<_>) (inputArray1: ClArray<_>) (inputArray2: ClArray<_>) ->
    let ndRange = Range1D.CreateValid(inputArray1.Length, workGroupSize)
    let outputArray =
        clContext.CreateClArray(inputArray1.Length, allocationMode = AllocationMode.Default)
    let kernel = kernel.GetKernel()
        Msg.MsgSetArguments(fun () ->
            kernel.KernelFunc ndRange inputArray1.Length inputArray1 inputArray2 outputArray myDU)
  5. Now pass matching cases to map functions:

    let intArraySum = arrayMap2 <@ (+) @> context 64 Case1
    let boolArraySum = arrayMap2 <@ (&&) @> context 64 Case2
    let arrayMask = arrayMap2 <@ fun x y -> if y then x else 0 @> context 64 Case3
  6. Observe the error. Final version of the code looks like this:

    open FSharp.Core
    open Brahma.FSharp

module Main =

type MyDU =
    | Case1
    | Case2
    | Case3

let arrayMap2 operation (clContext: ClContext) workGroupSize =

    let kernel =
            fun (range: Range1D) arrLength (array1: ClArray<_>) (array2: ClArray<_>) (result: ClArray<_>) (myDU: MyDU) ->
                let i = range.GlobalID0

                if i < arrLength then
                    match myDU with
                    | Case1 -> result[i] <- (%operation) array1[i] array2[i]
                    | Case2 -> result[i] <- (%operation) array1[i] array2[i]
                    | Case3 -> result[i] <- (%operation) array1[i] array2[i]

    let kernel = clContext.Compile kernel

    fun (myDU: MyDU) (commandQueue: MailboxProcessor<_>) (inputArray1: ClArray<_>) (inputArray2: ClArray<_>) ->
        let ndRange = Range1D.CreateValid(inputArray1.Length, workGroupSize)

        let outputArray =
            clContext.CreateClArray(inputArray1.Length, allocationMode = AllocationMode.Default)

        let kernel = kernel.GetKernel()

            Msg.MsgSetArguments(fun () ->
                kernel.KernelFunc ndRange inputArray1.Length inputArray1 inputArray2 outputArray myDU)

        commandQueue.Post(Msg.CreateRunMsg<_, _> kernel)


let main argv =

    let n = if argv.Length > 0 then int argv[0] else 10

    let device = ClDevice.GetFirstAppropriateDevice()

    let context = ClContext(device)
    let mainQueue = context.QueueProvider.CreateQueue()

    let intArraySum = arrayMap2 <@ (+) @> context 64 Case1
    let boolArraySum = arrayMap2 <@ (&&) @> context 64 Case2
    let arrayMask = arrayMap2 <@ fun x y -> if y then x else 0 @> context 64 Case3

    let rnd = System.Random()

    let randomIntArray () =
        Array.init n (fun _ -> rnd.Next() / 10000)

    let randomBoolArray () =
        Array.init n (fun _ -> rnd.Next() % 2 = 1)

    let intA1 = randomIntArray ()
    let intA2 = randomIntArray ()

    let boolA1 = randomBoolArray ()
    let boolA2 = randomBoolArray ()
    let boolA3 = randomBoolArray ()

    let clIntA1 = context.CreateClArray<_>(intA1)
    let clIntA2 = context.CreateClArray<_>(intA2)
    let clBoolA1 = context.CreateClArray<_>(boolA1)
    let clBoolA2 = context.CreateClArray<_>(boolA2)
    let clBoolA3 = context.CreateClArray<_>(boolA3)

    let intRes = intArraySum mainQueue clIntA1 clIntA2

    let boolRes =
        boolArraySum mainQueue clBoolA1 clBoolA2 |> boolArraySum mainQueue clBoolA3

    let res = arrayMask mainQueue intRes boolRes

    let resOnHost = Array.zeroCreate n
    let res = mainQueue.PostAndReply(fun ch -> Msg.CreateToHostMsg(res, resOnHost, ch))

    printfn "First int array:  %A" intA1
    printfn "Second int array: %A" intA2

    printfn "First bool array:  %A" boolA1
    printfn "Second bool array: %A" boolA2
    printfn "Third bool array:  %A" boolA3

    printfn "Result: %A" res


**Expected behavior**
The code should work.

**Actual behavior**
It doesn't.

**Desktop (please complete the following information):**
 - OS: Ubuntu 22.04

**Additional context**
 - Brahma.FSharp 2.0.1
 - OpenCL 2.0