diff --git a/KLR/NKI/Annotations.lean b/KLR/NKI/Annotations.lean index 57063246..44b2c459 100644 --- a/KLR/NKI/Annotations.lean +++ b/KLR/NKI/Annotations.lean @@ -41,6 +41,7 @@ abbrev Ann := Pass Unit private def isValidName' : Name -> Bool | .str `neuronxcc.nki._pre_prod_kernels _ + | .str `neuronxcc.nki._private_nkl _ | .str `neuronxcc.nki._pre_prod_nkl _ => true | .str _ "neuronxcc" => false | .str n _ => isValidName' n diff --git a/KLR/NKI/Simplify.lean b/KLR/NKI/Simplify.lean index c258d971..d6c02afb 100644 --- a/KLR/NKI/Simplify.lean +++ b/KLR/NKI/Simplify.lean @@ -202,6 +202,14 @@ private def expr' (e' : Python.Expr') : Simplify Expr' := | .boolOp op l => return (<- booleanOp op (<- exprs l)).expr | .binOp op l r => return .binOp (<- binOp op) (<- expr l) (<- expr r) | .unaryOp op e => return (<- unaryOp op) (<- expr e) + | .compare a [.is] [⟨.const .none, pos⟩] => + return .binOp .eq (<- expr a) ⟨ .value .none, pos ⟩ + | .compare a [.isNot] [⟨.const .none, pos⟩] => + return .binOp .ne (<- expr a) ⟨ .value .none, pos ⟩ + | .compare a [.isIn] [b] => + return .call ⟨ .var `builtin.op.in, a.pos ⟩ [<- expr a, <- expr b] [] + | .compare a [.notIn] [b] => + return .call ⟨ .var `builtin.op.notin, a.pos ⟩ [<- expr a, <- expr b] [] | .compare a ops l => do let a <- expr a let ops <- ops.mapM cmpOp @@ -396,10 +404,10 @@ private def params (args : Python.Args) : Simplify (List Param) := do throw "varargs are not supported in NKI" if args.kwarg.isSome then warn "variable keyword arguments are not supported in NKI" - if args.posonlyargs.length > 0 then - warn "position-only arguments are not supported in NKI" - if args.kwonlyargs.length > 0 then - warn "keyword-only arguments are not supported in NKI" + --if args.posonlyargs.length > 0 then + -- warn "position-only arguments are not supported in NKI" + --if args.kwonlyargs.length > 0 then + -- warn "keyword-only arguments are not supported in NKI" let defaults := args.all_defaults let mut params := [] for name in args.names do diff --git a/KLR/Trace/NKI.lean b/KLR/Trace/NKI.lean index d1383366..d1b8576c 100644 --- a/KLR/Trace/NKI.lean +++ b/KLR/Trace/NKI.lean @@ -316,7 +316,7 @@ partial def mutate (x e : Expr) : Trace Unit := let i : Nat <- fromNKI? i let e <- expr e if h : i < a.size then - extend name (.list (a.set i e h)) + extend_global name (.list (a.set i e h)) return () else throw "index out of range" else throw "internal error: expecting list literal" @@ -326,7 +326,7 @@ partial def mutate (x e : Expr) : Trace Unit := let i : String <- fromNKI? i let e <- expr e let a := AA.insert a i e - extend name (.dict (AA.insert a i e)) + extend_global name (.dict (AA.insert a i e)) return () else throw "internal error: expecting dictionary literal" | r@(.ref _ (.object cls)) => diff --git a/KLR/Trace/Python.lean b/KLR/Trace/Python.lean index 8be6725f..27eb24a6 100644 --- a/KLR/Trace/Python.lean +++ b/KLR/Trace/Python.lean @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. -/ +import Batteries.Data.String import KLR.Core import KLR.Trace.ISA import KLR.Trace.Types @@ -24,6 +25,7 @@ Python related builtins -/ namespace KLR.Trace +open Substring (containsSubstr) open Core /- @@ -53,6 +55,22 @@ nki builtin.op.invert (t : Term) := do let i : Int <- fromNKI? t return .int i.toInt32.complement.toInt +private def isin (t : Term) (l : Term) : Trace Bool := do + let l <- match l with + | .ref name _ => lookup name + | _ => pure l + match t, l with + | _, .tuple l => return l.contains t + | _, .list a => return a.contains t + | .string t, .string s => return containsSubstr s t + | _ , _ => throw "in operator not support on types {kindStr t} and {kindStr l}" + +nki builtin.op.in (t : Term) (l : Term) := do + return .bool (<- isin t l) + +nki builtin.op.notin (t : Term) (l : Term) := do + return .bool (<- isin t l).not + /- The builtin.python namespace is mapped to the top-level namespace. For example, builtin.python.f will appear as f.