growvv / Pytorch-SEResNet

基于SEResNet的多源数据融合模型
3 stars 0 forks source link

Quick Start

  1. 安装 CUDA

  2. 安装 PyTorch 1.13 or later

  3. 安装依赖

pip install -r requirements.txt

  1. 进行训练

bash scripts/radar.sh

bash scripts/radar_multi.sh

  1. 进行测试

bash scripts/test_radar_multi.sh

数据集

例如,在降水定量估计任务中,

代码结构

训练和预测

python -m torch.distributed.launch --nproc_per_node=1 --master_port=12347 train_multi.py \
    --epochs 3\
    --batch-size 2 \
    --learning-rate 1e-5 \
    --scale 0.5 \
    --validation 1.0 \
    --in_classes 15 \
    --out_classes 4 \
    --dir_img data/radar_npy/factors/ \
    --dir_mask data/radar_npy/ob/ \
    --dir_checkpoint res/41_3/checkpoint/ \
    --save_checkpoint 1 \
    --save_interval 1 \
    --log_dir res/41_3/runs/ \