aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
466 stars 154 forks source link

Tracing LightGlue with PyTorch NeuronX #809

Open layel2 opened 10 months ago

layel2 commented 10 months ago

Hi, I tried to trace LightGlue model with inf2 instance but it got error and crash. Trace command model_neuron = torch_neuronx.trace(model, input_features, compiler_args=["--target","inf2"]) then it got this output and crash.

2024-01-08 09:23:43.000175:  17179  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-01-08 09:23:43.000186:  17179  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.12.54.0+f631c2365/MODULE_2108354257835340671+d41d8cd9/model.neff. Exiting with a successfully compiled graph.
2024-01-08 09:23:43.000456:  17179  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-01-08 09:23:43.000465:  17179  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.12.54.0+f631c2365/MODULE_8221219792448796712+d41d8cd9/model.neff. Exiting with a successfully compiled graph.
2024-01-08 09:23:43.000693:  17179  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-01-08 09:23:43.000701:  17179  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.12.54.0+f631c2365/MODULE_5225604489184339909+d41d8cd9/model.neff. Exiting with a successfully compiled graph.
2024-01-08 09:23:43.000889:  17179  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-01-08 09:23:43.000890:  17179  ERROR ||NEURON_CC_WRAPPER||: Got a cached failed neff at /var/tmp/neuron-compile-cache/neuronxcc-2.12.54.0+f631c2365/MODULE_7572268680822998320+d41d8cd9/model.neff. Will skip compilation, please set --retry_failed_compilation for recompilation: 
 Failed compilation with ['neuronx-cc', 'compile', '--target=trn1', '--framework=XLA', '/tmp/ubuntu/neuroncc_compile_workdir/2f43b928-1f74-403f-84ea-f40d67d8264d/model.MODULE_7572268680822998320+d41d8cd9.hlo_module.pb', '--output', '/tmp/ubuntu/neuroncc_compile_workdir/2f43b928-1f74-403f-84ea-f40d67d8264d/model.MODULE_7572268680822998320+d41d8cd9.neff', '--verbose=35']: 2024-01-08T07:47:13Z [TEN404] (_gather.1446) Internal tensorizer error - Please open a support ticket at https://github.com/aws-neuron/aws-neuron-sdk/issues/new
.

Model analyze result lightglue-neuronx-model-analyze.txt

neuronx-cc -V

NeuronX Compiler version 2.12.54.0+f631c2365

Python version 3.8.10
HWM version 2.12.0.0-422c9037c
NumPy version 1.24.4

Running on AMI ami-06ab1b228b2610fbe
Running in region apse1-az3

pytorch version

torch==2.1.2
torch-neuronx==2.1.1.2.0.0b0
torch-xla==2.1.1
torchvision==0.16.2

Thank you for helping

shebbur-aws commented 10 months ago

Thanks for reporting the issue. We are trying to reproduce the problem on our end. Will get back to you shortly.

jeffhataws commented 10 months ago

Hi @layel2 ,

Since there's no reproduction code, I tried to modify LightGlue benchmark.py example like below and was able to compile with release 2.16 compiler and pytorch 2.1:

@@ -192,6 +193,11 @@ if __name__ == "__main__":
                 extractor.conf.max_num_keypoints = num_kpts
                 feats0 = extractor.extract(image0)
                 feats1 = extractor.extract(image1)
+                import torch_neuronx
+                matcher.pruning_keypoint_thresholds['xla'] = -1
+                #new_matcher = torch.jit.trace(matcher, {"image0": feats0, "image1": feats1})
+                new_matcher = torch_neuronx.trace(matcher, {"image0": feats0, "image1": feats1}, compiler_workdir="./workdir")
                 runtime = measure(
                     matcher,
                     {"image0": feats0, "image1": feats1},

However, after successful compilation where I see "Compiler status PASS", I then see the error RuntimeError: Tracer cannot infer type of ... Dictionary inputs to traced functions must have consistent type. Found Tensor and int for the returning results. This is limitation of TorchScript, which you can see also when you use torch.jit.trace instead of torch_neuronx.trace.

The torch_neuronx.trace API uses torch.jit.trace under the hood. Thus, in order to make the model functional with torch_neuronx.trace, it must first be compatible with torch.jit.trace. The error indicates that model creates an output dictionary with mixed value types (float & int tensors). This is not supported by torch.jit.trace even when strict=False, which means that the trace fails. To get this model running, first see if you can get it working with torch.jit.trace. One think you can try is to avoid mixed value type dictionary outputs by using the same data type across all the tensors, or create a module wrapper that only return tensors of one data type.

layel2 commented 10 months ago

Hi @jeffhataws ,

Sorry that I forgot to provide reproduction code, actually I fixed the error you got since I tried to compile by casting everything into torch tensor.

Here's my repo that contain fixed version of lightglue.py and also the compile code that I use. https://github.com/layel2/LightGlue-inf2 Compile code: https://github.com/layel2/LightGlue-inf2/blob/main/inf/lightglue.ipynb

LightGlue/lightglue.py

     scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
     return scores

-
+from types import SimpleNamespace
 class MatchAssignment(nn.Module):
     def __init__(self, dim: int) -> None:
         super().__init__()
@@ -332,6 +332,7 @@ class LightGlue(nn.Module):
         "mps": -1,
         "cuda": 1024,
         "flash": 1536,
+        "xla": -1,
     }

     required_data_keys = ["image0", "image1"]
@@ -579,14 +580,14 @@ class LightGlue(nn.Module):
             "matches1": m1,
             "matching_scores0": mscores0,
             "matching_scores1": mscores1,
-            "stop": i + 1,
-            "matches": matches,
-            "scores": mscores,
+            "stop": torch.tensor(i + 1).to(device),
+            "matches": torch.stack(matches).to(device),
+            "scores": torch.stack(mscores).to(device),
             "prune0": prune0,
             "prune1": prune1,
         }

-        return pred
+        return list(pred.values())

     def confidence_threshold(self, layer_index: int) -> float:
         """scaled confidence threshold"""

Thanks