ML4GW / aframe

Detecting binary black hole mergers in LIGO with neural networks
MIT License
18 stars 17 forks source link

Add check for psd_length >= window_length #430

Open rafia17 opened 11 months ago

rafia17 commented 11 months ago

window_length is calculated in train.py by adding the kernel_length and fduration here. This should always be greater than the psd_length, which is a parameter that we set in pyproject.toml. If you change either the kernel_length or fduration such that the window_length exceeds the psd_length, then we get the following error (this happened when kernel_length was changed from 1.5 sec to 8 sec) :

Traceback (most recent call last):
 File "/home/rafia.omer/miniconda3/envs/train-CPGBhxhY-py3.9/bin/train", line 6, in <module>
  sys.exit(main())
 File "/home/rafia.omer/miniconda3/envs/train-CPGBhxhY-py3.9/lib/python3.9/site-packages/hermes/typeo/typeo.py", line 572, in wrapper
  result = subcommand(**subkw)
 File "/home/rafia.omer/ML4GW/aframe2/BBHNet/libs/architectures/aframe/architectures/wrapper.py", line 31, in arch_fn
  return fn(**fn_kwargs)
 File "/home/rafia.omer/ML4GW/aframe2/BBHNet/libs/trainer/aframe/trainer/trainer.py", line 151, in train
  X, y = next(iter(train_dataset))
 File "/home/rafia.omer/ML4GW/aframe2/BBHNet/projects/sandbox/train/train/augmentor.py", line 278, in __iter__
  yield self.fn(X[0].to(self.device))
 File "/home/rafia.omer/miniconda3/envs/train-CPGBhxhY-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
  return forward_call(*input, **kwargs)
 File "/home/rafia.omer/miniconda3/envs/train-CPGBhxhY-py3.9/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
  return func(*args, **kwargs)
 File "/home/rafia.omer/ML4GW/aframe2/BBHNet/projects/sandbox/train/train/augmentor.py", line 217, in forward
  X, psds = self.psd_estimator(X)
 File "/home/rafia.omer/miniconda3/envs/train-CPGBhxhY-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
  return forward_call(*input, **kwargs)
 File "/home/rafia.omer/ML4GW/aframe2/BBHNet/libs/architectures/aframe/architectures/preprocessor.py", line 59, in forward
  psds = self.spectral_density(background.double())
 File "/home/rafia.omer/miniconda3/envs/train-CPGBhxhY-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
  return forward_call(*input, **kwargs)
 File "/home/rafia.omer/ML4GW/aframe2/BBHNet/ml4gw/ml4gw/transforms/spectral.py", line 86, in forward
  return fast_spectral_density(
 File "/home/rafia.omer/ML4GW/aframe2/BBHNet/ml4gw/ml4gw/spectral.py", line 156, in fast_spectral_density
  _validate_shapes(x, nperseg, y)
 File "/home/rafia.omer/ML4GW/aframe2/BBHNet/ml4gw/ml4gw/spectral.py", line 39, in _validate_shapes
  raise ValueError(
ValueError: Number of samples 32768 in input x is insufficient for number of fft samples 36864

We should add a check for this in train.py and an appropriate error message if the check fails.