Open kerolos opened 5 months ago
Hi @kerolos, thanks for opening this discussion! I can help you get your video recipe set up.
Is there a plan to add a recipe that supports video data in Lhotse?
There is one AV recipe currently for GRID AV corpus: https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/grid.py In general, lhotse recipes download and prepare the manifests for datasets, but actual training is out of lhotse's scope. You may want to set up a separate repository with your experiment's code that imports lhotse.
How can I start using customized features, for example, using MediaPipe Pose Estimation tools?
Once you create a recording, you can load the video, process it with some module, and save + attach as a custom field to the cut. For example:
video_recording = Recording.from_file("/path/to/-fZc293MpJk_0-1-rgb_front.mp4") # lhotse will auto-construct video recording manifest
video_cut = video_recording.to_cut()
video_frames = video_cut.load_video() # video frames is a uint8 np.array with shape (T, C, H, W) [or some other permutation, I don't remember off the top of my head]
video_features = compute_some_features(video_frames) # video_features is np.array with arbitrary shape
# Option 1 -> save to some storage directly
# temporal_dim indicates which dimension in video_features shape corresponds to time; set accordingly.
with NumpyHdf5Writer("video_features.h5") as writer:
video_cut.video_features = writer.store_array(video_cut.id, video_features, frame_shift=video_recording.video.fps, temporal_dim=0)
# Option 2 -> holds data in memory, write to some storage later (useful if you're going to use Lhotse Shar format):
video_cut = video_cut.attach_tensor("video_features", video_features, frame_shift=video_recording.video.fps, temporal_dim=0)
If you save the final video_cut
, you can then later load video_features with cut.load_video_features()
and access the manifest via cut.video_features
(special field and method are auto-added for custom fields registered via attach_tensor). You can compute many different features and attach all of them under different names.
What format should be used to save the extracted features (i.has_features ) and saved as features and load them later for training ?
I would use one of numpy format writers in lhotse (e.g. NumpyHdf5Writer
in the example above). Don't use lilcom unless you are sure it makes sense (it is a lossy format optimized for log-domain features). You may also want to explore lhotse shar format which I think should work with video recordings (and definitely works with video features extracted as above). It is better optimized for I/O which might help you process large video data in training.
That said, video features would likely require better compression for very large datasets, which is something we can explore later.
i have also Frames per second, it is not always fixed it in between 24 fps to 50 fps , how can i deal with that ?
You can access the fps via recording.video.fps
or cut.video.fps
. If you want to resample the video, you have two options: 1) either load the whole thing / cut of a given duration, and downsample/resample then in python; 2) leverage torchaudio ffmpeg bindings to resample the video (you might need to check out their tutorials to learn how to pass specific ffmpeg transform commands and find a way to expose/add it in AudioSource
API. For reference, this code loads the video)
Final comment, Recording
manifest doesn't support custom fields, so you'd be better off moving feature_path
key to supervision as {..., "custom": {"feature_path": ...}}
In the beginning, I tried to use to load the video mp4 format (this mp4 does not have an audio form sign language dataset) :
recording = Recording.from_file("test_rgb_front_clips/raw_videos/_fZbAxSSbX4_0-5-rgb_front.mp4")
I got this error :
File "/mnt/HD_8TB/training/_icefall_script_/SignRcg/local/_tmp_/readh54test.py", line 76, in <module>
video_recording = Recording.from_file("/mnt/TB16/sign2text/dataset/How2Sign/clips/test_rgb_front_clips/raw_videos/_fZbAxSSbX4_0-5-rgb_front.mp4") # lhotse will auto-construct video recording manifest
File "/home/kerolos/anaconda3/envs/icefall-run/lib/python3.8/site-packages/lhotse/audio/recording.py", line 200, in from_file
audio_info = info(
File "/home/kerolos/anaconda3/envs/icefall-run/lib/python3.8/site-packages/lhotse/audio/backend.py", line 1494, in info
return get_current_audio_backend().info(
File "/home/kerolos/anaconda3/envs/icefall-run/lib/python3.8/site-packages/lhotse/audio/backend.py", line 750, in info
raise AudioLoadingError(
lhotse.audio.utils.AudioLoadingError: Fetching info about audio from '/mnt/TB16/sign2text/dataset/How2Sign/clips/test_rgb_front_clips/raw_videos/_fZbAxSSbX4_0-5-rgb_front.mp4' failed. Details:
Exception #0 (<class 'lhotse.audio.backend.LibsndfileBackend'>): <class 'soundfile.LibsndfileError'>: Error opening '/mnt/TB16/sign2text/dataset/How2Sign/clips/test_rgb_front_clips/raw_videos/_fZbAxSSbX4_0-5-rgb_front.mp4': Format not recognised.
Exception #1 (<class 'lhotse.audio.backend.TorchaudioDefaultBackend'>): <class 'RuntimeError'>: Failed to fetch metadata from /mnt/TB16/sign2text/dataset/How2Sign/clips/test_rgb_front_clips/raw_videos/_fZbAxSSbX4_0-5-rgb_front.mp4
Set LHOTSE_AUDIO_LOADING_EXCEPTION_VERBOSE=1 environment variable for full stack traces.
I would like to use the "Lhotse SHAR format" to save SHAR files from manifests jsonl. I have two options:
from lhotse import RecordingSet, SupervisionSet, CutSet
from lhotse.shar import SharWriter
output_dir = "./data-shar"
recordings_manifest = src_dir / 'recordings.jsonl'
supervisions_manifest = src_dir / 'supervisions.jsonl'
recordings = RecordingSet.from_jsonl(recordings_manifest)
supervisions = SupervisionSet.from_jsonl(supervisions_manifest)
cuts = CutSet.from_manifests(recordings, supervisions).trim_to_supervisions()
try:
shards = cuts.to_shar(output_dir, fields={"recording": "mp4"}, shard_size=15)
except AssertionError as e:
print(f"Error: {e}")
Error:
AssertionError: Unknown field type (got: 'mp4', we support only: wav, flac, mp3, opus, lilcom, numpy, jsonl)
from lhotse import load_manifest_lazy, Recording
from lhotse.shar import ArrayTarWriter
import cv2
import logging
from tqdm import tqdm
import mediapipe as mp
output_dir = "./data-shar"
recordings_manifest_path = src_dir / 'recordings.jsonl'
supervisions_manifest_path = src_dir / 'supervisions.jsonl'
recordings_manifest = load_manifest_lazy(recordings_manifest_path)
supervisions_manifest = load_manifest_lazy(supervisions_manifest_path)
tar_path = output_dir / "video_features.%06d.tar"
with ArrayTarWriter(tar_path, shard_size=15) as writer, tqdm(total=len(recordings_manifest)) as pbar, mp.solutions.holistic.Holistic(
static_image_mode=False, model_complexity=0, min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
for recording in recordings_manifest:
try:
video_recording = Recording.from_dict(recording.to_dict())
video_cut = video_recording.to_cut()
video_frames = video_cut.load_video()
video_path = video_recording.sources[0].source # Get the video path from sources
logging.info(f"Loading video frames from {video_path}")
# Get FPS using OpenCV
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
video_features = extract_features_from_video(video_path, holistic)
if video_features is None:
logging.error(f"Failed to load video frames for recording ID: {video_recording.id}, video path: {video_path}")
continue
# Attach features to video_cut
video_cut = video_cut.attach_tensor("video_features", video_features, frame_shift=float(1.0 / fps), temporal_dim=0)
# Store the features using ArrayTarWriter
writer.write(video_cut.id, video_features, video_cut.video_features)
except Exception as e:
logging.error(f"Error processing recording ID {recording.id}: {e}")
pbar.update(1)
I can save features video_features.000000.tar (inside this folder for each video has two files -fZc293MpJk_0-1-rgb_front.json , and -fZc293MpJk_0-1-rgb_front.npy) and josn file looks like that : {"array": {"storage_type": "shar", "storage_path": "", "storage_key": "", "shape": [17, 1662]}, "temporal_dim": 0, "frame_shift": 0.02, "start": 0} futhermore i checked npy file it has the correct dimention.
Hint: i have not compressed with "lilcom" in ArrayTarWriter and also not saved feature_shards = writer.output_paths
I also want to be able to use the from_shar function and later training DataLoader with Lhotse Shar:
cuts_nodata = CutSet.from_shar(fields={"cuts": shards["cuts"]})
or
cuts = CutSet.from_shar(
fields={
"cuts": shards["cuts"],
"recording": shards["recording"],
},
In this tutorial (examples: 04-lhotse-shar.ipynb) Implementation note: the use of IterableDataset: It has been used the features from fbank on the fly without reading the feature from fields ""fbank": feature_shards," .llc or use another array file .npy like in my case ?
How the code be modified in the way to read the existed features from shads "feature_shards" in this DynamicBucketingSampler not extracting a new one from shards recording in (Implementation note: the use of IterableDataset session ) ?
Thanks in advance @pzelasko
- How can avoid this error from the class Recording ?
Video loading features depend on you having a recent version of pytorch, torchaudio, and compatible ffmpeg version to load videos. Based on the call stack I think maybe you don't have this backend available. Try updating your torch/torchaudio and setting the env var export LHOTSE_AUDIO_BACKEND= FfmpegTorchaudioStreamerBackend
to force torchaudio backend for this.
- How can I escape the alignments in load_video() function " method that loads video + audio (and keeps them in sync duration-wise)"
Try using with_audio=False
arg for https://github.com/lhotse-speech/lhotse/blob/master/lhotse/audio/recording.py#L479C9-L479C33
As for your other question:
Option 1: Saving Cuts from Recordings and Supervisions Manifests (from_jsonl):
Yeah we'll need to add mp4 support for AudioTarWriter. I don't have the bandwidth for this right now but I can give help you get started. First we'll need to add save_audio
function that can actually save both audio and video, to torchaudio ffmpeg streamer backend here. You can check the read_audio
function and implement save_audio
analogously to support the same set of features which I hope would be straightforward. Then, we'll need to register the mp4 format in shar writers, here and here. I think this is sufficient to get it working.
How can I save the missing parts in Lhotse SHAR format for both options "Cuts (e.g., cuts.000000.jsonl.gz), Recordings (e.g., recording.000000.tar) ?
You have two options. There is a high-level utility cuts.to_shar()
which uses LazySharWriter
that saves all the fields together (you have to specify which fields and which format, e.g. {"recording": "mp4", "video_features": "numpy"}
). You can use it to write only the cuts first, and then use your existing option 2 code to save video features next to cuts. It will create a valid Shar directory. Exporting recordings to mp4 is discussed above.
How the code be modified in the way to read the existed features from shads "feature_shards" in this DynamicBucketingSampler not extracting a new one from shards recording in (Implementation note: the use of IterableDataset session ) ?
I think what you want is, after executing the suggestions before, this:
cuts = CutSet.from_shar(
fields={
"cuts": shards["cuts"],
"video_features": shards["video_features"],
},
)
Thanks a lot @pzelasko, I really appreciate your help and support: I did the modification for lhotse/shar/writers/shar.py, and lhotse/shar/writers/audio.py for supporting mp4 format.
from lhotse import RecordingSet, SupervisionSet, CutSet
from lhotse.shar import SharWriter
import torch
print(torch.cuda.is_available())
import torchaudio
print(torchaudio.get_audio_backend())
#torchaudio.set_audio_backend("ffmpeg")
#torchaudio.set_audio_backend("sox_io") # or "soundfile"
recordings_manifest = "./recordings.jsonl'
supervisions_manifest = "./supervisions.jsonl'
# Load recordings and supervisions
recordings = RecordingSet.from_jsonl(recordings_manifest)
supervisions = SupervisionSet.from_jsonl(supervisions_manifest)
# Create CutSet and trim to supervisions
cuts = CutSet.from_manifests(recordings, supervisions).trim_to_supervisions()
# Write shards
shards = cuts.to_shar(output_dir, fields={"recording": "mp4"}, shard_size=15)
print("Shards created:", shards)
I got this error:
test_code_v2.py:6: UserWarning: torchaudio._backend.get_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call.
print(torchaudio.get_audio_backend())
None
_test_code_v2.py:7: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call.
torchaudio.set_audio_backend("ffmpeg")
Traceback (most recent call last):
File "/home/kerolos/anaconda3/envs/lhotse-env_3.8/lib/python3.8/site-packages/lhotse/utils.py", line 848, in wrapper
return fn(*args, **kwargs)
File "/home/kerolos/anaconda3/envs/lhotse-env_3.8/lib/python3.8/site-packages/lhotse/audio/recording.py", line 430, in load_audio
samples = source.load_audio(
File "/home/kerolos/anaconda3/envs/lhotse-env_3.8/lib/python3.8/site-packages/lhotse/audio/source.py", line 87, in load_audio
samples, sampling_rate = read_audio(
File "/home/kerolos/anaconda3/envs/lhotse-env_3.8/lib/python3.8/site-packages/lhotse/audio/backend.py", line 1493, in read_audio
return get_current_audio_backend().read_audio(
File "/home/kerolos/anaconda3/envs/lhotse-env_3.8/lib/python3.8/site-packages/lhotse/audio/backend.py", line 321, in read_audio
return torchaudio_ffmpeg_load(
File "/home/kerolos/anaconda3/envs/lhotse-env_3.8/lib/python3.8/site-packages/lhotse/audio/backend.py", line 1050, in torchaudio_ffmpeg_load
info = streamer.get_src_stream_info(streamer.default_audio_stream)
File "/home/kerolos/anaconda3/envs/lhotse-env_3.8/lib/python3.8/site-packages/torio/io/_streaming_media_decoder.py", line 587, in get_src_stream_info
return _parse_si(self._be.get_src_stream_info(i))
TypeError: get_src_stream_info(): incompatible function arguments. The following argument types are supported:
1. (self: torio.lib._torio_ffmpeg4.StreamingMediaDecoder, arg0: int) -> torio.lib._torio_ffmpeg4.SourceStreamInfo
Invoked with: <torio.lib._torio_ffmpeg4.StreamingMediaDecoder object at 0x7e14008b4c70>, None
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/kerolos/anaconda3/envs/lhotse-env_3.8/lib/python3.8/site-packages/lhotse/utils.py", line 848, in wrapper
return fn(*args, **kwargs)
File "/home/kerolos/anaconda3/envs/lhotse-env_3.8/lib/python3.8/site-packages/lhotse/cut/mono.py", line 77, in load_audio
return self.recording.load_audio(
File "/home/kerolos/anaconda3/envs/lhotse-env_3.8/lib/python3.8/site-packages/lhotse/utils.py", line 850, in wrapper
raise type(e)(
TypeError: get_src_stream_info(): incompatible function arguments. The following argument types are supported:
1. (self: torio.lib._torio_ffmpeg4.StreamingMediaDecoder, arg0: int) -> torio.lib._torio_ffmpeg4.SourceStreamInfo
Invoked with: <torio.lib._torio_ffmpeg4.StreamingMediaDecoder object at 0x7e14008b4c70>, None
[extra info] When calling: Recording.load_audio(args=(Recording(id='-fZc293MpJk_0-1-rgb_front', sources=[AudioSource(type='file', channels=[0], source='/mnt/TB16/sign2text/dataset/How2Sign/clips/test_rgb_front_clips/raw_videos/-fZc293MpJk_0-1-rgb_front.mp4')], sampling_rate=24, num_samples=17, duration=6.53, channel_ids=[0], transforms=None),) kwargs={'channels': 0, 'offset': 0.0, 'duration': 6.541666666666667})
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "_test_code_v2.py", line 23, in <module>
shards = cuts.to_shar(output_dir, fields={"recording": "mp4"}, shard_size=15)
File "/home/kerolos/anaconda3/envs/lhotse-env_3.8/lib/python3.8/site-packages/lhotse/cut/set.py", line 644, in to_shar
return _export_to_shar_single(
File "/home/kerolos/anaconda3/envs/lhotse-env_3.8/lib/python3.8/site-packages/lhotse/cut/set.py", line 3433, in _export_to_shar_single
writer.write(cut)
File "/home/kerolos/anaconda3/envs/lhotse-env_3.8/lib/python3.8/site-packages/lhotse/shar/writers/shar.py", line 129, in write
data = cut.load_audio()
File "/home/kerolos/anaconda3/envs/lhotse-env_3.8/lib/python3.8/site-packages/lhotse/utils.py", line 850, in wrapper
raise type(e)(
TypeError: get_src_stream_info(): incompatible function arguments. The following argument types are supported:
1. (self: torio.lib._torio_ffmpeg4.StreamingMediaDecoder, arg0: int) -> torio.lib._torio_ffmpeg4.SourceStreamInfo
Invoked with: <torio.lib._torio_ffmpeg4.StreamingMediaDecoder object at 0x7e14008b4c70>, None
[extra info] When calling: Recording.load_audio(args=(Recording(id='-fZc293MpJk_0-1-rgb_front', sources=[AudioSource(type='file', channels=[0], source='/mnt/TB16/sign2text/dataset/How2Sign/clips/test_rgb_front_clips/raw_videos/-fZc293MpJk_0-1-rgb_front.mp4')], sampling_rate=24, num_samples=17, duration=6.53, channel_ids=[0], transforms=None),) kwargs={'channels': 0, 'offset': 0.0, 'duration': 6.541666666666667})
[extra info] When calling: MonoCut.load_audio(args=(MonoCut(id='-fZc293MpJk_0-1-rgb_front', start=0.0, duration=6.541666666666667, channel=0, supervisions=[SupervisionSegment(id='-fZc293MpJk_0-1-rgb_front', recording_id='-fZc293MpJk_0-1-rgb_front', start=0.0, duration=6.53, channel=0, text='hi', language=None, speaker='-fZc293MpJk', gender=None, custom={'feature_raw_path': '/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/original/raw_features/-fZc293MpJk_0-1-rgb_front.txt'}, alignment=None)], features=None, recording=Recording(id='-fZc293MpJk_0-1-rgb_front', sources=[AudioSource(type='file', channels=[0], source='/mnt/TB16/sign2text/dataset/How2Sign/clips/test_rgb_front_clips/raw_videos/-fZc293MpJk_0-1-rgb_front.mp4')], sampling_rate=24, num_samples=17, duration=6.53, channel_ids=[0], transforms=None), custom=None),) kwargs={})
Then i was able to create the required files for shar manually: and i checked them with respect to the files created from this tutorial for speech dataset (https://github.com/lhotse-speech/lhotse/blob/master/examples/04-lhotse-shar.ipynb).
the json cut file looks like that :
{"id": "-fZc293MpJk_0-1-rgb_front", "start": 0.0, "duration": 6.53, "channel": 0, "supervisions": [{"id": "-fZc293MpJk_0-1-rgb_front", "recording_id": "-fZc293MpJk_0-1-rgb_front", "start": 0.0, "duration": 6.53, "channel": 0, "text": "hi", "language": "English", "speaker": "-fZc293MpJk"}], "recording": {"id": "-fZc293MpJk_0-1-rgb_front", "sources": [{"type": "shar", "channels": [0], "source": ""}], "sampling_rate": 24, "num_samples": 17, "duration": 6.53, "channel_ids": [0]}, "type": "MonoCut"}
{"id": "-fZc293MpJk_2-1-rgb_front", "start": 0.0, "duration": 13.03, "channel": 0, "supervisions": [{"id": "-fZc293MpJk_2-1-rgb_front", "recording_id": "-fZc293MpJk_2-1-rgb_front", "start": 0.0, "duration": 13.03, "channel": 0, "text": "the aileron is the control surface in the wing that is controlled by lateral movement right and left of the stick", "language": "English", "speaker": "-fZc293MpJk"}], "recording": {"id": "-fZc293MpJk_2-1-rgb_front", "sources": [{"type": "shar", "channels": [0], "source": ""}], "sampling_rate": 24, "num_samples": 412, "duration": 13.03, "channel_ids": [0]}, "type": "MonoCut"}
a) modify the DataModule python script to read shar data:
class SignLanguageDataModule:
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
train_path = "data/test_V2/shar_out"
cuts_video_feat_train = CutSet.from_shar(
fields={
"cuts": shards["cuts"],
"video_features": shards["video_features"],
},
shuffle_shards=True,
stateful_shuffle=True,
seed="randomized",
).repeat()
features_array = cuts_video_feat_train[0].load_video_features()
print("Features first array shape:", features_array.shape)
print("Features first array:", features_array)
logging.info(f"train_cuts size: {len(cuts_video_feat_train)}")
return cuts_video_feat_train
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
shuffle=True,
max_duration=10.0,
num_buckets=10,
rank=0,
world_size=1,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=10.0,
shuffle=self.args.shuffle,
)
logging.info(f"train_sampler created: {train_sampler}")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_iter_dataset = IterableDatasetWrapper(
dataset=train,
sampler=train_sampler,
)
train_dl = DataLoader(
train_iter_dataset,
batch_size=None,
num_workers=self.args.num_workers,
worker_init_fn=make_worker_init_fn(seed=0),
)
logging.info(f"train_dl created: {train_dl}")
return train_dl
train_sign.py:
signData = SignLanguageDataModule(args)
train_cuts = signData.train_cuts()
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None
train_dl = signData.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
loss.backward()
optimizer.zero_grad()
except Exception as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params, sp=sp)
raise
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
Use '--' to separate paths from revisions, like this:
'git <command> [<revision>...] -- [<file>...]'
fatal: Needed a single revision
fatal: your current branch 'main' does not have any commits yet
2024-07-18 11:31:22,707 INFO [train_shar.py:1102] Training started
2024-07-18 11:31:22,710 INFO [train_shar.py:1112] Device: cuda:0
2024-07-18 11:31:22,715 INFO [train_shar.py:1124] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 1692, 'subsampling_factor': 4, 'warm_step': 2000, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': 'e400fa3b456faf8afe0ee5bfe572946b4921a3db', 'k2-git-date': 'Sat Jul 15 04:21:50 2023', 'lhotse-version': '1.24.2', 'torch-version': '2.0.1+cu117', 'torch-cuda-available': True, 'torch-cuda-version': '11.7', 'python-version': '3.8', 'icefall-git-branch': None, 'icefall-git-sha1': None, 'icefall-git-date': None, 'icefall-path': '/home/kerolos/projects/asr/icefall', 'k2-path': '/home/kerolos/anaconda3/envs/icefall-run/lib/python3.8/site-packages/k2/__init__.py', 'lhotse-path': '/home/kerolos/anaconda3/envs/icefall-run/lib/python3.8/site-packages/lhotse/__init__.py', 'hostname': 'kerolos', 'IP address': '127.0.1.1'}, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 30, 'start_epoch': 1, 'start_batch': 0, 'exp_dir': PosixPath('/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/exp/models/model_zipformer'), 'bpe_model': '/mnt/TB16/sign2text/train_SignModel/en/20_06_2024//exp//lang/bpe.model', 'base_lr': 0.03, 'lr_batches': 7500, 'lr_epochs': 3.5, 'ref_duration': 600, 'context_size': 2, 'prune_range': 5, 'lm_scale': 0.25, 'am_scale': 0.0, 'simple_loss_scale': 0.5, 'ctc_loss_scale': 0.2, 'seed': 42, 'print_diagnostics': False, 'inf_check': False, 'save_every_n': 4000, 'keep_last_k': 30, 'average_period': 200, 'use_fp16': True, 'num_encoder_layers': '2,2,3,4,3,2', 'downsampling_factor': '1,2,4,8,4,2', 'feedforward_dim': '512,768,1024,1536,1024,768', 'num_heads': '4,4,4,8,4,4', 'encoder_dim': '192,256,384,512,384,256', 'query_head_dim': '32', 'value_head_dim': '12', 'pos_head_dim': '4', 'pos_dim': 48, 'encoder_unmasked_dim': '192,192,256,256,256,192', 'cnn_module_kernel': '31,31,15,15,15,31', 'decoder_dim': 512, 'joiner_dim': 512, 'causal': True, 'chunk_size': '2,4,8,-1', 'left_context_frames': '2,4,8,-1', 'use_transducer': True, 'use_ctc': True, 'manifest_dir': PosixPath('/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/manifests/test_V2/shar_out'), 'max_duration': 200, 'bucketing_sampler': True, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'shuffle': True, 'drop_last': True, 'return_cuts': False, 'num_workers': 1, 'input_strategy': 'PrecomputedFeatures', 'train_manifest': 'cut_V3', 'dev_manifest': 'cut_V3', 'test_manifest': 'kaldi_cuts_test.jsonl.gz', 'blank_id': 0, 'vocab_size': 2000}
2024-07-18 11:31:22,715 INFO [train_shar.py:1126] About to create model
2024-07-18 11:31:23,493 INFO [train_shar.py:1130] Number of model parameters: 80117559
2024-07-18 11:31:25,730 INFO [sign_datamodule.py:343] About to get train cuts
2024-07-18 11:31:25,747 INFO [sign_datamodule.py:366] Loaded 6 cut shards and 6 video feature shards
shards :{'cuts': ['/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/manifests/test_V2/shar_out/cuts.000000.jsonl.gz', '/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/manifests/test_V2/shar_out/cuts.000001.jsonl.gz', '/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/manifests/test_V2/shar_out/cuts.000002.jsonl.gz', '/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/manifests/test_V2/shar_out/cuts.000003.jsonl.gz', '/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/manifests/test_V2/shar_out/cuts.000004.jsonl.gz', '/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/manifests/test_V2/shar_out/cuts.000005.jsonl.gz'], 'video_features': ['/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/manifests/test_V2/shar_out/video_features.000000.tar', '/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/manifests/test_V2/shar_out/video_features.000001.tar', '/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/manifests/test_V2/shar_out/video_features.000002.tar', '/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/manifests/test_V2/shar_out/video_features.000003.tar', '/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/manifests/test_V2/shar_out/video_features.000004.tar', '/mnt/TB16/sign2text/train_SignModel/en/20_06_2024/data/manifests/test_V2/shar_out/video_features.000005.tar'']}
Features first array shape: (26, 1692) Features first array: [[ 0.51053083 0.4385103 -0.52307171 ... 0. 0.
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Zipformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
batch_idx_train = params.batch_idx_train
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
loss = 0.0
if params.use_transducer:
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
if params.use_transducer:
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
return loss, info
Thanks in advance
In the dataset class, instead of calling AudioSamples
/ PrecomputedFeatures
/ OnTheFlyFeatures
, simply do the following:
video_features, video_features_lens = lhotse.dataset.collation.collate_custom_field(cuts, "video_features")
batch["inputs"] = video_features
You might need to work out some details but hopefully this can get you started.
You can also safely remove the scan_for_pessimistic_oom
part that's giving you the trouble.
Extend Lhotse to support video features for tasks such as sign language recognition (e.g., How2Sign) and human activity recognition. This enhancement will be useful for the Icefall platform.
Details
With the recent support for video in PR #1151, I am interested in developing a new recipe to handle video data and extract features using tools like MediaPipe.
Objectives
Recipe Addition:
lhotse/recipes
directory.Feature Extraction:
Implementation Steps
Create Manifest Files:
Recordings manifest (
recordings.jsonl
):Supervisions manifest (
supervisions.jsonl
):Feature Extraction Script:
Create a script
compute_features_sign_language.py
:Questions
I would appreciate any guidance or support on implementing this feature and utilizing it within the Icefall platform @pzelasko .
Thank you!