B612-Asteroid-Institute / precovery

Fast precovery of small body observations at scale
BSD 3-Clause "New" or "Revised" License
6 stars 2 forks source link

Factor of 2x speed up increase #18

Closed moeyensj closed 2 years ago

moeyensj commented 2 years ago

On-going work to address #17.

stevenstetzler commented 2 years ago

I've added several new functions that implement a "vectorized" optimization i.e. functions are evaluated using arrays of ra/dec/mjd instead of scalars. These new functions are:

I've added an --opt argument to scripts/precovery-test.py that will use precovery_opt instead of precovery to make before/after comparisons easier. Timing results can be compared with:

$ source scripts/profile.sh # before opt
$ source scripts/profile.sh --opt # after opt

There were a few levels of optimization, so I'll describe them in order I implemented. The code as written does something like:

precover
- calls _check_window(mjd, obscode, orbit, tolerance)
  - calls compute_ephemeris(obscode, mjd)
  - calls radec_to_healpixel(ra, dec, nside)
  - calls approximately_propagate(obscode, orbit, timedelta)
  - calls radec_to_healpixel(approx_ra, approx_dec, nside)
  - calls _check_frames(orbit, healpix, obscode, mjd, tolerance)
    - calls compute_ephemeris(obscode, mjd)
    - calls obs.distance(ephem)

Which results in ~100k calls to compute_ephemeris/radec_to_healpixel/distance.

First, optimizing ephemeris computation:

Most of the ephermeris computations were happening in _check_frames, so I started there and made the changes:

In actuality, the code uses _check_frames_opt_2 which vectorizes distance computation as well by grabbing all observations and using the haversine_distance_deg function instead of obs.distance to compute the distance between all observations and an orbit ephemeris:

# _check_frames_opt_2(orbit, healpixel, obscode, mjd, tolerance)
for f in frames:
    logger.info("checking frame: %s", f)
    obs = np.array(list(self.frames.iterate_observations(f)))
    n = len(obs)
    obs_ras = np.array([o.ra for o in obs])
    obs_decs = np.array([o.dec for o in obs])
    distances = haversine_distance_deg(
        exact_ephem.ra,
        obs_ras,
        exact_ephem.dec,
        obs_decs,
    )
    dras = exact_ephem.ra - obs_ras
    ddecs = exact_ephem.dec - obs_decs
    # filter to observations with distance below tolerance
    idx = distances < tolerance
    distances = distances[idx]
    dras = dras[idx]
    ddecs = ddecs[idx]
    obs = obs[idx]
    for o, distance, dra, ddec in zip(obs, distances, dras, ddecs):
        candidate = PrecoveryCandidate(
            ra=o.ra,
            dec=o.dec,
            ra_sigma=o.ra_sigma,
            dec_sigma=o.dec_sigma,
            mag=o.mag,
            mag_sigma=o.mag_sigma,
            filter=f.filter,
            obscode=f.obscode,
            mjd=f.mjd,
            catalog_id=f.catalog_id,
            id=o.id.decode(),
            dra=dra,
            ddec=ddec,
            distance=distance,
        )
        yield candidate

which uses numpy indexing to determine for which observations distance < tolerance.

Finally, the _check_window_opt function is changed to _check_windows_opt, which modifies _check_window_opt to accept and array of window midpoints and use compute_ephemeris_opt instead of compute_ephemeris to propagate the orbit to window midpoint:

# _check_windows_opt(window_midpoints, obscode, orbit, tolerance)
window_ephems = orbit.compute_ephemeris_opt(obscode, window_midpoints)
window_healpixels = radec_to_healpixel(
    np.array([w.ra for w in window_ephems]),
    np.array([w.dec for w in window_ephems]),
    self.frames.healpix_nside,
).astype(int)

for window_midpoint, window_ephem, window_healpixel in zip(
    window_midpoints, window_ephems, window_healpixels
):
    ...

This also requires a change precover_opt to account for the fact that we can only propagate times that have the same corresponding obscode. To solve this I do a groupby:

# precover_opt(orbit, tolerance, max_matches, start_mjd, end_mjd)
windows = self.frames.idx.window_centers(start_mjd, end_mjd, self.window_size)
for obscode, obs_windows in itertools.groupby(
    windows, key=lambda pair: pair[1]
):
    mjds = [window[0] for window in obs_windows]
    matches = self._check_windows_opt(mjds, obscode, orbit, tolerance)

As written, the code structure is now:

precover_opt
- calls _check_windows_opt(mjds, obscode, orbit, tolerance)
  - calls compute_ephemeris_opt(obscode, window_midpoints)
  - calls radec_to_healpixel(ras, decs, nside)
  - calls approximately_propagate_opt(timdeltas)
  - calls radec_to_healpixel(approx_ras, approx_decs)
  - calls _check_frames_opt_2(orbit, keep_approx_healpixels, obscode, keep_mjds, tolerance)
    - calls compute_ephemeris_opt(obscode, mjds)
    - calls haversine_distance_deg(exact_ephem.ra, obs_ras, exact_ephem.dec, obs_decs)

and calls compute_ephemeris_opt 683 times.

Profiling results with the current code are:

Found 27 potential matches for orbit ID: 330784
    orbit_id                   id  ...          ddec  distance
0     330784     c4d.230114.7.517  ...  5.349504e-05  0.000084
1     330784    c4d.230826.13.872  ...  6.738136e-05  0.000091
2     330784   c4d.231472.47.1005  ...  6.280554e-05  0.000089
3     330784    c4d.232238.26.124  ...  2.968479e-05  0.000030
4     330784    c4d.232344.26.468  ...  5.881602e-05  0.000076
5     330784    c4d.232345.26.457  ...  6.842665e-05  0.000087
6     330784   c4d.232751.48.1401  ...  6.014635e-05  0.000079
7     330784   c4d.232752.48.1507  ...  6.366128e-05  0.000080
8     330784    c4d.232951.46.266  ... -4.628035e-05  0.000083
9     330784     c4d.233015.30.38  ...  3.616820e-05  0.000042
10    330784     c4d.233134.30.94  ...  5.665683e-05  0.000065
11    330784    c4d.233364.51.272  ...  3.803622e-05  0.000043
12    330784    c4d.233365.51.226  ...  4.792459e-05  0.000048
13    330784    c4d.233473.9.1618  ...  6.756319e-05  0.000080
14    330784    c4d.233475.9.1517  ...  6.549203e-05  0.000080
15    330784    c4d.237663.60.588  ...  5.767154e-05  0.000060
16    330784    c4d.237968.59.683  ...  4.574217e-05  0.000062
17    330784    c4d.238654.38.733  ...  4.623282e-05  0.000046
18    330784    c4d.238655.44.500  ...  1.607676e-04  0.000172
19    330784    c4d.238656.44.583  ...  1.769334e-04  0.000189
20    330784    c4d.238657.44.699  ...  1.943913e-04  0.000210
21    330784     c4d.420636.47.31  ... -3.123492e-05  0.000154
22    330784    c4d.420637.47.114  ... -4.783249e-05  0.000152
23    330784    c4d.500315.25.291  ...  2.380609e-05  0.000025
24    330784    c4d.500316.25.296  ... -6.973607e-07  0.000027
25    330784  c4d.777125.15.13724  ...  2.949515e-06  0.000127
26    330784   c4d.777435.15.9269  ...  4.504595e-05  0.000121

[27 rows x 15 columns]
         23941170 function calls (23777916 primitive calls) in 51.888 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      683   23.616    0.035   26.131    0.038 orbit.py:225(compute_ephemeris_opt)
  1543349    3.140    0.000    3.140    0.000 {method 'fetchone' of 'sqlite3.Cursor' objects}
   524453    2.082    0.000    2.885    0.000 frame_db.py:529(iterate_observations)
  1954135    1.496    0.000    7.802    0.000 frame_db.py:220(<genexpr>)
      219    1.371    0.006   44.634    0.204 precovery_db.py:208(_check_windows_opt)
     4190    1.324    0.000    1.324    0.000 {built-in method numpy.array}
  1543347    1.216    0.000    5.995    0.000 result.py:381(iterrows)
      667    1.041    0.002    1.041    0.002 {method 'execute' of 'sqlite3.Cursor' objects}
    36932    0.999    0.000    1.546    0.000 inspect.py:744(cleandoc)
   412778    0.874    0.000    0.874    0.000 orbit.py:263(__init__)
  1543347    0.864    0.000    4.779    0.000 cursor.py:1791(_fetchiter_impl)
  1543348    0.769    0.000    3.915    0.000 cursor.py:953(fetchone)
   412591    0.743    0.000    8.962    0.000 frame_db.py:191(propagation_targets)
stevenstetzler commented 2 years ago

Based on my explanation and review of the code, the distance_opt method of an Observation is not used in this implementation and instead the haversine_distance_deg function is used with the ras/decs of many observations against a single orbit ephemeris.

stevenstetzler commented 2 years ago

Additionally, I made a script scripts/sq-execute-test.py and scripts/profile-sql.sh that tests whether many small SQL queries are slower that one large one. The answer is they are roughly the same. The comparison is:

for i in range(len(mjds) - 1):
    stmt = sq.select(
        db.frames.idx.frames.c.mjd, db.frames.idx.frames.c.obscode
    ).where(
        db.frames.idx.frames.c.mjd <= mjds[i + 1],
        db.frames.idx.frames.c.mjd >= mjds[i],
    )
    rows = db.frames.idx.dbconn.execute(stmt)
    for row in rows:
       ...

vs constructing a big 'or' where clause:

stmt = sq.select(
    db.frames.idx.frames.c.mjd, db.frames.idx.frames.c.obscode
).where(
    or_(
        *[
            and_(
                db.frames.idx.frames.c.mjd <= mjds[i + 1],
                db.frames.idx.frames.c.mjd >= mjds[i],
            )
            for i in range(len(mjds) - 1)
        ]
    )
)
rows = db.frames.idx.dbconn.execute(stmt).fetchall()

The profiling:

$ source profile-sql.sh --sequential
99 queries got 86682 rows in 0.6292250156402588s
         5071008 function calls (5011474 primitive calls) in 5.197 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    36932    0.435    0.000    0.671    0.000 inspect.py:744(cleandoc)
        1    0.306    0.306    5.203    5.203 sq-execute-test.py:1(<module>)
  184/181    0.224    0.001    0.229    0.001 {built-in method _imp.create_dynamic}
     1496    0.212    0.000    0.212    0.000 {built-in method marshal.loads}
 4028/498    0.139    0.000    0.381    0.001 sre_parse.py:493(_parse)
    20643    0.129    0.000    0.132    0.000 sre_parse.py:172(append)
    86782    0.112    0.000    0.112    0.000 {method 'fetchone' of 'sqlite3.Cursor' objects}

fetchone is called 86782 in 0.112s. In comparison:

$ source profile-sql.sh
1 queries got 86682 rows in 0.6342263221740723s
         4639521 function calls (4578209 primitive calls) in 6.157 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    36932    0.532    0.000    0.821    0.000 inspect.py:744(cleandoc)
     1496    0.340    0.000    0.340    0.000 {built-in method marshal.loads}
        2    0.274    0.137    0.274    0.137 {method 'fetchall' of 'sqlite3.Cursor' objects}

fetchall is called twice (why not once?) in 0.274s. Maybe the difference is larger once the number of queries gets big, but it doesn't seem it will make a large difference in runtime for ~100-1000 queries, which I think in effect is the maximum "pooling" size we could expect given the window/frame code structure.

moeyensj commented 2 years ago
  • Note that this never occurs, so linear motion is never used to approximately propagate. Which might not be the desired behavior?

This was disabled since we found that linear propagation didn't work as well as intended. Two-body is thankfully fast enough as an approximate method.

moeyensj commented 2 years ago

@stevenstetzler Thanks for the fantastic work, I've corrected a few minor issues and accepted all the optimized functions as the new default in precovery.