Skip to content

Commit 9c8da66

Browse files
yongweiyppotapov-aws
authored andcommitted
add nisa.nc_n_gather in place of nl.gather_flattened
1 parent 11e8925 commit 9c8da66

File tree

9 files changed

+81
-2
lines changed

9 files changed

+81
-2
lines changed

KLR/Core/Basic.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ partial def operatorBasicTensors : Operator → List TensorRef
181181
| .setRngSeed r | .randSetState r => [r.src]
182182
| .extendedInst _ => []
183183
| .tensorScalarCumulative t => [t.dst, t.src]
184+
| .ncNGather g => [g.dst, g.data, g.indices]
184185

185186
partial def operatorAdditionalTensors : Operator → List TensorName
186187
| .ncActivate d => (tensors d.scale) ++ (tensors d.bias) ++ (tensors d.reduceRes)
@@ -202,6 +203,7 @@ partial def operatorAdditionalTensors : Operator → List TensorName
202203
| .recv r => tensors r.dsts
203204
| .rand2 r => tensors r.min ++ tensors r.max
204205
| .tensorScalarCumulative t => (tensors t.imm0) ++ (tensors t.imm1)
206+
| .ncNGather _ => []
205207
| _ => []
206208

207209
instance : Tensors Operator where

KLR/Core/LowerAP.lean

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ def Operator.lowerAccessPatterns (k : Operator) : KLR.Err Operator :=
197197
imm0 := <- Operand.lowerAccessPatterns op.imm0
198198
imm1 := <- op.imm1.mapM Operand.lowerAccessPatterns
199199
}
200+
| .ncNGather op => return .ncNGather { op with
201+
dst := <- op.dst.lowerAccessPatterns
202+
data := <- op.data.lowerAccessPatterns
203+
indices := <- op.indices.lowerAccessPatterns
204+
}
200205

201206
def Stmt.lowerAccessPatterns : Stmt → KLR.Err Stmt
202207
| .oper op name pos => return .oper (<- op.lowerAccessPatterns) name pos

KLR/Core/Operators.lean

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,14 @@ structure TensorScalarCumulative where
944944
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp
945945

946946
@[serde tag = 209]
947+
structure NcNGather where
948+
dst : TensorRef
949+
data : TensorRef
950+
indices : TensorRef
951+
dtype : Option Dtype
952+
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp
953+
954+
@[serde tag = 210]
947955
inductive Operator where
948956
| activate (op : Activate)
949957
| ncActivate (op : NcActivate)
@@ -1012,9 +1020,10 @@ inductive Operator where
10121020
| randSetState (op : RandSetState)
10131021
| extendedInst (op : ExtendedInst)
10141022
| tensorScalarCumulative (op: TensorScalarCumulative)
1023+
| ncNGather (op: NcNGather)
10151024
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp
10161025

1017-
@[serde tag = 210]
1026+
@[serde tag = 211]
10181027
inductive TGROperator where
10191028
| activate (op : Activate)
10201029
| affineSelect (op : AffineSelect)

KLR/Extract/Extract/Basic.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def klrAST: MetaM (List LeanType) := do
367367
`KLR.Core.RandSetState,
368368
`KLR.Core.ExtendedInst,
369369
`KLR.Core.TensorScalarCumulative,
370+
`KLR.Core.NcNGather,
370371
`KLR.Core.Operator,
371372
-- Core.Basic
372373
`KLR.Core.Stmt,

KLR/Trace/ISA.lean

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,3 +1180,19 @@ nki builtin.isa.extended_inst
11801180
data1
11811181
}) name
11821182
return .none
1183+
1184+
nki builtin.isa.nc_n_gather
1185+
(dst: Access)
1186+
(data: Access)
1187+
(indices: Access)
1188+
(mask: Option Immediate := none)
1189+
(dtype: Option Dtype := none)
1190+
(name : Option String := none) := do
1191+
if mask.isSome then throw maskNotSupported
1192+
Trace.add_stmt $ .oper (.ncNGather {
1193+
dst := .abstract dst
1194+
data := .abstract data
1195+
indices := .abstract indices
1196+
dtype := dtype.or dst.tensor.dtype
1197+
}) name
1198+
return .none

interop/klr/NKI.asdl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,8 @@ RandSetState = (TensorRef src, Engine engine)
339339
ExtendedInst = (Nat opcode, Bool hasRead, Bool hasWrite, Nat ports, Nat* data0, Nat* data1)
340340

341341
TensorScalarCumulative = (TensorRef dst, TensorRef src, AluOp op0, AluOp op1, Operand imm0, Operand? imm1, AccumCmd reduceCmd, TensorScalarReverseOps reverse, Dtype? dtype)
342+
343+
NcNGather = (TensorRef dst, TensorRef data, TensorRef indices, Dtype? dtype)
342344
Operator =
343345
| activate(Activate op)
344346
| ncActivate(NcActivate op)
@@ -407,6 +409,7 @@ Operator =
407409
| randSetState(RandSetState op)
408410
| extendedInst(ExtendedInst op)
409411
| tensorScalarCumulative(TensorScalarCumulative op)
412+
| ncNGather(NcNGather op)
410413

411414
Stmt =
412415
| oper(Operator op, String? name, Pos pos)

interop/klr/klir_ast.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,13 @@ struct TensorScalarCumulative final {
999999
Option<Dtype> dtype;
10001000
};
10011001

1002+
struct NcNGather final {
1003+
Ptr<TensorRef> dst;
1004+
Ptr<TensorRef> data;
1005+
Ptr<TensorRef> indices;
1006+
Option<Dtype> dtype;
1007+
};
1008+
10021009
struct Operator {
10031010
enum class Tag {
10041011
activate = 1,
@@ -1068,6 +1075,7 @@ struct Operator {
10681075
randSetState,
10691076
extendedInst,
10701077
tensorScalarCumulative,
1078+
ncNGather,
10711079
};
10721080
Tag tag;
10731081
Operator(Tag tag) : tag(tag) {}
@@ -1410,6 +1418,11 @@ struct OperatorTensorScalarCumulativeWrapper final : Operator {
14101418
: Operator(Tag::tensorScalarCumulative) {}
14111419
};
14121420

1421+
struct OperatorNcNGatherWrapper final : Operator {
1422+
Ptr<NcNGather> op;
1423+
OperatorNcNGatherWrapper() : Operator(Tag::ncNGather) {}
1424+
};
1425+
14131426
struct Stmt {
14141427
enum class Tag {
14151428
oper = 1,

interop/klr/klir_serde.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3232,11 +3232,32 @@ Ptr<TensorScalarCumulative> TensorScalarCumulative_des(FILE *in) {
32323232
return x;
32333233
}
32343234

3235+
Ptr<NcNGather> NcNGather_des(FILE *in) {
3236+
u8 t, c, l;
3237+
if (!deserialize_tag(in, &t, &c, &l)) {
3238+
std::ostringstream msg;
3239+
msg << "Could not find tag, expecting NcNGather:209,0";
3240+
throw std::runtime_error(msg.str());
3241+
}
3242+
if (t != 209 || c != 0 || l != 4) {
3243+
std::ostringstream msg;
3244+
msg << "Expecting NcNGather:(209,0,4)";
3245+
msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")";
3246+
throw std::runtime_error(msg.str());
3247+
}
3248+
Ptr<NcNGather> x = ptr<NcNGather>();
3249+
x->dst = TensorRef_des(in);
3250+
x->data = TensorRef_des(in);
3251+
x->indices = TensorRef_des(in);
3252+
x->dtype = Option_Dtype_des(in);
3253+
return x;
3254+
}
3255+
32353256
Ptr<Operator> Operator_des(FILE *in) {
32363257
u8 t, c, l;
32373258
if (!deserialize_tag(in, &t, &c, &l))
32383259
throw std::runtime_error("Could not read tag");
3239-
if (t != 209)
3260+
if (t != 210)
32403261
throw std::runtime_error("Unexpected type tag");
32413262
switch (c) {
32423263
case 0: {
@@ -3785,6 +3806,14 @@ Ptr<Operator> Operator_des(FILE *in) {
37853806
return x;
37863807
break;
37873808
}
3809+
case 67: {
3810+
if (l != 1)
3811+
throw std::runtime_error("Wrong number of elements");
3812+
Ptr<OperatorNcNGatherWrapper> x = ptr<OperatorNcNGatherWrapper>();
3813+
x->op = NcNGather_des(in);
3814+
return x;
3815+
break;
3816+
}
37883817
default:
37893818
throw std::runtime_error("Invalid value tag");
37903819
}

interop/klr/klir_serde.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ Ptr<SetRngSeed> SetRngSeed_des(FILE *in);
135135
Ptr<RandSetState> RandSetState_des(FILE *in);
136136
Ptr<ExtendedInst> ExtendedInst_des(FILE *in);
137137
Ptr<TensorScalarCumulative> TensorScalarCumulative_des(FILE *in);
138+
Ptr<NcNGather> NcNGather_des(FILE *in);
138139
Ptr<Operator> Operator_des(FILE *in);
139140
Ptr<Stmt> Stmt_des(FILE *in);
140141
Ptr<Block> Block_des(FILE *in);

0 commit comments

Comments
 (0)