ML4GW / aframev2

Detecting binary black hole mergers in LIGO with neural networks
MIT License
6 stars 14 forks source link

BNS Model #138

Closed wbenoit26 closed 4 months ago

wbenoit26 commented 4 months ago

Adds the infrastructure necessary to do training using Q-transforms, with the SingleQTransform module from `ml4gw.

Closes #125 Closes #139 Closes #140 Closes #142

wbenoit26 commented 4 months ago

@EthanMarx Curious to get your thoughts here. I don't love the way that q gets handled, and more generally, how any future augmentor args would have to be passed. I'm also thinking that it would be a good idea to make a place to put different configs for different scenarios, so that they can be easily swapped out without needing to change each parameter manually. We could have a set of base BBH configs and a set of base BNS configs that could be further modified for whatever experiment we're running.

EthanMarx commented 4 months ago

Looking at the configuration changes, I think that multiple config idea is necessary now. Could you add a configs directory at the root of the train project? We can then have bns.yaml bbh.yaml and tune.yaml live there.

EthanMarx commented 4 months ago

TODOs from our call:

  1. Move qtransform module into the forward method of a pl.LightningModule so that it can be traced with the model at checkpoint time. Maybe generalize this to some augmentor = Optional[torch.nn.Module] and then it can be configured in the yaml config
  2. Remove q parameter from luigi pipeline. This will now live in the training config.yaml
  3. Export code no longer needs to infer input size from batch file, or include the qtransform in the BatchWhitener. I think there shouldn't be any changes to the export code.
  4. Test out training pipeline for BBH and BNS to ensure that the model tracing during checkpointing is completely general.
wbenoit26 commented 4 months ago

@EthanMarx I was able to incorporate the augmentor into the model. It's a little less neat than I was hoping, but it works for both BNS and BBH cases. One thing I'm not loving is that during training, the log message is

  | Name      | Type           | Params
---------------------------------------------
0 | augmentor | Identity       | 0
1 | model     | Sequential     | 7.2 M
2 | metric    | TimeSlideAUROC | 0
---------------------------------------------

Rather than something more specific for the model. Not sure if there's a way to show the architecture there, rather than the model.

I'm realizing that this change kind of removes the need to have separate time and frequency domain data objects, so we may want to deprecate them in the future unless we find another need.

The only remaining issue is with exporting the Q-transform because of the FFTs. It looks like PyTorch has a beta dynamo_export that can do the conversion, but I need to play around with it to see if all the version changes will work.

EthanMarx commented 4 months ago

Ah yeah for clarity sake, it be worth removing the torch.nn.Sequential and just calling the augmentor by hand. Lightning should then log the actually model architecture (I think).

Also, good point on the FrequencyDomainDataset, if it's not being used anywhere, I would say remove it, or open an issue to remove it down the line.

Is it not currently possible to export the ffts? If so, we should probably reconsider this setup

wbenoit26 commented 4 months ago

If I call the augmentor by hand, then the trace won't include it, and it won't get saved with the rest of the model.

I can't export as-is, but ONNX supports FFTs, it's just a matter of whether the converter between PyTorch and ONNX has those included. It looks like it's possible, but might require changing so much that it's not worth it. I'm going to see what it takes to get things working.

EthanMarx commented 4 months ago

I think as long as you call the augmentor in the forward method it should get traced

wbenoit26 commented 4 months ago

I tried that, and I ran into issues loading the model from checkpoint to do the trace. But now that I'm thinking about it again, I think I know what I was doing wrong. I'll re-attempt.

EthanMarx commented 4 months ago

Hoping this is as simple as replacing torch.onnx.export with torch.onnx.dynamo_export in hermes. It looks like dynamo_export is still in beta version, but it might be stable enough to support our few use cases.

We can add some flag dynamo boolean flag to the TorchOnnx exporter that will use whichever one is desired.

wbenoit26 commented 4 months ago

It also requires onnxscript, which, judging by the PyPI release history, is also very much in development. Again, might be stable enough for us.

It turns out that dynamo and jit are separate frameworks, so more changes would be necessary. It seems like dynamo will be what's used more going forward, so I think it'll be worth it to figure out. I'll make a separate PR for just those changes and then rebase this one.

wbenoit26 commented 4 months ago

@EthanMarx As we discussed, I've reverted back to having the Q-transform explicitly in export. Including as part of the model will be the right thing to do eventually, but that requires a possibly significant update to hermes, and that shouldn't hold up this PR.

I've removed the q parameter from the pipeline config and the tasks, so it just lives in the train and export config.yaml files. I think it's cleaner this way, even though there's nothing to enforce that the value in each project be the same.

I've tested both the bbh.yaml and the bns.yaml config files (with max_epochs = 2), and both are able to get through training and export. Let me know if you have further thoughts on this PR.

EthanMarx commented 4 months ago

The only issue I see at the moment is that the Export Task in the luigi pipeline doesn't actually use the config.yaml in the export project. So unless I'm missing something I don't think the q value is being passed to that task.

wbenoit26 commented 4 months ago

Ah okay, I just assumed that was set up already. I'm not sure how it was working if None was being passed for q, but maybe some default value was being used (or maybe I missed something in testing). Was there a reason that the config.yaml wasn't being used yet, or was it just still on the to-do list?

EthanMarx commented 4 months ago

To use the config.yaml we would need to start calling the export cli directly from the luigi task. Currently, I just import and call the executable itself.

The main reason I did this was to avoid having to manage multiple configs that share variables, and to avoid too much confusion over which variables are being passed where. This works great when the train and export projects are entirely decoupled, but it seems that's no longer going to be the case.

I think there must be a way for the export config to import/share variables from the train config, in which case I think this problem could be solved. Let me look into this.

wbenoit26 commented 4 months ago

That would be great. I was wondering if it would ever make sense to combine the train and export projects, but being able to import the config file would be a much neater solution.

EthanMarx commented 4 months ago

Yeah thats definitely an idea worth considering at this point

wbenoit26 commented 4 months ago

@EthanMarx I went back to having q specified in the sandbox config, and I realized that there's an OptionalFloatParameter that makes things neater, though still not perfect. I also found that things worked a little more smoothly if I had the Q-transform be instantiated in the setup function of the dataset, so I switched to doing things that way.

Let me know what you think.

EthanMarx commented 4 months ago

Yeah that definitely makes sense for now. Looks good. Is this ready for a final once over?

wbenoit26 commented 4 months ago

Yeah, I'd say so.

wbenoit26 commented 4 months ago

@EthanMarx Looks like pushing additional commits dismissed your review, can you re-approve?