microsoft / i-Code

MIT License
1.68k stars 161 forks source link

Unable to reproduce the results of the paper #134

Open XinMing0411 opened 2 months ago

XinMing0411 commented 2 months ago

Hello. I tried using the demo code of Codi (https://github.com/microsoft/i-Code/tree/main/i-Code-V3) to reproduce results on the AudioCaps dataset. However, I was unable to achieve the results reported in the paper for the audio captioning and TTA tasks, with a significant discrepancy in performance:

Frechet Audio Distance: 12.3379363 Kullback-Leibler Divergence (Sigmoid): 9.3400078 Kullback-Leibler Divergence (Softmax): 3.8197691 Inception Score Mean: 2.9589245 Inception Score Std: 0.2177440 Frechet Distance: 54.1079137 Bleu-1: 0.2448 Bleu-2: 0.0918 Bleu-3: 0.0287 Bleu-4: 0.0097 Rouge: 0.1928 CIDEr: 0.0689 METEOR: 0.0877 SPICE: 0.0504 SPIDEr: 0.0596 Here is my code:

import os
import soundfile as sf
import numpy as np
import torchaudio
import torch
import json
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from core.models.model_module_infer import model_module
def gen_wave_text(gt_text_prompt, gt_audio_wavs):
    # import pdb;pdb.set_trace()
    audio_wave = inference_tester.inference(xtype = ['audio'], condition = gt_text_prompt, condition_types = ['text'],
                scale = 3.5, n_samples = 1,  ddim_steps = 200)[0]
    text = inference_tester.inference(xtype = ['text'], condition = gt_audio_wavs, condition_types = ['audio'], 
                                      n_samples = 1, ddim_steps = 200, scale = 3.5)
    return audio_wave[0] , text[0]

def load_json(fname):
    with open(fname, "r") as f:
        data = json.load(f)
        return data

class AudioDataset_Multimodal(Dataset):
    def __init__(self,dataset_json='audiocaps_test_nonrepeat_subset_0.json'):
        self.metadata_root = load_json(dataset_json)
        self.data = self.metadata_root["data"]
        self.pad_time = 10.23
        self.relative_path = "your audio relative_path"

    def __getitem__(self, index):
        datum = self.data[index]

        wav_path = os.path.join(self.relative_path,datum["wav"])
        audio_wavs, sr = torchaudio.load(wav_path)
        audio_wavs = torchaudio.functional.resample(waveform=audio_wavs, orig_freq=sr, new_freq=16000).mean(0)[:int(16000 * self.pad_time)]
        padding = torch.zeros([int(16000 * self.pad_time) - audio_wavs.size(0)])
        audio_wavs = torch.cat([audio_wavs, padding], 0)

        text = datum["caption"]
        fname = datum["wav"]
        data = {
            "text": text,  # list
            "waveform": "" if (audio_wavs is None) else audio_wavs.float(),
            "fname":fname,
        }

        return data

    def __len__(self):
        return len(self.data)

model_load_paths = ['CoDi_encoders.pth', 'CoDi_text_diffuser.pth', 'CoDi_audio_diffuser_m.pth', 'CoDi_video_diffuser_8frames.pth']
inference_tester = model_module(data_dir='checkpoints/', pth=model_load_paths, fp16=False) # turn on fp16=True if loading fp16 weights
inference_tester = inference_tester.cuda()
inference_tester = inference_tester.eval()

val_dataset = AudioDataset_Multimodal()
val_loader = DataLoader(val_dataset, batch_size=1,)
batchs = iter(val_loader)

path = "your output path"

for i, batch in enumerate(batchs):
    text = batch['text']
    waveform = batch['waveform']
    fname = os.path.splitext(os.path.basename(batch['fname'][0]))[0]

    pre_audio_wave , pre_text = gen_wave_text(text,waveform)

    todo_waveform = pre_audio_wave[0]
    todo_waveform = (todo_waveform / np.max(np.abs(todo_waveform))) * 0.8  # Normalize the energy of the generation output
    sf.write(os.path.join(path,"%s.wav" %(fname)), todo_waveform, samplerate=16000)

    todo_text = pre_text[0]
    with open(os.path.join(path,"%s.txt" %(fname)),"w") as f: 
        f.write(todo_text)  

the dataset_json is provided by AudioLDM I would like to ask what the specific issues might be?

haoheliu commented 2 months ago

Got the same issue here. Looking forward to the replies from the authors.

zinengtang commented 2 months ago

Can you give me a text prompt and the generated audio and random seed. So, I know it matches expected output.