bsz = len(indice)
indices = [indice]
for _ in range(n_step - 1):
indices.append(buffer.next(indices[-1]))
indices = np.stack(indices)
# terminal indicates buffer indexes nstep after 'indice',
# and are truncated at the end of each episode
terminal = indices[-1]
with torch.no_grad():
target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?)
target_q = to_numpy(target_q_torch.reshape(bsz, -1))
in above code, terminal =indices[-1], I think terminal should be the last column of indices (indices [:,-1]), but indices [-1] is just getting the last row of indices .
in above code, terminal =indices[-1], I think terminal should be the last column of indices (indices [:,-1]), but indices [-1] is just getting the last row of indices .