xiawj-hub / CTLIB

A lib of CT projector and back-projector based on PyTorch
MIT License
39 stars 6 forks source link

I got RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn using CTLIB library. #12

Open 2sjkim opened 1 year ago

2sjkim commented 1 year ago

Hello Teacher: I really appreciate your great CT geometry library.

I am now implementing "deep image prior" using ctlib library, but I had some errors. (exactly, I am doing sparse-view CT reconstruction)

For simplicity, I attached my code (Training network parts) ... input_train = data['input'].to(device) # gaussian target_train = data['target'].to(device) # sparse-view image

output = net(input_train) # simple UNet

out_for_loss = ctlib.projection(output, option_sparse) # forward projection with sparse view (num of projection views) out_for_loss = ctlib.fbp(output, option_sparse)

optim.zero_grad() loss = fn_loss(out_for_loss , target_train ) loss.backward() optim.step()

train_total_loss += [loss.item()] ....

What I want to do is, the output of network should be reconstructed into sparse-view CT images by applying forward projection and back-projection with option_sparse sequentially and then calculate the MSE loss with target image.

But I encountered a problem with "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn" thus, loss.backward() is not calculated.

For other solutions, people said that write " loss = Variable(loss, requires_grad=True)".

I also did that and aforementioned RuntimeError disappeared, but the loss was not updated (constant for every epoch).

I think the error was caused by directly calculating loss after using 'ctlib.projection' and 'ctlib.fbp' code.

How can I solve this problem?

xiawj-hub commented 1 year ago

You need to wrap the function with torch.autograd, see details in https://pytorch.org/docs/stable/notes/extending.html.

You can also access my other repositories to see how to warp the functions.

From: 2sjkim @.> Sent: Monday, February 6, 2023 4:41 AM To: xwj01/CTLIB @.> Cc: Subscribed @.***> Subject: [xwj01/CTLIB] I got RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn using CTLIB library. (Issue #12)

Hello Teacher: I really appreciate your great CT geometry library.

I am now implementing "deep image prior" using ctlib library, but I had some errors. (exactly, I am doing sparse-view CT reconstruction)

For simplicity, I attached my code (Training network parts) ... input_train = data['input'].to(device) # gaussian target_train = data['target'].to(device) # sparse-view image

output = net(input_train) # simple UNet

out_for_loss = ctlib.projection(output, option_sparse) # forward projection with sparse view (num of projection views) out_for_loss = ctlib.fbp(output, option_sparse)

backward

optim.zero_grad() loss = fn_loss(out_for_loss , target_train ) loss.backward() optim.step()

train_total_loss += [loss.item()] ....

What I want to do is, the output of network should be reconstructed into sparse-view CT images by applying forward projection and back-projection with option_sparse sequentially and then calculate the MSE loss with target image.

But I encountered a problem with "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn" thus, loss.backward() is not calculated.

For other solutions, people said that write " loss = Variable(loss, requires_grad=True)".

I also did that and aforementioned RuntimeError disappeared, but the loss was not updated (constant for every epoch).

I think the error was caused by directly calculating loss after using 'ctlib.projection' and 'ctlib.fbp' code.

How can I solve this problem?

— Reply to this email directly, view it on GitHub https://github.com/xwj01/CTLIB/issues/12 , or unsubscribe https://github.com/notifications/unsubscribe-auth/AJRVWQN5AELVCLDT2G54PIDWWDBL7ANCNFSM6AAAAAAUSOIUDU . You are receiving this because you are subscribed to this thread. https://github.com/notifications/beacon/AJRVWQLLCW3RLXEJ6INSI4LWWDBL7A5CNFSM6AAAAAAUSOIUDWWGG33NNVSW45C7OR4XAZNFJFZXG5LFVJRW63LNMVXHIX3JMTHF3NTPWE.gif Message ID: @. @.> >