szymonmaszke / torchlambda

Lightweight tool to deploy PyTorch models to AWS Lambda
MIT License
125 stars 3 forks source link

Top k predictions in output? #5

Open mihow opened 4 years ago

mihow commented 4 years ago

Hello, thanks for the great library! Is it possible to return the top k predictions in the output for the ResNet18 image classification example? And is it possible to include their confidence score? This is what I am currently doing in PyTorch:

predict_values = model(input_object)
preds = torch.nn.functional.softmax(predict_values, dim=1)

# Get top 5 results
top_k = 5
top_preds_raw = preds.topk(top_k)

top_preds = [
    {"label": DEFAULT_MODEL_CLASSES[i], "score": float(score)}
    for i, score in zip(top_preds_raw.indices[0], top_preds_raw.values[0])
]

return top_preds
szymonmaszke commented 4 years ago

Yes, you could do that, as torchlambda consumes PyTorch's C++ frontend so you can do most of the things it supports, but might have to get your hands dirty with C++. If that's not a problem for you it would be easiest to generate template source code with torchlambda template and work on that, see source code below:

C++ version

Run torchlambda template and try to understand what's going on in this template (you may have to modify/delete whether you want normalization or not etc.). Otherwise (with normalization and ImageNet means and stds) this code should probably work (please see lines from THIS IS CUSTOMIZED PART until the end of function, the rest is the same as template):

#include <aws/core/Aws.h>
#include <aws/core/utils/base64/Base64.h>
#include <aws/core/utils/json/JsonSerializer.h>
#include <aws/core/utils/memory/stl/AWSString.h>

#include <aws/lambda-runtime/runtime.h>

#include <torch/script.h>
#include <torch/torch.h>

/*!
 *
 *                    HANDLE REQUEST
 *
 */

static aws::lambda_runtime::invocation_response
handler(torch::jit::script::Module &module,
        const Aws::Utils::Base64::Base64 &transformer,
        const aws::lambda_runtime::invocation_request &request) {

  const Aws::String data_field{"data"};

  /*!
   *
   *              PARSE AND VALIDATE REQUEST
   *
   */

  const auto json = Aws::Utils::Json::JsonValue{request.payload};
  if (!json.WasParseSuccessful())
    return aws::lambda_runtime::invocation_response::failure(
        "Failed to parse input JSON file.", "InvalidJSON");

  const auto json_view = json.View();
  if (!json_view.KeyExists(data_field))
    return aws::lambda_runtime::invocation_response::failure(
        "Required data was not provided.", "InvalidJSON");

  /*!
   *
   *          LOAD DATA, TRANSFORM TO TENSOR, NORMALIZE
   *
   */

  const auto base64_data = json_view.GetString(data_field);
  Aws::Utils::ByteBuffer decoded = transformer.Decode(base64_data);

  torch::Tensor tensor =
      torch::from_blob(decoded.GetUnderlyingData(),
                       {
                           static_cast<long>(decoded.GetLength()),
                       },
                       torch::kUInt8)
          .reshape({1, 3, 64, 64})
          .toType(torch::kFloat32) /
      255.0;

  torch::Tensor normalized_tensor = torch::data::transforms::Normalize<>{
      {0.485, 0.456, 0.406}, {0.229, 0.224, 0.225}}(tensor);

  /*!
   *
   *                    THIS IS CUSTOMIZED PART
   *
   */

  auto output = module.forward({normalized_tensor}).toTensor();
  // Your probabilities
  auto probabilities = torch::softmax(output, 1).item<int>();
  auto [indices, values] = torch::topk(probabilities, 5);

  /*!
   *
   *                       RETURN CUSTOM JSON
   *
   */

  // Create JSON array types to hold indices and values
  Aws::Utils::Array<Aws::Utils::Json::JsonValue> json_indices{
      static_cast<std::size_t>(indices.numel())};

  Aws::Utils::Array<Aws::Utils::Json::JsonValue> json_values{
      static_cast<std::size_t>(values.numel())};

  // Cast indices tensor to C++
  const auto *indices_ptr = indices.data_ptr<int64_t>();
  for (int64_t i = 0; i < indices.numel(); ++i)
    // Input JSON indices values into JSON array type
    json_indices[i] =
        Aws::Utils::Json::JsonValue{}.WithInt64(*(indices_ptr + i));

  // Cast indices tensor to C++
  const auto *values_ptr = values.data_ptr<float>();
  for (int64_t i = 0; i < values.numel(); ++i)
    // Input JSON probability values into JSON array type
    json_values[i] =
        Aws::Utils::Json::JsonValue{}.WithDouble(*(values_ptr + i));

  return aws::lambda_runtime::invocation_response::success(
      Aws::Utils::Json::JsonValue{}
          .WithArray("indices", json_indices)
          .WithArray("probabilities", json_values)
          .View()
          .WriteCompact(),
      "application/json");
}

int main() {
  /*!
   *
   *                        LOAD MODEL ON CPU
   *                    & SET IT TO EVALUATION MODE
   *
   */

  /* Turn off gradient */
  torch::NoGradGuard no_grad_guard{};
  /* No optimization during first pass as it might slow down inference by 30s */
  torch::jit::setGraphExecutorOptimize(false);

  constexpr auto model_path = "/opt/model.ptc";

  torch::jit::script::Module module = torch::jit::load(model_path, torch::kCPU);
  module.eval();

  /*!
   *
   *                        INITIALIZE AWS SDK
   *                    & REGISTER REQUEST HANDLER
   *
   */

  Aws::SDKOptions options;
  Aws::InitAPI(options);
  {
    const Aws::Utils::Base64::Base64 transformer{};
    const auto handler_fn =
        [&module,
         &transformer](const aws::lambda_runtime::invocation_request &request) {
          return handler(module, transformer, request);
        };
    aws::lambda_runtime::run_handler(handler_fn);
  }
  Aws::ShutdownAPI(options);
  return 0;
}

All in all, almost always you would only have to modify how to process data from module and return it as JSON

Non-C++ version

You could simply return all your classes (overhead shouldn't be big if you have reasonable number of classes) after softmax and post-process it on the other side (if that's an option for you).

Edit:

If you see how one could handle similar cases in general via .yaml settings I'm open to suggestions (as all custom uses would be way to hard to handle via simple .yaml settings).