Closed jonathan-ibex closed 2 years ago
@jonathan-ibex Could you please try torch scripting the faster-rcnn model as shown here in alexnet example and then try your test. https://github.com/pytorch/serve/tree/master/examples/image_classifier/alexnet#torchscript-example-using-alexnet-image-classifier
Hey @agunapal, I tried saving the models weights with torch.jit
, and I now get a different error (moving forward! π):
2022-06-30T13:29:14,635 [INFO ] W-9003-stry_1.0 org.pytorch.serve.wlm.WorkerThread - Retry worker: 9003 in 34 seconds.
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - Backend worker process died.
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - Traceback (most recent call last):
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 210, in <module>
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - worker.run_server()
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 181, in run_server
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - self.handle_connection(cl_socket)
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 139, in handle_connection
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - service, result, code = self.load_model(msg)
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 104, in load_model
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - service = model_loader.load(
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_loader.py", line 151, in load
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - initialize_fn(service.context)
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/torch_handler/object_detector.py", line 22, in initialize
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - super().initialize(context)
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/torch_handler/vision_handler.py", line 20, in initialize
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - super().initialize(context)
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/torch_handler/base_handler.py", line 83, in initialize
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - self.model = self._load_pickled_model(
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/torch_handler/base_handler.py", line 151, in _load_pickled_model
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - model.load_state_dict(state_dict)
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1379, in load_state_dict
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - state_dict = state_dict.copy()
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/torch/jit/_script.py", line 667, in __getattr__
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - return super(RecursiveScriptModule, self).__getattr__(attr)
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/torch/jit/_script.py", line 384, in __getattr__
2022-06-30T13:29:14,742 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - return super(ScriptModule, self).__getattr__(attr)
2022-06-30T13:29:14,743 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in __getattr__
2022-06-30T13:29:14,743 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - raise AttributeError("'{}' object has no attribute '{}'".format(
2022-06-30T13:29:14,743 [INFO ] W-9002-stry_1.0-stdout MODEL_LOG - AttributeError: 'RecursiveScriptModule' object has no attribute 'copy'
2022-06-30T13:29:14,759 [WARN ] W-9002-stry_1.0-stderr MODEL_LOG - /opt/conda/lib/python3.8/site-packages/torch/serialization.py:602: UserWarning: 'torch.load' received a zip file that looks like a TorchScript archive dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to silence this warning)
2022-06-30T13:29:14,759 [WARN ] W-9002-stry_1.0-stderr MODEL_LOG - warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
2022-06-30T13:29:14,759 [INFO ] epollEventLoopGroup-5-72 org.pytorch.serve.wlm.WorkerThread - 9002 Worker disconnected. WORKER_STARTED
2022-06-30T13:29:14,760 [DEBUG] W-9002-stry_1.0 org.pytorch.serve.wlm.WorkerThread - System state is : WORKER_STARTED
2022-06-30T13:29:14,760 [DEBUG] W-9002-stry_1.0 org.pytorch.serve.wlm.WorkerThread - Backend worker monitoring thread interrupted or backend worker process died.
java.lang.InterruptedException: null
at java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.reportInterruptAfterWait(AbstractQueuedSynchronizer.java:2056) ~[?:?]
at java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2133) ~[?:?]
at java.util.concurrent.ArrayBlockingQueue.poll(ArrayBlockingQueue.java:432) ~[?:?]
at org.pytorch.serve.wlm.WorkerThread.run(WorkerThread.java:189) [model-server.jar:?]
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128) [?:?]
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628) [?:?]
at java.lang.Thread.run(Thread.java:829) [?:?]
2022-06-30T13:29:14,760 [WARN ] W-9002-stry_1.0 org.pytorch.serve.wlm.BatchAggregator - Load model failed: stry, error: Worker died.
2022-06-30T13:29:14,760 [DEBUG] W-9002-stry_1.0 org.pytorch.serve.wlm.WorkerThread - W-9002-stry_1.0 State change WORKER_STARTED -> WORKER_STOPPED
2022-06-30T13:29:14,760 [WARN ] W-9002-stry_1.0 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9002-stry_1.0-stderr
2022-06-30T13:29:14,760 [WARN ] W-9002-stry_1.0 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9002-stry_1.0-stdout
(Comment: I named the model "stry") What might the issue be here? Thanks!
@jonathan-ibex Will try it and get back to you
@jonathan-ibex Do you mind sharing your code.
I tried the following and it worked for me
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import torch
model = fasterrcnn_resnet50_fpn(pretrained=True)
sm = torch.jit.script(model)
sm.save("frcnn.pt")
Hey @agunapal
I used exactly your code, and the saving works without any issues. βοΈ
I then proceeded to archive the model using torch-model-archiver
as described here (without the model file!) βοΈ
Now the server starts correctly and I can see the model runs correctly on the workers βοΈ
But I still have the backbone on cuda:0 π’
Appreciate your help
@jonathan-ibex Thanks for checking. Will debug further and get back to you
@jonathan-ibex I tried reproducing this will PyTorch 1.12/cuda 11.3 and I think its working as expected. Please find the screenshot below. Does this look fine? Can you please try and let me know.
Closing this. Please re-open if issue is not resolved
Hey @agunapal!
Good news! π I managed to get the same result once I updated to the latest torch (1.12) + cuda 113 support.
Now I'm facing a new issue, perhaps you could help me here: When I'm trying to change the default network to a custom model that gives the full probability output (I want the softmax result instead of the argmax), I get the error I previously got (backbone loaded on cuda:0
)
Here are my model files. I give model.py
as the model file, and supply network_modules.py
as an extra file to the model archiver.
Some things I tried but didn't work:
__all__
expression to the network_modules filepytorch/torchserve:latest
and pytorch/torchserve:latest-gpu
install_py_dep_per_model=true
to the config.properties
file (I tried with and without supplying a requirements.txt file)Appreciate any further help, Thanks!
If anybody ever comes across the same issue in the future:
I solved it eventually. I followed this example, and had to rewrite my custom handler. The big difference is that in the way described in this example they skip the default loading and load the model on their own in the initialize method. Here's how I did it:
def initialize(self, context):
properties = context.system_properties
self.map_location = "cuda" if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu"
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else self.map_location
)
self.manifest = context.manifest
model_dir = properties.get("model_dir")
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)
self.model = FRCNNObjectDetector()
self.model.load_state_dict(torch.load(model_pt_path, map_location=self.device))
self.model.to(self.device)
self.model.eval()
logger.debug('Model loaded successfully')
self.initialized = True
Make sure to use torch.load
's map_location
, it made the difference for me.
π Describe the bug
Hey, I have an issue when running the FasterRCNN example given here: it seems that the backbone is being loaded on cuda:0 while the model itself is being distributed to multiple GPUs. I saw that this issue was mentioned before for other architectures: #1037, #1038, vision issue.
I believe that this is a similar issue but I'm not sure how to handle this case. Appreciate the help
Error logs
Installation instructions
Install torchserve from source: No Running in docker: Yes, inside this image:
nvcr.io/nvidia/pytorch:21.02-py3
I clone the serve repo, run the install dependencies script, and then pip install torchserve.
Model Packaing
I use the built in handler: https://github.com/pytorch/serve/blob/master/ts/torch_handler/object_detector.py
config.properties
default
Versions
Environment headers
Torchserve branch:
torchserve==0.6.0 torch-model-archiver==0.6.0
Python version: 3.8 (64-bit runtime) Python executable: /opt/conda/bin/python
Versions of relevant python libraries: captum==0.5.0 future==0.18.2 numpy==1.23.0 nvgpu==0.9.0 psutil==5.9.1 pytest==6.2.2 pytest-cov==2.11.1 pytest-pythonpath==0.7.3 pytorch-transformers==1.1.0 requests==2.28.0 sentencepiece==0.1.95 torch==1.9.0+cu111 torch-model-archiver==0.6.0 torch-workflow-archiver==0.2.4 torchaudio==0.9.0 torchserve==0.6.0 torchserve-dashboard==0.5.0 torchtext==0.10.0 torchvision==0.10.0+cu111 wheel==0.37.1 torch==1.9.0+cu111 torchtext==0.10.0 torchvision==0.10.0+cu111 torchaudio==0.9.0
Java Version:
OS: N/A GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 Clang version: N/A CMake version: version 3.19.4
Is CUDA available: Yes CUDA runtime version: 11.2.67 GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090 GPU 1: NVIDIA GeForce RTX 3090 GPU 2: NVIDIA GeForce RTX 3090 GPU 3: NVIDIA GeForce RTX 3090 GPU 4: NVIDIA GeForce RTX 3090 GPU 5: NVIDIA GeForce RTX 3090 GPU 6: NVIDIA GeForce RTX 3090 GPU 7: NVIDIA GeForce RTX 3090 Nvidia driver version: 510.54 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.0
Repro instructions
Follow the steps here https://github.com/pytorch/serve/tree/master/examples/object_detector/fast-rcnn and run nvidia-smi in a different terminal
Possible Solution
No response