Closed wbenoit26 closed 4 months ago
@EthanMarx Curious to get your thoughts here. I don't love the way that q
gets handled, and more generally, how any future augmentor
args would have to be passed. I'm also thinking that it would be a good idea to make a place to put different configs for different scenarios, so that they can be easily swapped out without needing to change each parameter manually. We could have a set of base BBH configs and a set of base BNS configs that could be further modified for whatever experiment we're running.
Looking at the configuration changes, I think that multiple config idea is necessary now. Could you add a configs
directory at the root of the train
project? We can then have bns.yaml
bbh.yaml
and tune.yaml
live there.
TODOs from our call:
qtransform
module into the forward
method of a pl.LightningModule
so that it can be traced with the model at checkpoint time. Maybe generalize this to some augmentor = Optional[torch.nn.Module]
and then it can be configured in the yaml
configq
parameter from luigi
pipeline. This will now live in the training config.yaml
qtransform
in the BatchWhitener
. I think there shouldn't be any changes to the export code.@EthanMarx I was able to incorporate the augmentor
into the model. It's a little less neat than I was hoping, but it works for both BNS and BBH cases. One thing I'm not loving is that during training, the log message is
| Name | Type | Params
---------------------------------------------
0 | augmentor | Identity | 0
1 | model | Sequential | 7.2 M
2 | metric | TimeSlideAUROC | 0
---------------------------------------------
Rather than something more specific for the model. Not sure if there's a way to show the architecture there, rather than the model.
I'm realizing that this change kind of removes the need to have separate time and frequency domain data objects, so we may want to deprecate them in the future unless we find another need.
The only remaining issue is with exporting the Q-transform because of the FFTs. It looks like PyTorch has a beta dynamo_export
that can do the conversion, but I need to play around with it to see if all the version changes will work.
Ah yeah for clarity sake, it be worth removing the torch.nn.Sequential
and just calling the augmentor by hand. Lightning should then log the actually model architecture (I think).
Also, good point on the FrequencyDomainDataset
, if it's not being used anywhere, I would say remove it, or open an issue to remove it down the line.
Is it not currently possible to export the ffts? If so, we should probably reconsider this setup
If I call the augmentor by hand, then the trace won't include it, and it won't get saved with the rest of the model.
I can't export as-is, but ONNX supports FFTs, it's just a matter of whether the converter between PyTorch and ONNX has those included. It looks like it's possible, but might require changing so much that it's not worth it. I'm going to see what it takes to get things working.
I think as long as you call the augmentor
in the forward method it should get traced
I tried that, and I ran into issues loading the model from checkpoint to do the trace. But now that I'm thinking about it again, I think I know what I was doing wrong. I'll re-attempt.
Hoping this is as simple as replacing torch.onnx.export
with torch.onnx.dynamo_export
in hermes
. It looks like dynamo_export
is still in beta version, but it might be stable enough to support our few use cases.
We can add some flag dynamo
boolean flag to the TorchOnnx
exporter that will use whichever one is desired.
It also requires onnxscript, which, judging by the PyPI release history, is also very much in development. Again, might be stable enough for us.
It turns out that dynamo
and jit
are separate frameworks, so more changes would be necessary. It seems like dynamo
will be what's used more going forward, so I think it'll be worth it to figure out. I'll make a separate PR for just those changes and then rebase this one.
@EthanMarx As we discussed, I've reverted back to having the Q-transform explicitly in export
. Including as part of the model will be the right thing to do eventually, but that requires a possibly significant update to hermes
, and that shouldn't hold up this PR.
I've removed the q
parameter from the pipeline config and the tasks, so it just lives in the train
and export
config.yaml
files. I think it's cleaner this way, even though there's nothing to enforce that the value in each project be the same.
I've tested both the bbh.yaml
and the bns.yaml
config files (with max_epochs = 2
), and both are able to get through training and export. Let me know if you have further thoughts on this PR.
The only issue I see at the moment is that the Export
Task in the luigi pipeline doesn't actually use the config.yaml
in the export project. So unless I'm missing something I don't think the q
value is being passed to that task.
Ah okay, I just assumed that was set up already. I'm not sure how it was working if None
was being passed for q
, but maybe some default value was being used (or maybe I missed something in testing). Was there a reason that the config.yaml
wasn't being used yet, or was it just still on the to-do list?
To use the config.yaml
we would need to start calling the export cli
directly from the luigi task. Currently, I just import and call the executable itself.
The main reason I did this was to avoid having to manage multiple configs that share variables, and to avoid too much confusion over which variables are being passed where. This works great when the train
and export
projects are entirely decoupled, but it seems that's no longer going to be the case.
I think there must be a way for the export config to import/share variables from the train config, in which case I think this problem could be solved. Let me look into this.
That would be great. I was wondering if it would ever make sense to combine the train and export projects, but being able to import the config file would be a much neater solution.
Yeah thats definitely an idea worth considering at this point
@EthanMarx I went back to having q
specified in the sandbox config, and I realized that there's an OptionalFloatParameter
that makes things neater, though still not perfect. I also found that things worked a little more smoothly if I had the Q-transform be instantiated in the setup
function of the dataset, so I switched to doing things that way.
Let me know what you think.
Yeah that definitely makes sense for now. Looks good. Is this ready for a final once over?
Yeah, I'd say so.
@EthanMarx Looks like pushing additional commits dismissed your review, can you re-approve?
Adds the infrastructure necessary to do training using Q-transforms, with the
SingleQTransform
module from `ml4gw.dataloader
to infer model input sizeexport
, model input size is inferred from the training batch saved during trainingaugmentor
to theBatchWhitener
, which acts on the whitened data before the data is passed to the NNmass_combo
an argument of the functionCloses #125 Closes #139 Closes #140 Closes #142