thoglu / jammy_flows

A package to describe amortized (conditional) normalizing-flow PDFs defined jointly on tensor products of manifolds with coverage control. The connection between different manifolds is fixed via an autoregressive structure.
MIT License
42 stars 3 forks source link

lambert projection does not work #3

Closed chrhck closed 2 years ago

chrhck commented 2 years ago
pdf = jammy_flows.pdf("e1+s2", "gggg+n", conditional_input_dim=4, hidden_mlp_dims_sub_pdfs="128-128")

helper_fns.visualize_pdf(
    pdf.to("cpu"),
    fig,
    nsamples=10000,
    conditional_input=labels,
    bounds=[[-50, 150], [0, np.pi], [0, 2*np.pi]],
    s2_norm="lambert"
    );

RuntimeError                              Traceback (most recent call last)
Cell In [20], line 4
      1 data_df, labels = read_photon_table_hdf_unweighted("../assets/photon_table.hd5", 80)
      3 fig=plt.figure(figsize=(8,6))
----> 4 helper_fns.visualize_pdf(
      5     pdf.to("cpu"),
      6     fig,
      7     nsamples=10000,
      8     conditional_input=labels,
      9     bounds=[[-50, 150], [0, np.pi], [0, 2*np.pi]],
     10     s2_norm="lambert"
     11     )

File ~/.local/lib/python3.10/site-packages/jammy_flows/helper_fns.py:1012, in visualize_pdf(pdf, fig, gridspec, subgridspec, conditional_input, nsamples, total_pdf_eval_pts, bounds, true_values, plot_only_contours, contour_probs, contour_color, autoscale, seed, skip_plotting_density, hide_labels, s2_norm, colormap, s2_rotate_to_true_value, s2_show_gridlines, skip_plotting_samples, var_names)
   1005   samples, samples_base, evals, evals_base = pdf.sample(
   1006       samplesize=nsamples,
   1007       conditional_input=sample_conditional_input,
   1008       seed=seed)
   1010   higher_dim_spheres = False
-> 1012   new_subgridspec, total_pdf_integral = plot_joint_pdf(
   1013       pdf,
   1014       fig,
   1015       gridspec,
   1016       samples,
...
---> 77 input_cumwidths = cumwidths.gather(-1, bin_idx)#[..., 0]
     81 input_bin_widths = widths.gather(-1, bin_idx)#[..., 0]
     83 input_cumheights = cumheights.gather(-1, bin_idx)#[..., 0]

RuntimeError: index -1 is out of bounds for dimension 1 with size 6
thoglu commented 2 years ago

The lambert projection projects to new coordinates which are [-2,2] for both. So instead of [0,pi], [0,2*pi] you should enter [-2,2], [-2,2] for the bounds and then it should work .. right now the defined bounds are going beyond those values. It should also work, when you do not specify bounds. Did you check what happens when you do not specify the bounds keyword?

There is also a script _plot_moving_lambertprojection.py in the examples folder, which uses a lambert and non-lambert visualization. Similar settings should be used.

chrhck commented 2 years ago
pdf = jammy_flows.pdf("e1+s2", "ggg+v", hidden_mlp_dims_sub_pdfs="128")
fig=plt.figure(figsize=(8,6))
helper_fns.visualize_pdf(
    pdf.to("cpu"),
    fig,
    nsamples=10000,

    #bounds=[[-50, 150], [0, np.pi], [0, 2*np.pi]],
    s2_norm="lambert"
    );

gives:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In [9], line 3
      1 pdf = jammy_flows.pdf("e1+s2", "ggg+v", hidden_mlp_dims_sub_pdfs="128")
      2 fig=plt.figure(figsize=(8,6))
----> 3 helper_fns.visualize_pdf(
      4     pdf.to("cpu"),
      5     fig,
      6     nsamples=10000,
      7   
      8     #bounds=[[-50, 150], [0, np.pi], [0, 2*np.pi]],
      9     s2_norm="lambert",
     10     s2_rotate_to_true_value=True
     11     )

File ~/.local/lib/python3.10/site-packages/jammy_flows/helper_fns.py:1012, in visualize_pdf(pdf, fig, gridspec, subgridspec, conditional_input, nsamples, total_pdf_eval_pts, bounds, true_values, plot_only_contours, contour_probs, contour_color, autoscale, seed, skip_plotting_density, hide_labels, s2_norm, colormap, s2_rotate_to_true_value, s2_show_gridlines, skip_plotting_samples, var_names)
   1005   samples, samples_base, evals, evals_base = pdf.sample(
   1006       samplesize=nsamples,
   1007       conditional_input=sample_conditional_input,
   1008       seed=seed)
   1010   higher_dim_spheres = False
-> 1012   new_subgridspec, total_pdf_integral = plot_joint_pdf(
   1013       pdf,
   1014       fig,
   1015       gridspec,
   1016       samples,
   1017       subgridspec=subgridspec,
   1018       conditional_input=conditional_input,
   1019       bounds=bounds,
   1020       multiplot=False,
   1021       total_pdf_eval_pts=total_pdf_eval_pts,
   1022       true_values=true_values,
   1023       plot_only_contours=plot_only_contours,
   1024       contour_probs=contour_probs,
   1025       contour_color=contour_color,
   1026       autoscale=autoscale,
   1027       skip_plotting_density=skip_plotting_density,
   1028       hide_labels=hide_labels,
   1029       s2_norm=s2_norm,
   1030       colormap=colormap,
   1031       s2_rotate_to_true_value=s2_rotate_to_true_value,
   1032       s2_show_gridlines=s2_show_gridlines,
   1033       skip_plotting_samples=skip_plotting_samples,
   1034       var_names=var_names)
   1037 return samples, new_subgridspec, total_pdf_integral

File ~/.local/lib/python3.10/site-packages/jammy_flows/helper_fns.py:639, in plot_joint_pdf(pdf, fig, gridspec, samples, subgridspec, conditional_input, bounds, multiplot, total_pdf_eval_pts, true_values, plot_only_contours, contour_probs, contour_color, autoscale, skip_plotting_density, hide_labels, s2_norm, colormap, s2_rotate_to_true_value, s2_show_gridlines, skip_plotting_samples, var_names)
    636 if (pdf_conditional_input is not None):
    637     pdf_conditional_input = pdf_conditional_input[0:1]
--> 639 evalpositions, log_evals, bin_volumes, sin_zen_mask, unreliable_spherical_regions= get_pdf_on_grid(
    640     pure_float_mms,
    641     pts_per_dim,
    642     pdf,
    643     conditional_input=pdf_conditional_input,
    644     s2_norm=s2_norm,
    645     s2_rotate_to_true_value=s2_rotate_to_true_value,
    646     true_values=true_values)
    649 total_pdf_integral=numpy.exp(log_evals).sum()*bin_volumes
    651 if (dim == 1):

File ~/.local/lib/python3.10/site-packages/jammy_flows/helper_fns.py:173, in get_pdf_on_grid(mins_maxs, npts, model, conditional_input, s2_norm, s2_rotate_to_true_value, true_values)
    169     cinput = conditional_input.repeat(npts**len(mins_maxs), 1)[mask_inner]
    171 ## require intrinsic coordinates
--> 173 log_res, _, _ = model(eval_positions[mask_inner], conditional_input=cinput, force_intrinsic_coordinates=True)
    176 ## update s2+lambert visualizations by adding sin(theta) factors to get proper normalization
    177 for ind, pdf_def in enumerate(model.pdf_defs_list):

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.10/site-packages/jammy_flows/flows.py:975, in pdf.forward(self, x, conditional_input, amortization_parameters, force_embedding_coordinates, force_intrinsic_coordinates)
    968 tot_log_det = torch.zeros(x.shape[0]).type_as(x)
    970 base_pos, tot_log_det=self.all_layer_inverse(x, tot_log_det, conditional_input, amortization_parameters=amortization_parameters, force_embedding_coordinates=force_embedding_coordinates, force_intrinsic_coordinates=force_intrinsic_coordinates)
    972 log_pdf = torch.distributions.MultivariateNormal(
    973     torch.zeros_like(base_pos).to(x),
    974     covariance_matrix=torch.eye(self.total_base_dim).type_as(x).to(x),
--> 975 ).log_prob(base_pos)
    978 return log_pdf + tot_log_det, log_pdf, base_pos

File ~/.local/lib/python3.10/site-packages/torch/distributions/multivariate_normal.py:210, in MultivariateNormal.log_prob(self, value)
    208 def log_prob(self, value):
    209     if self._validate_args:
--> 210         self._validate_sample(value)
    211     diff = value - self.loc
    212     M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)

File ~/.local/lib/python3.10/site-packages/torch/distributions/distribution.py:293, in Distribution._validate_sample(self, value)
    291 valid = support.check(value)
    292 if not valid.all():
--> 293     raise ValueError(
    294         "Expected value argument "
    295         f"({type(value).__name__} of shape {tuple(value.shape)}) "
    296         f"to be within the support ({repr(support)}) "
    297         f"of the distribution {repr(self)}, "
    298         f"but found invalid values:\n{value}"
    299     )

ValueError: Expected value argument (Tensor of shape (6405, 3)) to be within the support (IndependentConstraint(Real(), 1)) of the distribution MultivariateNormal(loc: torch.Size([6405, 3]), covariance_matrix: torch.Size([6405, 3, 3])), but found invalid values:
tensor([[ 3.8706, -1.0163, -2.5019],
        [ 3.8706, -0.6430, -2.0699],
        [ 3.8706, -0.3868, -1.9223],
        ...,
        [-3.4793,  0.4713,  1.8465],
        [-3.4793,  0.7190,  1.9667],
        [-3.4793,  1.1232,  2.3108]], dtype=torch.float64)
thoglu commented 2 years ago

Can you try to run the script _plot_moving_lambertprojection.py in the examples folder and see if you get the same error?

chrhck commented 2 years ago

❯ python plot_moving_lambert_projection.py
Traceback (most recent call last):
  File "/home/chrhck/repos/jammy_flows/examples/plot_moving_lambert_projection.py", line 58, in <module>
    test_pdf=jammy_flows.pdf("s2", args.layer_def, options_overwrite=extra_flow_defs)
  File "/home/chrhck/repos/jammy_flows/jammy_flows/flows.py", line 116, in __init__
    self.read_model_definition(pdf_defs, 
  File "/home/chrhck/repos/jammy_flows/jammy_flows/flows.py", line 200, in read_model_definition
    flow_options.check_flow_option(flow_abbrv, detail_opt, options_overwrite[k][detail_abbrv][detail_opt])
  File "/home/chrhck/repos/jammy_flows/jammy_flows/flow_options.py", line 230, in check_flow_option
    assert(opt_name in opts_dict[flow_abbrevation]["kwargs"].keys()), ("option name %s not found in defined options for flow %s" % (opt_name, flow_abbrevation))
AssertionError: option name use_extra_householder not found in defined options for flow n
``
thoglu commented 2 years ago

There were some outdated options in the script. It should work now, can you try again?