google-deepmind / deepmind-research

This repository contains implementations and illustrative code to accompany DeepMind publications
Apache License 2.0
12.99k stars 2.55k forks source link

snt.wrap_with_spectral_norm(snt.Conv3D, {'eps': 1e-4}) #339

Open wang6501sfx opened 2 years ago

wang6501sfx commented 2 years ago

Problems with pseudocode for Skilful precipitation nowcasting using deep articles

I would like to ask a question about pseudo code, in class layers SNConv3D=snt.wrap_with_spectral_norm(snt.Conv3D, {'eps': 1e-4})

I don't see this function in sonnet, how to deal with this.

wang6501sfx commented 2 years ago

I would like to ask a question about pseudocode above Skilful precipitation nowcasting using deep generative models of radar. In Class layers have function SNConv3D=snt.wrap_with_spectral_norm(snt.Conv3D, {'eps': 1e-4}) I don't see this function in sonnet, how to deal with this.

  The following is the program I implemented  ,I hope to get your help.

class SNConv3D(Model): """3D convolution with spectral regularisation."""

def __init__(self, output_channels, kernel_size, stride=1, rate=1,
           padding='SAME', sn_eps=0.0001, use_bias=True):
super(SNConv3D, self).__init__()
"""Constructor."""
self._output_channels = output_channels
self._kernel_size = kernel_size
self._stride = stride
self._rate = rate
self._padding = padding
self._sn_eps = sn_eps
self._use_bias = use_bias
self._snConv3D=tf.nn.conv3d

def call(self, tensor):

TO BE IMPLEMENTED

# One possible implementation is provided using the Sonnet library as:
# SNConv3D = snt.wrap_with_spectral_norm(snt.Conv3D, {'eps': 1e-4})
w = tf.compat.v1.get_variable("kernel", shape=[self._kernel_size,self._kernel_size, self._kernel_size, tensor.get_shape()[-1],self._output_channels]
                              ,initializer=tf.compat.v1.truncated_normal_initializer())
tensor = self._snConv3D(input=tensor, filters=spectral_norm(w),strides=[1, self._stride,self._stride, self._stride, 1],padding=self._padding)
return tensor
pass

def l2_norm(v, eps=1e-4): return v / (tf.reduce_sum(v 2) 0.5 + eps)

def spectral_norm(w, iteration=1): w_shape = w.shape.as_list() w = tf.reshape(w, [-1, w_shape[-1]]) u = tf.compat.v1.get_variable("u", [1, w_shape[-1]], initializer=tf.compat.v1.truncated_normal_initializer(), trainable=False)

u_hat = u
v_hat = None
for i in range(iteration):
  v_ = tf.matmul(u_hat, tf.transpose(w))
  v_hat = l2_norm(v_)
  u_ = tf.matmul(v_hat, w)
  u_hat = l2_norm(u_)
sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
w_norm = w / sigma
with tf.control_dependencies([u.assign(u_hat)]):
  w_norm = tf.reshape(w_norm, w_shape)
return w_norm