uber / causalml

Uplift modeling and causal inference with machine learning algorithms
Other
5.07k stars 779 forks source link

When will Dragonnet and CEVAE available? #559

Closed Leo-T-Zang closed 4 months ago

Leo-T-Zang commented 2 years ago

I find your documentation has these two methods but unavailable right now.

Thank you.

jeongyoonlee commented 2 years ago

Both Dragonnet and CEVAE are available. You can check out the example notebooks as follows:

Note that Dragonnet requires tensorflow>=2.4 as an additional dependency.

Leo-T-Zang commented 2 years ago

Hi @jeongyoonlee

Thank you so much!

But when I try to use CEVAE, here is a error that I could not solve. Could you please take a look?

ValueError                                Traceback (most recent call last)
File ~/miniconda3/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
    173 try:
--> 174     ret = self.fn(*args, **kwargs)
    175 except (ValueError, RuntimeError) as e:

File ~/miniconda3/lib/python3.8/site-packages/pyro/nn/module.py:427, in PyroModule.__call__(self, *args, **kwargs)
    426 with self._pyro_context:
--> 427     return super().__call__(*args, **kwargs)

File ~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used

File ~/miniconda3/lib/python3.8/site-packages/pyro/contrib/cevae/__init__.py:404, in Guide.forward(self, x, t, y, size)
    400 with pyro.plate("data", size, subsample=x):
    401     # The t and y sites are needed for prediction, and participate in
    402     # the auxiliary CEVAE loss. We mark them auxiliary to indicate they
    403     # do not correspond to latent variables during training.
--> 404     t = pyro.sample("t", self.t_dist(x), obs=t, infer={"is_auxiliary": True})
    405     y = pyro.sample("y", self.y_dist(t, x), obs=y, infer={"is_auxiliary": True})

File ~/miniconda3/lib/python3.8/site-packages/pyro/contrib/cevae/__init__.py:411, in Guide.t_dist(self, x)
    410 (logits,) = self.t_nn(x)
--> 411 return dist.Bernoulli(logits=logits)

File ~/miniconda3/lib/python3.8/site-packages/pyro/distributions/distribution.py:24, in DistributionMeta.__call__(cls, *args, **kwargs)
     23         return result
---> 24 return super().__call__(*args, **kwargs)

File ~/miniconda3/lib/python3.8/site-packages/torch/distributions/bernoulli.py:49, in Bernoulli.__init__(self, probs, logits, validate_args)
     48     batch_shape = self._param.size()
---> 49 super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)

File ~/miniconda3/lib/python3.8/site-packages/torch/distributions/distribution.py:55, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     54         if not valid.all():
---> 55             raise ValueError(
     56                 f"Expected parameter {param} "
     57                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     58                 f"of distribution {repr(self)} "
     59                 f"to satisfy the constraint {repr(constraint)}, "
     60                 f"but found invalid values:\n{value}"
     61             )
     62 super(Distribution, self).__init__()

ValueError: Expected parameter logits (Tensor of shape (32,)) of distribution Bernoulli(logits: torch.Size([32])) to satisfy the constraint Real(), but found invalid values:
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan], grad_fn=<ClampBackward1>)

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Input In [24], in <cell line: 10>()
      1 cevae = CEVAE(outcome_dist=outcome_dist,
      2               latent_dim=latent_dim,
      3               hidden_dim=hidden_dim,
   (...)
      7               learning_rate_decay=learning_rate_decay,
      8               num_layers=num_layers)
      9 # fit
---> 10 losses = cevae.fit(X=torch.tensor(x, dtype=torch.float),
     11                    treatment=torch.tensor(treatment, dtype=torch.float),
     12                    y=torch.tensor(y, dtype=torch.float))
     13 # predict
     14 ite = cevae.predict(x)

File ~/miniconda3/lib/python3.8/site-packages/causalml/inference/nn/cevae.py:100, in CEVAE.fit(self, X, treatment, y, p)
     90 X, treatment, y = convert_pd_to_np(X, treatment, y)
     92 self.cevae = CEVAEModel(
     93     outcome_dist=self.outcome_dist,
     94     feature_dim=X.shape[-1],
   (...)
     97     num_layers=self.num_layers,
     98 )
--> 100 self.cevae.fit(
    101     x=torch.tensor(X, dtype=torch.float),
    102     t=torch.tensor(treatment, dtype=torch.float),
    103     y=torch.tensor(y, dtype=torch.float),
    104     num_epochs=self.num_epochs,
    105     batch_size=self.batch_size,
    106     learning_rate=self.learning_rate,
    107     learning_rate_decay=self.learning_rate_decay,
    108     weight_decay=self.weight_decay,
    109 )

File ~/miniconda3/lib/python3.8/site-packages/pyro/contrib/cevae/__init__.py:592, in CEVAE.fit(self, x, t, y, num_epochs, batch_size, learning_rate, learning_rate_decay, weight_decay, log_every)
    590 for x, t, y in dataloader:
    591     x = self.whiten(x)
--> 592     loss = svi.step(x, t, y, size=len(dataset)) / len(dataset)
    593     if log_every and len(losses) % log_every == 0:
    594         logger.debug(
    595             "step {: >5d} loss = {:0.6g}".format(len(losses), loss)
    596         )

File ~/miniconda3/lib/python3.8/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
    143 # get loss and compute gradients
    144 with poutine.trace(param_only=True) as param_capture:
--> 145     loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    147 params = set(
    148     site["value"].unconstrained() for site in param_capture.trace.nodes.values()
    149 )
    151 # actually perform gradient steps
    152 # torch.optim objects gets instantiated for any params that haven't been seen yet

File ~/miniconda3/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
    138 loss = 0.0
    139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    141     loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
    142         model_trace, guide_trace
    143     )
    144     loss += loss_particle / self.num_particles

File ~/miniconda3/lib/python3.8/site-packages/pyro/infer/elbo.py:182, in ELBO._get_traces(self, model, guide, args, kwargs)
    180 else:
    181     for i in range(self.num_particles):
--> 182         yield self._get_trace(model, guide, args, kwargs)

File ~/miniconda3/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
     52 def _get_trace(self, model, guide, args, kwargs):
     53     """
     54     Returns a single trace from the guide, and the model that is run
     55     against it.
     56     """
---> 57     model_trace, guide_trace = get_importance_trace(
     58         "flat", self.max_plate_nesting, model, guide, args, kwargs
     59     )
     60     if is_validation_enabled():
     61         check_if_enumerated(guide_trace)

File ~/miniconda3/lib/python3.8/site-packages/pyro/infer/enum.py:60, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     58     model_trace, guide_trace = unwrapped_guide.get_traces()
     59 else:
---> 60     guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(
     61         *args, **kwargs
     62     )
     63     if detach:
     64         guide_trace.detach_()

File ~/miniconda3/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, **kwargs)
    190 def get_trace(self, *args, **kwargs):
    191     """
    192     :returns: data structure
    193     :rtype: pyro.poutine.Trace
   (...)
    196     Calls this poutine and returns its trace instead of the function's return value.
    197     """
--> 198     self(*args, **kwargs)
    199     return self.msngr.get_trace()

File ~/miniconda3/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:180, in TraceHandler.__call__(self, *args, **kwargs)
    178         exc = exc_type("{}\n{}".format(exc_value, shapes))
    179         exc = exc.with_traceback(traceback)
--> 180         raise exc from e
    181     self.msngr.trace.add_node(
    182         "_RETURN", name="_RETURN", type="return", value=ret
    183     )
    184 return ret

File ~/miniconda3/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
    170 self.msngr.trace.add_node(
    171     "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs
    172 )
    173 try:
--> 174     ret = self.fn(*args, **kwargs)
    175 except (ValueError, RuntimeError) as e:
    176     exc_type, exc_value, traceback = sys.exc_info()

File ~/miniconda3/lib/python3.8/site-packages/pyro/nn/module.py:427, in PyroModule.__call__(self, *args, **kwargs)
    425 def __call__(self, *args, **kwargs):
    426     with self._pyro_context:
--> 427         return super().__call__(*args, **kwargs)

File ~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/lib/python3.8/site-packages/pyro/contrib/cevae/__init__.py:404, in Guide.forward(self, x, t, y, size)
    399     size = x.size(0)
    400 with pyro.plate("data", size, subsample=x):
    401     # The t and y sites are needed for prediction, and participate in
    402     # the auxiliary CEVAE loss. We mark them auxiliary to indicate they
    403     # do not correspond to latent variables during training.
--> 404     t = pyro.sample("t", self.t_dist(x), obs=t, infer={"is_auxiliary": True})
    405     y = pyro.sample("y", self.y_dist(t, x), obs=y, infer={"is_auxiliary": True})
    406     # The z site participates only in the usual ELBO loss.

File ~/miniconda3/lib/python3.8/site-packages/pyro/contrib/cevae/__init__.py:411, in Guide.t_dist(self, x)
    409 def t_dist(self, x):
    410     (logits,) = self.t_nn(x)
--> 411     return dist.Bernoulli(logits=logits)

File ~/miniconda3/lib/python3.8/site-packages/pyro/distributions/distribution.py:24, in DistributionMeta.__call__(cls, *args, **kwargs)
     22     if result is not None:
     23         return result
---> 24 return super().__call__(*args, **kwargs)

File ~/miniconda3/lib/python3.8/site-packages/torch/distributions/bernoulli.py:49, in Bernoulli.__init__(self, probs, logits, validate_args)
     47 else:
     48     batch_shape = self._param.size()
---> 49 super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)

File ~/miniconda3/lib/python3.8/site-packages/torch/distributions/distribution.py:55, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     53         valid = constraint.check(value)
     54         if not valid.all():
---> 55             raise ValueError(
     56                 f"Expected parameter {param} "
     57                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     58                 f"of distribution {repr(self)} "
     59                 f"to satisfy the constraint {repr(constraint)}, "
     60                 f"but found invalid values:\n{value}"
     61             )
     62 super(Distribution, self).__init__()

ValueError: Expected parameter logits (Tensor of shape (32,)) of distribution Bernoulli(logits: torch.Size([32])) to satisfy the constraint Real(), but found invalid values:
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan], grad_fn=<ClampBackward1>)
     Trace Shapes:           
      Param Sites:           
t_nn$$$fc.0.weight      1 712
  t_nn$$$fc.0.bias          1
     Sample Sites:           
         data dist          |
             value 32 712   |
Leo-T-Zang commented 2 years ago

Here is the code I use. I don't know in the process, it gives all NaN value

# cevae model settings
outcome_dist = "normal"
latent_dim = 50
hidden_dim = 200
num_epochs = 5
batch_size = 32
learning_rate = 0.001
learning_rate_decay = 0.01
num_layers = 2

cevae = CEVAE(outcome_dist=outcome_dist,
              latent_dim=latent_dim,
              hidden_dim=hidden_dim,
              num_epochs=num_epochs,
              batch_size=batch_size,
              learning_rate=learning_rate,
              learning_rate_decay=learning_rate_decay,
              num_layers=num_layers)
# fit
losses = cevae.fit(X=torch.tensor(x, dtype=torch.float),
                   treatment=torch.tensor(treatment, dtype=torch.float),
                   y=torch.tensor(y, dtype=torch.float))
# predict
ite = cevae.predict(x)