Closed MADAOOOOO closed 2 months ago
你好,我大概看懂了codes/config/deraining/models/modules/DenoisingNAFNet_arch.py 的ConditionalNAFNet类。但是我还是有点疑惑。这个类中forward函数依然要接收contition和time两个参数。这是不是说结合后的NAFNet仍然使用的是IR-sde中扩散模型的多步去噪方法?希望能指导我一下。非常感谢!
你好,这里的NAFNet是专门用于扩散模型的。如果想直接使用NAFNet的话可以去掉condition和time相关模块代码。
你好,我想我没表达清楚,我想问的是prompt模块是怎么和nafnet直接结合的。谢谢
你好,这里我提供一个结合的例子(其中degra和context embedding的获取方式和U-Net中相同):
class NAFBlock(nn.Module):
def __init__(self, c, emb_dim=None, att_type='sca', context_dim=512, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
super().__init__()
self.mlp = nn.Sequential(
SimpleGate(), nn.Linear(emb_dim // 2, c * 4)
) if emb_dim else None
dw_channel = c * DW_Expand
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
bias=True)
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
dim = dw_channel // 2
dim_head = 32
# Attention
if att_type == 'simple':
self.att = SimpleChannelAttention(dim)
elif att_type == 'cross':
num_heads = dim // dim_head
self.att = Residual(PreNorm(dim, SpatialTransformer(dim, num_heads, dim_head, depth=1, context_dim=context_dim)))
# SimpleGate
self.sg = SimpleGate()
ffn_channel = FFN_Expand * c
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.norm1 = LayerNorm(c)
self.norm2 = LayerNorm(c)
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
def degra_forward(self, degra, mlp):
degra_emb = mlp(degra)
degra_emb = rearrange(degra_emb, 'b c -> b c 1 1')
return degra_emb.chunk(4, dim=1)
def forward(self, x):
inp, degra, context = x
shift_att, scale_att, shift_ffn, scale_ffn = self.degra_forward(degra, self.mlp)
x = inp
x = self.norm1(x)
x = x * (scale_att + 1) + shift_att
x = self.conv1(x)
x = self.conv2(x)
x = self.sg(x)
x = self.att(x, context)
x = self.conv3(x)
x = self.dropout1(x)
y = inp + x * self.beta
x = self.norm2(y)
x = x * (scale_ffn + 1) + shift_ffn
x = self.conv4(x)
x = self.sg(x)
x = self.conv5(x)
x = self.dropout2(x)
x = y + x * self.gamma
return x, degra, context
好的,十分感谢!
你好,我读了论文 就如何将NAFNet和DA-CLIP结合和论文内容有些问题: 1,看了前面的issues的回答 找到了codes/config/deraining/models/modules/DenoisingNAFNet_arch.py。但是在ConditionalNAFNet类里面没看到使用image_context或是degra_context的地方,还是不太懂是怎么结合的。 2, 论文里的对特定的修复任务的实验有使用prompt embedding模块么?还是只使用了cross-attention?
非常感谢!