OpenMined / KotlinSyft

The official Syft worker for secure on-device machine learning
https://www.openmined.org
Apache License 2.0
86 stars 27 forks source link

Batch size gets hard coded when we create Training Plan #332

Open mustansarsaeed opened 3 years ago

mustansarsaeed commented 3 years ago

Description

Hi, Thank you for the exciting library. I have just observed that when we create the Training Plan then we provide sample data that Plan should expect during the training but Torchscript code hard code the batch size in its code.

How to Reproduce

  1. Create Training Plan
  2. Pass sample data including batch size
  3. Build training plan as training_plan.build(X[:20], y[:20], batch_size, lr, model_params, trace_autograd=True)
  4. Now, print the torchscript code as
  5. training_plan.base_framework = TranslationTarget.PYTORCH.value
  6. print(training_plan.torchscript.code)

Expected Behavior

Batch size should have to be dynamic because if we hard code then if total data is not divisible by batch size then torchscript throws an exception of dimension mismatch.