TMElyralab / MuseTalk

MuseTalk: Real-Time High Quality Lip Synchorization with Latent Space Inpainting
Other
1.85k stars 224 forks source link

It's just not working on Windows 11 #40

Closed nitinmukesh closed 1 month ago

nitinmukesh commented 2 months ago

I am trying with the attached files and even if I wait for 2 hours no progress. Am I doing something wrong

image.zip

test.yaml

task_0:
 video_path: "data/image/face.jpeg"
 audio_path: "data/audio/audio.wav"
C:\sd\MuseTalk>venv\scripts\activate

(venv) C:\sd\MuseTalk>python -m scripts.inference --inference_config configs/inference/test.yaml
add ffmpeg to path
Loads checkpoint by local backend from path: ./models/dwpose/dw-ll_ucoco_384.pth
cuda start
{'task_0': {'video_path': 'data/image/face.jpeg', 'audio_path': 'data/audio/audio.wav'}}
video in 25 FPS, audio idx in 50FPS
extracting landmarks...time consuming
reading images...
100%|████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s]
get key_landmark and face bounding boxes with the default value
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.33s/it]
********************************************bbox_shift parameter adjustment**********************************************************
Total frame:「1」 Manually adjust range : [ -23~25 ] , the current value: 0
*************************************************************************************************************************************
start inference
  0%|                                                                                            | 0/3 [00:00<?, ?it/s]
image

FFmpeg is in path

image
C:\Users\nitin>ffmpeg
ffmpeg version 6.1-full_build-www.gyan.dev Copyright (c) 2000-2023 the FFmpeg developers
  built with gcc 12.2.0 (Rev10, Built by MSYS2 project)
  configuration: --enable-gpl --enable-version3 --enable-static --pkg-config=pkgconf --disable-w32threads --disable-autodetect --enable-fontconfig --enable-iconv --enable-gnutls --enable-libxml2 --enable-gmp --enable-bzlib --enable-lzma --enable-libsnappy --enable-zlib --enable-librist --enable-libsrt --enable-libssh --enable-libzmq --enable-avisynth --enable-libbluray --enable-libcaca --enable-sdl2 --enable-libaribb24 --enable-libaribcaption --enable-libdav1d --enable-libdavs2 --enable-libuavs3d --enable-libzvbi --enable-librav1e --enable-libsvtav1 --enable-libwebp --enable-libx264 --enable-libx265 --enable-libxavs2 --enable-libxvid --enable-libaom --enable-libjxl --enable-libopenjpeg --enable-libvpx --enable-mediafoundation --enable-libass --enable-frei0r --enable-libfreetype --enable-libfribidi --enable-libharfbuzz --enable-liblensfun --enable-libvidstab --enable-libvmaf --enable-libzimg --enable-amf --enable-cuda-llvm --enable-cuvid --enable-ffnvcodec --enable-nvdec --enable-nvenc --enable-dxva2 --enable-d3d11va --enable-libvpl --enable-libshaderc --enable-vulkan --enable-libplacebo --enable-opencl --enable-libcdio --enable-libgme --enable-libmodplug --enable-libopenmpt --enable-libopencore-amrwb --enable-libmp3lame --enable-libshine --enable-libtheora --enable-libtwolame --enable-libvo-amrwbenc --enable-libcodec2 --enable-libilbc --enable-libgsm --enable-libopencore-amrnb --enable-libopus --enable-libspeex --enable-libvorbis --enable-ladspa --enable-libbs2b --enable-libflite --enable-libmysofa --enable-librubberband --enable-libsoxr --enable-chromaprint
  libavutil      58. 29.100 / 58. 29.100
  libavcodec     60. 31.102 / 60. 31.102
  libavformat    60. 16.100 / 60. 16.100
  libavdevice    60.  3.100 / 60.  3.100
  libavfilter     9. 12.100 /  9. 12.100
  libswscale      7.  5.100 /  7.  5.100
  libswresample   4. 12.100 /  4. 12.100
  libpostproc    57.  3.100 / 57.  3.100
Hyper fast Audio and Video encoder
usage: ffmpeg [options] [[infile options] -i infile]... {[outfile options] outfile}...

Use -h to get full help or, even better, run 'man ffmpeg'
nitinmukesh commented 2 months ago

So it's 3 hours and still progress is 0.

itechmusic commented 2 months ago

image

I tried and it work for me.

After I resized the image from 667 x 741 to 666 x 740, I got this result

https://github.com/TMElyralab/MuseTalk/assets/163980830/a7ca806b-6302-4893-a007-14892039fbbd

nitinmukesh commented 2 months ago

Thank you @itechmusic.

I am still not sure how to make it work. I did adjusted the dimension of image as suggested by you.

There is no error, no missing package so don't know what should I try or where to look. VRAM is also available.

In results folder

image

Path is also correct

image

(venv) C:\sd\MuseTalk>pip list

Package                      Version
---------------------------- ------------
absl-py                      2.1.0
accelerate                   0.28.0
addict                       2.4.0
aiofiles                     23.2.1
aliyun-python-sdk-core       2.15.1
aliyun-python-sdk-kms        2.16.2
altair                       5.3.0
annotated-types              0.6.0
antlr4-python3-runtime       4.9.3
anyio                        4.3.0
astunparse                   1.6.3
attrs                        23.2.0
beautifulsoup4               4.12.3
cachetools                   5.3.3
certifi                      2024.2.2
cffi                         1.16.0
charset-normalizer           3.3.2
chumpy                       0.70
click                        8.1.7
colorama                     0.4.6
contourpy                    1.2.1
crcmod                       1.7
cryptography                 42.0.5
cycler                       0.12.1
Cython                       3.0.10
decorator                    4.4.2
diffusers                    0.27.2
exceptiongroup               1.2.1
fastapi                      0.110.2
ffmpeg-python                0.2.0
ffmpy                        0.3.2
filelock                     3.13.4
flatbuffers                  24.3.25
fonttools                    4.51.0
fsspec                       2024.3.1
future                       1.0.0
gast                         0.4.0
gdown                        5.1.0
google-auth                  2.29.0
google-auth-oauthlib         0.4.6
google-pasta                 0.2.0
gradio                       4.27.0
gradio_client                0.15.1
grpcio                       1.62.2
h11                          0.14.0
h5py                         3.11.0
httpcore                     1.0.5
httpx                        0.27.0
huggingface-hub              0.22.2
idna                         3.7
imageio                      2.34.1
imageio-ffmpeg               0.4.9
importlib_metadata           7.1.0
importlib_resources          6.4.0
jax                          0.4.26
Jinja2                       3.1.3
jmespath                     0.10.0
json-tricks                  3.17.3
jsonschema                   4.21.1
jsonschema-specifications    2023.12.1
keras                        2.12.0
kiwisolver                   1.4.5
libclang                     18.1.1
Markdown                     3.6
markdown-it-py               3.0.0
MarkupSafe                   2.1.5
matplotlib                   3.8.4
mdurl                        0.1.2
ml-dtypes                    0.4.0
mmcv                         2.1.0
mmdet                        3.2.0
mmengine                     0.10.4
mmpose                       1.3.1
model-index                  0.1.11
moviepy                      1.0.3
mpmath                       1.3.0
munkres                      1.1.4
networkx                     3.3
numpy                        1.23.5
oauthlib                     3.2.2
omegaconf                    2.3.0
opencv-python                4.9.0.80
opendatalab                  0.0.10
openmim                      0.3.9
openxlab                     0.0.38
opt-einsum                   3.3.0
ordered-set                  4.1.0
orjson                       3.10.1
oss2                         2.17.0
packaging                    24.0
pandas                       2.2.2
pillow                       10.3.0
pip                          22.2.1
platformdirs                 4.2.0
proglog                      0.1.10
protobuf                     4.25.3
psutil                       5.9.8
pyasn1                       0.6.0
pyasn1_modules               0.4.0
pycocotools                  2.0.7
pycparser                    2.22
pycryptodome                 3.20.0
pydantic                     2.7.0
pydantic_core                2.18.1
pydub                        0.25.1
Pygments                     2.17.2
pyparsing                    3.1.2
PySocks                      1.7.1
python-dateutil              2.9.0.post0
python-multipart             0.0.9
pytz                         2023.4
pywin32                      306
PyYAML                       6.0.1
referencing                  0.34.0
regex                        2024.4.16
requests                     2.28.2
requests-oauthlib            2.0.0
rich                         13.4.2
rpds-py                      0.18.0
rsa                          4.9
ruff                         0.4.1
safetensors                  0.4.3
scipy                        1.13.0
semantic-version             2.10.0
setuptools                   60.2.0
shapely                      2.0.4
shellingham                  1.5.4
six                          1.16.0
sniffio                      1.3.1
soundfile                    0.12.1
soupsieve                    2.5
spaces                       0.26.1
starlette                    0.37.2
static-ffmpeg                2.5
sympy                        1.12
tabulate                     0.9.0
tensorboard                  2.12.0
tensorboard-data-server      0.7.2
tensorboard-plugin-wit       1.8.1
tensorflow                   2.12.0
tensorflow-estimator         2.12.0
tensorflow-intel             2.12.0
tensorflow-io-gcs-filesystem 0.31.0
termcolor                    2.4.0
terminaltables               3.1.10
tokenizers                   0.15.2
tomli                        2.0.1
tomlkit                      0.12.0
toolz                        0.12.1
torch                        2.0.1+cu118
torchaudio                   2.0.2+cu118
torchvision                  0.15.2+cu118
tqdm                         4.65.2
transformers                 4.39.2
typer                        0.12.3
typing_extensions            4.11.0
tzdata                       2024.1
urllib3                      1.26.18
uvicorn                      0.29.0
websockets                   11.0.3
Werkzeug                     3.0.2
wheel                        0.43.0
wrapt                        1.14.1
xtcocotools                  1.14.3
yapf                         0.40.2
zipp                         3.18.1
nitinmukesh commented 2 months ago

It is getting stuck at the following line (recon = vae.decode_latents(pred_latents))

@torch.no_grad() def main(args):

    print("start inference")
    video_num = len(whisper_chunks)
    print("video_num ", video_num)
    batch_size = args.batch_size
    print("batch_size ", batch_size)
    gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
    print("gen ", gen)
    res_frame_list = []
    for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
        print("whisper_batch ", whisper_batch)
        print("latent_batch ", latent_batch)
        print("i ", i)

        tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
        print("tensor_list ", tensor_list)
        audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
        print("audio_feature_batch ", audio_feature_batch)
        audio_feature_batch = pe(audio_feature_batch)
        print("audio_feature_batch ", audio_feature_batch)

        pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
        print("pred_latents ", pred_latents)
        recon = vae.decode_latents(pred_latents)
        print("recon ", recon)
        for res_frame in recon:
            print("res_frame ", res_frame)
            res_frame_list.append(res_frame)
(venv) C:\sd\MuseTalk>python -m scripts.inference --inference_config configs/inference/test.yaml
add ffmpeg to path
Loads checkpoint by local backend from path: ./models/dwpose/dw-ll_ucoco_384.pth
cuda start
{'task_0': {'video_path': 'data/image/face.jpeg', 'audio_path': 'data/audio/audio.wav'}}
video in 25 FPS, audio idx in 50FPS
extracting landmarks...time consuming
reading images...
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 95.58it/s]
get key_landmark and face bounding boxes with the default value
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.89s/it]
********************************************bbox_shift parameter adjustment**********************************************************
Total frame:「1」 Manually adjust range : [ -25~23 ] , the current value: 0
*************************************************************************************************************************************
start inference
video_num  24
batch_size  8
gen  <generator object datagen at 0x000001A7FE10FB50>
  0%|                                                                                            | 0/3 [00:00<?, ?it/s]whisper_batch  [[[-1.6992e-01 -2.8849e-04 -1.4514e-01 ...  9.9609e-01  9.5850e-01
    8.3105e-01]
  [ 6.5479e-01  6.7578e-01  2.2339e-01 ...  6.2012e-01  8.7939e-01
    7.4072e-01]
  [-1.8799e-02  1.9385e-01  2.3035e-01 ...  6.1963e-01  1.0977e+00
    8.1250e-01]
  ...
  [ 8.7891e-01  6.8164e-01  2.3022e-01 ...  5.9766e-01  7.5195e-01
    3.6182e-01]
  [-3.9062e-02  6.7188e-01  2.3315e-01 ...  9.6973e-01  1.5469e+00
    1.0498e+00]
  [ 2.8564e-02  9.0479e-01  1.8616e-01 ...  2.5156e+00  2.3164e+00
    1.1719e+00]]

 [[-1.6992e-01 -2.8849e-04 -1.4514e-01 ...  9.9609e-01  9.5850e-01
    8.3105e-01]
  [ 6.5479e-01  6.7578e-01  2.2339e-01 ...  6.2012e-01  8.7939e-01
    7.4072e-01]
  [-1.8799e-02  1.9385e-01  2.3035e-01 ...  6.1963e-01  1.0977e+00
    8.1250e-01]
  ...
  [ 5.3662e-01  8.0469e-01  8.2959e-01 ...  6.2012e-01  9.1797e-01
    6.6113e-01]
  [ 5.0903e-02  6.9922e-01  6.6895e-01 ...  3.8623e-01  1.7100e+00
    7.8516e-01]
  [ 1.0413e-01  4.3945e-02  1.3086e+00 ...  1.6094e+00  2.7227e+00
    3.5864e-01]]

 [[-1.6992e-01 -2.8849e-04 -1.4514e-01 ...  9.9609e-01  9.5850e-01
    8.3105e-01]
  [ 6.5479e-01  6.7578e-01  2.2339e-01 ...  6.2012e-01  8.7939e-01
    7.4072e-01]
  [-1.8799e-02  1.9385e-01  2.3035e-01 ...  6.1963e-01  1.0977e+00
    8.1250e-01]
  ...
  [ 1.6875e+00  1.5293e+00  9.4727e-01 ...  5.1660e-01  8.5889e-01
    5.2393e-01]
  [ 1.8154e+00  1.4775e+00  1.2705e+00 ...  1.1113e+00  1.0674e+00
    9.5215e-01]
  [ 1.2295e+00  2.2734e+00  3.2852e+00 ...  2.1562e+00  2.4258e+00
    1.3252e+00]]

 ...

 [[-4.4922e-01 -5.3613e-01 -8.9355e-01 ...  8.7598e-01  8.7891e-01
    8.3105e-01]
  [ 5.0098e-01  7.6953e-01  5.7764e-01 ...  6.0791e-01  5.8838e-01
    6.2305e-01]
  [ 2.5928e-01  5.1709e-01  8.2031e-01 ...  4.8926e-01  9.0381e-01
    6.9824e-01]
  ...
  [ 8.1689e-01  1.3691e+00  1.0293e+00 ...  3.8672e-01  9.8975e-01
    8.1836e-01]
  [ 3.6963e-01  1.7598e+00  9.6191e-01 ...  6.6504e-01  1.5371e+00
    9.3359e-01]
  [ 2.8662e-01  4.6250e+00  1.6562e+00 ...  4.1284e-01  1.7520e+00
    1.6016e+00]]

 [[ 8.4326e-01  9.7363e-01  6.6113e-01 ...  8.5986e-01  8.7842e-01
    8.5547e-01]
  [ 1.3496e+00  1.3945e+00  8.2373e-01 ...  4.4141e-01  6.6602e-01
    5.6982e-01]
  [ 1.2510e+00  1.2715e+00  1.1982e+00 ...  5.4297e-01  8.2275e-01
    5.5371e-01]
  ...
  [ 8.1104e-01  1.8994e+00  6.9092e-01 ...  4.3701e-01  6.8604e-01
    8.3252e-01]
  [-3.7793e-01  2.3457e+00  9.8096e-01 ...  7.3438e-01  1.0410e+00
    9.4482e-01]
  [-1.2832e+00  4.2266e+00  1.4570e+00 ...  2.8198e-01  5.7861e-01
    1.2695e+00]]

 [[-6.7480e-01 -1.0455e-01  1.6760e-01 ...  8.3008e-01  8.7988e-01
    8.3496e-01]
  [ 1.6562e+00  1.4619e+00  3.9160e-01 ...  4.4092e-01  9.5801e-01
    7.1045e-01]
  [ 1.5234e+00  1.3545e+00  5.2588e-01 ...  3.2446e-01  1.0654e+00
    5.7080e-01]
  ...
  [ 1.0566e+00  1.6387e+00  3.7231e-01 ...  7.9590e-01  7.4268e-01
    1.0000e+00]
  [-2.1680e-01  2.4570e+00  4.3555e-01 ...  5.1758e-01  1.0869e+00
    9.8877e-01]
  [-2.0703e+00  3.5664e+00  1.5537e+00 ...  1.5454e-01  8.2764e-01
    1.5273e+00]]]
latent_batch  tensor([[[[ 0.3817, -0.9511, -0.0135,  ...,  0.3239, -0.3727, -1.2085],
          [-0.7161,  0.1602, -0.8396,  ..., -0.7383, -0.4438, -0.2077],
          [-1.0292, -0.2511, -1.4369,  ..., -0.2085, -0.2418, -1.1311],
          ...,
          [-1.2580, -1.0452, -0.9995,  ..., -0.9927, -1.1169, -1.3844],
          [-1.6654, -1.2902, -1.1895,  ..., -1.1970, -1.2997, -1.6015],
          [-1.1598, -1.9114, -1.4038,  ..., -1.7775, -1.6023, -1.4959]],

         [[-0.1305, -0.9890, -0.2282,  ..., -0.3648, -0.6344, -0.3028],
          [-1.2666, -0.5069, -0.5656,  ..., -1.1775, -0.1794, -0.7154],
          [-1.3254, -0.2052, -1.7900,  ..., -0.6103, -1.5310, -0.8676],
          ...,
          [-2.6156, -2.6934, -2.6734,  ..., -2.7262, -2.6924, -2.6582],
          [-2.6062, -2.6643, -2.6643,  ..., -2.7230, -2.7091, -2.6203],
          [-2.5358, -2.6483, -2.5463,  ..., -2.7173, -2.6786, -2.5009]],

         [[ 0.9009,  0.2777, -0.9097,  ...,  0.9195, -0.0183, -0.5688],
          [ 0.4648, -0.2058,  0.8519,  ..., -0.3059, -0.1156,  0.8472],
          [-0.6064, -0.1322, -0.6453,  ...,  0.9855,  0.5003, -0.4622],
          ...,
          [ 0.8297,  1.1355,  1.1744,  ...,  1.1504,  1.2261,  1.4360],
          [ 0.7979,  1.2265,  1.2789,  ...,  1.2295,  1.2898,  1.5244],
          [ 1.2797,  1.2329,  1.4728,  ...,  1.2937,  1.4724,  1.4008]],

         ...,

         [[-0.0796, -1.0523, -0.1950,  ..., -0.4383, -0.7383, -0.2432],
          [-1.2808, -0.4969, -0.4962,  ..., -1.1745, -0.1237, -0.7957],
          [-1.3536, -0.1554, -1.8771,  ..., -0.5685, -1.5523, -0.9214],
          ...,
          [-1.9274, -1.0349, -1.3094,  ..., -0.3054, -1.3915, -1.3486],
          [-0.7738, -0.7489, -1.1319,  ..., -1.9757, -0.8756, -2.6383],
          [-1.8903, -0.6943, -1.2970,  ..., -0.3808, -1.3761, -1.0375]],

         [[ 1.0789,  0.2526, -0.9007,  ...,  0.8455,  0.0512, -0.5581],
          [ 0.5242, -0.2375,  0.9145,  ..., -0.3762, -0.2012,  0.9488],
          [-0.7725, -0.1789, -0.9380,  ...,  0.9561,  0.3656, -0.5224],
          ...,
          [ 0.6917, -0.3562, -0.2107,  ...,  0.4791, -0.9471,  0.6350],
          [ 1.5163,  0.0446,  0.5912,  ..., -0.0412,  0.0447, -1.1374],
          [ 1.5193, -1.0238, -0.1816,  ...,  0.7875, -0.0535,  0.2286]],

         [[ 0.9593,  0.6785,  0.1850,  ...,  1.3929, -0.4940, -0.7949],
          [ 0.5678,  0.9684,  0.1273,  ..., -0.9839,  0.8837,  0.8088],
          [ 0.2073, -0.0170,  0.0418,  ...,  1.2770,  0.3857, -0.4266],
          ...,
          [-0.0539,  0.4435,  1.6115,  ...,  0.5722,  0.2167,  1.0428],
          [ 1.5132,  0.8811,  1.9802,  ...,  1.3308,  0.2180,  0.2373],
          [ 1.0616, -0.4796,  1.1155,  ...,  0.8659,  1.5309,  0.1066]]],

        [[[ 0.3817, -0.9511, -0.0135,  ...,  0.3239, -0.3727, -1.2085],
          [-0.7161,  0.1602, -0.8396,  ..., -0.7383, -0.4438, -0.2077],
          [-1.0292, -0.2511, -1.4369,  ..., -0.2085, -0.2418, -1.1311],
          ...,
          [-1.2580, -1.0452, -0.9995,  ..., -0.9927, -1.1169, -1.3844],
          [-1.6654, -1.2902, -1.1895,  ..., -1.1970, -1.2997, -1.6015],
          [-1.1598, -1.9114, -1.4038,  ..., -1.7775, -1.6023, -1.4959]],

         [[-0.1305, -0.9890, -0.2282,  ..., -0.3648, -0.6344, -0.3028],
          [-1.2666, -0.5069, -0.5656,  ..., -1.1775, -0.1794, -0.7154],
          [-1.3254, -0.2052, -1.7900,  ..., -0.6103, -1.5310, -0.8676],
          ...,
          [-2.6156, -2.6934, -2.6734,  ..., -2.7262, -2.6924, -2.6582],
          [-2.6062, -2.6643, -2.6643,  ..., -2.7230, -2.7091, -2.6203],
          [-2.5358, -2.6483, -2.5463,  ..., -2.7173, -2.6786, -2.5009]],

         [[ 0.9009,  0.2777, -0.9097,  ...,  0.9195, -0.0183, -0.5688],
          [ 0.4648, -0.2058,  0.8519,  ..., -0.3059, -0.1156,  0.8472],
          [-0.6064, -0.1322, -0.6453,  ...,  0.9855,  0.5003, -0.4622],
          ...,
          [ 0.8297,  1.1355,  1.1744,  ...,  1.1504,  1.2261,  1.4360],
          [ 0.7979,  1.2265,  1.2789,  ...,  1.2295,  1.2898,  1.5244],
          [ 1.2797,  1.2329,  1.4728,  ...,  1.2937,  1.4724,  1.4008]],

         ...,

         [[-0.0796, -1.0523, -0.1950,  ..., -0.4383, -0.7383, -0.2432],
          [-1.2808, -0.4969, -0.4962,  ..., -1.1745, -0.1237, -0.7957],
          [-1.3536, -0.1554, -1.8771,  ..., -0.5685, -1.5523, -0.9214],
          ...,
          [-1.9274, -1.0349, -1.3094,  ..., -0.3054, -1.3915, -1.3486],
          [-0.7738, -0.7489, -1.1319,  ..., -1.9757, -0.8756, -2.6383],
          [-1.8903, -0.6943, -1.2970,  ..., -0.3808, -1.3761, -1.0375]],

         [[ 1.0789,  0.2526, -0.9007,  ...,  0.8455,  0.0512, -0.5581],
          [ 0.5242, -0.2375,  0.9145,  ..., -0.3762, -0.2012,  0.9488],
          [-0.7725, -0.1789, -0.9380,  ...,  0.9561,  0.3656, -0.5224],
          ...,
          [ 0.6917, -0.3562, -0.2107,  ...,  0.4791, -0.9471,  0.6350],
          [ 1.5163,  0.0446,  0.5912,  ..., -0.0412,  0.0447, -1.1374],
          [ 1.5193, -1.0238, -0.1816,  ...,  0.7875, -0.0535,  0.2286]],

         [[ 0.9593,  0.6785,  0.1850,  ...,  1.3929, -0.4940, -0.7949],
          [ 0.5678,  0.9684,  0.1273,  ..., -0.9839,  0.8837,  0.8088],
          [ 0.2073, -0.0170,  0.0418,  ...,  1.2770,  0.3857, -0.4266],
          ...,
          [-0.0539,  0.4435,  1.6115,  ...,  0.5722,  0.2167,  1.0428],
          [ 1.5132,  0.8811,  1.9802,  ...,  1.3308,  0.2180,  0.2373],
          [ 1.0616, -0.4796,  1.1155,  ...,  0.8659,  1.5309,  0.1066]]],

        [[[ 0.3817, -0.9511, -0.0135,  ...,  0.3239, -0.3727, -1.2085],
          [-0.7161,  0.1602, -0.8396,  ..., -0.7383, -0.4438, -0.2077],
          [-1.0292, -0.2511, -1.4369,  ..., -0.2085, -0.2418, -1.1311],
          ...,
          [-1.2580, -1.0452, -0.9995,  ..., -0.9927, -1.1169, -1.3844],
          [-1.6654, -1.2902, -1.1895,  ..., -1.1970, -1.2997, -1.6015],
          [-1.1598, -1.9114, -1.4038,  ..., -1.7775, -1.6023, -1.4959]],

         [[-0.1305, -0.9890, -0.2282,  ..., -0.3648, -0.6344, -0.3028],
          [-1.2666, -0.5069, -0.5656,  ..., -1.1775, -0.1794, -0.7154],
          [-1.3254, -0.2052, -1.7900,  ..., -0.6103, -1.5310, -0.8676],
          ...,
          [-2.6156, -2.6934, -2.6734,  ..., -2.7262, -2.6924, -2.6582],
          [-2.6062, -2.6643, -2.6643,  ..., -2.7230, -2.7091, -2.6203],
          [-2.5358, -2.6483, -2.5463,  ..., -2.7173, -2.6786, -2.5009]],

         [[ 0.9009,  0.2777, -0.9097,  ...,  0.9195, -0.0183, -0.5688],
          [ 0.4648, -0.2058,  0.8519,  ..., -0.3059, -0.1156,  0.8472],
          [-0.6064, -0.1322, -0.6453,  ...,  0.9855,  0.5003, -0.4622],
          ...,
          [ 0.8297,  1.1355,  1.1744,  ...,  1.1504,  1.2261,  1.4360],
          [ 0.7979,  1.2265,  1.2789,  ...,  1.2295,  1.2898,  1.5244],
          [ 1.2797,  1.2329,  1.4728,  ...,  1.2937,  1.4724,  1.4008]],

         ...,

         [[-0.0796, -1.0523, -0.1950,  ..., -0.4383, -0.7383, -0.2432],
          [-1.2808, -0.4969, -0.4962,  ..., -1.1745, -0.1237, -0.7957],
          [-1.3536, -0.1554, -1.8771,  ..., -0.5685, -1.5523, -0.9214],
          ...,
          [-1.9274, -1.0349, -1.3094,  ..., -0.3054, -1.3915, -1.3486],
          [-0.7738, -0.7489, -1.1319,  ..., -1.9757, -0.8756, -2.6383],
          [-1.8903, -0.6943, -1.2970,  ..., -0.3808, -1.3761, -1.0375]],

         [[ 1.0789,  0.2526, -0.9007,  ...,  0.8455,  0.0512, -0.5581],
          [ 0.5242, -0.2375,  0.9145,  ..., -0.3762, -0.2012,  0.9488],
          [-0.7725, -0.1789, -0.9380,  ...,  0.9561,  0.3656, -0.5224],
          ...,
          [ 0.6917, -0.3562, -0.2107,  ...,  0.4791, -0.9471,  0.6350],
          [ 1.5163,  0.0446,  0.5912,  ..., -0.0412,  0.0447, -1.1374],
          [ 1.5193, -1.0238, -0.1816,  ...,  0.7875, -0.0535,  0.2286]],

         [[ 0.9593,  0.6785,  0.1850,  ...,  1.3929, -0.4940, -0.7949],
          [ 0.5678,  0.9684,  0.1273,  ..., -0.9839,  0.8837,  0.8088],
          [ 0.2073, -0.0170,  0.0418,  ...,  1.2770,  0.3857, -0.4266],
          ...,
          [-0.0539,  0.4435,  1.6115,  ...,  0.5722,  0.2167,  1.0428],
          [ 1.5132,  0.8811,  1.9802,  ...,  1.3308,  0.2180,  0.2373],
          [ 1.0616, -0.4796,  1.1155,  ...,  0.8659,  1.5309,  0.1066]]],

        ...,

        [[[ 0.3817, -0.9511, -0.0135,  ...,  0.3239, -0.3727, -1.2085],
          [-0.7161,  0.1602, -0.8396,  ..., -0.7383, -0.4438, -0.2077],
          [-1.0292, -0.2511, -1.4369,  ..., -0.2085, -0.2418, -1.1311],
          ...,
          [-1.2580, -1.0452, -0.9995,  ..., -0.9927, -1.1169, -1.3844],
          [-1.6654, -1.2902, -1.1895,  ..., -1.1970, -1.2997, -1.6015],
          [-1.1598, -1.9114, -1.4038,  ..., -1.7775, -1.6023, -1.4959]],

         [[-0.1305, -0.9890, -0.2282,  ..., -0.3648, -0.6344, -0.3028],
          [-1.2666, -0.5069, -0.5656,  ..., -1.1775, -0.1794, -0.7154],
          [-1.3254, -0.2052, -1.7900,  ..., -0.6103, -1.5310, -0.8676],
          ...,
          [-2.6156, -2.6934, -2.6734,  ..., -2.7262, -2.6924, -2.6582],
          [-2.6062, -2.6643, -2.6643,  ..., -2.7230, -2.7091, -2.6203],
          [-2.5358, -2.6483, -2.5463,  ..., -2.7173, -2.6786, -2.5009]],

         [[ 0.9009,  0.2777, -0.9097,  ...,  0.9195, -0.0183, -0.5688],
          [ 0.4648, -0.2058,  0.8519,  ..., -0.3059, -0.1156,  0.8472],
          [-0.6064, -0.1322, -0.6453,  ...,  0.9855,  0.5003, -0.4622],
          ...,
          [ 0.8297,  1.1355,  1.1744,  ...,  1.1504,  1.2261,  1.4360],
          [ 0.7979,  1.2265,  1.2789,  ...,  1.2295,  1.2898,  1.5244],
          [ 1.2797,  1.2329,  1.4728,  ...,  1.2937,  1.4724,  1.4008]],

         ...,

         [[-0.0796, -1.0523, -0.1950,  ..., -0.4383, -0.7383, -0.2432],
          [-1.2808, -0.4969, -0.4962,  ..., -1.1745, -0.1237, -0.7957],
          [-1.3536, -0.1554, -1.8771,  ..., -0.5685, -1.5523, -0.9214],
          ...,
          [-1.9274, -1.0349, -1.3094,  ..., -0.3054, -1.3915, -1.3486],
          [-0.7738, -0.7489, -1.1319,  ..., -1.9757, -0.8756, -2.6383],
          [-1.8903, -0.6943, -1.2970,  ..., -0.3808, -1.3761, -1.0375]],

         [[ 1.0789,  0.2526, -0.9007,  ...,  0.8455,  0.0512, -0.5581],
          [ 0.5242, -0.2375,  0.9145,  ..., -0.3762, -0.2012,  0.9488],
          [-0.7725, -0.1789, -0.9380,  ...,  0.9561,  0.3656, -0.5224],
          ...,
          [ 0.6917, -0.3562, -0.2107,  ...,  0.4791, -0.9471,  0.6350],
          [ 1.5163,  0.0446,  0.5912,  ..., -0.0412,  0.0447, -1.1374],
          [ 1.5193, -1.0238, -0.1816,  ...,  0.7875, -0.0535,  0.2286]],

         [[ 0.9593,  0.6785,  0.1850,  ...,  1.3929, -0.4940, -0.7949],
          [ 0.5678,  0.9684,  0.1273,  ..., -0.9839,  0.8837,  0.8088],
          [ 0.2073, -0.0170,  0.0418,  ...,  1.2770,  0.3857, -0.4266],
          ...,
          [-0.0539,  0.4435,  1.6115,  ...,  0.5722,  0.2167,  1.0428],
          [ 1.5132,  0.8811,  1.9802,  ...,  1.3308,  0.2180,  0.2373],
          [ 1.0616, -0.4796,  1.1155,  ...,  0.8659,  1.5309,  0.1066]]],

        [[[ 0.3817, -0.9511, -0.0135,  ...,  0.3239, -0.3727, -1.2085],
          [-0.7161,  0.1602, -0.8396,  ..., -0.7383, -0.4438, -0.2077],
          [-1.0292, -0.2511, -1.4369,  ..., -0.2085, -0.2418, -1.1311],
          ...,
          [-1.2580, -1.0452, -0.9995,  ..., -0.9927, -1.1169, -1.3844],
          [-1.6654, -1.2902, -1.1895,  ..., -1.1970, -1.2997, -1.6015],
          [-1.1598, -1.9114, -1.4038,  ..., -1.7775, -1.6023, -1.4959]],

         [[-0.1305, -0.9890, -0.2282,  ..., -0.3648, -0.6344, -0.3028],
          [-1.2666, -0.5069, -0.5656,  ..., -1.1775, -0.1794, -0.7154],
          [-1.3254, -0.2052, -1.7900,  ..., -0.6103, -1.5310, -0.8676],
          ...,
          [-2.6156, -2.6934, -2.6734,  ..., -2.7262, -2.6924, -2.6582],
          [-2.6062, -2.6643, -2.6643,  ..., -2.7230, -2.7091, -2.6203],
          [-2.5358, -2.6483, -2.5463,  ..., -2.7173, -2.6786, -2.5009]],

         [[ 0.9009,  0.2777, -0.9097,  ...,  0.9195, -0.0183, -0.5688],
          [ 0.4648, -0.2058,  0.8519,  ..., -0.3059, -0.1156,  0.8472],
          [-0.6064, -0.1322, -0.6453,  ...,  0.9855,  0.5003, -0.4622],
          ...,
          [ 0.8297,  1.1355,  1.1744,  ...,  1.1504,  1.2261,  1.4360],
          [ 0.7979,  1.2265,  1.2789,  ...,  1.2295,  1.2898,  1.5244],
          [ 1.2797,  1.2329,  1.4728,  ...,  1.2937,  1.4724,  1.4008]],

         ...,

         [[-0.0796, -1.0523, -0.1950,  ..., -0.4383, -0.7383, -0.2432],
          [-1.2808, -0.4969, -0.4962,  ..., -1.1745, -0.1237, -0.7957],
          [-1.3536, -0.1554, -1.8771,  ..., -0.5685, -1.5523, -0.9214],
          ...,
          [-1.9274, -1.0349, -1.3094,  ..., -0.3054, -1.3915, -1.3486],
          [-0.7738, -0.7489, -1.1319,  ..., -1.9757, -0.8756, -2.6383],
          [-1.8903, -0.6943, -1.2970,  ..., -0.3808, -1.3761, -1.0375]],

         [[ 1.0789,  0.2526, -0.9007,  ...,  0.8455,  0.0512, -0.5581],
          [ 0.5242, -0.2375,  0.9145,  ..., -0.3762, -0.2012,  0.9488],
          [-0.7725, -0.1789, -0.9380,  ...,  0.9561,  0.3656, -0.5224],
          ...,
          [ 0.6917, -0.3562, -0.2107,  ...,  0.4791, -0.9471,  0.6350],
          [ 1.5163,  0.0446,  0.5912,  ..., -0.0412,  0.0447, -1.1374],
          [ 1.5193, -1.0238, -0.1816,  ...,  0.7875, -0.0535,  0.2286]],

         [[ 0.9593,  0.6785,  0.1850,  ...,  1.3929, -0.4940, -0.7949],
          [ 0.5678,  0.9684,  0.1273,  ..., -0.9839,  0.8837,  0.8088],
          [ 0.2073, -0.0170,  0.0418,  ...,  1.2770,  0.3857, -0.4266],
          ...,
          [-0.0539,  0.4435,  1.6115,  ...,  0.5722,  0.2167,  1.0428],
          [ 1.5132,  0.8811,  1.9802,  ...,  1.3308,  0.2180,  0.2373],
          [ 1.0616, -0.4796,  1.1155,  ...,  0.8659,  1.5309,  0.1066]]],

        [[[ 0.3817, -0.9511, -0.0135,  ...,  0.3239, -0.3727, -1.2085],
          [-0.7161,  0.1602, -0.8396,  ..., -0.7383, -0.4438, -0.2077],
          [-1.0292, -0.2511, -1.4369,  ..., -0.2085, -0.2418, -1.1311],
          ...,
          [-1.2580, -1.0452, -0.9995,  ..., -0.9927, -1.1169, -1.3844],
          [-1.6654, -1.2902, -1.1895,  ..., -1.1970, -1.2997, -1.6015],
          [-1.1598, -1.9114, -1.4038,  ..., -1.7775, -1.6023, -1.4959]],

         [[-0.1305, -0.9890, -0.2282,  ..., -0.3648, -0.6344, -0.3028],
          [-1.2666, -0.5069, -0.5656,  ..., -1.1775, -0.1794, -0.7154],
          [-1.3254, -0.2052, -1.7900,  ..., -0.6103, -1.5310, -0.8676],
          ...,
          [-2.6156, -2.6934, -2.6734,  ..., -2.7262, -2.6924, -2.6582],
          [-2.6062, -2.6643, -2.6643,  ..., -2.7230, -2.7091, -2.6203],
          [-2.5358, -2.6483, -2.5463,  ..., -2.7173, -2.6786, -2.5009]],

         [[ 0.9009,  0.2777, -0.9097,  ...,  0.9195, -0.0183, -0.5688],
          [ 0.4648, -0.2058,  0.8519,  ..., -0.3059, -0.1156,  0.8472],
          [-0.6064, -0.1322, -0.6453,  ...,  0.9855,  0.5003, -0.4622],
          ...,
          [ 0.8297,  1.1355,  1.1744,  ...,  1.1504,  1.2261,  1.4360],
          [ 0.7979,  1.2265,  1.2789,  ...,  1.2295,  1.2898,  1.5244],
          [ 1.2797,  1.2329,  1.4728,  ...,  1.2937,  1.4724,  1.4008]],

         ...,

         [[-0.0796, -1.0523, -0.1950,  ..., -0.4383, -0.7383, -0.2432],
          [-1.2808, -0.4969, -0.4962,  ..., -1.1745, -0.1237, -0.7957],
          [-1.3536, -0.1554, -1.8771,  ..., -0.5685, -1.5523, -0.9214],
          ...,
          [-1.9274, -1.0349, -1.3094,  ..., -0.3054, -1.3915, -1.3486],
          [-0.7738, -0.7489, -1.1319,  ..., -1.9757, -0.8756, -2.6383],
          [-1.8903, -0.6943, -1.2970,  ..., -0.3808, -1.3761, -1.0375]],

         [[ 1.0789,  0.2526, -0.9007,  ...,  0.8455,  0.0512, -0.5581],
          [ 0.5242, -0.2375,  0.9145,  ..., -0.3762, -0.2012,  0.9488],
          [-0.7725, -0.1789, -0.9380,  ...,  0.9561,  0.3656, -0.5224],
          ...,
          [ 0.6917, -0.3562, -0.2107,  ...,  0.4791, -0.9471,  0.6350],
          [ 1.5163,  0.0446,  0.5912,  ..., -0.0412,  0.0447, -1.1374],
          [ 1.5193, -1.0238, -0.1816,  ...,  0.7875, -0.0535,  0.2286]],

         [[ 0.9593,  0.6785,  0.1850,  ...,  1.3929, -0.4940, -0.7949],
          [ 0.5678,  0.9684,  0.1273,  ..., -0.9839,  0.8837,  0.8088],
          [ 0.2073, -0.0170,  0.0418,  ...,  1.2770,  0.3857, -0.4266],
          ...,
          [-0.0539,  0.4435,  1.6115,  ...,  0.5722,  0.2167,  1.0428],
          [ 1.5132,  0.8811,  1.9802,  ...,  1.3308,  0.2180,  0.2373],
          [ 1.0616, -0.4796,  1.1155,  ...,  0.8659,  1.5309,  0.1066]]]],
       device='cuda:0')
i  0
tensor_list  [tensor([[-1.6992e-01, -2.8849e-04, -1.4514e-01,  ...,  9.9609e-01,
          9.5850e-01,  8.3105e-01],
        [ 6.5479e-01,  6.7578e-01,  2.2339e-01,  ...,  6.2012e-01,
          8.7939e-01,  7.4072e-01],
        [-1.8799e-02,  1.9385e-01,  2.3035e-01,  ...,  6.1963e-01,
          1.0977e+00,  8.1250e-01],
        ...,
        [ 8.7891e-01,  6.8164e-01,  2.3022e-01,  ...,  5.9766e-01,
          7.5195e-01,  3.6182e-01],
        [-3.9062e-02,  6.7188e-01,  2.3315e-01,  ...,  9.6973e-01,
          1.5469e+00,  1.0498e+00],
        [ 2.8564e-02,  9.0479e-01,  1.8616e-01,  ...,  2.5156e+00,
          2.3164e+00,  1.1719e+00]]), tensor([[-1.6992e-01, -2.8849e-04, -1.4514e-01,  ...,  9.9609e-01,
          9.5850e-01,  8.3105e-01],
        [ 6.5479e-01,  6.7578e-01,  2.2339e-01,  ...,  6.2012e-01,
          8.7939e-01,  7.4072e-01],
        [-1.8799e-02,  1.9385e-01,  2.3035e-01,  ...,  6.1963e-01,
          1.0977e+00,  8.1250e-01],
        ...,
        [ 5.3662e-01,  8.0469e-01,  8.2959e-01,  ...,  6.2012e-01,
          9.1797e-01,  6.6113e-01],
        [ 5.0903e-02,  6.9922e-01,  6.6895e-01,  ...,  3.8623e-01,
          1.7100e+00,  7.8516e-01],
        [ 1.0413e-01,  4.3945e-02,  1.3086e+00,  ...,  1.6094e+00,
          2.7227e+00,  3.5864e-01]]), tensor([[-1.6992e-01, -2.8849e-04, -1.4514e-01,  ...,  9.9609e-01,
          9.5850e-01,  8.3105e-01],
        [ 6.5479e-01,  6.7578e-01,  2.2339e-01,  ...,  6.2012e-01,
          8.7939e-01,  7.4072e-01],
        [-1.8799e-02,  1.9385e-01,  2.3035e-01,  ...,  6.1963e-01,
          1.0977e+00,  8.1250e-01],
        ...,
        [ 1.6875e+00,  1.5293e+00,  9.4727e-01,  ...,  5.1660e-01,
          8.5889e-01,  5.2393e-01],
        [ 1.8154e+00,  1.4775e+00,  1.2705e+00,  ...,  1.1113e+00,
          1.0674e+00,  9.5215e-01],
        [ 1.2295e+00,  2.2734e+00,  3.2852e+00,  ...,  2.1562e+00,
          2.4258e+00,  1.3252e+00]]), tensor([[ 0.7466,  0.9438,  0.9272,  ...,  0.8335,  0.8496,  0.9282],
        [ 1.3750,  1.0010,  0.2400,  ...,  0.5576,  0.5830,  0.6914],
        [ 0.5630,  0.4341,  0.7988,  ...,  0.6494,  0.7158,  0.3760],
        ...,
        [ 1.0615,  1.1064,  0.4785,  ...,  0.4209,  0.8628,  0.6089],
        [ 1.5928,  1.8438,  1.1025,  ...,  0.8960,  0.9668,  1.0703],
        [-1.2402,  2.9102,  3.0430,  ...,  1.6982,  2.1172,  2.1641]]), tensor([[-0.9268, -0.6216, -0.5200,  ...,  0.9253,  0.9023,  0.8301],
        [ 1.6689,  1.0840,  0.0192,  ...,  0.5884,  0.6011,  0.4355],
        [ 0.9575,  0.5566,  0.0983,  ...,  0.8398,  0.9385,  0.2595],
        ...,
        [ 0.7090,  0.7812,  0.4150,  ...,  0.3760,  0.5977,  0.7324],
        [ 0.5254,  1.0869,  0.5771,  ...,  0.7759,  1.3252,  0.6343],
        [-0.4062,  3.6562,  0.8647,  ...,  0.6416,  1.1504,  1.2510]]), tensor([[-0.4492, -0.5361, -0.8936,  ...,  0.8760,  0.8789,  0.8311],
        [ 0.5010,  0.7695,  0.5776,  ...,  0.6079,  0.5884,  0.6230],
        [ 0.2593,  0.5171,  0.8203,  ...,  0.4893,  0.9038,  0.6982],
        ...,
        [ 0.8169,  1.3691,  1.0293,  ...,  0.3867,  0.9897,  0.8184],
        [ 0.3696,  1.7598,  0.9619,  ...,  0.6650,  1.5371,  0.9336],
        [ 0.2866,  4.6250,  1.6562,  ...,  0.4128,  1.7520,  1.6016]]), tensor([[ 0.8433,  0.9736,  0.6611,  ...,  0.8599,  0.8784,  0.8555],
        [ 1.3496,  1.3945,  0.8237,  ...,  0.4414,  0.6660,  0.5698],
        [ 1.2510,  1.2715,  1.1982,  ...,  0.5430,  0.8228,  0.5537],
        ...,
        [ 0.8110,  1.8994,  0.6909,  ...,  0.4370,  0.6860,  0.8325],
        [-0.3779,  2.3457,  0.9810,  ...,  0.7344,  1.0410,  0.9448],
        [-1.2832,  4.2266,  1.4570,  ...,  0.2820,  0.5786,  1.2695]]), tensor([[-0.6748, -0.1046,  0.1676,  ...,  0.8301,  0.8799,  0.8350],
        [ 1.6562,  1.4619,  0.3916,  ...,  0.4409,  0.9580,  0.7104],
        [ 1.5234,  1.3545,  0.5259,  ...,  0.3245,  1.0654,  0.5708],
        ...,
        [ 1.0566,  1.6387,  0.3723,  ...,  0.7959,  0.7427,  1.0000],
        [-0.2168,  2.4570,  0.4355,  ...,  0.5176,  1.0869,  0.9888],
        [-2.0703,  3.5664,  1.5537,  ...,  0.1545,  0.8276,  1.5273]])]
audio_feature_batch  tensor([[[-1.6992e-01, -2.8849e-04, -1.4514e-01,  ...,  9.9609e-01,
           9.5850e-01,  8.3105e-01],
         [ 6.5479e-01,  6.7578e-01,  2.2339e-01,  ...,  6.2012e-01,
           8.7939e-01,  7.4072e-01],
         [-1.8799e-02,  1.9385e-01,  2.3035e-01,  ...,  6.1963e-01,
           1.0977e+00,  8.1250e-01],
         ...,
         [ 8.7891e-01,  6.8164e-01,  2.3022e-01,  ...,  5.9766e-01,
           7.5195e-01,  3.6182e-01],
         [-3.9062e-02,  6.7188e-01,  2.3315e-01,  ...,  9.6973e-01,
           1.5469e+00,  1.0498e+00],
         [ 2.8564e-02,  9.0479e-01,  1.8616e-01,  ...,  2.5156e+00,
           2.3164e+00,  1.1719e+00]],

        [[-1.6992e-01, -2.8849e-04, -1.4514e-01,  ...,  9.9609e-01,
           9.5850e-01,  8.3105e-01],
         [ 6.5479e-01,  6.7578e-01,  2.2339e-01,  ...,  6.2012e-01,
           8.7939e-01,  7.4072e-01],
         [-1.8799e-02,  1.9385e-01,  2.3035e-01,  ...,  6.1963e-01,
           1.0977e+00,  8.1250e-01],
         ...,
         [ 5.3662e-01,  8.0469e-01,  8.2959e-01,  ...,  6.2012e-01,
           9.1797e-01,  6.6113e-01],
         [ 5.0903e-02,  6.9922e-01,  6.6895e-01,  ...,  3.8623e-01,
           1.7100e+00,  7.8516e-01],
         [ 1.0413e-01,  4.3945e-02,  1.3086e+00,  ...,  1.6094e+00,
           2.7227e+00,  3.5864e-01]],

        [[-1.6992e-01, -2.8849e-04, -1.4514e-01,  ...,  9.9609e-01,
           9.5850e-01,  8.3105e-01],
         [ 6.5479e-01,  6.7578e-01,  2.2339e-01,  ...,  6.2012e-01,
           8.7939e-01,  7.4072e-01],
         [-1.8799e-02,  1.9385e-01,  2.3035e-01,  ...,  6.1963e-01,
           1.0977e+00,  8.1250e-01],
         ...,
         [ 1.6875e+00,  1.5293e+00,  9.4727e-01,  ...,  5.1660e-01,
           8.5889e-01,  5.2393e-01],
         [ 1.8154e+00,  1.4775e+00,  1.2705e+00,  ...,  1.1113e+00,
           1.0674e+00,  9.5215e-01],
         [ 1.2295e+00,  2.2734e+00,  3.2852e+00,  ...,  2.1562e+00,
           2.4258e+00,  1.3252e+00]],

        ...,

        [[-4.4922e-01, -5.3613e-01, -8.9355e-01,  ...,  8.7598e-01,
           8.7891e-01,  8.3105e-01],
         [ 5.0098e-01,  7.6953e-01,  5.7764e-01,  ...,  6.0791e-01,
           5.8838e-01,  6.2305e-01],
         [ 2.5928e-01,  5.1709e-01,  8.2031e-01,  ...,  4.8926e-01,
           9.0381e-01,  6.9824e-01],
         ...,
         [ 8.1689e-01,  1.3691e+00,  1.0293e+00,  ...,  3.8672e-01,
           9.8975e-01,  8.1836e-01],
         [ 3.6963e-01,  1.7598e+00,  9.6191e-01,  ...,  6.6504e-01,
           1.5371e+00,  9.3359e-01],
         [ 2.8662e-01,  4.6250e+00,  1.6562e+00,  ...,  4.1284e-01,
           1.7520e+00,  1.6016e+00]],

        [[ 8.4326e-01,  9.7363e-01,  6.6113e-01,  ...,  8.5986e-01,
           8.7842e-01,  8.5547e-01],
         [ 1.3496e+00,  1.3945e+00,  8.2373e-01,  ...,  4.4141e-01,
           6.6602e-01,  5.6982e-01],
         [ 1.2510e+00,  1.2715e+00,  1.1982e+00,  ...,  5.4297e-01,
           8.2275e-01,  5.5371e-01],
         ...,
         [ 8.1104e-01,  1.8994e+00,  6.9092e-01,  ...,  4.3701e-01,
           6.8604e-01,  8.3252e-01],
         [-3.7793e-01,  2.3457e+00,  9.8096e-01,  ...,  7.3438e-01,
           1.0410e+00,  9.4482e-01],
         [-1.2832e+00,  4.2266e+00,  1.4570e+00,  ...,  2.8198e-01,
           5.7861e-01,  1.2695e+00]],

        [[-6.7480e-01, -1.0455e-01,  1.6760e-01,  ...,  8.3008e-01,
           8.7988e-01,  8.3496e-01],
         [ 1.6562e+00,  1.4619e+00,  3.9160e-01,  ...,  4.4092e-01,
           9.5801e-01,  7.1045e-01],
         [ 1.5234e+00,  1.3545e+00,  5.2588e-01,  ...,  3.2446e-01,
           1.0654e+00,  5.7080e-01],
         ...,
         [ 1.0566e+00,  1.6387e+00,  3.7231e-01,  ...,  7.9590e-01,
           7.4268e-01,  1.0000e+00],
         [-2.1680e-01,  2.4570e+00,  4.3555e-01,  ...,  5.1758e-01,
           1.0869e+00,  9.8877e-01],
         [-2.0703e+00,  3.5664e+00,  1.5537e+00,  ...,  1.5454e-01,
           8.2764e-01,  1.5273e+00]]], device='cuda:0')
audio_feature_batch  tensor([[[-0.1699,  0.9997, -0.1451,  ...,  1.9961,  0.9585,  1.8311],
         [ 1.4963,  1.2161,  1.0386,  ...,  1.6201,  0.8795,  1.7407],
         [ 0.8905, -0.2223,  1.1746,  ...,  1.6196,  1.0979,  1.8125],
         ...,
         [ 1.0025, -0.3107,  0.9589,  ...,  1.5976,  0.7569,  1.3618],
         [-0.8073,  0.0317,  1.2135,  ...,  1.9697,  1.5519,  2.0498],
         [-0.9252,  1.2054,  0.5930,  ...,  3.5156,  2.3215,  2.1719]],

        [[-0.1699,  0.9997, -0.1451,  ...,  1.9961,  0.9585,  1.8311],
         [ 1.4963,  1.2161,  1.0386,  ...,  1.6201,  0.8795,  1.7407],
         [ 0.8905, -0.2223,  1.1746,  ...,  1.6196,  1.0979,  1.8125],
         ...,
         [ 0.6602, -0.1876,  1.5582,  ...,  1.6201,  0.9229,  1.6611],
         [-0.7174,  0.0591,  1.6493,  ...,  1.3862,  1.7150,  1.7851],
         [-0.8496,  0.3445,  1.7154,  ...,  2.6094,  2.7278,  1.3586]],

        [[-0.1699,  0.9997, -0.1451,  ...,  1.9961,  0.9585,  1.8311],
         [ 1.4963,  1.2161,  1.0386,  ...,  1.6201,  0.8795,  1.7407],
         [ 0.8905, -0.2223,  1.1746,  ...,  1.6196,  1.0979,  1.8125],
         ...,
         [ 1.8111,  0.5370,  1.6759,  ...,  1.5166,  0.8638,  1.5239],
         [ 1.0472,  0.8374,  2.2508,  ...,  2.1113,  1.0724,  1.9521],
         [ 0.2757,  2.5740,  3.6920,  ...,  3.1562,  2.4309,  2.3252]],

        ...,

        [[-0.4492,  0.4639, -0.8936,  ...,  1.8760,  0.8789,  1.8311],
         [ 1.3424,  1.3098,  1.3929,  ...,  1.6079,  0.5885,  1.6230],
         [ 1.1686,  0.1009,  1.7645,  ...,  1.4893,  0.9040,  1.6982],
         ...,
         [ 0.9405,  0.3768,  1.7579,  ...,  1.3867,  0.9947,  1.8183],
         [-0.3986,  1.1196,  1.9422,  ...,  1.6650,  1.5421,  1.9336],
         [-0.6671,  4.9256,  2.0631,  ...,  1.4128,  1.7571,  2.6015]],

        [[ 0.8433,  1.9736,  0.6611,  ...,  1.8599,  0.8784,  1.8555],
         [ 2.1911,  1.9348,  1.6390,  ...,  1.4414,  0.6661,  1.5698],
         [ 2.1603,  0.8553,  2.1425,  ...,  1.5430,  0.8230,  1.5537],
         ...,
         [ 0.9346,  0.9071,  1.4195,  ...,  1.4370,  0.6910,  1.8325],
         [-1.1462,  1.7056,  1.9613,  ...,  1.7344,  1.0461,  1.9448],
         [-2.2370,  4.5272,  1.8638,  ...,  1.2820,  0.5838,  2.2695]],

        [[-0.6748,  0.8954,  0.1676,  ...,  1.8301,  0.8799,  1.8350],
         [ 2.4977,  2.0022,  1.2069,  ...,  1.4409,  0.9581,  1.7104],
         [ 2.4327,  0.9383,  1.4701,  ...,  1.3245,  1.0656,  1.5708],
         ...,
         [ 1.1802,  0.6463,  1.1009,  ...,  1.7959,  0.7476,  2.0000],
         [-0.9851,  1.8169,  1.4159,  ...,  1.5176,  1.0919,  1.9888],
         [-3.0241,  3.8670,  1.9605,  ...,  1.1545,  0.8328,  2.5273]]],
       device='cuda:0')
pred_latents  tensor([[[[ 3.7412e-01, -1.0128e+00, -1.1039e-01,  ...,  2.3265e-01,
           -4.4953e-01, -1.4030e+00],
          [-7.9819e-01,  1.2140e-01, -9.3041e-01,  ..., -8.7399e-01,
           -5.0545e-01, -2.6268e-01],
          [-1.1118e+00, -3.5418e-01, -1.5300e+00,  ..., -3.1585e-01,
           -3.4360e-01, -1.3349e+00],
          ...,
          [-2.3894e-01, -1.0022e+00, -3.6404e-01,  ..., -3.6769e-01,
           -1.0657e+00, -1.5863e+00],
          [ 8.5260e-01, -4.5738e-01, -3.6743e-01,  ..., -9.2667e-01,
           -9.7199e-01, -2.0430e+00],
          [-9.4874e-02, -6.5196e-01, -4.9785e-01,  ..., -1.2289e+00,
           -1.5544e+00, -1.2522e+00]],

         [[-6.3309e-02, -1.0330e+00, -2.1044e-01,  ..., -3.3205e-01,
           -6.8075e-01, -3.3506e-01],
          [-1.2381e+00, -4.9634e-01, -4.8256e-01,  ..., -1.1648e+00,
           -1.3385e-01, -7.4852e-01],
          [-1.3310e+00, -7.5743e-02, -1.8286e+00,  ..., -5.8759e-01,
           -1.5620e+00, -9.1132e-01],
          ...,
          [-1.1714e+00, -6.8291e-01, -1.3052e+00,  ..., -7.5767e-01,
           -1.3653e+00, -1.0262e+00],
          [-2.8148e-01, -5.8824e-01, -1.2146e+00,  ..., -1.9754e+00,
           -1.2642e+00, -2.3117e+00],
          [-1.0777e+00, -2.5658e-01, -1.3899e+00,  ..., -6.7512e-01,
           -1.7101e+00, -1.4595e+00]],

         [[ 9.2931e-01,  2.6834e-01, -9.5206e-01,  ...,  8.7856e-01,
           -3.0979e-02, -6.1548e-01],
          [ 4.9769e-01, -2.8966e-01,  8.7906e-01,  ..., -3.5364e-01,
           -1.1364e-01,  9.2771e-01],
          [-7.5626e-01, -1.5688e-01, -7.9081e-01,  ...,  1.0198e+00,
            4.1046e-01, -5.3807e-01],
          ...,
          [-1.6315e-01, -1.9228e-02,  5.3250e-01,  ...,  3.4300e-01,
           -4.2475e-01,  2.9064e-01],
          [ 1.0992e+00, -7.6490e-01,  5.8788e-01,  ...,  2.2654e-02,
            2.1929e-01, -1.5611e-02],
          [ 9.2666e-01, -1.0825e+00,  8.1190e-01,  ...,  8.1316e-01,
            3.2251e-01,  3.4879e-01]],

         [[ 8.8533e-01,  6.7366e-01,  2.0601e-01,  ...,  1.4157e+00,
           -6.1014e-01, -8.9993e-01],
          [ 5.7487e-01,  9.3581e-01,  1.0111e-01,  ..., -9.6689e-01,
            8.8405e-01,  8.1334e-01],
          [ 2.2152e-01,  3.7591e-02,  1.8818e-01,  ...,  1.3375e+00,
            4.3562e-01, -4.3471e-01],
          ...,
          [-5.2369e-01,  5.8894e-01,  1.6815e+00,  ...,  4.6181e-01,
            7.5429e-01,  4.3781e-01],
          [ 6.8894e-01,  3.2236e-01,  1.6446e+00,  ...,  9.2284e-01,
            5.8956e-01,  7.9967e-01],
          [-1.2057e-01,  3.7213e-01,  1.1123e+00,  ...,  9.2104e-01,
            1.1558e+00,  6.9341e-01]]],

        [[[ 3.7491e-01, -1.0154e+00, -1.0798e-01,  ...,  2.3196e-01,
           -4.4994e-01, -1.4045e+00],
          [-8.0025e-01,  1.2184e-01, -9.3312e-01,  ..., -8.7618e-01,
           -5.0600e-01, -2.6268e-01],
          [-1.1125e+00, -3.5330e-01, -1.5316e+00,  ..., -3.1378e-01,
           -3.4207e-01, -1.3355e+00],
          ...,
          [-2.6295e-01, -9.4617e-01, -3.0006e-01,  ..., -3.7268e-01,
           -1.0672e+00, -1.5631e+00],
          [ 5.0202e-01, -7.8364e-01, -3.9186e-01,  ..., -9.5466e-01,
           -9.3741e-01, -2.0593e+00],
          [-1.6009e-01, -5.4739e-01, -4.3113e-01,  ..., -1.2075e+00,
           -1.5381e+00, -1.1957e+00]],

         [[-6.5459e-02, -1.0361e+00, -2.1266e-01,  ..., -3.3441e-01,
           -6.8478e-01, -3.3910e-01],
          [-1.2380e+00, -4.9685e-01, -4.8272e-01,  ..., -1.1662e+00,
           -1.3191e-01, -7.4877e-01],
          [-1.3316e+00, -7.6665e-02, -1.8300e+00,  ..., -5.8791e-01,
           -1.5641e+00, -9.1424e-01],
          ...,
          [-9.7245e-01, -6.2545e-01, -1.3281e+00,  ..., -7.4139e-01,
           -1.3339e+00, -1.0339e+00],
          [-2.3538e-01, -5.1351e-01, -1.2089e+00,  ..., -1.9968e+00,
           -1.2948e+00, -2.3031e+00],
          [-7.0884e-01, -2.3422e-01, -1.4031e+00,  ..., -6.5526e-01,
           -1.7602e+00, -1.4610e+00]],

         [[ 9.3176e-01,  2.6875e-01, -9.5082e-01,  ...,  8.8150e-01,
           -3.3057e-02, -6.1609e-01],
          [ 4.9788e-01, -2.9205e-01,  8.7881e-01,  ..., -3.5635e-01,
           -1.1823e-01,  9.2960e-01],
          [-7.5609e-01, -1.5499e-01, -7.9178e-01,  ...,  1.0206e+00,
            4.1383e-01, -5.4152e-01],
          ...,
          [-1.0994e-01,  1.6703e-01,  6.2203e-01,  ...,  3.4539e-01,
           -4.3519e-01,  3.0041e-01],
          [ 4.7367e-01, -9.9871e-01,  6.6283e-01,  ...,  7.8338e-03,
            1.9758e-01, -3.6090e-02],
          [ 8.4022e-01, -9.3114e-01,  9.0589e-01,  ...,  7.9662e-01,
            3.2852e-01,  3.0640e-01]],

         [[ 8.8775e-01,  6.7407e-01,  2.0714e-01,  ...,  1.4190e+00,
           -6.1197e-01, -9.0137e-01],
          [ 5.7439e-01,  9.3477e-01,  1.0082e-01,  ..., -9.7111e-01,
            8.8347e-01,  8.1671e-01],
          [ 2.2329e-01,  3.8728e-02,  1.8852e-01,  ...,  1.3430e+00,
            4.3770e-01, -4.3626e-01],
          ...,
          [-4.9566e-01,  7.1735e-01,  1.6934e+00,  ...,  4.5792e-01,
            7.6656e-01,  4.5826e-01],
          [ 8.2424e-02,  1.0736e-01,  1.6159e+00,  ...,  9.0682e-01,
            5.8456e-01,  8.0925e-01],
          [-1.5727e-01,  6.8191e-01,  1.0613e+00,  ...,  9.5976e-01,
            1.1494e+00,  7.2555e-01]]],

        [[[ 3.7690e-01, -1.0157e+00, -1.0709e-01,  ...,  2.3208e-01,
           -4.5129e-01, -1.4088e+00],
          [-8.0125e-01,  1.2066e-01, -9.3613e-01,  ..., -8.7329e-01,
           -5.0152e-01, -2.6336e-01],
          [-1.1140e+00, -3.5317e-01, -1.5358e+00,  ..., -3.1531e-01,
           -3.4874e-01, -1.3443e+00],
          ...,
          [-3.5891e-01, -8.1097e-01, -2.7077e-01,  ..., -3.4793e-01,
           -1.1200e+00, -1.4852e+00],
          [-6.5348e-02, -1.0241e+00, -4.0172e-01,  ..., -9.1828e-01,
           -9.1306e-01, -2.1127e+00],
          [-1.5626e-01, -3.0667e-01, -3.2174e-01,  ..., -1.1558e+00,
           -1.5781e+00, -1.1956e+00]],

         [[-7.2290e-02, -1.0394e+00, -2.1639e-01,  ..., -3.4024e-01,
           -6.9012e-01, -3.4473e-01],
          [-1.2397e+00, -5.0040e-01, -4.8048e-01,  ..., -1.1644e+00,
           -1.2973e-01, -7.5380e-01],
          [-1.3326e+00, -7.6061e-02, -1.8305e+00,  ..., -5.9590e-01,
           -1.5690e+00, -9.1238e-01],
          ...,
          [-6.7242e-01, -6.3245e-01, -1.3340e+00,  ..., -7.3324e-01,
           -1.2962e+00, -1.0694e+00],
          [-1.7754e-01, -4.8361e-01, -1.1758e+00,  ..., -1.9907e+00,
           -1.2494e+00, -2.3104e+00],
          [-2.6232e-01, -2.4234e-01, -1.4365e+00,  ..., -6.6283e-01,
           -1.7050e+00, -1.4727e+00]],

         [[ 9.3383e-01,  2.7177e-01, -9.5359e-01,  ...,  8.8360e-01,
           -3.5792e-02, -6.1504e-01],
          [ 4.9638e-01, -2.9618e-01,  8.7723e-01,  ..., -3.5796e-01,
           -1.1359e-01,  9.2978e-01],
          [-7.5771e-01, -1.5790e-01, -7.9293e-01,  ...,  1.0201e+00,
            4.0288e-01, -5.4458e-01],
          ...,
          [-6.2563e-02,  2.6690e-01,  6.8444e-01,  ...,  2.6761e-01,
           -3.6887e-01,  3.4091e-01],
          [-3.6567e-01, -1.0978e+00,  7.4587e-01,  ...,  3.5716e-02,
            1.0048e-01,  6.4403e-02],
          [ 7.9213e-01, -6.3202e-01,  9.6950e-01,  ...,  7.8871e-01,
            3.7619e-01,  2.0038e-01]],

         [[ 8.8721e-01,  6.7413e-01,  2.0735e-01,  ...,  1.4188e+00,
           -6.1319e-01, -9.0314e-01],
          [ 5.7138e-01,  9.3105e-01,  1.0111e-01,  ..., -9.6662e-01,
            8.9391e-01,  8.1402e-01],
          [ 2.2155e-01,  3.8815e-02,  1.8692e-01,  ...,  1.3422e+00,
            4.2641e-01, -4.3628e-01],
          ...,
          [-4.3834e-01,  7.8366e-01,  1.6903e+00,  ...,  4.7545e-01,
            7.6807e-01,  5.3027e-01],
          [-6.3775e-01,  2.9824e-02,  1.5722e+00,  ...,  9.0334e-01,
            6.0662e-01,  8.2394e-01],
          [-2.4989e-05,  1.0377e+00,  1.0059e+00,  ...,  9.5971e-01,
            1.1585e+00,  7.2809e-01]]],

        ...,

        [[[ 3.7881e-01, -1.0129e+00, -1.0404e-01,  ...,  2.3272e-01,
           -4.4801e-01, -1.4095e+00],
          [-8.0255e-01,  1.2074e-01, -9.3848e-01,  ..., -8.7483e-01,
           -5.0093e-01, -2.6335e-01],
          [-1.1134e+00, -3.5584e-01, -1.5385e+00,  ..., -3.1884e-01,
           -3.4743e-01, -1.3455e+00],
          ...,
          [-2.2740e-01, -7.9245e-01, -2.1258e-01,  ..., -4.2342e-01,
           -1.1332e+00, -1.5354e+00],
          [-5.1047e-02, -9.7859e-01, -3.4812e-01,  ..., -7.7539e-01,
           -9.8685e-01, -1.9749e+00],
          [-2.0676e-01, -2.6194e-01, -2.8099e-01,  ..., -1.2912e+00,
           -1.4854e+00, -1.4087e+00]],

         [[-7.3873e-02, -1.0426e+00, -2.1859e-01,  ..., -3.4034e-01,
           -6.9496e-01, -3.5022e-01],
          [-1.2413e+00, -5.0393e-01, -4.7642e-01,  ..., -1.1679e+00,
           -1.2823e-01, -7.5274e-01],
          [-1.3347e+00, -8.0531e-02, -1.8312e+00,  ..., -5.9717e-01,
           -1.5711e+00, -9.1540e-01],
          ...,
          [-6.7847e-01, -5.7596e-01, -1.4472e+00,  ..., -7.4821e-01,
           -1.3181e+00, -1.0511e+00],
          [-7.0417e-02, -4.8024e-01, -1.2953e+00,  ..., -2.0349e+00,
           -1.2606e+00, -2.2722e+00],
          [-1.4494e-01, -2.8143e-01, -1.5050e+00,  ..., -6.6654e-01,
           -1.6267e+00, -1.2808e+00]],

         [[ 9.3752e-01,  2.7484e-01, -9.5447e-01,  ...,  8.8638e-01,
           -3.1258e-02, -6.1985e-01],
          [ 4.9446e-01, -3.0070e-01,  8.8039e-01,  ..., -3.5695e-01,
           -1.1930e-01,  9.3498e-01],
          [-7.6046e-01, -1.6154e-01, -7.9942e-01,  ...,  1.0224e+00,
            4.0754e-01, -5.4409e-01],
          ...,
          [ 1.4601e-01,  3.1934e-01,  8.0593e-01,  ...,  3.8228e-01,
           -4.7611e-01,  3.4494e-01],
          [-3.0442e-01, -1.0315e+00,  8.1880e-01,  ...,  1.1862e-02,
            2.9394e-01, -1.3929e-01],
          [ 6.0173e-01, -5.3001e-01,  9.5803e-01,  ...,  7.9090e-01,
            2.2840e-01,  1.9992e-01]],

         [[ 8.9065e-01,  6.7350e-01,  2.0627e-01,  ...,  1.4216e+00,
           -6.1006e-01, -9.0994e-01],
          [ 5.6891e-01,  9.2989e-01,  1.0139e-01,  ..., -9.7078e-01,
            8.9229e-01,  8.1672e-01],
          [ 2.2205e-01,  3.3199e-02,  1.8477e-01,  ...,  1.3449e+00,
            4.3009e-01, -4.3757e-01],
          ...,
          [-3.1431e-01,  8.4123e-01,  1.6555e+00,  ...,  4.7952e-01,
            6.6309e-01,  6.2558e-01],
          [-6.5987e-01,  1.1704e-01,  1.5263e+00,  ...,  9.9806e-01,
            6.0882e-01,  8.2521e-01],
          [-9.3685e-02,  1.1655e+00,  8.9978e-01,  ...,  9.1807e-01,
            1.2599e+00,  5.3026e-01]]],

        [[[ 3.7887e-01, -1.0148e+00, -1.0429e-01,  ...,  2.3328e-01,
           -4.4793e-01, -1.4099e+00],
          [-8.0380e-01,  1.1888e-01, -9.4046e-01,  ..., -8.7514e-01,
           -5.0451e-01, -2.6578e-01],
          [-1.1163e+00, -3.5802e-01, -1.5427e+00,  ..., -3.2207e-01,
           -3.5010e-01, -1.3481e+00],
          ...,
          [-2.5032e-01, -7.4140e-01, -2.1611e-01,  ..., -4.5819e-01,
           -1.0988e+00, -1.5257e+00],
          [-1.7300e-01, -9.7451e-01, -3.3704e-01,  ..., -7.2036e-01,
           -9.4890e-01, -1.9047e+00],
          [-2.1151e-01, -2.2349e-01, -3.0061e-01,  ..., -1.3477e+00,
           -1.4600e+00, -1.4397e+00]],

         [[-7.4206e-02, -1.0432e+00, -2.1885e-01,  ..., -3.4395e-01,
           -6.9446e-01, -3.4897e-01],
          [-1.2443e+00, -5.0543e-01, -4.7744e-01,  ..., -1.1704e+00,
           -1.2685e-01, -7.5525e-01],
          [-1.3377e+00, -7.9505e-02, -1.8341e+00,  ..., -5.9765e-01,
           -1.5734e+00, -9.1514e-01],
          ...,
          [-6.0544e-01, -5.7026e-01, -1.5077e+00,  ..., -7.5749e-01,
           -1.3398e+00, -1.0650e+00],
          [ 3.6166e-02, -4.7411e-01, -1.3510e+00,  ..., -2.0254e+00,
           -1.2908e+00, -2.2805e+00],
          [ 2.6105e-02, -3.3543e-01, -1.5432e+00,  ..., -6.9454e-01,
           -1.6636e+00, -1.2343e+00]],

         [[ 9.3650e-01,  2.7316e-01, -9.5852e-01,  ...,  8.8653e-01,
           -3.1847e-02, -6.2069e-01],
          [ 4.9374e-01, -3.0041e-01,  8.7971e-01,  ..., -3.5915e-01,
           -1.2014e-01,  9.3505e-01],
          [-7.6167e-01, -1.6516e-01, -8.0321e-01,  ...,  1.0204e+00,
            4.0420e-01, -5.4611e-01],
          ...,
          [ 2.7208e-01,  3.6507e-01,  8.3463e-01,  ...,  4.0893e-01,
           -4.8551e-01,  3.2502e-01],
          [-5.4528e-01, -9.3188e-01,  8.7095e-01,  ...,  1.0679e-02,
            4.1194e-01, -1.0300e-01],
          [ 4.9015e-01, -4.7795e-01,  9.2759e-01,  ...,  7.8037e-01,
            1.8870e-01,  1.8663e-01]],

         [[ 8.9137e-01,  6.7537e-01,  2.0731e-01,  ...,  1.4242e+00,
           -6.0848e-01, -9.0874e-01],
          [ 5.7204e-01,  9.3216e-01,  1.0244e-01,  ..., -9.7021e-01,
            8.9468e-01,  8.1622e-01],
          [ 2.2283e-01,  3.3323e-02,  1.8434e-01,  ...,  1.3443e+00,
            4.2978e-01, -4.3566e-01],
          ...,
          [-2.2053e-01,  8.7080e-01,  1.6212e+00,  ...,  4.8182e-01,
            6.7962e-01,  6.2657e-01],
          [-7.8187e-01,  1.6520e-01,  1.5021e+00,  ...,  1.0492e+00,
            6.9901e-01,  8.6582e-01],
          [-1.1562e-01,  1.2524e+00,  7.8486e-01,  ...,  8.9864e-01,
            1.2786e+00,  4.5031e-01]]],

        [[[ 3.8042e-01, -1.0157e+00, -1.0492e-01,  ...,  2.3506e-01,
           -4.4954e-01, -1.4108e+00],
          [-8.0537e-01,  1.1882e-01, -9.4281e-01,  ..., -8.7574e-01,
           -5.0227e-01, -2.6616e-01],
          [-1.1198e+00, -3.5971e-01, -1.5452e+00,  ..., -3.2310e-01,
           -3.4905e-01, -1.3497e+00],
          ...,
          [-3.9963e-02, -7.2257e-01, -2.5422e-01,  ..., -4.7732e-01,
           -1.0690e+00, -1.5482e+00],
          [ 2.6502e-02, -9.1952e-01, -3.4455e-01,  ..., -7.1310e-01,
           -9.3568e-01, -1.8780e+00],
          [-2.8392e-01, -4.1943e-01, -3.7324e-01,  ..., -1.3746e+00,
           -1.4556e+00, -1.4259e+00]],

         [[-7.4021e-02, -1.0440e+00, -2.1981e-01,  ..., -3.4815e-01,
           -6.9472e-01, -3.4820e-01],
          [-1.2471e+00, -5.0837e-01, -4.7830e-01,  ..., -1.1688e+00,
           -1.2859e-01, -7.5787e-01],
          [-1.3406e+00, -8.1378e-02, -1.8372e+00,  ..., -5.9666e-01,
           -1.5736e+00, -9.1573e-01],
          ...,
          [-7.2981e-01, -5.7384e-01, -1.5248e+00,  ..., -7.7529e-01,
           -1.3764e+00, -1.0389e+00],
          [-9.2240e-02, -4.9245e-01, -1.3387e+00,  ..., -2.0503e+00,
           -1.3154e+00, -2.2657e+00],
          [-1.9961e-01, -3.2101e-01, -1.5318e+00,  ..., -7.1576e-01,
           -1.6716e+00, -1.2167e+00]],

         [[ 9.3784e-01,  2.7220e-01, -9.6078e-01,  ...,  8.8829e-01,
           -3.5619e-02, -6.2022e-01],
          [ 4.9234e-01, -3.0177e-01,  8.7965e-01,  ..., -3.6167e-01,
           -1.1949e-01,  9.3537e-01],
          [-7.6492e-01, -1.6762e-01, -8.0711e-01,  ...,  1.0213e+00,
            4.0193e-01, -5.4646e-01],
          ...,
          [ 5.1700e-01,  3.3006e-01,  8.1116e-01,  ...,  4.0618e-01,
           -4.8952e-01,  3.1623e-01],
          [-1.3833e-01, -9.3273e-01,  8.9023e-01,  ...,  2.3928e-02,
            4.3912e-01, -1.1734e-01],
          [ 3.2406e-01, -6.7952e-01,  8.8508e-01,  ...,  7.7547e-01,
            1.8584e-01,  2.1956e-01]],

         [[ 8.9341e-01,  6.7749e-01,  2.0655e-01,  ...,  1.4277e+00,
           -6.1087e-01, -9.0753e-01],
          [ 5.7274e-01,  9.3312e-01,  1.0202e-01,  ..., -9.6845e-01,
            8.9981e-01,  8.1548e-01],
          [ 2.2241e-01,  3.2400e-02,  1.8548e-01,  ...,  1.3454e+00,
            4.3240e-01, -4.3663e-01],
          ...,
          [-8.9979e-03,  8.4620e-01,  1.6043e+00,  ...,  4.6532e-01,
            6.9437e-01,  5.7828e-01],
          [-5.3942e-01,  1.8125e-01,  1.5477e+00,  ...,  1.0629e+00,
            7.1541e-01,  8.7044e-01],
          [-3.7573e-01,  1.0569e+00,  7.7045e-01,  ...,  8.9445e-01,
            1.2792e+00,  4.6282e-01]]]], device='cuda:0')
nitinmukesh commented 2 months ago

@itechmusic Please could you suggest something here to make it work

itechmusic commented 2 months ago

(recon = vae.decode_latents(pred_latents))

Sorry, we are not familiar with coding on Windows 11. As you identify that the issue occurs at (recon = vae.decode_latents(pred_latents)), maybe you can try decoding pred_latents one by one instead of using batch inference. You may start with setting batch_size=1 at https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L143, which make UNet and VAE-decode use batch=1