oreilly-japan / deep-learning-from-scratch

『ゼロから作る Deep Learning』(O'Reilly Japan, 2016)
MIT License
3.99k stars 3.34k forks source link

softmax関数のコード改善案 #45

Closed pometa0507 closed 4 years ago

pometa0507 commented 4 years ago

softmax関数について簡潔なコードのご提案です。

https://github.com/oreilly-japan/deep-learning-from-scratch/blob/77eba24406354f1361fe3614fdcac54844729184/common/functions.py#L31-L39

上記コードはnp.sumで次元が潰れてしまうため、先にxを転置して操作してまた転置していることかと思います。 下記が改善案のコードです。

def softmax(x):
    x = x - np.max(x, axis=-1, keepdims=True)   # オーバーフロー対策
    return np.exp(x) / np.sum(np.exp(x), axis=-1, keepdims=True)

keepdims=Trueを使うことでxの次元数を保ったまま操作し、 axis=-1を指定することでxが1次元でも2次元でも最後の次元に対して操作しています。 こちらのコードの方が直観的に理解しやすくすっきりするのではないでしょうか。

koki0702 commented 4 years ago

ありがとうございます。