devitocodes / devito

DSL and compiler framework for automated finite-differences and stencil computation
http://www.devitoproject.org
MIT License
561 stars 228 forks source link

compiler: Redundant haloupdate #2448

Open georgebisbas opened 2 months ago

georgebisbas commented 2 months ago

Elastic-like cross-loop dependencies generate a redundant haloupdate.

Problem in Pseudocode:

for time
  haloupd vx[t0]
  write to vx[t1] - read from vx[t0]
  haloupd vx[t1]
  read from vx[t1]
  read from vx[t0] 

it could be

<(haloup vx[t0])>
for time
  ---DROP haloupd vx[t0]---since previously written t1 is now t0 and is already updated
  write to vx[t1] - read from vx[t0]
  haloupd vx[t1]
  read from vx[t1]
  read from vx[t0] 

Python script to reproduce: (reduced, starting from https://github.com/devitocodes/devito/blob/master/examples/seismic/tutorials/06_elastic_varying_parameters.ipynb)

DEVITO_LOGGING=DEBUG DEVITO_MPI=1 mpirun -n 1 python3 tests/elastic_mfe_1d.py 
DEVITO_LOGGING=DEBUG DEVITO_MPI=1 ../../tmpi/tmpi 1 python3 tests/elastic_mfe_1d.py 

elastic_mfe_1d.py:

import numpy as np

from devito import (SpaceDimension, Grid, TimeFunction, Eq, Operator,
                    solve, Constant)
from examples.seismic.source import TimeAxis, Receiver

# Space related
extent = (1500., )
shape = (201, )
x = SpaceDimension(name='x', spacing=Constant(name='h_x', value=extent[0]/(shape[0]-1)))
grid = Grid(extent=extent, shape=shape, dimensions=(x, ))

# Time related
t0, tn = 0., 30.
dt = (10. / np.sqrt(2.)) / 6.
time_range = TimeAxis(start=t0, stop=tn, step=dt)

# Velocity and pressure fields
so = 2
to = 1
v = TimeFunction(name='v', grid=grid, space_order=so, time_order=to)
tau = TimeFunction(name='tau', grid=grid, space_order=so, time_order=to)

# The receiver
nrec = 1
rec = Receiver(name="rec", grid=grid, npoint=nrec, time_range=time_range)
rec.coordinates.data[:, 0] = np.linspace(0., extent[0], num=nrec)
rec_term = rec.interpolate(expr=v)

# First order elastic-like dependencies equations
pde_v = v.dt - (tau.dx)
pde_tau = (tau.dt - ((v.forward).dx))

u_v = Eq(v.forward, solve(pde_v, v.forward))
u_tau = Eq(tau.forward, solve(pde_tau, tau.forward))

op = Operator([u_v] + [u_tau] + rec_term)
op.apply(dt=dt)

# print(op.ccode)

generated code includes: (where haloupdate1(v_vec,comm,nb,t0); is redundant)

for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))
  {
    START(section0)
    haloupdate0(tau_vec,comm,nb,t0);
--------------------
    haloupdate1(v_vec,comm,nb,t0);
--------------------
    for (int x = x_m; x <= x_M; x += 1)
    {
      v[t1][x + 2] = dt*(r0*v[t0][x + 2] - r1*tau[t0][x + 2] + r1*tau[t0][x + 3]);
    }
    haloupdate0(v_vec,comm,nb,t1);
    for (int x = x_m; x <= x_M; x += 1)
    {
      tau[t1][x + 2] = dt*(r0*tau[t0][x + 2] - r1*v[t1][x + 2] + r1*v[t1][x + 3]);
    }
    STOP(section0,timers)

    START(section1)
    for (int p_rec = p_rec_m; p_rec <= p_rec_M; p_rec += 1)
    {
      float r5 = r2*(-o_x + rec_coords[p_rec][0]);
      float r4 = floorf(r5);
      int posx = (int)r4;
      float px = -r4 + r5;
      float sum = 0.0F;

      for (int rrecx = 0; rrecx <= 1; rrecx += 1)
      {
        if (rrecx + posx >= x_m - 1 && rrecx + posx <= x_M + 1)
        {
          sum += (rrecx*px + (1 - rrecx)*(1 - px))*v[t0][rrecx + posx + 2];
        }
      }

      rec[time][p_rec] = sum;
    }
    STOP(section1,timers)
georgebisbas commented 1 month ago

@FabioLuporini here

georgebisbas commented 1 month ago

Reminder for @georgebisbas to open as PR with test