k2-fsa / sherpa

Speech-to-text server framework with next-gen Kaldi
https://k2-fsa.github.io/sherpa
Apache License 2.0
483 stars 97 forks source link

Add temperature to softmax #459

Open csukuangfj opened 11 months ago

csukuangfj commented 11 months ago

The steps to do that are given below.

Streaming models

  1. Add a member after the following line https://github.com/k2-fsa/sherpa/blob/4254d4a302bc7bc2497900d7474dcc29bbc23b9f/sherpa/cpp_api/online-recognizer.h#L70
// temperature for the softmax in the joiner
float temperature = 1.0;
  1. Register a commandline argument for the added member by adding a line after https://github.com/k2-fsa/sherpa/blob/4254d4a302bc7bc2497900d7474dcc29bbc23b9f/sherpa/cpp_api/online-recognizer.cc#L110-L115

  2. Change https://github.com/k2-fsa/sherpa/blob/4254d4a302bc7bc2497900d7474dcc29bbc23b9f/sherpa/cpp_api/online-recognizer.cc#L175

os << "chunk_size=" << chunk_size << ", ";
os << "temperature=" << chunk_size << ")";
  1. Fix the python binding by changing https://github.com/k2-fsa/sherpa/blob/4254d4a302bc7bc2497900d7474dcc29bbc23b9f/sherpa/python/csrc/online-recognizer.cc#L28 to
    const FastBeamSearchConfig &fast_beam_search_config = {}, float temperature = 1.0) 

    Change https://github.com/k2-fsa/sherpa/blob/4254d4a302bc7bc2497900d7474dcc29bbc23b9f/sherpa/python/csrc/online-recognizer.cc#L47 to

    ans->chunk_size = chunk_size; 
    ans->temperature = temperature;

Change https://github.com/k2-fsa/sherpa/blob/4254d4a302bc7bc2497900d7474dcc29bbc23b9f/sherpa/python/csrc/online-recognizer.cc#L60 to

 py::arg("fast_beam_search_config") = FastBeamSearchConfig(),
 py::arg("temperature") = 1.0) 

Change https://github.com/k2-fsa/sherpa/blob/4254d4a302bc7bc2497900d7474dcc29bbc23b9f/sherpa/python/csrc/online-recognizer.cc#L77 to

 .def_readwrite("chunk_size", &PyClass::chunk_size) 
 .def_readwrite("temperature", &PyClass::temperature) 
  1. Change https://github.com/k2-fsa/sherpa/blob/4254d4a302bc7bc2497900d7474dcc29bbc23b9f/sherpa/csrc/online-transducer-modified-beam-search-decoder.h#L36 to
    int32_t num_active_paths_;
    float temperature = 1.0;

Change https://github.com/k2-fsa/sherpa/blob/4254d4a302bc7bc2497900d7474dcc29bbc23b9f/sherpa/csrc/online-transducer-modified-beam-search-decoder.h#L18-L19 to

      OnlineTransducerModel *model, int32_t num_active_paths, float temperature)
      : model_(model), num_active_paths_(num_active_paths), temperature_(temperature) {}
  1. Change https://github.com/k2-fsa/sherpa/blob/4254d4a302bc7bc2497900d7474dcc29bbc23b9f/sherpa/cpp_api/online-recognizer.cc#L299 to

    model_.get(), config.num_active_paths, config.temperature);
  2. Change https://github.com/k2-fsa/sherpa/blob/4254d4a302bc7bc2497900d7474dcc29bbc23b9f/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc#L172

    auto log_probs = (logits / temperature_).log_softmax(-1).cpu();

That's it!

Non-streaming models

Please follow the above steps and change online to offline and modify the corresponding files.