diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 6699b70bad0d7..82baf958553ee 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -135,8 +135,8 @@ else() endif() FetchContent_Declare( ONEMATH - GIT_REPOSITORY https://github.com/uxlfoundation/oneMath.git - GIT_TAG c255b1b4c41e2ee3059455c1f96a965d6a62568a + GIT_REPOSITORY https://github.com/EwanC/oneMath.git + GIT_TAG 671d9bcc5aa6cc52cce6b6518c9b8f126c0352f4 ) FetchContent_MakeAvailable(ONEMATH) # Create alias to match with find_package targets name diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 66b6f2cca4da9..a4f65d8c99644 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3639,11 +3639,37 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc } } +#ifdef GGML_SYCL_GRAPH +static bool check_node_graph_compatibility(ggml_cgraph * cgraph) { + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + switch(node->op) { + default: break; + case GGML_OP_CONCAT: + // ggml_sycl_op_concat() does a blocking host wait after memcpy operations, + // but wait() can't be called on the events returned by a queue recording + // to a graph. + [[fallthrough]]; + case GGML_OP_MUL_MAT_ID: + // ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after + // submitting a memcpy operation, but wait() can't be called on a queue that + // is recording to a graph. +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling SYCL graphs due to unsupported node type\n", __func__); +#endif + return false; + } + } + return true; +} +#endif + static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { auto * sycl_ctx = static_cast(backend->context); #ifdef GGML_SYCL_GRAPH - if (!g_ggml_sycl_disable_graph) { + bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_node_graph_compatibility(cgraph); + if (use_sycl_graph) { const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph); if (!graph_support) { GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);