scil-vital / TrackToLearn

Public release of Track-to-Learn: A general framework for tractography with deep reinforcement learning
GNU General Public License v3.0
16 stars 10 forks source link

Fluctuating Reward #5

Closed AJ-30 closed 1 year ago

AJ-30 commented 2 years ago

Hi @AntoineTheb !

I am facing the following issue in running the Experiment 2 (ISMRM dataset) of your paper.

The training reward as well as validation reward is fluctuating back and forth while using both sac and td3 algorithms on ISMRM2015 dataset and the tractogram generated at the end of training has broken points instead of proper streamlines scattered over entire brain mask instead of being confined to white matter area. It seems as if agents are not learning.

anterior raw

Training reward is in the range of 0.6 million and validation reward is in the range of 2900-3000.

How do you suggest to debug the pipeline? if I want to check if MDP is not formed properly or algorithm has some issue, or I am not sending correct input? Is this a reward maximisation problem or penalty minimisation?

Lastly, I wanted to know what the searcher files(td3_searcher and sac_searcher) are for in runners folder?

Your input would really help, as I have been trying to train it since a long time.

AntoineTheb commented 2 years ago

Hi ! Thank you for your interest in Track-to-Learn.

How do you suggest to debug the pipeline?

The framework has indeed a lot of moving parts that may make debugging a hard. Do you use the code in this repository or did you implement something else ? The MDP and algorithm should be correct if you are indeed using the code in this repository. As for the input, I'm sorry I haven't included the data used in the paper. An updated version of the repo should be coming soon, with data this time. In the meantime, did you test on a simpler dataset like the FiberCup ? Could you provide the data you used to train the tractogram shown above ?

Is this a reward maximisation problem or penalty minimisation?

As formulated by Track-to-Learn, it is a reward maximisation problem. There is no negative penalty at each timestep.

Lastly, I wanted to know what the searcher files(td3_searcher and sac_searcher) are for in runners folder?

These were used to make a hyperparameter search for both algorithms to train the agents used in the paper. These are artifacts that I should have removed before making the code public, sorry.

AJ-30 commented 2 years ago

Yes, I am running sac_train.py from this repository, using hyper parameters as mentioned in the paper. I have shared the input files that I used in this folder: ismrm_data I am using the following files as input to sac_train :-

Thank you for clarifying the rest. I am also doubtful about the input, however I personally lack the expertise to check its accuracy.

AntoineTheb commented 2 years ago

Sorry for the late reply. Looking at the data you provided, it seems some of the peaks are flipped: image

This may explain the poor performance of the agents. I am currently recomputing the fODFs and the associated metrics and will generate a new hdf5 for training when this is done. I will then try to train agents and report back.

AntoineTheb commented 2 years ago

Hi @AJ-30 ! So I have re-processed your data, recomputed the dataset (hdf5 file) and partly trained a TD3 agent on the resulting dataset.

Here is the data you provided, with recomputed fODF and DTI metrics: https://drive.google.com/file/d/1yjX8agUc3rjwxbDdXTBRCcDACrgkPPEe/view?usp=sharing

A lot of metrics are extraneous and not used by Track-to-Learn.

Here is the list of commands that I have run to process your data, create the dataset and train the agent:

# commands.sh
# These are the commands I used to preprocess your data and train a TD3 agent
# If you unzip the data archive I provided in the same directory you cloned TrackToLearn in, you should be able to run the same commands

# Preprocessing scripts
# You don't have to process the diffusion files again as they are included in the archive above. However,
# this can help if you want to process another dataset.
# This was run using scilpy: https://github.com/scilus/scilpy
# The following lines presume you have loaded scilpy in your shell

scil_resample_volume.py ankita_ismrm2015/raw_data/wm.nii.gz ankita_ismrm2015/raw_data/wm_2mm_iso.nii.gz --voxel_size 2 -v

scil_compute_ssst_frf.py ankita_ismrm2015/raw_data/Diffusion.nii.gz ankita_ismrm2015/raw_data/Diffusion.bvals ankita_ismrm2015/raw_data/Diffusion.bvecs ankita_ismrm2015/raw_data/frf.txt --mask_wm ankita_ismrm2015/raw_data/wm_2mm_iso.nii.gz -f -v

scil_compute_ssst_fodf.py ankita_ismrm2015/raw_data/Diffusion.nii.gz ankita_ismrm2015/raw_data/Diffusion.bvals ankita_ismrm2015/raw_data/Diffusion.bvecs ankita_ismrm2015/raw_data/frf.txt ankita_ismrm2015/raw_data/fodf.nii.gz --processes 8 --sh_order 6 -f

cd ankita_ismrm2015/raw_data

scil_compute_dti_metrics.py Diffusion.nii.gz Diffusion.bvals Diffusion.bvecs

scil_compute_fodf_metrics.py fodf.nii.gz -f

# Create dataset

python TrackToLearn/datasets/create_dataset.py ankita_ismrm2015/raw_data/fodf.nii.gz ankita_ismrm2015/raw_data/wm_2mm_iso.nii.gz ankita_ismrm2015/raw_data/peaks.nii.gz ismrmv2 ismrmv2 ankita_ismrm2015/ --wm ankita_ismrm2015/raw_data/wm_2mm_iso.nii.gz --gm ankita_ismrm2015/raw_data/gm_2mm_iso.nii.gz --csf ankita_ismrm2015/raw_data/csf_2mm_iso.nii.gz --save_signal

# Train the agent
./scripts/td3_experiment_ankita.sh

Here is the training script mentioned in the list of commands above: https://drive.google.com/file/d/139RVv2ci6tnm8UpKfbTqNUSWTmto4RBH/view?usp=sharing

Here are the last lines output by the training script before I ctrl-C'd:

 Total T: 2772 Episode Num: 34 Episode T: 101 Reward: 112519.612                                                                                                                                              
Total T: 2876 Episode Num: 35 Episode T: 104 Reward: 117458.456                                                                                                                                              
Total T: 2977 Episode Num: 36 Episode T: 101 Reward: 120025.694                                                                                                                                              
Total T: 3084 Episode Num: 37 Episode T: 107 Reward: 119371.826                                                                                                                                              
100%|█████████████████████████████████████████████████████████████████| 12/12 [01:19<00:00,  6.63s/it]                                                                                                       
---------------------------------------------------                                                                                                                                                          
./experiments/td3_experiment_ankita/2022-03-30-11_36_01                                                                                                                                                      
Episode 37       avg length: 44.376638162099376          total reward: 3840504.9480276373                                                                                                                    
---------------------------------------------------                                                                                                                                                                                                               
Total T: 3186 Episode Num: 38 Episode T: 102 Reward: 120184.351                                                                                                                                              
Total T: 3291 Episode Num: 39 Episode T: 105 Reward: 120423.212                                                                                                                                              
Total T: 3392 Episode Num: 40 Episode T: 101 Reward: 121305.200                                                                                                                                              
Total T: 3490 Episode Num: 41 Episode T: 98 Reward: 121052.631                                                                                                                                               
Total T: 3592 Episode Num: 42 Episode T: 102 Reward: 124328.222                                                                                                                                              
100%|█████████████████████████████████████████████████████████████████| 12/12 [01:20<00:00,  6.70s/it]                                                                                                       
---------------------------------------------------                                                                                                                                                          
./experiments/td3_experiment_ankita/2022-03-30-11_36_01                                                                                                                                                      
Episode 42       avg length: 47.14045595475463   total reward: 4118301.448187321                                                                                                                             
---------------------------------------------------                                                                                                                                                                                                             
Total T: 3689 Episode Num: 43 Episode T: 97 Reward: 122159.842                                                                                                                                               
Total T: 3791 Episode Num: 44 Episode T: 102 Reward: 120522.066                                                                                                                                              
Total T: 3886 Episode Num: 45 Episode T: 95 Reward: 119326.281                                                                                                                                               
Total T: 3984 Episode Num: 46 Episode T: 98 Reward: 118268.497                                                                                                                                               
Total T: 4079 Episode Num: 47 Episode T: 95 Reward: 118962.493                                                                                                                                               
100%|█████████████████████████████████████████████████████████████████| 12/12 [01:21<00:00,  6.78s/it]                                                                                                       
---------------------------------------------------                                                                                                                                                          
./experiments/td3_experiment_ankita/2022-03-30-11_36_01                                                                                                                                                      
Episode 47       avg length: 48.5220920269056    total reward: 4285954.524560813                                                                                                                             
---------------------------------------------------                                                                                                                                                                                                                 
Total T: 4175 Episode Num: 48 Episode T: 96 Reward: 119461.012                                                                                                                                               
Total T: 4268 Episode Num: 49 Episode T: 93 Reward: 116760.334                                                                                                                                               
Total T: 4364 Episode Num: 50 Episode T: 96 Reward: 115460.652                                                                                                                                               
Total T: 4459 Episode Num: 51 Episode T: 95 Reward: 115899.475
Total T: 4556 Episode Num: 52 Episode T: 97 Reward: 118272.552
100%|█████████████████████████████████████████████████████████████████| 12/12 [01:19<00:00,  6.64s/it]
--------------------------------------------------- 
./experiments/td3_experiment_ankita/2022-03-30-11_36_01
Episode 52       avg length: 48.794066643984166          total reward: 4321356.369898005
--------------------------------------------------- 
Total T: 4651 Episode Num: 53 Episode T: 95 Reward: 117409.265
Total T: 4753 Episode Num: 54 Episode T: 102 Reward: 115699.706
Total T: 4849 Episode Num: 55 Episode T: 96 Reward: 116005.970
Total T: 4943 Episode Num: 56 Episode T: 94 Reward: 115810.787
Total T: 5038 Episode Num: 57 Episode T: 95 Reward: 115890.782
100%|█████████████████████████████████████████████████████████████████| 12/12 [01:19<00:00,  6.64s/it]
--------------------------------------------------- 
./experiments/td3_experiment_ankita/2022-03-30-11_36_01
Episode 57       avg length: 48.94373431243459   total reward: 4307124.205632662
--------------------------------------------------- 

As you can see, the reward both during training and validation is increasing smoothly. And here is a screenshot of the tractogram generated at episode 57:

image

While it is far from perfect, it was starting to look like a brain before I ended the training. I'm sure letting the training run for more epochs would result in a nicer tractogram.

Let me know if you need anything else !

AJ-30 commented 2 years ago

Hi @AntoineTheb !

Thank you for getting back on the issue. I also trained the SAC agent on the processed dataset provided by you, and the tractogram I am getting post training is much better than before. However, I am yet to run test script. I can't wait to preprocess other datasets using the scilpy pipeline.

Screenshot 2022-04-01 at 11 06 07 AM

Thank you for your support!