chunmeifeng / T2Net

【MICCAI 2021】Task Transformer Network for Joint MRI Reconstruction and Super-Resolution
100 stars 13 forks source link

cannot import name 'NumpyMetric' from 'pytorch_lightning.metrics.metric' #1

Closed AmitSharma1127 closed 3 years ago

AmitSharma1127 commented 3 years ago

I am running this on google collab, in the TPU runtime. I have already installed all the dependecies.

Running python ixi_train_t2net.py throws an exception, the stack trace of which is as follows:

WARNING:root:TPU has started up successfully with version pytorch-1.7
Traceback (most recent call last):
  File "ixi_train_t2net.py", line 17, in <module>
    from ixi_module_t2net import UnetModule  # experimental.unet.unet_module
  File "/content/gdrive/My Drive/T2Net/ixi_module_t2net.py", line 14, in <module>
    from fastmri.mri_ixi_module_t2net import MriModule
  File "/content/gdrive/My Drive/T2Net/fastmri/mri_ixi_module_t2net.py", line 19, in <module>
    from fastmri import evaluate
  File "/content/gdrive/My Drive/T2Net/fastmri/evaluate.py", line 14, in <module>
    from pytorch_lightning.metrics.metric import NumpyMetric, TensorMetric
ImportError: cannot import name 'NumpyMetric' from 'pytorch_lightning.metrics.metric' (/usr/local/lib/python3.7/dist-packages/pytorch_lightning/metrics/metric.py)

Am I missing something?

AmitSharma1127 commented 3 years ago

I have replaced the evaluate.py code inside fastmri directory which you provided with the latest code available in the official fastmri repository (and copied the data folder from the official fastmri repo as well). This seems to solve the above issue.

chunmeifeng commented 3 years ago

Hi Amit, Please make sure to use the version of pytorchlighting that we have recommended. The IXI is a public dataset, you can obtain from http://brain-development. org/ixi-dataset/. However, the clinical dataset cannot be public because it involves patient privacy.

Amit Sharma @.***> 于2021年7月2日周五 下午6:18写道:

Closed #1 https://github.com/chunmeifeng/T2Net/issues/1.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/chunmeifeng/T2Net/issues/1#event-4971737816, or unsubscribe https://github.com/notifications/unsubscribe-auth/AR75XN5E73IZ53HFNRHX7NLTVXDENANCNFSM47WR7DFQ .

Four-Fire commented 3 years ago

Hi,I also have this question.

Traceback (most recent call last):
  File "D:/SJTU/Paper/motion_compensation/recon_and_resolution/code/ixi_train_t2net.py", line 17, in <module>
    from ixi_module_t2net import UnetModule  # experimental.unet.unet_module
  File "D:\SJTU\Paper\motion_compensation\recon_and_resolution\code\ixi_module_t2net.py", line 14, in <module>
    from fastmri.mri_ixi_module_t2net import MriModule
  File "D:\SJTU\Paper\motion_compensation\recon_and_resolution\code\fastmri\mri_ixi_module_t2net.py", line 19, in <module>
    from fastmri import evaluate
  File "D:\SJTU\Paper\motion_compensation\recon_and_resolution\code\fastmri\evaluate.py", line 14, in <module>
    from pytorch_lightning.metrics.metric import TensorMetric,NumpyMetric
ImportError: cannot import name 'TensorMetric' from 'pytorch_lightning.metrics.metric' (E:\python\lib\site-packages\pytorch_lightning\metrics\metric.py)

And the version of pytorchlighting is 1.0.6.

AmitSharma1127 commented 3 years ago

@Four-Fire Please refer this for your error. This solves the issue.

I have replaced the evaluate.py code inside fastmri directory which you provided with the latest code available in the official fastmri repository (and copied the data folder from the official fastmri repo as well). This seems to solve the above issue.

Four-Fire commented 3 years ago

@AmitSharma1127 I tried, but failed.

Traceback (most recent call last):
  File "D:/SJTU/Paper/motion_compensation/recon_and_resolution/code/ixi_train_t2net.py", line 17, in <module>
    from ixi_module_t2net import UnetModule  # experimental.unet.unet_module
  File "D:\SJTU\Paper\motion_compensation\recon_and_resolution\code\ixi_module_t2net.py", line 14, in <module>
    from fastmri.mri_ixi_module_t2net import MriModule
  File "D:\SJTU\Paper\motion_compensation\recon_and_resolution\code\fastmri\mri_ixi_module_t2net.py", line 21, in <module>
    from fastmri.evaluate import DistributedMetricSum
ImportError: cannot import name 'DistributedMetricSum' from 'fastmri.evaluate' (D:\SJTU\Paper\motion_compensation\recon_and_resolution\code\fastmri\evaluate.py)
AmitSharma1127 commented 3 years ago

You need to make appropriate changes in a couple of files to make this work.

In fastmri/mri_ixi_module_t2net.py, change the existing code with the following code:

At line 21:

from fastmri.evaluate import DistributedMetricSum

At line 88-93

self.NMSE = DistributedMetricSum()
self.SSIM = DistributedMetricSum()
self.PSNR = DistributedMetricSum()
self.ValLoss = DistributedMetricSum()
self.TestLoss = DistributedMetricSum()
self.TotExamples = DistributedMetricSum()

Replace the existinig fastmri/evaluate.py, with the following code:

"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import argparse
import pathlib
from argparse import ArgumentParser

import h5py
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.metrics.metric import Metric
from runstats import Statistics
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

import torch

import fastmri
from fastmri import evaluate

class MSE(Metric):
    """Calculates MSE and aggregates by summing across distr processes."""

    def __init__(self, name="MSE", *args, **kwargs):
        super().__init__(name=name, *args, **kwargs)

    def forward(self, gt, pred):
        return mse(gt, pred)

class NMSE(Metric):
    """Calculates NMSE and aggregates by summing across distr processes."""

    def __init__(self, name="NMSE", *args, **kwargs):
        super().__init__(name=name, *args, **kwargs)

    def forward(self, gt, pred):
        return nmse(gt, pred)

class PSNR(Metric):
    """Calculates PSNR and aggregates by summing across distr processes."""

    def __init__(self, name="PSNR", *args, **kwargs):
        super().__init__(name=name, *args, **kwargs)

    def forward(self, gt, pred):
        return psnr(gt, pred)

class SSIM(Metric):
    """Calculates SSIM and aggregates by summing across distr processes."""

    def __init__(self, name="SSIM", *args, **kwargs):
        super().__init__(name=name, *args, **kwargs)

    def forward(self, gt, pred, maxval=None):
        return ssim(gt, pred, maxval=maxval)

class DistributedMetricSum(pl.metrics.Metric):
    def __init__(self, dist_sync_on_step=True):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state("quantity", default=torch.tensor(0.0), dist_reduce_fx="sum")

    def update(self, batch: torch.Tensor):  # type: ignore
        self.quantity += batch

    def compute(self):
        return self.quantity

def mse(gt, pred):
    """Compute Mean Squared Error (MSE)"""
    return np.mean((gt - pred) ** 2)

def nmse(gt, pred):
    """Compute Normalized Mean Squared Error (NMSE)"""
    return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2

def psnr(gt, pred):
    """Compute Peak Signal to Noise Ratio metric (PSNR)"""
    return peak_signal_noise_ratio(gt, pred, data_range=gt.max())

def ssim(gt, pred, maxval=None):
    """Compute Structural Similarity Index Metric (SSIM)"""
    maxval = gt.max() if maxval is None else maxval

    ssim = 0
    for slice_num in range(gt.shape[0]):
        ssim = ssim + structural_similarity(
            gt[slice_num], pred[slice_num], data_range=maxval
        )

    ssim = ssim / gt.shape[0]

    return ssim

METRIC_FUNCS = dict(MSE=mse, NMSE=nmse, PSNR=psnr, SSIM=ssim,)

class Metrics(object):
    """
    Maintains running statistics for a given collection of metrics.
    """

    def __init__(self, metric_funcs):
        """
        Args:
            metric_funcs (dict): A dict where the keys are metric names and the
                values are Python functions for evaluating that metric.
        """
        self.metrics = {metric: Statistics() for metric in metric_funcs}

    def push(self, target, recons):
        for metric, func in METRIC_FUNCS.items():
            self.metrics[metric].push(func(target, recons))

    def means(self):
        return {metric: stat.mean() for metric, stat in self.metrics.items()}

    def stddevs(self):
        return {metric: stat.stddev() for metric, stat in self.metrics.items()}

    def __repr__(self):
        means = self.means()
        stddevs = self.stddevs()
        metric_names = sorted(list(means))
        return " ".join(
            f"{name} = {means[name]:.4g} +/- {2 * stddevs[name]:.4g}"
            for name in metric_names
        )

def evaluate(args, recons_key):
    metrics = Metrics(METRIC_FUNCS)

    for tgt_file in args.target_path.iterdir():
        with h5py.File(tgt_file, "r") as target, h5py.File(
            args.predictions_path / tgt_file.name, "r"
        ) as recons:
            if args.acquisition and args.acquisition != target.attrs["acquisition"]:
                continue

            if args.acceleration and target.attrs["acceleration"] != args.acceleration:
                continue

            target = target[recons_key][()]
            recons = recons["reconstruction"][()]
            target = transforms.center_crop(
                target, (target.shape[-1], target.shape[-1])
            )
            recons = transforms.center_crop(
                recons, (target.shape[-1], target.shape[-1])
            )
            metrics.push(target, recons)

    return metrics

if __name__ == "__main__":
    parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "--target-path",
        type=pathlib.Path,
        required=True,
        help="Path to the ground truth data",
    )
    parser.add_argument(
        "--predictions-path",
        type=pathlib.Path,
        required=True,
        help="Path to reconstructions",
    )
    parser.add_argument(
        "--challenge",
        choices=["singlecoil", "multicoil"],
        required=True,
        help="Which challenge",
    )
    parser.add_argument("--acceleration", type=int, default=None)
    print ('corpd')
    parser.add_argument(
        "--acquisition",
        choices=[
            "CORPD_FBK",
            "CORPDFS_FBK",
            "AXT1",
            "AXT1PRE",
            "AXT1POST",
            "AXT2",
            "AXFLAIR",
        ],
        default=CORPD_FBK,
        help="If set, only volumes of the specified acquisition type are used "
        "for evaluation. By default, all volumes are included.",
    )
    args = parser.parse_args()

    recons_key = (
        "reconstruction_rss" if args.challenge == "multicoil" else "reconstruction_esc"
    )
    metrics = evaluate(args, recons_key)
    print(metrics)
Four-Fire commented 3 years ago

@AmitSharma1127 Thanks a lot! By the way, would you mind sharing the "1D-Cartesian_6X_128128.mat" and "1D-Random-3X_256.mat" with me? My email address is four_fire@sjtu.edu.cn. Your help would be greatly appreciated.