sot / xija

Thermal modeling framework for Chandra X-ray Observatory
https://sot.github.io/xija
BSD 3-Clause "New" or "Revised" License
9 stars 5 forks source link

Improve performance of bad_times and mask_times #123

Closed taldcroft closed 2 years ago

taldcroft commented 2 years ago

Description

This fixes a performance issue note on Slack in #thermal_working_group:

I am doing some work that requires me to run the thermal models many, many times in a row, and was profiling my code when I noticed that the pm2thv1t model is running significantly slower than the pm1thv2t model, e.g. in my MATLAB setup:

This PR vectorizes the bad times processing and minimizes the number of date --> time (CXC second) conversions. It also factors this processing out into a separate function (working to improve separation in this code).

Some points of note:

Interface impacts

None.

Testing

Unit tests

Independent check of unit tests by [REVIEWER NAME]

Functional tests

I ran this script on master (version 4.26.1) and this branch (4.27.1.dev1+g19aed41).

from time import time

import xija
from xija.get_model_spec import get_xija_model_spec

print(xija.__version__)

def print_mdl_times(mdl):
    print(f"{len(mdl.bad_times)=}")
    print(f"{len(mdl.bad_times_indices)=}")
    print(f"{len(mdl.mask_times)=}")
    print(f"{len(mdl.mask_time_secs)=}")
    print(f"{len(mdl.mask_times_bad)=}")

def timer_func(func):
    # This function shows the execution time of
    # the function object passed
    def wrap_func(*args, **kwargs):
        t1 = time()
        result = func(*args, **kwargs)
        t2 = time()
        name = kwargs["model_spec"]["name"]
        datestart, datestop = args[1:3]
        print(f"{name=} {datestart=} {datestop=} executed in {(t2-t1) * 1000:.1f} ms")
        print_mdl_times(result)
        return result

    return wrap_func

# No bad times in spec1
spec1, version = get_xija_model_spec("pm1thv2t", version="3.40.2")

# 5876 bad times in spec2
spec2, version = get_xija_model_spec("pm2thv1t", version="3.40.2")

XijaModel = timer_func(xija.XijaModel)

mdl = XijaModel("mdl", "2019:001", "2022:301", model_spec=spec1)
mdl = XijaModel("mdl", "2025:001", "2025:002", model_spec=spec1)
mdl = XijaModel("mdl", "2019:001", "2022:301", model_spec=spec2)
mdl = XijaModel("mdl", "2025:001", "2025:002", model_spec=spec2)
mdl = XijaModel("mdl", "2022:083:22:30:00", "2022:084:04:00:00", model_spec=spec2)

Master version 4.26.1

4.26.1
name='pm1thv2t' datestart='2019:001' datestop='2022:301' executed in 6.7 ms
len(mdl.bad_times)=0
len(mdl.bad_times_indices)=0
len(mdl.mask_times)=0
len(mdl.mask_time_secs)=0
len(mdl.mask_times_bad)=0
name='pm1thv2t' datestart='2025:001' datestop='2025:002' executed in 1.1 ms
len(mdl.bad_times)=0
len(mdl.bad_times_indices)=0
len(mdl.mask_times)=0
len(mdl.mask_time_secs)=0
len(mdl.mask_times_bad)=0
name='pm2thv1t' datestart='2019:001' datestop='2022:301' executed in 407.7 ms
len(mdl.bad_times)=5876
len(mdl.bad_times_indices)=3842
len(mdl.mask_times)=5876
len(mdl.mask_time_secs)=5876
len(mdl.mask_times_bad)=5876
name='pm2thv1t' datestart='2025:001' datestop='2025:002' executed in 396.8 ms
len(mdl.bad_times)=5876
len(mdl.bad_times_indices)=0
len(mdl.mask_times)=5876
len(mdl.mask_time_secs)=5876
len(mdl.mask_times_bad)=5876
name='pm2thv1t' datestart='2022:083:22:30:00' datestop='2022:084:04:00:00' executed in 382.5 ms
len(mdl.bad_times)=5876
len(mdl.bad_times_indices)=2
len(mdl.mask_times)=5876
len(mdl.mask_time_secs)=5876
len(mdl.mask_times_bad)=5876

Dev version 4.27.1.dev1+g19aed41

4.27.1.dev2+g859870c
name='pm1thv2t' datestart='2019:001' datestop='2022:301' executed in 7.2 ms
len(mdl.bad_times)=0
len(mdl.bad_times_indices)=0
len(mdl.mask_times)=0
len(mdl.mask_time_secs)=0
len(mdl.mask_times_bad)=0
name='pm1thv2t' datestart='2025:001' datestop='2025:002' executed in 1.0 ms
len(mdl.bad_times)=0
len(mdl.bad_times_indices)=0
len(mdl.mask_times)=0
len(mdl.mask_time_secs)=0
len(mdl.mask_times_bad)=0
name='pm2thv1t' datestart='2019:001' datestop='2022:301' executed in 19.8 ms
len(mdl.bad_times)=3842
len(mdl.bad_times_indices)=3842
len(mdl.mask_times)=3842
len(mdl.mask_time_secs)=3842
len(mdl.mask_times_bad)=3842
name='pm2thv1t' datestart='2025:001' datestop='2025:002' executed in 6.0 ms
len(mdl.bad_times)=0
len(mdl.bad_times_indices)=0
len(mdl.mask_times)=0
len(mdl.mask_time_secs)=0
len(mdl.mask_times_bad)=0
name='pm2thv1t' datestart='2022:083:22:30:00' datestop='2022:084:04:00:00' executed in 5.7 ms
len(mdl.bad_times)=2
len(mdl.bad_times_indices)=2
len(mdl.mask_times)=2
len(mdl.mask_time_secs)=2
len(mdl.mask_times_bad)=2