omlins / ParallelStencil.jl

Package for writing high-level code for parallel high-performance stencil computations that can be deployed on both GPUs and CPUs
BSD 3-Clause "New" or "Revised" License
301 stars 31 forks source link

Allow for return statements in nested functions within parallel kernels #117

Closed omlins closed 10 months ago

omlins commented 10 months ago

This enables the following for example:

                @testset "nested function (long definition, array modification)" begin
                    A  = @zeros(4, 5, 6)
                    @parallel_indices (ix,iy,iz) function write_indices!(A)
                        function compute_indices!(A)
                            A[ix,iy,iz] = ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2);
                            return
                        end
                        compute_indices!(A)
                        return
                    end
                    @parallel write_indices!(A);
                    @test all(Array(A) .== [ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2) for ix=1:size(A,1), iy=1:size(A,2), iz=1:size(A,3)])
                end;
                @testset "nested function (short definition, array modification)" begin
                    A  = @zeros(4, 5, 6)
                    @parallel_indices (ix,iy,iz) function write_indices!(A)
                        compute_indices!(A) = (A[ix,iy,iz] = ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2); return)
                        compute_indices!(A)
                        return
                    end
                    @parallel write_indices!(A);
                    @test all(Array(A) .== [ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2) for ix=1:size(A,1), iy=1:size(A,2), iz=1:size(A,3)])
                end;
                @testset "nested function (long definition, return value)" begin
                    A  = @zeros(4, 5, 6)
                    @parallel_indices (ix,iy,iz) function write_indices!(A)
                        function compute_indices(A)
                            return ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2)
                        end
                        A[ix,iy,iz] = compute_indices(A)
                        return
                    end
                    @parallel write_indices!(A);
                    @test all(Array(A) .== [ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2) for ix=1:size(A,1), iy=1:size(A,2), iz=1:size(A,3)])
                end;
                @testset "nested function (short definition, return value)" begin
                    A  = @zeros(4, 5, 6)
                    @parallel_indices (ix,iy,iz) function write_indices!(A)
                        compute_indices(A) = return ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2)
                        A[ix,iy,iz] = compute_indices(A)
                        return
                    end
                    @parallel write_indices!(A);
                    @test all(Array(A) .== [ix + (iy-1)*size(A,1) + (iz-1)*size(A,1)*size(A,2) for ix=1:size(A,1), iy=1:size(A,2), iz=1:size(A,3)])
                end;
omlins commented 10 months ago

CC @albert-de-montserrat