arrayfire / arrayfire-ml

ArrayFire's Machine Learning Library.
BSD 3-Clause "New" or "Revised" License
102 stars 23 forks source link

Static Activations Class #26

Closed jramapuram closed 7 years ago

jramapuram commented 8 years ago

Should provide str2enum and then return a function ptr or something similar. This will be helpful for things like:

        private:

            Weights mWeights;
            Weights mDiffs;
            Activation mActivation;

        public:

            LinearNode(const int inputSize, const int outputSize,
                       std::string activation='tanh',
                       float spread = 0.05,
                       const char *name="none"):
                Node(1, &inputSize, 1, &outputSize, name),
                mWeights(inputSize, outputSize, spread),
                mActivation.get(activation),
                mDiffs()
{}

            ArrayVector forward(const ArrayVector &input){
                return {mActivation(af::matmul(mWeights.getWeights(), input[0])) +
                        af::tile(mWeights.getBias(), 1, input[0].dims(1))};
pavanky commented 8 years ago

LinearNode merely performs the matrix multiplication. The activation part is done separately. This is similar to what other neural network libraries / frameworks are doing. From what I have looked at, this looks like a good design choice. I am not sure if changing LinearNode is a good idea. If it is necessary for convenience, you can easily do this in user code.

For example FFNet sort of does this, but is more generalized. You can have a specialization of FFNet and restrict it to what you have mentioned here.

jramapuram commented 8 years ago

The linear node was just an example. I can't have the activations be external to the node for an LSTM. Some gates (eg: memory gate) are dependent on activations of the previous gates and thus cannot be refactored outside of the node level.

pavanky commented 8 years ago

Fair enough. I have not yet started looking into LSTMs. I am currently implementing some core functionality necessary for convolutional nets in ArrayFire. I will keep this in mind when looking at LSTMs.

jramapuram commented 8 years ago

I'v already started writing a skeleton of LSTMs. Will submit a pull soon

pavanky commented 7 years ago

No longer relevant.