BMEII-AI / RadImageNet

RadImageNet, a pre-trained convolutional neural networks trained solely from medical imaging to be used as the basis of transfer learning for medical imaging applications.
MIT License
345 stars 35 forks source link

Load RadImageNet weigths for pytorch #3

Open mligerhe opened 2 years ago

mligerhe commented 2 years ago

I was trying to load the weights from 'RadImageNet-ResMet50_notop.h5' to use them in pytorch. Is there any version of pretrained weights from RadImageNet for pytorch?

Also, is it possible to know the tensorflow and keras version used to train those CNNs? Could you provide a requirement.txt file?

innat commented 2 years ago

Also, is it possible to know the tensorflow and keras version used to train those CNNs? Could you provide a requirement.txt file?

Any version of tensorflow greater than 2 would be fine. I tested to load the model with tf 2.6 in kaggle environment, it loaded w/o any issue. I uploaded the pretrained weight here for easier use. You can check this

https://github.com/BMEII-AI/RadImageNet/blob/main/hemorrhage/hemorrhage_train.py#L125-L126

WuJunde commented 2 years ago

Here is a rough code to transfer keras RadImageNet-ResNet50 model to pytorch, hope it help

  import tensorflow as tf
  import numpy as np
  import torch
  import cv2
  from PIL import Image
  import torchvision.transforms as transforms
  from torchvision.models import resnet50
  from keras.layers.convolutional import Conv2D
  import keras
  import collections

  inputpath = './RadImageNet-ResNet50_notop.h5'
  outpath = './RadImageNet-ResNet50_notop_torch.pth'
  testimg = '../20478.PNG'

  def simple_test(net):

      img = Image.open(testimg).convert('RGB')

      trans_data = transforms.Compose([
      transforms.Resize(224),
      transforms.ToTensor(),
      ])

      img = trans_data(img).unsqueeze(0)
      out = net(img)
      return out.squeeze(0)[0]

  def keras_to_pyt(km, pm=None):

      weight_dict = dict()
      for layer in km.layers:
          if (type(layer) is Conv2D) and ('0' not in layer.get_config()['name']):
              weight_dict[layer.get_config()['name'] + '.weight'] = np.transpose(layer.get_weights()[0], (3, 2, 0, 1))
              # weight_dict[layer.get_config()['name'] + '.bias'] = layer.get_weights()[1] as mean
          elif type(layer) is keras.layers.Dense:
              weight_dict[layer.get_config()['name'] + '.weight'] = np.transpose(layer.get_weights()[0], (1, 0))
              weight_dict[layer.get_config()['name'] + '.bias'] = layer.get_weights()[1]

      if pm:
          pyt_state_dict = pm.state_dict()
          for key in pyt_state_dict.keys():
              pyt_state_dict[key] = torch.from_numpy(weight_dict[key])
          pm.load_state_dict(pyt_state_dict)
          return pm
      return weight_dict

  net = resnet50(num_classes = 1)
  out = simple_test(net)
  print('before output is', out)

  tf_keras_model = tf.keras.models.load_model(inputpath)
  weights = tf_keras_model.get_weights()

  weights = keras_to_pyt(tf_keras_model)
  values = list(weights.values())
  i = 0
  for name, param in net.named_parameters():
      if 'conv' in name:
          param.data = torch.tensor(values[i])
          i += 1

  out = simple_test(net)
  print('after output is', out)

  torch.save(net.state_dict(), outpath)
Warvito commented 1 year ago

Maybe it is interesting to have official support for Pytorch weights and possibly make them easy to use, as the pre-trained models from from torchvision.models or the timm package (https://github.com/rwightman/pytorch-image-models).

Warvito commented 1 year ago

Here is a rough code to transfer keras RadImageNet-ResNet50 model to pytorch, hope it help

  import tensorflow as tf
  import numpy as np
  import torch
  import cv2
  from PIL import Image
  import torchvision.transforms as transforms
  from torchvision.models import resnet50
  from keras.layers.convolutional import Conv2D
  import keras
  import collections

  inputpath = './RadImageNet-ResNet50_notop.h5'
  outpath = './RadImageNet-ResNet50_notop_torch.pth'
  testimg = '../20478.PNG'

  def simple_test(net):

      img = Image.open(testimg).convert('RGB')

      trans_data = transforms.Compose([
      transforms.Resize(224),
      transforms.ToTensor(),
      ])

      img = trans_data(img).unsqueeze(0)
      out = net(img)
      return out.squeeze(0)[0]

  def keras_to_pyt(km, pm=None):

    weight_dict = dict()
    for layer in km.layers:
        if (type(layer) is Conv2D) and ('0' not in layer.get_config()['name']):
            weight_dict[layer.get_config()['name'] + '.weight'] = np.transpose(layer.get_weights()[0], (3, 2, 0, 1))
            # weight_dict[layer.get_config()['name'] + '.bias'] = layer.get_weights()[1] as mean
        elif type(layer) is keras.layers.Dense:
            weight_dict[layer.get_config()['name'] + '.weight'] = np.transpose(layer.get_weights()[0], (1, 0))
            weight_dict[layer.get_config()['name'] + '.bias'] = layer.get_weights()[1]

    if pm:
        pyt_state_dict = pm.state_dict()
        for key in pyt_state_dict.keys():
            pyt_state_dict[key] = torch.from_numpy(weight_dict[key])
        pm.load_state_dict(pyt_state_dict)
        return pm
    return weight_dict

  net = resnet50(num_classes = 1)
  out = simple_test(net)
  print('before output is', out)

  tf_keras_model = tf.keras.models.load_model(inputpath)
  weights = tf_keras_model.get_weights()

  weights = keras_to_pyt(tf_keras_model)
  values = list(weights.values())
  i = 0
  for name, param in net.named_parameters():
    if 'conv' in name:
        param.data = torch.tensor(values[i])
        i += 1

  out = simple_test(net)
  print('after output is', out)

  torch.save(net.state_dict(), outpath)

I think here it is missing to convert the batchnorm layers, right?

WuJunde commented 1 year ago

Here is a rough code to transfer keras RadImageNet-ResNet50 model to pytorch, hope it help

  import tensorflow as tf
  import numpy as np
  import torch
  import cv2
  from PIL import Image
  import torchvision.transforms as transforms
  from torchvision.models import resnet50
  from keras.layers.convolutional import Conv2D
  import keras
  import collections

  inputpath = './RadImageNet-ResNet50_notop.h5'
  outpath = './RadImageNet-ResNet50_notop_torch.pth'
  testimg = '../20478.PNG'

  def simple_test(net):

      img = Image.open(testimg).convert('RGB')

      trans_data = transforms.Compose([
      transforms.Resize(224),
      transforms.ToTensor(),
      ])

      img = trans_data(img).unsqueeze(0)
      out = net(img)
      return out.squeeze(0)[0]

  def keras_to_pyt(km, pm=None):

      weight_dict = dict()
      for layer in km.layers:
          if (type(layer) is Conv2D) and ('0' not in layer.get_config()['name']):
              weight_dict[layer.get_config()['name'] + '.weight'] = np.transpose(layer.get_weights()[0], (3, 2, 0, 1))
              # weight_dict[layer.get_config()['name'] + '.bias'] = layer.get_weights()[1] as mean
          elif type(layer) is keras.layers.Dense:
              weight_dict[layer.get_config()['name'] + '.weight'] = np.transpose(layer.get_weights()[0], (1, 0))
              weight_dict[layer.get_config()['name'] + '.bias'] = layer.get_weights()[1]

      if pm:
          pyt_state_dict = pm.state_dict()
          for key in pyt_state_dict.keys():
              pyt_state_dict[key] = torch.from_numpy(weight_dict[key])
          pm.load_state_dict(pyt_state_dict)
          return pm
      return weight_dict

  net = resnet50(num_classes = 1)
  out = simple_test(net)
  print('before output is', out)

  tf_keras_model = tf.keras.models.load_model(inputpath)
  weights = tf_keras_model.get_weights()

  weights = keras_to_pyt(tf_keras_model)
  values = list(weights.values())
  i = 0
  for name, param in net.named_parameters():
      if 'conv' in name:
          param.data = torch.tensor(values[i])
          i += 1

  out = simple_test(net)
  print('after output is', out)

  torch.save(net.state_dict(), outpath)

I think here it is missing to convert the batchnorm layers, right?

Hi, the author of brain generation diffusion : ) We had a little chat on twitter about "generating the brain with lesions", remember? About your question, I tried to transfer the batchnorm parameters but failed (I vaguely remember it could not be done techniqually, due to the different implementations of ResNet in keras and pytorch ), and also I find the influence is not big (at least on my own dataset).

Warvito commented 1 year ago

Hi Junde! Yes, I do (btw, congratulations on your MedSegDiff paper! it looks great ^^). Thanks for sharing your experience! I have been trying to change the backend of these models and have been finding several differences between the implementation from Keras applications and Torchvision (for example, Keras version has bias, and I guess it is based on version 1 from ResNet, while the torchvision is based on version1.5). I hope to add more details here in the next few days.

For the batchnorm, I have been using:

def convert_bn(pytorch_bn, tf_bn):
    pytorch_bn.weight.data = torch.tensor(tf_bn.gamma.numpy())
    pytorch_bn.bias.data = torch.tensor(tf_bn.beta.numpy())
    pytorch_bn.running_mean.data = torch.tensor(tf_bn.moving_mean.numpy())
    pytorch_bn.running_var.data = torch.tensor(tf_bn.moving_variance.numpy())
    return pytorch_bn
ArielKes commented 1 year ago

I am trying to convert the InceptionV3 weights to Pytorch to extract features for computing FID. Does anybody know what changes I need to do?

NivAm12 commented 1 year ago

Here is a rough code to transfer keras RadImageNet-ResNet50 model to pytorch, hope it help

  import tensorflow as tf
  import numpy as np
  import torch
  import cv2
  from PIL import Image
  import torchvision.transforms as transforms
  from torchvision.models import resnet50
  from keras.layers.convolutional import Conv2D
  import keras
  import collections

  inputpath = './RadImageNet-ResNet50_notop.h5'
  outpath = './RadImageNet-ResNet50_notop_torch.pth'
  testimg = '../20478.PNG'

  def simple_test(net):

      img = Image.open(testimg).convert('RGB')

      trans_data = transforms.Compose([
      transforms.Resize(224),
      transforms.ToTensor(),
      ])

      img = trans_data(img).unsqueeze(0)
      out = net(img)
      return out.squeeze(0)[0]

  def keras_to_pyt(km, pm=None):

    weight_dict = dict()
    for layer in km.layers:
        if (type(layer) is Conv2D) and ('0' not in layer.get_config()['name']):
            weight_dict[layer.get_config()['name'] + '.weight'] = np.transpose(layer.get_weights()[0], (3, 2, 0, 1))
            # weight_dict[layer.get_config()['name'] + '.bias'] = layer.get_weights()[1] as mean
        elif type(layer) is keras.layers.Dense:
            weight_dict[layer.get_config()['name'] + '.weight'] = np.transpose(layer.get_weights()[0], (1, 0))
            weight_dict[layer.get_config()['name'] + '.bias'] = layer.get_weights()[1]

    if pm:
        pyt_state_dict = pm.state_dict()
        for key in pyt_state_dict.keys():
            pyt_state_dict[key] = torch.from_numpy(weight_dict[key])
        pm.load_state_dict(pyt_state_dict)
        return pm
    return weight_dict

  net = resnet50(num_classes = 1)
  out = simple_test(net)
  print('before output is', out)

  tf_keras_model = tf.keras.models.load_model(inputpath)
  weights = tf_keras_model.get_weights()

  weights = keras_to_pyt(tf_keras_model)
  values = list(weights.values())
  i = 0
  for name, param in net.named_parameters():
    if 'conv' in name:
        param.data = torch.tensor(values[i])
        i += 1

  out = simple_test(net)
  print('after output is', out)

  torch.save(net.state_dict(), outpath)

Hey, are there any updates for the weights in pytorch? This solution doesn't work for me. There are some missing weights.

yyama17 commented 1 year ago

Here is a rough code to transfer keras RadImageNet-ResNet50 model to pytorch, hope it help

  import tensorflow as tf
  import numpy as np
  import torch
  import cv2
  from PIL import Image
  import torchvision.transforms as transforms
  from torchvision.models import resnet50
  from keras.layers.convolutional import Conv2D
  import keras
  import collections

  inputpath = './RadImageNet-ResNet50_notop.h5'
  outpath = './RadImageNet-ResNet50_notop_torch.pth'
  testimg = '../20478.PNG'

  def simple_test(net):

      img = Image.open(testimg).convert('RGB')

      trans_data = transforms.Compose([
      transforms.Resize(224),
      transforms.ToTensor(),
      ])

      img = trans_data(img).unsqueeze(0)
      out = net(img)
      return out.squeeze(0)[0]

  def keras_to_pyt(km, pm=None):

      weight_dict = dict()
      for layer in km.layers:
          if (type(layer) is Conv2D) and ('0' not in layer.get_config()['name']):
              weight_dict[layer.get_config()['name'] + '.weight'] = np.transpose(layer.get_weights()[0], (3, 2, 0, 1))
              # weight_dict[layer.get_config()['name'] + '.bias'] = layer.get_weights()[1] as mean
          elif type(layer) is keras.layers.Dense:
              weight_dict[layer.get_config()['name'] + '.weight'] = np.transpose(layer.get_weights()[0], (1, 0))
              weight_dict[layer.get_config()['name'] + '.bias'] = layer.get_weights()[1]

      if pm:
          pyt_state_dict = pm.state_dict()
          for key in pyt_state_dict.keys():
              pyt_state_dict[key] = torch.from_numpy(weight_dict[key])
          pm.load_state_dict(pyt_state_dict)
          return pm
      return weight_dict

  net = resnet50(num_classes = 1)
  out = simple_test(net)
  print('before output is', out)

  tf_keras_model = tf.keras.models.load_model(inputpath)
  weights = tf_keras_model.get_weights()

  weights = keras_to_pyt(tf_keras_model)
  values = list(weights.values())
  i = 0
  for name, param in net.named_parameters():
      if 'conv' in name:
          param.data = torch.tensor(values[i])
          i += 1

  out = simple_test(net)
  print('after output is', out)

  torch.save(net.state_dict(), outpath)

Hey, are there any updates for the weights in pytorch? This solution doesn't work for me. There are some missing weights.

It is true that the weights are not converted properly with this method. I tried this method in a kaggle environment and there was a big gap between the outputs. https://www.kaggle.com/code/yosukeyama/convert-radimagenet-to-pytorch-model

On the other hand, it seems that the proper model can also be obtained in pytorch if the conversion is done via onnx. Here is the code, which I share with you. https://www.kaggle.com/code/yosukeyama/onnx-convert-radimagenet-to-pth?scriptVersionId=120685288