diff --git a/main/src/main/scala/org/clulab/dynet/ConstEmbeddingsGlove.scala b/main/src/main/scala/org/clulab/dynet/ConstEmbeddingsGlove.scala index 14de7d8e0..2b86006d9 100644 --- a/main/src/main/scala/org/clulab/dynet/ConstEmbeddingsGlove.scala +++ b/main/src/main/scala/org/clulab/dynet/ConstEmbeddingsGlove.scala @@ -24,7 +24,7 @@ object ConstEmbeddingsGlove { val logger:Logger = LoggerFactory.getLogger(classOf[ConstEmbeddingsGlove]) // This is not marked private for debugging purposes - private var SINGLETON_WORD_EMBEDDING_MAP: Option[WordEmbeddingMap] = None + var SINGLETON_WORD_EMBEDDING_MAP: Option[WordEmbeddingMap] = None // make sure the singleton is loaded load() diff --git a/main/src/main/scala/org/clulab/dynet/TestOnnx.scala b/main/src/main/scala/org/clulab/dynet/TestOnnx.scala index 3690142a3..3de1b2bcf 100644 --- a/main/src/main/scala/org/clulab/dynet/TestOnnx.scala +++ b/main/src/main/scala/org/clulab/dynet/TestOnnx.scala @@ -1,50 +1,70 @@ package org.clulab.dynet -import org.clulab.embeddings.{CompactWordEmbeddingMap, WordEmbeddingMapPool} - -import java.io.{FileWriter, PrintWriter} - +import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} import com.typesafe.config.ConfigFactory -import org.clulab.dynet.Utils._ -import org.clulab.utils.StringUtils +import org.clulab.utils.{StringUtils, Timer} import scala.io.Source import scala.util.parsing.json._ - -import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} -import org.slf4j.{Logger, LoggerFactory} - import java.time.LocalDateTime import java.time.Duration -import scala.io.Source - - object TestOnnx extends App { + class TextEmbedder(filename: String) { + def get_embeddings(embed_file_path: String): Map[String,Array[Float]]={ - val emb = Source.fromFile(embed_file_path) - var emb_map:Map[String,Array[Float]] = Map() - for (s<-emb.getLines){ - if (s.split(" ")(0) == ""){ - emb_map += (""-> s.split(" ").slice(1, s.split(" ").size).map(_.toFloat)) - }else{ - emb_map += (s.split(" ")(0) -> s.split(" ").slice(1, s.split(" ").size).map(_.toFloat)) - } + val emb = Source.fromFile(embed_file_path) + var emb_map:Map[String,Array[Float]] = Map() + for (s<-emb.getLines){ + if (s.split(" ")(0) == ""){ + // TODO: These probably need to be normalized in both cases. + emb_map += (""-> s.split(" ").slice(1, s.split(" ").size).map(_.toFloat)) + }else{ + emb_map += (s.split(" ")(0) -> s.split(" ").slice(1, s.split(" ").size).map(_.toFloat)) } - emb_map + } + emb_map } + + protected val map = { + val timer = new Timer("get_embeddings") + val result = timer.time(get_embeddings(filename)) + + println(timer.toString) + result + } + protected val unknown = map("") + + def apply(key: String): Array[Float] = map.getOrElse(key, unknown) + } + + class ProcessorsEmbedder() { + val map = { + val timer = new Timer("SINGLETON_WORD_EMBEDDING_MAP") + val result = timer.time(ConstEmbeddingsGlove.SINGLETON_WORD_EMBEDDING_MAP.get) + + println(timer.toString) + result + } + + def apply(key: String): Array[Float] = map.getOrElseUnknown(key).toArray + } + val start_time = LocalDateTime.now() val props = StringUtils.argsToProperties(args) val configName = props.getProperty("conf") val config = ConfigFactory.load(configName) val taskManager = new TaskManager(config) - - val embed_file_path: String = "/data1/home/zheng/processors/main/src/main/python/glove.840B.300d.10f.txt" - val wordEmbeddingMap = get_embeddings(embed_file_path) + + // Pick one of these. + val embedder = new TextEmbedder("/data1/home/zheng/processors/main/src/main/python/glove.840B.300d.10f.txt") + // val embedder = new TextEmbedder("../glove/glove.840B.300d.10f.txt") + // val embedder = new ProcessorsEmbedder() val jsonString = Source.fromFile("/data1/home/zheng/processors/ner.json").getLines.mkString + // val jsonString = Source.fromFile("../onnx/ner.json").getLines.mkString val parsed = JSON.parseFull(jsonString) val w2i = parsed.get.asInstanceOf[List[Any]](0).asInstanceOf[Map[String, Any]]("x2i").asInstanceOf[Map[String, Any]]("initialLayer").asInstanceOf[Map[String, Any]]("w2i").asInstanceOf[Map[String, Double]] val c2i = parsed.get.asInstanceOf[List[Any]](0).asInstanceOf[Map[String, Any]]("x2i").asInstanceOf[Map[String, Any]]("initialLayer").asInstanceOf[Map[String, Any]]("c2i").asInstanceOf[Map[String, Double]] @@ -53,8 +73,10 @@ object TestOnnx extends App { val ortEnvironment = OrtEnvironment.getEnvironment val modelpath1 = "/data1/home/zheng/processors/char.onnx" + /// val modelpath1 = "../onnx/char.onnx" val session1 = ortEnvironment.createSession(modelpath1, new OrtSession.SessionOptions) val modelpath2 = "/data1/home/zheng/processors/model.onnx" + /// val modelpath2 = "../onnx/model.onnx" val session2 = ortEnvironment.createSession(modelpath2, new OrtSession.SessionOptions) println(session1.getOutputInfo) @@ -81,7 +103,7 @@ object TestOnnx extends App { var char_embs:Array[Array[Float]] = new Array[Array[Float]](words.length) for(i <- words.indices){ val word = words(i) - embeddings(i) = wordEmbeddingMap.getOrElse(word,wordEmbeddingMap.get( "").get) + embeddings(i) = embedder(word) wordIds(i) = w2i.getOrElse(word, 0).asInstanceOf[Number].longValue val char_input = new java.util.HashMap[String, OnnxTensor]() char_input.put("char_ids", OnnxTensor.createTensor(ortEnvironment, word.map(c => c2i.getOrElse(c.toString, 0).asInstanceOf[Number].longValue).toArray))