bohundan / treadscan

Python package providing tools for scanning tire tread(s). Part of my bachelor thesis at Faculty of Information Technology, Czech Technical University.
MIT License
1 stars 0 forks source link

convert .pth to .pt (Torch Script) for Mobile deployment #3

Closed Yogeshvasu closed 7 months ago

Yogeshvasu commented 7 months ago

Thanks for your continuous support. I would like to convert the model .pth to ,pt for mobile deployment i have executed many approches but i couldn't convert to torch script since output value is not an tensor.

import torch
import torchvision.transforms.functional as F
from PIL import Image
from torchvision.models.detection.anchor_utils import AnchorGenerator  # Import AnchorGenerator

def get_model(num_keypoints, weights_path=None):
    anchor_generator = AnchorGenerator(sizes=(100, 250, 400, 650, 800), 
                                       aspect_ratios=(1.0, 1.25, 1.5, 1.75, 2.0))
    model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=False,
                                                                   pretrained_backbone=True,
                                                                   num_keypoints=num_keypoints,
                                                                   num_classes=2,
                                                                   rpn_anchor_generator=anchor_generator)
    if weights_path:
        state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
        model.load_state_dict(state_dict)        

    return model

# Load the pre-trained model from the .pth file
device = torch.device('cpu')
model = get_model(num_keypoints=5, weights_path='saved_model.pth')
model.eval()

# Example usage:
image_path = 'object/ADAS/20240228_112139468_iOS.jpeg'  # Path to your image
original_image = Image.open(image_path)
image_tensor = F.to_tensor(original_image).unsqueeze(0)  # Convert image to tensor and add batch dimension

# Trace the model
traced_model = torch.jit.trace(model, image_tensor)

# Save the traced model to a .pt file
traced_model.save('traced_model.pt')
bohundan commented 7 months ago

I fail to see how this relates to the treadscan package in any way. If you want help with the intricacies of PyTorch, try their forums https://discuss.pytorch.org/.