Git Product home page Git Product logo

Comments (15)

andrei-pokrovsky avatar andrei-pokrovsky commented on May 9, 2024 2

Also there are template instantiations for some block sizes missing in the source. I recommend autotuning over block size for your specific problem (ie searching for the fastest block/sparsity trade-off and the block size that fits your sparsity pattern best). Sorry at first I didn't realize you were using block size 3. For our internal detector the smallest block size that we used was 5 and it went up to 29 for the high resolution layers.

from sbnet.

fferroni avatar fferroni commented on May 9, 2024 1

Thanks for all the detailed tips! I got it running fine now :-)

from sbnet.

fferroni avatar fferroni commented on May 9, 2024

Hi,

Digging a bit deeper, I tried to use the library directly rather than the python wrappers...

def divup(a, b):
    return (a+b-1) // b

# Specify input tensor dimensions and block-sparsity parameters
batch = 1
hw = 1000
nb_inputs = 1
nb_outputs = 1
blockSize = [3, 3]
blockStride = [1, 1]
blockOffset = [1, 1]
blockCount = [divup(hw, blockStride[0]), divup(hw, blockStride[1])]

# build kwargs to simplify op calls
inBlockParams = {"dynamic_bsize": blockSize,
                 "dynamic_boffset": blockOffset,
                 "dynamic_bstride": blockStride}
outBlockParams = {"dynamic_bsize": [blockSize[0]-2, blockSize[1]-2],
                  "dynamic_boffset": blockOffset,
                  "dynamic_bstride": blockStride }

sbnet_module = tf.load_op_library('../sbnet_ops/libsbnet.so')

tol = 0.1

with tf.Session() as sess:
    
    # sparse input
    input_placeholder = tf.placeholder(tf.float32, [batch, hw, hw, nb_inputs])
    
    # create a weight tensor
    w = tf.constant( np.ones((3, 3, nb_inputs, nb_outputs)).astype(np.float32) )

    # reduce the input to indices by using a fused pooling+indexing operation
    indices = sbnet_module.reduce_mask(input_placeholder, 
                                       blockCount,
                                       avgpool=False,
                                       tol=tol,
                                       **inBlockParams)
    
    # stack active overlapping tiles to batch dimension
    blockStack = sbnet_module.sparse_gather(
        input_placeholder, 
        indices.bin_counts,
        indices.active_block_indices,
        transpose=True,
        **inBlockParams)
    
    # perform dense convolution on a sparse stack of tiles
    convBlocks = tf.nn.conv2d(
        blockStack, w, strides=[1, 1, 1, 1], padding='VALID', data_format='NCHW')
    
    # write/scatter the tiles back on top of original tensor
    # note that the output tensor is reduced by 1 on each side due to 'VALID' convolution
    validX = tf.zeros([batch, hw, hw, nb_outputs])
    y = sbnet_module.sparse_scatter(
        convBlocks,
        indices.bin_counts,
        indices.active_block_indices,
        validX,
        transpose=True,
        add=False,
        atomic=False,
        **outBlockParams)

    print("Sparse:")
    print("Up to indices:")
    %timeit -n10 sess.run([indices], feed_dict={input_placeholder: grid})
    print("Up to blockStack:")
    %timeit -n10 sess.run([blockStack], feed_dict={input_placeholder: grid})
    print("Up to convBlocks:")
    %timeit -n10 sess.run([convBlocks], feed_dict={input_placeholder: grid})
    print("Up to output")
    y_sparse, = sess.run([y], feed_dict={input_placeholder: grid})
    %timeit -n10 sess.run([y], feed_dict={input_placeholder: grid})
    
    print("Dense:")
    y_dense = tf.nn.conv2d(input_placeholder, w, strides=[1, 1, 1, 1], padding='SAME')
    %timeit -n10 sess.run([y_dense], feed_dict={input_placeholder: grid})
    y_dense_res, = sess.run([y_dense], feed_dict={input_placeholder: grid})

These are the timings:

Sparse:
Up to indices:
42.9 ms ± 705 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Up to blockStack:
112 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Up to convBlocks:
116 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Up to output
147 ms ± 282 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Dense:
5.52 ms ± 302 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

The input is a 1 channel numpy array of 95% sparsity...

from sbnet.

dhingratul avatar dhingratul commented on May 9, 2024

+ 1

from sbnet.

andrei-pokrovsky avatar andrei-pokrovsky commented on May 9, 2024

Using timeit on sess.run is incorrect for many reasons including large overhead from session startup. Proper CUDA time measurement in tensorflow is not trivial and I recommend inspecting the benchmarking code, in particular cuda_timer_start and cuda_timer_end custom ops. You also have to properly preheat the GPU and throw away the initial statistics while the clock ramps up (which could be one of many reason why your observed timings are skewed). There are many other nontrivial details for proper GPU benchmarking that you can find in our benchmarking code including understanding some of the tensorflow and cuDNN internals such as either disabling the internal TF autotuner or letting it run for at least one session.run() just off the top of my head.

Also see my answer here:
https://stackoverflow.com/questions/34293714/can-i-measure-the-execution-time-of-individual-operations-with-tensorflow

from sbnet.

fferroni avatar fferroni commented on May 9, 2024

I appreciate that it might not be perfect due to the overheads; however, even if I run the operations for 100 times, or 1000 times (-n100 instead of -n10), I get quite the same mean timings, and smaller standard deviations. Similarly, if I use the TF Timeline object it's the same. Also, for all practical purposes, the time taken by sess.run is the one people are interested, that will dictate whether people use the operations or not ... and initialization overheads / transferring to GPU should be affecting both the normal convolution as well as this method.

from sbnet.

andrei-pokrovsky avatar andrei-pokrovsky commented on May 9, 2024

In this case also the block size is only 3 which is very suboptimal for this method. I would recommend a block size for a 3x3 convolution kernel should be at least 5-10 so that Winograd convolutions can be leveraged and overlap overhead is reduced.

from sbnet.

fferroni avatar fferroni commented on May 9, 2024

Thanks, using larger block sizes present in the macro helped.

  • Especially helpful to switch on avg-pooling, however, this also causes the output to be substantially different (filtered) to the dense version; which is not something I necessarily want. Comments on this latter point? Doesn't seem to affect your detection rates in the paper, but I worry that information will gradually be lost.

  • It also helps to use sparse_scatter_var with Variable, but this raises a question: how to export this to a frozen protobuf for inference-only. Comments on this?

  • It helps stacking operations in between gather / scatter, as you suggest. In your example you have a 1x1, 3x3, and 1x1 resnet module. Is it possible to do this also with consecutive 3x3 convolutions (i.e. VGG)? I tried but the results look strange, presumably because it messes up the indexing.
    For example, a single 3x3 layer + relu it's:

Sparse:
23 ms ± 477 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)
Dense:
24.7 ms ± 785 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)

And for stacked 1x1, 3x3, 1x1 with relus:

Sparse:
37.2 ms ± 549 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)
Dense:
63.3 ms ± 630 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)

so there is a real need to avoid gathering / scattering unless really necessary.

Cheers

from sbnet.

dhingratul avatar dhingratul commented on May 9, 2024

@andrei-pokrovsky Is there an automated way to tune the hyperparameter for block size/ can we learn it?

from sbnet.

andrei-pokrovsky avatar andrei-pokrovsky commented on May 9, 2024

As a side note with 3x3 block size for 3x3 kernel you will be getting about 9 pixels of overlap per pixel (extra memory bandwidth to replicate that data into a block) so I think about 9x slowdown is expected plus you lose Winograd so i'm slightly surprised you only got 10x slowdown and not 20x as I would roughly expect. It's also possible at those kernel sizes and resolution the GPU compute is underutilized and the convolution is bandwidth bound (i'd have to crunch the numbers to say exactly) which might be why you were only seeing 10x slowdown.

from sbnet.

andrei-pokrovsky avatar andrei-pokrovsky commented on May 9, 2024

Wrt average pooling, if your tolerance cut-off is fixed then average pooling will produce higher sparsity than maxpool and GPU can become underutilized for such a small tensor/high sparsity which is my current guess why you are seeing a lower speedup. GPU needs a decent amount of work left over after block gather to realize the close to linear speedup or same speedup as for maxpool. This is just my guess right now, I haven't ran your code yet.

For inference there are many alternate solutions such as TensorRT (SBNet could be integrated using IPlugin interface) and ONNX with Tensor RT backend. You also don't necessarily have to freeze the graph for inference in TensorFlow.

Another alternative is using nvGraphAddNode API etc.

There are many inferencing solutions out there, getting the perfect one might require some engineering work. I personally find that TensorFlow out of the box is too heavyweight for production inference and it's not that hard to roll your own inference mini-framework if you are not trying to be middleware and just trying to solve your specific application, you just capture the graph, export dependencies and parameters, do a topological sort on the graph and you are basically done. There's some work there but you have full control over your source as opposed to close source libraries like TensorRT. Plus you can always splice subgraphs from TRT and have more control that way.

Wrt session overhead, i don't think timing a session with timeit is representative of actual workload when you have a full network wrapped in a single session. So i recommend using the provided timing operation based on CUDA event wrappers - refer to my stackoverflow post i referenced earlier for details on why.

HTH

from sbnet.

andrei-pokrovsky avatar andrei-pokrovsky commented on May 9, 2024

If you really want to continue using timeit for timing a single sparse convolution I recommend wrapping 50-100 repeated subgraphs with single sparse convolution into a single session which will reduce the session overhead to be more representative of full end-to-end inference of a single session run for a full network. In this scenario pay attention to feed different randomized inputs into each separate convolution and make sure all the outputs are consumed, otherwise tensorflow will be aggressive about optimizing out unused subgraphs or subgraphs with duplicated inputs.

from sbnet.

andrei-pokrovsky avatar andrei-pokrovsky commented on May 9, 2024

@dhingratul Wrt autotuning - autotuner in this case is just trying all the different block sizes and picks the fastest. You can check how we do it in our timing code that is included with the distribution. We don't try non-square block sizes, which could potentially work better for a particular sparsity patter (such as for instance 16x8, 32x4 etc) but those template instantiations should be added explicitly to C++ code as described in the readme. More generally speaking, autotuning CUDA kernels is a fairly extensive subject in of itself and is a bit outside of scope of this project but you can look at projects like NVIDIA's jitify etc.

from sbnet.

fferroni avatar fferroni commented on May 9, 2024

Hello @andrei-pokrovsky
I was also wondering about the following difference:

sbnet_module.reduce_mask(x, ...

or

sbnet_module.reduce_mask(x[..., 0:1], ...

x is a tensor with channels > 1, and the first way is how you use it in the sparse_conv_lib.py file.

However, I am unsure how you select a block for channels greater than 1 when having a defined numerical mask threshold. Do you do a sum internally along the channel direction, or a max?

I cannot get the 'expected' output compared to a normal dense convolution if I use the first option, but can get an identical output to a dense convolution if I use the second one.

from sbnet.

andrei-pokrovsky avatar andrei-pokrovsky commented on May 9, 2024

Right now the implementation expects/requires a tensor of shape [N,H,W,1] for the mask. This is somewhat redundant since it's always a single channel. In the end it's either max or average pooling per block, so you can come up with your own way of reducing C channels to 1, such as averaging across channels or some other way.

from sbnet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.