Closed loretoparisi closed 3 years ago
If anyone succeeded in making a brief inference, I would appreciate it if you could leave it here.
If I succeed, I will leave the code here.
I Success !!
I'll be wrapping up the code and put it up here!
I did it in Fairseq version 0.9.0.
In fairseq-0.9.0, Wav2vec-2.0 is not supported, So I took it from the fairseq code and applied it.
I hope this will help.
I will improve the code further and send a pull request. Here is my code.
import os
import math
import sys
import torch
import torch.nn.functional as F
import numpy as np
import itertools as it
import torch.nn as nn
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.tasks.audio_pretraining import AudioPretrainingTask
from fairseq.data import Dictionary
from fairseq.models import BaseFairseqModel
import soundfile as sf
from wav2letter.decoder import CriterionType
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
import contextlib
import torch
import torch.nn as nn
from fairseq import checkpoint_utils
from fairseq.models import FairseqEncoder
from examples.wav2vec2.tasks.audio_pretraining import Wav2vec2PretrainingTask
def post_process(sentence: str, symbol: str):
if symbol == "sentencepiece":
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
elif symbol == 'wordpiece':
sentence = sentence.replace(" ", "").replace("_", " ").strip()
elif symbol == 'letter':
sentence = sentence.replace(" ", "").replace("|", " ").strip()
elif symbol == "_EOW":
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
elif symbol is not None and symbol != 'none':
sentence = (sentence + " ").replace(symbol, "").rstrip()
return sentence
class Wav2VecEncoder(FairseqEncoder):
def __init__(self, args, tgt_dict=None):
self.apply_mask = args.apply_mask
arg_overrides = {
"dropout": args.dropout,
"activation_dropout": args.activation_dropout,
"dropout_input": args.dropout_input,
"attention_dropout": args.attention_dropout,
"mask_length": args.mask_length,
"mask_prob": args.mask_prob,
"mask_selection": args.mask_selection,
"mask_other": args.mask_other,
"no_mask_overlap": args.no_mask_overlap,
"mask_channel_length": args.mask_channel_length,
"mask_channel_prob": args.mask_channel_prob,
"mask_channel_selection": args.mask_channel_selection,
"mask_channel_other": args.mask_channel_other,
"no_mask_channel_overlap": args.no_mask_channel_overlap,
"encoder_layerdrop": args.layerdrop,
"feature_grad_mult": args.feature_grad_mult,
}
if getattr(args, "w2v_args", None) is None:
state = checkpoint_utils.load_checkpoint_to_cpu(
args.w2v_path, arg_overrides
)
w2v_args = state["args"]
else:
state = None
w2v_args = args.w2v_args
assert args.normalize == w2v_args.normalize, 'Fine-tuning works best when data normalization is the same'
w2v_args.data = args.data
task = Wav2vec2PretrainingTask.setup_task(w2v_args)
model = task.build_model(w2v_args)
if state is not None and not args.no_pretrained_weights:
model.load_state_dict(state["model"], strict=True)
model.remove_pretraining_modules()
super().__init__(task.source_dictionary)
d = w2v_args.encoder_embed_dim
self.w2v_model = model
self.final_dropout = nn.Dropout(args.final_dropout)
self.freeze_finetune_updates = args.freeze_finetune_updates
self.num_updates = 0
if tgt_dict is not None:
self.proj = Linear(d, len(tgt_dict))
elif getattr(args, 'decoder_embed_dim', d) != d:
self.proj = Linear(d, args.decoder_embed_dim)
else:
self.proj = None
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
super().set_num_updates(num_updates)
self.num_updates = num_updates
def forward(self, source, padding_mask, tbc=True, **kwargs):
w2v_args = {
"source": source,
"padding_mask": padding_mask,
"mask": self.apply_mask and self.training,
}
ft = self.freeze_finetune_updates <= self.num_updates
with torch.no_grad() if not ft else contextlib.ExitStack():
x, padding_mask = self.w2v_model.extract_features(**w2v_args)
if tbc:
# B x T x C -> T x B x C
x = x.transpose(0, 1)
x = self.final_dropout(x)
if self.proj:
x = self.proj(x)
return {
"encoder_out": x, # T x B x C
"encoder_padding_mask": padding_mask, # B x T
"padding_mask": padding_mask,
}
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None:
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(0, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return None
def upgrade_state_dict_named(self, state_dict, name):
return state_dict
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m
def base_architecture(args):
args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False)
args.dropout_input = getattr(args, "dropout_input", 0)
args.final_dropout = getattr(args, "final_dropout", 0)
args.apply_mask = getattr(args, "apply_mask", False)
args.dropout = getattr(args, "dropout", 0)
args.attention_dropout = getattr(args, "attention_dropout", 0)
args.activation_dropout = getattr(args, "activation_dropout", 0)
args.mask_length = getattr(args, "mask_length", 10)
args.mask_prob = getattr(args, "mask_prob", 0.5)
args.mask_selection = getattr(args, "mask_selection", "static")
args.mask_other = getattr(args, "mask_other", 0)
args.no_mask_overlap = getattr(args, "no_mask_overlap", False)
args.mask_channel_length = getattr(args, "mask_channel_length", 10)
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5)
args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
args.mask_channel_other = getattr(args, "mask_channel_other", 0)
args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)
args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0)
args.feature_grad_mult = getattr(args, "feature_grad_mult", 0)
args.layerdrop = getattr(args, "layerdrop", 0.0)
class W2lDecoder(object):
def __init__(self, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = 1
self.criterion_type = CriterionType.CTC
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
self.asg_transitions = None
def generate(self, models, sample, **unused):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(models, encoder_input)
return self.decode(emissions)
def get_emissions(self, models, encoder_input):
"""Run encoder and normalize emissions"""
# encoder_out = models[0].encoder(**encoder_input)
encoder_out = models[0](**encoder_input)
if self.criterion_type == CriterionType.CTC:
emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs):
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
class W2lViterbiDecoder(W2lDecoder):
def __init__(self, tgt_dict):
super().__init__(tgt_dict)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = list()
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] for b in range(B)
]
class Wav2VecCtc(BaseFairseqModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
add_common_args(parser)
def __init__(self, w2v_encoder, args):
super().__init__()
self.w2v_encoder = w2v_encoder
self.args = args
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, args, target_dict):
"""Build a new model instance."""
base_architecture(args)
w2v_encoder = Wav2VecEncoder(args, target_dict)
return cls(w2v_encoder, args)
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output["encoder_out"]
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def forward(self, **kwargs):
x = self.w2v_encoder(**kwargs)
return x
def get_feature(filepath):
def postprocess(feats, sample_rate):
if feats.dim == 2:
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
wav, sample_rate = sf.read(filepath)
feats = torch.from_numpy(wav).float()
feats = postprocess(feats, sample_rate)
return feats
def load_target_dict(manifest_path='./manifest'):
dict_path = os.path.join(manifest_path, "dict.ltr.txt")
target_dict = Dictionary.load(dict_path)
return target_dict
def load_model(model_path, target_dict):
# state = checkpoint_utils.load_checkpoint_to_cpu(model_path)
# args = state["args"]
w2v = torch.load(model_path)
# from examples.wav2vec2.models.wav2vec2_asr import Wav2Vec2Model
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
model.load_state_dict(w2v["model"], strict=True)
return [model]
def main():
sample, input = dict(), dict()
WAV_PATH = 'xxx.wav'
W2V_PATH = 'wav2vec2_vox_960h.pt'
manifest_path = "MANIFEST_PATH"
feature = get_feature(WAV_PATH )
use_cuda = torch.cuda.is_available()
target_dict = load_target_dict(manifest_path)
model = load_model(W2V_PATH, target_dict)
model[0].eval()
generator = W2lViterbiDecoder(target_dict)
input["source"] = feature.unsqueeze(0)
padding_mask = torch.BoolTensor(input["source"].size(1)).fill_(False).unsqueeze(0)
input["padding_mask"] = padding_mask
sample["net_input"] = input
with torch.no_grad():
hypo = generator.generate(model, sample, prefix_tokens=None)
hyp_pieces = target_dict.string(hypo[0][0]["tokens"].int().cpu())
print(post_process(hyp_pieces, 'letter'))
if __name__ == '__main__':
main()
I CAME TO THE CONCLUSION THAT WHAT WE NEED IN EDUCATION IS MUCH BETTER UNDERSTANDING EXCLUSIVE AND LEARNING FROM A MOTIVATION OF PERSPECTIVE FROM A PSYCHOLOGICAL REPROSPECTIVE
@sooftware amazing!!! Did you use the latest version of wav2letter
?
I don`t sure but I have a command that I used.
# Install python libraries
pip install soundfile
pip install torchaudio
pip install sentencepiece
# Update apt-get & Install soundfile
apt-get update \
&& apt-get upgrade -y \
&& apt-get install -y \
&& apt-get -y install apt-utils gcc libpq-dev libsndfile-dev
# Install kenlm
mkdir external_lib
cd external_lib
sudo apt install build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev
git clone https://github.com/kpu/kenlm.git
cd kenlm
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=Release -DKENLM_MAX_ORDER=20 -DCMAKE_POSITION_INDEPENDENT_CODE=ON
make -j 16
export KENLM_ROOT_DIR=$ABSOLUTE_PATH'/external_lib/kenlm/'
cd ../..
# Install Additional Dependencies (ATLAS, OpenBLAS, Accelerate, Intel MKL)
apt-get install libsndfile1-dev libopenblas-dev libfftw3-dev libgflags-dev libgoogle-glog-dev
# Install wav2letter
git clone -b v0.2 https://github.com/facebookresearch/wav2letter.git
cd wav2letter/bindings/python
pip install -e .
cd ../../..
I installed wav2letter a few days ago.
@sooftware Thanks! I'm getting an import error for ModuleNotFoundError: No module named 'examples.wav2vec2'
.
This module doesn't exist in fairseq
though. Did you add it from somewhere else?
@sooftware Could you please specify what does you have inside the file from manifest_path = "MANIFEST_PATH"
Is this path to link
@mironnn The manifest path only contains the dictionary from what I can tell. Look at the load_target_dict
function
def load_target_dict(manifest_path='./manifest'):
dict_path = os.path.join(manifest_path, "dict.ltr.txt")
target_dict = Dictionary.load(dict_path)
return target_dict
@sooftware Thanks! I'm getting an import error for
ModuleNotFoundError: No module named 'examples.wav2vec2'
. This module doesn't exist infairseq
though. Did you add it from somewhere else?
Have the same issue =(
@kpister I made and used wav2vec2 in the examples folder because I was using it in fairseq-0.9.0.
I'll make code to deduce the latest fairseq! Please wait for a little.
@mironnn
I create pull request (https://github.com/pytorch/fairseq/pull/2668)
I create recognize.py in examples/wav2vec/
directory.
Usage is simple.
$ python3 examples/wav2vec/recognize.py --wav_path $WAV_PATH --w2v_path $W2V_PATH --target_dict_path $TARGET_DICT_PATH
I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I LOVE THEE PURELY AS THEY TURN FROM PRAISE
Here is the code recognize.py
import torch
import argparse
import soundfile as sf
import torch.nn.functional as F
import itertools as it
from fairseq import utils
from fairseq.models import BaseFairseqModel
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
from fairseq.data import Dictionary
from fairseq.models.wav2vec.wav2vec2_asr import base_architecture, Wav2VecEncoder
from wav2letter.decoder import CriterionType
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
parser = argparse.ArgumentParser(description='Wav2vec-2.0 Recognize')
parser.add_argument('--wav_path', type=str,
default='~/xxx.wav',
help='path of wave file')
parser.add_argument('--w2v_path', type=str,
default='~/wav2vec2_vox_960h.pt',
help='path of pre-trained wav2vec-2.0 model')
parser.add_argument('--target_dict_path', type=str,
default='dict.ltr.txt',
help='path of target dict (dict.ltr.txt)')
class Wav2VecCtc(BaseFairseqModel):
def __init__(self, w2v_encoder, args):
super().__init__()
self.w2v_encoder = w2v_encoder
self.args = args
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, args, target_dict):
"""Build a new model instance."""
base_architecture(args)
w2v_encoder = Wav2VecEncoder(args, target_dict)
return cls(w2v_encoder, args)
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output["encoder_out"]
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def forward(self, **kwargs):
x = self.w2v_encoder(**kwargs)
return x
class W2lDecoder(object):
def __init__(self, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = 1
self.criterion_type = CriterionType.CTC
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
self.asg_transitions = None
def generate(self, models, sample, **unused):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(models, encoder_input)
return self.decode(emissions)
def get_emissions(self, models, encoder_input):
"""Run encoder and normalize emissions"""
# encoder_out = models[0].encoder(**encoder_input)
encoder_out = models[0](**encoder_input)
if self.criterion_type == CriterionType.CTC:
emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs):
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
class W2lViterbiDecoder(W2lDecoder):
def __init__(self, tgt_dict):
super().__init__(tgt_dict)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = list()
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] for b in range(B)
]
def post_process(sentence: str, symbol: str):
if symbol == "sentencepiece":
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
elif symbol == 'wordpiece':
sentence = sentence.replace(" ", "").replace("_", " ").strip()
elif symbol == 'letter':
sentence = sentence.replace(" ", "").replace("|", " ").strip()
elif symbol == "_EOW":
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
elif symbol is not None and symbol != 'none':
sentence = (sentence + " ").replace(symbol, "").rstrip()
return sentence
def get_feature(filepath):
def postprocess(feats, sample_rate):
if feats.dim == 2:
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
wav, sample_rate = sf.read(filepath)
feats = torch.from_numpy(wav).float()
feats = postprocess(feats, sample_rate)
return feats
def load_model(model_path, target_dict):
w2v = torch.load(model_path)
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
model.load_state_dict(w2v["model"], strict=True)
return [model]
def main():
args = parser.parse_args()
sample = dict()
net_input = dict()
feature = get_feature(args.wav_path)
target_dict = Dictionary.load(args.target_dict_path)
model = load_model(args.w2v_path, target_dict)
model[0].eval()
generator = W2lViterbiDecoder(target_dict)
net_input["source"] = feature.unsqueeze(0)
padding_mask = torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0)
net_input["padding_mask"] = padding_mask
sample["net_input"] = net_input
with torch.no_grad():
hypo = generator.generate(model, sample, prefix_tokens=None)
hyp_pieces = target_dict.string(hypo[0][0]["tokens"].int().cpu())
print(post_process(hyp_pieces, 'letter'))
if __name__ == '__main__':
main()
I don`t sure but I have a command that I used.
# Install python libraries pip install soundfile pip install torchaudio pip install sentencepiece # Update apt-get & Install soundfile apt-get update \ && apt-get upgrade -y \ && apt-get install -y \ && apt-get -y install apt-utils gcc libpq-dev libsndfile-dev # Install kenlm mkdir external_lib cd external_lib sudo apt install build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev git clone https://github.com/kpu/kenlm.git cd kenlm mkdir -p build cd build cmake .. -DCMAKE_BUILD_TYPE=Release -DKENLM_MAX_ORDER=20 -DCMAKE_POSITION_INDEPENDENT_CODE=ON make -j 16 export KENLM_ROOT_DIR=$ABSOLUTE_PATH'/external_lib/kenlm/' cd ../.. # Install Additional Dependencies (ATLAS, OpenBLAS, Accelerate, Intel MKL) apt-get install libsndfile1-dev libopenblas-dev libfftw3-dev libgflags-dev libgoogle-glog-dev # Install wav2letter git clone -b v0.2 https://github.com/facebookresearch/wav2letter.git cd wav2letter/bindings/python pip install -e . cd ../../..
@sooftware thanks, I'm trying a CPU build in this case I get a
CMake Error at cmake/CUDAUtils.cmake:12 (message):
CUDA required to build CUDA criterion backend
Call Stack (most recent call first):
src/libraries/criterion/CMakeLists.txt:28 (include)
I can see from your script you build the python bindings, but how to include the -DCRITERION_BACKEND=CPU
to disable CUDA?
Oh, I'm sorry. I don't know that issue. T.T
@loretoparisi
I tested CPU case in docker env, and the recognize.py
did work.
Here are my processes below:
fairseq/data
, model, dict, wav files:
# For example
fairseq/data/wav2vec_small_960h.pt # model
fairseq/data/dict.ltr.txt # dict file
fairseq/data/temp.wav # the wav you want to test, and don't forget to resample it as 16kHz
recognize.py
mentioned above, I put it at fairseq/examples/wav2vec/recognize.py
fairseq/wav2vec2.CPU.Dockerfile
, the build script is:
FROM wav2letter/wav2letter:cpu-latest
ENV USE_CUDA=0 ENV KENLM_ROOT_DIR=/root/kenlm
ENV LD_LIBRARY_PATH=/opt/intel/compilers_and_libraries_2018.5.274/linux/mkl/lib/intel64:$LD_IBRARY_PATH WORKDIR /root/wav2letter/bindings/python
RUN pip install --upgrade pip && pip install soundfile packaging && pip install -e .
WORKDIR /root RUN git clone https://github.com/pytorch/fairseq.git RUN mkdir data COPY examples/wav2vec/recognize.py /root/fairseq/examples/wav2vec/recognize.py
WORKDIR /root/fairseq RUN pip install --editable ./ && python examples/speech_recognition/infer.py --help && python examples/wav2vec/recognize.py --help
4. go to `fairseq/` dir, then build docker:
``` bash
# build
docker build -t wav2vec2 -f wav2vec2.CPU.Dockerfile .
# run docker
docker run --rm -itd --ipc=host -v $PWD/data:/root/data --name w2v wav2vec2
# go into container
docker exec -it w2v bash
# run recognize
python examples/wav2vec/recognize.py --wav_path ~/data/temp.wav --w2v_path ~/data/wav2vec_small_960h.pt --target_dict_path ~/data/dict.ltr.txt
@mychiux413 thank you so much. I'm getting this UserWarning
/root/fairseq/examples/speech_recognition/w2l_decoder.py:39: UserWarning: wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings
"wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings"
usage: recognize.py [-h] [--wav_path WAV_PATH] [--w2v_path W2V_PATH]
[--target_dict_path TARGET_DICT_PATH]
recognize.py: error: unrecognized arguments: --wv2_path /app/data/wav2vec_small_10m.pt
Within the container the command used was
python examples/wav2vec/recognize.py --wav_path /root/data/temp.wav --wv2_path /root/data/wav2vec_small_10m.pt --target_dict_path /root/data/dict.ltr.txt
It should not be there, so I have opened an issue.
@loretoparisi there is an typo. not wv2_path, w2v_path. :)
@sooftware gosh!!! I've have checked it ten times!
LoL!! I'm glad I found it now!
@loretoparisi Have you tried evaluating Wav2vec-2.0 model with KenLM or Transformer LM?
@sooftware not yet but this is definitively something I'm are going to do!
Let me know if you succeed! I have an issue (https://github.com/pytorch/fairseq/issues/2654) (with KenLM)
If I succeed, I'll write on the issue.
@sooftware definitively I will. In the meanwhile I have pushed everything here with Docker. I did two Dockerfile
. The one suggested by @mychiux413 (đź‘Ť thanks) and one edited by me with your commands (đź‘Ť thank you too) slightly adapted starting from a stripped down python:3.7.4-slim-buster
. They both works, but the docker images have very different sizes:
wav2vec-python3 latest cfdcb450b427 51 minutes ago 9.97GB
wav2vec-wav2letter latest e028493c66b0 2 hours ago 3.37GB
Thank you guys for your help and collaboration! I will keep you posted.
Grrrrrrreat !!!
I am studying wav2vec with great interest. It would be nice if we could help each other. :)
if feats.dim == 2:
@sooftware I guess this a typo, it worked for me when I changed if feats.dim == 2: to if feats.dim() == 2: I have observed this in @loretoparisi repo's as well. Anyways, Thanks a ton to both of you for your awesome work!! đź‘Ť
if feats.dim == 2:
@sooftware I guess this a typo, it worked for me when I changed if feats.dim == 2: to if feats.dim() == 2: I have observed this in @loretoparisi repo's as well. Anyways, Thanks a ton to both of you for your awesome work!! đź‘Ť
Or you just need to convert your audio to mono from stereo. than it would be feats.dim() == 1
btw it should be actually fixed so I will change in my repo get_feature
to
def get_feature(filepath):
def postprocess(feats, sample_rate):
if feats.dim() == 2:
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
if feats.dim == 2:
@sooftware I guess this a typo, it worked for me when I changed if feats.dim == 2: to if feats.dim() == 2: I have observed this in @loretoparisi repo's as well. Anyways, Thanks a ton to both of you for your awesome work!! đź‘Ť
Or you just need to convert your audio to mono from stereo. than it would be
feats.dim() == 1
Yes I confirm it's a mono/stereo issue. I had an mp3 and tried converting it as:
ffmpeg -i input.mp3 -acodec pcm_s16le -ac 1 -ar 16000 output.wav
and it worked
I succeeded in the installation of wav2letter and fairseq but when running recognize.py:
python examples/wav2vec/recognize.py --wav_path /root/data/audio1.wav --w2v_path /root/data/wav2vec_small_10m.pt --target_dict_path /root/data/dict.ltr.txt
I get the following UserWarning:
`/root/fairseq/examples/speech_recognition/w2l_decoder.py:39: UserWarning: wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings
I used wav2vec.Dockerfile from @loretoparisi repo's for the installation. The repo structure is:
root/ |──data/ |──fairseq/ |──wav2letter/ |──flashlight/ |──kenlm/
Did you encounter the same warning message?
I encounter the same warning message too. But, It works well. Don`t worry.
@LorenzoGalizia yes I can confirm that there is this UserWarning
. I have asked in a separated issue here, but it's not exactly clear which was the underlying cause though.
I am getting the following error on running the python script recognize.py
Traceback (most recent call last):
File "/home/ubuntu/fairseq/examples/wav2vec/recognize.py", line 10, in <module>
from wav2letter.decoder import CriterionType
File "/home/ubuntu/wav2letter/bindings/python/wav2letter/decoder.py", line 3, in <module>
from wav2letter._decoder import *
ImportError: /home/ubuntu/wav2letter/bindings/python/wav2letter/_decoder.cpython-36m-x86_64-linux-gnu.so: undefined symbol: _ZN2lm5ngram11LoadVirtualEPKcRKNS0_6ConfigENS0_9ModelTypeE
Can someone help?
@RMisha101 If you are using the Docker repository this should not happen.
i am not using the docker repository. Can you tell me what am i doing wrong?
@RMisha101 as far as I can see there is a problem with Cython build of the python bindings - have a look at there https://github.com/facebookresearch/wav2letter/issues/486.
Btw, I strongly suggest to use the Dockerfile
we have provided to avoid these issues.
Hi all! I've just opened an issue regarding the output of the inference done with this script. Would it be possible to get some sort of character time information? You can find the discussion here.
Hi @loretoparisi @sooftware Can you guys suggest me a way to add custom vocabulary to pre-trained Wav2vec 2.0 ASR model?
Thanks in advance!
@bharat-patidar What means custom vocabulary?
Pre-train Wav2vec 2.0 model`s vocab dictionary is fixed.
If you want to generate new vocabulary, you have to finetune from pre-trained Wav2vec 2.0 model.
In this case, You just make vocab to this format:
vocab1 frequency1
vocab2 frequency2
vocab3 frequency3
...
...
...
@bharat-patidar What means custom vocabulary? Pre-train Wav2vec 2.0 model`s vocab dictionary is fixed. If you want to generate new vocabulary, you have to finetune from pre-trained Wav2vec 2.0 model.
In this case, You just make vocab to this format:
vocab1 frequency1 vocab2 frequency2 vocab3 frequency3 ... ... ...
Hi @sooftware ,
Thanks for the response. By custom vocabulary, I want to identify few custom words, let's say my name, "Bharat" or word like "fairseq" which are not english dictionary words. What changes do we have to make for this requirement?
@bharat-patidar Not that in method. Wav2vec Model inference with character level.
So, Wav2vec Model can inference "Bharat", "fairseq" .
@sooftware @loretoparisi Can I use my own trained checkpoint file to infer audio files using recognize.py or does this only work for the pre-trained models? If not then how can i modify the script to use my own model? Thanks!
@Romulan12 You can infer with your trained checkpoint.
I succeed in recognizing my own model.
Hi, I am having issues with the following import:
from examples.speech_recognition import W2lViterbiDecoder
I receive the following error:
No module named 'examples.speech_recognition.utils'
However he does recognize examples
and the submodule noisychannel
but not speech_recognition
.
Anyone who has the same problem and/or knows the solution?
(Also, not sure this is relevant, in the init.py file there is only written import examples.noisychannel
but not import examples.speech_recognition
)
@Romulan12 You can infer with your trained checkpoint. I succeed in recognizing my own model.
How were you able to do it? I am unable to. What path do you set for the checkpoint? Can you give me the command for training? Please also give the command for inference using my own checkpoint_best.pt @loretoparisi @sooftware
Hi all, In Docker environment, The warning message can be ignored, if we are NOT using Fairseq Language Model, The issue was here: wav2letter#775
and the file example/speech_recognition/w2l_decoder.py
import the module LexiconFreeDecoder
as below( w2l_decoder.py#L35 ):
try:
from wav2letter.common import create_word_dict, load_words
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
from wav2letter.decoder import (
CriterionType,
DecoderOptions,
KenLM,
LM,
LMState,
SmearingMode,
Trie,
LexiconDecoder,
LexiconFreeDecoder, # ---> wav2letter don't support LexiconFreeDecoder in python bindings right now.
)
except:
warnings.warn(
"wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings"
)
LM = object
LMState = object
so, with default wav2letter module, the exception must be triggered when import LexiconFreeDecoder, unless we customize the .cpp file.
if we comment that line (w2l_decoder.py#L35) , the warning message will disappear,
but in the other words, if we use fairseqlm
as decoder, we must specify the lexicon, or the program will call LexiconFreeDecoder()
as decoder ( w2l_decoder.py#L405 ).
@loretoparisi I tested CPU case in docker env, and the
recognize.py
did work.Here are my processes below:
- prepare wav2vec2 required data at
fairseq/data
, model, dict, wav files:# For example fairseq/data/wav2vec_small_960h.pt # model fairseq/data/dict.ltr.txt # dict file fairseq/data/temp.wav # the wav you want to test, and don't forget to resample it as 16kHz
- prepare
recognize.py
mentioned above, I put it atfairseq/examples/wav2vec/recognize.py
- prepare a dockerfile at
fairseq/wav2vec2.CPU.Dockerfile
, the build script is:FROM wav2letter/wav2letter:cpu-latest ENV USE_CUDA=0 ENV KENLM_ROOT_DIR=/root/kenlm # will use Intel MKL for featurization but this may cause dynamic loading conflicts. # ENV USE_MKL=1 ENV LD_LIBRARY_PATH=/opt/intel/compilers_and_libraries_2018.5.274/linux/mkl/lib/intel64:$LD_IBRARY_PATH WORKDIR /root/wav2letter/bindings/python RUN pip install --upgrade pip && pip install soundfile packaging && pip install -e . WORKDIR /root RUN git clone https://github.com/pytorch/fairseq.git RUN mkdir data COPY examples/wav2vec/recognize.py /root/fairseq/examples/wav2vec/recognize.py WORKDIR /root/fairseq RUN pip install --editable ./ && python examples/speech_recognition/infer.py --help && python examples/wav2vec/recognize.py --help
- go to
fairseq/
dir, then build docker:# build docker build -t wav2vec2 -f wav2vec2.CPU.Dockerfile . # run docker docker run --rm -itd --ipc=host -v $PWD/data:/root/data --name w2v wav2vec2 # go into container docker exec -it w2v bash # run recognize python examples/wav2vec/recognize.py --wav_path ~/data/temp.wav --w2v_path ~/data/wav2vec_small_960h.pt --target_dict_path ~/data/dict.ltr.txt
Hi everyone, When I try to build the docker file, i get an error: ERROR: File "setup.py" not found. Directory cannot be installed in editable mode: /datadrive/conda-envs/stt/lib/python3.8/site-packages/wav2letter/bindings/python. The file is definitely there, anyone else had this issue? Thanks in advance
@RMisha101 Tell me more specific??
I train the model by README.md command.
After training, I recognize by above code.
@kjellvb
there should have no conda in wav2letter/wav2letter:cpu-latest
, and the python version is 3.6,
so I don't know why you got such error message when built the docker file.
I am getting the following error on running the python script recognize.py
Traceback (most recent call last): File "/home/ubuntu/fairseq/examples/wav2vec/recognize.py", line 10, in <module> from wav2letter.decoder import CriterionType File "/home/ubuntu/wav2letter/bindings/python/wav2letter/decoder.py", line 3, in <module> from wav2letter._decoder import * ImportError: /home/ubuntu/wav2letter/bindings/python/wav2letter/_decoder.cpython-36m-x86_64-linux-gnu.so: undefined symbol: _ZN2lm5ngram11LoadVirtualEPKcRKNS0_6ConfigENS0_9ModelTypeE
Can someone help?
I have the same error with you, and after studied it, the answer is, we should call it in different way. so you should do this:
from wav2letter import decoder
from wav2letter import criterion
from wav2letter import common
CriterionType = decoder.CriterionType
DecoderOptions = decoder.DecoderOptions
KenLM = decoder.KenLM
LexiconDecoder = decoder.LexiconDecoder
SmearingMode = decoder.SmearingMode
Trie = decoder.Trie
CpuViterbiPath = criterion.CpuViterbiPath
get_data_ptr_as_bytes = criterion.get_data_ptr_as_bytes
create_word_dict = common.create_word_dict
load_words = common.load_words
tkn_to_idx = common.tkn_to_idx
🚀 Feature Request
Provide a simple inference pipeline for the
wav2vec 2.0
model.Motivation
Current inference script
examples/speech_recognition/infer.py
handles a lot of cases, resulting being extremely complex.Pitch
A single python script that loads and runs inference with
wav2vec 2.0
pre-trained model on a single wav file or on a programmatically loaded waveform signal.Alternatives
-
Additional context
This kind of inference pipeline would enable indi researchers to test the model on their audio dataset and and against other models.