Dana-Farber-AIOS / pathml

Tools for computational pathology
https://pathml.org
GNU General Public License v2.0
387 stars 81 forks source link

Warnings associated with circulating a keras model among dask workers #212

Closed surya-narayanan closed 2 years ago

surya-narayanan commented 2 years ago

We are getting a set of warnings (which I think is contributing to a subsequent error https://github.com/Dana-Farber-AIOS/pathml/issues/164#issuecomment-953384867 and the warnings https://github.com/Dana-Farber-AIOS/pathml/issues/211#issue-1038691185) is around the loading of a saved keras checkpoint file.

Here is the warning we get, which we get when we run the SegmentMIF function:

_WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), NOT tf.saved_model.save(). To confirm, there should be a file named "kerasmetadata.pb" in the SavedModel directory.

We believe that the keras saved model is being recycled dirtily to dask workers (existing locks not released etc.), causing the warnings in https://github.com/Dana-Farber-AIOS/pathml/issues/211#issue-1038691185 and eventually, the error in https://github.com/Dana-Farber-AIOS/pathml/issues/164#issuecomment-953384867.

To Reproduce Here is our pipeline. I cannot share the data for regulatory reasons.

pipeline = Pipeline([
    CollapseRunsVectra(),    
    SegmentMIF(model='mesmer', nuclear_channel=0, cytoplasm_channel=2, image_resolution=0.5, 
               gpu=False, postprocess_kwargs_whole_cell=None, 
               postprocess_kwrags_nuclear=None),
    QuantifyMIF('nuclear_segmentation')   
])
jacob-rosenthal commented 2 years ago

Thanks Surya, can you also post the code you are using to set up the dask client and run the pipeline?

jacob-rosenthal commented 2 years ago

From the warning you posted, it seems like this warning is coming from loading the weights for the pretrained mesmer model. I am guessing that the exact line of code may be this one.

Can you also post the full traceback? That could help us figure out if the warning is coming from Mesmer or pathml.

surya-narayanan commented 2 years ago

No traceback, since it's a warning, not an error.

code to run the pipeline is as follows:

For a single tile:

tile.image = np.squeeze(tile.image)
pipeline.apply(tile)

For a SlideDataset:

slides_names = [path.as_posix() for path in Path(src).rglob('*.qptiff') if 'HnE' not in path.name]
#Create slides from file paths
slides = [SlideData(i, backend = "bioformats", slide_type = types.Vectra) for i in slides_names]
#Run the pipeline
slide_dataset.run(pipeline)
jacob-rosenthal commented 2 years ago

Hey @surya-narayanan, we found a bug when running pipeline on datasets with distributed=True (#216, fixed in #217 )

A workaround is to pass a Dask Client object directly when you run a pipeline, here's some pseudocode:

my_client = distributed.Client()
slide_dataset.run(pipeline, client=my_client)

Can you try this out and let me know if it resolves any or all of the problems you were running into?

surya-narayanan commented 2 years ago

What is distributed? Do I import it from pathml or am I to install dask separately?

jacob-rosenthal commented 2 years ago

It's dask distributed

surya-narayanan commented 2 years ago

I still get the following warning -

distributed.worker - WARNING - Compute Failed
Function:  apply

args:      (Tile(coords=(0, 9000),
    name=None,
    image shape: (3000, 3000, 7),
    slide_type=SlideType(stain=Fluor, platform=Vectra, tma=None, rgb=None, volumetric=None, time_series=None),
    labels=None,
    masks=None,
    counts=None))
kwargs:    {}
Exception: IndexError('Read less bytes than requested')
surya-narayanan commented 2 years ago

To address the warning at the beginning of this issue, I also get the following warning despite manually establishing the Client()

WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.

surya-narayanan commented 2 years ago

This still errors out btw (after about 30 mins, similar to my previous experiences with this error), producing the error mentioned in https://github.com/Dana-Farber-AIOS/pathml/issues/164#issuecomment-953384867

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/tmp/ipykernel_4753/2593815674.py in <module>
----> 1 slide_dataset.run(pipeline, client=my_client)

/opt/conda/envs/wtf/lib/python3.8/site-packages/pathml/core/slide_dataset.py in run(self, pipeline, **kwargs)
     47         # run preprocessing
     48         for slide in self.slides:
---> 49             slide.run(pipeline, **kwargs)
     50 
     51         assert not any([s.tile_dataset is None for s in self.slides])

/opt/conda/envs/wtf/lib/python3.8/site-packages/pathml/core/slide_data.py in run(self, pipeline, distributed, client, tile_size, tile_stride, level, tile_pad, overwrite_existing_tiles)
    307 
    308             # as tiles are processed, add them to h5
--> 309             for future, tile in dask.distributed.as_completed(
    310                 processed_tile_futures, with_results=True
    311             ):

/opt/conda/envs/wtf/lib/python3.8/site-packages/distributed/client.py in __next__(self)
   4522             with self.thread_condition:
   4523                 self.thread_condition.wait(timeout=0.100)
-> 4524         return self._get_and_raise()
   4525 
   4526     async def __anext__(self):

/opt/conda/envs/wtf/lib/python3.8/site-packages/distributed/client.py in _get_and_raise(self)
   4513             if self.raise_errors and future.status == "error":
   4514                 typ, exc, tb = result
-> 4515                 raise exc.with_traceback(tb)
   4516         return res
   4517 

/opt/conda/envs/wtf/lib/python3.8/site-packages/pathml/preprocessing/pipeline.py in apply()
     47         if self.transforms:
     48             for t in self.transforms:
---> 49                 t.apply(tile)
     50         return tile
     51 

/opt/conda/envs/wtf/lib/python3.8/site-packages/pathml/preprocessing/transforms.py in apply()
   1275             tile.slide_type.stain == "Fluor"
   1276         ), f"Tile has slide_type.stain='{tile.slide_type.stain}', but must be 'Fluor'"
-> 1277         cell_segmentation, nuclear_segmentation = self.F(tile.image)
   1278         tile.masks["cell_segmentation"] = cell_segmentation
   1279         tile.masks["nuclear_segmentation"] = nuclear_segmentation

/opt/conda/envs/wtf/lib/python3.8/site-packages/pathml/preprocessing/transforms.py in F()
   1251             from deepcell.applications import Mesmer
   1252 
-> 1253             model = Mesmer()
   1254             cell_segmentation_predictions = model.predict(
   1255                 nuc_cytoplasm, compartment="whole-cell"

/opt/conda/envs/wtf/lib/python3.8/site-packages/deepcell/applications/mesmer.py in __init__()
    219             )
    220             model_path = os.path.splitext(archive_path)[0]
--> 221             model = tf.keras.models.load_model(model_path)
    222 
    223         super(Mesmer, self).__init__(

/opt/conda/envs/wtf/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py in load_model()
    204         filepath = path_to_string(filepath)
    205         if isinstance(filepath, str):
--> 206           return saved_model_load.load(filepath, compile, options)
    207 
    208   raise IOError(

/opt/conda/envs/wtf/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py in load()
    150   for node_id, loaded_node in keras_loader.loaded_nodes.items():
    151     nodes_to_load[keras_loader.get_path(node_id)] = loaded_node
--> 152   loaded = tf_load.load_partial(path, nodes_to_load, options=options)
    153 
    154   # Finalize the loaded layers and remove the extra tracked dependencies.

/opt/conda/envs/wtf/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py in load_partial()
    773     A dictionary mapping node paths from the filter to loaded objects.
    774   """
--> 775   return load_internal(export_dir, tags, options, filters=filters)
    776 
    777 

/opt/conda/envs/wtf/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py in load_internal()
    903     with ops.init_scope():
    904       try:
--> 905         loader = loader_cls(object_graph_proto, saved_model_proto, export_dir,
    906                             ckpt_options, filters)
    907       except errors.NotFoundError as err:

/opt/conda/envs/wtf/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py in __init__()
    161 
    162     self._load_all()
--> 163     self._restore_checkpoint()
    164 
    165     for node in self._nodes:

/opt/conda/envs/wtf/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py in _restore_checkpoint()
    487                                   self._checkpoint_options).expect_partial()
    488     else:
--> 489       load_status = saver.restore(variables_path, self._checkpoint_options)
    490     load_status.assert_existing_objects_matched()
    491     checkpoint = load_status._checkpoint

/opt/conda/envs/wtf/lib/python3.8/site-packages/tensorflow/python/training/tracking/util.py in restore()
   1299       dtype_map = reader.get_variable_to_dtype_map()
   1300     try:
-> 1301       object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
   1302     except errors_impl.NotFoundError:
   1303       # The object graph proto does not exist in this checkpoint. Try the

/opt/conda/envs/wtf/lib/python3.8/site-packages/tensorflow/python/training/py_checkpoint_reader.py in get_tensor()
     67   """Get the tensor from the Checkpoint object."""
     68   try:
---> 69     return CheckpointReader.CheckpointReader_GetTensor(
     70         self, compat.as_bytes(tensor_str))
     71   # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the

IndexError: Read less bytes than requested
surya-narayanan commented 2 years ago

Further, when I try to run a pipeline again, I get

distributed.worker - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker.html#memtrim for more information. -- Unmanaged memory: 9.31 GiB -- Worker memory limit: 13.27 GiB

Indicating that mismanaged memory is not let out.

Should I have re-installed pathml before I did anything?

surya-narayanan commented 2 years ago

Running any subsequent code that requires multiple workers incurs the following error, indicating that memory is still mismanaged.

---------------------------------------------------------------------------
TimeoutError                              Traceback (most recent call last)
/tmp/ipykernel_4753/1999371452.py in <module>
----> 1 slide_dataset.run(pipeline, client=my_client, tile_size= (250, 250))

/opt/conda/envs/wtf/lib/python3.8/site-packages/pathml/core/slide_dataset.py in run(self, pipeline, **kwargs)
     47         # run preprocessing
     48         for slide in self.slides:
---> 49             slide.run(pipeline, **kwargs)
     50 
     51         assert not any([s.tile_dataset is None for s in self.slides])

/opt/conda/envs/wtf/lib/python3.8/site-packages/pathml/core/slide_data.py in run(self, pipeline, distributed, client, tile_size, tile_stride, level, tile_pad, overwrite_existing_tiles)
    302                 # explicitly scatter data, i.e. send the tile data out to the cluster before applying the pipeline
    303                 # according to dask, this can reduce scheduler burden and keep data on workers
--> 304                 big_future = client.scatter(tile)
    305                 f = client.submit(pipeline.apply, big_future)
    306                 processed_tile_futures.append(f)

/opt/conda/envs/wtf/lib/python3.8/site-packages/distributed/client.py in scatter(self, data, workers, broadcast, direct, hash, timeout, asynchronous)
   2167         else:
   2168             local_worker = None
-> 2169         return self.sync(
   2170             self._scatter,
   2171             data,

/opt/conda/envs/wtf/lib/python3.8/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    858             return future
    859         else:
--> 860             return sync(
    861                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    862             )

/opt/conda/envs/wtf/lib/python3.8/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    324     if error[0]:
    325         typ, exc, tb = error[0]
--> 326         raise exc.with_traceback(tb)
    327     else:
    328         return result[0]

/opt/conda/envs/wtf/lib/python3.8/site-packages/distributed/utils.py in f()
    307             if callback_timeout is not None:
    308                 future = asyncio.wait_for(future, callback_timeout)
--> 309             result[0] = yield future
    310         except Exception:
    311             error[0] = sys.exc_info()

/opt/conda/envs/wtf/lib/python3.8/site-packages/tornado/gen.py in run(self)
    760 
    761                     try:
--> 762                         value = future.result()
    763                     except Exception:
    764                         exc_info = sys.exc_info()

/opt/conda/envs/wtf/lib/python3.8/site-packages/distributed/client.py in _scatter(self, data, workers, broadcast, direct, local_worker, timeout, hash)
   2058                 )
   2059             else:
-> 2060                 await self.scheduler.scatter(
   2061                     data=data2,
   2062                     workers=workers,

/opt/conda/envs/wtf/lib/python3.8/site-packages/distributed/core.py in send_recv_from_rpc(**kwargs)
    872             name, comm.name = comm.name, "ConnectionPool." + key
    873             try:
--> 874                 result = await send_recv(comm=comm, op=key, **kwargs)
    875             finally:
    876                 self.pool.reuse(self.addr, comm)

/opt/conda/envs/wtf/lib/python3.8/site-packages/distributed/core.py in send_recv(comm, reply, serializers, deserializers, **kwargs)
    665         if comm.deserialize:
    666             typ, exc, tb = clean_exception(**response)
--> 667             raise exc.with_traceback(tb)
    668         else:
    669             raise Exception(response["exception_text"])

/opt/conda/envs/wtf/lib/python3.8/site-packages/distributed/core.py in handle_comm()
    500                             result = asyncio.ensure_future(result)
    501                             self._ongoing_coroutines.add(result)
--> 502                             result = await result
    503                     except (CommClosedError, CancelledError):
    504                         if self.status == Status.running:

/opt/conda/envs/wtf/lib/python3.8/site-packages/distributed/scheduler.py in scatter()
   5635             await asyncio.sleep(0.2)
   5636             if time() > start + timeout:
-> 5637                 raise TimeoutError("No workers found")
   5638 
   5639         if workers is None:

TimeoutError: No workers found
jacob-rosenthal commented 2 years ago

Hi @surya-narayanan, so there are a lot of different things going on here, I'll try to break them down:

1. The warning about loading a pretrained tf model:

WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), NOT tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.

This warning looks like it comes from deepcell. I guess they trained and saved their model in TF 2.4 and only bumped the version to 2.5 here, so you get this warning when loading it in TF 2.5. From my googling I think it's ok to ignore this warning (see here for example)

2. IndexError('Read less bytes than requested')

From here it looks like this could be due to running out of disk space: https://github.com/tensorflow/tensorflow/issues/21544#issuecomment-418848591

How is your dask cluster configured? Are you provisioning enough resources to the workers? See here for a place to start. You may need to manage your dask cluster instead of just using the default settings.

3. TimeoutError: No workers found

This seems like a problem with the client. Take a look at client.scheduler_info()["workers"] to see if the scheduler can see the workers (from here).

You can also try restarting the cluster to see if that helps: http://distributed.dask.org/en/latest/api.html#distributed.Client.restart

4. distributed.worker - WARNING - Unmanaged memory use is high

We are not exactly sure where this is coming from or what it means. From this we hypothesized that it could be due to objects not being released for garbage collection. So we added a step to manually delete the Mesmer model each time (see here). This is merged to dev so you could try installing from dev and see if that helps at all


It is hard for me to troubleshoot every single error/warning but I hope this is useful. Thanks!

jacob-rosenthal commented 2 years ago

I confirmed with the vanvalen lab - it's ok to ignore the warning about loading a model from a previous version of tensorflow. See here: https://github.com/vanvalenlab/deepcell-tf/issues/569