gwmod / nlmod

Python package to build, run and visualize MODFLOW 6 groundwater models in the Netherlands.
https://nlmod.readthedocs.io
MIT License
31 stars 2 forks source link

Update cache.py #326

Closed bdestombe closed 5 months ago

bdestombe commented 6 months ago

The netcdf cache function validates the cache by comparing the ds argument and other function arguments to the pickled arguments. If they match, the cache can be used.

Currently, just the coordinates of the argument ds and the output ds had to match, introducing two errors:

The PR compares the hash of the coords and data_vars of the ds argument to those that were stored in the pickle together with the cached output ds.

Ideally, the cache.cache_netcdf() accepts arguments that specify specifically which data_vars and coords need to be included in the validation check. Beyond the scope of this pr. Part of this PR.

OnnoEbbens commented 6 months ago

Thanks @bdestombe, the new set up looks a lot cleaner than the previous one. It does change the behavior of the cache function a bit. Because the caching can become complex I have tried to list the changes and our previous solutions to make sure we are on the same page before we decide on implementing this.

Changes in behavior of cache function

I think these are the changes in behavior of the cache function:

Issues in old behavior

The new behavior solves some issues:

If the data_vars differ and are used the cache is falsely valid

An example of this would be:

  1. I call the function get_bathymetry which requires the DataArray Northsea which is a datavar in the argument dataset. A cached dataset with the bathymetry is created.
  2. I call the function get_bathymetry again but meanwhile I changed the DataArray Northsea in the argument dataset. Old behavior: the cached Bathymetry data is falsely returned New behavior: the cached data is not used. The Bathymetry is computed using the new Northsea DataArray and the cached dataset is updated.

The notebook on caching gives a description on how we previously tackled this:

If one of the function arguments is an xarray Dataset we only check if the dataset has the same dimensions and coordinates as the cached netcdf file. There is no check on the variables (DataArrays) in the dataset because it would simply take too much time to check all the variables in the dataset. Also, most of the time it is not necessary to check all the variables as they are not used to create the cached file. There is one example where a variable from the dataset is used to create the cached file. The nlmod.read.jarkus.get_bathymetry uses the 'Northsea' DataArray to create a bathymetry dataset. When we access the 'Northsea' DataArray using ds['Northsea'] in the get_bathymetry function there would be no check if the 'Northsea' DataArray that was used to create the cache is the same as the 'Northsea' DataArray in the current function call. The current solution for this is to make the 'Northsea' DataArray a separate function argument in the get_bathymetry function. This makes it also more clear which data is used in the function.

Note that "it would simply take too much time to check all the variables in the dataset" is not so relevant anymore because your method of using the hashing is pretty fast I think. However, I do like the idea of 'Northsea' as a separate argument because it makes it more explicit what data is used in the function. At the same time you really have to know this when you create a function where you want to apply the cache to. I am not even 100% sure that we use this consistently in nlmod.

The coordintates of the ds argument have to match the coordinates of the output ds. This limits the use of the cache function.

If I understand this correctly the use of the cache function is limited when the cached dataset has coordinates that are irrelevant to the function. An example of this limitation would be:

  1. I have a cached dataset ahn.nc with a time dimension. This time dimension is not used in the get_ahn function nor in any of the data variables in this cached dataset.
  2. I call the function get_ahn using a dataset without a time dimension as an argument.
  3. The cache function will not return the cached dataset because the argument dataset has no time dimension.

This is related to this issue: https://github.com/gwmod/nlmod/issues/181. Our solution back then, which is not very well documented, is to modify the get_ahn function in such a way that only relevant coordinates are stored in the cache. So the cached ahn.nc will only contain the dimensions/coordinates that are used in the get_ahn function and not the time dimension.

advantages/disavantages

I tried to summarize the above by listing the advantages and disadvantages of the new behavior:

new behavior regarding checking dataset coordinates

new behavior regarding checking dataset datavars

I hope this whole story was somewhat understandable and complete. Please let me know if I made any mistakes or forgot anything. I think if we both agree on the changes and implications we can discuss shortly with @dbrakenhoff and @rubencalje and decide on implementing this.

bdestombe commented 6 months ago

Hi Onno, Thank you so much for sharing your thoughts! And I fully agree with your analysis. Just a short reply as I am off for a week long holiday.

So this approach is a bit pickier, but rather too picky than falsely validating the arguments and using the cache.

As I cryptically proposed in the last sentence, we could introduce a keyword argument to the cache_netcdf(). That allows you explicitly set which datavars and coordinates are checked, such that it is picky for the right reason. It would solve your Northsea case. If this argument is not provided, it might be a bit too picky, but rightfully so, as continuing with false cache is really hard to debug. If the provided list with datavars is not exhaustive, a clear error is given as the input ds only contains the datavars of the provided list.

Take care, Bas

bdestombe commented 6 months ago

Thus we need to include the following logic somewhere:

def ds_contains(ds, coords_2d=False, coords_3d=False, datavars=[], coords=[]):
    if coords_2d:
        coords.append("x")
        coords.append("y")

    return xr.Dataset(datavars=ds.datavars[datavars], coords=ds.coords[coords], attrs=ds.attrs)

But where do we put that logic? Case 1 or 2?

Case 1

@cache_netcdf(coords_2d=True, coords_3d=False, datavars=[], coords=[])
def get_ahn(ds, identifier="AHN4_DTM_5m"):
...
    return ahn

In one of the first lines of cache_netcdf() the following line would be added: ds = ds_contains(ds, coords_2d=coords_2d, coords_3d=coords_3d, datavars=datavars, coords=coords) If no arguments are passed to cache_netcdf() then any difference in ds invalidates the cache, including vars and coords that are not relevant for the computation of ahn.

Case 2

Or use ds_contains() as a decorator:

@ds_contains(ds, coords_2d=True, coords_3d=False, data_vars=[], coords=[])
@cache_netcdf()
def get_ahn(ds, identifier="AHN4_DTM_5m"):
...
    return ahn

No changes to cache_netcdf() are required, as it receives a stripped ds. I am not fully familiar with decorators and wrappers, so some guidance on this would be appreciated.

bdestombe commented 6 months ago

ds_contains() could also contain some helpful messages:

def ds_contains(ds, coords_2d=False, coords_3d=False, datavars=[], coords=[]):
    if coords_2d:
        coords.append("x")
        coords.append("y")

    if "northsea" in datavars and not in ds.datavars:
        raise Error("Northsea not in dataset. Run nlmod.read.rws.add_northsea() first.")

    return xr.Dataset(datavars=ds.datavars[datavars], coords=ds.coords[coords], attrs=ds.attrs)
OnnoEbbens commented 6 months ago

Nice idea I really like your solution.

Just to check if I understand your approach correctly. When you use the cache decorator on a function you have the option to list specific datavars/coordinates/attributes. If you choose to do so and you call the function a hash representation of the specific datavars/coordinates/attributes is created and saved together with the cache (as a pickle). When you call the function again the hash from the previous call is compared to a hash of the current call and if they differ the cache becomes invalid. Is this correct?

What do we do if no datavars/coordinates/attributes are specified, do we check everything, nothing or maybe only coordinates? I think checking all would make sense to be more fool proof.

Finally we have to choose between the two cases you mention. Personally I prefer case 1 because I think we have to add most of the logic in the cache_netcdf anyway. I am also not so familiar with decoraters and I have a hard time wrapping my head around two wrappers for one function :smirk:.

Do you have time to add this to the PR? If not I can also give it a try.

bdestombe commented 6 months ago

Ah indeed, default arguments of ds_contains() should lead to that the entire ds is passed.

I'll come up with something.

bdestombe commented 5 months ago

Hi Onno, seems to be working. Could you have a look at the code?