andrew-cr / tauLDR

Code for the paper https://arxiv.org/abs/2205.14987v2
Other
42 stars 7 forks source link

Dimension of transition choice #3

Closed asiraudin closed 1 year ago

asiraudin commented 1 year ago

Hi,

Great work and nice implementation !

I have a question on the way you choose the dimension of transition during training (code below) :

    rate_vals_square = rate[
        torch.arange(B, device=device).repeat_interleave(D),
        x_t.long().flatten(),
        :
    ] # (B*D, S)

    rate_vals_square[
        torch.arange(B*D, device=device),
        x_t.long().flatten()
    ] = 0.0 # 0 the diagonals

    rate_vals_square = rate_vals_square.view(B, D, S)

    rate_vals_square_dimsum = torch.sum(rate_vals_square, dim=2).view(B, D)

    square_dimcat = torch.distributions.categorical.Categorical(
        rate_vals_square_dimsum
    )

    square_dims = square_dimcat.sample() # (B,) taking values in [0, D)

What I understand is that you sample the dimension of transition from the distribution of the total outgoing rates of x components. Even though it seems pretty intuitive to me, I don't see how this is justified from a theoretical perspective, and can't find any clear explanation in your paper.

With that in mind, could you please elaborate on how this implementation can be related to your model ?

Thanks,

Antoine

andrew-cr commented 1 year ago

Hi Antoine,

Thanks so much for your interest in the method! Regarding your question, we need to sample \tilde{x} from R_t(x^{1:D}, \tilde{x}^{1:D}) suitably normalized to be treated as a probability distribution in \tilde{x}^{1:D} and without the identity transition (r_t in Prop2). With the factorization assumptions we find that R_t has only D x (S-1) non-zero values because only transitions where exactly one dimension changes are allowed (Prop3). We can imagine a joint probability distribution for the new \tilde{x}^{1:D} value of the form p(d, s) where d is the dimension that changes and s is its new value. To get the new value of \tilde{x}^{1:D} we sample p(d) then p(s | d). To find p(d) this is the normalized total outgoing rates for each dimension. Then p(s | d) is the normalized row of the rate matrix for the sampled d. Please let me know if this has cleared it up, and if you have any further questions.

asiraudin commented 1 year ago

Okay, it's much more clear now !

Thanks for the quick and detailed answer :)