Open seven7777777 opened 2 days ago
hi, i pre train data with generate_flows_v2.py, then trained encoder with train_encoder_v2.py. I got message like:
Traceback (most recent call last): File "/home/notebook/code/group/benny/modflows/./train_encoder_v2.py", line 124, in trained_model.load_state_dict(trained_param) File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for NeuralODE: size mismatch for layer_1.weight: copying a param with shape torch.Size([64, 4]) from checkpoint, the shape in current model is torch.Size([1024, 4]). size mismatch for layer_1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1024]). size mismatch for layer_2.weight: copying a param with shape torch.Size([3, 64]) from checkpoint, the shape in current model is torch.Size([3, 1024]).
@maria-larchenko what to do with train_encoder script then? thanks.
I worked it out. In generate_flows_v2.py, hidden is 64, but it's 1024 in train_encoder_v2.py @seven7777777 @maria-larchenko
Traceback (most recent call last): File "/home/notebook/code/group/benny/modflows/./train_encoder_v2.py", line 124, in
trained_model.load_state_dict(trained_param)
File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for NeuralODE:
size mismatch for layer_1.weight: copying a param with shape torch.Size([64, 4]) from checkpoint, the shape in current model is torch.Size([1024, 4]).
size mismatch for layer_1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for layer_2.weight: copying a param with shape torch.Size([3, 64]) from checkpoint, the shape in current model is torch.Size([3, 1024]).