Unofficial training code for SegGPT.
From left to right: input image, masked label, ground truth, raw model prediction, discretized model prediction.
This code is developed with Python 3.9.
Install the required packages by running:
pip install -r requirements.txt
Create a new conda environment and install the required packages by running:
conda env create -f env.yml
Setup your dataset directory as follows:
<root_dataset_path>
├── images
│ ├── image1.tif
│ ├── image2.tif
│ ...
└── labels
├── image1.tif
├── image2.tif
...
Note:
data.py
to support your needs)..tif
as long as it can be loaded using PIL
library. Create a .json
config file. You can use the provided configs/base.json
as a template. Then, run:
python train.py --config <path_to_json_config>
The training uses DDP strategy and utilizes all available GPUs by default. You can specify the GPU to use by setting CUDA_VISIBLE_DEVICES
in the environment variable.
You can also launch tensorboard to monitor the training progress:
tensorboard --logdir logs
In the paper, the author mentioned using learnable tensor for in-context tuning. You can find my implementation for this in model.py
.