Skip to content

Commit f46a0c4

Browse files
committed
feat: add dtype override to ap method
1 parent 095a195 commit f46a0c4

File tree

5 files changed

+20
-8
lines changed

5 files changed

+20
-8
lines changed

KLR/Core/Tensor.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ structure BirAccessPattern where
569569
scalarOffset : Option ScalarOffset
570570
vectorOffset : Option Access
571571
indirectDim : Int
572+
dtypeOverride : Option Dtype := none
572573
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp
573574

574575

KLR/Trace/Term.lean

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,8 @@ nki builtin.access.ap
594594
(offset : Nat := 0)
595595
(scalar_offset : Option (Sum Access Term) := none)
596596
(vector_offset : Option Access := none)
597-
(indirect_dim : Int := 0) := do
597+
(indirect_dim : Int := 0)
598+
(dtype : Option Dtype := none) := do
598599
match self with
599600
| .simple t =>
600601
let pattern := pattern.map fun (s,c) => Core.APPair.mk s c 0
@@ -610,6 +611,7 @@ nki builtin.access.ap
610611
scalarOffset
611612
vectorOffset := vector_offset
612613
indirectDim := indirect_dim
614+
dtypeOverride := dtype
613615
}
614616
return .access (.birPattern ap)
615617
-- TODO: need to figoure out how to combine cannonical form AP with user specified AP

interop/klr/NKI.asdl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ ScalarOffset =
5252
| acc(Access a)
5353

5454

55-
BirAccessPattern = (TensorName tensor, Nat offset, APPair* pattern, ScalarOffset? scalarOffset, Access? vectorOffset, Int indirectDim)
55+
BirAccessPattern = (TensorName tensor, Nat offset, APPair* pattern, ScalarOffset? scalarOffset, Access? vectorOffset, Int indirectDim, Dtype? dtypeOverride)
5656
Access =
5757
| simple(TensorName tensor)
5858
| basic(AccessBasic access)

interop/klr/klir_ast.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ struct BirAccessPattern final {
181181
Option<Ptr<ScalarOffset>> scalarOffset;
182182
Option<Ptr<Access>> vectorOffset;
183183
Int indirectDim;
184+
Option<Dtype> dtypeOverride;
184185
};
185186

186187
struct Access {

interop/klr/klir_serde.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -749,9 +749,9 @@ Ptr<BirAccessPattern> BirAccessPattern_des(FILE *in) {
749749
msg << "Could not find tag, expecting BirAccessPattern:124,0";
750750
throw std::runtime_error(msg.str());
751751
}
752-
if (t != 124 || c != 0 || l != 6) {
752+
if (t != 124 || c != 0 || l != 7) {
753753
std::ostringstream msg;
754-
msg << "Expecting BirAccessPattern:(124,0,6)";
754+
msg << "Expecting BirAccessPattern:(124,0,7)";
755755
msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")";
756756
throw std::runtime_error(msg.str());
757757
}
@@ -762,6 +762,7 @@ Ptr<BirAccessPattern> BirAccessPattern_des(FILE *in) {
762762
x->scalarOffset = Option_ScalarOffset_des(in);
763763
x->vectorOffset = Option_Access_des(in);
764764
x->indirectDim = Int_des(in);
765+
x->dtypeOverride = Option_Dtype_des(in);
765766
return x;
766767
}
767768

@@ -3207,10 +3208,17 @@ Ptr<ExtendedInst> ExtendedInst_des(FILE *in) {
32073208

32083209
Ptr<TensorScalarCumulative> TensorScalarCumulative_des(FILE *in) {
32093210
u8 t, c, l;
3210-
if (!deserialize_tag(in, &t, &c, &l))
3211-
throw std::runtime_error("Could not find tag");
3212-
if (t != 208 || c != 0 || l != 9)
3213-
throw std::runtime_error("Invalid Tag");
3211+
if (!deserialize_tag(in, &t, &c, &l)) {
3212+
std::ostringstream msg;
3213+
msg << "Could not find tag, expecting TensorScalarCumulative:208,0";
3214+
throw std::runtime_error(msg.str());
3215+
}
3216+
if (t != 208 || c != 0 || l != 9) {
3217+
std::ostringstream msg;
3218+
msg << "Expecting TensorScalarCumulative:(208,0,9)";
3219+
msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")";
3220+
throw std::runtime_error(msg.str());
3221+
}
32143222
Ptr<TensorScalarCumulative> x = ptr<TensorScalarCumulative>();
32153223
x->dst = TensorRef_des(in);
32163224
x->src = TensorRef_des(in);

0 commit comments

Comments
 (0)