bowang-lab / scGPT

https://scgpt.readthedocs.io/en/latest/
MIT License
934 stars 170 forks source link

Error in ``evaluate function`` in ``Tutorial_Perturbation.ipynb`` #172

Open Yonggie opened 3 months ago

Yonggie commented 3 months ago

Loaded the parameter from the full human link and try the tutorial, found error:

def evaluate(model: nn.Module, val_loader: torch.utils.data.DataLoader) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    model.eval()
    total_loss = 0.0
    total_error = 0.0

    with torch.no_grad():
        for batch, batch_data in enumerate(val_loader):
            batch_size = len(batch_data.y)
            batch_data.to(device)
            x: torch.Tensor = batch_data.x  # (batch_size * n_genes, 2)
            ori_gene_values = x[:, 0].view(batch_size, n_genes)
            pert_flags = x[:, 1].long().view(batch_size, n_genes) # error here
Exception has occurred: IndexError
index 1 is out of bounds for dimension 1 with size 1
  File "/p", line 224, in evaluate
    pert_flags = x[:, 1].long().view(batch_size, n_genes)
  File "t.py", line 282, in <module>
    val_loss, val_mre = evaluate(
IndexError: index 1 is out of bounds for dimension 1 with size 1
Yonggie commented 3 months ago

I tried pip install cell-gear==0.0.1. And delete the older version of the data, let it downloads it again.

Another error happened:

# code
train_loader = pert_data.dataloader["train_loader"]

# hint
'PertData' object has no attribute 'dataloader'
AttributeError: 'PertData' object has no attribute 'dataloader'
Yonggie commented 3 months ago

Just checked the source code of gears==0.0.1, should do some extra change at the same time: change pert_data.dataloader[xxx] into pert_data.get_dataloader(**required_para)[xxx].

It worked,