Closed Leo-T-Zang closed 4 months ago
Both Dragonnet and CEVAE are available. You can check out the example notebooks as follows:
Note that Dragonnet requires tensorflow>=2.4
as an additional dependency.
Hi @jeongyoonlee
Thank you so much!
But when I try to use CEVAE, here is a error that I could not solve. Could you please take a look?
ValueError Traceback (most recent call last)
File ~/miniconda3/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
File ~/miniconda3/lib/python3.8/site-packages/pyro/nn/module.py:427, in PyroModule.__call__(self, *args, **kwargs)
426 with self._pyro_context:
--> 427 return super().__call__(*args, **kwargs)
File ~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
File ~/miniconda3/lib/python3.8/site-packages/pyro/contrib/cevae/__init__.py:404, in Guide.forward(self, x, t, y, size)
400 with pyro.plate("data", size, subsample=x):
401 # The t and y sites are needed for prediction, and participate in
402 # the auxiliary CEVAE loss. We mark them auxiliary to indicate they
403 # do not correspond to latent variables during training.
--> 404 t = pyro.sample("t", self.t_dist(x), obs=t, infer={"is_auxiliary": True})
405 y = pyro.sample("y", self.y_dist(t, x), obs=y, infer={"is_auxiliary": True})
File ~/miniconda3/lib/python3.8/site-packages/pyro/contrib/cevae/__init__.py:411, in Guide.t_dist(self, x)
410 (logits,) = self.t_nn(x)
--> 411 return dist.Bernoulli(logits=logits)
File ~/miniconda3/lib/python3.8/site-packages/pyro/distributions/distribution.py:24, in DistributionMeta.__call__(cls, *args, **kwargs)
23 return result
---> 24 return super().__call__(*args, **kwargs)
File ~/miniconda3/lib/python3.8/site-packages/torch/distributions/bernoulli.py:49, in Bernoulli.__init__(self, probs, logits, validate_args)
48 batch_shape = self._param.size()
---> 49 super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)
File ~/miniconda3/lib/python3.8/site-packages/torch/distributions/distribution.py:55, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
54 if not valid.all():
---> 55 raise ValueError(
56 f"Expected parameter {param} "
57 f"({type(value).__name__} of shape {tuple(value.shape)}) "
58 f"of distribution {repr(self)} "
59 f"to satisfy the constraint {repr(constraint)}, "
60 f"but found invalid values:\n{value}"
61 )
62 super(Distribution, self).__init__()
ValueError: Expected parameter logits (Tensor of shape (32,)) of distribution Bernoulli(logits: torch.Size([32])) to satisfy the constraint Real(), but found invalid values:
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan, nan], grad_fn=<ClampBackward1>)
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
Input In [24], in <cell line: 10>()
1 cevae = CEVAE(outcome_dist=outcome_dist,
2 latent_dim=latent_dim,
3 hidden_dim=hidden_dim,
(...)
7 learning_rate_decay=learning_rate_decay,
8 num_layers=num_layers)
9 # fit
---> 10 losses = cevae.fit(X=torch.tensor(x, dtype=torch.float),
11 treatment=torch.tensor(treatment, dtype=torch.float),
12 y=torch.tensor(y, dtype=torch.float))
13 # predict
14 ite = cevae.predict(x)
File ~/miniconda3/lib/python3.8/site-packages/causalml/inference/nn/cevae.py:100, in CEVAE.fit(self, X, treatment, y, p)
90 X, treatment, y = convert_pd_to_np(X, treatment, y)
92 self.cevae = CEVAEModel(
93 outcome_dist=self.outcome_dist,
94 feature_dim=X.shape[-1],
(...)
97 num_layers=self.num_layers,
98 )
--> 100 self.cevae.fit(
101 x=torch.tensor(X, dtype=torch.float),
102 t=torch.tensor(treatment, dtype=torch.float),
103 y=torch.tensor(y, dtype=torch.float),
104 num_epochs=self.num_epochs,
105 batch_size=self.batch_size,
106 learning_rate=self.learning_rate,
107 learning_rate_decay=self.learning_rate_decay,
108 weight_decay=self.weight_decay,
109 )
File ~/miniconda3/lib/python3.8/site-packages/pyro/contrib/cevae/__init__.py:592, in CEVAE.fit(self, x, t, y, num_epochs, batch_size, learning_rate, learning_rate_decay, weight_decay, log_every)
590 for x, t, y in dataloader:
591 x = self.whiten(x)
--> 592 loss = svi.step(x, t, y, size=len(dataset)) / len(dataset)
593 if log_every and len(losses) % log_every == 0:
594 logger.debug(
595 "step {: >5d} loss = {:0.6g}".format(len(losses), loss)
596 )
File ~/miniconda3/lib/python3.8/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
143 # get loss and compute gradients
144 with poutine.trace(param_only=True) as param_capture:
--> 145 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
147 params = set(
148 site["value"].unconstrained() for site in param_capture.trace.nodes.values()
149 )
151 # actually perform gradient steps
152 # torch.optim objects gets instantiated for any params that haven't been seen yet
File ~/miniconda3/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
138 loss = 0.0
139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
141 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
142 model_trace, guide_trace
143 )
144 loss += loss_particle / self.num_particles
File ~/miniconda3/lib/python3.8/site-packages/pyro/infer/elbo.py:182, in ELBO._get_traces(self, model, guide, args, kwargs)
180 else:
181 for i in range(self.num_particles):
--> 182 yield self._get_trace(model, guide, args, kwargs)
File ~/miniconda3/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
52 def _get_trace(self, model, guide, args, kwargs):
53 """
54 Returns a single trace from the guide, and the model that is run
55 against it.
56 """
---> 57 model_trace, guide_trace = get_importance_trace(
58 "flat", self.max_plate_nesting, model, guide, args, kwargs
59 )
60 if is_validation_enabled():
61 check_if_enumerated(guide_trace)
File ~/miniconda3/lib/python3.8/site-packages/pyro/infer/enum.py:60, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
58 model_trace, guide_trace = unwrapped_guide.get_traces()
59 else:
---> 60 guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(
61 *args, **kwargs
62 )
63 if detach:
64 guide_trace.detach_()
File ~/miniconda3/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, **kwargs)
190 def get_trace(self, *args, **kwargs):
191 """
192 :returns: data structure
193 :rtype: pyro.poutine.Trace
(...)
196 Calls this poutine and returns its trace instead of the function's return value.
197 """
--> 198 self(*args, **kwargs)
199 return self.msngr.get_trace()
File ~/miniconda3/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:180, in TraceHandler.__call__(self, *args, **kwargs)
178 exc = exc_type("{}\n{}".format(exc_value, shapes))
179 exc = exc.with_traceback(traceback)
--> 180 raise exc from e
181 self.msngr.trace.add_node(
182 "_RETURN", name="_RETURN", type="return", value=ret
183 )
184 return ret
File ~/miniconda3/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
170 self.msngr.trace.add_node(
171 "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs
172 )
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
176 exc_type, exc_value, traceback = sys.exc_info()
File ~/miniconda3/lib/python3.8/site-packages/pyro/nn/module.py:427, in PyroModule.__call__(self, *args, **kwargs)
425 def __call__(self, *args, **kwargs):
426 with self._pyro_context:
--> 427 return super().__call__(*args, **kwargs)
File ~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniconda3/lib/python3.8/site-packages/pyro/contrib/cevae/__init__.py:404, in Guide.forward(self, x, t, y, size)
399 size = x.size(0)
400 with pyro.plate("data", size, subsample=x):
401 # The t and y sites are needed for prediction, and participate in
402 # the auxiliary CEVAE loss. We mark them auxiliary to indicate they
403 # do not correspond to latent variables during training.
--> 404 t = pyro.sample("t", self.t_dist(x), obs=t, infer={"is_auxiliary": True})
405 y = pyro.sample("y", self.y_dist(t, x), obs=y, infer={"is_auxiliary": True})
406 # The z site participates only in the usual ELBO loss.
File ~/miniconda3/lib/python3.8/site-packages/pyro/contrib/cevae/__init__.py:411, in Guide.t_dist(self, x)
409 def t_dist(self, x):
410 (logits,) = self.t_nn(x)
--> 411 return dist.Bernoulli(logits=logits)
File ~/miniconda3/lib/python3.8/site-packages/pyro/distributions/distribution.py:24, in DistributionMeta.__call__(cls, *args, **kwargs)
22 if result is not None:
23 return result
---> 24 return super().__call__(*args, **kwargs)
File ~/miniconda3/lib/python3.8/site-packages/torch/distributions/bernoulli.py:49, in Bernoulli.__init__(self, probs, logits, validate_args)
47 else:
48 batch_shape = self._param.size()
---> 49 super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)
File ~/miniconda3/lib/python3.8/site-packages/torch/distributions/distribution.py:55, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
53 valid = constraint.check(value)
54 if not valid.all():
---> 55 raise ValueError(
56 f"Expected parameter {param} "
57 f"({type(value).__name__} of shape {tuple(value.shape)}) "
58 f"of distribution {repr(self)} "
59 f"to satisfy the constraint {repr(constraint)}, "
60 f"but found invalid values:\n{value}"
61 )
62 super(Distribution, self).__init__()
ValueError: Expected parameter logits (Tensor of shape (32,)) of distribution Bernoulli(logits: torch.Size([32])) to satisfy the constraint Real(), but found invalid values:
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan, nan], grad_fn=<ClampBackward1>)
Trace Shapes:
Param Sites:
t_nn$$$fc.0.weight 1 712
t_nn$$$fc.0.bias 1
Sample Sites:
data dist |
value 32 712 |
Here is the code I use. I don't know in the process, it gives all NaN value
# cevae model settings
outcome_dist = "normal"
latent_dim = 50
hidden_dim = 200
num_epochs = 5
batch_size = 32
learning_rate = 0.001
learning_rate_decay = 0.01
num_layers = 2
cevae = CEVAE(outcome_dist=outcome_dist,
latent_dim=latent_dim,
hidden_dim=hidden_dim,
num_epochs=num_epochs,
batch_size=batch_size,
learning_rate=learning_rate,
learning_rate_decay=learning_rate_decay,
num_layers=num_layers)
# fit
losses = cevae.fit(X=torch.tensor(x, dtype=torch.float),
treatment=torch.tensor(treatment, dtype=torch.float),
y=torch.tensor(y, dtype=torch.float))
# predict
ite = cevae.predict(x)
I find your documentation has these two methods but unavailable right now.
Thank you.