ML4GLand / EUGENe

Elucidating the Utility of Genomic Elements with Neural Nets
MIT License
65 stars 4 forks source link

ValueError: Expected argument `target` to be an int or long tensor, but got tensor with dtype torch.float32 #46

Open X02cinnamondirty opened 8 months ago

X02cinnamondirty commented 8 months ago
train.fit_sequence_module(
     in_memory=True,
     train_var="train_val",
...      model=model,
...      sdata=sd_train,
...      seq_var="ohe_seq",
...      target_vars= "id_x",
...      in_memory=True,
...      train_var="train_val",
...      epochs=25,
...      gpus=1,
...      batch_size=9,
...      num_workers=4,
...      prefetch_factor=2,
...      drop_last=False,
...      name="LTRidentity",
...      version="0.75",
...      transforms={"ohe_seq": lambda x: x.swapaxes(1, 2)}
...  )
Dropping 0 sequences with NaN targets.
Loading ohe_seq and ['id_x'] into memory
No seed set
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name         | Type            | Params
-------------------------------------------------
0 | arch         | SmallCNN        | 2.8 K
1 | train_metric | MulticlassAUROC | 0
2 | val_metric   | MulticlassAUROC | 0
3 | test_metric  | MulticlassAUROC | 0
-------------------------------------------------
2.8 K     Trainable params
0         Non-trainable params
2.8 K     Total params
0.011     Total estimated model params size (MB)
Sanity Checking DataLoader 0:   0%|                                                                                                                                                     | 0/2 [00:00<?, ?it/s]/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/eugene/models/_SequenceModule.py:203: UserWarning: Using a target size (torch.Size([9])) that is different to the input size (torch.Size([9, 9])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  loss = self.loss_fxn(outs, y)  # train
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/eugene/train/_fit.py", line 273, in fit_sequence_module
    trainer = fit(
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/eugene/train/_fit.py", line 123, in fit
    trainer.fit(
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 989, in _run
    results = self._run_stage()
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1033, in _run_stage
    self._run_sanity_check()
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1062, in _run_sanity_check
    val_loop.run()
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 134, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 391, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 403, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/eugene/models/_SequenceModule.py", line 228, in validation_step
    calculate_metric(
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/eugene/models/base/_metrics.py", line 53, in calculate_metric
    metric(outs, y)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/torchmetrics/metric.py", line 298, in forward
    self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/torchmetrics/metric.py", line 367, in _forward_reduce_state_update
    self.update(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/torchmetrics/metric.py", line 460, in wrapped_func
    update(*args, **kwargs)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/torchmetrics/classification/precision_recall_curve.py", line 345, in update
    _multiclass_precision_recall_curve_tensor_validation(preds, target, self.num_classes, self.ignore_index)
  File "/home/xiongzx/.conda/envs/eugene/lib/python3.9/site-packages/torchmetrics/functional/classification/precision_recall_curve.py", line 394, in _multiclass_precision_recall_curve_tensor_validation
    raise ValueError(
**ValueError: Expected argument `target` to be an int or long tensor, but got tensor with dtype torch.float32**

but

sd_train
<xarray.Dataset>
Dimensions:    (_sequence: 20466, _length: 300, length: 300, _ohe: 4)
Dimensions without coordinates: _sequence, _length, length, _ohe
Data variables:
    qseqid_x   (_sequence) object dask.array<chunksize=(91,), meta=np.ndarray>
    seq        (_sequence, _length) |S1 dask.array<chunksize=(91, 300), meta=np.ndarray>
    set        (_sequence) object dask.array<chunksize=(91,), meta=np.ndarray>
    spe        (_sequence) object dask.array<chunksize=(91,), meta=np.ndarray>
    sseqid     (_sequence) object dask.array<chunksize=(91,), meta=np.ndarray>
    id         (_sequence) <U8 'seq00000' 'seq00001' ... 'seq22738' 'seq22739'
    ohe_seq    (_sequence, length, _ohe) uint8 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0
    id_x       (_sequence) int8 0 1 1 0 1 0 0 0 1 0 0 ... 1 1 1 1 1 1 1 0 1 0 1
    train_val  (_sequence) bool False True True False ... True False True False
    **target     (_sequence) int8 0 1 1 0 1 0 0 0 1 0 0 ... 1 1 1 1 1 1 1 0 1 0 1**
>>> **sd_train['target']
<xarray.DataArray 'target' (_sequence: 20466)>
array([0, 1, 1, ..., 1, 0, 1], dtype=int8)
Dimensions without coordinates: _sequence**

My target var is int ,why this error happen?

adamklie commented 8 months ago

By default, the dataloading step casts target_vars to torch.float32.

You can overwrite this using thetransforms argument. Try modifying it to:

transforms={"ohe_seq": lambda x: x.swapaxes(1, 2), "id_x": lambda x: torch.tensor(x, dtype=torch.long))}