azizilab / DIISCO_public

Publically accessible code for DIISCO method
MIT License
3 stars 0 forks source link

the result is error : raceback (most recent call last): Cell In[175], line 1 print(y_samples_predict.shape) AttributeError: 'str' object has no attribute 'shape' #7

Open tian6067 opened 1 month ago

tian6067 commented 1 month ago

Hi,

This toolkit is excellent!


f_samples_predict, W_samples_predict, y_samples_predict = \
        model.sample(predict_timepoints, n_samples=10, include_emission_variance=TRUE)

it running: look this, i also try it n_samples=10000, it waste time to test my dataset, but the result is same (n_samples=10000) 100%|██████████| 100/100 [00:52<00:00, 1.89it/s]

but the result is :

print(y_samples_predict.shape)
print(W_samples_predict.shape)
Traceback (most recent call last):

  Cell In[175], line 1
    print(y_samples_predict.shape)

AttributeError: 'str' object has no attribute 'shape'

when i look this result:


W_samples_predict
Out[176]: 'F'

y_samples_predict
Out[177]: 'Y'

f_samples_predict
Out[178]: 'W'

can you help me slove problem! thank you very much!!!

cameronyoungpark commented 1 month ago

Hello, thank you for the interest in the package! We are currently in the process of updating the method so the tutorial notebooks may not be up to date. In the version you are using model.sample function returns a dictionary, with keys 'W', 'F' and 'Y'. example code for getting predicted W:

sample_dict = model.sample(predict_timepoints, n_samples=10, include_emission_variance=TRUE) W_samples_predict = sample_dict["W"]

I would suggest waiting a couple of weeks until we have finished the method improvements and updates, as the repo may be unstable with changes.

Hope that helps!

tian6067 commented 1 month ago

@cameronyoungpark ,thanks your quickly resoponse, you suggest was useful! i can running all the step ,but i found the celltaio and interaction by DIISCO predicted was straight line . i confused, it is my code. sorry fot the long code and my stupid quesiton. I am eager to add DIISCO to my paper!!!!! I didn't know python before. I just learned it in the past two days. and aslo referenced other paper(https://www.biorxiv.org/content/10.1101/2024.02.09.579677v1, and the code in :https://github.com/azizilab/dli_reproducibility/blob/main/DIISCO_AML_model_run_1.ipynb).Thank you very much for your answer.

from diisco import DIISCO
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import torch
torch.set_default_dtype(torch.float64)

from diisco import DIISCO
import diisco.names as names

df = pd.read_csv('lung.csv')
df.index = df.iloc[:,0]
df
df=df.iloc[:,1:]
df.iloc[:, :21].sum(axis=1)
cell_types = [
    'Alveolar type 2 cell', 'Smooth muscle cell', 'Alveolar type 1 cell', 'Ciliated Cell',
    'Alveolar macrophage', 'Serous cell', 'Capillaries endothelial', 'Mesothelial Cell',
    'Lymphatic endothelial', 'Fibroblast', 'Tuft cell', 'B plasma',
    'Mast', 'Basal cell', 'Dendritic cell', 'Goblet cell',
    'Monocyte', 'CD8-T', 'CD4-T', 'Neutrophils', 'B'
]

cell_type_color = [('Alveolar type 2 cell', 'tab:blue'), 
                    ('Smooth muscle cell', 'tab:orange'), 
                    ('Alveolar type 1 cell', 'tab:green'),
                    ('Ciliated Cell', 'tab:green'),
                    ('Alveolar macrophage', 'tab:green'),
                    ('Serous cell', 'tab:green'),
                    ('Capillaries endothelial', 'tab:green'),
                    ('Mesothelial Cell', 'tab:green'),
                    ('Lymphatic endothelial', 'tab:green'),
                    ('Fibroblast', 'tab:green'),
                    ('Tuft cell', 'tab:green'),
                    ('B plasma', 'tab:green'),
                    ('Mast', 'tab:green'),
                    ('Basal cell', 'tab:green'),
                    ('Dendritic cell', 'tab:green'),
                    ('Goblet cell', 'tab:green'),
                    ('Monocyte', 'tab:green'),
                    ('CD8-T', 'tab:green'),
                    ('CD4-T', 'tab:green'),
                    ('Neutrophils', 'tab:green'),
                    ('B', 'tab:red')]

fig, axes = plt.subplots(1, 21, figsize=(50, 4))
for i, cell_type in enumerate(cell_types):
    ax = axes[i]
    ax.scatter(df.index, df[cell_type], c=cell_type_color[i][1], s=25)
    ax.set_title(f'{cell_type_color[i][0]}', fontsize=14)
    ax.set_xlabel('Hours post co-culture', fontsize=12)
    if i==0: ax.set_ylabel('Proportion', fontsize=12)
plt.suptitle('Cell type proportions', fontsize=15, y=1.05)

w=pd.read_csv('pir.csv')
w.index = w.iloc[:,0]
w=w.iloc[:,1:]
ax = sns.heatmap(w, cmap="Reds", annot=True)
timepoints = torch.tensor(df.index.values.reshape(-1, 1))
proportions = torch.tensor(df[cell_types].values)
W_prior_variance=w
W_prior_variance=np.array(W_prior_variance)
n_timepoints, n_cell_types = proportions.shape

prior_matrix = torch.tensor(W_prior_variance)
proportions_mean = proportions.mean(dim=0)
proportions_std = proportions.std(dim=0)
proportions = (proportions - proportions_mean) / (proportions_std)
def unscale(proportions, proportions_mean, proportions_std, cluster_index): 
        return np.clip(proportions * 
                       proportions_std.detach().numpy()[cluster_index] + 
                       proportions_mean.detach().numpy()[cluster_index], 
                       0, None)

hyperparams = {
        names.LENGTHSCALE_F: 300,
        names.LENGTHSCALE_W: 300,
        names.SIGMA_F: 1,
        names.VARIANCE_F: 1,
        names.SIGMA_W: 0.1,
        names.VARIANCE_W: 1,
        names.SIGMA_Y: 0.5,
    }

model = DIISCO(lambda_matrix=prior_matrix, hypers_init_vals=hyperparams, verbose=True,verbose_freq=100)
model.fit(timepoints, 
          proportions, 
          n_iter=300000, 
          lr=0.00005,
          hypers_to_optim=[], 
          guide="MultivariateNormalFactorized")

start =0
loss_moving_avg = np.convolve(model.losses[start:], np.ones(100)/100, 'valid')
plt.figure(figsize=(8, 5))
plt.plot(loss_moving_avg)
plt.title('DIISCO model loss', fontsize=13)
plt.ylabel('Loss', fontsize=12)
plt.xlabel('Number of epochs', fontsize=12)

predict_timepoints = torch.linspace(timepoints.min(), timepoints.max(), 100).reshape(-1, 1)
means = model.get_means(predict_timepoints)

samples = model.sample(predict_timepoints, 
                     n_samples=10000, 
                     n_samples_per_latent=10,
                     include_emission_variance=False)      

W_samples_predict = samples['W']
f_samples_predict = samples['F']
y_samples_predict = samples['Y']        
print(y_samples_predict.shape)
print(W_samples_predict.shape)        

fig, axes = plt.subplots(1, 21, figsize=(100,18))
for i, cell_type in enumerate(cell_types):
    cell_type_samples = y_samples_predict[:, :, i]
    mean = unscale(cell_type_samples.mean(axis=0), proportions_mean, proportions_std, i)
    x = predict_timepoints.squeeze().numpy()
    #percentile_75 = unscale(np.percentile(cell_type_samples, 84, axis=0), 
    #                        proportions_mean, proportions_std, i)
    #percentile_25 = unscale(np.percentile(cell_type_samples, 16, axis=0), 
    #                        proportions_mean, proportions_std, i)
    ax = axes[i]
    line = ax.plot(x, mean, c=cell_type_color[i][1])
    #ax.fill_qbetween(x, percentile_25, percentile_75, color=cell_type_color[i][1], alpha=0.2)
    ax.scatter(df.index, df[cell_type], c=cell_type_color[i][1], s=25)
    ax.set_title(f'{cell_type_color[i][0]}', fontsize=14)
    ax.set_xlabel('Hours post co-culture', fontsize=12)
    if i==0: ax.set_ylabel('Proportion', fontsize=12)
    ax.set_ylim([0, 0.75])
plt.suptitle('DIISCO predicted cell type proportions', fontsize=15, y=1.05)
plt.savefig('cell.pdf', bbox_inches='tight')

plt.figure(figsize=(5, 4))
W_avg_over_time = W_samples_predict.mean(axis=(0, 1)).detach().numpy()
ax = sns.heatmap(W_avg_over_time, cmap="RdBu_r", annot=True, 
                 vmax=0.4, vmin=-0.4, center=0)
ax.set_yticklabels(cell_types, fontsize=12)
ax.set_xticklabels(cell_types, fontsize=12)
plt.yticks(rotation=0)
plt.xticks(rotation=45)
plt.title('DIISCO predicted interactions mean ($\hat{W}_{avg})$', fontsize=14, y=1.05)
plt.xlabel('Source cluster', fontsize=12)
plt.ylabel('Target cluster', fontsize=12)

fig.set_size_inches(30, 30)
W_mean = W_samples_predict.mean(axis=0)
lines = 0  
lines=0
for i, cell_type_i in enumerate(cell_types):
    for j, cell_type_j in enumerate(cell_types):
        if i != j:
            plt.plot(predict_timepoints.squeeze(), 
                     W_mean[:, i, j].detach().numpy(),
                     linestyle=linestyles[lines % len(linestyles)],
                     label='$W_{%s,%s}$ (%s - %s interaction)' % (i, j, cell_type_i, cell_type_j))
            lines += 1            
plt.legend(bbox_to_anchor=(1, 1.02), loc='upper left', fontsize=5)
plt.title('DIISCO predicted interactions over time', fontsize=15)
plt.ylabel('$W_{i, j}$', fontsize=15)
plt.xlabel('time', fontsize=14)
plt.savefig('all.pdf', bbox_inches='tight')

annot = False
vmin = -1
vmax = 1

X_200_days_pre_dli_index = np.where(predict_timepoints > 7)[0][0]
X_post_dli_index = np.where(predict_timepoints > 14)[0][0]

X_post_dli_index
W = W_samples_predict.mean(axis=0)
W_pre_dli = W[X_200_days_pre_dli_index:X_post_dli_index]

W_pre_dli_avg_over_time = torch.mean(W_pre_dli, dim=0)

plt.figure(figsize=(18, 7))

     plt.subplot(1, 1, 1)
     ax = sns.heatmap(W_pre_dli_avg_over_time, cmap="RdBu_r", annot=annot,
                      fmt='.2f', vmin=vmin, vmax=vmax, center=0)
     ax.set_yticklabels(cell_types, fontsize=12)
     ax.set_xticklabels(cell_types, fontsize=12)
     plt.xlabel('Source cluster', fontsize=12)
     plt.ylabel('Target cluster', fontsize=12)
     plt.yticks(rotation=0)
     plt.xticks(rotation=45)
     plt.title('$\\hat{W}_{avg}$ (200 days pre-DLI)', fontsize=14)

fig.set_size_inches(30, 30)
lines=0

     for i, cell_type_i in enumerate(cell_types):
         for j, cell_type_j in enumerate(cell_types):
             # a sustained strong interaction over the entire post-DLI time period
             sustained = np.abs(W_pre_dli_avg_over_time[i, j]) > mean_abs_W_threshold
             # a transient strong interaction during the post-DLI time period
             transient = (np.abs(W_pre_dli[:, i, j]) > max_abs_W_threshold).any()
             show_line = (sustained or transient)
             if show_line and i!= j:
                 plt.plot(predict_timepoints.squeeze(),
                          W[:, i, j],
                          linestyle=linestyles[lines % len(linestyles)],
                          label='$W_{%s,%s}$ (%s - %s interaction)' % (i, j, cell_type_i, cell_type_j))
     plt.legend(bbox_to_anchor=(1, 1.02), loc='upper left', fontsize=5)
     plt.title('DIISCO predicted interactions over time', fontsize=15)
     plt.ylabel('$W_{i, j}$', fontsize=15)
     plt.xlabel('time', fontsize=14)
     plt.savefig('up.pdf', bbox_inches='tight')