SciSharp / TensorFlow.NET

.NET Standard bindings for Google's TensorFlow for developing, training and deploying Machine Learning models in C# and F#.
https://scisharp.github.io/tensorflow-net-docs
Apache License 2.0
3.23k stars 517 forks source link

fix: fix the bug of load LSTM model and add test #1144

Closed Wanglongzhi2001 closed 1 year ago

Wanglongzhi2001 commented 1 year ago

在使用反射来根据 metadata 来恢复 layer 的时候使用的参数是Tensorflow.Keras.ArgsDefinition.{class_name}Args,没有多余的.rnn,因此放在命名空间Tensorflow.Keras.ArgsDefinition.Rnn里的Args都读不出来,所以将所有的 layer 的 Args 都应该放在原来的Tensorflow.Keras.ArgsDefinition命名空间。同理,layer 也应该放在Tensorflow.Keras.Layers命名空间。并且所有需要恢复的 layer 的 args 里的参数都应该加上[JsonProperty] attribute,否则无法反序列化成功: https://github.com/SciSharp/TensorFlow.NET/blob/ed1a8d2edfbad3e47efa48af5e1dbb4c22a20f2e/src/TensorFlowNET.Keras/Utils/generic_utils.cs#L61-L72

AsakusaRinne commented 1 year ago

plz fix the ci error

Oceania2018 commented 1 year ago
  Failed LSTMLoad [26 s]
  Error Message:
   Test method Tensorflow.Keras.UnitTest.Model.ModelLoadTest.LSTMLoad threw exception: 
Google.Protobuf.InvalidProtocolBufferException: Protocol message contained an invalid tag (zero).
  Stack Trace:
      at Google.Protobuf.ParsingPrimitives.ParseTag(ReadOnlySpan`1& buffer, ParserInternalState& state)
   at Google.Protobuf.CodedInputStream.ReadTag()
   at ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedMetadata.MergeFrom(CodedInputStream input) in D:\a\TensorFlow.NET\TensorFlow.NET\src\TensorFlowNET.Keras\Protobuf\SavedMetadata.cs:line 158
   at Google.Protobuf.MessageExtensions.MergeFrom(IMessage message, Stream input, Boolean discardUnknownFields, ExtensionRegistry registry)
   at Google.Protobuf.MessageExtensions.MergeFrom(IMessage message, Stream input)
   at Tensorflow.Keras.Saving.SavedModel.KerasLoadModelUtils.load(String path, Boolean compile, LoadOptions options) in D:\a\TensorFlow.NET\TensorFlow.NET\src\TensorFlowNET.Keras\Saving\SavedModel\load.cs:line 55
   at Tensorflow.Keras.Saving.SavedModel.KerasLoadModelUtils.load_model(String filepath, IDictionary`2 custom_objects, Boolean compile, LoadOptions options) in D:\a\TensorFlow.NET\TensorFlow.NET\src\TensorFlowNET.Keras\Saving\SavedModel\load.cs:line 37
   at Tensorflow.Keras.Models.ModelsApi.load_model(String filepath, Boolean compile, LoadOptions options) in D:\a\TensorFlow.NET\TensorFlow.NET\src\TensorFlowNET.Keras\Models\ModelsApi.cs:line 19
   at Tensorflow.Keras.UnitTest.Model.ModelLoadTest.LSTMLoad() in D:\a\TensorFlow.NET\TensorFlow.NET\test\TensorFlowNET.Keras.UnitTest\Model\ModelLoadTest.cs:line 87