tensorflow / lattice

Lattice methods in TensorFlow
Apache License 2.0
518 stars 94 forks source link

Creating a 3D LUT for Colour Management #75

Open ethan-ou opened 1 year ago

ethan-ou commented 1 year ago

Hello, I wanted to ask how to set up TF Lattice for creating 3D LUTs (similar to Section IV of Optimized Regression for Efficient Function Evaluation in 3D). Given pairs of RGB values, how would you set up Lattice to create a 3D transform between them and evaluate new input data?

Here's an example of how the input data would look:

device_data = np.array([
    [0.185776889324, 0.142643123865, 0.118471339345],
    [0.322191089392, 0.256722867489, 0.232894062996],
    [0.188466250896, 0.213120296597, 0.231695920229],
    [0.16567106545, 0.168356031179, 0.126793220639],
    [0.233433827758, 0.226880550385, 0.251555323601],
    # More rows of patches
])
patch_data = np.array([
    [0.224871307611, 0.199941590428, 0.131533548236],
    [0.32054439187, 0.293659061193, 0.254355072975],
    [0.196207210422, 0.242451697588, 0.290418058634],
    [0.188654735684, 0.229085355997, 0.136322394013],
    [0.246626213193, 0.252050608397, 0.297366410494],
    # More rows of patches
]

And here's the code I've tried for creating a lattice between them. I believe this would be able to create a 1D transform but not a 3D one:

import tensorflow as tf
import copy
import logging
import numpy as np
import sys
import tensorflow_lattice as tfl
logging.disable(sys.maxsize)

LEARNING_RATE = 0.01
BATCH_SIZE = 128
NUM_EPOCHS = 500
PREFITTING_NUM_EPOCHS = 10

feature_names = ['r', 'g', 'b']

device_data_T = device_data.T
patch_data_T = patch_data.T

train_dict = {
    'r': device_data_T[0],
    'g': device_data_T[1],
    'b': device_data_T[2]
}

train_input = [device_data_T[0], device_data_T[1], device_data_T[2]]
train_output = [patch_data_T[0], patch_data_T[1], patch_data_T[2]]

feature_configs = [
    tfl.configs.FeatureConfig(
        name='r',
        monotonicity='increasing',
    ),
        tfl.configs.FeatureConfig(
        name='g',
        monotonicity='increasing',
    ),
        tfl.configs.FeatureConfig(
        name='b',
        monotonicity='increasing',
    ),
]

feature_keypoints = tfl.premade_lib.compute_feature_keypoints(
    feature_configs=feature_configs, features=train_dict)
tfl.premade_lib.set_feature_keypoints(
    feature_configs=feature_configs,
    feature_keypoints=feature_keypoints,
    add_missing_feature_configs=False)

lattice_model_config = tfl.configs.CalibratedLatticeConfig(
    feature_configs=feature_configs,
    output_initialization=[0.0, 1.0],
    regularizer_configs=[
        # Torsion regularizer applied to the lattice to make it more linear.
        tfl.configs.RegularizerConfig(name='torsion', l2=1e-2),
        # Globally defined calibration regularizer is applied to all features.
        tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-2),
    ])
# A CalibratedLattice premade model constructed from the given model config.
lattice_model = tfl.premade.CalibratedLattice(lattice_model_config)

lattice_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.AUC(from_logits=True)],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
lattice_model.fit(
    train_input,
    train_output,
    epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=False)