deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.12k stars 655 forks source link

How to load .pt model in scala? #3402

Open zaryabRiasat opened 2 months ago

zaryabRiasat commented 2 months ago

I've downloaded pre-trained model from Github Repository, which is 20180402-114759-vggface2.pt. I've used this in python and it is working fine with great accuracy.

python.py:

from facenet_pytorch import MTCNN, InceptionResnetV1
from PIL import Image
import torch

mtcnn = MTCNN(image_size=160, margin=0)
resnet = InceptionResnetV1(pretrained='vggface2').eval()

resnet.load_state_dict(torch.load('../20180402-114759-vggface2.pt'), strict=False)

img1 = Image.open('../img1')
img2 = Image.open('../img2')

img1_cropped = mtcnn(img1)
img2_cropped = mtcnn(img2)

if img1_cropped is not None and img2_cropped is not None:
    img1_embedding = resnet(img1_cropped.unsqueeze(0))
    img2_embedding = resnet(img2_cropped.unsqueeze(0))

    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    similarity = cos(img1_embedding, img2_embedding)

    print(f"Cosine Similarity: {similarity.item()}")

    threshold = 0.6  
    if similarity > threshold:
        print("The faces are similar!")
    else:
        print("The faces are different!")
else:
    print("Face not detected in one or both images.")

Now I want to use it in Scala (JVM Environment). I've searched a lot, and found that we can use .pt model in scala using DJL (Deep Java Library), the code which I tried in scala is:

libraries in build.sbt:

libraryDependencies ++= Seq(
  "ai.djl" % "api" % "0.29.0",
  "ai.djl.pytorch" % "pytorch-engine" % "0.29.0" % "runtime",
  "ai.djl.pytorch" % "pytorch-model-zoo" % "0.29.0",
  "ai.djl.pytorch" % "pytorch-native-cpu" % "2.3.1" % "runtime" classifier "linux-x86_64",
  "ai.djl.pytorch" % "pytorch-jni" % "2.3.1-0.29.0" % "runtime"
)

main:

import ai.djl.Model
import ai.djl.modality.cv.Image
import ai.djl.modality.cv.ImageFactory
import ai.djl.ndarray.{NDArray, NDList, NDManager}
import ai.djl.ndarray.types.Shape
import ai.djl.translate.{Batchifier, Translator, TranslatorContext}

import java.nio.file.Paths

object FaceRecognitionDJL {

  def main(args: Array[String]): Unit = {
    val image1Path = Paths.get("../img_1.png")
    val image2Path = Paths.get("../img_2.png")

    val image1 = ImageFactory.getInstance().fromFile(image1Path)
    val image2 = ImageFactory.getInstance().fromFile(image2Path)

    val model = Model.newInstance("face_recognition_model")
    model.load(Paths.get("../20180402-114759-vggface2.pt"))

    val embeddings1 = getEmbeddings(model, image1)
    val embeddings2 = getEmbeddings(model, image2)

    val similarity = compareEmbeddings(embeddings1, embeddings2)
    println(s"Similarity between faces: $similarity")

    if (similarity > 0.7) {
      println("Faces belong to the same person.")
    } else {
      println("Faces do not belong to the same person.")
    }
  }

  def getEmbeddings(model: Model, image: Image): Array[Float] = {
    val predictor = model.newPredictor(new MyTranslator)
    predictor.predict(image)
  }

  def compareEmbeddings(embedding1: Array[Float], embedding2: Array[Float]): Double = {
    val dotProduct = embedding1.zip(embedding2).map { case (a, b) => a * b }.sum
    val norm1 = Math.sqrt(embedding1.map(x => x * x).sum)
    val norm2 = Math.sqrt(embedding2.map(x => x * x).sum)
    dotProduct / (norm1 * norm2)
  }
}

class MyTranslator extends Translator[Image, Array[Float]] {
  override def processInput(ctx: TranslatorContext, input: Image): NDList = {
    val manager = NDManager.newBaseManager()

    val imgArray: NDArray = input.toNDArray(manager)

    val resizedImgArray = imgArray.reshape(new Shape(160, 160))
    val normalizedImgArray = resizedImgArray.div(255.0)

    new NDList(normalizedImgArray)
  }

  override def processOutput(ctx: TranslatorContext, list: NDList): Array[Float] = {
    list.get(0).toFloatArray
  }

  override def getBatchifier: Batchifier = null
}

I have tried above code, after searching on different websites. But this is giving an error:

[error] Exception in thread "main" ai.djl.engine.EngineException: PytorchStreamReader failed reading zip archive: failed finding central directory

Same .pt model is working fine in python but I'm unable to run that in scala. Guide me what I'm doing wrong?

frankfliu commented 2 months ago

DJL only support torchscript model format. You need to trace your model into torchscript format.

See: https://docs.djl.ai/master/docs/pytorch/how_to_convert_your_model_to_torchscript.html