bioinf-jku / FCD

Fréchet ChemNet Distance: A quality measure for generative models for molecules
GNU Lesser General Public License v3.0
71 stars 26 forks source link

An error in the function 'get_one_hot' #14

Closed Koiocs closed 6 months ago

Koiocs commented 10 months ago

when the getitem function is called, get_one_hot(smiles, 350) is called. In get_one_hot(smiles, 350) function, the array_length is limited within 350, but the index of numeric can exceed 350, causing the IndexError for one_hot in axis 0.

class SmilesDataset(Dataset):
    __PAD_LEN = 350

    def __init__(self, smiles_list):
        super().__init__()
        self.smiles_list = smiles_list

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        features = get_one_hot(smiles, 350)  //set pad_len = 350
        return features / features.shape[1]

    def __len__(self):
        return len(self.smiles_list)
def get_one_hot(smiles: str, pad_len: int = -1) -> np.ndarray:
    """Generate one-hot representation of a Smiles string.

    Args:
        smiles (str): Input molecule as Smiles
        pad_len (int, optional): Whether or not to pad to a given size. Defaults to -1.

    Returns:
        np.ndarray: Array containing the one-hot encoded Smiles
    """
    smiles = smiles + "."

    # initialize array
    array_length = len(smiles) if pad_len < 0 else pad_len
    vocab_size = len(__vocab)
    one_hot = np.zeros((array_length, vocab_size))

    tokens = tokenize(smiles)
    numeric = [__vocab_c2i.get(token, __unk) for token in tokens]

    for pos, num in enumerate(numeric):  //pos can exceed 350
        one_hot[pos, num] = 1    //IndexError

    return one_hot
renzph commented 6 months ago

Thanks for your comment. This has been fixed in 1.2.1