qgallouedec / lge

MIT License
29 stars 2 forks source link

[Selected index k out of range] while executing torch.topk #9

Open SeungHunJeon opened 11 months ago

SeungHunJeon commented 11 months ago

In the utils, density estimation function gets top k elements among the cdist. The author manually set the k as 1000, but it seems the out of range.

(I refer the notebooks/latent_go_explore_maze.ipynb)

File [~/workspace/lge/lge/lge.py:330](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/workspace/lge/lge/lge.py:330), in LatentGoExplore.explore(self, total_timesteps, callback)
    [328](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/workspace/lge/lge/lge.py:328) else:
    [329](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/workspace/lge/lge/lge.py:329)     callback = [self.module_learner]
--> [330](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/workspace/lge/lge/lge.py:330) self.model.learn(total_timesteps, callback=callback, log_interval=1000)

File [~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:309](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:309), in SAC.learn(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)
    [299](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:299) def learn(
    [300](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:300)     self: SelfSAC,
    [301](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:301)     total_timesteps: int,
ref='~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:0'>0</a>;32m   (...)
    [306](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:306)     progress_bar: bool = False,
    [307](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:307) ) -> SelfSAC:
--> [309](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:309)     return super().learn(
    [310](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:310)         total_timesteps=total_timesteps,
    [311](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:311)         callback=callback,
    [312](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:312)         log_interval=log_interval,
    [313](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:313)         tb_log_name=tb_log_name,
    [314](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:314)         reset_num_timesteps=reset_num_timesteps,
...
     [70](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/workspace/lge/lge/utils.py:70) cdist = torch.cdist(x, samples)
---> [71](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/workspace/lge/lge/utils.py:71) dist_to_kst = cdist.topk(k, largest=False)[0][:, -1]
     [72](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/workspace/lge/lge/utils.py:72) return -dist_to_kst

RuntimeError: selected index k out of range
datake commented 1 month ago

same issue here