Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Template recursion-free static_ford #2011

Open
wants to merge 29 commits into
base: develop
Choose a base branch
from
Open

Conversation

tenpercent
Copy link
Contributor

@tenpercent tenpercent commented Mar 25, 2025

Proposed changes

The current implementation of static_ford uses recursion over the sequence of dimensions and instantiates lambdas at each internal step. This can be avoided and should improve compilation time. The current most time consuming template is static_for(0,1) from static_ford instantiation internals for both old CK and CK-tile

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

Snapshot -

ck::static_ford<ck::Sequence<14, 14, 16>>{} ([] (auto I) {
        I.Print();
    });
*** Old
 10703 ms: ck::static_for<0, 14, 1>::operator()<(lambda at ../include/ck/utility/functional3.hpp:31:52)> (15 times, avg 713 ms)
 10701 ms: ck::detail::static_for_impl<ck::Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13>>::operator()<(lambda at ../include/ck/utility/functional3.hpp:31:52)> (15 times, avg 713 ms)
  5364 ms: ck::static_ford<ck::Sequence<14, 14, 16>>::operator()<(lambda at main.cpp:114:50)> (1 times, avg 5364 ms)
  5358 ms: ck::detail::static_ford_impl<const ck::Sequence<14, 14, 16>, ck::Sequence<0, 1, 2>>::operator()<(lambda at main.cpp:114:50), ck::Sequence<>> (1 times, avg 5358 ms)
  5210 ms: ck::static_for<0, 16, 1>::operator()<(lambda at ../include/ck/utility/functional3.hpp:31:52)> (196 times, avg 26 ms)
  5184 ms: ck::detail::static_for_impl<ck::Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15>>::operator()<(lambda at ../include/ck/utility/functional3.hpp:31:52)> (196 times, avg 26 ms)
  1844 ms: ck::static_for<0, 3, 1>::operator()<(lambda at ../include/ck/utility/sequence.hpp:182:36)> (3136 times, avg 0 ms)
   401 ms: ck::detail::static_ford_impl<const ck::Sequence<14, 14, 16>, ck::Sequence<0, 1, 2>>::operator()((lambda at main.cpp:114:50), ck::Sequence<>)::(anonymous class)::operator()<ck::integral_constant<int, 7>> (1 times, avg 401 ms)
   400 ms: ck::detail::static_ford_impl<ck::Sequence<14, 16>, ck::Sequence<0, 1, 2>>::operator()<(lambda at main.cpp:114:50), ck::Sequence<7>> (1 times, avg 400 ms)
   399 ms: ck::detail::static_ford_impl<const ck::Sequence<14, 14, 16>, ck::Sequence<0, 1, 2>>::operator()((lambda at main.cpp:114:50), ck::Sequence<>)::(anonymous class)::operator()<ck::integral_constant<int, 11>> (1 times, avg 399 ms)
   398 ms: ck::detail::static_ford_impl<ck::Sequence<14, 16>, ck::Sequence<0, 1, 2>>::operator()<(lambda at main.cpp:114:50), ck::Sequence<11>> (1 times, avg 398 ms)
...
*** New
  2978 ms: ck::static_for<0, 3, 1>::operator()<(lambda at ../include/ck/utility/sequence.hpp:182:36)> (3136 times, avg 0 ms)
  2639 ms: ck::detail::static_for_impl<ck::Sequence<0, 1, 2>>::operator()<(lambda at ../include/ck/utility/sequence.hpp:182:36)> (3136 times, avg 0 ms)
  1457 ms: ck::static_ford<ck::Sequence<14, 14, 16>>::operator()<(lambda at main.cpp:114:50), ck::static_ford<ck::Sequence<14, 14, 16>>::convert_t> (1 times, avg 1457 ms)
  1457 ms: ck::detail::applier<int, 0, 1, 2, 3, 4, 5, 6, 7... (1 times, avg 1457 ms)
    20 ms: ck::Sequence<9, 12, 13>::Print (1 times, avg 20 ms)
...

@tenpercent tenpercent changed the title Special static for d Recursion-free static_fordd Mar 25, 2025
@tenpercent tenpercent changed the title Recursion-free static_fordd Template recursion-free static_ford Mar 25, 2025
@@ -4,6 +4,7 @@
#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/utility/functional3.hpp"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static_ford is used but not included explicitly

@@ -547,6 +547,10 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS
add_compile_options(-fdiagnostics-color=always)
endif()

# fold expression depth
# device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp
add_compile_options(-fbracket-depth=3136)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternative is to implement some sort of paging for the parameter packs, not sure how to do it yet

__host__ __device__ constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
base::template operator()<F, IndexTransform>(f);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possible improvement here is to figure out how to bake index transform into the base class, to do using base::operator() and avoid one extra template instantiation

using TCumProd = typename make_cumulative_product<Dims...>::type;

template <ck::index_t flat_idx>
using type = decltype((TCumProd{} * ck::Number<flat_idx>{} / ck::Number<Prod>{}) % SDim{});
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my initial expectation was this would be fairly cheap but it is not


// clang-format off
template <int32_t... Idims>
struct make_cumulative_product
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recursive template definition seems doable but complex

// clang-format off

template <int32_t IDim0>
constexpr auto make_cumulative_product(ck::Number<IDim0>)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

functions are faster than structs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant