lfortran / lfortran

Official main repository for LFortran
https://lfortran.org/
Other
943 stars 152 forks source link

Implement any array expression in `where` #4330

Open certik opened 3 months ago

certik commented 3 months ago

Examples:

where (array(1, :) == 0)
  array(1, :) = 2
end where

WHERE( dfmxi <= epsi ) inrdone = .TRUE.

! Arbitrarily complicated expression
where (x**2 + x == y*(x + abs(y + f(x))))
  z = (x+y+abs(y))*x ! the LHS has to be just an array, RHS can be complicated
end where

I think these can all be transformed as follows:

test_condition = (array(1, :) == 0) ! temporary array
LHS =>array(1, :)  ! pointer
RHS = 2 ! temporary array
do i = a, b
  if (test_condition(i)) then
    LHS(i) = RHS(i)
  end if
end do

test_condition = (dfmxi <= epsi) ! temporary array
LHS =>inrdone   ! pointer
RHS =  .TRUE. ! temporary array
do i = a, b
  if (test_condition(i)) then
    LHS(i) = RHS(i)
  end if
end do

test_condition = (x**2 + x == y*(x + abs(y + f(x))))  ! temporary array
LHS => z ! pointer
RHS = (x+y+abs(y))*x ! temporary array
do i = a, b
  if (test_condition(i)) then
    LHS(i) = RHS(i)
  end if
end do
certik commented 3 months ago

Note: the reworked array_op pass by @czgdp1807 and @gxyd should be smart enough to rewrite:

test_condition = (x**2 + x == y*(x + abs(y + f(x))))  ! temporary array
LHS => z ! pointer
RHS = (x+y+abs(y))*x ! temporary array
do i = a, b
  if (test_condition(i)) then
    LHS(i) = RHS(i)
  end if
end do

into:

do i = a, b
  test_condition(i) = (x(i)**2 + x(i) == y(i)*(x(i) + abs(y(i) + f(x(i)))))
end do
LHS => z ! pointer
do i = a, b
  RHS(i) = (x(i)+y(i)+abs(y(i)))*x(i)
end do
do i = a, b
  if (test_condition(i)) then
    LHS(i) = RHS(i)
  end if
end do

And then to speed this up in Release mode, all we need to do is to implement an optimization pass that will recognize that all these loops and the pointer are the same range and can be fused:

do i = a, b
  test_condition(i) = (x(i)**2 + x(i) == y(i)*(x(i) + abs(y(i) + f(x(i)))))
  LHS => z(i)
  RHS(i) = (x(i)+y(i)+abs(y(i)))*x(i)
  if (test_condition(i)) then
    LHS = RHS(i)
  end if
end do

and then the temporaries can be removed with another pass:

do i = a, b
  if(x(i)**2 + x(i) == y(i)*(x(i) + abs(y(i) + f(x(i))))) then
    z(i) = (x(i)+y(i)+abs(y(i)))*x(i)
  end if
end do

So this becomes the optimal code for this case. Other subsequent passes will then vectorize this loop, etc.

I think this design keeps the door open to maximum performance. Alternatively, if the new array_op pass is implemented in a modular manner, we can use the "array expression indexer" directly in the where pass to skip these rewrites, the code might not be much more complex, and it should compile faster. Either way should work.

assem2002 commented 2 months ago

We need to address many many errors that could happen in where statement. If have bunch of them listed, should i mention them here, create one issue for all Where Errors or create separate ones?

gxyd commented 1 week ago

A much simpler example with where, which doesn't work currently (neither in main nor in simplifier_pass branch):

program main
    implicit none
    real a(4)
    a = [1.0, 2.0, 1.0, 4.0]
    where(a == 1.0) a = a * [100.0, 2.0, 3.0, 4.0]

    print *, a
end program main

The error, I get on main is:

ASR verify pass error: The variable in ArrayItem must be an array, not a scalar
Internal Compiler Error: Unhandled exception
Traceback (most recent call last):
LCompilersException: Verify failed in the pass: where

while the error on simplifier_pass branch currently is:

declare void @_lpython_call_initial_functions(i32, i8**)

declare i8* @_lcompilers_string_format_fortran(i32, i8*, ...)

declare void @_lfortran_printf(i8*, ...)
asr_to_llvm: module failed verification. Error:
Stored value type does not match pointer operand type!
  store float* %31, float* %25, align 8
 float
code generation error: asr_to_llvm: module failed verification. Error:
Stored value type does not match pointer operand type!
  store float* %31, float* %25, align 8
 float

Note: Please report unclear, confusing or incorrect messages as bugs at
https://github.com/lfortran/lfortran/issues