https://github.com/huggingface/pytorch_block_sparse/blob/0985083851a5708cfb3adf50da19860f467e51ae/pytorch_block_sparse/block_sparse_linear.py#L141