voize-gmbh / pytorch-lite-multiplatform

A Kotlin multi-platform wrapper around the PyTorch Lite libraries on Android and iOS.
Apache License 2.0
38 stars 4 forks source link
kotlin kotlin-multiplatform-mobile pytorch pytorch-lite pytorch-mobile

pytorch-lite-multiplatform

CI Maven Central Cocoapods

A Kotlin multi-platform wrapper around the PyTorch Lite libraries on Android and iOS. You can use this library in your Kotlin multi-platform project to write mobile inference code for PyTorch Lite models. The API is very close to the Android API of PyTorch Lite. A high level introduction is available in our blog post.

Installation

Add the following to your shared/build.gradle.kts as a commonMain dependency.

implementation("de.voize:pytorch-lite-multiplatform:<version>")

Add the PLMLibTorchWrapper pod to your cocoapods plugin block in shared/build.gradle.kts and add useLibraries() because the PLMLibTorchWrapper pod has a dependency on the LibTorch-Lite pod which contains static libraries.

cocoapods {
    ...

    pod("PLMLibTorchWrapper") {
        version = "<version>"
        headers = "LibTorchWrapper.h"
    }

    useLibraries()
}

If you use Kotlin version < 1.8.0 the headers property is not available. Instead, you have to add the following to your shared/build.gradle.kts (see this issue for more information):

tasks.named<org.jetbrains.kotlin.gradle.tasks.DefFileTask>("generateDefPLMLibTorchWrapper").configure {
    doLast {
        outputFile.writeText("""
            language = Objective-C
            headers = LibTorchWrapper.h
        """.trimIndent())
    }
}

Additional steps:

Usage

First, export your PyTorch model for the lite interpreter. Manage in your application how the exported model file is stored on device, e.g. bundled with your app, downloaded from a server during app initialization or something else. Then you can initialize the TorchModule with the path to the model file.

import de.voize.pytorch_lite_multiplatform.TorchModule

val module = TorchModule(path = "<path/to/model.ptl>")

Once you initialized the model you are ready to run inference.

Just like in the Android API of PyTorch Lite, you can use IValue and Tensor to pass input data into your model and to process the model output. To manage the memory allocated for your tensors you need to use plmScoped to specify up to which point you need to keep the memory allocated.

import de.voize.pytorch_lite_multiplatform.*

plmScoped {
    val inputTensor = Tensor.fromBlob(
        data = floatArrayOf(...),
        shape = longArrayOf(...),
        scope = this
    )

    val inputIValue = IValue.fromTensor(inputTensor)

    val output = module.forward(inputIValue)
    // you could also use
    // module.runMethod("forward", inputIValue)

    val outputTensor = output.toTensor()
    val outputData = outputTensor.getDataAsFloatArray()

    ...
}

IValues are very flexible to construct the input you need for your model, e.g. tensors, scalars, flags, dicts, tuples etc. Refer to the [IValue]() interface for all available options and browse PyTorch's Android Demo for examples on inferences using IValue.

Memory Management

To make management of resources allocated for your inference across Android and iOS simpler we introduced the PLMScope and the plmScoped util. On Android, the JVM garbage collection and PyTorch Lite manage the allocated memory nicely so plmScoped is a noop. But on iOS, memory is allocated in Kotlin and exchanged with native Objective-C code and vice-versa without automatic deallocation of resources. This is where plmScoped comes in and frees the memory allocated for your inference. So it is important that you properly define the scope in which resources need to stay allocated to avoid memory leaks or memory being lost that is needed later.

Running tests

iOS

To run the tests on iOS, execute the iosSimulatorX64Test gradle task:

./gradlew iosSimulatorX64Test

This will automatically call build_dummy_model.py to create the dummy torchscript module for testing, copy it into the simulator files directory and execute the tests. Make sure to select a Python environment where the torch dependency is available.