mlexchange / mlex_dlsia_segmentation_prototype

Other
3 stars 3 forks source link

Enable custom metrics exporter with DLSIA #7

Open taxe10 opened 9 months ago

taxe10 commented 9 months ago

Currently, we are modifying an existing DLSIA function train_segmentation to write/export the loss and metrics at every epoch while the training process is ongoing.

We (@xiaoyachong @zhuowenzhao @taxe10) discussed about how to integrate DVC without modifying the DLSIA function at our end. Here we summarize our initial thoughts:

In MLExchange:

# train.py
from dvclive import Live
from dlsia.core import train_segmentation

# Creates custom exporter class compatible with DVC
class DvcExporter:
    def __init__(self, live):
        self.live = live

    def export_train_metrics(self, metrics):
        # DVC code

# Init DVC Live
with Live() as live:

    # Define the parameters to be tracked
    live.log_param("epochs", NUM_EPOCHS)

    # Init custom exporter
    dvc_exporter = DvcExporter(live)

    # Calls DLSIA function with custom exporter
    train_segmentation(net, trainloader, validationloader, NUM_EPOCHS,
                       criterion, optimizer, device,
                       savepath=None, saveevery=None,
                       scheduler=None, show=0,
                       use_amp=False, clip_value=None, custom_exporter=dvc_exporter)

    live.log_artifact(path, type="model", name=name)

This would require a PR in DLSIA that would look as follows:

def train_segmentation(net, trainloader, validationloader, NUM_EPOCHS,
                       criterion, optimizer, device,
                       savepath=None, saveevery=None,
                       scheduler=None, show=0,
                       use_amp=False, clip_value=None, custom_exporter=None):
.........
for epoch in range(NUM_EPOCHS):
    ....
    ## After validation maybe around https://github.com/phzwart/dlsia/blob/f3f50a78faeb99aca4b9725ffa63c7b95c0613df/dlsia/core/train_scripts.py#L228
    if custom_exporter is not None:
        custom_exporter.export_train_metrics(metrics)

This is a very rough draft, mostly to gather feedback. Any thoughts and/or comments? @Wiebke @dylanmcreynolds @TibbersHao

zhuowenzhao commented 9 months ago

Another thought of mine is aligned with implementing a TrainModel class, which might be very doable within the timeframe of the Diamond trip since I'd assume train_segmentation() function is not called elsewhere in DLSIA. I am putting it here for (future) record with Peter.

Can implement a function (or internal) called train_epoch() in DLSIA that update per epoch which can be used in an outside loop:

class TrainModel:
     def _ _init_ _(self, **args):
          # initialize metrics if needed 
          self.metrics...

     def train_epoch(self, ...):

     def tain_segmentation(self, ...):
          ...
           for epoch in range(NUM_EPOCHS):
                 ....
                 self.train_epoch()
                 ....

Then for DVC Live, we can use the same code as DVC Live documentation suggested:

train_model = TrainMode()

# Init DVC Live, this code stays unchanged as DVC documentation
with Live() as live:

    live.log_param("epochs", NUM_EPOCHS)

    for epoch in range(NUM_EPOCHS):
        train_model.train_epoch()
        metrics = train_model.metrics

        for metric_name, value in metrics.items():
            live.log_metric(metric_name, value)

        live.next_step()

    live.log_artifact(path, type="model", name=name)
dylanmcreynolds commented 9 months ago

How would this Live() communicate updated loss to the user? Writing to a file that the segmentation app polls? Writing to a web socket?

phzwart commented 9 months ago

Hi,

I like Zhuowen's suggestion a lot.

P

On Fri, Feb 23, 2024 at 12:58 PM Zhuowen (Kevin) Zhao < @.***> wrote:

Another thought of mine is aligned with implementing a TrainModel class, which might be very doable within the timeframe of the Diamond trip since I'd assume train_segmentation() function is not called elsewhere in DLSIA. I am putting it here for (future) record with Peter.

Can implement a function (or internal) called train_epoch() that update per epoch which can be used in an outside loop:

class TrainModel: def _ init _(self, **args):

initialize metrics if needed

      self.metrics...

 def train_epoch(self, ...):

 def tain_segmentation(self, ...):
      ...
       for epoch in range(NUM_EPOCHS):
             ....
             self.train_epoch()
             ....

Then for DVC Live, we can use the same code as DVC Live documentation suggested:

train_model = TrainMode()

Init DVC Live, this code stays unchanged as DVC documentation

with Live() as live:

live.log_param("epochs", NUM_EPOCHS)

for epoch in range(NUM_EPOCHS):
    train_model.train_epoch()
    metrics = train_model.metrics

    for metric_name, value in metrics.items():
        live.log_metric(metric_name, value)

    live.next_step()

live.log_artifact(path, type="model", name=name)

— Reply to this email directly, view it on GitHub https://github.com/mlexchange/mlex_dlsia_segmentation_prototype/issues/7#issuecomment-1961980744, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADWIEEZ44FT6XATBQXZDCU3YVD7IFAVCNFSM6AAAAABDXJJD76VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSNRRHE4DANZUGQ . You are receiving this because you are subscribed to this thread.Message ID: <mlexchange/mlex_dlsia_segmentation_prototype/issues/7/1961980744@ github.com>

--

Peter Zwart Staff Scientist, Molecular Biophysics and Integrated Bioimaging Berkeley Synchrotron Infrared Structural Biology Biosciences Lead, Center for Advanced Mathematics for Energy Research Applications Lawrence Berkeley National Laboratories 1 Cyclotron Road, Berkeley, CA-94703, USA Cell: 510 289 9246

phzwart commented 9 months ago

Especially given the fact that we need to work on Distributed Data Parallel options in the future, and configuring that could be part of this.

On Fri, Feb 23, 2024 at 1:11 PM Petrus Zwart @.***> wrote:

Hi,

I like Zhuowen's suggestion a lot.

P

On Fri, Feb 23, 2024 at 12:58 PM Zhuowen (Kevin) Zhao < @.***> wrote:

Another thought of mine is aligned with implementing a TrainModel class, which might be very doable within the timeframe of the Diamond trip since I'd assume train_segmentation() function is not called elsewhere in DLSIA. I am putting it here for (future) record with Peter.

Can implement a function (or internal) called train_epoch() that update per epoch which can be used in an outside loop:

class TrainModel: def _ init _(self, **args):

initialize metrics if needed

      self.metrics...

 def train_epoch(self, ...):

 def tain_segmentation(self, ...):
      ...
       for epoch in range(NUM_EPOCHS):
             ....
             self.train_epoch()
             ....

Then for DVC Live, we can use the same code as DVC Live documentation suggested:

train_model = TrainMode()

Init DVC Live, this code stays unchanged as DVC documentation

with Live() as live:

live.log_param("epochs", NUM_EPOCHS)

for epoch in range(NUM_EPOCHS):
    train_model.train_epoch()
    metrics = train_model.metrics

    for metric_name, value in metrics.items():
        live.log_metric(metric_name, value)

    live.next_step()

live.log_artifact(path, type="model", name=name)

— Reply to this email directly, view it on GitHub https://github.com/mlexchange/mlex_dlsia_segmentation_prototype/issues/7#issuecomment-1961980744, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADWIEEZ44FT6XATBQXZDCU3YVD7IFAVCNFSM6AAAAABDXJJD76VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSNRRHE4DANZUGQ . You are receiving this because you are subscribed to this thread.Message ID: <mlexchange/mlex_dlsia_segmentation_prototype/issues/7/1961980744@ github.com>

--


Peter Zwart Staff Scientist, Molecular Biophysics and Integrated Bioimaging Berkeley Synchrotron Infrared Structural Biology Biosciences Lead, Center for Advanced Mathematics for Energy Research Applications Lawrence Berkeley National Laboratories 1 Cyclotron Road, Berkeley, CA-94703, USA Cell: 510 289 9246


--

Peter Zwart Staff Scientist, Molecular Biophysics and Integrated Bioimaging Berkeley Synchrotron Infrared Structural Biology Biosciences Lead, Center for Advanced Mathematics for Energy Research Applications Lawrence Berkeley National Laboratories 1 Cyclotron Road, Berkeley, CA-94703, USA Cell: 510 289 9246

xiaoyachong commented 9 months ago

DVCLive supports a lot of existing ML Frameworks (e.g. Fast.ai, Pytorch, Keras, Hugging Face, etc).

Tanny's idea is similar to how DVCLive supports Keras (https://dvc.org/doc/dvclive/ml-frameworks/keras), while Zhuowen's idea is similar to the Hugging Face's method (https://dvc.org/doc/dvclive/ml-frameworks/huggingface). I think both will be fine.

xiaoyachong commented 9 months ago

How would this Live() communicate updated loss to the user? Writing to a file that the segmentation app polls? Writing to a web socket?

Live() will automatically generate a local file called 'report.html' during training, which is updated once after each epoch. And the report.html looks like:

Screenshot 2024-02-23 at 3 45 31 PM
xiaoyachong commented 9 months ago

@phzwart Hi Peter, based on Zhuowen's idea, I create a new Class called Trainer() and test DVC using a jupyter notebook file (https://drive.google.com/file/d/1Hy7qKViilWDV_fHk0F1NbGkw1TM7vnBI/view?usp=sharing).

Could you take a look at it and tell whether we could add it to DLSIA?

phzwart commented 9 months ago

TypeError Traceback (most recent call last) Cell In[9], line 5 2 shift = 2 3 data_transform = transforms.ToTensor() ----> 5 dataset = TiledDataset( 6 recon_uri=RECON_TILED_URI, 7 mask_uri=MASK_TILED_URI, 8 #seg_uri=SEG_TILED_URI, 9 mask_idx=mask_idx, 10 recon_api_key=RECON_TILED_API_KEY, 11 mask_api_key=MASK_TILED_API_KEY, 12 #seg_api_key=SEG_TILED_API_KEY, 13 shift = shift, 14 transform=data_transform 15 )

Cell In[6], line 28, in TiledDataset.init(self, recon_uri, mask_uri, mask_idx, recon_api_key, mask_api_key, shift, transform) 2 def init( 3 self, 4 recon_uri, (...) 11 shift=0, 12 transform=None): 13 ''' 14 Args: 15 recon_uri: str, Tiled URI of the reconstruction (...) 26 ml_data: tuple, (recon_tensor, mask_tensor) 27 ''' ---> 28 self.recon_client = from_uri(recon_uri, api_key=recon_api_key) 29 self.mask_client = from_uri(mask_uri, api_key=mask_api_key) 30 #self.seg_client = from_uri(seg_uri, api_key=seg_api_key)

File ~/anaconda3/envs/dlsia-new/lib/python3.9/site-packages/tiled/client/constructors.py:64, in from_uri(uri, structure_clients, cache, username, auth_provider, api_key, verify, prompt_for_reauthentication, headers, timeout, include_data_sources) 12 def from_uri( 13 uri, 14 structure_clients="numpy", (...) 24 include_data_sources=False, 25 ): 26 """ 27 Connect to a Node on a local or remote server. 28 (...) 62 Default False. If True, fetch information about underlying data sources. 63 """ ---> 64 context, node_path_parts = Context.from_any_uri( 65 uri, 66 api_key=api_key, 67 cache=cache, 68 headers=headers, 69 timeout=timeout, 70 verify=verify, 71 ) 72 return from_context( 73 context, 74 structure_clients=structure_clients, (...) 79 include_data_sources=include_data_sources, 80 )

File ~/anaconda3/envs/dlsia-new/lib/python3.9/site-packages/tiled/client/context.py:247, in Context.from_any_uri(cls, uri, headers, api_key, cache, timeout, verify, token_cache, app) 227 @classmethod 228 def from_any_uri( 229 cls, (...) 238 app=None, 239 ): 240 """ 241 Accept a URI to a specific node. 242 (...) 245 ["a", "b", "c"]. 246 """ --> 247 uri = httpx.URL(uri) 248 node_path_parts = [] 249 if "/metadata" in uri.path:

File ~/anaconda3/envs/dlsia-new/lib/python3.9/site-packages/httpx/_urls.py:119, in URL.init(self, url, kwargs) 117 self._uri_reference = url._uri_reference.copy_with(kwargs) 118 else: --> 119 raise TypeError( 120 "Invalid type for url. Expected str or httpx.URL," 121 f" got {type(url)}: {url!r}" 122 )

TypeError: Invalid type for url. Expected str or httpx.URL, got <class 'NoneType'>: None

On Mon, Mar 4, 2024 at 11:18 AM xiaoyachong @.***> wrote:

@phzwart https://github.com/phzwart Hi Peter, based on Zhuowen's idea, I create a new Class called Trainer() and test DVC using a jupyter notebook file ( https://drive.google.com/file/d/1Hy7qKViilWDV_fHk0F1NbGkw1TM7vnBI/view?usp=sharing ).

Could you take a look at it and tell whether we could add it to DLSIA?

— Reply to this email directly, view it on GitHub https://github.com/mlexchange/mlex_dlsia_segmentation_prototype/issues/7#issuecomment-1977290444, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADWIEE7SWVXSYJQEXV2C7U3YWTCKBAVCNFSM6AAAAABDXJJD76VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSNZXGI4TANBUGQ . You are receiving this because you were mentioned.Message ID: @.*** com>

--

Peter Zwart Staff Scientist, Molecular Biophysics and Integrated Bioimaging Berkeley Synchrotron Infrared Structural Biology Biosciences Lead, Center for Advanced Mathematics for Energy Research Applications Lawrence Berkeley National Laboratories 1 Cyclotron Road, Berkeley, CA-94703, USA Cell: 510 289 9246

phzwart commented 9 months ago

Tiled doesn't seem to work here, I think something is off.

Also, could you modify your class method train_segmentation? I'm not sure exactly what is happening, but am not a big fan of this:

      dvclive_folder = "result_trainer"
      with Live(dvclive_folder,report="html") as live

It's probably better to instantiate live outside this class and instead pass it in the function as an argument (dvclive_object=live), with a default value to None.

Then in the train_segmentation method, you do a

if dvclive_object is not None: dvclive_object.log_metric( ... ) .... dvclive_object.next_step()

Make sure the thing runs even without the dvclive object.

P

On Mon, Mar 4, 2024 at 5:16 PM Petrus Zwart @.***> wrote:


TypeError Traceback (most recent call last) Cell In[9], line 5 2 shift = 2 3 data_transform = transforms.ToTensor() ----> 5 dataset = TiledDataset( 6 recon_uri=RECON_TILED_URI, 7 mask_uri=MASK_TILED_URI, 8 #seg_uri=SEG_TILED_URI, 9 mask_idx=mask_idx, 10 recon_api_key=RECON_TILED_API_KEY, 11 mask_api_key=MASK_TILED_API_KEY, 12 #seg_api_key=SEG_TILED_API_KEY, 13 shift = shift, 14 transform=data_transform 15 )

Cell In[6], line 28, in TiledDataset.init(self, recon_uri, mask_uri, mask_idx, recon_api_key, mask_api_key, shift, transform) 2 def init( 3 self, 4 recon_uri, (...) 11 shift=0, 12 transform=None): 13 ''' 14 Args: 15 recon_uri: str, Tiled URI of the reconstruction (...) 26 ml_data: tuple, (recon_tensor, mask_tensor) 27 ''' ---> 28 self.recon_client = from_uri(recon_uri, api_key=recon_api_key) 29 self.mask_client = from_uri(mask_uri, api_key=mask_api_key) 30 #self.seg_client = from_uri(seg_uri, api_key=seg_api_key)

File ~/anaconda3/envs/dlsia-new/lib/python3.9/site-packages/tiled/client/constructors.py:64, in from_uri(uri, structure_clients, cache, username, auth_provider, api_key, verify, prompt_for_reauthentication, headers, timeout, include_data_sources) 12 def from_uri( 13 uri, 14 structure_clients="numpy", (...) 24 include_data_sources=False, 25 ): 26 """ 27 Connect to a Node on a local or remote server. 28 (...) 62 Default False. If True, fetch information about underlying data sources. 63 """ ---> 64 context, node_path_parts = Context.from_any_uri( 65 uri, 66 api_key=api_key, 67 cache=cache, 68 headers=headers, 69 timeout=timeout, 70 verify=verify, 71 ) 72 return from_context( 73 context, 74 structure_clients=structure_clients, (...) 79 include_data_sources=include_data_sources, 80 )

File ~/anaconda3/envs/dlsia-new/lib/python3.9/site-packages/tiled/client/context.py:247, in Context.from_any_uri(cls, uri, headers, api_key, cache, timeout, verify, token_cache, app) 227 @classmethod 228 def from_any_uri( 229 cls, (...) 238 app=None, 239 ): 240 """ 241 Accept a URI to a specific node. 242 (...) 245 ["a", "b", "c"]. 246 """ --> 247 uri = httpx.URL(uri) 248 node_path_parts = [] 249 if "/metadata" in uri.path:

File ~/anaconda3/envs/dlsia-new/lib/python3.9/site-packages/httpx/_urls.py:119, in URL.init(self, url, kwargs) 117 self._uri_reference = url._uri_reference.copy_with(kwargs) 118 else: --> 119 raise TypeError( 120 "Invalid type for url. Expected str or httpx.URL," 121 f" got {type(url)}: {url!r}" 122 )

TypeError: Invalid type for url. Expected str or httpx.URL, got <class 'NoneType'>: None

On Mon, Mar 4, 2024 at 11:18 AM xiaoyachong @.***> wrote:

@phzwart https://github.com/phzwart Hi Peter, based on Zhuowen's idea, I create a new Class called Trainer() and test DVC using a jupyter notebook file ( https://drive.google.com/file/d/1Hy7qKViilWDV_fHk0F1NbGkw1TM7vnBI/view?usp=sharing ).

Could you take a look at it and tell whether we could add it to DLSIA?

— Reply to this email directly, view it on GitHub https://github.com/mlexchange/mlex_dlsia_segmentation_prototype/issues/7#issuecomment-1977290444, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADWIEE7SWVXSYJQEXV2C7U3YWTCKBAVCNFSM6AAAAABDXJJD76VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSNZXGI4TANBUGQ . You are receiving this because you were mentioned.Message ID: @.*** com>

--


Peter Zwart Staff Scientist, Molecular Biophysics and Integrated Bioimaging Berkeley Synchrotron Infrared Structural Biology Biosciences Lead, Center for Advanced Mathematics for Energy Research Applications Lawrence Berkeley National Laboratories 1 Cyclotron Road, Berkeley, CA-94703, USA Cell: 510 289 9246


--

Peter Zwart Staff Scientist, Molecular Biophysics and Integrated Bioimaging Berkeley Synchrotron Infrared Structural Biology Biosciences Lead, Center for Advanced Mathematics for Energy Research Applications Lawrence Berkeley National Laboratories 1 Cyclotron Road, Berkeley, CA-94703, USA Cell: 510 289 9246

xiaoyachong commented 9 months ago

@phzwart Hi Peter, I modify the Trainer() accordingly and test it. It works well. For the definition of the Trainer(), you could refer to the latest version https://github.com/mlexchange/mlex_dlsia_segmentation_prototype/blob/xchong-dvc/src/seg_utils.py