uber-research / sbnet

Sparse Blocks Networks
Other
435 stars 91 forks source link

Time cost increases #28

Open IwakuraRein opened 3 years ago

IwakuraRein commented 3 years ago

Hi. Thanks for the codes and the detailed instruction.

I implemented sparse convolution into my encoder:

with tf.variable_scope('featureEncoder'):
    auxiShape = (self.inputShape[0], self.inputShape[1], self.inputShape[2], 7)
    featureShape = (self.inputShape[0], self.inputShape[1], self.inputShape[2], 32)
    blockSize = 8
    blockStride = (8,8)
    blockOffset = (0,0)
    blockCount = (self.divup(self.inputShape[1], blockStride[0]), self.divup(self.inputShape[2], blockStride[1]))
    inBlockParams = { "dynamic_bsize": (blockSize, blockSize), "dynamic_boffset": blockOffset, "dynamic_bstride": blockStride }
    outBlockParams = { "dynamic_bsize": (blockSize, blockSize), "dynamic_boffset": blockOffset, "dynamic_bstride": blockStride }

    if not self.training:
        indices = sbnet_module.reduce_mask(self.mask, blockCount, tol=0.1, **inBlockParams)

        # stack active overlapping tiles to batch dimension
        stack = sbnet_module.sparse_gather(
            auxi, indices.bin_counts, indices.active_block_indices, transpose=False, **inBlockParams)
    else:
        stack = auxi
    # perform dense convolution on a sparse stack of tiles
    stack = self.conv_layer2(stack, 7, 32, name='1')
    stack = tf.nn.leaky_relu(stack)
    stack = self.conv_layer2(stack, 32,32, name='2')
    stack = tf.nn.leaky_relu(stack)
    stack = self.conv_layer2(stack, 32,32, name='3')
    stack = tf.nn.leaky_relu(stack)
    stack = self.conv_layer2(stack, 32,32, name='4')
    stack = tf.nn.leaky_relu(stack)
    stack = self.conv_layer2(stack, 32,32, name='5')
    stack = tf.nn.leaky_relu(stack)

    # write/scatter the tiles back on top of original tensor
    # note that the output tensor is reduced by 1 on each side due to 'VALID' convolution
    if not self.training:
        feature=sbnet_module.sparse_scatter(
            stack, indices.bin_counts, indices.active_block_indices,
            self.lastFeature, transpose=False, add=False, atomic=False, **outBlockParams)
        feature.set_shape(featureShape)
    else:
        feature=stack

self.training is set False when training and True when testing. Variable mask is generated outside the network and fed in via tf.placeholder. So does self.lastFeature.

I tried to measure the inference time with timeline:

feed_dict = {model.source: src, model.target: tgt, model.batch_size:src_hdr.shape[0], model.mask:Mask, model.feature:Feature}
denoised_1_bd, Feature = sess.run([model.fake_image, model.feature], feed_dict, options=run_options, run_metadata=run_metadata)
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format(show_memory=True)
with open(os.path.join(errorlog_dir, 'timeline.json'),'w') as wd:
    wd.write(ctf)

timeline

However, I can't find time records of layers under 'featureEncoder'. And there are two bars captioned unknown, the second of which is strangely long. Some Pooling and LeakyRelu‘s time is also strange, costing nearly 2ms.

unknown

I wonder how I can get the proper time measurement. Thanks.

My Environment TensorFlow Version: 1.15.0 Operating System: Ubuntu 16.04 Python Version: 3.6.13 CUDA Version: 10.0 CUDNN Version: 7.6.4 GPU Type: RTX 2080ti Nvidia Driver Version: 460.67

IwakuraRein commented 3 years ago

I wrap the structures behind the featureEncoder with with tf.control_dependencies([feature]): and now the timeline result seems fine. It's nearly the same as sbnet_module.cuda_timer's result.

However, the time cost of the featureEncoder increases heavily. My input is (720, 1280, 7). The original network spends roughly 38ms, where the featureEncoder takes up about 10ms. I want to reduce the inference time to less than 33ms. After wrapping the featureEncoder with SparseScatter and SparseGather, the network's inference time comes to 44ms with all '1' in the mask.

When I feed the mask of zero values, strange happens. When the sparsity comes to nearly 0.1, the time rises to 150ms. Convolutions under the featureEncoder become discrete pieces shown in the timeline chart. The time is 90ms when the sparsity goes to 0.5 and 64ms with 0.8.

I checked the issue. I've tried many block sizes and sparsity but still seeing no improvement. Firstly I guess the problem is because the sparse convolution reduces the GPU memory usage, causing it lazy. But since your experiment used GTX1080ti, I think the method works well on powerful GPUs.

I must have misunderstood something and made serious mistakes. Hope to receive answers. Thanks.