|
| 1 | +# Development Notes |
| 2 | + |
| 3 | + |
| 4 | +This python package provides a PyTorch extension . |
| 5 | + |
| 6 | + |
| 7 | +## Organisation |
| 8 | +### Build |
| 9 | + |
| 10 | +The setup.py script use the standard PyTorch extension mechanism to build the package: |
| 11 | + |
| 12 | +``` |
| 13 | +from torch.utils.cpp_extension import BuildExtension, CUDAExtension |
| 14 | +... |
| 15 | + ext_modules=[ |
| 16 | + CUDAExtension('block_sparse_native', |
| 17 | + ['pytorch_block_sparse/native/block_sparse_native.cpp', |
| 18 | + 'pytorch_block_sparse/native/block_sparse_cutlass_kernel_back.cu', |
| 19 | + 'pytorch_block_sparse/native/block_sparse_cutlass_kernel.cu'], |
| 20 | + extra_compile_args=['-I', '%s/pytorch_block_sparse' % rootdir] |
| 21 | + ), |
| 22 | + ], |
| 23 | + cmdclass={ |
| 24 | + 'build_ext': BuildExtension |
| 25 | + } |
| 26 | +``` |
| 27 | + |
| 28 | +### Native functions python interface |
| 29 | +A single c++ file `block_sparse_native.cpp` provides the native functions visible from python. |
| 30 | +These functions provides access to CUDA kernels which computes : |
| 31 | + - dense x native -> dense |
| 32 | + - dense x dense on sparse support -> sparse |
| 33 | + |
| 34 | +### CUDA/Cutlass kernels |
| 35 | +The `*.cu` files in the `native` directory provides the kernel themselves. |
| 36 | +They are using the cutlass primitives available in the `cutlass` subdirectory. |
| 37 | + |
| 38 | +Multiple levels of C++ templating provides dispatch/code generation of the kernels. |
| 39 | + |
| 40 | +The main files in the `cutlass/gemm` directory are `block_task.h` and `block_task_back.h` . |
| 41 | +They express the final CUDA kernel that will be executed, using |
| 42 | +- `block_loader_.*` to load A and B matrix tiles in an efficient way |
| 43 | +- `thread_accumulator.h` to store the result tiles 'R' |
| 44 | +- `epilogue_function` to combine R with C `C' = alpha * R + beta * C` |
| 45 | +- `grid_raster_.*` to list the output tiles that must be computed |
| 46 | + |
| 47 | +### block_sparse python module |
| 48 | +This library includes as little native code as possible, because native code is hard to write/debug/understand. |
| 49 | + |
| 50 | +The native functions are performing the performance critical tasks, and the python code in `block_sparse.py` is doing |
| 51 | +all the preparatory work, which is executed only once, or a unfrequently. |
| 52 | + |
| 53 | +The main job of `block_sparse.py` is to build indexes into the sparse matrices. |
| 54 | +Three sets of sparse indices are built: |
| 55 | +- row wise index of non-zero entries (for dense x sparse) |
| 56 | +- column wise index of non-zero entries (for dense x sparse with transposition) |
| 57 | +- linear list of 2D coordinates of non-zero entries (for dense x dense on sparse support) |
| 58 | + |
| 59 | +These structures are created using standard PyTorch primitives, and so are easy to debug, understand, |
| 60 | +or reimplement in other languages. |
| 61 | + |
| 62 | +### block_sparse_linear python module |
| 63 | +The block_sparse_linear is a thin layer on top of `block_sparse` |
| 64 | +It use the linear algebra primitives of block_sparse to create a drop in replacement for `torch.nn.Linear`, |
| 65 | +with the proper back-propagation primitives, implemented using a `torch.autograd.Function` subclass. |
| 66 | + |
| 67 | +## Testing |
| 68 | +Debugging CUDA kernels is hard. Fortunately, it's easy to compare the kernel results with |
| 69 | +a reference PyTorch implementation. |
| 70 | +The `tests` directory provides some code to test and measure performance of the library. |
| 71 | + |
| 72 | +## TODO |
| 73 | + |
| 74 | +block_sparse |
| 75 | +- add input parameters sanity checks |
| 76 | +- add dispatch for |
| 77 | + - different matrix size -> different dispatch strategy (tile sizes in k-dimension) |
| 78 | + - different block sizes |
| 79 | + |
| 80 | +tests |
| 81 | + - Refactor/cleanup tests |
| 82 | + |
| 83 | +doc |
| 84 | +- schema of sparse index structures |
| 85 | + |
| 86 | +cutlass |
| 87 | +- move to 2.x version |
| 88 | + |
| 89 | +cleanup algorithms |
| 90 | +- add algorithms to measure weights importance and optimize the sparsity pattern |
| 91 | + |
| 92 | + |
| 93 | + |
| 94 | + |
| 95 | + |
0 commit comments