mlmed / torchxrayvision

TorchXRayVision: A library of chest X-ray datasets and models. Classifiers, segmentation, and autoencoders.
https://mlmed.org/torchxrayvision
Apache License 2.0
856 stars 210 forks source link

Everything classified wrongly (seems like there is some system error I do) #152

Open croraf opened 2 months ago

croraf commented 2 months ago
import torchxrayvision as xrv
import skimage, torch, torchvision

# Prepare the image:
#img = skimage.io.imread("16747_3_1.jpg")
img = skimage.io.imread("covid-19-pneumonia-58-prior.jpg")
#img = skimage.io.imread("test2.png")
img = xrv.datasets.normalize(img, 255) # convert 8-bit image to [-1024, 1024] range
img = img.mean(2)[None, ...] # Make single color channel

transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)])
#transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(512)])

img = transform(img)
img = torch.from_numpy(img)

# Load model and process image
model = xrv.models.DenseNet(weights="densenet121-res224-all")
#model = xrv.models.ResNet(weights="resnet50-res512-all")
# model = xrv.baseline_models.jfhealthcare.DenseNet() 

outputs = model(img[None,...]) # or model.features(img[None,...]) 

# Print results
cpu_tensor = outputs[0].cpu();

result = zip(model.pathologies, cpu_tensor.detach().numpy())
result_sorted = sorted(result, key=lambda x: x[1], reverse=True)

for finding, percentage in result_sorted:
  print(f"{finding}: {percentage * 100:.0f}%")

I'm using this code which is pretty much the same as the code from the README. But the classification on the test image is completely wrong, as the image represents pneumonia, why?

image

croraf commented 2 months ago

Running process_image.py against the same image gives correct result

{'preds': {'Atelectasis': 0.5890367,
           'Cardiomegaly': 0.591056,
           'Consolidation': 0.311878,
           'Edema': 0.22105017,
           'Effusion': 0.44506797,
           'Emphysema': 0.47612804,
           'Enlarged Cardiomediastinum': 0.5062166,
           'Fibrosis': 0.4609814,
           'Fracture': 0.5526564,
           'Hernia': 0.06847992,
           'Infiltration': 0.13490699,
           'Lung Lesion': 0.095150016,
           'Lung Opacity': 0.29126057,
           'Mass': 0.09128846,
           'Nodule': 0.26987314,
           'Pleural_Thickening': 0.16248804,
           'Pneumonia': 0.51418424,
           'Pneumothorax': 0.12021518}}
croraf commented 2 months ago

I compared the scripts and the meaningful difference is that process_image.py doesn't use -resize by default.

If I employ -resize flag I get the same results as in the original code which is very very wrong.

Is this expected? I mean, the original image is 2300x2300 and resize reduces to 224x224 which is a big loss.

But the problem here is that xrv.models.DenseNet(weights="densenet121-res224-all") reports that it scales the image anyways (if it is not of 224x224 format). So there shouldn't be any difference in the two results.

ieee8023 commented 2 months ago

It seems the two different resize operations, one with skimage and one with PyTorch upsample, are changing the range of the pixel values. I'll look more into it.

croraf commented 2 months ago

I tried the same code without explicit resizing and with ResNet(weights="resnet50-res512-all") on covid-19-pneumonia-58-prior.jpg and it gives 1% on pneumonia. Something is fishy here.

import torchxrayvision as xrv
import skimage, torch, torchvision
print(xrv.__file__)

# Prepare the image:
#img = skimage.io.imread("16747_3_1.jpg")
img = skimage.io.imread("covid-19-pneumonia-58-prior.jpg")
#img = skimage.io.imread("test2.png")
img = xrv.datasets.normalize(img, 255) # convert 8-bit image to [-1024, 1024] range
img = img.mean(2)[None, ...] # Make single color channel

#transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)])
transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()])
#transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(512)])

img = transform(img)
img = torch.from_numpy(img)

# Load model and process image
#model = xrv.models.DenseNet(weights="densenet121-res224-all")
model = xrv.models.ResNet(weights="resnet50-res512-all")

# model = xrv.baseline_models.jfhealthcare.DenseNet() 

outputs = model(img[None,...]) # or model.features(img[None,...]) 

# Print results
cpu_tensor = outputs[0].cpu();

result = zip(model.pathologies, cpu_tensor.detach().numpy())
result_sorted = sorted(result, key=lambda x: x[1], reverse=True)

for finding, percentage in result_sorted:
  print(f"{finding}: {percentage * 100:.0f}%")