diff --git a/inception/inception-recommendation-api/pom.xml b/inception/inception-recommendation-api/pom.xml index 47117f12745..0a0cc58d6e9 100644 --- a/inception/inception-recommendation-api/pom.xml +++ b/inception/inception-recommendation-api/pom.xml @@ -62,6 +62,10 @@ de.tudarmstadt.ukp.inception.app inception-preferences + + de.tudarmstadt.ukp.inception.app + inception-api-annotation + org.apache.uima diff --git a/inception/inception-recommendation-api/src/main/java/de/tudarmstadt/ukp/inception/recommendation/api/recommender/RecommendationEngine.java b/inception/inception-recommendation-api/src/main/java/de/tudarmstadt/ukp/inception/recommendation/api/recommender/RecommendationEngine.java index 82af94891b2..b60a7e1b8c4 100644 --- a/inception/inception-recommendation-api/src/main/java/de/tudarmstadt/ukp/inception/recommendation/api/recommender/RecommendationEngine.java +++ b/inception/inception-recommendation-api/src/main/java/de/tudarmstadt/ukp/inception/recommendation/api/recommender/RecommendationEngine.java @@ -21,6 +21,7 @@ import static de.tudarmstadt.ukp.inception.recommendation.api.RecommendationService.FEATURE_NAME_IS_PREDICTION; import static de.tudarmstadt.ukp.inception.recommendation.api.RecommendationService.FEATURE_NAME_SCORE_EXPLANATION_SUFFIX; import static de.tudarmstadt.ukp.inception.recommendation.api.RecommendationService.FEATURE_NAME_SCORE_SUFFIX; +import static java.util.Collections.emptyList; import static org.apache.uima.fit.util.CasUtil.getType; import java.io.IOException; @@ -31,6 +32,7 @@ import org.apache.uima.cas.Feature; import org.apache.uima.cas.Type; +import de.tudarmstadt.ukp.inception.annotation.events.AnnotationEvent; import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.DataSplitter; import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.EvaluationResult; import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender; @@ -239,4 +241,21 @@ public void exportModel(RecommenderContext aContext, OutputStream aOutput) throw { throw new UnsupportedOperationException("Model export not supported"); } + + public List generateIncrementalTrainingInstances(AnnotationEvent aEvent) + { + return emptyList(); + } + + /** + * Store the given incremental training data into the given recommender context. The idea is + * that the engine then picks the training data up from the context on the next training run. It + * may then choose to use only the incremental training data during that run instead of training + * from scratch. + */ + public void putIncrementalTrainingData(RecommenderContext aRecommenderContext, + List aIncrementalTrainingData) + { + // Nothing do to + } } diff --git a/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/model/DirtySpot.java b/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/model/DirtySpot.java index 5bb9712a64d..990cf80b24f 100644 --- a/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/model/DirtySpot.java +++ b/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/model/DirtySpot.java @@ -17,11 +17,18 @@ */ package de.tudarmstadt.ukp.inception.recommendation.model; +import static java.util.Collections.unmodifiableMap; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Objects; import de.tudarmstadt.ukp.clarin.webanno.model.Project; import de.tudarmstadt.ukp.clarin.webanno.model.SourceDocument; import de.tudarmstadt.ukp.inception.annotation.events.AnnotationEvent; +import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender; +import de.tudarmstadt.ukp.inception.recommendation.api.recommender.TrainingInstance; import de.tudarmstadt.ukp.inception.rendering.model.Range; public class DirtySpot @@ -29,12 +36,14 @@ public class DirtySpot private final SourceDocument document; private final String user; private final Range affectedRange; + private final Map> incrementalTrainingData; - public DirtySpot(AnnotationEvent aEvent) + public DirtySpot(AnnotationEvent aEvent, Map> aIncrementalTrainingData) { document = aEvent.getDocument(); user = aEvent.getDocumentOwner(); affectedRange = aEvent.getAffectedRange(); + incrementalTrainingData = unmodifiableMap(new LinkedHashMap<>(aIncrementalTrainingData)); } public Range getAffectedRange() @@ -62,6 +71,11 @@ public Project getProject() return document.getProject(); } + public Map> getIncrementalTrainingData() + { + return incrementalTrainingData; + } + @Override public int hashCode() { diff --git a/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/service/RecommendationServiceImpl.java b/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/service/RecommendationServiceImpl.java index 23a6f98921b..40b1ae9b39d 100644 --- a/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/service/RecommendationServiceImpl.java +++ b/inception/inception-recommendation/src/main/java/de/tudarmstadt/ukp/inception/recommendation/service/RecommendationServiceImpl.java @@ -27,6 +27,7 @@ import static java.util.Optional.empty; import static java.util.function.Function.identity; import static java.util.stream.Collectors.toMap; +import static org.apache.commons.collections4.CollectionUtils.isEmpty; import java.io.IOException; import java.lang.invoke.MethodHandles; @@ -110,6 +111,7 @@ import de.tudarmstadt.ukp.inception.recommendation.api.model.SuggestionGroup; import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngineFactory; import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext; +import de.tudarmstadt.ukp.inception.recommendation.api.recommender.TrainingInstance; import de.tudarmstadt.ukp.inception.recommendation.config.RecommenderServiceAutoConfiguration; import de.tudarmstadt.ukp.inception.recommendation.event.RecommenderDeletedEvent; import de.tudarmstadt.ukp.inception.recommendation.event.RecommenderUpdatedEvent; @@ -730,7 +732,38 @@ public void onAnnotation(AnnotationEvent aEvent) requestCycle.setMetaData(DIRTIES, dirties); } - dirties.add(new DirtySpot(aEvent)); + var incrementalTrainingData = new LinkedHashMap>(); + var sessionOwner = userRepository.getCurrentUser(); + var recommenders = getActiveRecommenders(sessionOwner, aEvent.getProject()).stream() // + .map(EvaluatedRecommender::getRecommender) // + .filter(Recommender::isEnabled) // + .filter(rec -> rec.getLayer() != null) // + .filter(rec -> rec.getLayer().isEnabled()) // + .filter(rec -> rec.getFeature().isEnabled()) // + .filter(rec -> rec.getLayer().equals(aEvent.getLayer())) // + .toList(); + + for (var recommender : recommenders) { + try { + var maybeFactory = getRecommenderFactory(recommender); + if (maybeFactory.isEmpty()) { + continue; + } + + var factory = maybeFactory.get(); + var engine = factory.build(recommender); + + incrementalTrainingData.computeIfAbsent(recommender, $ -> new ArrayList<>()) // + .addAll(engine.generateIncrementalTrainingInstances(aEvent)); + } + catch (Exception e) { + LOG.warn("Unable to collect incremental training data for active recommender {}", + recommender); + continue; + } + } + + dirties.add(new DirtySpot(aEvent, incrementalTrainingData)); } /* @@ -946,15 +979,17 @@ private void triggerTraining(String aSessionOwner, Project aProject, String aEve return; } - var user = userRepository.get(aSessionOwner); - // do not trigger training during when viewing others' work - if (user == null || !user.equals(userRepository.getCurrentUser())) { + // Do not trigger training during when viewing others' work + var sessionOwner = userRepository.get(aSessionOwner); + if (sessionOwner == null || !sessionOwner.equals(userRepository.getCurrentUser())) { return; } + commitIncrementalTrainingData(aSessionOwner, aDirties); + // Update the task count var count = trainingTaskCounter.computeIfAbsent( - new RecommendationStateKey(user.getUsername(), aProject.getId()), + new RecommendationStateKey(aSessionOwner, aProject.getId()), _key -> new AtomicInteger(0)); // If there is no active recommender at all then let's try hard to make one active by @@ -964,35 +999,70 @@ private void triggerTraining(String aSessionOwner, Project aProject, String aEve } if (aForceSelection || (count.getAndIncrement() % TRAININGS_PER_SELECTION == 0)) { - // If it is time for a selection task, we just start a selection task. - // The selection task then will start the training once its finished, - // i.e. we do not start it here. - schedulingService.enqueue(SelectionTask.builder() // - .withSessionOwner(user) // - .withProject(aProject) // - .withTrigger(aEventName) // - .withCurrentDocument(aCurrentDocument) // - .withDataOwner(aDataOwner) // - .build()); + triggerSelectionRun(sessionOwner, aProject, aCurrentDocument, aDataOwner, aEventName); + return; + } - var state = getState(aSessionOwner, aProject); - synchronized (state) { - state.setPredictionsUntilNextEvaluation(TRAININGS_PER_SELECTION - 1); - state.setPredictionsSinceLastEvaluation(0); + triggerTrainingRun(sessionOwner, aProject, aCurrentDocument, aDataOwner, aEventName); + } + + private void commitIncrementalTrainingData(String aSessionOwner, Set aDirties) + { + if (isEmpty(aDirties)) { + return; + } + + var aggregatedIncrementalTrainingData = new LinkedHashMap>(); + for (var dirtySpot : aDirties) { + for (var entry : dirtySpot.getIncrementalTrainingData().entrySet()) { + var recommender = entry.getKey(); + aggregatedIncrementalTrainingData + .computeIfAbsent(recommender, $ -> new ArrayList<>()) // + .addAll(entry.getValue()); + } + } + + for (var entry : aggregatedIncrementalTrainingData.entrySet()) { + var recommender = entry.getKey(); + var maybeContext = getContext(aSessionOwner, recommender); + if (maybeContext.isEmpty()) { + continue; + } + + try { + var maybeFactory = getRecommenderFactory(recommender); + if (maybeFactory.isEmpty()) { + continue; + } + + var factory = maybeFactory.get(); + var engine = factory.build(recommender); + engine.putIncrementalTrainingData(maybeContext.get(), entry.getValue()); } + catch (Exception e) { + LOG.warn("Unable to collect incremental training data for active recommender {}", + recommender); + continue; + } + } + } + private void triggerTrainingRun(User aSessionOwner, Project aProject, SourceDocument aDocument, + String aDataOwner, String aTrigger) + { + if (isSuspended(aSessionOwner.getUsername(), aProject)) { return; } schedulingService.enqueue(TrainingTask.builder() // - .withSessionOwner(user) // + .withSessionOwner(aSessionOwner) // .withProject(aProject) // - .withTrigger(aEventName) // - .withCurrentDocument(aCurrentDocument) // + .withTrigger(aTrigger) // + .withCurrentDocument(aDocument) // .withDataOwner(aDataOwner) // .build()); - var state = getState(aSessionOwner, aProject); + var state = getState(aSessionOwner.getUsername(), aProject); synchronized (state) { int predictions = state.getPredictionsSinceLastEvaluation() + 1; state.setPredictionsSinceLastEvaluation(predictions); @@ -1000,6 +1070,31 @@ private void triggerTraining(String aSessionOwner, Project aProject, String aEve } } + private void triggerSelectionRun(User aSessionOwner, Project aProject, SourceDocument aDocument, + String aDataOwner, String aEventName) + { + if (isSuspended(aSessionOwner.getUsername(), aProject)) { + return; + } + + // If it is time for a selection task, we just start a selection task. + // The selection task then will start the training once its finished, + // i.e. we do not start it here. + schedulingService.enqueue(SelectionTask.builder() // + .withSessionOwner(aSessionOwner) // + .withProject(aProject) // + .withTrigger(aEventName) // + .withCurrentDocument(aDocument) // + .withDataOwner(aDataOwner) // + .build()); + + var state = getState(aSessionOwner.getUsername(), aProject); + synchronized (state) { + state.setPredictionsUntilNextEvaluation(TRAININGS_PER_SELECTION - 1); + state.setPredictionsSinceLastEvaluation(0); + } + } + @Override public List getLog(String aSessionOwner, Project aProject) {