NVlabs / sionna

Sionna: An Open-Source Library for Next-Generation Physical Layer Research
https://nvlabs.github.io/sionna
Other
740 stars 211 forks source link

Usage of pytorch model along with sionna #571

Open ShreyasKulkarni19 opened 1 week ago

ShreyasKulkarni19 commented 1 week ago

Hello, I'm trying to implement the pytorch version of the neural receiver model. I have converted the tensorflow tensors into pytorch tensors (edno data). The project details cannot be discussed. But would like to know what you think about the challenges that I might face and what other code snippets should be changed if I'm converting my tensorflow model into pytorch.

jhoydis commented 1 week ago

Hi @ShreyasKulkarni19,

I do not think that there should be any unexpected difficulty.

hafezmg48 commented 1 week ago

@jhoydis So if there is a chance to consider updating the Sionna's keras to v3, I believe that in keras 3 it is possible to select the backend between torch, and tf, which would be a huge step towards compatibility with different projects. I hope that would be a feasible thing to do...

ShreyasKulkarni19 commented 1 week ago

The outputs of the sionna mapper, demapper etc are all tensorflow tensors. my approach is to convert those into pytorch tensors and then feed them into the pytorch model. but I'm facing issues with dimensions and typeErrors.

jhoydis commented 1 week ago

@jhoydis So if there is a chance to consider updating the Sionna's keras to v3, I believe that in keras 3 it is possible to select the backend between torch, and tf, which would be a huge step towards compatibility with different projects. I hope that would be a feasible thing to do...

Even with Keras 3, Sionna would only support the TF backend. In order to support multiple backends, Sionna would need to be ported to the Keras API. One of the reasons why Sionna is written in TF is because of its support for complex-valued tensors that other frameworks are lacking, see #486.

ShreyasKulkarni19 commented 1 week ago

Because I'm working on building a wrapper for sionna using pytorch, my approach is to convert all the tensorflow tensors to pytorch tensors and then feed it to the pytorch model. What do you think about this? Because i always face problems while converting the complex-valued tensor.

jhoydis commented 1 week ago

The problem with such a wrapper is that you cannot compile TF and PyTorch together into a graph. That means that your code will always be less optimal than a native implementation in one of the frameworks. You will also loose gradient information.

You might want to have a look at dlpack which should allow you to access TF tensors via PyTorch https://www.tensorflow.org/api_docs/python/tf/experimental/dlpack/from_dlpack https://pytorch.org/docs/stable/dlpack.html

Here is a small code snippet that shows how this works in principle:

import tensorflow as tf
import torch

# Assume `tf_tensor` is your TensorFlow tensor on the GPU
tf_tensor = tf.constant([1.0, 2.0, 3.0], dtype=tf.float32)

# Convert TensorFlow tensor to DLPack
dlpack_tensor = tf.experimental.dlpack.to_dlpack(tf_tensor)

# Convert DLPack tensor to PyTorch tensor
torch_tensor = torch.utils.dlpack.from_dlpack(dlpack_tensor)
ShreyasKulkarni19 commented 4 days ago

Hi @jhoydis , here is the section I'm working on. I have the model in PyTorch and I'm converting the tensors using the dl_pack as suggested. I am kinda stuck here. Not sure how to proceed.

Sionna_Pytorch.ipynb - Colab.pdf

jhoydis commented 4 days ago

Hi @ShreyasKulkarni19,

As I have said above, I do not recommend mixing these frameworks and cannot help here.

hafezmg48 commented 1 day ago

I am kinda stuck here. Not sure how to proceed.

Sionna_Pytorch.ipynb - Colab.pdf

This is likely the thing you are looking for:

    import torch
    import tensorflow as tf

    dlpack_tensor = torch.utils.dlpack.to_dlpack(torch_tensor)
    tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_tensor)

    tf_tensor = TensorFlow_functions(tf_tensor)

    dlpack_tensor = tf.experimental.dlpack.to_dlpack(tf_tensor)
    torch_tensor= torch.utils.dlpack.from_dlpack(dlpack_tensor)
ShreyasKulkarni19 commented 1 day ago

@hafezmg48 thank you! I believe I used this in my current code that I have shared above. Yet, I'm facing some errors. I want to know how I can handle the gradients as well. Because directly converting the tensors might miss out on gradient information during back-propagation. Any help with that is highly appreciated!