Open sahilm1992 opened 4 years ago
We used the strategy described in the original paper, the tau should be tuned on your dataset.
def encoder(self, x, τ):
"""
Input:
x (batch, D)
Output:
z (batch, K): Gumbel-softmax samples.
"""
logπ = self.x2logπ(x)
u = self.d_uniform.sample(logπ.size())
g = -torch.log(-torch.log(u))
z = F.softmax((logπ + g)/τ, dim=1)
return z
where x2logπ
is a MLP mapping coordinate to an unconstrained real vector.
Could you please share details on gumbel softmax? How did you incorporate in your project? Parameters like tau, hard etc?