SatelliteShorelines / CoastSeg

An interactive toolbox for downloading satellite imagery, applying image segmentation models, mapping shoreline positions and more. The mapping extension for CoastSat and Zoo.
https://satelliteshorelines.github.io/CoastSeg/
GNU General Public License v3.0
42 stars 5 forks source link

Feature Request: Add Land Mask to Zoo Workflow #233

Closed 2320sharon closed 3 months ago

dbuscombe-usgs commented 4 months ago

Basic land mask workflow

def return_timeav(valid_files,time_var):
    da = xr.concat([load_xarray_data(i) for i in valid_files],dim=time_var)
    timeav = da.mean("time", skipna=True)
    return timeav, da

Make a time average label images

    timeav, da = return_timeav(valid_files,time_var)
    mask_land = np.array(np.round(timeav)==3).astype('int')

In this implementation the land mask is used to mask each frame in the xarray, makes a new label image and writes out to a revised npz

    good_image_files = [i.replace(folder+os.sep+'good', image_path).replace('_res.npz','.jpg') for i in files_good]

    da = xr.concat([load_xarray_softmax(i,1) for i in files_good],dim=time_var_good)

    # apply land mask, and filter the whitewater class
    for npzf,f, time in zip(files_good,good_image_files, [str(l.to_numpy()) for l in time_var_good]):

        dat_dict = {}
        with np.load(npzf) as data:
            for k in list(data):
                dat_dict[k] = data[k]

        frame = da.sel(time=time).to_numpy()
        print(np.unique(frame))

        if mask_land is not None:
            frame[frame==3] = 0
            frame[mask_land==1] = 3

        da.sel(time=time).values = frame
        if len(frame.shape)==2:
            dat_dict['grey_label'] = frame
        else:
            dat_dict['grey_label'] = frame[0,:,:]

        np.savez_compressed(f.replace('.jpg','_filt_res.npz'),**dat_dict) #frame)
2320sharon commented 3 months ago

Thanks for posting this code Dan! I was digging though some old code I had and I found were I implemented the land mask code previously.

Code

def filter_model_outputs(
    satname: str, files: list, dest_folder_good: str, dest_folder_bad: str
) -> None:
    """
    Filter model outputs based on KMeans clustering of RMSE values and organize into 'good' and 'bad'.

    Args:
        label (str): Label used for categorizing.
        files (list): List of file paths.
        dest_folder_good (str): Destination folder for 'good' files.
        dest_folder_bad (str): Destination folder for 'bad' files.
    """
    valid_files = return_valid_files(files)
    print(f"Found {len(valid_files)} valid files for {satname}.")
    times, time_var = get_time_vectors(valid_files)
    da = xr.concat([load_xarray_data(f) for f in valid_files], dim=time_var)
    timeav = da.mean(dim="time")

    rmse, input_rmse = measure_rmse(da, times, timeav)
    labels, scores = get_kmeans_clusters(input_rmse, rmse)
    files_bad, files_good = get_good_bad_files(valid_files, labels, scores)
    # print(files_good)
    print(f"Found {len(files_bad)} files_bad.")
    print(f"Found {len(files_good)} files_good.")

    # apply land mask to good files
    # get the times from the good file names
    times, time_var = get_time_vectors(files_good)
    # create xarray from good files
    da = xr.concat([load_xarray_data(f) for f in files_good], dim=time_var)
    # create time average of good files
    timeav = da.mean(dim="time")
    # create land mask from the time averaged image
    mask_land = np.array(np.round(timeav)==3).astype('int')
    # apply land mask to each time in the good files
    for time in times:
        # select the time
        frame = da.sel(time=time).to_numpy()

        frame[mask_land==1] = 3

        da.sel(time=time).values = frame

    # save the masked files to npz
    for f in files_good:
        dest_path = os.path.join(dest_folder_good, os.path.basename(f))
        if not os.path.exists(os.path.dirname(dest_path)):
            os.makedirs(os.path.dirname(dest_path),exist_ok=True)
        print(f"Saving {dest_path}")
        np.savez_compressed(dest_path, grey_label=da.sel(time=time).to_numpy())

    files_good = []
    print(files_good)
    print("bad",files_bad)
    handle_files_and_directories(
        files_bad, files_good, dest_folder_bad, dest_folder_good
    )