Closed jusjusjus closed 1 year ago
Hi jusjusjus,
Greetings! Sincerely appreciate your great help for writing up this pytorch version!
I have tried to reproduce using this version, but failed to do so. After batch 1000, the train loss is still 0.69 and accuracy is around 0.5. I just typed python torch_train_model.py
. Did I do something wrong?
Thank you so much!
Sincerely, Martin
We add '/torch_train_model.py' that reproduces the training in '/train_model.py'. Each element in the cpc network are wrapped in
torch.nn.Module
subclasses:NetworkEncoder
: Transforms a batch of images into an encoded representation withcode_size
elements.Autoregressive
: Processes a series of encoded images to summarize their content in a context vector of lengthhidden_size
.Predictor
: Predict a number future encodings given a context vector.These are in turn wrapped into a
CPCNetwork
convenience class to manage training.