I went into the
TensorRT documentation and a
different Nvidia whitepaper to look into this more, and here's what I think is going on to the best of my understanding. I'll separate this post into two sections on sparsity workflow and the theoretical 2x speedup.
Sparsity workflow
Each layer of the neural network has to have 0s in the weights acting on 2 out of each set of 4 channels to use 2:4 sparsity.
When you train a dense neural network, this doesn't happen naturally, but some of the weights may be very small. You can prune/clip the smallest 2 out of every 4 weights to zero and force the network to have structured sparsity, but you will lose accuracy in the process. However, there is a tool called Automatic SParsity (ASP) that allows you to do a second training step after pruning to further optimize the weights and reduce the accuracy loss. I think this is the part that you are talking about with the extra optimization step.
The readme at the Github repo for ASP shows the standard way to use this for a pretrained dense model. Basically, you train the dense model, prune the values you don't need, then do the second training step with the sparse weights until the difference in accuracy is negligible:
There's a nice figure in the whitepaper that diagrams what this multistep training process would look like:
2x speedup
My best guess for why it's "up to 2x" performance is that, in some cases, structured sparsity is not faster than normal performance, even though there's twice as much arithmetic throughput. TensorRT will tell you which layers fit these criteria.
With a normal matrix, you can store all of the values in contiguous arrays, which are very fast to access, but with sparse matrices, you have to come up with some compressed representation that stores where each data point is located in the matrix. From the whitepaper:
There is an extra mapping step that you have to do with this metadata to get the sparse data in the right position for the accumulator in the tensor cores. For small matrices, this mapping step and the associated overhead from sparsity is significant, so the speedup over dense matrices is not fully realized.
For a GEMM operation, Nvidia uses MxNxK notation to denote multiplying an MxK matrix by a KxN matrix. As K gets larger, the arithmetic work becomes the main bottleneck, and you asymptotically approach the theoretical 2x speedup. I believe M is the number of output channels, N is the size of the input image times the batch size, and K is proportional to the kernel size and the number of channels in the layer, e.g. for a 3x3 kernel size and 6 input channels, K would be 54 (3 * 3 * 6). Essentially, every row in the MxK matrix is a learned filter that acts on the input data in the columns of the KxN matrix.
Since DLSS runs in real time in a couple milliseconds, it is likely on the lower end of this plot. For reference, the Facebook neural supersampling paper had only 128 channels in its deepest layer and used 3x3 kernels, which I believe would correspond to a GEMM-K of 1152 (128 * 3 * 3).