jdermody / brightwire

Bright Wire is an open source machine learning library for .NET with GPU support (via CUDA)
https://github.com/jdermody/brightwire/wiki
MIT License
121 stars 18 forks source link

Lower Level api #23

Open basejn opened 5 years ago

basejn commented 5 years ago

Hi i want to use BrightWire to implement the following simple interface.

The problem is that the library is very strongly coupled with data source.

 interface IBaseTrainableProbabilisticMulticategoryModel
    {
        //Construct with  NumClasses and NumFeatures

        int[] NumClasses { get; }
        int NumFeatures { get; }
        /// <summary>
        /// Train a batch 
        /// </summary>
        /// <param name="batchData"> [N,numfeatures]</param>
        /// <param name="batchLabels"> [N,subclasses] = clas id</param>
        /// <param name="weights"> [N,subclasses] = weight for subclass. used to balance classes. weights each data point contribution to the loss for each subclass</param>
        /// <returns>Loss</returns>
        float Train(float[][] batchData, int[][] batchLabels, float[][] weights);

        /// <summary>
        /// predict a batch
        /// </summary>
        /// <returns>[subclasses,N,probability distribution]</returns>
        float[][][] Predict(float[][] batchData);
    }

Is there a lower level api where i can implement a feed forward neural network model , that holds its state(weights) , can be serializable , and is always ready for training and prediction.

This is easily achievable with CNTK and Tenroflow. Their problem is that they only suppoort 64 bit.

jdermody commented 5 years ago

No, there's no lower level API unfortunately. But you could adapt your data, set the batchsize to N and train for 1 iteration to get what you want...

You would need to implement an IDataSource that accepted the first two arguments to your Train method (something like VectorDataSource).

For your weights argument you would need to implement an IAction that weighted the error signal based on the weights (something like ConstrainSignal) and add it to the graph just before the backpropagation.

basejn commented 5 years ago

Thank you . I will give it a try. The only thing that worries me is , can i have N groups of softmax outputs, N probability distributions. For eaxmple : One thing can have 3 sub categories . [dog ,cat, horse] [male , female] [white , brown , black] .

jdermody commented 5 years ago

You would need to create a custom Softmax activation to achieve that. You could use the existing softmax activation as a guide but instead of calculating softmax over the entire matrix it could split the matrix both forward and backward on each distribution (GetNewMatrixFromColumns) and then combine them again afterwards as the output.

PatriceDargenton commented 2 years ago

Hi, how is it possible to use nuget packages if "No, there's no lower level API" ? Is there an example using this package ? https://www.nuget.org/packages/BrightWire/ https://www.nuget.org/packages/BrightWire.Net4/

jdermody commented 2 years ago

Example code for https://www.nuget.org/packages/BrightWire/ (v3) is here: https://github.com/jdermody/brightwire/tree/master/ExampleCode

Example code for https://www.nuget.org/packages/BrightWire.Net4/ (v2) is here: https://github.com/jdermody/brightwire-v2/tree/master/SampleCode

PatriceDargenton commented 2 years ago

Ok, there is no problem to test all these examples in their original solution. But, if I put a nuget package in my personal project for example on BrightWire.Net4, and try to compile the XOR example, then it is not possible to use DataTableBuilder because its accessibility is friend and not public: impossible to compile in a personal project!

jdermody commented 2 years ago

I see your point - the current design is that everything is created through indirection and the classes themselves are not public from the assembly. The drawback of this is that the usage is less obvious.

For example in BrightWire.Net4 you create a data table builder like this:

var dataTableBuilder = BrightWireProvider.CreateDataTableBuilder();
dataTableBuilder.AddColumn(ColumnType.Float, "capital costs");
dataTableBuilder.AddColumn(ColumnType.Float, "labour costs");
dataTableBuilder.AddColumn(ColumnType.Float, "energy costs");
dataTableBuilder.AddColumn(ColumnType.Float, "output", true);

In BrightWire (,net core) you do it like this:

var context = new BrightDataContext();
var dataTableBuilder = context.BuildTable();
dataTableBuilder.AddColumn(BrightDataType.Float, "capital costs");
dataTableBuilder.AddColumn(BrightDataType.Float, "labour costs");
dataTableBuilder.AddColumn(BrightDataType.Float, "energy costs");
dataTableBuilder.AddColumn(BrightDataType.Float, "output").SetTarget(true);

BrightWire (.net core) is a more consistent in that everything is available through the context, which acts as an extension point via extension methods in other assemblies. I suppose this approach is one of framework rather than library, but perhaps the better design is to be both.

PatriceDargenton commented 2 years ago

Ok thanks!, now I have a problem with: var testData = graph.CreateDataSource(data); Visual Studio recommends using one of these 3 possibilities, but none can be cast from BrightWire.TabularData.Helper.DataTableWriter (InvalidCastException): var testData = graph.CreateDataSource(CType(data, IReadOnlyList(Of Models.FloatTensor))); var testData = graph.CreateDataSource(CType(data, IReadOnlyList(Of Models.FloatMatrix))); var testData = graph.CreateDataSource(CType(data, IReadOnlyList(Of Models.FloatVector)));

PatriceDargenton commented 2 years ago

Ok, I found out how it works, thank you very much! Here it is with BrightWire.Net4 in VB .NET:

Public Sub TestXOR()

    Dim lap = BrightWireProvider.CreateLinearAlgebra
    ' Some training data that the network will learn. The XOR pattern looks like:
    ' 0 0 => 0
    ' 1 0 => 1
    ' 0 1 => 1
    ' 1 1 => 0
    Dim builder = BrightWireProvider.CreateDataTableBuilder()
    builder.AddColumn(ColumnType.Float, "X")
    builder.AddColumn(ColumnType.Float, "Y")
    builder.AddColumn(ColumnType.Float, "XOR", True)
    builder.Add(0.0!, 0.0!, 0.0!)
    builder.Add(1.0!, 0.0!, 1.0!)
    builder.Add(0.0!, 1.0!, 1.0!)
    builder.Add(1.0!, 1.0!, 0.0!)
    Dim data = builder
    Dim dataTable = builder.Build()

    ' create the graph
    Dim graph = New GraphFactory(lap)
    Dim errorMetric = graph.ErrorMetric.CrossEntropy
    graph.CurrentPropertySet.Use(graph.GradientDescent.RmsProp)
    ' and gaussian weight initialisation
    graph.CurrentPropertySet.Use(graph.WeightInitialisation.Gaussian)
    ' create the engine
    Dim testData = graph.CreateDataSource(dataTable)

    Dim engine = graph.CreateTrainingEngine(testData, learningRate:=0.1!, batchSize:=4)
    ' create the network
    Const HIDDEN_LAYER_SIZE As Integer = 6
    With graph.Connect(engine)
        ' create a feed forward layer with sigmoid activation
        .AddFeedForward(HIDDEN_LAYER_SIZE).Add(graph.SigmoidActivation)
        ' create a second feed forward layer with sigmoid activation
        .AddFeedForward(engine.DataSource.OutputSize).Add(graph.SigmoidActivation)
        ' calculate the error and backpropagate the error signal
        .AddBackpropagation(errorMetric)
    End With

    ' train the network
    Dim executionContext = graph.CreateExecutionContext
    Dim i = 0
    Do While i < 1000
        engine.Train(executionContext)
        If i Mod 100 = 0 Then engine.Test(testData, errorMetric)
        i += 1
    Loop

    engine.Test(testData, errorMetric)
    ' create a new network to execute the learned network
    Dim networkGraph = engine.Graph
    Dim executionEngine = graph.CreateEngine(networkGraph)
    Dim output = executionEngine.Execute(testData)
    Debug.WriteLine("Average output = " &
        output.Average((Function(o) o.CalculateError(errorMetric))))

    ' print the values that have been learned
    For Each item In output
        For Each index In item.MiniBatchSequence.MiniBatch.Rows
            Dim row = dataTable.GetRow(index)
            Dim result = item.Output(index)
            Dim rOutput! = result.Data(0)
            Debug.WriteLine(
                row.GetField(Of Single)(0) & " XOR " &
                row.GetField(Of Single)(1) & " = " & rOutput.ToString("0.00000"))
        Next
    Next

End Sub