facebookresearch / gtn

Automatic differentiation with weighted finite-state transducers.
MIT License
453 stars 40 forks source link

Including GTN graphs as trainable parameters in a model #45

Closed brianyan918 closed 3 years ago

brianyan918 commented 3 years ago

Hi!

I would like to keep a gtn.graph object in my model that is updated with each optimizer.step(). I may be wrong, but I doubt that the pytorch optimizer would treat the weights in this graph as parameters and update them with that call.

For example, the code in gtn_application runs some model to get emission probabilities and computes CTC loss from gtn.intersect(g_emissions, g_criterion). This works because g_emissions is being created each time from the tensor representation of emissions.

Now, if I want to insert a new_graph so that my CTC loss comes from gtn.intersect(gtn.intersect(g_emissions, new_graph), g_criterion), is there a way to keep the parameters of new_graph in a gtn.graph and have them be updated with pytorch's optimizer.step? Or does new_graph need to be created each time from tensors and the set_weights call?

awni commented 3 years ago

I think what you are looking for is something like what we do with the transition graph and "mirrored" torch transition_params in the Transducer.

The basic idea is the transition graph is used to compute the loss function in GTN. Before we use it we set the weights from the torch params. And after the gradient on the transition graph is computed we extract it and return it from the backward function as the gradient for transition_params which are held by the optimizer and used during the step.

In summary:

  1. Make a torch tensor to hold new_graph weights, call it e.g. new_graph_params.
  2. new_graph_params is held by the optimizer
  3. Before calling forward set the weights of new_graph from new_graph_params
  4. When returning from backward extract the gradient from new_graph and return it in a tensor as the gradient for new_graph_params