caiyuanhao1998 / MST-plus-plus

"MST++: Multi-stage Spectral-wise Transformer for Efficient Spectral Reconstruction" (CVPRW 2022) & (Winner of NTIRE 2022 Spectral Recovery Challenge) and a toolbox for spectral reconstruction
https://arxiv.org/abs/2204.07908
MIT License
429 stars 59 forks source link

How to get the predict results #6

Closed randomNNN closed 2 years ago

randomNNN commented 2 years ago

Thanks for your open-source code. The MST++ is an amzing project in HSI reconstruction scene. But your code only have train and test code, which not contains the predict code. I'm a people of a new type of HSI reconstruction and I don't have any idea about the predict results. So, May you open your predict code in your repository?

linjing7 commented 2 years ago

Hi, we have uploaded the predicting code and updatad the README in our repo. You can reconstruct your RGB image by the following commands:

(1) Download the pretrained model zoo from (Google Drive / Baidu Disk, code: mst1) and place them to /MST-plus-plus/predict_code/model_zoo/.

(2) Run the following command to reconstruct your own RGB image.

cd /MST-plus-plus/predict_code/

# reconstruct by MST++
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mst_plus_plus --pretrained_model_path ./model_zoo/mst_plus_plus.pth --outf ./exp/mst_plus_plus/  --gpu_id 0

# reconstruct by MST-L
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mst --pretrained_model_path ./model_zoo/mst.pth --outf ./exp/mst/  --gpu_id 0

# reconstruct by MIRNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mirnet --pretrained_model_path ./model_zoo/mirnet.pth --outf ./exp/mirnet/  --gpu_id 0

# reconstruct by HINet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hinet --pretrained_model_path ./model_zoo/hinet.pth --outf ./exp/hinet/  --gpu_id 0

# reconstruct by MPRNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mprnet --pretrained_model_path ./model_zoo/mprnet.pth --outf ./exp/mprnet/  --gpu_id 0

# reconstruct by Restormer
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method restormer --pretrained_model_path ./model_zoo/restormer.pth --outf ./exp/restormer/  --gpu_id 0

# reconstruct by EDSR
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg --method edsr --pretrained_model_path ./model_zoo/edsr.pth --outf ./exp/edsr/  --gpu_id 0

# reconstruct by HDNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hdnet --pretrained_model_path ./model_zoo/hdnet.pth --outf ./exp/hdnet/  --gpu_id 0

# reconstruct by HRNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hrnet --pretrained_model_path ./model_zoo/hrnet.pth --outf ./exp/hrnet/  --gpu_id 0

# reconstruct by HSCNN+
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hscnn_plus --pretrained_model_path ./model_zoo/hscnn_plus.pth --outf ./exp/hscnn_plus/  --gpu_id 0

Please replace './demo/ARAD_1K_0912.jpg' with your RGB image path. The reconstructed results will be saved in /MST-plus-plus/predict_code/exp/.

randomNNN commented 2 years ago

Hi, we have uploaded the predicting code and updatad the README in our repo. You can reconstruct your RGB image by the following commands:

(1) Download the pretrained model zoo from (Google Drive / Baidu Disk, code: mst1) and place them to /MST-plus-plus/predict_code/model_zoo/.

(2) Run the following command to reconstruct your own RGB image.

cd /MST-plus-plus/predict_code/

# reconstruct by MST++
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mst_plus_plus --pretrained_model_path ./model_zoo/mst_plus_plus.pth --outf ./exp/mst_plus_plus/  --gpu_id 0

# reconstruct by MST-L
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mst --pretrained_model_path ./model_zoo/mst.pth --outf ./exp/mst/  --gpu_id 0

# reconstruct by MIRNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mirnet --pretrained_model_path ./model_zoo/mirnet.pth --outf ./exp/mirnet/  --gpu_id 0

# reconstruct by HINet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hinet --pretrained_model_path ./model_zoo/hinet.pth --outf ./exp/hinet/  --gpu_id 0

# reconstruct by MPRNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method mprnet --pretrained_model_path ./model_zoo/mprnet.pth --outf ./exp/mprnet/  --gpu_id 0

# reconstruct by Restormer
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method restormer --pretrained_model_path ./model_zoo/restormer.pth --outf ./exp/restormer/  --gpu_id 0

# reconstruct by EDSR
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg --method edsr --pretrained_model_path ./model_zoo/edsr.pth --outf ./exp/edsr/  --gpu_id 0

# reconstruct by HDNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hdnet --pretrained_model_path ./model_zoo/hdnet.pth --outf ./exp/hdnet/  --gpu_id 0

# reconstruct by HRNet
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hrnet --pretrained_model_path ./model_zoo/hrnet.pth --outf ./exp/hrnet/  --gpu_id 0

# reconstruct by HSCNN+
python test.py --rgb_path ./demo/ARAD_1K_0912.jpg  --method hscnn_plus --pretrained_model_path ./model_zoo/hscnn_plus.pth --outf ./exp/hscnn_plus/  --gpu_id 0

Please replace './demo/ARAD_1K_0912.jpg' with your RGB image path. The reconstructed results will be saved in /MST-plus-plus/predict_code/exp/.

Thanks for your reply!!!

randomNNN commented 2 years ago

I feel so sorry to bother you again. But when I run your predict code in the virtual environment torch1.7.1+torchvision0.8. Here is a bug!!!

屏幕截图 2022-05-13 234758

Can you help me to solve this bug?

linjing7 commented 2 years ago

Hi, our method is trained and tested in Linux and we have not tested it in Windows.
It seems that there is something wrong with your CUDA. Have you correctly installed your CUDA?
You can:

  1. Run our code in the Linux system.
  2. Or you can try to run the code on the CPU, you need to make the following modifications: (1) Replace predict_code/architecture/__init__.pywith:
    
    import torch
    from .edsr import EDSR
    from .HDNet import HDNet
    from .hinet import HINet
    from .hrnet import SGN
    from .HSCNN_Plus import HSCNN_Plus
    from .MIRNet import MIRNet
    from .MPRNet import MPRNet
    from .MST import MST
    from .MST_Plus_Plus import MST_Plus_Plus
    from .Restormer import Restormer

def model_generator(method, pretrained_model_path=None): if method == 'mirnet': model = MIRNet(n_RRG=3, n_MSRB=1, height=3, width=1) elif method == 'mst_plus_plus': model = MST_Plus_Plus()

model = MST_Plus_Plus()

elif method == 'mst':
    model = MST(dim=31, stage=2, num_blocks=[4, 7, 5])
elif method == 'hinet':
    model = HINet(depth=4)
elif method == 'mprnet':
    model = MPRNet(num_cab=4)
elif method == 'restormer':
    model = Restormer()
elif method == 'edsr':
    model = EDSR()
elif method == 'hdnet':
    model = HDNet()
elif method == 'hrnet':
    model = SGN()
elif method == 'hscnn_plus':
    model = HSCNN_Plus()
else:
    print(f'Method {method} is not defined !!!!')
if pretrained_model_path is not None:
    print(f'load model from {pretrained_model_path}')
    checkpoint = torch.load(pretrained_model_path, map_location=lambda storage, loc: storage)
    model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()},
                          strict=True)
return model
(2) Replace prediect_code/test.py with the following code:

import torch import argparse import torch.backends.cudnn as cudnn import os from architecture import * from utils import save_matv73 import cv2 import numpy as np import itertools parser = argparse.ArgumentParser(description="SSR") parser.add_argument('--method', type=str, default='mst_plus_plus') parser.add_argument('--pretrained_model_path', type=str, default='./model_zoo/mst_plus_plus.pth') parser.add_argument('--rgb_path', type=str, default='./demo/ARAD_1K_0912.jpg') parser.add_argument('--outf', type=str, default='./exp/mst_plus_plus/') parser.add_argument('--ensemble_mode', type=str, default='mean') parser.add_argument("--gpu_id", type=str, default='0') opt = parser.parse_args() os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id if not os.path.exists(opt.outf): os.makedirs(opt.outf)

def main(): cudnn.benchmark = True pretrained_model_path = opt.pretrained_model_path method = opt.method model = model_generator(method, pretrained_model_path) test(model, opt.rgb_path, opt.outf)

def test(model, rgb_path, save_path): var_name = 'cube' bgr = cv2.imread(rgb_path) rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) rgb = np.float32(rgb) rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min()) rgb = np.expand_dims(np.transpose(rgb, [2, 0, 1]), axis=0).copy() rgb = torch.from_numpy(rgb).float() print(f'Reconstructing {rgb_path}') with torch.no_grad(): result = model(rgb) result = result.cpu().numpy() * 1.0 result = np.transpose(np.squeeze(result), [1, 2, 0]) result = np.minimum(result, 1.0) result = np.maximum(result, 0)

mat_name = rgb_path.split('/')[-1][:-4] + '.mat'
mat_dir = os.path.join(save_path, mat_name)
save_matv73(mat_dir, var_name, result)
print(f'The reconstructed hyper spectral image are saved as {mat_dir}.')

if name == 'main': main()

randomNNN commented 2 years ago

Thanks for your patient reply. I get the .mat profile accroding to your guidnce using CPU. But I don't know how to visulize every channels of the .mat profile. Can you teach me how to visualize the HSI channels in python?

linjing7 commented 2 years ago

Hi, we do not plan to open source the visualization code now for some considerations. You can try to visualize the hyperspectral image

  1. using the 'Hyperspectral Viewer' toolbox in MATLAB (1) Install Hyperspectral Viewer in MATLAB (2) Run the following code:

    clear; clc;
    file_path = "ARAD_1K_0912.mat";
    pred = load(file_path).cube;
    hyperspectralViewer(pred)

    image

  2. or try the following code:

    import h5py
    import cv2
    import numpy as np
    path = "ARAD_1K_0912.mat"
    with h5py.File(path, 'r') as mat:
    hyper = np.float32(np.array(mat['cube']))*255
    cv2.imwrite('ARAD_1K_0912.png', hyper[15,:,:])
zzllbg commented 2 years ago

Hello author, I have got the MAT file, may I ask how I can extract the spectral information of a certain point?

caiyuanhao1998 commented 2 years ago

what do you mean by extracting the spectral information of a certain point?

zzllbg commented 2 years ago

提取某个点的光谱信息是什么意思?

For example, in figure 4 and figure 5 in the literature MST++, there is a comparison of spectra in the lower left corner. How can I verify this information?

caiyuanhao1998 commented 2 years ago

(i) Select a small spatial patch.

(ii) Compute the average spectral intensity of this region.

(iii) Compute the average correlation coefficient between the reconstructed HSI with GT HSI.

zzllbg commented 2 years ago

(i) 选择一个小的空间补丁。

(ii) 计算该区域的平均光谱强度。

(iii) 计算重建 HSI 与 GT HSI 之间的平均相关系数。 How to implement the second step?

caiyuanhao1998 commented 2 years ago

直接计算均值。你可以直接说中文。

zzllbg commented 2 years ago

直接计算均值。你可以直接说中文。 哈哈,谢谢老哥。我现在刚开始接触这个高光谱图像,所以不太明白怎么提取mat文件中的光谱信息,你的代码中有这部分的内容吗?

caiyuanhao1998 commented 2 years ago

有的,在另一个repo里面

https://github.com/caiyuanhao1998/MST

你可以看看4.3

然后你要是觉得我们的代码对你有帮助的话,麻烦点点star,fork,follow

zzllbg commented 2 years ago

好的,感谢!!

zzllbg commented 2 years ago

有的,在另一个repo里面

https://github.com/caiyuanhao1998/MST

你可以看看4.3

然后你要是觉得我们的代码对你有帮助的话,麻烦点点star,fork,follow

你好,我运行代码的时候遇到pred_block_无法识别的问题,老哥知道怎么解决吗?