materialsvirtuallab / megnet

Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals
BSD 3-Clause "New" or "Revised" License
502 stars 156 forks source link

Impossible to do classification training (ValueError) #276

Open LuciusV opened 3 years ago

LuciusV commented 3 years ago

Hello! I wanted to train megnet model to classification of structure set (some property is zero or not), so I prepared train data as column with string values 'zero' or 'nonzero'.

Then model.train method (model is MEGNetModel loaded from band_classification.hdf5) fails with ValueError: Failed to convert a NumPy array to a Tensor (Unsupported object type list). Using bool values instead of string also doesn't work. If I use 0 and 1 for train, then trained model gives me float values (near 0 or 1 if trained well enough), but this is not reliable and I would like to do classification machine learning, with probably more than two classes. Could you please provide any advice how to change model properties to allow that?

there is function i use to make model:

def gnn_model(n_targets=1):
    model_form = MEGNetModel.from_file('band_classification.hdf5')
    embedding_layer = [i for i in model_form.layers if i.name.startswith('embedding')][0]
    embedding = embedding_layer.get_weights()[0]
    #print('Embedding matrix dimension is ', embedding.shape)
    model = MEGNetModel(100,2,ntarget=n_targets)
    # find the embedding layer  index in all the model layers
    embedding_layer_index = [i for i, j in enumerate(model.layers) if j.name.startswith('atom_embedding')][0]

    # Set the weights to our previous embedding
    model.layers[embedding_layer_index].set_weights([embedding])

    # Freeze the weights
    model.layers[embedding_layer_index].trainable = False
    return model
chc273 commented 3 years ago

Did you solve the problem? @LuciusV

Also make sure your numpy version is 1.19, since the 1.20 versions have some incompatibility issues

LuciusV commented 3 years ago

Did you solve the problem? @LuciusV

Also make sure your numpy version is 1.19, since the 1.20 versions have some incompatibility issues

Thank you for pointing this. I will create a new virtual environment with numpy =1.19 and try, because I was using newer numpy .