cossio / RestrictedBoltzmannMachines.jl

Train and sample Restricted Boltzmann machines in Julia
MIT License
14 stars 3 forks source link

wts with eltype of rbm.w ? #21

Closed cossio closed 7 months ago

cossio commented 1 year ago

https://github.com/cossio/RestrictedBoltzmannMachines.jl/blob/73bb5bfd2091344ab5dce7e4fd718bcdf9a73a01/src/rbm.jl#L317

Should we convert wts here to the eltype of rbm.w ? Something like this:

wts_diag = with_eltypeof(rbm.w, Diagonal(vec(wts)))
∂wflat = -vflat * wts_diag * hflat' / sum(wts)