Closed ResearchingDexter closed 4 years ago
调整了代码后能够复现
import paddle.fluid as fluid
from paddle.fluid import ParamAttr as ParamAttr
from paddle.fluid.initializer import Normal
import numpy as np
def _norm_initial(loc=0.0,scale=1.0,seed=0):
return ParamAttr(initializer=Normal(loc,scale,seed))
class FeatureAdaption(object):
def __init__(self,
out_channels,
kernel_size=3,
deformable_groups=4
):
self.out_channels=out_channels
self.kernel_size=kernel_size
self.deformable_groups=deformable_groups
self.offset_attr=_norm_initial(scale=0.1)
self.adaption_attr=_norm_initial(scale=0.01)
def __call__(self, x,shape):
offset_channels=self.kernel_size*self.kernel_size*2
offset=fluid.layers.conv2d(shape,
offset_channels*self.deformable_groups,
filter_size=1,
param_attr=self.offset_attr,
bias_attr=False)
# only DCNv2 is supported in the paddlepaddle,so initialize the mask by constant 0.
_,c,h,w=offset.shape
# offset.stop_gradient = True
#mask=fluid.layers.fill_constant([b,c//2,h,w],dtype='float32',value=1.)
mask=fluid.layers.fill_constant_batch_size_like(offset,[-1,c//2,h,w],dtype='float32',value=1.)
out=fluid.layers.deformable_conv(x,offset,mask,
self.out_channels,
filter_size=self.kernel_size,
stride=1,
padding=(self.kernel_size-1)//2,
deformable_groups=self.deformable_groups,
param_attr=self.adaption_attr)
out=fluid.layers.relu(out)
return out
if __name__=='__main__':
feature=FeatureAdaption(256,3,4)
main_program = fluid.Program()
start_program = fluid.Program()
with fluid.program_guard(main_program, start_program):
x=fluid.layers.data('x',[256,50,50])
shape=fluid.layers.data('shape',[2,50,50])
out=feature(x,shape)
out=fluid.layers.reduce_mean(out)
optimizer = fluid.optimizer.SGD(0.1)
optimizer.minimize(out)
exe=fluid.Executor(fluid.CUDAPlace(0))
x_np=np.ones((1,256,50,50),'float32')
shape_np=np.ones((1,2,50,50),'float32')
exe.run(start_program)
out=exe.run(main_program,feed={'x':x_np,'shape':shape_np},fetch_list=[out.name])
print(out)
设置 offset.stop_gradient = True
能够运行。
但是我需要offset的梯度回传啊,设置了True就没办法回传了 啊,我是想要DCNV1的但是paddle目前不支持啊,我该如何实现
好的,我下一下
报错信息