QihuangZhang / CeLEry

CeLEry: cell location recovery in single-cell RNA sequencing
MIT License
26 stars 3 forks source link

Enhance Robustness of Type Handling in fit_functions.py #9

Open KaishinShaw opened 1 month ago

KaishinShaw commented 1 month ago

Description: There is a potential issue in CeLEry/CeLEry_package/CeLEry/fit_functions.py at line 109 where the code attempts to convert layer_weights to a NumPy array before creating a PyTorch tensor:

layer_weights = torch.tensor(layer_weights.to_numpy())

In the tutorial.ipynb, layer_weights is already a tensor type, which does not have a to_numpy() method, causing an AttributeError. This suggests that the handling of layer_weights should be more robust to accommodate different data types.

Potential Optimization:

# If there's no need for a new tensor copy, consider using:
# layer_weights = layer_weights
# This avoids unnecessary data conversion if the only goal is to ensure it's a tensor.

Proposed Solution: To prevent errors and make the code more robust to different input types, I recommend adding a type check before processing layer_weights. Here is the suggested modification:

if isinstance(layer_weights, torch.Tensor):
    # If layer_weights is already a Tensor, convert to NumPy array and back to Tensor (if necessary)
    layer_weights = torch.tensor(layer_weights.numpy())
elif isinstance(layer_weights, pd.DataFrame):
    # If layer_weights is a DataFrame, convert to NumPy array first
    layer_weights = torch.tensor(layer_weights.to_numpy())
else:
    # If the input type is neither Tensor nor DataFrame, raise an error
    raise TypeError("layer_weights must be either a pandas DataFrame or a torch Tensor.")

Additional Context: This change will ensure that layer_weights can be correctly processed whether it comes as a PyTorch tensor or a pandas DataFrame, thereby enhancing the robustness and reliability of the Fit_layer function.