Skip to content

Commit fe116e3

Browse files
committed
Add tensor_scalar_cumulative API
1 parent 7d7ff0e commit fe116e3

File tree

9 files changed

+106
-2
lines changed

9 files changed

+106
-2
lines changed

KLR/Core/Basic.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ partial def operatorBasicTensors : Operator → List TensorRef
180180
| .rng r | .rand2 r | .randGetState r => [r.dst]
181181
| .setRngSeed r | .randSetState r => [r.src]
182182
| .extendedInst _ => []
183+
| .tensorScalarCumulative t => [t.dst, t.src]
183184

184185
partial def operatorAdditionalTensors : Operator → List TensorName
185186
| .ncActivate d => (tensors d.scale) ++ (tensors d.bias) ++ (tensors d.reduceRes)
@@ -200,6 +201,7 @@ partial def operatorAdditionalTensors : Operator → List TensorName
200201
| .send s => tensors s.srcs
201202
| .recv r => tensors r.dsts
202203
| .rand2 r => tensors r.min ++ tensors r.max
204+
| .tensorScalarCumulative t => (tensors t.imm0) ++ (tensors t.imm1)
203205
| _ => []
204206

205207
instance : Tensors Operator where

KLR/Core/LowerAP.lean

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ def Operator.lowerAccessPatterns (k : Operator) : KLR.Err Operator :=
191191
| .setRngSeed r => return .setRngSeed { r with src := (<- r.src.lowerAccessPatterns)}
192192
| .randSetState r => return .randSetState { r with src := (<- r.src.lowerAccessPatterns)}
193193
| .extendedInst i => return .extendedInst i
194+
| .tensorScalarCumulative op => return .tensorScalarCumulative { op with
195+
dst := <- op.dst.lowerAccessPatterns
196+
src := <- op.src.lowerAccessPatterns
197+
imm0 := <- Operand.lowerAccessPatterns op.imm0
198+
imm1 := <- op.imm1.mapM Operand.lowerAccessPatterns
199+
}
194200

195201
def Stmt.lowerAccessPatterns : Stmt → KLR.Err Stmt
196202
| .oper op name pos => return .oper (<- op.lowerAccessPatterns) name pos

KLR/Core/Operators.lean

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,19 @@ structure ExtendedInst where
931931
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp
932932

933933
@[serde tag = 208]
934+
structure TensorScalarCumulative where
935+
dst : TensorRef
936+
src : TensorRef
937+
op0 : AluOp
938+
op1: AluOp
939+
imm0: Operand
940+
imm1: Option Operand
941+
reduceCmd: AccumCmd
942+
reverse : TensorScalarReverseOps
943+
dtype : Option Dtype
944+
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp
945+
946+
@[serde tag = 209]
934947
inductive Operator where
935948
| activate (op : Activate)
936949
| ncActivate (op : NcActivate)
@@ -998,9 +1011,10 @@ inductive Operator where
9981011
| setRngSeed (op : SetRngSeed)
9991012
| randSetState (op : RandSetState)
10001013
| extendedInst (op : ExtendedInst)
1014+
| tensorScalarCumulative (op: TensorScalarCumulative)
10011015
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp
10021016

1003-
@[serde tag = 209]
1017+
@[serde tag = 210]
10041018
inductive TGROperator where
10051019
| activate (op : Activate)
10061020
| affineSelect (op : AffineSelect)

KLR/Extract/Extract/Basic.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ def klrAST: MetaM (List LeanType) := do
366366
`KLR.Core.SetRngSeed,
367367
`KLR.Core.RandSetState,
368368
`KLR.Core.ExtendedInst,
369+
`KLR.Core.TensorScalarCumulative,
369370
`KLR.Core.Operator,
370371
-- Core.Basic
371372
`KLR.Core.Stmt,

KLR/Trace/ISA.lean

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,36 @@ nki builtin.isa.tensor_scalar_reduce
426426
}) name
427427
return .none
428428

429+
nki builtin.isa.tensor_scalar_cumulative
430+
(dst: Access)
431+
(src: Access)
432+
(op0: AluOp)
433+
(op1: AluOp)
434+
(imm0: Sum Immediate Access)
435+
(imm1: Option (Sum Immediate Access) := none)
436+
(reduce_cmd: AccumCmd := AccumCmd.ZeroAccumulate)
437+
(mask: Option Immediate := none)
438+
(name : Option String := none) := do
439+
if mask.isSome then
440+
throw maskNotSupported
441+
Trace.add_stmt $ .oper (.tensorScalarCumulative {
442+
dst := .abstract dst
443+
src := .abstract src
444+
op0 := op0
445+
op1 := op1
446+
imm0 := match imm0 with
447+
| .inl i => .imm i
448+
| .inr t => .tile $ .abstract t
449+
imm1 := match imm1 with
450+
| some (.inl i) => some $ .imm i
451+
| some (.inr t) => some $ .tile $ .abstract t
452+
| none => .none
453+
reduceCmd := reduce_cmd
454+
reverse := TensorScalarReverseOps.none
455+
dtype := dst.tensor.dtype
456+
}) name
457+
return .none
458+
429459
nki builtin.isa.tensor_copy
430460
(dst: Access)
431461
(src : Access)

interop/klr/NKI.asdl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ SetRngSeed = (TensorRef src)
337337
RandSetState = (TensorRef src, Engine engine)
338338

339339
ExtendedInst = (Nat opcode, Bool hasRead, Bool hasWrite, Nat ports, Nat* data0, Nat* data1)
340+
341+
TensorScalarCumulative = (TensorRef dst, TensorRef src, AluOp op0, AluOp op1, Operand imm0, Operand? imm1, AccumCmd reduceCmd, TensorScalarReverseOps reverse, Dtype? dtype)
340342
Operator =
341343
| activate(Activate op)
342344
| ncActivate(NcActivate op)
@@ -404,6 +406,7 @@ Operator =
404406
| setRngSeed(SetRngSeed op)
405407
| randSetState(RandSetState op)
406408
| extendedInst(ExtendedInst op)
409+
| tensorScalarCumulative(TensorScalarCumulative op)
407410

408411
Stmt =
409412
| oper(Operator op, String? name, Pos pos)

interop/klr/klir_ast.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,18 @@ struct ExtendedInst final {
986986
List<Nat> data1;
987987
};
988988

989+
struct TensorScalarCumulative final {
990+
Ptr<TensorRef> dst;
991+
Ptr<TensorRef> src;
992+
AluOp op0;
993+
AluOp op1;
994+
Ptr<Operand> imm0;
995+
Option<Ptr<Operand>> imm1;
996+
AccumCmd reduceCmd;
997+
TensorScalarReverseOps reverse;
998+
Option<Dtype> dtype;
999+
};
1000+
9891001
struct Operator {
9901002
enum class Tag {
9911003
activate = 1,
@@ -1054,6 +1066,7 @@ struct Operator {
10541066
setRngSeed,
10551067
randSetState,
10561068
extendedInst,
1069+
tensorScalarCumulative,
10571070
};
10581071
Tag tag;
10591072
Operator(Tag tag) : tag(tag) {}
@@ -1390,6 +1403,12 @@ struct OperatorExtendedInstWrapper final : Operator {
13901403
OperatorExtendedInstWrapper() : Operator(Tag::extendedInst) {}
13911404
};
13921405

1406+
struct OperatorTensorScalarCumulativeWrapper final : Operator {
1407+
Ptr<TensorScalarCumulative> op;
1408+
OperatorTensorScalarCumulativeWrapper()
1409+
: Operator(Tag::tensorScalarCumulative) {}
1410+
};
1411+
13931412
struct Stmt {
13941413
enum class Tag {
13951414
oper = 1,

interop/klr/klir_serde.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3205,11 +3205,30 @@ Ptr<ExtendedInst> ExtendedInst_des(FILE *in) {
32053205
return x;
32063206
}
32073207

3208+
Ptr<TensorScalarCumulative> TensorScalarCumulative_des(FILE *in) {
3209+
u8 t, c, l;
3210+
if (!deserialize_tag(in, &t, &c, &l))
3211+
throw std::runtime_error("Could not find tag");
3212+
if (t != 208 || c != 0 || l != 9)
3213+
throw std::runtime_error("Invalid Tag");
3214+
Ptr<TensorScalarCumulative> x = ptr<TensorScalarCumulative>();
3215+
x->dst = TensorRef_des(in);
3216+
x->src = TensorRef_des(in);
3217+
x->op0 = AluOp_des(in);
3218+
x->op1 = AluOp_des(in);
3219+
x->imm0 = Operand_des(in);
3220+
x->imm1 = Option_Operand_des(in);
3221+
x->reduceCmd = AccumCmd_des(in);
3222+
x->reverse = TensorScalarReverseOps_des(in);
3223+
x->dtype = Option_Dtype_des(in);
3224+
return x;
3225+
}
3226+
32083227
Ptr<Operator> Operator_des(FILE *in) {
32093228
u8 t, c, l;
32103229
if (!deserialize_tag(in, &t, &c, &l))
32113230
throw std::runtime_error("Could not read tag");
3212-
if (t != 208)
3231+
if (t != 209)
32133232
throw std::runtime_error("Unexpected type tag");
32143233
switch (c) {
32153234
case 0: {
@@ -3749,6 +3768,15 @@ Ptr<Operator> Operator_des(FILE *in) {
37493768
return x;
37503769
break;
37513770
}
3771+
case 66: {
3772+
if (l != 1)
3773+
throw std::runtime_error("Wrong number of elements");
3774+
Ptr<OperatorTensorScalarCumulativeWrapper> x =
3775+
ptr<OperatorTensorScalarCumulativeWrapper>();
3776+
x->op = TensorScalarCumulative_des(in);
3777+
return x;
3778+
break;
3779+
}
37523780
default:
37533781
throw std::runtime_error("Invalid value tag");
37543782
}

interop/klr/klir_serde.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ Ptr<RandGetState> RandGetState_des(FILE *in);
134134
Ptr<SetRngSeed> SetRngSeed_des(FILE *in);
135135
Ptr<RandSetState> RandSetState_des(FILE *in);
136136
Ptr<ExtendedInst> ExtendedInst_des(FILE *in);
137+
Ptr<TensorScalarCumulative> TensorScalarCumulative_des(FILE *in);
137138
Ptr<Operator> Operator_des(FILE *in);
138139
Ptr<Stmt> Stmt_des(FILE *in);
139140
Ptr<Block> Block_des(FILE *in);

0 commit comments

Comments
 (0)