TorchDSP / torchsig

TorchSig is an open-source signal processing machine learning toolkit based on the PyTorch data handling pipeline.
MIT License
174 stars 40 forks source link

Narrowband dataset generation fails when passing a subset of the signal classes to use. #253

Open nsbruce opened 3 weeks ago

nsbruce commented 3 weeks ago

Describe the bug Narrowband dataset generation fails when passing a subset of the signal classes to use.

To Reproduce In generate_narrowband.py, add a classes list to the ModulationsDataset instantiation:

        ds = ModulationsDataset(
            classes=["lfm_radar"],
            level=config.level,
            num_samples=num_samples,
            num_iq_samples=config.num_iq_samples,
            use_class_idx=config.use_class_idx,
            include_snr=config.include_snr,
            eb_no=config.eb_no,
        )

And try generating the dataset. It fails with an indexing error.

Expected behavior The dataset generates.

Additional context I have fixed this locally by editing the __init__ function of the ModulateNarrowbandDataset class to only pass relevant datasets to the final super().__init__() call.

    def __init__(
        self,
        modulations: Optional[Union[List, Tuple]] = torchsig_signals.class_list,
        num_iq_samples: int = 100,
        num_samples_per_class: int = 100,
        iq_samples_per_symbol: Optional[int] = None,
        random_data: bool = False,
        random_pulse_shaping: bool = False,
        **kwargs,
    ) -> None:
        modulations = (
            torchsig_signals.class_list
            if modulations is None
            else modulations
        )

        constellation_list = [m for m in map(str.lower, modulations) if m in torchsig_signals.constellation_signals]
        fsk_list = [m for m in map(str.lower, modulations) if m in torchsig_signals.fsk_signals]
        fm_list = [m for m in map(str.lower, modulations) if m in torchsig_signals.fm_signals]
        am_list = [m for m in map(str.lower, modulations) if m in torchsig_signals.am_signals]
        lfm_list = [m for m in map(str.lower, modulations) if m in torchsig_signals.lfm_signals]
        chirpss_list = [m for m in map(str.lower, modulations) if m in torchsig_signals.chirpss_signals]

        datasets = []
        if len(constellation_list) > 0:
            datasets.append(
                ConstellationDataset(
                    constellations=constellation_list,
                    num_iq_samples=num_iq_samples,
                    num_samples_per_class=num_samples_per_class,
                    iq_samples_per_symbol=2
                    if iq_samples_per_symbol is None
                    else iq_samples_per_symbol,
                    random_data=random_data,
                    random_pulse_shaping=random_pulse_shaping,
                    **kwargs,
                )
            )

        if len(fsk_list) > 0:
            datasets.append(
                FSKDataset(
                    modulations=fsk_list,
                    num_iq_samples=num_iq_samples,
                    num_samples_per_class=num_samples_per_class,
                    iq_samples_per_symbol=8,
                    random_data=random_data,
                    random_pulse_shaping=random_pulse_shaping,
                    **kwargs,
                )
            )

        if len(fm_list) > 0:
            datasets.append(
                FMDataset(
                    num_iq_samples=num_iq_samples,
                    num_samples_per_class=num_samples_per_class,
                    random_data=random_data,
                    **kwargs,
                )
            )

        if len(am_list) > 0:
            datasets.append(
                AMDataset(
                    modulations=am_list,
                    num_iq_samples=num_iq_samples,
                    num_samples_per_class=num_samples_per_class,
                    random_data=random_data,
                    **kwargs,
                )
            )

        if len(lfm_list) > 0:
            datasets.append(
                LFMDataset(
                    constellations=lfm_list,
                    num_iq_samples=num_iq_samples,
                    num_samples_per_class=num_samples_per_class,
                    random_data=random_data,
                    **kwargs,
                )
            )

        if len(chirpss_list) > 0:
            datasets.append(
                ChirpSSDataset(
                    constellations=chirpss_list,
                    num_iq_samples=num_iq_samples,
                    num_samples_per_class=num_samples_per_class,
                    random_data=random_data,
                    **kwargs,
                )
            )

        super(ModulateNarrowbandDataset, self).__init__(datasets)

I'm happy to add this to a pull request if you want.

ereoh commented 2 weeks ago

Hi! Thanks for bringing this issue to our attention. We are have fixed this for the next release v0.6.1 in a few weeks.

nsbruce commented 3 days ago

Hi @ereoh also note that in the current version some of the modulation lists are not currently passed to the datasets. For example,

AMDataset(
  num_iq_samples=num_iq_samples,
  num_samples_per_class=num_samples_per_class,
  random_data=random_data,
  **kwargs,
))

should become

AMDataset(
  modulations=am_list,
  num_iq_samples=num_iq_samples,
  num_samples_per_class=num_samples_per_class,
  random_data=random_data,
  **kwargs,
))

A couple of others are missing too.

nsbruce commented 1 day ago

I'm not sure if you'd like this to be a new issue or not but reading the narrowband dataset also has problems if the dataset was generated using a subset of signal classes. In the TorchsigNarrowband.getitem method, the class_name is set using self._idx_to_name_dict[mod] which disregards any custom list.

My current fix for this is a subclass:

class MyNarrowband(TorchSigNarrowband):
    _idx_to_name_dict = dict(zip(range(len(MY_CLASSES)), MY_CLASSES))
    _name_to_idx_dict = dict(zip(MY_CLASSES, range(len(MY_CLASSES))))

    @staticmethod
    def convert_idx_to_name(idx: int) -> str:
        return MyTorchSigNarrowband._idx_to_name_dict.get(idx, "unknown")

    @staticmethod
    def convert_name_to_idx(name: str) -> int:
        return MyTorchSigNarrowband._name_to_idx_dict.get(name, -1)