flashlight / wav2letter

Facebook AI Research's Automatic Speech Recognition Toolkit
https://github.com/facebookresearch/wav2letter/wiki
Other
6.37k stars 1.01k forks source link

How to deal with model checkpoint compatibility issue? #963

Open DongChanS opened 3 years ago

DongChanS commented 3 years ago

Question

I trained Transformer (Transformer encoder + Transformer criterion) model from Wav2letter v0.2.

Unfortunately, I should use flashlight-consolidated's Wav2letter (Due to some updates in flashlight..)

But, i cannot use previous checkpoint directly.

Here is the error message.

E0324 07:40:20.319444 84544 Serializer.h:145 Error while loading "(CHECKPOINT_PATH)": Trying to load an unregistered polymorphic type (w2l::TransformerCriterion).
Make sure your type is registered with CEREAL_REGISTER_TYPE and that the archive you are using was included (and registered with CEREAL_REGISTER_ARCHIVE) prior to calling CEREAL_REGISTER_TYPE.
If your type is already registered and you still see this error, you may need to use CEREAL_REGISTER_DYNAMIC_INIT.

Yes, previous checkpoint (Transformer criterion) have the type w2l::TransformerCriterion.

It is not fl::app::asr::TransformerCriterion in flashlight-consolidated.

How to solve this problem??

Additional Context

New version's transformer

// flashlight/flashlight/app/asr/criterion/TransformerCriterion.h
AMUpdateFunc buildTransformerAmUpdateFunction(
    std::shared_ptr<SequenceCriterion>& crit);
} // namespace asr
} // namespace app
} // namespace fl

CEREAL_REGISTER_TYPE(fl::app::asr::TransformerCriterion)

Old version's transformer

// wav2letter/src/criterion/TransformerCriterion.h
AMUpdateFunc buildTransformerAmUpdateFunction(
    std::shared_ptr<SequenceCriterion>& crit);

} // namespace w2l

CEREAL_REGISTER_TYPE(w2l::TransformerCriterion)
tlikhomanenko commented 3 years ago

cc @vineelpratap @avidov @jacobkahn @xuqiantong Do we have converting scripts or any guides / hints how to do this?

DongChanS commented 3 years ago

Is it impossible..?? if not, please let me know how to change class type of checkpoint..

tlikhomanenko commented 3 years ago

It is possible =) @vineelpratap @avidov

vineelpratap commented 3 years ago

Hi, Sorry for the delay. We were busy with Interspeech deadline =) . We will aim to provide a script today/tomorrow to do the conversion.

DongChanS commented 3 years ago

Is there any problem for providing script..?? I should re-train the same model unless i receive the guidelines...

tlikhomanenko commented 3 years ago

yep, it is in the PR for now, need to fix some CI stuff but you can try it https://github.com/facebookresearch/flashlight/pull/524. Please comment if you have any troubles to use it as it is in this PR.

DongChanS commented 3 years ago

It is not working...

I was built serialization tools from above PR, but this error message occurred.

root@cd303acf12b0:~/flashlight/build/bin/asr# ./fl_asr_model_converter old {old_model_path}
WARNING: Logging before InitGoogleLogging() is written to STDERR
I0426 06:20:01.159353 20012 ModelConverter.cpp:105] Saving params from `old binary` model to a binary dump
E0426 06:20:01.160109 95840 Serializer.h:82 Error while loading "{old_model_path}": Trying to load an unregistered polymorphic type (w2l::TransformerCriterion).
Make sure your type is registered with CEREAL_REGISTER_TYPE and that the archive you are using was included (and registered with CEREAL_REGISTER_ARCHIVE) prior to calling CEREAL_REGISTER_TYPE.
If your type is already registered and you still see this error, you may need to use CEREAL_REGISTER_DYNAMIC_INIT.

E0426 06:20:02.160425 95840 Serializer.h:82 Error while loading "{old_model_path}": Trying to load an unregistered polymorphic type (w2l::TransformerCriterion).
Make sure your type is registered with CEREAL_REGISTER_TYPE and that the archive you are using was included (and registered with CEREAL_REGISTER_ARCHIVE) prior to calling CEREAL_REGISTER_TYPE.
If your type is already registered and you still see this error, you may need to use CEREAL_REGISTER_DYNAMIC_INIT.

E0426 06:20:04.160701 95840 Serializer.h:82 Error while loading "{old_model_path}": Trying to load an unregistered polymorphic type (w2l::TransformerCriterion).
Make sure your type is registered with CEREAL_REGISTER_TYPE and that the archive you are using was included (and registered with CEREAL_REGISTER_ARCHIVE) prior to calling CEREAL_REGISTER_TYPE.
If your type is already registered and you still see this error, you may need to use CEREAL_REGISTER_DYNAMIC_INIT.

I think because it also use fl::ext::Serializer::load(modelPath, version, cfg, network, criterion) which is also used in Decode.cpp

tlikhomanenko commented 3 years ago

Seems your old bin still doesn't have proper classes thus you cannot load model. Are you sure to use old_binary that has w2l::TransformerCriterion?

Also cc @vineelpratap.

vineelpratap commented 3 years ago

@DongChanS - {old_model_path} should be replaced with the appropriate path...

Also, can you copy the current fl::app::asr::TransformerCriterion class and create a duplicate class in the same file under namespace w2l::TransformerCriterion .

DongChanS commented 3 years ago

@tlikhomanenko - Yes, i'm sure. but i didn't know that the serialization tool require full AM binary (network + criterion). so i try to this with full AM binary file!

Thanks! i successfully convert the wav2letter v0.2 binary to flashlight v0.3 binary

But, i conducted procedures different than @vineelpratap. is it okay?

1) Since i need only saveToBinaryDump function in wav2letter v0.2, I built serialization tools in wav2letter v0.2 with minimal setting

  // tools/serialization/ModelConverter.cpp
  std::string binaryType = argv[1];
  std::string modelPath = argv[2];
  std::string version;
  if (binaryType == "old") {
    LOG(INFO) << "Saving params from `old binary` model to a binary dump";
    W2lSerializer::load(modelPath, cfg, criterion);
    saveToBinaryDump(tempModelPath(modelPath).c_str(), network, criterion);
  } else if (binaryType == "new") {
    LOG(FATAL) << "Unsupported binary type in wav2letter";
  } else {
    LOG(FATAL) << "Incorrect binary type specified.";
  }

2) I built flashlight v0.3 with full setting of serialization tools, and run it

root@cd303acf12b0:~/flashlight/build/bin/asr# ./fl_asr_model_converter new /root/flashlight/025_model_last.bin
WARNING: Logging before InitGoogleLogging() is written to STDERR
I0429 07:28:43.208770 34623 ModelConverter.cpp:112] Loading model params from binary dump to `new binary` model
I0429 07:28:52.309464 34623 ModelConverter.cpp:220] Done !

But this model cannot work properly..... This is the error messase in the TransformerCriterion.

I0429 07:31:37.571990 34664 memoryefficient_offline_inference_cpu_consolidated.cpp:329] [Criterion] Number of params: 15070135
I0429 07:31:37.577639 34664 memoryefficient_offline_inference_cpu_consolidated.cpp:362] [ConvLM]: Loading LM from /model/wav2letter/v0.3/lm_model
[ConvLM]: Loading vocabulary from /model/wav2letter/v0.3/lm_vocab
[ConvLM]: vocabulary size of convLM 20552
I0429 07:31:52.202826 34664 memoryefficient_offline_inference_cpu_consolidated.cpp:378] [Decoder] LM constructed.
I0429 07:31:52.203383 34891 memoryefficient_offline_inference_cpu_consolidated.cpp:422] [ConvLM]: Loading LM from /model/wav2letter/v0.3/lm_model
I0429 07:31:52.203384 34889 memoryefficient_offline_inference_cpu_consolidated.cpp:422] [ConvLM]: Loading LM from /model/wav2letter/v0.3/lm_model
I0429 07:31:52.203434 34888 memoryefficient_offline_inference_cpu_consolidated.cpp:487] [Decoder] LexiconFreeSeq2Seq decoder with token-LM loaded in thread: 0
I0429 07:31:52.203383 34890 memoryefficient_offline_inference_cpu_consolidated.cpp:422] [ConvLM]: Loading LM from /model/wav2letter/v0.3/lm_model
terminate called after throwing an instance of 'std::invalid_argument'
  what():  Invalid inputs for transformer block: there should be at least input and mask
*** Aborted at 1619681514 (unix time) try "date -d @1619681514" if you are using GNU date ***
PC: @     0x7efd1920418b gsignal
*** SIGABRT (@0x8768) received by PID 34664 (TID 0x7efd14680980) from PID 34664; stack trace: ***
    @     0x7efd24996631 (unknown)
    @     0x7efd1d87a3c0 (unknown)
    @     0x7efd1920418b gsignal
    @     0x7efd191e3859 abort
    @     0x7efd195fc951 (unknown)
    @     0x7efd1960847c (unknown)
    @     0x7efd196084e7 std::terminate()
    @     0x7efd1960846f std::rethrow_exception()
    @     0x563be81e7f05 main
    @     0x7efd191e50b3 __libc_start_main
    @     0x563be826fe2e _start
DongChanS commented 3 years ago

This error message is related to Transformer module in flashlight

The Flashlight v0.3's Transformer require mask unlike wav2letter v0.2's.

// flashlight/flashlight/fl/contrib/modules/Transformer.cpp

std::vector<Variable> Transformer::forward(const std::vector<Variable>& input) {
  // previous step[optionally], input, padMask
  // padMask should be empty if previous step is provided
  // padMask is expected to have "1" on the used positions and "0" on padded
  // positions
  if (input.size() < 2) {
    throw std::invalid_argument(
        "Invalid inputs for transformer block: there should be at least input and mask");
  }
  auto x = input.at(input.size() - 2);
  if (!input.back().isempty() && x.dims(2) != input.back().dims(1)) {
    throw std::invalid_argument(
        "Invalid inputs for transformer block: input and Mask batch sizes are different");
  }

But, The Transformer encoder doesn't return mask. (since the last layer of encoder is Linear layer, mask is not included in output)

// flashlight/flashlight/ext/common/SequentialBuilder.cpp

fl::Variable forwardSequentialModuleWithPadMask(
    const fl::Variable& input,
    std::shared_ptr<fl::Module> ntwrk,
    const af::array& inputSizes) {
  // expected input dims T x C x 1 x B
  int T = input.dims(0), B = input.dims(3);
  auto inputMaxSize = af::tile(af::max(inputSizes), 1, B);
  af::array inputNotPaddedSize = af::ceil(inputSizes * T / inputMaxSize);
  auto padMask = af::iota(af::dim4(T, 1), af::dim4(1, B)) <
      af::tile(inputNotPaddedSize, T, 1);
  auto ntwrkSeq = std::dynamic_pointer_cast<fl::Sequential>(ntwrk);
  auto output = input;
  for (auto& module : ntwrkSeq->modules()) {
    auto tr = std::dynamic_pointer_cast<fl::Transformer>(module);
    auto cfr = std::dynamic_pointer_cast<fl::Conformer>(module);
    if (tr != nullptr || cfr != nullptr) {
      output = module->forward({output, fl::noGrad(padMask)}).front();
    } else {
      output = module->forward({output}).front();
    }
  }
  return output.as(input.type());
}

How to solve this problem..??

tlikhomanenko commented 3 years ago

Sorry, I don't get what happened. So you have converted model and running decoding in fl v0.3 and see the error on the forward pass for the transformer block, right? Let me check that for s2s it should work and we have a proper call on transformer blocks everywhere.

vineelpratap commented 3 years ago

@DongChanS I believe you can just do

std::vector<Variable> Transformer::forward(const std::vector<Variable>& input2) {
   auto input = input2;
   if (input2.size() == 1) {
       input.push_back(fl::Variable(af::array(), false));
    }
   if (input.size() < 2) {
      throw std::invalid_argument(
        "Invalid inputs for transformer block: there should be at least input and mask");
  }

I'll let @tlikhomanenko confirm though...

tlikhomanenko commented 3 years ago

@DongChanS Please change this https://github.com/flashlight/flashlight/blob/master/flashlight/app/asr/criterion/TransformerCriterion.cpp#L284 to

yBatched = layer(i)->forward(std::vector<Variable>({yBatched}), fl::Variable(af::array())).front();

and this https://github.com/flashlight/flashlight/blob/master/flashlight/app/asr/criterion/TransformerCriterion.cpp#L296 to

yBatched = layer(i)->forward(tmp, fl::Variable(af::array())).front();

I will send this fix later, but this should unblock you. Let me know if you still have problems.

DongChanS commented 3 years ago

Good!

Since there is some syntax error, i changed these two lines to

yBatched = layer(i)->forward(std::vector<Variable>({yBatched, fl::Variable(af::array(), false)})).front();
tmp.push_back(fl::Variable(af::array(), false));
yBatched = layer(i)->forward(tmp).front();

then, the model works fine!

tlikhomanenko commented 3 years ago

Feel free to send PR on this =)