Closed wormhole2019 closed 3 years ago
我正在学习强化学习相关内容,打算使用KelpNet实现DNQ,需要构造到 类似于pytorch的learn算法: def learn(self, obs, action, reward, next_obs, terminal): """ update value model self.model with DQN algorithm """ pred_value = self.model(obs).gather(1, action) with torch.no_grad(): max_v = self.target_model(next_obs).max(1, keepdim=True)[0] target = reward + (1 - terminal) self.gamma max_v self.optimizer.zero_grad() loss = self.mse_loss(pred_value, target) loss.backward() self.optimizer.step() return loss.item() 却不知如何实现gather(1, action) 这个功能,可能在KelpNet中需要实现自定义层,希望能提供相应的样例
遗憾的是,目前KelpNet中还没有相当于Pytorch的Gather功能
根据函数的说明,似乎是在为Index操作进行处理 如果你不想让它通用,我想可以自己做 https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather
由于我们目前正在努力将KelpNet的主干从Chainer改为TensorFlow 我们在回答请求时遇到了困难
我引用这个网站的一个例子作为参考 https://qiita.com/dokkozo/items/e173acded17a142e6d02
>>> a = torch.rand(3,3)
>>> a
tensor([[0.8395, 0.2678, 0.5021],
[0.1397, 0.9839, 0.8991],
[0.6298, 0.6101, 0.6841]])
>>> sorted, idx = torch.sort(a)
>>> sorted
tensor([[0.2678, 0.5021, 0.8395],
[0.1397, 0.8991, 0.9839],
[0.6101, 0.6298, 0.6841]])
>>> idx
tensor([[1, 2, 0],
[0, 2, 1],
[1, 0, 2]])
>>> sorted, idx = torch.sort(a, descending = True)
>>> sorted
tensor([[0.8395, 0.5021, 0.2678],
[0.9839, 0.8991, 0.1397],
[0.6841, 0.6298, 0.6101]]
>>> idx = torch.argsort(a)
>>> idx
tensor([[1, 2, 0],
[0, 2, 1],
[1, 0, 2]])
>>> sorted, idx = torch.sort(a, -1)
>>> torch.gather(a, -1, idx)
tensor([[0.2678, 0.5021, 0.8395],
[0.1397, 0.8991, 0.9839],
[0.6101, 0.6298, 0.6841]])
>>> b = torch.rand_like(a)
>>> torch.gather(b, 1, idx)
>>> invidx = torch.argsort(idx)
>>> torch.gather(sorted, -1, invidx)
tensor([[0.8395, 0.2678, 0.5021],
[0.1397, 0.9839, 0.8991],
[0.6298, 0.6101, 0.6841]])
English: Unfortunately, the equivalent of Pytorch's Gather function is not currently available in KelpNet.
According to the specification of the function, it seems to be processing for the Index operation. If you don't want to make it generic, I think it's possible to make it yourself. https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather
Since we are currently working on changing the backbone of KelpNet from Chainer to TensorFlow We are having difficulty answering requests.
For reference, I will quote a sample from this site. https://qiita.com/dokkozo/items/e173acded17a142e6d02
Japanese: 残念ながらPytorchのGather関数に相当する処理は現在KelpNetに用意されていません。
関数の仕様を見た所Index操作に関わる処理を行っているようで、 汎用的に作らないのであれば、十分自作が可能だと思います。 https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather
現在KelpNetのバックボーンをChainerからTensorFlowに変更する作業を行っているため、 リクエストにお答えするのが難しい状況です。
参考にこちらのサイトからサンプルを引用させていただきます。 https://qiita.com/dokkozo/items/e173acded17a142e6d02
It has been a week and is temporarily closed.
Please open the issue again if you have additional information.
I understand that you have some kind of request, but I couldn't get to your intentions.
Please add detailed information about your request, even if it is in your native language.