locuslab / torchdeq

Modern Fixed Point Systems using Pytorch
MIT License
78 stars 9 forks source link

Implementation of `torchdeq.utils.mixed_init` different from original paper #3

Open jklim01 opened 6 months ago

jklim01 commented 6 months ago

From the paper: "we experimented with initializing the hidden states with zeros on half of the examples in the batch, and with standard Gaussian noise on the rest of the examples"

"Mixed initialization: During each training forward pass, each sample was assigned with either zero initialization (i.e. the fixed point was initialized with the 0 vector) or standard normal distribution (i.e. ...) using a Bernoulli random variable of probability 0.5 (i.e. the examples that were run with zero vs. normal initializations were roughly half-half."

Current implementation: https://github.com/locuslab/torchdeq/blob/4f6bd5fa66dd991cad74fcc847c88061764cf8db/torchdeq/utils/init.py#L4-L21

It seems more appropriate to do this instead to match the paper.

*mask_shape, _ = z_shape
mask = torch.empty(*mask_shape, device=device).bernoulli_(0.5).unsqueeze(-1)

This form has the disadvantage of assuming that all but the last dimension are batch dimensions. But this seems to be quite a reasonable assumption, and downstream users can easily adjust to this by reshaping and rearranging the dimensions.