Closed maniospas closed 2 months ago
Progress: began separation of the base ModelTraining
class and NodeClassificaiton
as a training strategy. As of v1.3.28, they are both moved under the adhoc package.
The final scheme for working with trainers is that they should be passed all necessary data and they overload the data setting and batch generation methods internally. The final interface automates attributed graph function training like this:
ModelTraining trainer = new AGFTraining() // for any attributed graph function.
.setGraphs(dtrain.graphs)
.setNodeFeatures(dtrain.features)
.setGraphLabels(dtrain.labels)
.setValidationSplit(0.2)
.setEpochs(300)
.setOptimizer(new Adam(0.001))
.setLoss(new CategoricalCrossEntropy())
//.setNumBatches(10)
//.setParallelizedStochasticGradientDescent(true)
.setValidationLoss(new VerboseLoss(new CategoricalCrossEntropy(), new Accuracy()));
Model model = builder.getModel()
.init(new XavierNormal())
.train(trainer);
Let
ModelTraining
or a similar class support graph classification tasks to streamline training.