Open alexliyang opened 2 years ago
I think use conv can replace google/hdrnet's ccm function block:
idtity = np.identity(nchans, dtype=np.float32) + np.random.randn(1).astype(np.float32)*1e-4 ccm = tf.get_variable('ccm', dtype=tf.float32, initializer=idtity) with tf.name_scope('ccm'): ccm_bias = tf.get_variable('ccm_bias', shape=[nchans,], dtype=tf.float32, initializer=tf.constant_initializer(0.0)) guidemap = tf.matmul(tf.reshape(input_tensor, [-1, nchans]), ccm) guidemap = tf.nn.bias_add(guidemap, ccm_bias, name='ccm_bias_add') guidemap = tf.reshape(guidemap, tf.shape(input_tensor))
so , the code like the following:
class GuideCurves(nn.Module): def init(self,npts = 16): super(GuideCurves, self).init() self.guide_pts = npts self.ccm = ConvBlock(3,3,kernel_size=1,padding=0,use_bias=True, activation=None, batch_norm=False)
self.shifts = np.linspace(0, 1, self.guide_pts, endpoint=False, dtype=np.float32) self.shifts = self.shifts[np.newaxis, np.newaxis, np.newaxis, :] self.shifts = np.tile(self.shifts, (3, 1, 1, 1)) self.shifts = nn.Parameter(data=torch.from_numpy(self.shifts)) self.slopes = np.zeros([1, 3, 1, 1, self.guide_pts], dtype=np.float32) self.slopes[:, :, :, :, 0] = 1.0 self.slopes = nn.Parameter(data=torch.from_numpy(self.slopes)) self.projection = ConvBlock(3,1,kernel_size=1,padding=0,use_bias=True, activation=None, batch_norm=False) def forward(self, x): guidemap = self.ccm(x) guidemap = guidemap.unsqueeze(dim=4) guidemap = (self.slopes * F.relu(guidemap - self.shifts)).sum(dim=4) guidemap = self.projection(guidemap) guidemap = F.hardtanh(guidemap, min_val=0, max_val=1) #print('guidemap:',guidemap.shape) #guidemap = guidemap.squeeze(dim=1) return guidemap
Curves works worse then Conv, but still it's a good thing. you can make a PR and i add it to repo. Thanks
I think use conv can replace google/hdrnet's ccm function block:
Color space change
so , the code like the following:
class GuideCurves(nn.Module): def init(self,npts = 16): super(GuideCurves, self).init() self.guide_pts = npts self.ccm = ConvBlock(3,3,kernel_size=1,padding=0,use_bias=True, activation=None, batch_norm=False)