OpenMined / SwiftSyft

The official Syft worker for iOS, built in Swift
Apache License 2.0
50 stars 17 forks source link

Training API #172

Open vvmnnnkv opened 4 years ago

vvmnnnkv commented 4 years ago

Feature Description

Add Useful API for training cycle so that user doesn't need to code training loop from scratch each time.

Add new methods in Job:

  1. Job.request() Same as we currently do inside the .start() method (auth, download of model and plan).

  2. trainingProcess = Job.train(trainingPlan, parameters) Helper for training loop

trainingProcess - object would contain current epoch, batch, modelParameters trainingPlan - string parameters - dict of values:

planInputs: list of PlanInputSpec planOutputs: list of PlanOutputSpec data: tensor target: (optional) tensor epochs: number - how many epoch to train batchSize: number stepsPerEpoch: (optional) number - max number of steps per epoch events: list of handlers: 'start', 'end', 'epochStart', 'epochEnd', 'batchStart', 'batchEnd', 'error'

PlanInputSpec: object that describes plan input argument

type: 'data' | 'target' | 'epoch' | 'batchSize' | 'step' | 'modelParameter' | 'value' index: number name: (optional) string value: (optional) tensor

PlanOutputSpec: object that describes plan output

type: 'loss' | 'metric' | 'modelParameter' index: number name: (optional) string

Pseudo code: Training loop:

train(...):
  x, y = get_batch(data, batchSize), get_batch(target, batchSize)
  stepsPerEpoch = stepsPerEpoch || len(data) / batchSize
  trigger_event('start')
  modelParameters = job.model.parameters
  for (i = 0; i < epochs; i++) {
    trigger_event('epochStart', (i))
    for (j = 0; j < stepsPerEpoch; j++) {
      trigger_event('batchStart', (i, j))
      plan_args = resolve_inputs(planInputs,
        {
          modelParameters: modelParameters,
          data: x,
          target: y,
          epoch: i,
          batchSize: batchSize,
          step: j   
        }
      )
      raw_outputs = job.plans[trainingPlan].execute(...plan_args)
      outputs = resolve_outputs(planOutputs, raw_outputs)
      status = {loss: outputs.loss, metric: output.metric}
      modelParameters = outputs.modelParameter
      trigger_event('batchEnd', (i, j, status))
    }
    trigger_event('epochEnd', (i))
  }
  trigger_event('end')

Resolving plan inputs/outputs from specs:

resolve_inputs(specs, vars) {
  args = []
  for (spec in specs) {
    if (spec.type == 'value') {
      args.push(spec.value)
    } elseif (spec.index) {
      args.push(vars[spec.type][spec.index])
    } else {
      args.push(vars[spec.type])
    }
  }
  return args
}

resolve_outputs(specs, output) {
  out = {}
  i = 0
  for (spec in specs) {
    if (spec.index) {
      out[spec.type][index] = output[i]
    } else {
      out[spec.type] = output[i]
    }
    i++
  }
  return args
}

Example for input/output specs for MNIST training plan:

[{type: 'data'}, {type: 'target'}, {type: 'batchSize'}, {type: 'value', 'value': <lr>}, 
{type: 'modelParams', index: 0}, {type: 'modelParams', index: 1},
{type: 'modelParams', index: 2}, {type: 'modelParams', index: 3}]

[{type: 'loss'}, {type: 'metric'},
{type: 'modelParams', index: 0}, {type: 'modelParams', index: 1},
{type: 'modelParams', index: 2}, {type: 'modelParams', index: 3}]

What alternatives have you considered?

API was discussed in FL team.

Additional Context

n/a