Audio-WestlakeU / NBSS

The official repo of NBC & SpatialNet for multichannel speech separation, denoising, and dereverberation
MIT License
175 stars 21 forks source link

dataset issue #2

Closed ermu-tech closed 2 years ago

ermu-tech commented 2 years ago

Hi! I've learned from your paper that you remix the WSJ0 dataset in the manner used in Fasnet. Actually I don't have WSJ0 dataset, but I generated mixed utterances with the data generation script used in Fasnet (https://github.com/yluo42/TAC/tree/master/data), which also contains 20000, 5000 and 3000 mixed utterances for training, validation and test respectively. So I wonder if I can directly used the data I generated above to train your model? And it would be great if you could give me some advice on how to modify the code! Looking for your reply! Thank you!

quancs commented 2 years ago

You need to change the dataset's __getitem__ function. As we could see from its comments, https://github.com/Audio-WestlakeU/NBSS/blob/dfd25877a21b7a48aff958512479f22ac77c1994/data_loaders/ss_semi_online_dataset.py#L69-L78 what you should return is xm the mixture, ys the ground truth speeches, and the paras.

To adapt to the dataset of Fasnet, the __getitem__ of your Dataset can be something like:

    def __getitem__(self, index: Dict[str, int]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:  # type: ignore
        sidx = index['speech_index'] # an index in [0, 20000), [0, 5000), [0, 3000)
        # xm= read the 'sidx'-th speech mixture in time domain
        # ys= read the 'sidx'-th speech targets in time domain
        paras = {
            "index": sidx,
            # any other paras you need for evaluation
        }

Also, you need to remove the unnecessary code (like unnecessary parameters) in ss_semi_online_data_module.py and ss_semi_online_sampler.py to make your own datamodule and sampler. But collate_func_train, collate_func_val, and collate_func_test is necessary in datamodule as they are linked to the corresponding functions of class NBSS_ifp in https://github.com/Audio-WestlakeU/NBSS/blob/dfd25877a21b7a48aff958512479f22ac77c1994/cli_ifp.py#L26-L28

quancs commented 2 years ago

Another simpler implementation could be something like the following code. As if your dataset returns xm, ys and paras in the correct shape and type, it works for the other code in NBSS.

class YourDatasetClass(Dataset):
    def __init__(self, speech_paths) -> None:
        super().__init__()
        self.speech_paths=speech_paths

    def __getitem__(self,index:int)-> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:  # type: ignore
        """returns the indexed item

        Args:
            index: index

        Returns:
            Tensor: xm of shape [channel, time] in time domain
            Tensor: ys of shape [spk, channel, time] in time domain
            dict: paras used
        """
        # read mix and targets
        # xm= read the 'index'-th speech mixture, the path of which is self.speech_paths[index]['mix']
        # ys= read the 'index'-th speech targets, the paths of which is self.speech_paths[index]['spk_1'] and self.speech_paths[index]['spk_2'] for 2-speaker case
        paras = {
            "index": index,
            # any other paras you need for evaluation
        }
        return xm, ys, paras

    def __len__(self):
        return len(self.speech_paths)

class SS_SemiOnlineDataModule(LightningDataModule):

    def __init__(
        self,
        speech_dir_path: str, # your speech dir including the generated multi-channel mixture and multi-channel target speeches
        batch_size: List[int] = [5, 5],
        speaker_num: int = 2,
        num_workers: int = 5,
        collate_func_train: Callable = None,
        collate_func_val: Callable = None,
        collate_func_test: Callable = None,
    ):
        super().__init__()
        self.speech_dir_path = speech_dir_path
        self.batch_size = batch_size[0]
        self.batch_size_val = batch_size[1]

        self.speaker_num = speaker_num
        self.num_workers = num_workers

        self.collate_func_train = collate_func_train
        self.collate_func_val = collate_func_val
        self.collate_func_test = collate_func_test

        self.prepare_data()

    def prepare_data(self):
        """prepare data to self.speech_pathes
        """
        self.speech_pathes=dict()
        self.speech_pathes['train']=... # the pathes of training mixtures and targets, each element can be a Dict contains the path of a mixture and its corresponding target pairs
        self.speech_pathes['val']=... # the pathes of validation mixtures and targets, each element can be a Dict contains the path of a mixture and its corresponding target pairs
        self.speech_pathes['test']=... # the pathes of test mixtures and targets, each element can be a Dict contains the path of a mixture and its corresponding target pairs

    def setup(self, stage=None):
        # YourDatasetClass is your Dataset implementation, which receives the paths and its __getitem__ function returns xm, ys and paras
        self.train = YourDatasetClass(speeches=self.speech_pathes['train'])
        self.val = YourDatasetClass(speeches=self.speech_pathes['train'])
        self.test = YourDatasetClass(speeches=self.speech_pathes['train'])

    # Sampler is removed as it is used for online shuffuling rirs and and speeches, thus it's unnecessary for already generated dataset
    def train_dataloader(self) -> DataLoader:
        prefetch_factor = self.batch_size
        persistent_workers = False
        return DataLoader(self.train,
                          batch_size=self.batch_size,
                          collate_fn=self.collate_func_train,
                          num_workers=self.num_workers,
                          prefetch_factor=prefetch_factor,
                          pin_memory=True,
                          persistent_workers=persistent_workers)

    def val_dataloader(self) -> DataLoader:
        prefetch_factor = self.batch_size_val
        persistent_workers = False
        return DataLoader(self.val,
                          batch_size=self.batch_size_val,
                          collate_fn=self.collate_func_val,
                          num_workers=self.num_workers,
                          prefetch_factor=prefetch_factor,
                          pin_memory=True,
                          persistent_workers=persistent_workers)

    def test_dataloader(self) -> DataLoader:
        prefetch_factor = 2
        return DataLoader(
            self.test,
            batch_size=1,
            collate_fn=self.collate_func_test,
            num_workers=1,
            prefetch_factor=prefetch_factor,
        )
quancs commented 2 years ago

Another thing you need to change is the config file: Remove/Add the paramters in the config file you removed/added in the code

ermu-tech commented 2 years ago

Thank you for your quick response and advice! I've tried with my dataset, but I still can't run it :( . Maybe it is because I couldn't run your original code and it's hard for me to make some modification. I've noticed that you also did a contrast experiment with Fasnet, so maybe it's much easier for you to publish another code that train your model with dataset of Fasnet? And if you haven't generated the dataset with the script used in Fasnet (https://github.com/yluo42/TAC/tree/master/data), I'll soon email you the dataset to quanchangsheng@outlook.com, which is displayed on your home page. I will be so much appreciate it if you could help with that! Looking forward to your reply and thanks again!

quancs commented 2 years ago

You are welcome. ^^

I've tried with my dataset, but I still can't run it :( . Maybe it is because I couldn't run your original code and it's hard for me to make some modification.

You could post your error messages. Or send me your code. Maybe I can help some.

I've noticed that you also did a contrast experiment with Fasnet, so maybe it's much easier for you to publish another code that train your model with dataset of Fasnet?

I did train my model with the dataset of fasnet. But the code is not the lastest version. To adapt to the dataset of fasnet, you need to implement your own dataset things like dataset, datamodule.

You could refer to the PyTorch Lightning doc to implememt your own datamodule. https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.datamodule.html#pytorch_lightning.core.datamodule.LightningDataModule

And the lightning cli: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_cli.html

ermu-tech commented 2 years ago

Thanks a lot! And I will try again :)