selfint / rust-ml

Rust machine learning library
0 stars 0 forks source link

Implement cce loss with softmax (i.e. from logits) #19

Closed selfint closed 3 years ago

selfint commented 3 years ago

As I understand it now, softmax needs the activation of the node (which can be calculated easily), but also the expected activation of said node. Currently the derivative of an activation function is expected to only require the transfer, from which it can calculate its own activation if needed.

To implement the softmax derivative we can either:

  1. Implement a new layer, separate from the current layers that handles the softmax edge case
  2. Implement a flag that makes network output be processed through softmax.
selfint commented 3 years ago

Some more thoughts on 2: Softmax and CCE loss only make sense when used together AFAIK. When calculating the gradients, we can branch into 2 separate options - softmax + cce / standard. In other words classification / regression.

Cons:

  1. Might slow down implementation (probably not by much though)
  2. Not extensible - it will be impossible (or at least very ugly) for a user of the library to implement their own 'mode'
selfint commented 3 years ago

Option 3: Implement classification/regression as a network trait.

This means that anyone can add their own 'mode', and optimizers can be implemented for any mode.

selfint commented 3 years ago

Softmax derivative and activation implementations (+cce loss) should be implemented together

selfint commented 3 years ago

Option 4: Do not implement softmax as an activation, and bake it into cce loss.

Explanation:

  1. When using a network for classification, applying softmax to the output layer and then getting the argmax, or, applying argmax to the output layer directly will always yield the same result.
  2. Softmax isn't used as an activation of a hidden layer AFAIK, and since it produces a matrix on not a vector, it doesn't really make sense to do so.

That is why I think we can safely use softmax only when computing the loss of the network. It doesn't have to be only for CCE, but it will stay there until some other loss function needs it. Worst case, it is a very simple implementation so duplicating it shouldn't be that bad.