pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.96k stars 499 forks source link

Can’t create tensors as shown in first Captum Titanic tutorial?! #1325

Closed rbelew closed 3 months ago

rbelew commented 3 months ago

🐛 Bug

Basic steps in the Titanic tutorial to load CSV to tensors don't work?

To Reproduce

I'm stumped by the simplest part of the most basic "Titanic Basic" captum tutorial: converting the data into tensors?!

After getting the data and performing the first preprocessing steps, converting to numpy arrays and separating out train and test sets works fine:

    data = titanic_data.to_numpy()

    train_indices = np.random.choice(len(labels), int(0.7*len(labels)), replace=False)
    test_indices = list(set(range(len(labels))) - set(train_indices))
    train_features = data[train_indices]
    train_labels = labels[train_indices]
    test_features = data[test_indices]
    test_labels = labels[test_indices]

but converting to tensors doesn't work:

    File ".../Titanic_Basic_Interpret.py", line 139, in <module>
    input_tensor = torch.as_tensor(train_features,dtype=torch.float32)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.

Maybe the problem is "no more magic, convert_objects has been deprecated in pandas 0.17" ? It seems the tutorial was added back in 2019, but other issues seem to have used it more recently?

I've tried some of the suggestions there (building a separate dictionary of data types and then data = data.astype(dtype=dtypeDict), converting each column separately:

    for c in titanic_data.columns:
        titanic_data[c] = pd.to_numeric(titanic_data[c])

but these don't go thru either. What could the issue be?!

Expected behavior

Environment

Describe the environment used for Captum



 - Captum / PyTorch Version: 2.1.0.post100 / 0.7.0
 - OS: OSX 14.5
 - How you installed Captum / PyTorch (`conda`, `pip`, source): pip
 - Python version: 3.11.7