Skip to content

Commit 33a7f13

Browse files
committed
add nisa.nc_n_gather in place of nl.gather_flattened
1 parent fe116e3 commit 33a7f13

File tree

9 files changed

+92
-6
lines changed

9 files changed

+92
-6
lines changed

KLR/Core/Basic.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ partial def operatorBasicTensors : Operator → List TensorRef
181181
| .setRngSeed r | .randSetState r => [r.src]
182182
| .extendedInst _ => []
183183
| .tensorScalarCumulative t => [t.dst, t.src]
184+
| .ncNGather g => [g.dst, g.data, g.indices]
184185

185186
partial def operatorAdditionalTensors : Operator → List TensorName
186187
| .ncActivate d => (tensors d.scale) ++ (tensors d.bias) ++ (tensors d.reduceRes)
@@ -202,6 +203,7 @@ partial def operatorAdditionalTensors : Operator → List TensorName
202203
| .recv r => tensors r.dsts
203204
| .rand2 r => tensors r.min ++ tensors r.max
204205
| .tensorScalarCumulative t => (tensors t.imm0) ++ (tensors t.imm1)
206+
| .ncNGather _ => []
205207
| _ => []
206208

207209
instance : Tensors Operator where

KLR/Core/LowerAP.lean

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ def Operator.lowerAccessPatterns (k : Operator) : KLR.Err Operator :=
197197
imm0 := <- Operand.lowerAccessPatterns op.imm0
198198
imm1 := <- op.imm1.mapM Operand.lowerAccessPatterns
199199
}
200+
| .ncNGather op => return .ncNGather { op with
201+
dst := <- op.dst.lowerAccessPatterns
202+
data := <- op.data.lowerAccessPatterns
203+
indices := <- op.indices.lowerAccessPatterns
204+
}
200205

201206
def Stmt.lowerAccessPatterns : Stmt → KLR.Err Stmt
202207
| .oper op name pos => return .oper (<- op.lowerAccessPatterns) name pos

KLR/Core/Operators.lean

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,14 @@ structure TensorScalarCumulative where
944944
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp
945945

946946
@[serde tag = 209]
947+
structure NcNGather where
948+
dst : TensorRef
949+
data : TensorRef
950+
indices : TensorRef
951+
dtype : Option Dtype
952+
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp
953+
954+
@[serde tag = 210]
947955
inductive Operator where
948956
| activate (op : Activate)
949957
| ncActivate (op : NcActivate)
@@ -1012,9 +1020,10 @@ inductive Operator where
10121020
| randSetState (op : RandSetState)
10131021
| extendedInst (op : ExtendedInst)
10141022
| tensorScalarCumulative (op: TensorScalarCumulative)
1023+
| ncNGather (op: NcNGather)
10151024
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp
10161025

1017-
@[serde tag = 210]
1026+
@[serde tag = 211]
10181027
inductive TGROperator where
10191028
| activate (op : Activate)
10201029
| affineSelect (op : AffineSelect)

KLR/Extract/Extract/Basic.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def klrAST: MetaM (List LeanType) := do
367367
`KLR.Core.RandSetState,
368368
`KLR.Core.ExtendedInst,
369369
`KLR.Core.TensorScalarCumulative,
370+
`KLR.Core.NcNGather,
370371
`KLR.Core.Operator,
371372
-- Core.Basic
372373
`KLR.Core.Stmt,

KLR/Trace/ISA.lean

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,3 +1180,19 @@ nki builtin.isa.extended_inst
11801180
data1
11811181
}) name
11821182
return .none
1183+
1184+
nki builtin.isa.nc_n_gather
1185+
(dst: Access)
1186+
(data: Access)
1187+
(indices: Access)
1188+
(mask: Option Immediate := none)
1189+
(dtype: Option Dtype := none)
1190+
(name : Option String := none) := do
1191+
if mask.isSome then throw maskNotSupported
1192+
Trace.add_stmt $ .oper (.ncNGather {
1193+
dst := .abstract dst
1194+
data := .abstract data
1195+
indices := .abstract indices
1196+
dtype := dtype.or dst.tensor.dtype
1197+
}) name
1198+
return .none

interop/klr/NKI.asdl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,8 @@ RandSetState = (TensorRef src, Engine engine)
339339
ExtendedInst = (Nat opcode, Bool hasRead, Bool hasWrite, Nat ports, Nat* data0, Nat* data1)
340340

341341
TensorScalarCumulative = (TensorRef dst, TensorRef src, AluOp op0, AluOp op1, Operand imm0, Operand? imm1, AccumCmd reduceCmd, TensorScalarReverseOps reverse, Dtype? dtype)
342+
343+
NcNGather = (TensorRef dst, TensorRef data, TensorRef indices, Dtype? dtype)
342344
Operator =
343345
| activate(Activate op)
344346
| ncActivate(NcActivate op)
@@ -407,6 +409,7 @@ Operator =
407409
| randSetState(RandSetState op)
408410
| extendedInst(ExtendedInst op)
409411
| tensorScalarCumulative(TensorScalarCumulative op)
412+
| ncNGather(NcNGather op)
410413

411414
Stmt =
412415
| oper(Operator op, String? name, Pos pos)

interop/klr/klir_ast.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,13 @@ struct TensorScalarCumulative final {
998998
Option<Dtype> dtype;
999999
};
10001000

1001+
struct NcNGather final {
1002+
Ptr<TensorRef> dst;
1003+
Ptr<TensorRef> data;
1004+
Ptr<TensorRef> indices;
1005+
Option<Dtype> dtype;
1006+
};
1007+
10011008
struct Operator {
10021009
enum class Tag {
10031010
activate = 1,
@@ -1067,6 +1074,7 @@ struct Operator {
10671074
randSetState,
10681075
extendedInst,
10691076
tensorScalarCumulative,
1077+
ncNGather,
10701078
};
10711079
Tag tag;
10721080
Operator(Tag tag) : tag(tag) {}
@@ -1409,6 +1417,11 @@ struct OperatorTensorScalarCumulativeWrapper final : Operator {
14091417
: Operator(Tag::tensorScalarCumulative) {}
14101418
};
14111419

1420+
struct OperatorNcNGatherWrapper final : Operator {
1421+
Ptr<NcNGather> op;
1422+
OperatorNcNGatherWrapper() : Operator(Tag::ncNGather) {}
1423+
};
1424+
14121425
struct Stmt {
14131426
enum class Tag {
14141427
oper = 1,

interop/klr/klir_serde.cpp

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3207,10 +3207,17 @@ Ptr<ExtendedInst> ExtendedInst_des(FILE *in) {
32073207

32083208
Ptr<TensorScalarCumulative> TensorScalarCumulative_des(FILE *in) {
32093209
u8 t, c, l;
3210-
if (!deserialize_tag(in, &t, &c, &l))
3211-
throw std::runtime_error("Could not find tag");
3212-
if (t != 208 || c != 0 || l != 9)
3213-
throw std::runtime_error("Invalid Tag");
3210+
if (!deserialize_tag(in, &t, &c, &l)) {
3211+
std::ostringstream msg;
3212+
msg << "Could not find tag, expecting TensorScalarCumulative:208,0";
3213+
throw std::runtime_error(msg.str());
3214+
}
3215+
if (t != 208 || c != 0 || l != 9) {
3216+
std::ostringstream msg;
3217+
msg << "Expecting TensorScalarCumulative:(208,0,9)";
3218+
msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")";
3219+
throw std::runtime_error(msg.str());
3220+
}
32143221
Ptr<TensorScalarCumulative> x = ptr<TensorScalarCumulative>();
32153222
x->dst = TensorRef_des(in);
32163223
x->src = TensorRef_des(in);
@@ -3224,11 +3231,32 @@ Ptr<TensorScalarCumulative> TensorScalarCumulative_des(FILE *in) {
32243231
return x;
32253232
}
32263233

3234+
Ptr<NcNGather> NcNGather_des(FILE *in) {
3235+
u8 t, c, l;
3236+
if (!deserialize_tag(in, &t, &c, &l)) {
3237+
std::ostringstream msg;
3238+
msg << "Could not find tag, expecting NcNGather:209,0";
3239+
throw std::runtime_error(msg.str());
3240+
}
3241+
if (t != 209 || c != 0 || l != 4) {
3242+
std::ostringstream msg;
3243+
msg << "Expecting NcNGather:(209,0,4)";
3244+
msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")";
3245+
throw std::runtime_error(msg.str());
3246+
}
3247+
Ptr<NcNGather> x = ptr<NcNGather>();
3248+
x->dst = TensorRef_des(in);
3249+
x->data = TensorRef_des(in);
3250+
x->indices = TensorRef_des(in);
3251+
x->dtype = Option_Dtype_des(in);
3252+
return x;
3253+
}
3254+
32273255
Ptr<Operator> Operator_des(FILE *in) {
32283256
u8 t, c, l;
32293257
if (!deserialize_tag(in, &t, &c, &l))
32303258
throw std::runtime_error("Could not read tag");
3231-
if (t != 209)
3259+
if (t != 210)
32323260
throw std::runtime_error("Unexpected type tag");
32333261
switch (c) {
32343262
case 0: {
@@ -3777,6 +3805,14 @@ Ptr<Operator> Operator_des(FILE *in) {
37773805
return x;
37783806
break;
37793807
}
3808+
case 67: {
3809+
if (l != 1)
3810+
throw std::runtime_error("Wrong number of elements");
3811+
Ptr<OperatorNcNGatherWrapper> x = ptr<OperatorNcNGatherWrapper>();
3812+
x->op = NcNGather_des(in);
3813+
return x;
3814+
break;
3815+
}
37803816
default:
37813817
throw std::runtime_error("Invalid value tag");
37823818
}

interop/klr/klir_serde.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ Ptr<SetRngSeed> SetRngSeed_des(FILE *in);
135135
Ptr<RandSetState> RandSetState_des(FILE *in);
136136
Ptr<ExtendedInst> ExtendedInst_des(FILE *in);
137137
Ptr<TensorScalarCumulative> TensorScalarCumulative_des(FILE *in);
138+
Ptr<NcNGather> NcNGather_des(FILE *in);
138139
Ptr<Operator> Operator_des(FILE *in);
139140
Ptr<Stmt> Stmt_des(FILE *in);
140141
Ptr<Block> Block_des(FILE *in);

0 commit comments

Comments
 (0)