davidtellez / contrastive-predictive-coding

Keras implementation of Representation Learning with Contrastive Predictive Coding
525 stars 120 forks source link

Implement mnist-cpc training in pytorch #5

Closed jusjusjus closed 1 year ago

jusjusjus commented 5 years ago

We add '/torch_train_model.py' that reproduces the training in '/train_model.py'. Each element in the cpc network are wrapped in torch.nn.Module subclasses:

These are in turn wrapped into a CPCNetwork convenience class to manage training.

martinmamql commented 4 years ago

Hi jusjusjus,

Greetings! Sincerely appreciate your great help for writing up this pytorch version!

I have tried to reproduce using this version, but failed to do so. After batch 1000, the train loss is still 0.69 and accuracy is around 0.5. I just typed python torch_train_model.py. Did I do something wrong?

Thank you so much!

Sincerely, Martin