|
| 1 | +package edu.umass.cs.iesl.apassos |
| 2 | + |
| 3 | +import cc.factorie.app.nlp._ |
| 4 | +import cc.factorie.app.nlp.pos.PTBPosLabel |
| 5 | +import cc.factorie.{FeatureVectorVariable, CategoricalTensorDomain, LabeledCategoricalVariable, CategoricalDomain} |
| 6 | +import cc.factorie.optimize.{LinearMultiClassClassifier, OnlineLinearMultiClassTrainer} |
| 7 | +import scala.annotation.tailrec |
| 8 | + |
| 9 | +/** |
| 10 | + * User: apassos |
| 11 | + * Date: 9/10/13 |
| 12 | + * Time: 5:43 PM |
| 13 | + */ |
| 14 | + |
| 15 | +object ChunkLabelDomain extends CategoricalDomain[String] { |
| 16 | + this ++= Seq("IN", "OUT") |
| 17 | + freeze() |
| 18 | +} |
| 19 | +class ChunkLabel(target: String) extends LabeledCategoricalVariable[String](target) { |
| 20 | + def domain = ChunkLabelDomain |
| 21 | +} |
| 22 | + |
| 23 | +class Lecture1Chunker extends DocumentAnnotator { |
| 24 | + def prereqAttrs = Seq(classOf[PTBPosLabel]) |
| 25 | + def postAttrs = Seq(classOf[ChunkLabel]) |
| 26 | + def tokenAnnotationString(token: Token) = token.attr[ChunkLabel].categoryValue |
| 27 | + |
| 28 | + val featuresDomain = new CategoricalTensorDomain[String] {} |
| 29 | + class TokenFeatures extends FeatureVectorVariable[String] { |
| 30 | + def domain = featuresDomain |
| 31 | + override def skipNonCategories = true |
| 32 | + } |
| 33 | + var model: LinearMultiClassClassifier = null |
| 34 | + |
| 35 | + def extractFeatures(t: Token): TokenFeatures = { |
| 36 | + val f = new TokenFeatures |
| 37 | + for (tt <- t.prevWindow(3) ++ t.nextWindow(3); dif = t.positionInSection - tt.positionInSection) { |
| 38 | + f += s"STRING@$dif=${tt.string}" |
| 39 | + f += s"STRING@$dif=${tt.posLabel}" |
| 40 | + } |
| 41 | + f |
| 42 | + } |
| 43 | + |
| 44 | + def process(document: Document) = { |
| 45 | + document.tokens.foreach(t => { |
| 46 | + val feats = extractFeatures(t) |
| 47 | + t.attr += new ChunkLabel(ChunkLabelDomain.categories(model.classification(feats.value).bestLabelIndex)) |
| 48 | + }) |
| 49 | + document |
| 50 | + } |
| 51 | + |
| 52 | + @tailrec |
| 53 | + private final def isInChunk(s: Sentence, t: Token): Boolean = { |
| 54 | + if (t.posLabel.categoryValue.startsWith("N")) true |
| 55 | + else if (s.parse.parent(t) == null) false |
| 56 | + else isInChunk(s, s.parse.parent(t)) |
| 57 | + } |
| 58 | + |
| 59 | + def train(trainSentences: Iterable[Sentence], testSentences: Iterable[Sentence])(implicit rng: scala.util.Random) { |
| 60 | + val trainLabels = trainSentences.flatMap(s => s.tokens.map(t => new ChunkLabel(if (isInChunk(s, t)) "IN" else "OUT"))).toSeq |
| 61 | + val testLabels = testSentences.flatMap(s => s.tokens.map(t => new ChunkLabel(if (isInChunk(s, t)) "IN" else "OUT"))).toSeq |
| 62 | + val trainFeatures = trainSentences.flatMap(s => s.tokens.map(extractFeatures)).toSeq |
| 63 | + featuresDomain.freeze() |
| 64 | + val testFeatures = testSentences.flatMap(s => s.tokens.map(extractFeatures)).toSeq |
| 65 | + val trainer = new OnlineLinearMultiClassTrainer() |
| 66 | + model = trainer.train(trainLabels, trainFeatures, testLabels, testFeatures) |
| 67 | + } |
| 68 | + |
| 69 | +} |
| 70 | + |
| 71 | +object Lecture1Chunker { |
| 72 | + |
| 73 | + def main(args: Array[String]) = { |
| 74 | + implicit val rng = new scala.util.Random(0) |
| 75 | + val trainDoc = LoadOntonotes5.fromFilename("/iesl/canvas/mccallum/data/ontonotes-en-1.1.0/trn-pmd/nw-wsj-trn.dep.pmd").head |
| 76 | + val testDoc = LoadOntonotes5.fromFilename("/iesl/canvas/mccallum/data/ontonotes-en-1.1.0/dev-pmd/nw-wsj-24.dep.pmd").head |
| 77 | + val model = new Lecture1Chunker |
| 78 | + model.train(trainDoc.sentences, testDoc.sentences) |
| 79 | + model.process(testDoc) |
| 80 | + println(testDoc.owplString(Seq(model.tokenAnnotationString))) |
| 81 | + } |
| 82 | +} |
0 commit comments