jyaacoub / MutDTA

Improving the precision oncology pipeline by providing binding affinity purtubations predictions on a pirori identified cancer driver genes.
1 stars 2 forks source link

Weighted edges to account for protein dynamics #25

Closed jyaacoub closed 10 months ago

jyaacoub commented 1 year ago

Using multiple structures from AlphaFold to generate edge weights for the GNN -> https://github.com/PDBe-KB/pdbe-kb-manual/wiki/Secondary-structure-variance

jyaacoub commented 1 year ago

Part of this will require some filtering out by using TM-score to identify severe misfolds.

Code for TM-Score

from prody import parsePDB, matchAlign, showProtein
from pylab import legend
import numpy as np

# %%
# tm score for A&B is 0.9835 (https://seq2fun.dcmb.med.umich.edu//TM-score/tmp/110056.html)
src_model = '/cluster/home/t122995uhn/projects/data/v2020-other-PL/1a1e/1a1e_protein.pdb'
pred_model = '/cluster/home/t122995uhn/projects/colabfold/out/1a1e.msa_unrelaxed_rank_001_alphafold2_ptm_model_1_seed_000.pdb'

sm = parsePDB(src_model, model=1, subset="ca", chain="A")
sm.setTitle('experimental')
pm = parsePDB(pred_model, model=1, subset="ca", chain="A")
pm.setTitle('alphafold')

showProtein(sm,pm)
legend()

#%% Performing alignment before TM-score
result = matchAlign(pm, sm)

showProtein(sm,pm)
legend()

# %%
def tm_score(xyz0, xyz1): #Check if TM-align use all atoms!    
    L = len(xyz0)
    # d0 is less than 0.5 for L < 22 
    # and nan for L < 15 (root of a negative number)
    d0 = 1.24 * np.power(L - 15, 1/3) - 1.8
    d0 = max(0.5, d0) 

    # compute the distance for each pair of atoms (L2 distance)
    di = np.sqrt(np.sum((xyz0 - xyz1) ** 2, 1)) # sum along 2nd axis
    return np.sum(1 / (1 + (di / d0) ** 2)) / L

# TM score for predicted model 1 with chain A of src should be 0.9681 (https://seq2fun.dcmb.med.umich.edu//TM-score/tmp/987232.html)
tm_score(sm.getCoords(), 
         pm.getCoords())
jyaacoub commented 1 year ago

Attempting to replicate LAT1 figure from the af2_confirmations paper (Almano, 2022)

Starting with PCA I get the following (after running 4x5 with local colabfold): image|400

Code to reproduce:

# %%
#NOTE A3M files are located in /cluster/home/t122995uhn/projects/data/PDBbind_aln/ (symlink to  /cluster/projects/kumargroup/msa/output/)
from prody import parsePDB, matchAlign, showProtein, PCA
from pylab import legend
import numpy as np
import os
import glob

# %%
# tm score for A&B is 0.9835 (https://seq2fun.dcmb.med.umich.edu//TM-score/tmp/110056.html)
code='LAT1'
code1 = '7dsq'
code2 = '6irs'
pred_dir = '/cluster/home/t122995uhn/projects/colabfold/LAT1_out_4x5/*'

exp_model1 = f'/cluster/home/t122995uhn/projects/colabfold/LAT1_out_4x5/{code1}.pdb'
exp_model2 = f'/cluster/home/t122995uhn/projects/colabfold/LAT1_out_4x5/{code2}.pdb'

pred_models = glob.glob(f'{pred_dir}/{code}*.pdb')
pred_models.sort()

# %% parsing models
em1 = parsePDB(exp_model1, model=1, subset="ca", chain="B")
em1.setTitle('experimental1')
em1 = em1.select('resnum 51:507')
em2 = parsePDB(exp_model2, model=1, subset="ca", chain="B",)
em2.setTitle('experimental2')
em2 = em2.select('resnum 51:507')

# parse predicted and set concise titles.
pms = [parsePDB(pred_model, model=1, subset="ca", chain="A") for pred_model in pred_models]
for i, pm in enumerate(pms): 
    pm.setTitle(f'AF_{i}')
    pms[i] = pm.select('resnum 51:507')

showProtein(em1, em2, *pms)
# legend()

#%% Algnment of ensembles:
from prody import Ensemble

ensemble = Ensemble(f"{code}_ensemble")
ensemble.setCoords(em1.getCoords())
ensemble.addCoordset(em1.getCoordsets())
ensemble.addCoordset(em2.getCoordsets())
for pm in pms:
    ensemble.addCoordset(pm.getCoordsets()) # inter until added all models

ensemble.iterpose() # aligns all proteins until convergence

# %%
copy_em1 = em1.copy()
copy_em1.delCoordset(range(copy_em1.numCoordsets()))
copy_em1.addCoordset(ensemble.getCoordsets()[0,:,:])
copy_em2 = em2.copy()
copy_em2.delCoordset(range(copy_em2.numCoordsets()))
copy_em2.addCoordset(ensemble.getCoordsets()[0,:,:])

copy_pms = []
for i, pm in enumerate(pms):
    copy_pm = pm.copy()
    copy_pm.delCoordset(range(copy_pm.numCoordsets()))
    copy_pm.addCoordset(ensemble.getCoordsets()[i+2,:,:]) #NOTE: +2 due to 2 above
    copy_pms.append(copy_pm)

showProtein(copy_em1, copy_em2, *copy_pms)
# legend()

# %% PCA
pca = PCA(f'{code}_pca')
pca.performSVD(ensemble)

# %%
# Get the eigenvalues and eigenvectors
eigvals = pca.getEigvals()
eigvecs = pca.getEigvecs()

# Normalize eigenvectors by square root of eigenvalues
normalized_eigvecs = eigvecs / np.sqrt(eigvals)

# Calculate principal component coordinates using the normalized eigenvectors
deviations = ensemble.getDeviations().reshape(-1, ensemble.numAtoms() * 3)
pc_coords = np.dot(deviations, normalized_eigvecs)

# Reshape pc_coords back to (num_frames, num_modes)
pc_coords = pc_coords.reshape(ensemble.numConfs(), -1)
# %%
import matplotlib.pyplot as plt
num_modes_to_plot=len(pc_coords)

# Plot all 6 modes against the two most significant principal components
plt.figure(figsize=(8, 6))
for mode_num in range(num_modes_to_plot):
    label = code if mode_num in [0,1] else None
    plt.scatter(pc_coords[mode_num, 0], pc_coords[mode_num, 1], label=label)

plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.title('Modes Plotted by Principal Components')
plt.legend()
plt.show()
# %%
jyaacoub commented 1 year ago

To replicate TM-score figure: image

code:

# %%
#NOTE A3M files are located in /cluster/home/t122995uhn/projects/data/PDBbind_aln/ (symlink to  /cluster/projects/kumargroup/msa/output/)
from prody import parsePDB, matchAlign, showProtein, PCA
from pylab import legend
import numpy as np
import os
import glob

# %%
# tm score for A&B is 0.9835 (https://seq2fun.dcmb.med.umich.edu//TM-score/tmp/110056.html)
code='LAT1'
code1 = '7dsq'
code2 = '6irs'
pred_dir = '/cluster/home/t122995uhn/projects/colabfold/LAT1_out_4x5/*'

exp_model1 = f'/cluster/home/t122995uhn/projects/colabfold/LAT1_out_4x5/{code1}.pdb'
exp_model2 = f'/cluster/home/t122995uhn/projects/colabfold/LAT1_out_4x5/{code2}.pdb'

pred_models = glob.glob(f'{pred_dir}/{code}*.pdb')
pred_models.sort()

# %% parsing models
em1 = parsePDB(exp_model1, model=1, subset="ca", chain="B")
em1.setTitle('experimental1')
em1 = em1.select('resnum 51:507')
em2 = parsePDB(exp_model2, model=1, subset="ca", chain="B",)
em2.setTitle('experimental2')
em2 = em2.select('resnum 51:507')

# parse predicted and set concise titles.
pms = [parsePDB(pred_model, model=1, subset="ca", chain="A") for pred_model in pred_models]
for i, pm in enumerate(pms): 
    pm.setTitle(f'AF_{i}')
    pms[i] = pm.select('resnum 51:507')

# showProtein(em1, em2, *pms)
# legend()

# %%
#TM-Score calculation
def tm_score(xyz0, xyz1): #Check if TM-align use all atoms!    
    L = len(xyz0)
    # d0 is less than 0.5 for L < 22 
    # and nan for L < 15 (root of a negative number)
    d0 = 1.24 * np.power(L - 15, 1/3) - 1.8
    d0 = max(0.5, d0) 

    # compute the distance for each pair of atoms (L2 distance)
    di = np.sqrt(np.sum((xyz0 - xyz1) ** 2, 1)) # sum along 2nd axis
    return np.sum(1 / (1 + (di / d0) ** 2)) / L

all_conf = [em1,em2]+ pms
#wrt em1
em1_TM = []
for i, conf in enumerate(all_conf):
    tmp = conf.copy()
    res = matchAlign(tmp, em1)
    em1_TM.append(tm_score(tmp.getCoords(), em1.getCoords()))

#wrt em2
em2_TM = []
for i, conf in enumerate(all_conf):
    tmp = conf.copy()
    res = matchAlign(tmp, em2)
    em2_TM.append(tm_score(tmp.getCoords(), em2.getCoords()))

# %%
import matplotlib.pyplot as plt

plt.scatter(em1_TM, em2_TM, alpha=0.5, label="AF_pred")
plt.scatter(em1_TM[0], em2_TM[0], alpha=0.5, label=code1)
plt.scatter(em1_TM[1], em2_TM[1], alpha=0.5, label=code2)
plt.legend(loc="center left")
plt.xlabel(code1)
plt.ylabel(code2)
plt.title('TM-score wrt to real modes')
# %%
jyaacoub commented 1 year ago

Above code had some errors, here is the corrected version: image

Code:

# %%
#NOTE A3M files are located in /cluster/home/t122995uhn/projects/data/PDBbind_aln/ (symlink to  /cluster/projects/kumargroup/msa/output/)
from prody import parsePDB, matchAlign, showProtein, PCA
from pylab import legend
import numpy as np
import os
import glob

# %%
# tm score for A&B is 0.9835 (https://seq2fun.dcmb.med.umich.edu//TM-score/tmp/110056.html)
code='LAT1'
code1 = '7dsq'
code2 = '6irs'
pred_dir = '/cluster/home/t122995uhn/projects/colabfold/LAT1_out_4x5/out*/'

exp_model1 = f'/cluster/home/t122995uhn/projects/colabfold/LAT1_out_4x5/{code1}.pdb'
exp_model2 = f'/cluster/home/t122995uhn/projects/colabfold/LAT1_out_4x5/{code2}.pdb'

pred_models = glob.glob(f'{pred_dir}/{code}*.pdb')
pred_models.sort()

# %% parsing models
em1 = parsePDB(exp_model1, model=1, subset="ca", chain="B")
em1.setTitle('experimental1')
em1 = em1.select('resnum 51:508')
em2 = parsePDB(exp_model2, model=1, subset="ca", chain="B",)
em2.setTitle('experimental2')
em2 = em2.select('resnum 51:508')

# parse predicted and set concise titles.
pms = [parsePDB(pred_model, model=1, subset="ca", chain="A") for pred_model in pred_models]
for i, pm in enumerate(pms): 
    pm.setTitle(f'AF_{i}')
    # pms[i] = pm.select('resnum 1:456')

# showProtein(em1, em2, *pms)
# legend()

# %%
#TM-Score calculation
def tm_score(xyz0, xyz1): #Check if TM-align use all atoms!    
    L = len(xyz0)
    # d0 is less than 0.5 for L < 22 
    # and nan for L < 15 (root of a negative number)
    d0 = 1.24 * np.power(L - 15, 1/3) - 1.8
    d0 = max(0.5, d0) 

    # compute the distance for each pair of atoms (L2 distance)
    di = np.sqrt(np.sum((xyz0 - xyz1) ** 2, 1)) # sum along 2nd axis
    return np.sum(1 / (1 + (di / d0) ** 2)) / L

all_conf = [em1,em2]+ pms
#wrt em1
em1_TM = []
for i, conf in enumerate(all_conf):
    tmp = conf.copy()
    res = matchAlign(tmp, em1)
    em1_TM.append(tm_score(tmp.getCoords(), em1.getCoords()))

#wrt em2
em2_TM = []
for i, conf in enumerate(all_conf):
    tmp = conf.copy()
    res = matchAlign(tmp, em2)
    em2_TM.append(tm_score(tmp.getCoords(), em2.getCoords()))

# %%
import matplotlib.pyplot as plt

plt.scatter(em1_TM, em2_TM, alpha=0.5, label="AF_pred")
plt.scatter(em1_TM[0], em2_TM[0], alpha=0.5, label=code1)
plt.scatter(em1_TM[1], em2_TM[1], alpha=0.5, label=code2)
plt.legend(loc="center left")
plt.xlabel(code1)
plt.ylabel(code2)
plt.title('TM-score wrt to real modes')
# %%
jyaacoub commented 1 year ago

Note that if you dont explicitly set diffferent random seeds you will just get the same 5 models in the end:

This is a run with seq set to 32:64 and no setting of random seed: image

Code

# %%
#NOTE A3M files are located in /cluster/home/t122995uhn/projects/data/PDBbind_aln/ (symlink to  /cluster/projects/kumargroup/msa/output/)
from prody import parsePDB, matchAlign, showProtein, PCA
from pylab import legend
import numpy as np
import os
import glob

# %%
# tm score for A&B is 0.9835 (https://seq2fun.dcmb.med.umich.edu//TM-score/tmp/110056.html)
code='LAT1'
code1 = '7dsq'
code2 = '6irs'
pred_dir = '/cluster/home/t122995uhn/projects/colabfold/LAT1_out_4x5/old_outs/*/'

exp_model1 = f'/cluster/home/t122995uhn/projects/colabfold/LAT1_out_4x5/{code1}.pdb'
exp_model2 = f'/cluster/home/t122995uhn/projects/colabfold/LAT1_out_4x5/{code2}.pdb'

pred_models = glob.glob(f'{pred_dir}/{code}*.pdb')
pred_models.sort()

# %% parsing models
em1 = parsePDB(exp_model1, model=1, subset="ca", chain="B")
em1.setTitle('experimental1')
em1 = em1.select('resnum 51:508')
em2 = parsePDB(exp_model2, model=1, subset="ca", chain="B",)
em2.setTitle('experimental2')
em2 = em2.select('resnum 51:508')

# parse predicted and set concise titles.
pms = [parsePDB(pred_model, model=1, subset="ca", chain="A") for pred_model in pred_models]
for i, pm in enumerate(pms): 
    pm.setTitle(f'AF_{i}')
    pms[i] = pm.select('resnum > 70')

showProtein(em1, em2, *pms)
# legend()

# %%
#TM-Score calculation
def tm_score(xyz0, xyz1): #Check if TM-align use all atoms!    
    L = len(xyz0)
    # d0 is less than 0.5 for L < 22 
    # and nan for L < 15 (root of a negative number)
    d0 = 1.24 * np.power(L - 15, 1/3) - 1.8
    d0 = max(0.5, d0) 

    # compute the distance for each pair of atoms (L2 distance)
    di = np.sqrt(np.sum((xyz0 - xyz1) ** 2, 1)) # sum along 2nd axis
    return np.sum(1 / (1 + (di / d0) ** 2)) / L

all_conf = [em1,em2]+ pms
#wrt em1
em1_TM = []
for i, conf in enumerate(all_conf):
    tmp = conf.copy()
    res = matchAlign(tmp, em1)
    em1_TM.append(tm_score(tmp.getCoords(), em1.getCoords()))

#wrt em2
em2_TM = []
for i, conf in enumerate(all_conf):
    tmp = conf.copy()
    res = matchAlign(tmp, em2)
    em2_TM.append(tm_score(tmp.getCoords(), em2.getCoords()))

# %%
import matplotlib.pyplot as plt

plt.scatter(em1_TM, em2_TM, alpha=0.5, label="AF_pred")
plt.scatter(em1_TM[0], em2_TM[0], alpha=0.5, label=code1)
plt.scatter(em1_TM[1], em2_TM[1], alpha=0.5, label=code2)
plt.legend(loc="center left")
plt.xlabel(code1)
plt.ylabel(code2)
plt.title('TM-score wrt to real modes')
# %%
jyaacoub commented 1 year ago

PCA results with LAT show that real experimental structures cluster differently: image

This is different from the paper which is more consistent along PC2: image

jyaacoub commented 1 year ago

PCA results with LAT show that real experimental structures cluster differently: image

This is different from the paper which is more consistent along PC2: image

Running on the ASCT2 model gives us something different as well:

image

Code:

# %%
#NOTE A3M files are located in /cluster/home/t122995uhn/projects/data/PDBbind_aln/ (symlink to  /cluster/projects/kumargroup/msa/output/)
from prody import parsePDB, matchAlign, showProtein, PCA
from pylab import legend
import numpy as np
import os
import glob

# %%
# tm score for A&B is 0.9835 (https://seq2fun.dcmb.med.umich.edu//TM-score/tmp/110056.html)
code='ASCT2'
code1, code2 = '6rvx', '7bcq'
chainsel= 'A'

out_dir = '/cluster/home/t122995uhn/projects/colabfold/ASCT2_out/'
pred_dir = f'{out_dir}/32_64/out?/'

exp_model1 = f'{out_dir}/{code1}_cleaned.pdb'
exp_model2 = f'{out_dir}/{code2}_cleaned.pdb'

pred_models = glob.glob(f'{pred_dir}/{code}*.pdb')
pred_models.sort()

# %% parsing models
em1 = parsePDB(exp_model1, model=1, subset="ca", chain=chainsel)
em1.setTitle('experimental1')
if code == "LAT1": em1 = em1.select('resnum 51:508') 

em2 = parsePDB(exp_model2, model=1, subset="ca", chain=chainsel)
em2.setTitle('experimental2')
if code == "LAT1": em2 = em2.select('resnum 51:508') 

# parse predicted and set concise titles.
pms = [parsePDB(pred_model, model=1, subset="ca", chain="A") for pred_model in pred_models]
for i, pm in enumerate(pms): 
    pm.setTitle(f'AF_{i}')
    # pms[i] = pm.select('resnum > 70')

showProtein(em1, em2, *pms)
# legend()
#%% Algnment of ensembles:
from prody import Ensemble

ensemble = Ensemble(f"{code}_ensemble")
ensemble.setCoords(em1.getCoords())
ensemble.addCoordset(em1.getCoordsets())
ensemble.addCoordset(em2.getCoordsets())
for pm in pms:
    ensemble.addCoordset(pm.getCoordsets()) # inter until added all models

ensemble.iterpose() # aligns all proteins until convergence

# %%
copy_em1 = em1.copy()
copy_em1.delCoordset(range(copy_em1.numCoordsets()))
copy_em1.addCoordset(ensemble.getCoordsets()[0,:,:])
copy_em2 = em2.copy()
copy_em2.delCoordset(range(copy_em2.numCoordsets()))
copy_em2.addCoordset(ensemble.getCoordsets()[0,:,:])

copy_pms = []
for i, pm in enumerate(pms):
    copy_pm = pm.copy()
    copy_pm.delCoordset(range(copy_pm.numCoordsets()))
    copy_pm.addCoordset(ensemble.getCoordsets()[i+2,:,:]) #NOTE: +2 due to 2 above
    copy_pms.append(copy_pm)

showProtein(copy_em1, copy_em2, *copy_pms)
# legend()

# %% PCA
pca = PCA(f'{code}_pca')
pca.performSVD(ensemble)

# %%
# Get the eigenvalues and eigenvectors
eigvals = pca.getEigvals()
eigvecs = pca.getEigvecs()

# Normalize eigenvectors by square root of eigenvalues
normalized_eigvecs = eigvecs / np.sqrt(eigvals)

# Calculate principal component coordinates using the normalized eigenvectors
deviations = ensemble.getDeviations().reshape(-1, ensemble.numAtoms() * 3)
pc_coords = np.dot(deviations, normalized_eigvecs)

# Reshape pc_coords back to (num_frames, num_modes)
pc_coords = pc_coords.reshape(ensemble.numConfs(), -1)
# %%
import matplotlib.pyplot as plt
pc_coords_64_128 = pc_coords
# Plot all 6 modes against the two most significant principal components
plt.figure(figsize=(8, 6))
plt.scatter(pc_coords[0, 0], pc_coords[0, 1], label=code1)
plt.scatter(pc_coords[1, 0], pc_coords[1, 1], label=code2)
plt.scatter(pc_coords[2:, 0], pc_coords[2:, 1])

plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.title('MSA-32:64')
plt.legend()
plt.show()
jyaacoub commented 1 year ago

Results image

Code:

#%%
import pandas as pd
import matplotlib.pyplot as plt

# Read the CSV file into a DataFrame
file_path = "/home/jyaacoub/projects/MutDTA/results/model_media/model_stats.csv"
df = pd.read_csv(file_path)

run_mapping = {
    "EDI_davis_10B_0.0001LR_0.4D_2000E_nomsaF": "ESM_binaryW",
    "EDIM_davisD_nomsaF_simpleE_10B_0.0001LR_0.4D_2000E": "ESM_simpleW",
    "randW_davis-fixed_64B_0.0001LR_0.4D_2000E_shannonF_DGraphDTA": "DGraphDTA_binaryW",
    "DGIM_davisD_nomsaF_simpleE_64B_0.0001LR_0.4D_2000E": "DGraphDTA_simpleW",
    # Add more mappings as needed
}

# Define colors for specific runs
colors = {
    "ESM_binaryW": 'C0',
    "ESM_simpleW": 'C0',
    "DGraphDTA_binaryW": 'green',
    "DGraphDTA_simpleW": 'green',
    # Add more colors as needed
}

# Define a list of specific runs to filter for
specific_runs = list(run_mapping.keys())

# Filter the DataFrame for the specific runs
filtered_df = df[df['run'].isin(specific_runs)]

# Map the original run names to the desired names
filtered_df['run'] = filtered_df['run'].map(run_mapping)

# Extract the "run" and "cindex" columns
runs = filtered_df["run"]
cindex_values = filtered_df["cindex"]

# Create a bar graph with different colors for each run
plt.figure(figsize=(10, 6))
bars = plt.bar(runs, cindex_values)
for bar, run in zip(bars, runs):
    bar.set_color(colors[run])
plt.xlabel('Run')
plt.ylabel('C-Index')
plt.title('C-Index for Each Run')
# plt.xticks(rotation=20)  # Rotate x-axis labels for better readability
plt.tight_layout()
plt.ylim(0.5, 1.0)
# Show the plot
plt.show()
jyaacoub commented 12 months ago

Visualizing af2 weighted edges to validate its use for identifying functional regions. Areas surrounded by red bars are the kinase regions.

EGFR

drawing drawing drawing

IGF1R

drawing drawing drawing

Code

#%%
#%% create edge weights and visualize them
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
from src.utils.residue import Chain

target = 'IGF1R' #EGFR
structures = glob(f'/cluster/home/t122995uhn/projects/colabfold/out_misc/{target}/out*/*.pdb')

chains = [Chain(p) for p in structures]
M = np.array([c.get_contact_map() for c in chains])
print("num chains:", len(chains))

edgeW = np.sum(M < 8.0, axis=0)/len(M)

#%% Use ANM then build cross correlation matrix for all chains
from prody import calcANM, calcCrossCorr
import numpy as np
from tqdm import tqdm
n_modes = 5

avg_cc = np.zeros(shape=(len(chains[0]), len(chains[0])), 
                  dtype=float)
for chain in tqdm(chains, 'Running ANM'):
    anm = calcANM(chain.hessian, selstr='calpha', n_modes=n_modes)

    cc = calcCrossCorr(anm[:n_modes], n_cpu=1)
    cc_min, cc_max = cc.min(), cc.max()
    # min-max normalization into [0,1] range
    avg_cc += (cc-cc_min)/(cc_max-cc_min)

avg_cc /= len(chains)

#%%
highlight_rows = [690, 954, 134, 313] if target == 'EGFR' else [969-92, 1236-92]
plt.figure(figsize=(20,20))
plt.matshow(edgeW, fignum=1)

for idx in highlight_rows:
    plt.axhline(idx - 0.5, color='red', linewidth=1, alpha=0.3)
    plt.axvline(idx - 0.5, color='red', linewidth=1, alpha=0.3)
# plt.axis('off')
plt.show()

highlight_rows = [690, 954, 134, 313] if target == 'EGFR' else [969-92, 1236-92]
plt.figure(figsize=(20,20))
plt.matshow(avg_cc, fignum=1)

for idx in highlight_rows:
    plt.axhline(idx - 0.5, color='red', linewidth=1, alpha=0.3)
    plt.axvline(idx - 0.5, color='red', linewidth=1, alpha=0.3)
# plt.axis('off')
plt.show()