riponazad / echotracker

Point Tracking in Echocardiography.
MIT License
29 stars 2 forks source link

is there a training code? #2

Open kimsekeun opened 2 months ago

kimsekeun commented 2 months ago

Could you guys provide training code for reproducibility?

Thanks.

riponazad commented 2 months ago

Hi @kimsekeun,

The training code is already there as a method in the class def train. Since we couldn't publish our dataset, we didn't explicitly show how to call the method. The following lines can be adapted to load your data and train/finetune EchoTracker.

from model.net import EchoTracker
B=1, # batchsize 
S=24, # seqlen
#you should have your own load_datasets() method before executing the following line    
dataloaders, dataset_size = load_datasets(B=B, S=S, debug=debug, use_aug=True)

model = EchoTracker(device_ids=[0])
#model.load(path=configs.WEIGHTS_PATH.echopips, eval=False)  #uncomment to fine-tune the model instead training
model.load(eval=False) #load the model to train
model.train(
        dataloaders=dataloaders, 
        dataset_size=dataset_size, 
        log_dir=configs.LOG_PATH.echotracker, 
        ckpt_path=configs.SAVE_PATH.echotracker,
        epochs=100
)

Please let me know if you need any further information.

Best, Azad