cesarali / conditional_rate_matching

Flow Matching for Discrete Variables
2 stars 0 forks source link

make logistic head compatible with 32x32 and 28x28 images #11

Open dfaroughy opened 4 months ago

dfaroughy commented 4 months ago

change hardcoded C and 32x32 pixels in TemporalToRateLogistic as below:

class TemporalToRateLogistic(nn.Module):
    def __init__(self, config:CRMConfig,temporal_output_total,device):
        nn.Module.__init__(self)
        self.D = config.data1.dimensions
        self.S = config.data1.vocab_size
        self.C = config.data1.temporal_net_expected_shape[0]
        self.K = config.data1.temporal_net_expected_shape[1]
        self.device = device
        self.fix_logistic = config.temporal_network_to_rate.fix_logistic

    def forward(self,net_out):
        B = net_out.shape[0]
        D = self.D
        C = self.C
        S = self.S
        K = self.K
        net_out = net_out.view(B,2*C,K,K)