bloc97 / Anime4K

A High-Quality Real Time Upscaler for Anime Video
MIT License
17.98k stars 1.34k forks source link

How to train/load S/M/L CNN models in tensorflow? #220

Open kristoftunner opened 6 months ago

kristoftunner commented 6 months ago

Is there a way to train/load S/M/L CNN models in tensorflow? I am interested in experimenting a bit with these models in tensorflow or onnxruntime. I see that there is one specific model in the tensorflow directory, but I am not sure which one is it.

Tama47 commented 4 months ago

Is there a way to train/load S/M/L CNN models in tensorflow?

Yes, you would need to load the original models in TensorFlow.

I see that there is one specific model in the tensorflow directory, but I am not sure which one is it.

Someone has converted the original Anime4K models into Core ML models. I can provide you the link.

The ones you're looking for are under Models > model-sr-s.wifm / model-sr-m.wifm / model-sr-l.wifm for upscale models model-restore-s.wifm / model-restore-m.wifm / model-restore-l.wifm for restore models

You would need to convert them to TensorFlow, then create a Python or Jupyter Notebook script to load the weights and models. You can use the models to fine-tune and train your own, better model.

Note I have not converted or trained the models myself, and cannot guarantee success. I can only provide general steps, and you will need to do your own research. Supposedly, the steps to convert between Core ML and TensorFlow should be relatively straightforward. The training process itself should be more or less the same as training any other TensorFlow or ESRGAN models.

Sample Python Script:
import coremltools as ct

# Load Core ML model coreml_model = ct.models.MLModel('model-sr-s.wifm')

# Convert Core ML to TensorFlow tf_model = ct.convert(coreml_model, source='mlmodel', target='tensorflow')
arianaa30 commented 4 months ago

@Tama47 The training code located in \tensorflow dir is for the restore or upscale model? And if it is the restore, is it easy to change it to the "upscale" model to train?

Fannovel16 commented 4 months ago

@Tama47 From what I've researched so far, there is no way to convert current version of MLModel to TF2 or ONNX. However, I managed to get Netron working and also loading weight:

  1. Change the file extension from .whml to .zip
  2. Compress all files inside of *.mlpackage folder (not including the folder itself) to a zip file

Preset-a-hq: preset-a-hq zip

Fannovel16 commented 4 months ago

I can convert some GLSL files to PyTorch now but still stuck at converting the weight. Here is the code if anyone interested:

Fannovel16 commented 4 months ago


The training code located in \tensorflow dir is for the restore or upscale model?

It contains both. In Gen_Shader.ipynb, SR1Model generate restore models while SR2Model generate upscale ones

arianaa30 commented 4 months ago


The training code located in \tensorflow dir is for the restore or upscale model?

It contains both. In Gen_Shader.ipynb, SR1Model generate restore models while SR2Model generate upscale ones

Thanks. Apparently the models there for for M size shaders. Do you happen to know what parameter values (block_depth, etc?) I should use to get the S/L sizes?

Also, the code uses epochs=1 (3 times). Should I change them to like 100? I noticed the loss doesn't really decrease.

Fannovel16 commented 4 months ago


Thanks. Apparently the models there for for M size shaders. Do you happen to know what parameter values (block_depth, etc?) I should use to get the S/L sizes?

Ig you can figure out the block_depth with a model's components Conv2d(3, 4) means it has 3 input channels and 4 output channels. CReLU() activation function doubles channel size, e.g. (1, 128, 128, 4) -> (1, 128, 128, 8)

Size S:

  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_1_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_2_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_last_0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)

Size M

  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_1_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_2_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_3_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_4_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_5_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_6_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_last_0): Conv2d(56, 4, kernel_size=(1, 1), stride=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)

Size L:

  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_1): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_1_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_1_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_2_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_2_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_last_0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2d_last_1): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2d_last_2): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)

Size VL:

  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_1): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_1_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_1_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_2_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_2_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_3_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_3_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_4_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_4_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_5_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_5_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_6_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_6_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_last_0): Conv2d(112, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv2d_last_1): Conv2d(112, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv2d_last_2): Conv2d(112, 4, kernel_size=(1, 1), stride=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)

Size UL:

  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_1): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_2): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_1_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_1_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_1_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_2_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_2_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_2_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_3_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_3_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_3_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_4_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_4_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_4_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_5_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_5_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_5_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_6_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_6_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_6_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  (conv2d_last_0): Conv2d(120, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv2d_last_1): Conv2d(120, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv2d_last_2): Conv2d(120, 4, kernel_size=(1, 1), stride=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)
arianaa30 commented 4 months ago

Thanks. I have those architecture. But do you know what to pass to this function to get each of those S, L, VL sizes? I need it for training.

def SR2Model (input_depth=3, highway_depth=4, block_depth=4, init='he_normal', init_last = RandomNormal (mean=0.0, stddev=0.001)):

Fannovel16 commented 4 months ago

@arianaa30 My main library is PyTorch so Idk tbh

arianaa30 commented 4 months ago

@Fannovel16 Btw, do you know how to measure SSIM/PSNR of what Anime4K shaders provide me (upscaled version of low-res image) vs the original high resolution image? Is there a way to measure them?

Fannovel16 commented 4 months ago

@arianaa30 You can pass images to mpv, ffmpeg with compiled libplacebo using commands and save upscaled images. I'm not sure how to do that tho since I'm not familiar much with ffmpeg or mpv. Alternatively, you can try Anime4K-GPU and use canvas.toDataURL("image/png") to save the results

Fannovel16 commented 4 months ago

Btw I have no luck on converting GLSL shaders into actual models in PyTorch. I believe the problem is transforming the weight.

arianaa30 commented 4 months ago

@arianaa30 You can pass images to mpv, ffmpeg with compiled libplacebo using commands and save upscaled images. I'm not sure how to do that tho since I'm not familiar much with ffmpeg or mpv. Alternatively, you can try Anime4K-GPU and use canvas.toDataURL("image/png") to save the results

Hmm ok thanks. The problem is we apply multiple anime4k shaders (restore, upscale, restore, ...). Not sure if we can do that..

Fannovel16 commented 4 months ago

@arianaa30 It's possible: But now you mentioned it, I kinda wonder how A4K shaders were actually trained.

arianaa30 commented 4 months ago

@Fannovel16 Yeah the training has some unknowns. Using the Tensorflow script, I trained a model/shader by calling SR2Model() function, and it works. But when I trained the SR1Model (which should be the Restore), the h5 model training works. But when trying to convert with, it shows me a "Shape Mismatch" error. Have you experienced it before?

 Layer (type)                                Output Shape                                 Param #        Connected to
 input.MAIN (InputLayer)                     [(None, None, None, 3)]                      0              []

 conv2d (Conv2D)                             (None, None, None, 4)                        112            ['input.MAIN[0][0]']

 tf.compat.v1.nn.crelu (TFOpLambda)          (None, None, None, 8)                        0              ['conv2d[0][0]']

 conv2d_1 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu[0][0]']

 tf.compat.v1.nn.crelu_1 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_1[0][0]']

 conv2d_2 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_1[0][0]']

 tf.compat.v1.nn.crelu_2 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_2[0][0]']

 conv2d_3 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_2[0][0]']

 tf.compat.v1.nn.crelu_3 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_3[0][0]']

 conv2d_4 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_3[0][0]']

 tf.compat.v1.nn.crelu_4 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_4[0][0]']

 conv2d_5 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_4[0][0]']

 tf.compat.v1.nn.crelu_5 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_5[0][0]']

 conv2d_6 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_5[0][0]']

 tf.compat.v1.nn.crelu_6 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_6[0][0]']

 concatenate (Concatenate)                   (None, None, None, 56)                       0              ['tf.compat.v1.nn.crelu[0][0]',

 conv2d_lastresid.MAIN (Conv2D)              (None, None, None, 3)                        171            ['concatenate[0][0]']

 add.ignore.MAIN (Add)                       (None, None, None, 3)                        0              ['conv2d_lastresid.MAIN[0][0]',

Total params: 2035 (7.95 KB)
Trainable params: 2035 (7.95 KB)
Non-trainable params: 0 (0.00 Byte)
Traceback (most recent call last):
  File "", line 141, in <module>
  File "/opt/tensorflow/lib/python3.10/site-packages/keras/src/utils/", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/opt/tensorflow/lib/python3.10/site-packages/keras/src/", line 4361, in _assign_value_to_variable
ValueError: Cannot assign value to variable ' conv2d_lastresid.MAIN/kernel:0': Shape mismatch.The variable shape (1, 1, 56, 3), and the assigned value shape (12, 56, 1, 1) are incompatible.
kato-megumi commented 4 months ago

Btw I have no luck on converting GLSL shaders into actual models in PyTorch. I believe the problem is transforming the weight.

@Fannovel16 This is my script to convert GLSL shaders to PyTorch model.

arianaa30 commented 4 months ago

Btw I have no luck on converting GLSL shaders into actual models in PyTorch. I believe the problem is transforming the weight.

@Fannovel16 This is my script to convert GLSL shaders to PyTorch model.

Is the displayed image the upscaled output? Can we apply multiple shaders as well?

kato-megumi commented 4 months ago

Of course, just string the models together like this: model2(model1(image))

Fannovel16 commented 4 months ago

@kato-megumi Thanks! It seems like I got the CreLU formula wrong

arianaa30 commented 4 months ago

Of course, just string the models together like this: model2(model1(image))

Great thanks. Can we simply add other shaders to the list as well? I want to use Anime4K_Clamp_Highlights.glsl as well. Instructions highly recommend to have this in the list as it highly increases the quality.

Fannovel16 commented 4 months ago

@arianaa30 Here it is P/s: I made some changes based on kato's advice

def get_luma(x):
    x = x[:, 0] * 0.299 + x[:, 1] * 0.587 + x[:, 2] * 0.114
    x = x.unsqueeze(1)
    return x

class MaxPoolKeepShape(nn.Module):
    def __init__(self, kernel_size, stride=None):
        super(MaxPoolKeepShape, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride if stride is not None else kernel_size

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        kernel_height, kernel_width = self.kernel_size
        pad_height = (((height - 1) // self.stride + 1) - 1) * self.stride + kernel_height - height
        pad_width = (((width - 1) // self.stride + 1) - 1) * self.stride + kernel_width - width

        x = F.pad(x, (pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2))
        x = F.max_pool2d(x, kernel_size=self.kernel_size, stride=self.stride)
        return x

class ClampHighlight(nn.Module):
    def __init__(self):
        super(ClampHighlight, self).__init__()
        self.max_pool = MaxPoolKeepShape(kernel_size=(5, 5), stride=1)
    def forward(self, shader_img, orig_img):
        curr_luma = get_luma(shader_img)
        statsmax = self.max_pool(get_luma(orig_img))
        if statsmax.shape != curr_luma.shape:
            statsmax = F.interpolate(statsmax, curr_luma.shape[2:4])
        new_luma = torch.min(curr_luma, statsmax)
        return shader_img - (curr_luma - new_luma)

new_img = ClampHighlight()(out[None], image2)
kato-megumi commented 4 months ago
Fannovel16 commented 4 months ago


The kernel size should be (5, 5).

Oh so the first block iterates x-axis while the second block iterates y-axis

ClampHighlight clamps the output of another shader using the original image's luminance, so it requir two images as input.

What is PREKERNEL? I assumed it is the same as MAIN as the mpv doc is a bit ambiguous

kato-megumi commented 4 months ago

Oh so the first block iterates x-axis while the second block iterates y-axis

Yeah, it reduce computation cost compare to find max of 25 pixel in single pass.

What is PREKERNEL? I assumed it is the same as MAIN as the mpv doc is a bit ambiguous

In anime4k doc about ClampHighlight: "Computes and saves image statistics at the location it is placed in the shader stage, then clamps the image highlights at the end after all the shaders to prevent overshoot and reduce ringing."

PREKERNEL The image immediately before the scaler kernel runs.

I think it refers to the image right before mpv performs internal scaling. Other shaders are hooked to MAIN, which come before PREKERNEL in mpv's rendering process, so those should run first.

Fannovel16 commented 4 months ago

I added ClampHightlight, AutoDownscalePre, automatic glsl downloading and pipeline class for convenience:

arianaa30 commented 4 months ago

@arianaa30 Here it is P/s: I made some changes based on kato's advice

def get_luma(x):
    x = x[:, 0] * 0.299 + x[:, 1] * 0.587 + x[:, 2] * 0.114
    x = x.unsqueeze(1)
    return x

class MaxPoolKeepShape(nn.Module):
    def __init__(self, kernel_size, stride=None):
        super(MaxPoolKeepShape, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride if stride is not None else kernel_size

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        kernel_height, kernel_width = self.kernel_size
        pad_height = (((height - 1) // self.stride + 1) - 1) * self.stride + kernel_height - height
        pad_width = (((width - 1) // self.stride + 1) - 1) * self.stride + kernel_width - width

        x = F.pad(x, (pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2))
        x = F.max_pool2d(x, kernel_size=self.kernel_size, stride=self.stride)
        return x

class ClampHighlight(nn.Module):
    def __init__(self):
        super(ClampHighlight, self).__init__()
        self.max_pool = MaxPoolKeepShape(kernel_size=(5, 5), stride=1)
    def forward(self, shader_img, orig_img):
        curr_luma = get_luma(shader_img)
        statsmax = self.max_pool(get_luma(orig_img))
        if statsmax.shape != curr_luma.shape:
            statsmax = F.interpolate(statsmax, curr_luma.shape[2:4])
        new_luma = torch.min(curr_luma, statsmax)
        return shader_img - (curr_luma - new_luma)

new_img = ClampHighlight()(out[None], image2)

Great I will try it.

kato-megumi commented 4 months ago

I recommend using to train model. Just put pytorch model in arch/ folder, tweak some config in yml file and train.

arianaa30 commented 4 months ago

@Fannovel16 Btw do you have a training code for the PyTorch models? Would you be able to share?

Fannovel16 commented 4 months ago

@arianaa30 No but you can use my notebook to get the model and randomize its parameters to train

arianaa30 commented 4 months ago

@arianaa30 No but you can use my notebook to get the model and randomize its parameters to train

Should I fine tune it (only train last layers) or train the whole network? Btw your notebook shows some errors in the convert() function and use of combination () when I want to run the pipeline code. Maybe something recently changed.

Fannovel16 commented 4 months ago

@arianaa30 I forgot to test 😅 . It works now

Should I fine tune it (only train last layers) or train the whole network?

Anime4K's CNN networks are pretty small so training from scratch is a better choice, imo.

arianaa30 commented 1 month ago

@Fannovel16 @kato-megumi bumping up this thread: So I was trying to train the PyTorch Anime4K models using NeoSR. I trained and tested the produced model and idk why it generates a bad quality image with many grains, with SSIM like 0.65- 0.70. Here are more details.

Have any of you had any success training the PyTorch models, at least just to test out? It is so weird

kato-megumi commented 1 month ago

You might have encountered grain artifacts due to a flawed dataset or improper loss configuration. You'd likely receive better assistance from the folks at Enhance Everything!.