htcr / sam_road

Segment Anything Model for large-scale, vectorized road network extraction from aerial imagery. CVPRW 2024
https://arxiv.org/pdf/2403.16051.pdf
MIT License
148 stars 18 forks source link

How can I run the model on my own dataset? #33

Open kurdt23 opened 1 month ago

kurdt23 commented 1 month ago

I have my own image dataset and I don't understand where/how I can make key point masks for images. Can you explain the process to prepare the own dataset and then run the model, please? I'll be waiting your answer.

htcr commented 1 month ago

Hi, does your dataset contain ground-truth graphs? You can process that graph into the same format as used in this codebase, then it shall generate keypoint masks for you.

htcr commented 1 month ago

For example, you can refer to the generate_labels.py under cityscales dir. From line 82, it reads the GT graph - and it reveals the gt graph format, basically an adjacency list of vertices.

kurdt23 commented 1 month ago

Unfortunately my dataset does not contain ground-truth graphs. I have only satellite images of 512x512 roads with different magnification in tif format. I attach some examples of these images. Can you please help me, how can I get ground-truth graphs from my images?

htcr commented 1 month ago

Ok, if you were trying to re-train/fine-tune on your dataset, you do need ground-truth graph. Otherwise if you just want to run our checkpoint on some images, you can just follow the inference instruction in README.

kurdt23 commented 1 month ago

Thank you! I have checkpoints on City scale and SpaceNet datasets, I'll try some of them on my images.

kurdt23 commented 1 month ago

Hello, I try to start inference on my own images via command: python3 inferencer.py --config=config/toponet_vitb_512_ekb.yaml --checkpoint=lightning_logs/vhfsw197/checkpoints/epoch=9-step=25000.ckpt

But then I got an error:

##### Loading Trained CKPT lightning_logs/vhfsw197/checkpoints/epoch=9-step=25000.ckpt #####
Traceback (most recent call last):
  File "/misc/home6/s0181/sam_road/inferencer.py", line 281, in <module>
    for img_id in test_img_indices:
NameError: name 'test_img_indices' is not defined

line 281:

for img_id in test_img_indices:
        print(f'Processing {img_id}')
        # [H, W, C] RGB
        img = read_rgb_img(rgb_pattern.format(img_id))
        start_seconds = time.time()

and e.t.c.

I think the erorr is sampling with the dataset name. But for cityscale and spacenet need gt graph, which I don't have for my images.

if config.DATASET == 'cityscale':
        _, _, test_img_indices = cityscale_data_partition()
        rgb_pattern = './cityscale/20cities/region_{}_sat.png'
        gt_graph_pattern = 'cityscale/20cities/region_{}_graph_gt.pickle'
    elif config.DATASET == 'spacenet':
        _, _, test_img_indices = spacenet_data_partition()
        rgb_pattern = './spacenet/RGB_1.0_meter/{}__rgb.png'
        gt_graph_pattern = './spacenet/RGB_1.0_meter/{}__gt_graph.p'

I used the configuration reference from your toponet_vitb_1024.yaml
toponet_vitb_512_ekb.yaml:

SAM_VERSION: 'vit_b'
SAM_CKPT_PATH: 'sam_ckpts/sam_vit_b_01ec64.pth'
PATCH_SIZE: 512
BATCH_SIZE: 16
DATA_WORKER_NUM: 1
TRAIN_EPOCHS: 10
BASE_LR: 0.001
FREEZE_ENCODER: False
ENCODER_LR_FACTOR: 0.1
ENCODER_LORA: False
FOCAL_LOSS: False
USE_SAM_DECODER: False

# TOPONET
# sample per patch
TOPO_SAMPLE_NUM: 256

# Inference
INFER_BATCH_SIZE: 64
SAMPLE_MARGIN: 64
INFER_PATCHES_PER_EDGE: 16

# [0, 255]
ITSC_THRESHOLD: 128
ROAD_THRESHOLD: 128
# pixels
ITSC_NMS_RADIUS: 8
ROAD_NMS_RADIUS: 16
NEIGHBOR_RADIUS: 64
MAX_NEIGHBOR_QUERIES: 16

Please, can you explane me what I need to do to solve the problem?

kurdt23 commented 1 week ago

Hello, I try to start inference on my own images via command: python3 inferencer.py --config=config/toponet_vitb_512_ekb.yaml --checkpoint=lightning_logs/vhfsw197/checkpoints/epoch=9-step=25000.ckpt

and e.t.c. Please, can you explane me what I need to do to solve the problem?

This problem has already been solved. I understand correctly that even to run inferencer.py you still need GTs? Could you please write some instructions to run inferencer.py on arbitrary images on a trained model? Because I managed to run only with GT from SityScale dataset (I just copied random ones and renamed them for my pictures), without GT inferencer.py on random pictures it doesn't run.

inferencer.py and GT

gt_graph_path = gt_graph_pattern.format(img_id)
        gt_graph = pickle.load(open(gt_graph_path, "rb"))
        gt_nodes, gt_edges = graph_utils.convert_from_sat2graph_format(gt_graph)
        if len(gt_nodes) == 0:
            gt_nodes = np.zeros([0, 2], dtype=np.float32)
...
large_map_sat2graph_format = graph_utils.convert_to_sat2graph_format(pred_nodes, pred_edges)
        graph_save_dir = os.path.join(output_dir, 'graph')
        if not os.path.exists(graph_save_dir):
            os.makedirs(graph_save_dir)
        graph_save_path = os.path.join(graph_save_dir, f'{img_id}.p')
        with open(graph_save_path, 'wb') as file:
            pickle.dump(large_map_sat2graph_format, file)