budai4medtech / xfetus

xfetus -- [x]ynthetic[fetus] (:baby: :robot:) -- A library for ultrasound fetal imaging synthesis using techniques from GANs, transformers, and diffusion models.
5 stars 1 forks source link

Replicating DSRGAN.ipynb and using other planes (Trans-ventricular, Trans-thalamic, and Trans-cerebellum) #40

Open mxochicale opened 1 year ago

mxochicale commented 1 year ago

Qingyu Yang reported the following error thumbnail_c600431f40a94eed53e32a1ce08667eb

Also I noted that notebook is not updated !pip install -qqq medisynth https://github.com/budai4medtech/xfetus/blob/main/examples/difussion-super-resolution-gan/DSRGAN.ipynb

mxochicale commented 1 year ago

sorted out Screenshot from 2023-08-14 22-07-00

but leave it open in case there is further feedback from Qingyu Yang

Qingyu-Yang1 commented 1 year ago

When calculating the FID, do I need to calculate the FID for the GAN network or for the diffusion model, or do I need both?

mxochicale commented 1 year ago

FID score is computed with the following lines from this notebook: https://github.com/budai4medtech/xfetus/blob/main/examples/difussion-super-resolution-gan/DSRGAN.ipynb (also available in google colabs https://colab.research.google.com/drive/1Cbudr2g5qdC2LGBj_xYS-amgJQ-6OVM6?usp=sharing). I would suggest to running all the notebook in your google colabs and then read sr_gan_loss.csv and plot FID values. Please share your google-colab to have a look to your development.

It would be great if you get FID for both GAN and diffusion model.

Let me know how it goes. Thanks, --Miguel

        # Calculate FID score using unaugmented images and fake images
        fake_images = torch.from_numpy(fake_images)
        fake_images = fake_images.to(device)
        fid.update(original_images.byte(), real=True)        
        fid.update(fake_images.byte(), real=False)
        current_fid = fid.compute().item()

        # Save model weights
        torch.save(netG.state_dict(), "SRGAN_G_x256" + str(epoch))

    # Write loss/FID to a log file for each epoch
    with open('sr_gan_loss.csv', 'a') as f_object:
        writer_object = writer(f_object)
        writer_object.writerow([str(epoch),  str(total_g_loss / 236), str(total_d_loss / 236), str(current_fid)])
        f_object.close()
    fid.reset()
Qingyu-Yang1 commented 1 year ago

When I run the notebook, only one value appears for FID, and only three FID values appear after running three labels. In this case, how can I draw FID value to achieve that image?

It is appreciated for answering my question in your busy time.

--Qingyu Yang

mxochicale commented 1 year ago

Hi Qingyu, you need to read sr_gan_loss.csv to then plot it with matplotlib. This link might be useful to create your plots https://www.tutorialspoint.com/plot-data-from-csv-file-with-matplotlib.