Open basejn opened 5 years ago
No, there's no lower level API unfortunately. But you could adapt your data, set the batchsize to N and train for 1 iteration to get what you want...
You would need to implement an IDataSource that accepted the first two arguments to your Train method (something like VectorDataSource).
For your weights argument you would need to implement an IAction that weighted the error signal based on the weights (something like ConstrainSignal) and add it to the graph just before the backpropagation.
Thank you . I will give it a try. The only thing that worries me is , can i have N groups of softmax outputs, N probability distributions. For eaxmple : One thing can have 3 sub categories . [dog ,cat, horse] [male , female] [white , brown , black] .
You would need to create a custom Softmax activation to achieve that. You could use the existing softmax activation as a guide but instead of calculating softmax over the entire matrix it could split the matrix both forward and backward on each distribution (GetNewMatrixFromColumns) and then combine them again afterwards as the output.
Hi, how is it possible to use nuget packages if "No, there's no lower level API" ? Is there an example using this package ? https://www.nuget.org/packages/BrightWire/ https://www.nuget.org/packages/BrightWire.Net4/
Example code for https://www.nuget.org/packages/BrightWire/ (v3) is here: https://github.com/jdermody/brightwire/tree/master/ExampleCode
Example code for https://www.nuget.org/packages/BrightWire.Net4/ (v2) is here: https://github.com/jdermody/brightwire-v2/tree/master/SampleCode
Ok, there is no problem to test all these examples in their original solution. But, if I put a nuget package in my personal project for example on BrightWire.Net4, and try to compile the XOR example, then it is not possible to use DataTableBuilder because its accessibility is friend and not public: impossible to compile in a personal project!
I see your point - the current design is that everything is created through indirection and the classes themselves are not public from the assembly. The drawback of this is that the usage is less obvious.
For example in BrightWire.Net4 you create a data table builder like this:
var dataTableBuilder = BrightWireProvider.CreateDataTableBuilder();
dataTableBuilder.AddColumn(ColumnType.Float, "capital costs");
dataTableBuilder.AddColumn(ColumnType.Float, "labour costs");
dataTableBuilder.AddColumn(ColumnType.Float, "energy costs");
dataTableBuilder.AddColumn(ColumnType.Float, "output", true);
In BrightWire (,net core) you do it like this:
var context = new BrightDataContext();
var dataTableBuilder = context.BuildTable();
dataTableBuilder.AddColumn(BrightDataType.Float, "capital costs");
dataTableBuilder.AddColumn(BrightDataType.Float, "labour costs");
dataTableBuilder.AddColumn(BrightDataType.Float, "energy costs");
dataTableBuilder.AddColumn(BrightDataType.Float, "output").SetTarget(true);
BrightWire (.net core) is a more consistent in that everything is available through the context, which acts as an extension point via extension methods in other assemblies. I suppose this approach is one of framework rather than library, but perhaps the better design is to be both.
Ok thanks!, now I have a problem with: var testData = graph.CreateDataSource(data); Visual Studio recommends using one of these 3 possibilities, but none can be cast from BrightWire.TabularData.Helper.DataTableWriter (InvalidCastException): var testData = graph.CreateDataSource(CType(data, IReadOnlyList(Of Models.FloatTensor))); var testData = graph.CreateDataSource(CType(data, IReadOnlyList(Of Models.FloatMatrix))); var testData = graph.CreateDataSource(CType(data, IReadOnlyList(Of Models.FloatVector)));
Ok, I found out how it works, thank you very much! Here it is with BrightWire.Net4 in VB .NET:
Public Sub TestXOR()
Dim lap = BrightWireProvider.CreateLinearAlgebra
' Some training data that the network will learn. The XOR pattern looks like:
' 0 0 => 0
' 1 0 => 1
' 0 1 => 1
' 1 1 => 0
Dim builder = BrightWireProvider.CreateDataTableBuilder()
builder.AddColumn(ColumnType.Float, "X")
builder.AddColumn(ColumnType.Float, "Y")
builder.AddColumn(ColumnType.Float, "XOR", True)
builder.Add(0.0!, 0.0!, 0.0!)
builder.Add(1.0!, 0.0!, 1.0!)
builder.Add(0.0!, 1.0!, 1.0!)
builder.Add(1.0!, 1.0!, 0.0!)
Dim data = builder
Dim dataTable = builder.Build()
' create the graph
Dim graph = New GraphFactory(lap)
Dim errorMetric = graph.ErrorMetric.CrossEntropy
graph.CurrentPropertySet.Use(graph.GradientDescent.RmsProp)
' and gaussian weight initialisation
graph.CurrentPropertySet.Use(graph.WeightInitialisation.Gaussian)
' create the engine
Dim testData = graph.CreateDataSource(dataTable)
Dim engine = graph.CreateTrainingEngine(testData, learningRate:=0.1!, batchSize:=4)
' create the network
Const HIDDEN_LAYER_SIZE As Integer = 6
With graph.Connect(engine)
' create a feed forward layer with sigmoid activation
.AddFeedForward(HIDDEN_LAYER_SIZE).Add(graph.SigmoidActivation)
' create a second feed forward layer with sigmoid activation
.AddFeedForward(engine.DataSource.OutputSize).Add(graph.SigmoidActivation)
' calculate the error and backpropagate the error signal
.AddBackpropagation(errorMetric)
End With
' train the network
Dim executionContext = graph.CreateExecutionContext
Dim i = 0
Do While i < 1000
engine.Train(executionContext)
If i Mod 100 = 0 Then engine.Test(testData, errorMetric)
i += 1
Loop
engine.Test(testData, errorMetric)
' create a new network to execute the learned network
Dim networkGraph = engine.Graph
Dim executionEngine = graph.CreateEngine(networkGraph)
Dim output = executionEngine.Execute(testData)
Debug.WriteLine("Average output = " &
output.Average((Function(o) o.CalculateError(errorMetric))))
' print the values that have been learned
For Each item In output
For Each index In item.MiniBatchSequence.MiniBatch.Rows
Dim row = dataTable.GetRow(index)
Dim result = item.Output(index)
Dim rOutput! = result.Data(0)
Debug.WriteLine(
row.GetField(Of Single)(0) & " XOR " &
row.GetField(Of Single)(1) & " = " & rOutput.ToString("0.00000"))
Next
Next
End Sub
Hi i want to use BrightWire to implement the following simple interface.
The problem is that the library is very strongly coupled with data source.
Is there a lower level api where i can implement a feed forward neural network model , that holds its state(weights) , can be serializable , and is always ready for training and prediction.
This is easily achievable with CNTK and Tenroflow. Their problem is that they only suppoort 64 bit.