Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/ops/SYCL.csv
Original file line number Diff line number Diff line change
Expand Up @@ -6025,7 +6025,7 @@
"SYCL0","GROUP_NORM","type=f32,ne=[9,9,1280,1],num_groups=32,eps=0.000001","support","1","yes","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]","support","1","yes","SYCL"
"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1","support","1","yes","SYCL"
"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","no","SYCL"
"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","yes","SYCL"
"SYCL0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","0","no","SYCL"
"SYCL0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","0","no","SYCL"
"SYCL0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","SYCL"
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-sycl/backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,7 @@
#include "softmax.hpp"
#include "tsembd.hpp"
#include "wkv.hpp"
#include "pad_reflect_d1.hpp"


#endif // GGML_SYCL_BACKEND_HPP
7 changes: 7 additions & 0 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3673,6 +3673,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_CONCAT:
ggml_sycl_op_concat(ctx, dst);
break;
case GGML_OP_PAD_REFLECT_1D:
ggml_sycl_op_pad_reflect_d1(ctx,dst);
break;
case GGML_OP_UPSCALE:
ggml_sycl_upscale(ctx, dst);
break;
Expand Down Expand Up @@ -4369,6 +4372,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_PAD_REFLECT_1D:
return ggml_is_contiguous(op->src[0]) &&
op-> type == GGML_TYPE_F32 &&
op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LOG:
#if defined (GGML_SYCL_F16)
return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type));
Expand Down
77 changes: 77 additions & 0 deletions ggml/src/ggml-sycl/pad_reflect_d1.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include "pad_reflect_d1.hpp"

void pad_reflect_d1_f32(const float* src,float* dst,
const int64_t ne0, const int64_t ne02, const int p0, const int p1,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const sycl::nd_item<3> &item_ct1){

const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0);
const int i1 = item_ct1.get_group(1);
const int g2 = item_ct1.get_group(2);
const int i2 = g2 % ne02;
const int i3 = g2 / ne02;

if (i0 >= p0 + ne0 + p1) return;

int t = i0 - p0;
int period = 2 * ne0 -2;
int m = t % period;
m += (m < 0) * period;
int center = ne0 -1;
int srci0 = center - abs(center - m);

int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0;
int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00;
dst[offest_dst] = src[offest_src];

}

void ggml_sycl_op_pad_reflect_d1(ggml_backend_sycl_context& ctx, ggml_tensor* dst){

const ggml_tensor * src0 = dst->src[0];
queue_ptr stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

const int32_t * opts = (const int32_t *) dst->op_params;
const int p0 = opts[0];
const int p1 = opts[1];

const int64_t ne0 = src0->ne[0];

const int64_t ne00 = dst->ne[0];
const int64_t ne01 = dst->ne[1];
const int64_t ne02 = dst->ne[2];
const int64_t ne03 = dst->ne[3];

const int64_t nb00 = dst->nb[0];
const int64_t nb01 = dst->nb[1];
const int64_t nb02 = dst->nb[2];
const int64_t nb03 = dst->nb[3];
const int64_t nb0 = src0->nb[0];
const int64_t nb1 = src0->nb[1];
const int64_t nb2 = src0->nb[2];
const int64_t nb3 = src0->nb[3];

int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;

sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03);
sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1);

stream->parallel_for(
sycl::nd_range<3>(global,
local),
[=](sycl::nd_item<3> item_ct1) { pad_reflect_d1_f32(
(const float *) src0->data, (float *) dst->data,
ne0, ne02, p0, p1,
nb0, nb1, nb2, nb3,
nb00, nb01, nb02, nb03
, item_ct1);
});
}




8 changes: 8 additions & 0 deletions ggml/src/ggml-sycl/pad_reflect_d1.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef GGML_SYCL_PAD_REFLECT_D1_HPP
#define GGML_SYCL_PAD_REFLECT_D1_HPP

#include "common.hpp"

void ggml_sycl_op_pad_reflect_d1(ggml_backend_sycl_context& ctx, ggml_tensor* dst);

#endif // GGML_SYCL_PAD_REFLECT_D1_HPP