diff --git a/inception/inception-recommendation-api/pom.xml b/inception/inception-recommendation-api/pom.xml
index 47117f12745..0a0cc58d6e9 100644
--- a/inception/inception-recommendation-api/pom.xml
+++ b/inception/inception-recommendation-api/pom.xml
@@ -62,6 +62,10 @@
de.tudarmstadt.ukp.inception.app
inception-preferences
+
+ de.tudarmstadt.ukp.inception.app
+ inception-api-annotation
+
org.apache.uima
diff --git a/inception/inception-recommendation-api/src/main/java/de/tudarmstadt/ukp/inception/recommendation/api/recommender/RecommendationEngine.java b/inception/inception-recommendation-api/src/main/java/de/tudarmstadt/ukp/inception/recommendation/api/recommender/RecommendationEngine.java
index 82af94891b2..b60a7e1b8c4 100644
--- a/inception/inception-recommendation-api/src/main/java/de/tudarmstadt/ukp/inception/recommendation/api/recommender/RecommendationEngine.java
+++ b/inception/inception-recommendation-api/src/main/java/de/tudarmstadt/ukp/inception/recommendation/api/recommender/RecommendationEngine.java
@@ -21,6 +21,7 @@
import static de.tudarmstadt.ukp.inception.recommendation.api.RecommendationService.FEATURE_NAME_IS_PREDICTION;
import static de.tudarmstadt.ukp.inception.recommendation.api.RecommendationService.FEATURE_NAME_SCORE_EXPLANATION_SUFFIX;
import static de.tudarmstadt.ukp.inception.recommendation.api.RecommendationService.FEATURE_NAME_SCORE_SUFFIX;
+import static java.util.Collections.emptyList;
import static org.apache.uima.fit.util.CasUtil.getType;
import java.io.IOException;
@@ -31,6 +32,7 @@
import org.apache.uima.cas.Feature;
import org.apache.uima.cas.Type;
+import de.tudarmstadt.ukp.inception.annotation.events.AnnotationEvent;
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.DataSplitter;
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.EvaluationResult;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
@@ -239,4 +241,21 @@ public void exportModel(RecommenderContext aContext, OutputStream aOutput) throw
{
throw new UnsupportedOperationException("Model export not supported");
}
+
+ public List generateIncrementalTrainingInstances(AnnotationEvent aEvent)
+ {
+ return emptyList();
+ }
+
+ /**
+ * Store the given incremental training data into the given recommender context. The idea is
+ * that the engine then picks the training data up from the context on the next training run. It
+ * may then choose to use only the incremental training data during that run instead of training
+ * from scratch.
+ */
+ public void putIncrementalTrainingData(RecommenderContext aRecommenderContext,
+ List aIncrementalTrainingData)
+ {
+ // Nothing do to
+ }
}
diff --git a/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/model/DirtySpot.java b/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/model/DirtySpot.java
index 5bb9712a64d..990cf80b24f 100644
--- a/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/model/DirtySpot.java
+++ b/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/model/DirtySpot.java
@@ -17,11 +17,18 @@
*/
package de.tudarmstadt.ukp.inception.recommendation.model;
+import static java.util.Collections.unmodifiableMap;
+
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
import java.util.Objects;
import de.tudarmstadt.ukp.clarin.webanno.model.Project;
import de.tudarmstadt.ukp.clarin.webanno.model.SourceDocument;
import de.tudarmstadt.ukp.inception.annotation.events.AnnotationEvent;
+import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
+import de.tudarmstadt.ukp.inception.recommendation.api.recommender.TrainingInstance;
import de.tudarmstadt.ukp.inception.rendering.model.Range;
public class DirtySpot
@@ -29,12 +36,14 @@ public class DirtySpot
private final SourceDocument document;
private final String user;
private final Range affectedRange;
+ private final Map> incrementalTrainingData;
- public DirtySpot(AnnotationEvent aEvent)
+ public DirtySpot(AnnotationEvent aEvent, Map> aIncrementalTrainingData)
{
document = aEvent.getDocument();
user = aEvent.getDocumentOwner();
affectedRange = aEvent.getAffectedRange();
+ incrementalTrainingData = unmodifiableMap(new LinkedHashMap<>(aIncrementalTrainingData));
}
public Range getAffectedRange()
@@ -62,6 +71,11 @@ public Project getProject()
return document.getProject();
}
+ public Map> getIncrementalTrainingData()
+ {
+ return incrementalTrainingData;
+ }
+
@Override
public int hashCode()
{
diff --git a/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/service/RecommendationServiceImpl.java b/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/service/RecommendationServiceImpl.java
index 23a6f98921b..40b1ae9b39d 100644
--- a/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/service/RecommendationServiceImpl.java
+++ b/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/service/RecommendationServiceImpl.java
@@ -27,6 +27,7 @@
import static java.util.Optional.empty;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toMap;
+import static org.apache.commons.collections4.CollectionUtils.isEmpty;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
@@ -110,6 +111,7 @@
import de.tudarmstadt.ukp.inception.recommendation.api.model.SuggestionGroup;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngineFactory;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
+import de.tudarmstadt.ukp.inception.recommendation.api.recommender.TrainingInstance;
import de.tudarmstadt.ukp.inception.recommendation.config.RecommenderServiceAutoConfiguration;
import de.tudarmstadt.ukp.inception.recommendation.event.RecommenderDeletedEvent;
import de.tudarmstadt.ukp.inception.recommendation.event.RecommenderUpdatedEvent;
@@ -730,7 +732,38 @@ public void onAnnotation(AnnotationEvent aEvent)
requestCycle.setMetaData(DIRTIES, dirties);
}
- dirties.add(new DirtySpot(aEvent));
+ var incrementalTrainingData = new LinkedHashMap>();
+ var sessionOwner = userRepository.getCurrentUser();
+ var recommenders = getActiveRecommenders(sessionOwner, aEvent.getProject()).stream() //
+ .map(EvaluatedRecommender::getRecommender) //
+ .filter(Recommender::isEnabled) //
+ .filter(rec -> rec.getLayer() != null) //
+ .filter(rec -> rec.getLayer().isEnabled()) //
+ .filter(rec -> rec.getFeature().isEnabled()) //
+ .filter(rec -> rec.getLayer().equals(aEvent.getLayer())) //
+ .toList();
+
+ for (var recommender : recommenders) {
+ try {
+ var maybeFactory = getRecommenderFactory(recommender);
+ if (maybeFactory.isEmpty()) {
+ continue;
+ }
+
+ var factory = maybeFactory.get();
+ var engine = factory.build(recommender);
+
+ incrementalTrainingData.computeIfAbsent(recommender, $ -> new ArrayList<>()) //
+ .addAll(engine.generateIncrementalTrainingInstances(aEvent));
+ }
+ catch (Exception e) {
+ LOG.warn("Unable to collect incremental training data for active recommender {}",
+ recommender);
+ continue;
+ }
+ }
+
+ dirties.add(new DirtySpot(aEvent, incrementalTrainingData));
}
/*
@@ -946,15 +979,17 @@ private void triggerTraining(String aSessionOwner, Project aProject, String aEve
return;
}
- var user = userRepository.get(aSessionOwner);
- // do not trigger training during when viewing others' work
- if (user == null || !user.equals(userRepository.getCurrentUser())) {
+ // Do not trigger training during when viewing others' work
+ var sessionOwner = userRepository.get(aSessionOwner);
+ if (sessionOwner == null || !sessionOwner.equals(userRepository.getCurrentUser())) {
return;
}
+ commitIncrementalTrainingData(aSessionOwner, aDirties);
+
// Update the task count
var count = trainingTaskCounter.computeIfAbsent(
- new RecommendationStateKey(user.getUsername(), aProject.getId()),
+ new RecommendationStateKey(aSessionOwner, aProject.getId()),
_key -> new AtomicInteger(0));
// If there is no active recommender at all then let's try hard to make one active by
@@ -964,35 +999,70 @@ private void triggerTraining(String aSessionOwner, Project aProject, String aEve
}
if (aForceSelection || (count.getAndIncrement() % TRAININGS_PER_SELECTION == 0)) {
- // If it is time for a selection task, we just start a selection task.
- // The selection task then will start the training once its finished,
- // i.e. we do not start it here.
- schedulingService.enqueue(SelectionTask.builder() //
- .withSessionOwner(user) //
- .withProject(aProject) //
- .withTrigger(aEventName) //
- .withCurrentDocument(aCurrentDocument) //
- .withDataOwner(aDataOwner) //
- .build());
+ triggerSelectionRun(sessionOwner, aProject, aCurrentDocument, aDataOwner, aEventName);
+ return;
+ }
- var state = getState(aSessionOwner, aProject);
- synchronized (state) {
- state.setPredictionsUntilNextEvaluation(TRAININGS_PER_SELECTION - 1);
- state.setPredictionsSinceLastEvaluation(0);
+ triggerTrainingRun(sessionOwner, aProject, aCurrentDocument, aDataOwner, aEventName);
+ }
+
+ private void commitIncrementalTrainingData(String aSessionOwner, Set aDirties)
+ {
+ if (isEmpty(aDirties)) {
+ return;
+ }
+
+ var aggregatedIncrementalTrainingData = new LinkedHashMap>();
+ for (var dirtySpot : aDirties) {
+ for (var entry : dirtySpot.getIncrementalTrainingData().entrySet()) {
+ var recommender = entry.getKey();
+ aggregatedIncrementalTrainingData
+ .computeIfAbsent(recommender, $ -> new ArrayList<>()) //
+ .addAll(entry.getValue());
+ }
+ }
+
+ for (var entry : aggregatedIncrementalTrainingData.entrySet()) {
+ var recommender = entry.getKey();
+ var maybeContext = getContext(aSessionOwner, recommender);
+ if (maybeContext.isEmpty()) {
+ continue;
+ }
+
+ try {
+ var maybeFactory = getRecommenderFactory(recommender);
+ if (maybeFactory.isEmpty()) {
+ continue;
+ }
+
+ var factory = maybeFactory.get();
+ var engine = factory.build(recommender);
+ engine.putIncrementalTrainingData(maybeContext.get(), entry.getValue());
}
+ catch (Exception e) {
+ LOG.warn("Unable to collect incremental training data for active recommender {}",
+ recommender);
+ continue;
+ }
+ }
+ }
+ private void triggerTrainingRun(User aSessionOwner, Project aProject, SourceDocument aDocument,
+ String aDataOwner, String aTrigger)
+ {
+ if (isSuspended(aSessionOwner.getUsername(), aProject)) {
return;
}
schedulingService.enqueue(TrainingTask.builder() //
- .withSessionOwner(user) //
+ .withSessionOwner(aSessionOwner) //
.withProject(aProject) //
- .withTrigger(aEventName) //
- .withCurrentDocument(aCurrentDocument) //
+ .withTrigger(aTrigger) //
+ .withCurrentDocument(aDocument) //
.withDataOwner(aDataOwner) //
.build());
- var state = getState(aSessionOwner, aProject);
+ var state = getState(aSessionOwner.getUsername(), aProject);
synchronized (state) {
int predictions = state.getPredictionsSinceLastEvaluation() + 1;
state.setPredictionsSinceLastEvaluation(predictions);
@@ -1000,6 +1070,31 @@ private void triggerTraining(String aSessionOwner, Project aProject, String aEve
}
}
+ private void triggerSelectionRun(User aSessionOwner, Project aProject, SourceDocument aDocument,
+ String aDataOwner, String aEventName)
+ {
+ if (isSuspended(aSessionOwner.getUsername(), aProject)) {
+ return;
+ }
+
+ // If it is time for a selection task, we just start a selection task.
+ // The selection task then will start the training once its finished,
+ // i.e. we do not start it here.
+ schedulingService.enqueue(SelectionTask.builder() //
+ .withSessionOwner(aSessionOwner) //
+ .withProject(aProject) //
+ .withTrigger(aEventName) //
+ .withCurrentDocument(aDocument) //
+ .withDataOwner(aDataOwner) //
+ .build());
+
+ var state = getState(aSessionOwner.getUsername(), aProject);
+ synchronized (state) {
+ state.setPredictionsUntilNextEvaluation(TRAININGS_PER_SELECTION - 1);
+ state.setPredictionsSinceLastEvaluation(0);
+ }
+ }
+
@Override
public List getLog(String aSessionOwner, Project aProject)
{