diff --git a/macros/src/main/scala/shine/macros/Primitive.scala b/macros/src/main/scala/shine/macros/Primitive.scala index 90d985a5d..9ed67ce97 100644 --- a/macros/src/main/scala/shine/macros/Primitive.scala +++ b/macros/src/main/scala/shine/macros/Primitive.scala @@ -23,16 +23,14 @@ object Primitive { class Impl(val c: blackbox.Context) { import c.universe._ - def expPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees) - def accPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees) - def comPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees) + def expPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees) + def accPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees) + def comPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees) def primitive(transform : ClassDef => ClassDef)(annottees: Seq[c.Expr[Any]]): c.Expr[Any] = { annottees.map(_.tree) match { - case (cdef: ClassDef) :: Nil => - c.Expr(transform(cdef)) - case (cdef: ClassDef) :: (md: ModuleDef) :: Nil => - c.Expr(q"{${transform(cdef)}; $md}") + case (cdef: ClassDef) :: Nil => c.Expr(transform(cdef)) + case (cdef: ClassDef) :: (md: ModuleDef) :: Nil => c.Expr(q"{${transform(cdef)}; $md}") case _ => c.abort(c.enclosingPosition, "expected a class definition") } } @@ -40,6 +38,47 @@ object Primitive { def makeLowerCaseName(s: String): String = s"${Character.toLowerCase(s.charAt(0))}${s.substring(1)}" + def makeTraverseCall(v : Tree, name : TermName) : Tree => Option[Tree] = { + case Ident(TypeName("DataType")) | Ident(TypeName("ScalarType")) | + Ident(TypeName("BasicType")) => Some(fq"${name} <- $v.datatype($name)") + case Ident(TypeName("Data")) => Some(fq"${name} <- $v.data($name)") + case Ident(TypeName("Nat")) => Some(fq"${name} <- $v.natDispatch($name)") + case Ident(TypeName("NatIdentifier")) => Some(fq"${name} <- $v.typeIdentifierDispatch(shine.DPIA.Phrases.traverse.Reference)($name)") + case Ident(TypeName("NatToNat")) => Some(fq"${name} <- $v.natToNat($name)") + case Ident(TypeName("NatToData")) => Some(fq"${name} <- $v.natToData($name)") + case Ident(TypeName("AccessType")) => Some(fq"${name} <- $v.accessType($name)") + case Ident(TypeName("AddressSpace")) => Some(fq"${name} <- $v.addressSpace($name)") + // Phrase[ExpType] + case AppliedTypeTree((Ident(TypeName("Phrase")), _)) => Some(fq"${name} <- $v.phrase($name)") + // Vector[Phrase[ExpType]] + case AppliedTypeTree((Ident(TypeName("Vector")), + List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) => Some(fq"${name} <- monad.traverseV($name.map($v.phrase(_)))") + case AppliedTypeTree((Ident(TypeName("Seq")), + List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) => Some(fq"${name} <- monad.traverse($name.map($v.phrase(_)))") + case _ => None + } + + def makeTraverse(name: TypeName, additionalParams: List[ValDef], params: List[ValDef], parent : Tree): Tree = { + val v = q"v" + val paramNames = params.map { case ValDef(_, name, _, _) => q"$name" } + val additionalParamNames = additionalParams.map { case ValDef(_, name, _, _) => q"$name" } + val forLoopBindings : List[Tree] = params.flatMap { + case ValDef(_, name, tpt, _) => makeTraverseCall(v, name)(tpt) + } + val construct = if (additionalParamNames.isEmpty) q"new $name(..$paramNames)" + else q"new $name(..$additionalParamNames)(..$paramNames)" + val forloop = if (forLoopBindings.isEmpty) q"monad.return_($construct)" + else q"for (..${forLoopBindings}) yield ${construct}" + + q""" + override def traverse[M[+_]]($v: shine.DPIA.Phrases.traverse.Traversal[M]): M[$name] = { + import util.monads._ + implicit val monad: Monad[M] = implicitly($v.monad) + $forloop + } + """ + } + def makeVisitAndRebuild(name: TypeName, additionalParams: List[ValDef], params: List[ValDef]): Tree = { @@ -81,13 +120,77 @@ object Primitive { """ } + def makeXMLPrinter(name: TypeName, + additionalParams: List[ValDef], + params: List[ValDef]): Tree = { + def makeAttributes(params: List[ValDef]): (List[ValDef], Tree) = { + if (params.isEmpty) return (params, q"scala.xml.Null") + params.head match { + case ValDef(_, name, tpt, _) => tpt match { + case Ident(TypeName("DataType")) | Ident(TypeName("ScalarType")) | + Ident(TypeName("BasicType")) | Ident(TypeName("Nat")) | + Ident(TypeName("NatToNat")) | Ident(TypeName("NatToData")) | + Ident(TypeName("AccessType")) | Ident(TypeName("AddressSpace")) + => + val (list, next) = makeAttributes(params.tail) + (list, q""" + scala.xml.Attribute(${name.toString}, + scala.xml.Text( + shine.DPIA.Phrases.ToString($name)), + $next) + """) + case _ => (params, q"scala.xml.Null") + } + } + } + + def makeBody(params: List[ValDef]): List[Tree] = { + params.map { + case ValDef(_, name, tpt, _) => + + val body = tpt match { + // Phrase[ExpType] + case AppliedTypeTree((Ident(TypeName("Phrase")), _)) => + q"shine.DPIA.Phrases.xmlPrinter($name)" + // Vector[Phrase[ExpType]] + case AppliedTypeTree((Ident(TypeName("Vector")), + List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) + | AppliedTypeTree((Ident(TypeName("Seq")), + List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) + => + q"$name.flatMap(shine.DPIA.Phrases.xmlPrinter(_)):_*" + case _ => + q"scala.xml.Text(shine.DPIA.Phrases.ToString($name))" + } + q""" + scala.xml.Elem(null, ${name.toString}, + scala.xml.Null, scala.xml.TopScope, + minimizeEmpty = false, $body) + """ + } + } + + val lowerCaseName = makeLowerCaseName(name.toString) + val (rest, attributes) = makeAttributes(params) + val body = makeBody(rest) + + q""" + override def xmlPrinter: scala.xml.Elem = { + val attributes_ = $attributes + val body_ = $body + scala.xml.Elem(null, $lowerCaseName, attributes_, scala.xml.TopScope, + minimizeEmpty = false, (body_):_*) + } + """ + } + case class ClassInfo(name: TypeName, additionalParams: List[ValDef], params: List[ValDef], body: List[Tree], parents: List[Tree]) - def primitivesFromClassDef: ClassDef => ClassInfo = { + def getClassInfo: ClassDef => ClassInfo = { case q"case class $name(..$params) extends { ..$_ } with ..$parents {..$body} " => ClassInfo( name.asInstanceOf[c.TypeName], @@ -123,14 +226,15 @@ object Primitive { } def makePrimitiveClass : ClassInfo => ClassDef = { case ClassInfo(name, additionalParams, params, body, parents) => + val traverseMissing = + body.collectFirst({ case DefDef(_, TermName("traverse"), _, _, _, _) => ()}).isEmpty val visitAndRebuildMissing = body.collectFirst({ case DefDef(_, TermName("visitAndRebuild"), _, _, _, _) => ()}).isEmpty val generated = q""" - ${if (visitAndRebuildMissing) - makeVisitAndRebuild(name, additionalParams, params) - else q""} - """ + ${if (traverseMissing) makeTraverse(name, additionalParams, params, parents(0)) else q""} + ${if (visitAndRebuildMissing) makeVisitAndRebuild(name, additionalParams, params) else q""} + """ val expClass = (additionalParams match { case List() => diff --git a/src/main/scala/rise/core/traverse.scala b/src/main/scala/rise/core/traverse.scala index d15e0d7d5..148e854ad 100644 --- a/src/main/scala/rise/core/traverse.scala +++ b/src/main/scala/rise/core/traverse.scala @@ -10,7 +10,7 @@ object traverse { case object Binding extends VarType case object Reference extends VarType - trait Traversal[M[_]] { + trait Traversal[M[+_]] { protected[this] implicit def monad : Monad[M] def return_[T] : T => M[T] = monad.return_ def bind[T,S] : M[T] => (T => M[S]) => M[S] = monad.bind @@ -187,13 +187,13 @@ object traverse { } } - trait ExprTraversal[M[_]] extends Traversal[M] { + trait ExprTraversal[M[+_]] extends Traversal[M] { override def `type`[T <: Type] : T => M[T] = return_ } trait PureTraversal extends Traversal[Pure] {override def monad : PureMonad.type = PureMonad } trait PureExprTraversal extends PureTraversal with ExprTraversal[Pure] - trait AccumulatorTraversal[F,M[_]] extends Traversal[InMonad[M]#SetFst[F]#Type] { + trait AccumulatorTraversal[F,M[+_]] extends Traversal[InMonad[M]#SetFst[F]#Type] { type Pair[T] = InMonad[M]#SetFst[F]#Type[T] implicit val accumulator : Monoid[F] implicit val wrapperMonad : Monad[M] @@ -211,6 +211,6 @@ object traverse { def traverse[T <: Type](t: T, f: PureTraversal): T = f.`type`(t).unwrap def traverse[F](e: Expr, f: PureAccumulatorTraversal[F]): (F, Expr) = f.expr(e).unwrap def traverse[F,T <: Type](t: T, f: PureAccumulatorTraversal[F]): (F, T) = f.`type`(t).unwrap - def traverse[M[_]](e: Expr, f: Traversal[M]): M[Expr] = f.expr(e) - def traverse[T <: Type, M[_]](e: T, f: Traversal[M]): M[T] = f.`type`(e) + def traverse[M[+_]](e: Expr, f: Traversal[M]): M[Expr] = f.expr(e) + def traverse[T <: Type, M[+_]](e: T, f: Traversal[M]): M[T] = f.`type`(e) } diff --git a/src/main/scala/shine/DPIA/Phrases/Phrase.scala b/src/main/scala/shine/DPIA/Phrases/Phrase.scala index 57dcfd0fb..a982cd1c2 100644 --- a/src/main/scala/shine/DPIA/Phrases/Phrase.scala +++ b/src/main/scala/shine/DPIA/Phrases/Phrase.scala @@ -2,10 +2,12 @@ package shine.DPIA.Phrases import arithexpr.arithmetic.{NamedVar, RangeAdd} import shine.DPIA.Lifting.{liftDependentFunction, liftFunction, liftPair} +import shine.DPIA.Phrases.traverse._ import shine.DPIA.Types._ import shine.DPIA.Types.TypeCheck._ import shine.DPIA._ import shine.DPIA.primitives.functional.NatAsIndex +import util.monads.Pure sealed trait Phrase[T <: PhraseType] { val t: T @@ -43,6 +45,7 @@ final case class DepLambda[K <: Kind, T <: PhraseType](x: K#I, body: Phrase[T]) extends Phrase[K `()->:` T] { override val t: DepFunType[K, T] = DepFunType[K, T](x, body.t) override def toString: String = s"Λ(${x.name} : ${kn.get}). $body" + val kindName : KindName[K] = implicitly(kn) } object DepLambda { @@ -134,7 +137,7 @@ object Phrase { `for`: Phrase[T1], in: Phrase[T2]): Phrase[T2] = { var substCounter = 0 - object Visitor extends VisitAndRebuild.Visitor { + object Visitor extends PureTraversal { def renaming[X <: PhraseType](p: Phrase[X]): Phrase[X] = { case class Renaming(idMap: Map[String, String]) extends VisitAndRebuild.Visitor { override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = p match { @@ -151,8 +154,7 @@ object Phrase { } override def nat[N <: Nat](n: N): N = n.visitAndRebuild({ - case i: NatIdentifier => - NatIdentifier(idMap.getOrElse(i.name, i.name)) + case i: NatIdentifier => NatIdentifier(idMap.getOrElse(i.name, i.name)) case ae => ae }).asInstanceOf[N] @@ -164,13 +166,13 @@ object Phrase { } VisitAndRebuild(p, Renaming(Map())) } - override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = { - p match { + + override def phrase[T <: PhraseType]: Phrase[T] => Pure[Phrase[T]] = { case `for` => val newPh = if (substCounter == 0) ph else renaming(ph) substCounter += 1 - Stop(newPh.asInstanceOf[Phrase[T]]) - case Natural(n) => + return_(newPh.asInstanceOf[Phrase[T]]) + case p@Natural(n) => val v = NatIdentifier(`for` match { case Identifier(name, _) => name case _ => throw new Exception("This should never happen") @@ -178,19 +180,18 @@ object Phrase { ph.t match { case ExpType(NatType, _) => - Stop(Natural(Nat.substitute( + return_(Natural(Nat.substitute( Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]]) case ExpType(IndexType(_), _) => - Stop(Natural(Nat.substitute( + return_(Natural(Nat.substitute( Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]]) - case _ => Continue(p, this) + case _ => super.phrase(p) } - case _ => Continue(p, this) - } + case p => super.phrase(p) } } - VisitAndRebuild(in, Visitor) + Visitor.phrase(in).unwrap } def substitute[T2 <: PhraseType](substitutionMap: Map[Phrase[_], Phrase[_]], @@ -361,6 +362,9 @@ object Phrase { sealed trait Primitive[T <: PhraseType] extends Phrase[T] { def prettyPrint: String = this.toString + def traverse[M[+_]](f: Traversal[M]): M[Phrase[T]] = + throw new Exception("traverse should be implemented by a macro") + def visitAndRebuild(f: VisitAndRebuild.Visitor): Phrase[T] = throw new Exception("visitAndRebuild should be implemented by a macro") } diff --git a/src/main/scala/shine/DPIA/Phrases/traverse.scala b/src/main/scala/shine/DPIA/Phrases/traverse.scala new file mode 100644 index 000000000..0bd2477b8 --- /dev/null +++ b/src/main/scala/shine/DPIA/Phrases/traverse.scala @@ -0,0 +1,190 @@ +package shine.DPIA.Phrases + +import scala.language.implicitConversions +import util.monads._ +import shine.DPIA.Types._ +import shine.DPIA._ + +object traverse { + trait ExprTraversal[M[+_]] extends Traversal[M] { + override def `type`[T <: PhraseType] : T => M[T] = return_ + } + trait PureTraversal extends Traversal[Pure] { override def monad = PureMonad } + trait PureExprTraversal extends PureTraversal with ExprTraversal[Pure] + + def traverse[T <: PhraseType](e : Phrase[T], f : PureTraversal) : Phrase[T] = f.phrase(e).unwrap + def traverse[T <: PhraseType, M[+_]](e : Phrase[T], f : Traversal[M]) : M[Phrase[T]] = f.phrase(e) + def traverse[T <: PhraseType] (t : T, f : PureTraversal) : T = f.`type`(t).unwrap + def traverse[T <: PhraseType, M[+_]](e : T, f : Traversal[M]) : M[T] = f.`type`(e) + + sealed trait VarType + case object Binding extends VarType + case object Reference extends VarType + + trait Traversal[M[+_]] { + implicit def monad: Monad[M] + def return_[T]: T => M[T] = monad.return_ + def bind[T, S]: M[T] => (T => M[S]) => M[S] = monad.bind + + def nat[N <: Nat] : N => M[N] = return_ + def identifier[T <: PhraseType]: VarType => Identifier[T] => M[Identifier[T]] = _ => i => + for {t1 <- `type`(i.t)} + yield Identifier(i.name, t1) + def typeIdentifier[I <: Kind.Identifier]: VarType => I => M[I] = _ => return_ + def typeIdentifierDispatch[I <: Kind.Identifier]: VarType => I => M[I] = vt => i => (i match { + case n: NatIdentifier => bind(typeIdentifier(vt)(n))(nat) + case dt: DataTypeIdentifier => bind(typeIdentifier(vt)(dt))(datatype) + case a: AddressSpaceIdentifier => bind(typeIdentifier(vt)(a))(addressSpace) + case ac: AccessTypeIdentifier => bind(typeIdentifier(vt)(ac))(accessType) + case n2n: NatToNatIdentifier => bind(typeIdentifier(vt)(n2n))(natToNat) + case n2d: NatToDataIdentifier => bind(typeIdentifier(vt)(n2d))(natToData) + }).asInstanceOf[M[I]] + def natDispatch : Nat => M[Nat] = { + case i : NatIdentifier => bind(typeIdentifier(Reference)(i))(nat) + case n => nat(n) + } + + def addressSpace: AddressSpace => M[AddressSpace] = return_ + def accessType: AccessType => M[AccessType] = return_ + def data: Data => M[Data] = { + case VectorData(vd) => return_(VectorData(vd) : Data) + case NatData(n) => + for { n1 <- natDispatch(n) } + yield NatData(n1) + case IndexData(i, n) => + for { i1 <- natDispatch(i); n1 <- natDispatch(n) } + yield IndexData(i1, n1) + case ArrayData(ad) => + for { ad1 <- monad.traverseV(ad.map(data)) } + yield ArrayData(ad1) + case PairData(l, r) => + for { l1 <- data(l); r1 <- data(r) } + yield PairData(l1, r1) + case d => return_(d) + } + + def datatype[ D <:DataType] : D => M[D] = d => (d match { + case NatType => return_(NatType) + case s : ScalarType => return_(s) + case IndexType(size) => + for {n1 <- natDispatch(size)} + yield IndexType(n1) + case ArrayType(size, dt) => + for {n1 <- natDispatch(size); dt1 <- datatype(dt)} + yield ArrayType(n1, dt1) + case DepArrayType(n, n2d) => + for {n1 <- natDispatch(n); n2d1 <- natToData(n2d)} + yield DepArrayType(n1, n2d1) + case VectorType(size, dt) => + for {n1 <- natDispatch(size); dt1 <- datatype(dt)} + yield VectorType(n1, dt1) + case PairType(l, r) => + for {l1 <- datatype(l); r1 <- datatype(r)} + yield PairType(l1, r1) + case pair@DepPairType(x, e) => + for {x1 <- typeIdentifierDispatch(Binding)(x); e1 <- datatype(e)} + yield DepPairType(x1, e1) + case NatToDataApply(ntdf, n) => + for {ntdf1 <- natToData(ntdf); n1 <- natDispatch(n)} + yield NatToDataApply(ntdf1, n1) + case FragmentType(rs, cs, d3, dt, fk, l) => + for {rs1 <- natDispatch(rs); cs1 <- natDispatch(cs); d31 <- natDispatch(d3); dt1 <- datatype(dt)} + yield FragmentType(rs1, cs1, d31, dt1, fk, l) + + }).asInstanceOf[M[D]] + + def natToNat: NatToNat => M[NatToNat] = { + case i: NatToNatIdentifier => return_(i : NatToNat) + case NatToNatLambda(n, body) => + for {n1 <- typeIdentifierDispatch(Binding)(n); body1 <- natDispatch(body)} + yield NatToNatLambda(n1, body1) + } + + def natToData: NatToData => M[NatToData] = { + case i: NatToDataIdentifier => return_(i) + case NatToDataLambda(n, body) => + for {n1 <- typeIdentifierDispatch(Binding)(n); body1 <- datatype(body)} + yield NatToDataLambda(n1, body1) + } + + def phrase[T <: PhraseType]: Phrase[T] => M[Phrase[T]] = { + case i: Identifier[T] => for {i1 <- identifier(Reference)(i)} yield i1 + case Lambda(x, p) => + for {x1 <- identifier(Binding)(x); p1 <- phrase(p)} + yield Lambda(x1, p1) + case Apply(p, q) => + for {p1 <- phrase(p); q1 <- phrase(q)} + yield Apply(p1, q1) + case dl@DepLambda(i, p) => + for {i1 <- typeIdentifierDispatch(Binding)(i); p1 <- phrase(p)} + yield DepLambda(i1, p1)(dl.kindName) + case da@DepApply(f, x) => x match { + case n: Nat => + for {f1 <- phrase(f); n1 <- natDispatch(n)} + yield DepApply[NatKind, T](f1.asInstanceOf[Phrase[NatKind `()->:` T]], n1) + case dt: DataType => + for {f1 <- phrase(f); dt1 <- datatype(dt)} + yield DepApply[DataKind, T](f1.asInstanceOf[Phrase[DataKind `()->:` T]], dt1) + case a: AddressSpace => + for {f1 <- phrase(f); a1 <- addressSpace(a)} + yield DepApply[AddressSpaceKind, T](f1.asInstanceOf[Phrase[AddressSpaceKind `()->:` T]], a1) + case n2n: NatToNat => + for {f1 <- phrase(f); n2n1 <- natToNat(n2n)} + yield DepApply[NatToNatKind, T](f1.asInstanceOf[Phrase[NatToNatKind `()->:` T]], n2n1) + case n2d: NatToData => + for {f1 <- phrase(f); n2d1 <- natToData(n2d)} + yield DepApply[NatToDataKind, T](f1.asInstanceOf[Phrase[NatToDataKind `()->:` T]], n2d1) + } + case LetNat(binder, defn, body) => + for {defn1 <- phrase(defn); body1 <- phrase(body)} + yield LetNat(binder, defn1, body1) + case PhrasePair(p, q) => + for {p1 <- phrase(p); q1 <- phrase(q)} + yield PhrasePair(p1, q1) + case Proj1(p) => + for {p1 <- phrase(p)} + yield Proj1(p1) + case Proj2(p) => + for {p1 <- phrase(p)} + yield Proj2(p1) + case IfThenElse(cond, thenP, elseP) => + for {cond1 <- phrase(cond); thenP1 <- phrase(thenP); elseP1 <- phrase(elseP)} + yield IfThenElse(cond1, thenP1, elseP1) + case Literal(d) => + for {d1 <- data(d)} + yield Literal(d1) + case Natural(n) => + for {n1 <- natDispatch(n)} + yield Natural(n1) + case UnaryOp(op, x) => + for {x1 <- phrase(x)} + yield UnaryOp(op, x1) + case BinOp(op, lhs, rhs) => + for {lhs1 <- phrase(lhs); rhs1 <- phrase(rhs)} + yield BinOp(op, lhs1, rhs1) + case c: Primitive[T] => c.traverse(this) + } + + def `type`[T <: PhraseType] : T => M[T] = t => (t match { + case CommType() => return_(CommType()) + case ExpType(dt, w) => + for {dt1 <- datatype(dt); w1 <- accessType(w)} + yield ExpType(dt1, w1) + case AccType(dt) => + for {dt1 <- datatype(dt)} + yield AccType(dt1) + case PhrasePairType(l, r) => + for {l1 <- `type`(l); r1 <- `type`(r)} + yield PhrasePairType(l1, r1) + case FunType(inT, outT) => + for {inT1 <- `type`(inT); outT1 <- `type`(outT)} + yield FunType(inT1, outT1) + case PassiveFunType(inT, outT) => + for {inT1 <- `type`(inT); outT1 <- `type`(outT)} + yield PassiveFunType(inT1, outT1) + case df@DepFunType(x, t) => + for {x1 <- typeIdentifierDispatch(Binding)(x); t1 <- `type`(t)} + yield DepFunType(x1, t1)(df.kindName) + }).asInstanceOf[M[T]] + } +} \ No newline at end of file diff --git a/src/main/scala/shine/DPIA/Types/PhraseType.scala b/src/main/scala/shine/DPIA/Types/PhraseType.scala index 828d3a3dc..bd6db99f8 100644 --- a/src/main/scala/shine/DPIA/Types/PhraseType.scala +++ b/src/main/scala/shine/DPIA/Types/PhraseType.scala @@ -45,6 +45,7 @@ final case class DepFunType[K <: Kind, +R <: PhraseType](x: K#I, t: R) (implicit val kn: KindName[K]) extends PhraseType { override def toString = s"(${x.name}: ${kn.get}) -> $t" + val kindName = implicitly(kn) } object PhraseType { diff --git a/src/main/scala/shine/DPIA/fromRise.scala b/src/main/scala/shine/DPIA/fromRise.scala index 285c24129..16e59d949 100644 --- a/src/main/scala/shine/DPIA/fromRise.scala +++ b/src/main/scala/shine/DPIA/fromRise.scala @@ -1,5 +1,6 @@ package shine.DPIA +import arithexpr.arithmetic._ import elevate.core.strategies.Traversable import elevate.core.strategies.basic.normalize import rise.elevate.Rise @@ -56,15 +57,15 @@ object fromRise { arg) x match { - case n: Nat => depApp[NatKind](f, n) + case n: rt.Nat => depApp[NatKind](f, nat(n)) case dt: rt.DataType => depApp[DataKind](f, dataType(dt)) case a: rt.AddressSpace => depApp[AddressSpaceKind](f, addressSpace(a)) case n2n: rt.NatToNat => depApp[NatToNatKind](f, nat2nat(n2n)) } case r.Literal(d) => d match { - case rs.NatData(n) => Natural(n) - case rs.IndexData(i, n) => NatAsIndex(n, Natural(i)) + case rs.NatData(n) => Natural(nat(n)) + case rs.IndexData(i, n) => NatAsIndex(nat(n), Natural(nat(i))) case _ => Literal(data(d)) } @@ -82,8 +83,8 @@ object fromRise { case rs.FloatData(f) => FloatData(f) case rs.DoubleData(d) => DoubleData(d) case rs.VectorData(v) => VectorData(v.map(data(_)).toVector) - case rs.IndexData(i, n) => IndexData(i, n) - case rs.NatData(n) => NatData(n) + case rs.IndexData(i, n) => IndexData(nat(i), nat(n)) + case rs.NatData(n) => NatData(nat(n)) } import rise.core.{primitives => core} @@ -946,22 +947,22 @@ object fromRise { case rt.NatToNatIdentifier(name, _) => NatToNatIdentifier(name) case rt.NatToNatLambda(x, body) => - NatToNatLambda(x.range, NatIdentifier(x.name), body) + NatToNatLambda(x.range, NatIdentifier(x.name), nat(body)) } def dataType(t: rt.DataType): DataType = t match { case st: rt.ScalarType => scalarType(st) case rt.NatType => NatType - case rt.IndexType(sz) => IndexType(sz) + case rt.IndexType(sz) => IndexType(nat(sz)) case rt.VectorType(sz, et) => et match { - case e : rt.ScalarType => VectorType(sz, scalarType(e)) + case e : rt.ScalarType => VectorType(nat(sz), scalarType(e)) case _ => ??? } case i: rt.DataTypeIdentifier => dataTypeIdentifier(i) - case rt.ArrayType(sz, et) => ArrayType(sz, dataType(et)) - case rt.DepArrayType(sz, f) => DepArrayType(sz, ntd(f)) + case rt.ArrayType(sz, et) => ArrayType(nat(sz), dataType(et)) + case rt.DepArrayType(sz, f) => DepArrayType(nat(sz), ntd(f)) case rt.PairType(a, b) => PairType(dataType(a), dataType(b)) - case rt.NatToDataApply(f, n) => NatToDataApply(ntd(f), n) + case rt.NatToDataApply(f, n) => NatToDataApply(ntd(f), nat(n)) case rt.DepPairType(x, t) => x match { case x:rt.NatIdentifier => DepPairType(natIdentifier(x), dataType(t)) @@ -970,11 +971,11 @@ object fromRise { case f: rt.FragmentType => f.fragmentKind match { case rt.FragmentKind.AMatrix => - FragmentType(f.rows, f.d3, f.columns, dataType(f.dataType), FragmentKind.AMatrix, layout(f.layout)) + FragmentType(nat(f.rows), nat(f.d3), nat(f.columns), dataType(f.dataType), FragmentKind.AMatrix, layout(f.layout)) case rt.FragmentKind.BMatrix => - FragmentType(f.d3, f.columns, f.rows, dataType(f.dataType), FragmentKind.BMatrix, layout(f.layout)) + FragmentType(nat(f.d3), nat(f.columns), nat(f.rows), dataType(f.dataType), FragmentKind.BMatrix, layout(f.layout)) case rt.FragmentKind.Acuumulator => - FragmentType(f.rows, f.columns, f.d3, dataType(f.dataType), FragmentKind.Accumulator, layout(f.layout)) + FragmentType(nat(f.rows), nat(f.columns), nat(f.d3), dataType(f.dataType), FragmentKind.Accumulator, layout(f.layout)) case _ => throw new Exception("this should not happen") } } @@ -1011,10 +1012,15 @@ object fromRise { } def ntn(ntn: rt.NatToNat): NatToNat= ntn match { - case rt.NatToNatLambda(n, body) => NatToNatLambda(natIdentifier(n), body) + case rt.NatToNatLambda(n, body) => NatToNatLambda(natIdentifier(n), nat(body)) case rt.NatToNatIdentifier(x, _) => NatToNatIdentifier(x) } + def nat(n : rt.Nat) : Nat = n.visitAndRebuild{ + case i : NamedVar => NatIdentifier(i.name, i.range) + case e => e + } + def dataTypeIdentifier(dt: rt.DataTypeIdentifier): DataTypeIdentifier = DataTypeIdentifier(dt.name) def natIdentifier(n: rt.NatIdentifier): NatIdentifier = diff --git a/src/main/scala/shine/DPIA/package.scala b/src/main/scala/shine/DPIA/package.scala index 55a6515a1..306caf9ee 100644 --- a/src/main/scala/shine/DPIA/package.scala +++ b/src/main/scala/shine/DPIA/package.scala @@ -16,16 +16,14 @@ package object DPIA { } type Nat = ArithExpr - type NatIdentifier = NamedVar with Kind.Identifier + case class NatIdentifier(override val name : String, override val range : Range = RangeUnknown) extends NamedVar(name, range) with Kind.Identifier { + override def withRange(r : Range) : NatIdentifier = NatIdentifier(name, r) + } object Nat { - def substitute[N <: Nat](ae: Nat, `for`: NatIdentifier, in: N): N = + def substitute[N <: Nat](ae: Nat, `for`: NatIdentifier, in: N): N = { ArithExpr.substitute(in, Map(`for` -> ae)).asInstanceOf[N] - } - - object NatIdentifier { - def apply(name: String): NatIdentifier = new NamedVar(name) with Kind.Identifier - def apply(name: String, range: Range): NatIdentifier = new NamedVar(name, range) with Kind.Identifier + } } // note: this is an easy fix to avoid name conflicts between lift and dpia diff --git a/src/main/scala/util/monads.scala b/src/main/scala/util/monads.scala index 246fd7f3c..a7f03a52e 100644 --- a/src/main/scala/util/monads.scala +++ b/src/main/scala/util/monads.scala @@ -4,20 +4,23 @@ import scala.collection.immutable.HashSet import scala.language.implicitConversions object monads { - trait Monad[M[_]] { + trait Monad[M[+_]] { def return_[T] : T => M[T] def bind[T,S] : M[T] => (T => M[S]) => M[S] def traverse[A] : Seq[M[A]] => M[Seq[A]] = _.foldRight(return_(Nil : Seq[A]))({case (mx, mxs) => bind(mx)(x => bind(mxs)(xs => return_(x +: xs)))}) + def traverseV[A] : Vector[M[A]] => M[Vector[A]] = + _.foldRight(return_(Vector.empty : Vector[A]))({case (mx, mxs) => + bind(mx)(x => bind(mxs)(xs => return_(x +: xs)))}) } - implicit def monadicSyntax[M[_], A](m: M[A])(implicit tc: Monad[M]) = new { + implicit def monadicSyntax[M[+_], A](m: M[A])(implicit tc: Monad[M]) = new { def map[B](f: A => B): M[B] = tc.bind(m)(a => tc.return_(f(a)) ) def flatMap[B](f: A => M[B]): M[B] = tc.bind(m)(f) } - case class Pure[T](unwrap : T) + case class Pure[+T](unwrap : T) implicit object PureMonad extends Monad[Pure] { override def return_[T] : T => Pure[T] = t => Pure(t) override def bind[T,S] : Pure[T] => (T => Pure[S]) => Pure[S] = @@ -61,9 +64,9 @@ object monads { } } - trait InMonad[M[_]] { trait SetFst[F] { type Type[S] = M[Tuple2[F, S]] } } - trait PairMonoidMonad[F, M[_]] extends Monad[InMonad[M]#SetFst[F]#Type] { - type Pair[T] = InMonad[M]#SetFst[F]#Type[T] + trait InMonad[M[+_]] { trait SetFst[F] { type Type[+S] = M[Tuple2[F, S]] } } + trait PairMonoidMonad[F, M[+_]] extends Monad[InMonad[M]#SetFst[F]#Type] { + type Pair[+T] = InMonad[M]#SetFst[F]#Type[T] implicit val monoid : Monoid[F] implicit val monad : Monad[M] override def return_[T]: T => Pair[T] = t => monad.return_((monoid.empty, t))