dingo-gw / dingo

Dingo: Deep inference for gravitational-wave observations
MIT License
60 stars 20 forks source link

Test loss not decreasing for some datasets #61

Closed stephengreen closed 2 years ago

stephengreen commented 2 years ago

When I train a network for a chirp mass prior starting from 10 M_\odot, the test loss does not decrease properly. It stays ~ 24, while the train loss decreases. This bad loss persists even when I take train data and evaluate the network in eval mode, so I suspect it has something to do with the BatchNorm1D layers. Probably the low-Mc waveforms have large fluctuations and therefore the batch-norm running averages fail. However, it's hard to see why this would be the case.

settings.yaml:
# settings for domain of waveforms
domain:
  type: FrequencyDomain
  f_min: 20.0
  f_max: 1024.0
  delta_f: 0.125  # Expressions like 1.0/8.0 would require eval and are not supported
  window_factor: 1.0 # This should maybe be in noise settings? It is not used for the generation of waveform polarizations

# settings for waveform generator
waveform_generator:
  approximant: IMRPhenomXPHM  # SEOBNRv4PHM
  f_ref: 20.0  # Hz

# settings for intrinsic prior over parameters
intrinsic_prior:
  # prior for non-fixed parameters
  mass_1: bilby.core.prior.Constraint(minimum=10.0, maximum=80.0)
  mass_2: bilby.core.prior.Constraint(minimum=10.0, maximum=80.0)
  mass_ratio: bilby.core.prior.Uniform(minimum=0.125, maximum=1.0)
  chirp_mass: bilby.core.prior.Uniform(minimum=10.0, maximum=100.0)
  phase: default
  a_1: bilby.core.prior.Uniform(minimum=0.0, maximum=0.88)
  a_2: bilby.core.prior.Uniform(minimum=0.0, maximum=0.88)
  tilt_1: default
  tilt_2: default
  phi_12: default
  phi_jl: default
  theta_jn: default
  # reference values for fixed (extrinsic) parameters
  luminosity_distance: 100.0 # Mpc
  geocent_time: 0.0 # s

# Number of samples in the dataset
num_samples: 5000000

# Save a compressed representation of the dataset
compression:
  svd:
    num_training_samples: 50000
    # Truncate the SVD basis at this size. No truncation if zero.
    size: 200
train_settings.yaml:
# Settings for data generation
data:
  waveform_dataset_path: ./waveform_dataset.hdf5  # Contains intrinsic waveforms
  train_fraction: 0.95
  # data conditioning for inference
  window:
    type: tukey
    f_s: 4096
    T: 8.0
    roll_off: 0.4
  detectors:
    - H1
    - L1
  extrinsic_prior:
    dec: default
    ra: default
    geocent_time: bilby.core.prior.Uniform(minimum=-0.10, maximum=0.10)
    psi: default
    luminosity_distance: bilby.core.prior.Uniform(minimum=100.0, maximum=2000.0)
  ref_time: 1126259462.391
  gnpe_time_shifts:
    kernel_kwargs: {type: uniform, low: -0.001, high: 0.001}
    exact_equiv: True
  selected_parameters: default # [chirp_mass, mass_ratio,  luminosity_distance, dec]

# Model architecture
model:
  type: nsf+embedding
  # kwargs for neural spline flow
  nsf_kwargs:
    num_flow_steps: 30
    base_transform_kwargs:
      hidden_dim: 512
      num_transform_blocks: 5
      activation: elu
      dropout_probability: 0.0
      batch_norm: True
      num_bins: 8
      base_transform_type: rq-coupling
  # kwargs for embedding net
  embedding_net_kwargs:
    output_dim: 128
    hidden_dims: [1024, 1024, 1024, 1024, 1024, 1024,
                  512, 512, 512, 512, 512, 512,
                  256, 256, 256, 256, 256, 256,
                  128, 128, 128, 128, 128, 128]
    activation: elu
    dropout: 0.0
    batch_norm: True
    svd:
      num_training_samples: 50000
      num_validation_samples: 5000
      size: 200

# Training is divided in stages. They each require all settings as indicated below.
training:
  stage_0:
    epochs: 300
    asd_dataset_path: ../asds_O3_fiducial.hdf5 #/home/jonas/Desktop/dingo-devel/tutorials/02_gwpe/datasets/ASDs/asds_O2.hdf5
    freeze_rb_layer: True
    optimizer:
      type: adam
      lr: 0.0002
    scheduler:
      type: cosine
      T_max: 300
    batch_size: 8192

  stage_1:
    epochs: 150
    asd_dataset_path: ../asds_O3.hdf5 #/home/jonas/Desktop/dingo-devel/tutorials/02_gwpe/datasets/ASDs/asds_O2.hdf5
    freeze_rb_layer: False
    optimizer:
      type: adam
      lr: 0.00002
    scheduler:
      type: cosine
      T_max: 150
    batch_size: 8192

# Local settings for training that have no impact on the final trained network.
local:
  device: cuda
  num_workers: 32 # num_workers >0 does not work on Mac, see https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206
  runtime_limits:
    max_time_per_run: 3600000
    max_epochs_per_run: 500
  checkpoint_epochs: 10
stephengreen commented 2 years ago

Resolved by #73