Closed AmitSharma1127 closed 3 years ago
I have replaced the evaluate.py
code inside fastmri
directory which you provided with the latest code available in the official fastmri repository (and copied the data folder from the official fastmri repo as well). This seems to solve the above issue.
Hi Amit, Please make sure to use the version of pytorchlighting that we have recommended. The IXI is a public dataset, you can obtain from http://brain-development. org/ixi-dataset/. However, the clinical dataset cannot be public because it involves patient privacy.
Amit Sharma @.***> 于2021年7月2日周五 下午6:18写道:
Closed #1 https://github.com/chunmeifeng/T2Net/issues/1.
— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/chunmeifeng/T2Net/issues/1#event-4971737816, or unsubscribe https://github.com/notifications/unsubscribe-auth/AR75XN5E73IZ53HFNRHX7NLTVXDENANCNFSM47WR7DFQ .
Hi,I also have this question.
Traceback (most recent call last):
File "D:/SJTU/Paper/motion_compensation/recon_and_resolution/code/ixi_train_t2net.py", line 17, in <module>
from ixi_module_t2net import UnetModule # experimental.unet.unet_module
File "D:\SJTU\Paper\motion_compensation\recon_and_resolution\code\ixi_module_t2net.py", line 14, in <module>
from fastmri.mri_ixi_module_t2net import MriModule
File "D:\SJTU\Paper\motion_compensation\recon_and_resolution\code\fastmri\mri_ixi_module_t2net.py", line 19, in <module>
from fastmri import evaluate
File "D:\SJTU\Paper\motion_compensation\recon_and_resolution\code\fastmri\evaluate.py", line 14, in <module>
from pytorch_lightning.metrics.metric import TensorMetric,NumpyMetric
ImportError: cannot import name 'TensorMetric' from 'pytorch_lightning.metrics.metric' (E:\python\lib\site-packages\pytorch_lightning\metrics\metric.py)
And the version of pytorchlighting is 1.0.6.
@Four-Fire Please refer this for your error. This solves the issue.
I have replaced the
evaluate.py
code insidefastmri
directory which you provided with the latest code available in the official fastmri repository (and copied the data folder from the official fastmri repo as well). This seems to solve the above issue.
@AmitSharma1127 I tried, but failed.
Traceback (most recent call last):
File "D:/SJTU/Paper/motion_compensation/recon_and_resolution/code/ixi_train_t2net.py", line 17, in <module>
from ixi_module_t2net import UnetModule # experimental.unet.unet_module
File "D:\SJTU\Paper\motion_compensation\recon_and_resolution\code\ixi_module_t2net.py", line 14, in <module>
from fastmri.mri_ixi_module_t2net import MriModule
File "D:\SJTU\Paper\motion_compensation\recon_and_resolution\code\fastmri\mri_ixi_module_t2net.py", line 21, in <module>
from fastmri.evaluate import DistributedMetricSum
ImportError: cannot import name 'DistributedMetricSum' from 'fastmri.evaluate' (D:\SJTU\Paper\motion_compensation\recon_and_resolution\code\fastmri\evaluate.py)
You need to make appropriate changes in a couple of files to make this work.
In fastmri/mri_ixi_module_t2net.py
, change the existing code with the following code:
At line 21:
from fastmri.evaluate import DistributedMetricSum
At line 88-93
self.NMSE = DistributedMetricSum()
self.SSIM = DistributedMetricSum()
self.PSNR = DistributedMetricSum()
self.ValLoss = DistributedMetricSum()
self.TestLoss = DistributedMetricSum()
self.TotExamples = DistributedMetricSum()
Replace the existinig fastmri/evaluate.py
, with the following code:
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import argparse
import pathlib
from argparse import ArgumentParser
import h5py
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.metrics.metric import Metric
from runstats import Statistics
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import torch
import fastmri
from fastmri import evaluate
class MSE(Metric):
"""Calculates MSE and aggregates by summing across distr processes."""
def __init__(self, name="MSE", *args, **kwargs):
super().__init__(name=name, *args, **kwargs)
def forward(self, gt, pred):
return mse(gt, pred)
class NMSE(Metric):
"""Calculates NMSE and aggregates by summing across distr processes."""
def __init__(self, name="NMSE", *args, **kwargs):
super().__init__(name=name, *args, **kwargs)
def forward(self, gt, pred):
return nmse(gt, pred)
class PSNR(Metric):
"""Calculates PSNR and aggregates by summing across distr processes."""
def __init__(self, name="PSNR", *args, **kwargs):
super().__init__(name=name, *args, **kwargs)
def forward(self, gt, pred):
return psnr(gt, pred)
class SSIM(Metric):
"""Calculates SSIM and aggregates by summing across distr processes."""
def __init__(self, name="SSIM", *args, **kwargs):
super().__init__(name=name, *args, **kwargs)
def forward(self, gt, pred, maxval=None):
return ssim(gt, pred, maxval=maxval)
class DistributedMetricSum(pl.metrics.Metric):
def __init__(self, dist_sync_on_step=True):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("quantity", default=torch.tensor(0.0), dist_reduce_fx="sum")
def update(self, batch: torch.Tensor): # type: ignore
self.quantity += batch
def compute(self):
return self.quantity
def mse(gt, pred):
"""Compute Mean Squared Error (MSE)"""
return np.mean((gt - pred) ** 2)
def nmse(gt, pred):
"""Compute Normalized Mean Squared Error (NMSE)"""
return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2
def psnr(gt, pred):
"""Compute Peak Signal to Noise Ratio metric (PSNR)"""
return peak_signal_noise_ratio(gt, pred, data_range=gt.max())
def ssim(gt, pred, maxval=None):
"""Compute Structural Similarity Index Metric (SSIM)"""
maxval = gt.max() if maxval is None else maxval
ssim = 0
for slice_num in range(gt.shape[0]):
ssim = ssim + structural_similarity(
gt[slice_num], pred[slice_num], data_range=maxval
)
ssim = ssim / gt.shape[0]
return ssim
METRIC_FUNCS = dict(MSE=mse, NMSE=nmse, PSNR=psnr, SSIM=ssim,)
class Metrics(object):
"""
Maintains running statistics for a given collection of metrics.
"""
def __init__(self, metric_funcs):
"""
Args:
metric_funcs (dict): A dict where the keys are metric names and the
values are Python functions for evaluating that metric.
"""
self.metrics = {metric: Statistics() for metric in metric_funcs}
def push(self, target, recons):
for metric, func in METRIC_FUNCS.items():
self.metrics[metric].push(func(target, recons))
def means(self):
return {metric: stat.mean() for metric, stat in self.metrics.items()}
def stddevs(self):
return {metric: stat.stddev() for metric, stat in self.metrics.items()}
def __repr__(self):
means = self.means()
stddevs = self.stddevs()
metric_names = sorted(list(means))
return " ".join(
f"{name} = {means[name]:.4g} +/- {2 * stddevs[name]:.4g}"
for name in metric_names
)
def evaluate(args, recons_key):
metrics = Metrics(METRIC_FUNCS)
for tgt_file in args.target_path.iterdir():
with h5py.File(tgt_file, "r") as target, h5py.File(
args.predictions_path / tgt_file.name, "r"
) as recons:
if args.acquisition and args.acquisition != target.attrs["acquisition"]:
continue
if args.acceleration and target.attrs["acceleration"] != args.acceleration:
continue
target = target[recons_key][()]
recons = recons["reconstruction"][()]
target = transforms.center_crop(
target, (target.shape[-1], target.shape[-1])
)
recons = transforms.center_crop(
recons, (target.shape[-1], target.shape[-1])
)
metrics.push(target, recons)
return metrics
if __name__ == "__main__":
parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--target-path",
type=pathlib.Path,
required=True,
help="Path to the ground truth data",
)
parser.add_argument(
"--predictions-path",
type=pathlib.Path,
required=True,
help="Path to reconstructions",
)
parser.add_argument(
"--challenge",
choices=["singlecoil", "multicoil"],
required=True,
help="Which challenge",
)
parser.add_argument("--acceleration", type=int, default=None)
print ('corpd')
parser.add_argument(
"--acquisition",
choices=[
"CORPD_FBK",
"CORPDFS_FBK",
"AXT1",
"AXT1PRE",
"AXT1POST",
"AXT2",
"AXFLAIR",
],
default=CORPD_FBK,
help="If set, only volumes of the specified acquisition type are used "
"for evaluation. By default, all volumes are included.",
)
args = parser.parse_args()
recons_key = (
"reconstruction_rss" if args.challenge == "multicoil" else "reconstruction_esc"
)
metrics = evaluate(args, recons_key)
print(metrics)
@AmitSharma1127 Thanks a lot! By the way, would you mind sharing the "1D-Cartesian_6X_128128.mat" and "1D-Random-3X_256.mat" with me? My email address is four_fire@sjtu.edu.cn. Your help would be greatly appreciated.
I am running this on google collab, in the TPU runtime. I have already installed all the dependecies.
Running
python ixi_train_t2net.py
throws an exception, the stack trace of which is as follows:Am I missing something?