Closed ermu-tech closed 2 years ago
You need to change the dataset's __getitem__ function. As we could see from its comments, https://github.com/Audio-WestlakeU/NBSS/blob/dfd25877a21b7a48aff958512479f22ac77c1994/data_loaders/ss_semi_online_dataset.py#L69-L78 what you should return is xm the mixture, ys the ground truth speeches, and the paras.
To adapt to the dataset of Fasnet, the __getitem__ of your Dataset can be something like:
def __getitem__(self, index: Dict[str, int]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: # type: ignore
sidx = index['speech_index'] # an index in [0, 20000), [0, 5000), [0, 3000)
# xm= read the 'sidx'-th speech mixture in time domain
# ys= read the 'sidx'-th speech targets in time domain
paras = {
"index": sidx,
# any other paras you need for evaluation
}
Also, you need to remove the unnecessary code (like unnecessary parameters) in ss_semi_online_data_module.py and ss_semi_online_sampler.py to make your own datamodule and sampler. But collate_func_train, collate_func_val, and collate_func_test is necessary in datamodule as they are linked to the corresponding functions of class NBSS_ifp in https://github.com/Audio-WestlakeU/NBSS/blob/dfd25877a21b7a48aff958512479f22ac77c1994/cli_ifp.py#L26-L28
Another simpler implementation could be something like the following code. As if your dataset returns xm, ys and paras in the correct shape and type, it works for the other code in NBSS.
class YourDatasetClass(Dataset):
def __init__(self, speech_paths) -> None:
super().__init__()
self.speech_paths=speech_paths
def __getitem__(self,index:int)-> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: # type: ignore
"""returns the indexed item
Args:
index: index
Returns:
Tensor: xm of shape [channel, time] in time domain
Tensor: ys of shape [spk, channel, time] in time domain
dict: paras used
"""
# read mix and targets
# xm= read the 'index'-th speech mixture, the path of which is self.speech_paths[index]['mix']
# ys= read the 'index'-th speech targets, the paths of which is self.speech_paths[index]['spk_1'] and self.speech_paths[index]['spk_2'] for 2-speaker case
paras = {
"index": index,
# any other paras you need for evaluation
}
return xm, ys, paras
def __len__(self):
return len(self.speech_paths)
class SS_SemiOnlineDataModule(LightningDataModule):
def __init__(
self,
speech_dir_path: str, # your speech dir including the generated multi-channel mixture and multi-channel target speeches
batch_size: List[int] = [5, 5],
speaker_num: int = 2,
num_workers: int = 5,
collate_func_train: Callable = None,
collate_func_val: Callable = None,
collate_func_test: Callable = None,
):
super().__init__()
self.speech_dir_path = speech_dir_path
self.batch_size = batch_size[0]
self.batch_size_val = batch_size[1]
self.speaker_num = speaker_num
self.num_workers = num_workers
self.collate_func_train = collate_func_train
self.collate_func_val = collate_func_val
self.collate_func_test = collate_func_test
self.prepare_data()
def prepare_data(self):
"""prepare data to self.speech_pathes
"""
self.speech_pathes=dict()
self.speech_pathes['train']=... # the pathes of training mixtures and targets, each element can be a Dict contains the path of a mixture and its corresponding target pairs
self.speech_pathes['val']=... # the pathes of validation mixtures and targets, each element can be a Dict contains the path of a mixture and its corresponding target pairs
self.speech_pathes['test']=... # the pathes of test mixtures and targets, each element can be a Dict contains the path of a mixture and its corresponding target pairs
def setup(self, stage=None):
# YourDatasetClass is your Dataset implementation, which receives the paths and its __getitem__ function returns xm, ys and paras
self.train = YourDatasetClass(speeches=self.speech_pathes['train'])
self.val = YourDatasetClass(speeches=self.speech_pathes['train'])
self.test = YourDatasetClass(speeches=self.speech_pathes['train'])
# Sampler is removed as it is used for online shuffuling rirs and and speeches, thus it's unnecessary for already generated dataset
def train_dataloader(self) -> DataLoader:
prefetch_factor = self.batch_size
persistent_workers = False
return DataLoader(self.train,
batch_size=self.batch_size,
collate_fn=self.collate_func_train,
num_workers=self.num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
persistent_workers=persistent_workers)
def val_dataloader(self) -> DataLoader:
prefetch_factor = self.batch_size_val
persistent_workers = False
return DataLoader(self.val,
batch_size=self.batch_size_val,
collate_fn=self.collate_func_val,
num_workers=self.num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
persistent_workers=persistent_workers)
def test_dataloader(self) -> DataLoader:
prefetch_factor = 2
return DataLoader(
self.test,
batch_size=1,
collate_fn=self.collate_func_test,
num_workers=1,
prefetch_factor=prefetch_factor,
)
Another thing you need to change is the config file: Remove/Add the paramters in the config file you removed/added in the code
Thank you for your quick response and advice! I've tried with my dataset, but I still can't run it :( . Maybe it is because I couldn't run your original code and it's hard for me to make some modification. I've noticed that you also did a contrast experiment with Fasnet, so maybe it's much easier for you to publish another code that train your model with dataset of Fasnet? And if you haven't generated the dataset with the script used in Fasnet (https://github.com/yluo42/TAC/tree/master/data), I'll soon email you the dataset to quanchangsheng@outlook.com, which is displayed on your home page. I will be so much appreciate it if you could help with that! Looking forward to your reply and thanks again!
You are welcome. ^^
I've tried with my dataset, but I still can't run it :( . Maybe it is because I couldn't run your original code and it's hard for me to make some modification.
You could post your error messages. Or send me your code. Maybe I can help some.
I've noticed that you also did a contrast experiment with Fasnet, so maybe it's much easier for you to publish another code that train your model with dataset of Fasnet?
I did train my model with the dataset of fasnet. But the code is not the lastest version. To adapt to the dataset of fasnet, you need to implement your own dataset things like dataset, datamodule.
You could refer to the PyTorch Lightning doc to implememt your own datamodule. https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.datamodule.html#pytorch_lightning.core.datamodule.LightningDataModule
And the lightning cli: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_cli.html
Thanks a lot! And I will try again :)
Hi! I've learned from your paper that you remix the WSJ0 dataset in the manner used in Fasnet. Actually I don't have WSJ0 dataset, but I generated mixed utterances with the data generation script used in Fasnet (https://github.com/yluo42/TAC/tree/master/data), which also contains 20000, 5000 and 3000 mixed utterances for training, validation and test respectively. So I wonder if I can directly used the data I generated above to train your model? And it would be great if you could give me some advice on how to modify the code! Looking for your reply! Thank you!