Skip to content

Commit 01d7302

Browse files
committed
feat: implement isinstance for objects and basic values
1 parent 3bd8aa2 commit 01d7302

File tree

6 files changed

+43
-1
lines changed

6 files changed

+43
-1
lines changed

KLR/Trace/Builtin.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ def neuronxcc : Name := .str .anonymous "neuronxcc"
3333
def nki_ : Name := .str neuronxcc "nki"
3434
def nki_isa : Name := .str nki_ "isa"
3535
def nki_lang : Name := .str nki_ "language"
36+
def nki_typing : Name := .str nki_ "typing"
3637

3738
def nl : String -> Name := .str nki_lang
3839
def nisa : String -> Name := .str nki_isa
40+
def nt : String -> Name := .str nki_typing
3941

4042
-- conveience functions for creating environment entries
4143

KLR/Trace/ISA.lean

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def getTransposeOps(op: Option (List Int)) : Trace TransposeOps :=
8888
nki builtin.isa.get_nc_version := do
8989
lookup `arch
9090

91+
nki builtin.typing.scalar (t : Term) := do
92+
return .scalar t
93+
9194
-- set_option linter.unusedVariables false
9295

9396
nki builtin.isa.nc_matmul

KLR/Trace/NKI.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def NKIEnv : List (Name × Term) :=
3838
, module nki_
3939
, module nki_isa
4040
, module nki_lang
41+
, module nki_typing
4142
, module `math
4243
, const_int (.str (nl "tile_size") "pmax") 128
4344
, const_int (.str (nl "tile_size") "gemm_stationary_fmax") 128

KLR/Trace/Python.lean

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,34 @@ The builtin.python namespace is mapped to the top-level namespace.
5858
For example, builtin.python.f will appear as f.
5959
-/
6060

61+
nki builtin.python.isinstance (t : Term) (ty : Term) := do
62+
match t, ty with
63+
| .object cls .., .source { name }
64+
| .ref _ (.object cls), .source { name } => return .bool (cls == name)
65+
| .none, .builtin `builtin.python.NoneType ..
66+
| .bool .., .builtin `builtin.python.bool ..
67+
| .int .., .builtin `builtin.python.int ..
68+
| .float .., .builtin `builtin.python.float ..
69+
| .string .., .builtin `builtin.python.str ..
70+
| .tuple .., .builtin `builtin.python.tuple ..
71+
| .list .., .builtin `builtin.python.list ..
72+
| .ref _ .list, .builtin `builtin.python.list ..
73+
| .dict .., .builtin `builtin.python.dict ..
74+
| .ref _ .dict, .builtin `builtin.python.dict ..
75+
| .scalar .., .builtin `builtin.typing.scalar ..
76+
| .ellipsis, .builtin `builtin.python.ellipsis ..
77+
| .slice .., .builtin `builtin.python.slice .. => return .bool true
78+
| _, _ => return .bool false
79+
80+
nki builtin.python.NoneType := do
81+
return .none
82+
83+
nki builtin.python.EllipsisType := do
84+
return .ellipsis
85+
86+
nki builtin.python.ellipsis := do
87+
return .ellipsis
88+
6189
nki builtin.python.slice (args : List Term) := do
6290
match args with
6391
| [e] => return .slice (some 0) (<- fromNKI? e) (some 1)
@@ -91,6 +119,9 @@ nki builtin.python.abs (t : Term) := do
91119
| .float f => return .float f.abs
92120
| _ => throw "abs expects an integer or float number"
93121

122+
nki builtin.python.tuple (l : List Term := []) := do
123+
return .tuple l
124+
94125
nki builtin.python.str (t : Term) := do
95126
return .string (<- t.toStr)
96127

KLR/Trace/Term.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ def builtinEnv : List (Name × Term) := Id.run do
629629
let names : List Name := match name with
630630
| .str `builtin.python n => [.str `builtins n, .str .anonymous n]
631631
| .str `builtin.isa n => [nisa n, name]
632+
| .str `builtin.typing n => [nt n, name]
632633
| .str `builtin.lang n => [nl n, name]
633634
| _ => [name]
634635
names.map fun n => (n, fn)

KLR/Trace/Types.lean

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ inductive Term where
112112
| list : Array Term -> Term
113113
| dict : Array (String × Term) -> Term
114114
| tensor : TensorLib.Tensor -> Term
115+
| scalar : Term -> Term
115116
-- indexing
116117
| ellipsis : Term
117118
| slice : Option Int -> Option Int -> Option Int -> Term
@@ -138,6 +139,7 @@ def kindStr : Term → String
138139
| .list _ => "list"
139140
| .dict _ => "dict"
140141
| .tensor _ => "tensor"
142+
| .scalar .. => "scalar"
141143
| .ellipsis => "ellipsis"
142144
| .slice _ _ _ => "slice"
143145
| .pointer _ => "pointer"
@@ -335,6 +337,7 @@ partial def toStr : Term -> Trace String
335337
| .slice a b c => return s!"slice({a},{b},{c})"
336338
| .pointer a => return s!"<Ptr({a.name})>"
337339
| .tensor .. => return "<Tensor>"
340+
| .scalar .. => return "<scalar>"
338341

339342
-- This is partial because the user could have created a heap graph
340343
partial def isTrue (t : Term) : Trace Bool := do
@@ -369,7 +372,8 @@ partial def isTrue (t : Term) : Trace Bool := do
369372
throw "The truth value of an array with more than one element is ambiguous"
370373
else
371374
throw "The truth value of an empty array is ambiguous"
372-
375+
| .scalar .. =>
376+
throw "boolean conversion of scalar not supported"
373377

374378
def isFalse (t : Term) : Trace Bool :=
375379
return not (<- t.isTrue)

0 commit comments

Comments
 (0)