CUNY-CL / yoyodyne

Small-vocabulary sequence-to-sequence generation with optional feature conditioning
Apache License 2.0
25 stars 15 forks source link

Migrates to data modules #110

Closed kylebgorman closed 11 months ago

kylebgorman commented 11 months ago

This PR migrates us to use lightning's built-in notion of data modules. This is a sort of data "God class" (by design) which holds the information about file paths, file parsing, collation/batching/padding, and the index, and can be used more or less the same during training, validation, prediction, or testing.

By migrating to this we are much closer to being able to use Lightning (and Torch!) 2.0, particularly the automagical LightningCLI, and we get rid of a lot of boilerplate in the training and prediction scripts. (I also thinks this will help enormously with testing.) Thus this is relevant to #5 and #60.

Two big notes:

(I am closing #91 in favor of this; that had a bug I could never squash so I started over and worked more incrementally.)

kylebgorman commented 11 months ago
  1. It seems that we do still need to use the dataset directly sometimes, and we access it with e.g.: datamodule.train_dataloader().dataset. To be sure, this is not a separate initialization from what PTL does under the hood?

Yes, that's how I'd do it. It comes up once, in feeding data to transducer EM.

  1. If I am a user that wants to modify some dataset behavior in a way that is not easily supported from the CLI (say, I want a 4th column for some metadata available to my script, or I want to somehow change my data at runtime based on some other info), what is the process of updating the code now? I just want to make sure that we have not increased complexity too much from this perspective.

I'd say it's no more or less complex than before.


In README.md https://github.com/CUNY-CL/yoyodyne/pull/110#discussion_r1265675164:

@@ -54,8 +54,33 @@ import yoyodyne

Usage

-See yoyodyne-predict --help and -yoyodyne-train --help. +### Training + +Training is performed by the yoyodyne-train script. One +must specify the following required arguments: + +- --train: path to TSV file containing training data +- --val: path to TSV file containing validation data +- --experiment: name of experiment (pick something unique) +- --model_dir: path for model metadata and checkpoints output during

  • training
  • +The user can also specify as well as various optional training and architectural

I think "...can also specify as well as various optional..." might be a typo? I am having trouble parsing it.

On tests/collator_test.py https://github.com/CUNY-CL/yoyodyne/pull/110#discussion_r1265677554:

Apologies if it's been discussed already, but why are we deleting tests?

In yoyodyne/data/collators.py https://github.com/CUNY-CL/yoyodyne/pull/110#discussion_r1265683411:

class LengthError(Exception): pass

@.***

Making the Collator a dataclass scares me a bit since it is used under the hood by the pytorch dataloader, and I know that there are sometimes surprising behaviors from dataclasses due to things that happen implicitly around init.

Just wanted to call that out, though realistically I think this is good. Pretty sure all pytorch does with the Collator is call collator(samples).

In yoyodyne/data/datamodules.py https://github.com/CUNY-CL/yoyodyne/pull/110#discussion_r1265690313:

  • if self.index.has_features
  • else 0,
  • max_source_length=max_source_length,
  • max_target_length=max_target_length,
  • )
  • Helpers.

  • @property
  • def paths(self) -> Iterator[str]:
  • if self.train is not None:
  • yield self.train
  • if self.val is not None:
  • yield self.val
  • if self.predict is not None:
  • yield self.predict

Just checking, and maybe I will find out below, if I specify --predict during training, what will happen?

In yoyodyne/data/datamodules.py https://github.com/CUNY-CL/yoyodyne/pull/110#discussion_r1265690529:

  • max_source_length=max_source_length,
  • max_target_length=max_target_length,
  • )
  • Helpers.

  • @property
  • def paths(self) -> Iterator[str]:
  • if self.train is not None:
  • yield self.train
  • if self.val is not None:
  • yield self.val
  • if self.predict is not None:
  • yield self.predict
  • if self.test is not None:
  • yield self.test

Same Q for --test during training.

In yoyodyne/data/datamodules.py https://github.com/CUNY-CL/yoyodyne/pull/110#discussion_r1265691823:

+

  • def log_vocabularies(self) -> None:
  • """Logs this module's vocabularies."""
  • util.log_info(f"Source vocabulary: {self.index.source_map.pprint()}")
  • if self.index.has_features:
  • util.log_info(
  • f"Features vocabulary: {self.index.features_map.pprint()}"
  • )
  • if self.index.has_target:
  • util.log_info(
  • f"Target vocabulary: {self.index.target_map.pprint()}"
  • )
  • def write_index(self, model_dir: str, experiment: str) -> None:
  • """Writes the index."""
  • index_path = self.index.index_path(model_dir, experiment)

I think we've discussed this convention in the past so feel free to ignore, but what about either:

  • self.index.index_path() --> self.index.get_index_path()
  • self.index.index_path() --> self.index.make_index_path()

In yoyodyne/data/datasets.py https://github.com/CUNY-CL/yoyodyne/pull/110#discussion_r1265697244:

  • def has_features(self) -> bool:
  • return self.index.has_features
  • @property
  • def has_target(self) -> bool:
  • return self.index.has_target
  • def _encode(
  • self,
  • symbols: List[str],
  • symbol_map: indexes.SymbolMap,
  • ) -> torch.Tensor:
  • """Encodes a sequence as a tensor of indices with string boundary IDs.
  • Args:
  • string (str): string to be encoded.

string --> symbols

In yoyodyne/data/datasets.py https://github.com/CUNY-CL/yoyodyne/pull/110#discussion_r1265697505:

  • return self.index.has_features
  • @property
  • def has_target(self) -> bool:
  • return self.index.has_target
  • def _encode(
  • self,
  • symbols: List[str],
  • symbol_map: indexes.SymbolMap,
  • ) -> torch.Tensor:
  • """Encodes a sequence as a tensor of indices with string boundary IDs.
  • Args:
  • string (str): string to be encoded.
  • sep (str): separator to use.

remove

In yoyodyne/data/datasets.py https://github.com/CUNY-CL/yoyodyne/pull/110#discussion_r1265703623:

  • symbol_map: indexes.SymbolMap,
  • ) -> Iterator[List[str]]:
  • """Decodes the tensor of indices into lists of symbols.
  • Args:
  • indices (torch.Tensor): 2d tensor of indices.
  • symbol_map (indexes.SymbolMap).
  • Yields:
  • List[str]: Decoded symbols.
  • """
  • for idx in indices.cpu().numpy():
  • yield [
  • symbol_map.symbol(c)
  • for c in idx
  • if c not in self.index.special_idx

Wasn't decoding special symbols controlled by an arg previously? Not very important, but could potentially be useful for debugging.

In yoyodyne/data/datasets.py https://github.com/CUNY-CL/yoyodyne/pull/110#discussion_r1265709677:

  • ]
  • def decode_source(
  • self,
  • indices: torch.Tensor,
  • ) -> Iterator[str]:
  • """Decodes a source tensor.
  • Args:
  • indices (torch.Tensor): 2d tensor of indices.
  • Yields:
  • str: Decoded source strings.
  • """
  • for symbols in self._decode(indices, self.index.source_map):
  • yield self.parser.source_string(symbols)

It is now obvious to me what this does, but could be opaque to someone else. Do you think we can add a comment somewhere indicating that this just uses the correct separator for joining the decoded chars?

In yoyodyne/data/tsv.py https://github.com/CUNY-CL/yoyodyne/pull/110#discussion_r1265712468:

  • raise Error(f"Out of range source column: {self.source_col}")
  • if self.features_col < 0:
  • raise Error(f"Out of range features column: {self.features_col}")
  • if self.features_col < 0:
  • raise Error(f"Out of range features column: {self.features_col}")
  • if self.target_col < 0:
  • raise Error(f"Out of range target column: {self.target_col}")
  • @staticmethod
  • def _tsv_reader(path: str) -> Iterator[str]:
  • with open(path, "r") as tsv:
  • yield from csv.reader(tsv, delimiter="\t")
  • @staticmethod
  • def _get_string(row: List[str], col: int) -> str:
  • """Returns a string from a row by index.

Add a newline below

In yoyodyne/predict.py https://github.com/CUNY-CL/yoyodyne/pull/110#discussion_r1265728241:

  • collator = collators.Collator(
  • dataset,
  • arch,
  • max_source_length,
  • max_target_length,
  • )
  • return data.DataLoader(
  • dataset,
  • collate_fn=collator,
  • batch_size=batch_size,
  • num_workers=1,
  • separate_features = args.features_col != 0 and args.arch in [
  • "pointer_generator_lstm",
  • "transducer",
  • ]
  • TODO(kbg): reuse index?

If I understand correctly, yes. We need to reuse the index from training for this to work.

In yoyodyne/predict.py https://github.com/CUNY-CL/yoyodyne/pull/110#discussion_r1265728538:

 output: str,

) -> None: """Predicts from the model.

 Args:
      trainer (pl.Trainer).
      model (pl.LightningModule).
  • loader (data.DataLoader).
  • dataomdule (data.DataModule).

dataomdule --> datamodule

In yoyodyne/predict.py https://github.com/CUNY-CL/yoyodyne/pull/110#discussion_r1265732204:

 util.log_info(f"Writing to {output}")

_mkdir(output)

  • decode_target = datamodule.predict_dataloader().dataset.decode_target

Is predict_dataloader rerun by the PTL trainer? Is there a way to define the decode methods on the datamodule?

— Reply to this email directly, view it on GitHub https://github.com/CUNY-CL/yoyodyne/pull/110#pullrequestreview-1533303161, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABG4OJRYKV3C3HW2UT2CM3XQV6NPANCNFSM6AAAAAA2MID72I . You are receiving this because you authored the thread.Message ID: @.***>

kylebgorman commented 11 months ago

Only leftover comment is about reusing the train index when making predictions, for which you have left a TODO. But unless this index is loaded elsewhere, this will be an issue if the train and test set are do not have identical vocabularies. It seems like this can be handled in a separate issue if you like.

I wasn't sure at first but I convinced myself that you're right and it's important.

Just putting this here since we are already talking about it: after migration to LightningCLI, this will compute the index across any of --train, --val, --predict, --test as provided, independent of what command (training, prediction, etc.) is requested.