Closed lqiang2003cn closed 1 year ago
Hi @lqiang2003cn ,
Thanks for noticing this, and many apologies for the late reply! I think you're absolutely right about this -- I will look into this and fix, but feel free to submit a PR / your own fix if you'd like.
I'm imagining the solution would be to do exactly what you're saying, namely:
arg_list = [X, list(range(X.ndim))] + list(chain(*([x[xdim_i],[dims[xdim_i]]] for xdim_i in range(len(x))))) # leave out the [[0]] at the end
accuracy = np.einsum(*arg_list)
But I'm debating whether to make that an additional gating condition within spm_dot
(with another if
statement), or whether to just do it in-line in the calc_free_energy
function...
An aside: + [[0]]
in the arguments to einsum
is a vestige of the typical use of spm_dot
, where you're trying to sum out everything except the first dimension (typically, the support of the variable over which you're computing a marginal distribution).
But I'm debating whether to make that an additional gating condition within spm_dot (with another if statement), or whether to just do it in-line in the calc_free_energy function...
spm_dot is already complex enough (for me), i'd do it in-line in the calc_free_energy function
:).
I fixed the bug, and have merged the resulting PR into master. Nice job again @lqiang2003cn for spotting this, and thank you. This is by far the most technically-nuanced error someone has noticed so far; I'm impressed that you managed to parse that line of preparation for the einsum arguments in spm_dot
:)
Will close this issue upon release of pymdp 0.0.7
Closing now that v0.0.7 is released
hi, guys, i have a question about calculating the free energy. While calculating einsum, why only sum to the first dim:
arg_list = [X, list(range(X.ndim))] + list(chain(*([x[xdim_i],[dims[xdim_i]]] for xdimi in range(len(x))))) + [[0]]_
and then use:
spm_dot(likelihood, qs)[0]
to fetch only the first level of s0 (suppose hidden state s is factorized as s0, s1, ... ) ?
in my opinion, we should sum over all the tensor by removing the last + [[0]] and the Y should just be a scalar. am I totally wrong, if i am, please correct me, thanks a lot!