creotiv / hdrnet-pytorch

Unofficial PyTorch implementation of 'Deep Bilateral Learning for Real-Time Image Enhancement', SIGGRAPH 2017 https://groups.csail.mit.edu/graphics/hdrnet/
224 stars 43 forks source link

I write some code about HDRNetCurves #24

Open alexliyang opened 2 years ago

alexliyang commented 2 years ago

I think use conv can replace google/hdrnet's ccm function block:

Color space change

idtity = np.identity(nchans, dtype=np.float32) + np.random.randn(1).astype(np.float32)*1e-4
ccm = tf.get_variable('ccm', dtype=tf.float32, initializer=idtity)
with tf.name_scope('ccm'):
  ccm_bias = tf.get_variable('ccm_bias', shape=[nchans,], dtype=tf.float32, initializer=tf.constant_initializer(0.0))

  guidemap = tf.matmul(tf.reshape(input_tensor, [-1, nchans]), ccm)
  guidemap = tf.nn.bias_add(guidemap, ccm_bias, name='ccm_bias_add')

  guidemap = tf.reshape(guidemap, tf.shape(input_tensor))

so , the code like the following:

class GuideCurves(nn.Module): def init(self,npts = 16): super(GuideCurves, self).init() self.guide_pts = npts self.ccm = ConvBlock(3,3,kernel_size=1,padding=0,use_bias=True, activation=None, batch_norm=False)

    self.shifts = np.linspace(0, 1, self.guide_pts, endpoint=False, dtype=np.float32)
    self.shifts = self.shifts[np.newaxis, np.newaxis, np.newaxis, :]
    self.shifts = np.tile(self.shifts, (3, 1, 1, 1))
    self.shifts = nn.Parameter(data=torch.from_numpy(self.shifts))

    self.slopes = np.zeros([1, 3, 1, 1, self.guide_pts], dtype=np.float32)
    self.slopes[:, :, :, :, 0] = 1.0
    self.slopes = nn.Parameter(data=torch.from_numpy(self.slopes))  

    self.projection = ConvBlock(3,1,kernel_size=1,padding=0,use_bias=True, activation=None, batch_norm=False)

def forward(self, x):
    guidemap = self.ccm(x)
    guidemap = guidemap.unsqueeze(dim=4)
    guidemap = (self.slopes * F.relu(guidemap - self.shifts)).sum(dim=4)
    guidemap = self.projection(guidemap)
    guidemap = F.hardtanh(guidemap, min_val=0, max_val=1)
    #print('guidemap:',guidemap.shape)
    #guidemap = guidemap.squeeze(dim=1) 

    return guidemap
creotiv commented 2 years ago

Curves works worse then Conv, but still it's a good thing. you can make a PR and i add it to repo. Thanks