insarlab / MintPy

Miami InSAR time-series software in Python
https://mintpy.readthedocs.io
Other
560 stars 245 forks source link

`conncomp.plot_bridge()`: improve memory usage #1155

Closed ehavazli closed 3 months ago

ehavazli commented 4 months ago

This commit improves memory usage when plotting bridges. Current method creates separate mask arrays while plotting the bridges and the memory usage grows to hundreds of GBs when running tens of interferograms (54 interferograms -> ~200GB). This change brings memory usage down to ~8GB.

yunjunz commented 4 months ago

Hi @ehavazli, these changes seem reasonable to me. For the complex logic in this part, I vaguely remember it was for the small background values ("holes") in the bridge end point region, but could not recall the details anymore. Could you show the plot before and after the PR here for a more intuitive evaluation?

ehavazli commented 4 months ago

Hi @yunjunz, here are the plots for before and after the change:

before the change: bridges_old_code

after the change: bridges_new_code

yunjunz commented 4 months ago

Thanks for the figure @ehavazli. Below are my testing results and code: the current version is a more accurate presentation of the bridge endpoints, and the memory usage is the same for this single-pair case. I could not see it easily from the change that why this PR saves memory for the multi-pair cases, could you:

  1. show your testing code here for me to test?
  2. update the PR to not plot the white transparency over the waterbody or other unrelated land areas, whose pixels are not used while calculating the phase difference of a bridge.
%matplotlib inline
import os
from matplotlib import pyplot as plt
from mintpy.objects.conncomp import connectComponent
from mintpy.utils import readfile, network as pnet

# read conn comp data & info
stack_file = os.path.expanduser('~/data/test/FernandinaSenDT128/mintpy/inputs/ifgramStack.h5')
atr = readfile.read_attribute(stack_file)
date12_list = pnet.get_date12_list(stack_file)
conncomp = readfile.read(stack_file, datasetName=f'connectComponent-{date12_list[0]}')[0]

# initiate an connectComponent object
cc = connectComponent(conncomp=conncomp, metadata=atr)
cc.label()
cc.find_mst_bridge()

# plot bridges
fig, ax = plt.subplots(figsize=[8, 6])
cc.plot_bridge(ax)
fig.savefig(os.path.expanduser('~/Downloads/bridges.png'), bbox_inches='tight', dpi=300)
ehavazli commented 4 months ago

hey @yunjunz, I am going to update this PR w.r.t. your second comment but here are two plots showing memory profiles for both current and this PR: Mark the difference between runtime and memory usage.

Here is the code I used for memory profiling. I pulled the conncomp.py out and am importing it from my local directory. There are no changes to conncomp.py other than the plot_bridge function.


import numpy as np
import glob
import colorcet
from conncomp import connectComponent
from osgeo import gdal
gdal.UseExceptions()

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [20, 10]

import time
from memory_profiler import profile

@profile
def read_gdal(filepath):
    ds = gdal.Open(filepath, gdal.GA_ReadOnly)

    if ds.RasterCount == 1:
        f = ds.GetRasterBand(1).ReadAsArray()
    elif ds.RasterCount == 2:
        f = ds.GetRasterBand(2).ReadAsArray()
    else:
        raise Exception(f'Unexpected number of bands')

    f_atr = gdal.Info(filepath, format='json')
    ds = None
    return f, f_atr

@profile
def make_plot(array, title, outfile):
    fig, ax = plt.subplots()
    img = ax.imshow(array, interpolation='nearest', cmap=colorcet.m_CET_C8, vmin=-30, vmax=70)
    ax.set_title(title)
    fig.colorbar(img)
    fig.savefig(outfile, bbox_inches='tight')
    plt.clf()
    plt.close(fig)

    return print(f"Wrote: {outfile}")

@profile
def bridge_iteration(unw_file_list, conncomp_file_list, output_folder):
    st = time.time()
    for i in range(len(unw_file_list)):
        conncomp, conn_atr = read_gdal(conncomp_file_list[i])
        unw, unw_atr = read_gdal(unw_file_list[i])
        unw = np.ma.masked_where(unw == 0, unw)
        unw = np.ma.masked_where(conncomp == 0, unw)
        title1 = f'{unw_file_list[i].split("/")[-2]}_{unw_file_list[i].split("/")[-1]}'
        make_plot(unw, title1, f'./{output_folder}/{unw_file_list[i].split("/")[-2]}_{unw_file_list[i].split("/")[-1]}.png')

        cc = connectComponent(conncomp=conncomp, metadata=conn_atr)
        brdg_labels = cc.label()
        bridges = cc.find_mst_bridge()
        fig, ax = plt.subplots()
        bridge_plot = cc.plot_bridge(ax)
        title2 = f'Connected Components: {conncomp_file_list[i].split("/")[-2]}_{conncomp_file_list[i].split("/")[-1]}'
        bridge_plot.set_title(title2)
        fig.savefig(f'./{output_folder}/{conncomp_file_list[i].split("/")[-2]}_{conncomp_file_list[i].split("/")[-1]}.bridges.png', bbox_inches='tight')
        print(f'Wrote: ./{output_folder}/{conncomp_file_list[i].split("/")[-2]}_{conncomp_file_list[i].split("/")[-1]}.bridges.png')
        plt.close('all')

        bridge_unw = cc.unwrap_conn_comp(unw)
        bridge_unw = np.ma.masked_where(bridge_unw == 0, bridge_unw)
        bridge_unw = np.ma.masked_where(conncomp == 0, bridge_unw)
        make_plot(bridge_unw, f'bridging_{title1}', f'./{output_folder}/{unw_file_list[i].split("/")[-2]}_{unw_file_list[i].split("/")[-1]}.bridging.png')

        diff = bridge_unw - unw
        make_plot(diff, f'Difference: {unw_file_list[i].split("/")[-1]}', f'./{output_folder}/{unw_file_list[i].split("/")[-2]}_{unw_file_list[i].split("/")[-1]}.diff.png')
    et = time.time()
    elapsed_time = (et - st) / 60
    return print(f'Elapsed time: {elapsed_time} minutes')

if __name__ == "__main__":
    f_dir = 'SWE_P71_F444_A_merged_ifgrams/*'
    output_folder = 'test_figs/'
    conncomp_list = sorted(glob.glob(f'{f_dir}/*.conncomp'))
    unw_list = sorted(glob.glob(f'{f_dir}/*.unw'))
    bridge_iteration(unw_list, conncomp_list, output_folder)
ehavazli commented 3 months ago

Here is the plot showing the latest changes: 20210118_20210211_bridges

Here is the memory profile: Screenshot 2024-03-19 at 4 20 53 PM Elapsed time: 9.470982917149861 minutes