Open sheryjoe opened 4 years ago
BeforeIteration - compute shape statistics and correspondence object gradient. We can use python for training the network and compute the correspondence gradient. @AtefehKashani Which APIs are needed from the Optimize lib to train and compute gradients?
I have added a list of needed APIs here.
Let's define some notations before going to the pseudocode.
here is the steps in nonlinear shape modeling
@AtefehKashani, if you want to try the pybind_optimize branch, both methods of using the Optimize class with Python are implemented:
https://github.com/SCIInstitute/ShapeWorks/tree/pybind_optimize
Implemented so far:
You can drive from Python by doing this:
import numpy
import shapeworks
opt = shapeworks.Optimize()
opt.LoadParameterFile("optimize_python.xml")
def callback():
print("python callback")
particles = opt.GetParticleSystem()
print(type(particles))
print(particles)
opt.SetIterationCallbackFunction(callback)
opt.Run()
Alternatively, you can have ShapeWorksRun (and later Studio) drive by specifying:
<python_filename>file.py</python_filename>
This module will be loaded at startup and the "run" command will be called with the Optimize class as the parameter. An example file that accomplishes the same as above:
import shapeworks
opt = shapeworks.Optimize
def callback():
print("python callback")
particles = opt.GetParticleSystem()
print(type(particles))
print(particles)
def run(optimizer):
global opt
opt = optimizer
opt.SetIterationCallbackFunction(callback)
Example output:
...
1. 0ms
python callback
<class 'numpy.ndarray'>
[[-19.12674606 -19.12674606 -23.51327091 -23.51327091]
[-35.55056751 -35.55056751 -35.6341362 -35.6341362 ]
[-91.48947597 -91.48947597 -90.42861462 -90.42861462]
[-21.17568552 -21.17568552 -23.91829342 -23.91829342]
[-39.20589983 -39.20589983 -37.9530549 -37.9530549 ]
[-89.5236671 -89.5236671 -89.37270641 -89.37270641]]
Energy: 1.44793
...
Let me know what calls, both ways (to/from python) you would like added to the API.
pybind11 will automatically convert Eigen matrices to numpy.
Callback List
I'll add in the callbacks here as I need them. I think we are good for now.
Callback | Status | Description |
---|---|---|
GetShapeWorksStage() | 🕒 | flag that indicates whether shapeworks is in the initialize stage or the optimize stage. This is required as some of the processing is only in the optimize stage and need not be performed when particles are still in the initialize stage |
GetParticleSystem() | ✅ | Gets the current matrix containing the correspondence particles |
GetCorrespondenceUpdateMatrix() | ✅ | Gets the correspondence updates for particles |
GetParticles() | ✅ | Gets the current matrix containing the correspondence points |
SetCorrespondenceUpdateMatrix(updateMatrix) | ✅ | Set the updates for the correspondence term from Python |
I've added GetOptimizing() that returns true or false. True for optimization phase, false for initialization.
perfect!
@AtefehKashani Please define APIs and callbacks needed to perform particle optimization in python for the nonlinear shape model. Adding @cchriste to help Alan and Atefeh getting started with python APIs (pybind).