pytorch / serve

Serve, optimize and scale PyTorch models in production
https://pytorch.org/serve/
Apache License 2.0
4.19k stars 858 forks source link

Serve the generative model #1192

Closed SimKarras closed 3 years ago

SimKarras commented 3 years ago

📚 Documentation

I tried to serve a generative model, but failed. Can you help me? repo: https://github.com/TencentARC/GFPGAN

Files

main network : examples/GAN/gfpgan/gfpganv1_clean_arch.py

import math
import random
import torch
from torch import nn
from torch.nn import functional as F
from stylegan2_clean_arch import (ResBlock, StyleGAN2GeneratorCSFT)

class GFPGANv1Clean(nn.Module):
    """GFPGANv1 Clean version."""

    def __init__(
            self,
            out_size=512,
            num_style_feat=512,
            channel_multiplier=2,
            decoder_load_path=None,
            fix_decoder=True,
            # for stylegan decoder
            num_mlp=8,
            input_is_latent=True,
            different_w=True,
            narrow=1,
            sft_half=True):

        super(GFPGANv1Clean, self).__init__()
        self.input_is_latent = input_is_latent
        self.different_w = different_w
        self.num_style_feat = num_style_feat

        unet_narrow = narrow * 0.5
        channels = {
            '4': int(512 * unet_narrow),
            '8': int(512 * unet_narrow),
            '16': int(512 * unet_narrow),
            '32': int(512 * unet_narrow),
            '64': int(256 * channel_multiplier * unet_narrow),
            '128': int(128 * channel_multiplier * unet_narrow),
            '256': int(64 * channel_multiplier * unet_narrow),
            '512': int(32 * channel_multiplier * unet_narrow),
            '1024': int(16 * channel_multiplier * unet_narrow)
        }

        self.log_size = int(math.log(out_size, 2))
        first_out_size = 2**(int(math.log(out_size, 2)))

        self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)

        # downsample
        in_channels = channels[f'{first_out_size}']
        self.conv_body_down = nn.ModuleList()
        for i in range(self.log_size, 2, -1):
            out_channels = channels[f'{2**(i - 1)}']
            self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
            in_channels = out_channels

        self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)

        # upsample
        in_channels = channels['4']
        self.conv_body_up = nn.ModuleList()
        for i in range(3, self.log_size + 1):
            out_channels = channels[f'{2**i}']
            self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
            in_channels = out_channels

        # to RGB
        self.toRGB = nn.ModuleList()
        for i in range(3, self.log_size + 1):
            self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))

        if different_w:
            linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
        else:
            linear_out_channel = num_style_feat

        self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)

        self.stylegan_decoder = StyleGAN2GeneratorCSFT(
            out_size=out_size,
            num_style_feat=num_style_feat,
            num_mlp=num_mlp,
            channel_multiplier=channel_multiplier,
            narrow=narrow,
            sft_half=sft_half)

        if decoder_load_path:
            self.stylegan_decoder.load_state_dict(
                torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
        if fix_decoder:
            for name, param in self.stylegan_decoder.named_parameters():
                param.requires_grad = False

        # for SFT
        self.condition_scale = nn.ModuleList()
        self.condition_shift = nn.ModuleList()
        for i in range(3, self.log_size + 1):
            out_channels = channels[f'{2**i}']
            if sft_half:
                sft_out_channels = out_channels
            else:
                sft_out_channels = out_channels * 2
            self.condition_scale.append(
                nn.Sequential(
                    nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
                    nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
            self.condition_shift.append(
                nn.Sequential(
                    nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
                    nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))

    def forward(self,
                x,
                return_latents=False,
                save_feat_path=None,
                load_feat_path=None,
                return_rgb=True,
                randomize_noise=True):
        conditions = []
        unet_skips = []
        out_rgbs = []

        # encoder
        feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
        for i in range(self.log_size - 2):
            feat = self.conv_body_down[i](feat)
            unet_skips.insert(0, feat)
        feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)

        # style code
        style_code = self.final_linear(feat.view(feat.size(0), -1))
        if self.different_w:
            style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
        # decode
        for i in range(self.log_size - 2):
            # add unet skip
            feat = feat + unet_skips[i]
            # ResUpLayer
            feat = self.conv_body_up[i](feat)
            # generate scale and shift for SFT layer
            scale = self.condition_scale[i](feat)
            conditions.append(scale.clone())
            shift = self.condition_shift[i](feat)
            conditions.append(shift.clone())
            # generate rgb images
            if return_rgb:
                out_rgbs.append(self.toRGB[i](feat))

        if save_feat_path is not None:
            torch.save(conditions, save_feat_path)
        if load_feat_path is not None:
            conditions = torch.load(load_feat_path)
            conditions = [v.cuda() for v in conditions]

        # decoder
        image, _ = self.stylegan_decoder([style_code],
                                         conditions,
                                         return_latents=return_latents,
                                         input_is_latent=self.input_is_latent,
                                         randomize_noise=randomize_noise)

        # return image, out_rgbs
        return image

sub-network: examples/GAN/gfpgan/stylegan2_clean_arch.py

import math
import random
import torch
from basicsr.archs.arch_util import default_init_weights
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F

class StyleGAN2GeneratorClean(nn.Module):
    """Clean version of StyleGAN2 Generator.

    Args:
        out_size (int): The spatial size of outputs.
        num_style_feat (int): Channel number of style features. Default: 512.
        num_mlp (int): Layer number of MLP style layers. Default: 8.
        channel_multiplier (int): Channel multiplier for large networks of
            StyleGAN2. Default: 2.
        narrow (float): Narrow ratio for channels. Default: 1.0.
    """

    def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
        super(StyleGAN2GeneratorClean, self).__init__()
        # Style MLP layers
        self.num_style_feat = num_style_feat
        style_mlp_layers = [NormStyleCode()]
        for i in range(num_mlp):
            style_mlp_layers.extend(
                [nn.Linear(num_style_feat, num_style_feat, bias=True),
                 nn.LeakyReLU(negative_slope=0.2, inplace=True)])
        self.style_mlp = nn.Sequential(*style_mlp_layers)
        # initialization
        default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')

        channels = {
            '4': int(512 * narrow),
            '8': int(512 * narrow),
            '16': int(512 * narrow),
            '32': int(512 * narrow),
            '64': int(256 * channel_multiplier * narrow),
            '128': int(128 * channel_multiplier * narrow),
            '256': int(64 * channel_multiplier * narrow),
            '512': int(32 * channel_multiplier * narrow),
            '1024': int(16 * channel_multiplier * narrow)
        }
        self.channels = channels

        self.constant_input = ConstantInput(channels['4'], size=4)
        self.style_conv1 = StyleConv(
            channels['4'],
            channels['4'],
            kernel_size=3,
            num_style_feat=num_style_feat,
            demodulate=True,
            sample_mode=None)
        self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)

        self.log_size = int(math.log(out_size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1
        self.num_latent = self.log_size * 2 - 2

        self.style_convs = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()

        in_channels = channels['4']
        # noise
        for layer_idx in range(self.num_layers):
            resolution = 2**((layer_idx + 5) // 2)
            shape = [1, 1, resolution, resolution]
            self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
        # style convs and to_rgbs
        for i in range(3, self.log_size + 1):
            out_channels = channels[f'{2**i}']
            self.style_convs.append(
                StyleConv(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    num_style_feat=num_style_feat,
                    demodulate=True,
                    sample_mode='upsample'))
            self.style_convs.append(
                StyleConv(
                    out_channels,
                    out_channels,
                    kernel_size=3,
                    num_style_feat=num_style_feat,
                    demodulate=True,
                    sample_mode=None))
            self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
            in_channels = out_channels

    def make_noise(self):
        """Make noise for noise injection."""
        device = self.constant_input.weight.device
        noises = [torch.randn(1, 1, 4, 4, device=device)]

        for i in range(3, self.log_size + 1):
            for _ in range(2):
                noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))

        return noises

    def get_latent(self, x):
        return self.style_mlp(x)

    def mean_latent(self, num_latent):
        latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
        latent = self.style_mlp(latent_in).mean(0, keepdim=True)
        return latent

    def forward(self,
                styles,
                input_is_latent=False,
                noise=None,
                randomize_noise=True,
                truncation=1,
                truncation_latent=None,
                inject_index=None,
                return_latents=False):
        """Forward function for StyleGAN2Generator.

        Args:
            styles (list[Tensor]): Sample codes of styles.
            input_is_latent (bool): Whether input is latent style.
                Default: False.
            noise (Tensor | None): Input noise or None. Default: None.
            randomize_noise (bool): Randomize noise, used when 'noise' is
                False. Default: True.
            truncation (float): TODO. Default: 1.
            truncation_latent (Tensor | None): TODO. Default: None.
            inject_index (int | None): The injection index for mixing noise.
                Default: None.
            return_latents (bool): Whether to return style latents.
                Default: False.
        """
        # style codes -> latents with Style MLP layer
        if not input_is_latent:
            styles = [self.style_mlp(s) for s in styles]
        # noises
        if noise is None:
            if randomize_noise:
                noise = [None] * self.num_layers  # for each style conv layer
            else:  # use the stored noise
                noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
        # style truncation
        if truncation < 1:
            style_truncation = []
            for style in styles:
                style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
            styles = style_truncation
        # get style latent with injection
        if len(styles) == 1:
            inject_index = self.num_latent

            if styles[0].ndim < 3:
                # repeat latent code for all the layers
                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            else:  # used for encoder with different latent code for each layer
                latent = styles[0]
        elif len(styles) == 2:  # mixing noises
            if inject_index is None:
                inject_index = random.randint(1, self.num_latent - 1)
            latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
            latent = torch.cat([latent1, latent2], 1)

        # main generation
        out = self.constant_input(latent.shape[0])
        out = self.style_conv1(out, latent[:, 0], noise=noise[0])
        skip = self.to_rgb1(out, latent[:, 1])

        i = 1
        for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
                                                        noise[2::2], self.to_rgbs):
            out = conv1(out, latent[:, i], noise=noise1)
            out = conv2(out, latent[:, i + 1], noise=noise2)
            skip = to_rgb(out, latent[:, i + 2], skip)
            i += 2

        image = skip

        if return_latents:
            return image, latent
        else:
            return image, None

class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
    """StyleGAN2 Generator.

    Args:
        out_size (int): The spatial size of outputs.
        num_style_feat (int): Channel number of style features. Default: 512.
        num_mlp (int): Layer number of MLP style layers. Default: 8.
        channel_multiplier (int): Channel multiplier for large networks of
            StyleGAN2. Default: 2.
    """

    def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
        super(StyleGAN2GeneratorCSFT, self).__init__(
            out_size,
            num_style_feat=num_style_feat,
            num_mlp=num_mlp,
            channel_multiplier=channel_multiplier,
            narrow=narrow)

        self.sft_half = sft_half

    def forward(self,
                styles,
                conditions,
                input_is_latent=False,
                noise=None,
                randomize_noise=True,
                truncation=1,
                truncation_latent=None,
                inject_index=None,
                return_latents=False):
        """Forward function for StyleGAN2Generator.

        Args:
            styles (list[Tensor]): Sample codes of styles.
            input_is_latent (bool): Whether input is latent style.
                Default: False.
            noise (Tensor | None): Input noise or None. Default: None.
            randomize_noise (bool): Randomize noise, used when 'noise' is
                False. Default: True.
            truncation (float): TODO. Default: 1.
            truncation_latent (Tensor | None): TODO. Default: None.
            inject_index (int | None): The injection index for mixing noise.
                Default: None.
            return_latents (bool): Whether to return style latents.
                Default: False.
        """
        # style codes -> latents with Style MLP layer
        if not input_is_latent:
            styles = [self.style_mlp(s) for s in styles]
        # noises
        if noise is None:
            if randomize_noise:
                noise = [None] * self.num_layers  # for each style conv layer
            else:  # use the stored noise
                noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
        # style truncation
        if truncation < 1:
            style_truncation = []
            for style in styles:
                style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
            styles = style_truncation
        # get style latent with injection
        if len(styles) == 1:
            inject_index = self.num_latent

            if styles[0].ndim < 3:
                # repeat latent code for all the layers
                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            else:  # used for encoder with different latent code for each layer
                latent = styles[0]
        elif len(styles) == 2:  # mixing noises
            if inject_index is None:
                inject_index = random.randint(1, self.num_latent - 1)
            latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
            latent = torch.cat([latent1, latent2], 1)

        # main generation
        out = self.constant_input(latent.shape[0])
        out = self.style_conv1(out, latent[:, 0], noise=noise[0])
        skip = self.to_rgb1(out, latent[:, 1])

        i = 1
        for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
                                                        noise[2::2], self.to_rgbs):
            out = conv1(out, latent[:, i], noise=noise1)

            # the conditions may have fewer levels
            if i < len(conditions):
                # SFT part to combine the conditions
                if self.sft_half:
                    out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
                    out_sft = out_sft * conditions[i - 1] + conditions[i]
                    out = torch.cat([out_same, out_sft], dim=1)
                else:
                    out = out * conditions[i - 1] + conditions[i]

            out = conv2(out, latent[:, i + 1], noise=noise2)
            skip = to_rgb(out, latent[:, i + 2], skip)
            i += 2

        image = skip

        if return_latents:
            return image, latent
        else:
            return image, None

class ResBlock(nn.Module):
    """Residual block with upsampling/downsampling.

    Args:
        in_channels (int): Channel number of the input.
        out_channels (int): Channel number of the output.
    """

    def __init__(self, in_channels, out_channels, mode='down'):
        super(ResBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
        self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
        if mode == 'down':
            self.scale_factor = 0.5
        elif mode == 'up':
            self.scale_factor = 2

    def forward(self, x):
        out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
        # upsample/downsample
        out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
        out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
        # skip
        x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
        skip = self.skip(x)
        out = out + skip
        return out

class NormStyleCode(nn.Module):

    def forward(self, x):
        """Normalize the style codes.

        Args:
            x (Tensor): Style codes with shape (b, c).

        Returns:
            Tensor: Normalized tensor.
        """
        return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)

class ModulatedConv2d(nn.Module):
    """Modulated Conv2d used in StyleGAN2.

    There is no bias in ModulatedConv2d.

    Args:
        in_channels (int): Channel number of the input.
        out_channels (int): Channel number of the output.
        kernel_size (int): Size of the convolving kernel.
        num_style_feat (int): Channel number of style features.
        demodulate (bool): Whether to demodulate in the conv layer.
            Default: True.
        sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
            Default: None.
        eps (float): A value added to the denominator for numerical stability.
            Default: 1e-8.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 num_style_feat,
                 demodulate=True,
                 sample_mode=None,
                 eps=1e-8):
        super(ModulatedConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.demodulate = demodulate
        self.sample_mode = sample_mode
        self.eps = eps

        # modulation inside each modulated conv
        self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
        # initialization
        default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')

        self.weight = nn.Parameter(
            torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
            math.sqrt(in_channels * kernel_size**2))
        self.padding = kernel_size // 2

    def forward(self, x, style):
        """Forward function.

        Args:
            x (Tensor): Tensor with shape (b, c, h, w).
            style (Tensor): Tensor with shape (b, num_style_feat).

        Returns:
            Tensor: Modulated tensor after convolution.
        """
        b, c, h, w = x.shape  # c = c_in
        # weight modulation
        style = self.modulation(style).view(b, 1, c, 1, 1)
        # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
        weight = self.weight * style  # (b, c_out, c_in, k, k)

        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
            weight = weight * demod.view(b, self.out_channels, 1, 1, 1)

        weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)

        if self.sample_mode == 'upsample':
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        elif self.sample_mode == 'downsample':
            x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)

        b, c, h, w = x.shape
        x = x.view(1, b * c, h, w)
        # weight: (b*c_out, c_in, k, k), groups=b
        out = F.conv2d(x, weight, padding=self.padding, groups=b)
        out = out.view(b, self.out_channels, *out.shape[2:4])

        return out

    def __repr__(self):
        return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
                f'out_channels={self.out_channels}, '
                f'kernel_size={self.kernel_size}, '
                f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')

class StyleConv(nn.Module):
    """Style conv.

    Args:
        in_channels (int): Channel number of the input.
        out_channels (int): Channel number of the output.
        kernel_size (int): Size of the convolving kernel.
        num_style_feat (int): Channel number of style features.
        demodulate (bool): Whether demodulate in the conv layer. Default: True.
        sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
            Default: None.
    """

    def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
        super(StyleConv, self).__init__()
        self.modulated_conv = ModulatedConv2d(
            in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
        self.weight = nn.Parameter(torch.zeros(1))  # for noise injection
        self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
        self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x, style, noise=None):
        # modulate
        out = self.modulated_conv(x, style) * 2**0.5  # for conversion
        # noise injection
        if noise is None:
            b, _, h, w = out.shape
            noise = out.new_empty(b, 1, h, w).normal_()
        out = out + self.weight * noise
        # add bias
        out = out + self.bias
        # activation
        out = self.activate(out)
        return out

class ToRGB(nn.Module):
    """To RGB from features.

    Args:
        in_channels (int): Channel number of input.
        num_style_feat (int): Channel number of style features.
        upsample (bool): Whether to upsample. Default: True.
    """

    def __init__(self, in_channels, num_style_feat, upsample=True):
        super(ToRGB, self).__init__()
        self.upsample = upsample
        self.modulated_conv = ModulatedConv2d(
            in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
        self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))

    def forward(self, x, style, skip=None):
        """Forward function.

        Args:
            x (Tensor): Feature tensor with shape (b, c, h, w).
            style (Tensor): Tensor with shape (b, num_style_feat).
            skip (Tensor): Base/skip tensor. Default: None.

        Returns:
            Tensor: RGB images.
        """
        out = self.modulated_conv(x, style)
        out = out + self.bias
        if skip is not None:
            if self.upsample:
                skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
            out = out + skip
        return out

class ConstantInput(nn.Module):
    """Constant input.

    Args:
        num_channel (int): Channel number of constant input.
        size (int): Spatial size of constant input.
    """

    def __init__(self, num_channel, size):
        super(ConstantInput, self).__init__()
        self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))

    def forward(self, batch):
        out = self.weight.repeat(batch, 1, 1, 1)
        return out

simple handler: ts/torch_handler/generative.py

"""
Base Handler for Generative Model.
Created by sjw, 2021/7/27.
"""
from PIL import Image
from captum.attr import IntegratedGradients
import base64
import io
import torch
from torchvision import transforms, utils
from ts.torch_handler.vision_handler import VisionHandler

class Generative_Handler(VisionHandler):
    """
    Base class for all generative handlers.
    """
    image_processing = transforms.Compose([
        transforms.Resize(512),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
                             std=[0.5, 0.5, 0.5])
    ])

Store a Model

torch-model-archiver --model-name clean --version 1.0 --model-file examples/GAN/gfpgan/gfpganv1_clean_arch.py --serialized-file model_zoo/clean.pth --export-path model_store --handler ts/torch_handler/generative.py  --extra-files examples/GAN/gfpgan/stylegan2_clean_arch.py -f

Start TorchServe

torchserve --start --ncs --model-store model_store --models clean.mar

logs:

2021-08-10 09:25:22,904 [INFO ] main org.pytorch.serve.servingsdk.impl.PluginsManager - Initializing plugins manager...
2021-08-10 09:25:23,021 [INFO ] main org.pytorch.serve.ModelServer - 
Torchserve version: 0.4.2
TS Home: /opt/conda/lib/python3.8/site-packages
Current directory: /workspace
Temp directory: /tmp
Number of GPUs: 2
Number of CPUs: 24
Max heap size: 30688 M
Python executable: /opt/conda/bin/python
Config file: N/A
Inference address: http://127.0.0.1:8080
Management address: http://127.0.0.1:8081
Metrics address: http://127.0.0.1:8082
Model Store: /workspace/model_store
Initial Models: clean.mar
Log dir: /workspace/logs
Metrics dir: /workspace/logs
Netty threads: 0
Netty client threads: 0
Default workers per model: 2
Blacklist Regex: N/A
Maximum Response Size: 6553500
Maximum Request Size: 6553500
Prefer direct buffer: false
Allowed Urls: [file://.*|http(s)?://.*]
Custom python dependency for model allowed: false
Metrics report format: prometheus
Enable metrics API: true
Workflow Store: /workspace/model_store
Model config: N/A
2021-08-10 09:25:23,028 [INFO ] main org.pytorch.serve.servingsdk.impl.PluginsManager -  Loading snapshot serializer plugin...
2021-08-10 09:25:23,044 [INFO ] main org.pytorch.serve.ModelServer - Loading initial models: clean.mar
2021-08-10 09:25:25,856 [DEBUG] main org.pytorch.serve.wlm.ModelVersionedRefs - Adding new version 1.0 for model clean
2021-08-10 09:25:25,856 [DEBUG] main org.pytorch.serve.wlm.ModelVersionedRefs - Setting default version to 1.0 for model clean
2021-08-10 09:25:25,856 [INFO ] main org.pytorch.serve.wlm.ModelManager - Model clean loaded.
2021-08-10 09:25:25,856 [DEBUG] main org.pytorch.serve.wlm.ModelManager - updateModel: clean, count: 2
2021-08-10 09:25:25,863 [INFO ] main org.pytorch.serve.ModelServer - Initialize Inference server with: EpollServerSocketChannel.
2021-08-10 09:25:25,908 [INFO ] main org.pytorch.serve.ModelServer - Inference API bind to: http://127.0.0.1:8080
2021-08-10 09:25:25,908 [INFO ] main org.pytorch.serve.ModelServer - Initialize Management server with: EpollServerSocketChannel.
2021-08-10 09:25:25,909 [INFO ] main org.pytorch.serve.ModelServer - Management API bind to: http://127.0.0.1:8081
2021-08-10 09:25:25,909 [INFO ] main org.pytorch.serve.ModelServer - Initialize Metrics server with: EpollServerSocketChannel.
2021-08-10 09:25:25,910 [INFO ] main org.pytorch.serve.ModelServer - Metrics API bind to: http://127.0.0.1:8082
Model server started.
2021-08-10 09:25:26,022 [WARN ] pool-2-thread-1 org.pytorch.serve.metrics.MetricCollector - worker pid is not available yet.
2021-08-10 09:25:26,066 [INFO ] pool-2-thread-1 TS_METRICS - CPUUtilization.Percent:0.0|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587526
2021-08-10 09:25:26,067 [INFO ] pool-2-thread-1 TS_METRICS - DiskAvailable.Gigabytes:15.87015151977539|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587526
2021-08-10 09:25:26,067 [INFO ] pool-2-thread-1 TS_METRICS - DiskUsage.Gigabytes:852.9611473083496|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587526
2021-08-10 09:25:26,067 [INFO ] pool-2-thread-1 TS_METRICS - DiskUtilization.Percent:98.2|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587526
2021-08-10 09:25:26,068 [INFO ] pool-2-thread-1 TS_METRICS - MemoryAvailable.Megabytes:107751.2734375|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587526
2021-08-10 09:25:26,068 [INFO ] pool-2-thread-1 TS_METRICS - MemoryUsed.Megabytes:19003.48828125|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587526
2021-08-10 09:25:26,068 [INFO ] pool-2-thread-1 TS_METRICS - MemoryUtilization.Percent:16.3|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587526
2021-08-10 09:25:26,458 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - Listening on port: /tmp/.ts.sock.9001
2021-08-10 09:25:26,459 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - [PID]31611
2021-08-10 09:25:26,459 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - Torch worker started.
2021-08-10 09:25:26,459 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - Python runtime: 3.8.5
2021-08-10 09:25:26,459 [DEBUG] W-9001-clean_1.0 org.pytorch.serve.wlm.WorkerThread - W-9001-clean_1.0 State change null -> WORKER_STARTED
2021-08-10 09:25:26,464 [INFO ] W-9001-clean_1.0 org.pytorch.serve.wlm.WorkerThread - Connecting to: /tmp/.ts.sock.9001
2021-08-10 09:25:26,468 [INFO ] W-9000-clean_1.0-stdout MODEL_LOG - Listening on port: /tmp/.ts.sock.9000
2021-08-10 09:25:26,468 [INFO ] W-9000-clean_1.0-stdout MODEL_LOG - [PID]31612
2021-08-10 09:25:26,468 [INFO ] W-9000-clean_1.0-stdout MODEL_LOG - Torch worker started.
2021-08-10 09:25:26,468 [DEBUG] W-9000-clean_1.0 org.pytorch.serve.wlm.WorkerThread - W-9000-clean_1.0 State change null -> WORKER_STARTED
2021-08-10 09:25:26,468 [INFO ] W-9000-clean_1.0-stdout MODEL_LOG - Python runtime: 3.8.5
2021-08-10 09:25:26,468 [INFO ] W-9000-clean_1.0 org.pytorch.serve.wlm.WorkerThread - Connecting to: /tmp/.ts.sock.9000
2021-08-10 09:25:26,476 [INFO ] W-9000-clean_1.0-stdout MODEL_LOG - Connection accepted: /tmp/.ts.sock.9000.
2021-08-10 09:25:26,476 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - Connection accepted: /tmp/.ts.sock.9001.
2021-08-10 09:25:26,500 [INFO ] W-9000-clean_1.0-stdout MODEL_LOG - model_name: clean, batchSize: 1
2021-08-10 09:25:26,500 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - model_name: clean, batchSize: 1
2021-08-10 09:25:29,951 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - Missing the index_to_name.json file. Inference output will not include class name.
2021-08-10 09:25:29,958 [INFO ] W-9001-clean_1.0 org.pytorch.serve.wlm.WorkerThread - Backend response time: 3457
2021-08-10 09:25:29,958 [DEBUG] W-9001-clean_1.0 org.pytorch.serve.wlm.WorkerThread - W-9001-clean_1.0 State change WORKER_STARTED -> WORKER_MODEL_LOADED
2021-08-10 09:25:29,958 [INFO ] W-9001-clean_1.0 TS_METRICS - W-9001-clean_1.0.ms:4098|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587529
2021-08-10 09:25:29,959 [INFO ] W-9001-clean_1.0 TS_METRICS - WorkerThreadTime.ms:24|#Level:Host|#hostname:14ef1de0d587,timestamp:null
2021-08-10 09:25:29,968 [INFO ] W-9000-clean_1.0-stdout MODEL_LOG - Missing the index_to_name.json file. Inference output will not include class name.
2021-08-10 09:25:29,968 [INFO ] W-9000-clean_1.0 org.pytorch.serve.wlm.WorkerThread - Backend response time: 3468
2021-08-10 09:25:29,969 [DEBUG] W-9000-clean_1.0 org.pytorch.serve.wlm.WorkerThread - W-9000-clean_1.0 State change WORKER_STARTED -> WORKER_MODEL_LOADED
2021-08-10 09:25:29,969 [INFO ] W-9000-clean_1.0 TS_METRICS - W-9000-clean_1.0.ms:4110|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587529
2021-08-10 09:25:29,969 [INFO ] W-9000-clean_1.0 TS_METRICS - WorkerThreadTime.ms:23|#Level:Host|#hostname:14ef1de0d587,timestamp:null

Inference

curl http://127.0.0.1:8080/predictions/clean -T examples/face_512resolution.png

output:

{
  "code": 500,
  "type": "InternalServerException",
  "message": "Worker died."
}

logs:

2021-08-10 09:26:26,062 [INFO ] pool-2-thread-1 TS_METRICS - CPUUtilization.Percent:0.0|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587586
2021-08-10 09:26:26,063 [INFO ] pool-2-thread-1 TS_METRICS - DiskAvailable.Gigabytes:15.86968994140625|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587586
2021-08-10 09:26:26,063 [INFO ] pool-2-thread-1 TS_METRICS - DiskUsage.Gigabytes:852.9616088867188|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587586
2021-08-10 09:26:26,063 [INFO ] pool-2-thread-1 TS_METRICS - DiskUtilization.Percent:98.2|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587586
2021-08-10 09:26:26,063 [INFO ] pool-2-thread-1 TS_METRICS - MemoryAvailable.Megabytes:101453.48828125|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587586
2021-08-10 09:26:26,064 [INFO ] pool-2-thread-1 TS_METRICS - MemoryUsed.Megabytes:25364.87890625|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587586
2021-08-10 09:26:26,064 [INFO ] pool-2-thread-1 TS_METRICS - MemoryUtilization.Percent:21.2|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587586
2021-08-10 09:26:49,873 [ERROR] epollEventLoopGroup-5-1 org.pytorch.serve.wlm.WorkerThread - Unknown exception
io.netty.handler.codec.CorruptedFrameException: Message size exceed limit: 21025425
    at org.pytorch.serve.util.codec.CodecUtils.readLength(CodecUtils.java:24)
    at org.pytorch.serve.util.codec.ModelResponseDecoder.decode(ModelResponseDecoder.java:75)
    at io.netty.handler.codec.ByteToMessageDecoder.decodeRemovalReentryProtection(ByteToMessageDecoder.java:501)
    at io.netty.handler.codec.ByteToMessageDecoder.callDecode(ByteToMessageDecoder.java:440)
    at io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:276)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:379)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:365)
    at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:357)
    at io.netty.channel.DefaultChannelPipeline$HeadContext.channelRead(DefaultChannelPipeline.java:1410)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:379)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:365)
    at io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:919)
    at io.netty.channel.epoll.AbstractEpollStreamChannel$EpollStreamUnsafe.epollInReady(AbstractEpollStreamChannel.java:795)
    at io.netty.channel.epoll.EpollDomainSocketChannel$EpollDomainUnsafe.epollInReady(EpollDomainSocketChannel.java:138)
    at io.netty.channel.epoll.EpollEventLoop.processReady(EpollEventLoop.java:475)
    at io.netty.channel.epoll.EpollEventLoop.run(EpollEventLoop.java:378)
    at io.netty.util.concurrent.SingleThreadEventExecutor$4.run(SingleThreadEventExecutor.java:989)
    at io.netty.util.internal.ThreadExecutorMap$2.run(ThreadExecutorMap.java:74)
    at io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30)
    at java.base/java.lang.Thread.run(Thread.java:829)
2021-08-10 09:26:49,876 [INFO ] epollEventLoopGroup-5-1 org.pytorch.serve.wlm.WorkerThread - 9001 Worker disconnected. WORKER_MODEL_LOADED
2021-08-10 09:26:49,876 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - Backend worker process died.
2021-08-10 09:26:49,876 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - Traceback (most recent call last):
2021-08-10 09:26:49,876 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG -   File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 183, in <module>
2021-08-10 09:26:49,876 [DEBUG] W-9001-clean_1.0 org.pytorch.serve.wlm.WorkerThread - System state is : WORKER_MODEL_LOADED
2021-08-10 09:26:49,876 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG -     worker.run_server()
2021-08-10 09:26:49,876 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG -   File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 155, in run_server
2021-08-10 09:26:49,877 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG -     self.handle_connection(cl_socket)
2021-08-10 09:26:49,877 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG -   File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 115, in handle_connection
2021-08-10 09:26:49,876 [DEBUG] W-9001-clean_1.0 org.pytorch.serve.wlm.WorkerThread - Backend worker monitoring thread interrupted or backend worker process died.
java.lang.InterruptedException
    at java.base/java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.reportInterruptAfterWait(AbstractQueuedSynchronizer.java:2056)
    at java.base/java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2133)
    at java.base/java.util.concurrent.ArrayBlockingQueue.poll(ArrayBlockingQueue.java:432)
    at org.pytorch.serve.wlm.WorkerThread.run(WorkerThread.java:188)
    at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:515)
    at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
    at java.base/java.lang.Thread.run(Thread.java:829)
2021-08-10 09:26:49,877 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG -     cl_socket.sendall(resp)
2021-08-10 09:26:49,877 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - BrokenPipeError: [Errno 32] Broken pipe
2021-08-10 09:26:49,877 [ERROR] epollEventLoopGroup-5-1 org.pytorch.serve.wlm.WorkerThread - Unknown exception
io.netty.handler.codec.CorruptedFrameException: Message size exceed limit: 21025425
    at org.pytorch.serve.util.codec.CodecUtils.readLength(CodecUtils.java:24)
    at org.pytorch.serve.util.codec.ModelResponseDecoder.decode(ModelResponseDecoder.java:75)
    at io.netty.handler.codec.ByteToMessageDecoder.decodeRemovalReentryProtection(ByteToMessageDecoder.java:501)
    at io.netty.handler.codec.ByteToMessageDecoder.callDecode(ByteToMessageDecoder.java:440)
    at io.netty.handler.codec.ByteToMessageDecoder.channelInputClosed(ByteToMessageDecoder.java:404)
    at io.netty.handler.codec.ByteToMessageDecoder.channelInputClosed(ByteToMessageDecoder.java:371)
    at io.netty.handler.codec.ByteToMessageDecoder.channelInactive(ByteToMessageDecoder.java:354)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelInactive(AbstractChannelHandlerContext.java:262)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelInactive(AbstractChannelHandlerContext.java:248)
    at io.netty.channel.AbstractChannelHandlerContext.fireChannelInactive(AbstractChannelHandlerContext.java:241)
    at io.netty.channel.DefaultChannelPipeline$HeadContext.channelInactive(DefaultChannelPipeline.java:1405)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelInactive(AbstractChannelHandlerContext.java:262)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelInactive(AbstractChannelHandlerContext.java:248)
    at io.netty.channel.DefaultChannelPipeline.fireChannelInactive(DefaultChannelPipeline.java:901)
    at io.netty.channel.AbstractChannel$AbstractUnsafe$8.run(AbstractChannel.java:819)
    at io.netty.util.concurrent.AbstractEventExecutor.safeExecute(AbstractEventExecutor.java:164)
    at io.netty.util.concurrent.SingleThreadEventExecutor.runAllTasks(SingleThreadEventExecutor.java:472)
    at io.netty.channel.epoll.EpollEventLoop.run(EpollEventLoop.java:384)
    at io.netty.util.concurrent.SingleThreadEventExecutor$4.run(SingleThreadEventExecutor.java:989)
    at io.netty.util.internal.ThreadExecutorMap$2.run(ThreadExecutorMap.java:74)
    at io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30)
    at java.base/java.lang.Thread.run(Thread.java:829)
2021-08-10 09:26:49,878 [WARN ] W-9001-clean_1.0-stderr MODEL_LOG - /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3657: UserWarning: The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. 
2021-08-10 09:26:49,878 [WARN ] W-9001-clean_1.0-stderr MODEL_LOG -   warnings.warn(
2021-08-10 09:26:49,887 [INFO ] W-9001-clean_1.0 ACCESS_LOG - /127.0.0.1:59494 "PUT /predictions/clean HTTP/1.1" 500 876
2021-08-10 09:26:49,888 [INFO ] W-9001-clean_1.0 TS_METRICS - Requests5XX.Count:1|#Level:Host|#hostname:14ef1de0d587,timestamp:null
2021-08-10 09:26:49,888 [DEBUG] W-9001-clean_1.0 org.pytorch.serve.job.Job - Waiting time ns: 199942, Inference time ns: 872438847
2021-08-10 09:26:49,888 [DEBUG] W-9001-clean_1.0 org.pytorch.serve.wlm.WorkerThread - W-9001-clean_1.0 State change WORKER_MODEL_LOADED -> WORKER_STOPPED
2021-08-10 09:26:49,888 [WARN ] W-9001-clean_1.0 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9001-clean_1.0-stderr
2021-08-10 09:26:49,889 [WARN ] W-9001-clean_1.0 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9001-clean_1.0-stdout
2021-08-10 09:26:49,889 [INFO ] W-9001-clean_1.0 org.pytorch.serve.wlm.WorkerThread - Retry worker: 9001 in 1 seconds.
2021-08-10 09:26:50,055 [INFO ] W-9001-clean_1.0-stderr org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9001-clean_1.0-stderr
2021-08-10 09:26:50,055 [INFO ] W-9001-clean_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9001-clean_1.0-stdout
2021-08-10 09:26:51,401 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - Listening on port: /tmp/.ts.sock.9001
2021-08-10 09:26:51,402 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - [PID]31757
2021-08-10 09:26:51,402 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - Torch worker started.
2021-08-10 09:26:51,402 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - Python runtime: 3.8.5
2021-08-10 09:26:51,402 [DEBUG] W-9001-clean_1.0 org.pytorch.serve.wlm.WorkerThread - W-9001-clean_1.0 State change WORKER_STOPPED -> WORKER_STARTED
2021-08-10 09:26:51,402 [INFO ] W-9001-clean_1.0 org.pytorch.serve.wlm.WorkerThread - Connecting to: /tmp/.ts.sock.9001
2021-08-10 09:26:51,403 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - Connection accepted: /tmp/.ts.sock.9001.
2021-08-10 09:26:51,410 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - model_name: clean, batchSize: 1
2021-08-10 09:26:54,554 [INFO ] W-9001-clean_1.0-stdout MODEL_LOG - Missing the index_to_name.json file. Inference output will not include class name.
2021-08-10 09:26:54,555 [INFO ] W-9001-clean_1.0 org.pytorch.serve.wlm.WorkerThread - Backend response time: 3145
2021-08-10 09:26:54,555 [DEBUG] W-9001-clean_1.0 org.pytorch.serve.wlm.WorkerThread - W-9001-clean_1.0 State change WORKER_STARTED -> WORKER_MODEL_LOADED
2021-08-10 09:26:54,555 [INFO ] W-9001-clean_1.0 TS_METRICS - W-9001-clean_1.0.ms:88695|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587614
2021-08-10 09:26:54,555 [INFO ] W-9001-clean_1.0 TS_METRICS - WorkerThreadTime.ms:7|#Level:Host|#hostname:14ef1de0d587,timestamp:null
2021-08-10 09:27:26,060 [INFO ] pool-2-thread-2 TS_METRICS - CPUUtilization.Percent:0.0|#Level:Host|#hostname:14ef1de0d587,timestamp:1628587646
...
p1x31 commented 3 years ago

Hi @JiaweiShiCV, try setting max_response_size to a bigger value in config.properties

SimKarras commented 3 years ago

thx for your help! It works.

p1x31 commented 3 years ago

Try running your server without --ncs flag this will create config folder under logs folder e.g. torchserve --start --model-store model_store --models densenet161.mar You can copy the settings to a different file and add max_response_size=65535000 Then run with --ts_config flag e.g. torchserve --start --ncs --model-store model_store --models clean.mar --ts_config <your_config_file> Please refer to the docs

SimKarras commented 3 years ago

Try running your server without --ncs flag this will create config folder under logs folder e.g. torchserve --start --model-store model_store --models densenet161.mar You can copy the settings to a different file and add max_response_size=65535000 Then run with --ts_config flag e.g. torchserve --start --ncs --model-store model_store --models clean.mar --ts_config <your_config_file> Please refer to the docs

Yes, I have overcome this problem through your method. Thank you very much!