sphinxteam / Boltzmann.jl

Restricted Boltzmann Machines in Julia
Other
15 stars 4 forks source link

Update of the weights #15

Closed AurelienDecelle closed 8 years ago

AurelienDecelle commented 8 years ago

I'm not sure about that, but, at least for the positive part of the update weights there seems to be an error. the code is the following :

function gibbs(rbm::RBM, vis::Mat{Float64}; n_times=1) v_pos = vis h_pos = sample_hiddens(rbm, v_pos) h_neg = Array(Float64,0,0)::Mat{Float64} v_neg = Array(Float64,0,0)::Mat{Float64} if n_times > 0

Save computation by setting n_times=0 in the case

# of persistent CD.
    v_neg = sample_visibles(rbm, h_pos)
    h_neg = sample_hiddens(rbm, v_neg)
    for i=1:n_times-1
        v_neg = sample_visibles(rbm, h_neg)
        h_neg = sample_hiddens(rbm, v_neg)
    end
end
return v_pos, h_pos, v_neg, h_neg

end

So, I guess that h_pos should be something like h_pos = hid_means(rbm,v_pos) ? In scikitLearn they also use it for the negative part, but I'm less confident that it is essential.

A.

eric-tramel commented 8 years ago

This is a good question, @AurelienDecelle, one that @marylou-gabrie pointed out while working on the TAP implementation.

In fact, looking at the reference code of Hinton, you can see that the final values of the hidden positive and negative phases are calculated according to the mean values on the hidden units and not the sampled values.

So, in this case, it should more accurately be

function gibbs(rbm::RBM, vis::Mat{Float64}; n_times=1)
    v_pos = vis
    h_pos = sample_hiddens(rbm, v_pos)
    h_neg = Array(Float64,0,0)::Mat{Float64}
    v_neg = Array(Float64,0,0)::Mat{Float64}
    if n_times > 0
    # Save computation by setting `n_times=0` in the case
    # of persistent CD.
        v_neg = sample_visibles(rbm, h_pos)
        h_neg = sample_means(rbm, v_neg)
        for i=1:n_times-1
            v_neg = sample_visibles(rbm, h_neg)
            h_neg = sample_hiddens(rbm, v_neg)
        end
    end

    h_neg = hid_means(rbm,v_neg)
    h_pos = hid_means(rbm,v_pos)

    return v_pos, h_pos, v_neg, h_neg
end
eric-tramel commented 8 years ago

Corrected with a vengeance! Now closing the issue. Unfortunately, we now have some extra computations in the main loop, so we lose a bit of efficiency at runtime :(

eric-tramel commented 8 years ago

CORRECTED WITH SUPER DOUBLE VENGANCE

maxresdefault