Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement LogCumSumExp #3640

Open
wants to merge 65 commits into
base: develop
Choose a base branch
from
Open

Implement LogCumSumExp #3640

wants to merge 65 commits into from

Conversation

anhskrttt
Copy link
Collaborator

  • Added LogCumSumExp [ref] forward and backward operations and kernels.
  • This implementation works when value of the cumulative operated dimension is less than the highest number of threads inside a kernel block. Otherwise, it returns invalid results.
  • Added driver test and gtest for LogCumSumExp operations.
  • New API is guarded by MIOPEN_BETA_API macro.

Average improvement over ROCm

type fwd bwd
float 13.72 15.37
float16 - -
bfloat16 13.11 15.38

Detail Benchmark

float32 Forward
op_name dtype size dim contiguous direction ROCm pytorch MIOpen HIP Improvement
LogCumSumExp float32 [512 64 112 112] -1 contiguous fwd 145809128 7904310 18.44678764
LogCumSumExp float32 [512 64 56 56] -1 contiguous fwd 69732172 2374840 29.36289266
LogCumSumExp float32 [512 128 56 56] -1 contiguous fwd 139499742 4745900 29.39373817
LogCumSumExp float32 [512 128 28 28] -1 contiguous fwd 69528870 2347620 29.61674803
LogCumSumExp float32 [512 256 28 28] -1 contiguous fwd 139033407 4694170 29.61831527
LogCumSumExp float32 [512 256 14 14] -1 contiguous fwd 69378332 2329770 29.77904772
LogCumSumExp float32 [512 512 14 14] -1 contiguous fwd 138751082 4658060 29.78731103
LogCumSumExp float32 [512 512 7 7] -1 contiguous fwd 69328528 2316700 29.92555273
LogCumSumExp float32 [512 1024 7 7] -1 contiguous fwd 138644482 4633250 29.92380769
LogCumSumExp float32 [512 1024 100] -1 contiguous fwd 20109653 1117820 17.9900637
LogCumSumExp float32 [1024 1024 7 7] -1 contiguous fwd 277302592 11619000 23.8663045
LogCumSumExp float32 [1024 1024 100] -1 contiguous fwd 40241550 2233520 18.01709857
LogCumSumExp float32 [64 112 112 512] -1 contiguous fwd 43455625 7936570 5.475365933
LogCumSumExp float32 [64 56 56 512] -1 contiguous fwd 5013870 1983640 2.527610857
LogCumSumExp float32 [128 56 56 512] -1 contiguous fwd 21722237 3965290 5.478095423
LogCumSumExp float32 [128 28 28 512] -1 contiguous fwd 2557599 993785 2.573593886
LogCumSumExp float32 [256 28 28 512] -1 contiguous fwd 5011870 1983890 2.526284219
LogCumSumExp float32 [256 14 14 512] -1 contiguous fwd 1330935 498963 2.667402192
LogCumSumExp float32 [512 14 14 512] -1 contiguous fwd 2558607 994034 2.573963265
LogCumSumExp float32 [512 7 7 512] -1 contiguous fwd 682939 251962 2.710484121
LogCumSumExp float32 [1024 7 7 512] -1 contiguous fwd 1330327 499088 2.665515901
LogCumSumExp float32 [1024 100 512] -1 contiguous fwd 2564831 1014140 2.529069951
LogCumSumExp float32 [1024 7 7 1024] -1 contiguous fwd 2653214 1187580 2.234134964
LogCumSumExp float32 [1024 100 1024] -1 contiguous fwd 5115869 2420190 2.113829493
float32 Backward
op_name dtype size dim contiguous direction ROCm pytorch MIOpen HIP Improvement
LogCumSumExp float32 [512 64 112 112] -1 contiguous bwd 341245668 15222100 22.41777862
LogCumSumExp float32 [512 64 56 56] -1 contiguous bwd 154188858 3600300 42.82666944
LogCumSumExp float32 [512 128 56 56] -1 contiguous bwd 309668527 7199660 43.01154874
LogCumSumExp float32 [512 128 28 28] -1 contiguous bwd 146470549 3496350 41.89241609
LogCumSumExp float32 [512 256 28 28] -1 contiguous bwd 292874323 6990730 41.89466951
LogCumSumExp float32 [512 256 14 14] -1 contiguous bwd 142612281 3445940 41.38559609
LogCumSumExp float32 [512 512 14 14] -1 contiguous bwd 286393459 6889560 41.5691944
LogCumSumExp float32 [512 512 7 7] -1 contiguous bwd 140653822 3396510 41.41127864
LogCumSumExp float32 [512 1024 7 7] -1 contiguous bwd 281166630 6790880 41.40356331
LogCumSumExp float32 [512 1024 100] -1 contiguous bwd 47806229 2152690 22.20766994
LogCumSumExp float32 [1024 1024 7 7] -1 contiguous bwd 562217927 13587700 41.37697528
LogCumSumExp float32 [1024 1024 100] -1 contiguous bwd 95480131 4303720 22.1854886
LogCumSumExp float32 [64 112 112 512] -1 contiguous bwd 143977019 15467800 9.308176922
LogCumSumExp float32 [64 56 56 512] -1 contiguous bwd 24262716 3865060 6.277448733
LogCumSumExp float32 [128 56 56 512] -1 contiguous bwd 71758202 7731350 9.281458219
LogCumSumExp float32 [128 28 28 512] -1 contiguous bwd 12271517 1936850 6.335811756
LogCumSumExp float32 [256 28 28 512] -1 contiguous bwd 24321708 3866090 6.291035129
LogCumSumExp float32 [256 14 14 512] -1 contiguous bwd 6260454 971563 6.443693307
LogCumSumExp float32 [512 14 14 512] -1 contiguous bwd 12272334 1937170 6.335186896
LogCumSumExp float32 [512 7 7 512] -1 contiguous bwd 3209674 489470 6.557447852
LogCumSumExp float32 [1024 7 7 512] -1 contiguous bwd 6262886 971367 6.447497187
LogCumSumExp float32 [1024 100 512] -1 contiguous bwd 12416076 1975960 6.283566469
LogCumSumExp float32 [1024 7 7 1024] -1 contiguous bwd 12463612 2356880 5.288182682
LogCumSumExp float32 [1024 100 1024] -1 contiguous bwd 24801400 4803570 5.163118264
float16 Forward

This operation doesn't work with float16 tensors using ROCm/Pytorch.

float16 Backward

This operation doesn't work with float16 tensors using ROCm/Pytorch.

bfloat16 Forward
op_name dtype size dim contiguous direction ROCm pytorch MIOpen HIP Improvement
LogCumSumExp bfloat16 [512 64 112 112] -1 contiguous fwd 155711190 7766570 20.04890061
LogCumSumExp bfloat16 [512 64 56 56] -1 contiguous fwd 78634221 2369570 33.18501711
LogCumSumExp bfloat16 [512 128 56 56] -1 contiguous fwd 154616908 4737550 32.63646991
LogCumSumExp bfloat16 [512 128 28 28] -1 contiguous fwd 77031384 2345150 32.84710317
LogCumSumExp bfloat16 [512 256 28 28] -1 contiguous fwd 155450902 4690190 33.14383895
LogCumSumExp bfloat16 [512 256 14 14] -1 contiguous fwd 76914969 2329770 33.0139752
LogCumSumExp bfloat16 [512 512 14 14] -1 contiguous fwd 153749491 4657050 33.01435265
LogCumSumExp bfloat16 [512 512 7 7] -1 contiguous fwd 76870091 2317380 33.1711204
LogCumSumExp bfloat16 [512 1024 7 7] -1 contiguous fwd 153749893 5881890 26.13953899
LogCumSumExp bfloat16 [512 1024 100] -1 contiguous fwd 22237226 1104490 20.1334788
LogCumSumExp bfloat16 [1024 1024 7 7] -1 contiguous fwd 307385054 9265820 33.17408001
LogCumSumExp bfloat16 [1024 1024 100] -1 contiguous fwd 44433574 2205060 20.15073241
LogCumSumExp bfloat16 [64 112 112 512] -1 contiguous fwd 45994588 8012800 5.740139277
LogCumSumExp bfloat16 [64 56 56 512] -1 contiguous fwd 5199885 2002470 2.596735532
LogCumSumExp bfloat16 [128 56 56 512] -1 contiguous fwd 23011078 4003970 5.747065537
LogCumSumExp bfloat16 [128 28 28 512] -1 contiguous fwd 2649870 1003100 2.64168079
LogCumSumExp bfloat16 [256 28 28 512] -1 contiguous fwd 5200894 2002679 2.596968361
LogCumSumExp bfloat16 [256 14 14 512] -1 contiguous fwd 1375703 503568 2.731911083
LogCumSumExp bfloat16 [512 14 14 512] -1 contiguous fwd 2650126 1002850 2.642594605
LogCumSumExp bfloat16 [512 7 7 512] -1 contiguous fwd 1908003 254218 7.505381208
LogCumSumExp bfloat16 [1024 7 7 512] -1 contiguous fwd 1375319 503923 2.729224505
LogCumSumExp bfloat16 [1024 100 512] -1 contiguous fwd 2657150 1023539 2.596041773
LogCumSumExp bfloat16 [1024 7 7 1024] -1 contiguous fwd 2742302 1202620 2.28027307
LogCumSumExp bfloat16 [1024 100 1024] -1 contiguous fwd 5305517 2450040 2.165481788
bfloat16 Backward
op_name dtype size dim contiguous direction ROCm pytorch MIOpen HIP Improvement
LogCumSumExp bfloat16 [512 64 112 112] -1 contiguous bwd 351717497 15334900 22.93575419
LogCumSumExp bfloat16 [512 64 56 56] -1 contiguous bwd 164634038 3604780 45.67103624
LogCumSumExp bfloat16 [512 128 56 56] -1 contiguous bwd 329388221 7209500 45.68808114
LogCumSumExp bfloat16 [512 128 28 28] -1 contiguous bwd 159209435 3536600 45.01765396
LogCumSumExp bfloat16 [512 256 28 28] -1 contiguous bwd 318429575 8015910 39.72469439
LogCumSumExp bfloat16 [512 256 14 14] -1 contiguous bwd 156468733 3479140 44.97339371
LogCumSumExp bfloat16 [512 512 14 14] -1 contiguous bwd 312781554 6956140 44.96481583
LogCumSumExp bfloat16 [512 512 7 7] -1 contiguous bwd 155144215 3436940 45.1402163
LogCumSumExp bfloat16 [512 1024 7 7] -1 contiguous bwd 310310835 6874340 45.14045494
LogCumSumExp bfloat16 [512 1024 100] -1 contiguous bwd 49671856 2181040 22.7743902
LogCumSumExp bfloat16 [1024 1024 7 7] -1 contiguous bwd 620280020 13746200 45.12374474
LogCumSumExp bfloat16 [1024 1024 100] -1 contiguous bwd 99066678 4358260 22.7307866
LogCumSumExp bfloat16 [64 112 112 512] -1 contiguous bwd 131557996 15735600 8.36053255
LogCumSumExp bfloat16 [64 56 56 512] -1 contiguous bwd 20352407 3931210 5.177135538
LogCumSumExp bfloat16 [128 56 56 512] -1 contiguous bwd 65836374 7861480 8.37455212
LogCumSumExp bfloat16 [128 28 28 512] -1 contiguous bwd 10347722 1968140 5.257614804
LogCumSumExp bfloat16 [256 28 28 512] -1 contiguous bwd 20364119 3931120 5.180233369
LogCumSumExp bfloat16 [256 14 14 512] -1 contiguous bwd 5309644 987474 5.376996255
LogCumSumExp bfloat16 [512 14 14 512] -1 contiguous bwd 10341083 1967850 5.25501588
LogCumSumExp bfloat16 [512 7 7 512] -1 contiguous bwd 2735166 497826 5.494220872
LogCumSumExp bfloat16 [1024 7 7 512] -1 contiguous bwd 5307532 987901 5.372534292
LogCumSumExp bfloat16 [1024 100 512] -1 contiguous bwd 10427898 2008300 5.192400538
LogCumSumExp bfloat16 [1024 7 7 1024] -1 contiguous bwd 10520905 2365660 4.447344504
LogCumSumExp bfloat16 [1024 100 1024] -1 contiguous bwd 20781652 4823750 4.308194247

long10024070 and others added 30 commits June 19, 2024 10:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants