Open rasoolianbehnam opened 3 months ago
There seems to be a bug in the model specification. in
tfd.MultivariateNormalDiag( loc=tf.zeros(num_students), scale_diag=self._stddev_students * tf.ones(num_students)), tfd.MultivariateNormalDiag( loc=tf.zeros(num_instructors), scale_diag=self._stddev_instructors * tf.ones(num_instructors)), tfd.MultivariateNormalDiag( loc=tf.zeros(num_departments), scale_diag=self._stddev_departments * tf.ones(num_departments)),
it seems that self._stddev_students, self._stddev_instructors, and self._stddev_students are not being tracked by the GradientTape and therefore not updated properly in the m step.
self._stddev_students
self._stddev_instructors
There seems to be a bug in the model specification. in
it seems that
self._stddev_students
,self._stddev_instructors
, andself._stddev_students
are not being tracked by the GradientTape and therefore not updated properly in the m step.