Skip to content

Commit

Permalink
#5318 - Ability to update recommenders incrementally
Browse files Browse the repository at this point in the history
- Introduce APIs necessary for incremental updates
  • Loading branch information
reckart committed Mar 2, 2025
1 parent 3ee308e commit 5066b4a
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 24 deletions.
4 changes: 4 additions & 0 deletions inception/inception-recommendation-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-preferences</artifactId>
</dependency>
<dependency>
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-api-annotation</artifactId>
</dependency>

<dependency>
<groupId>org.apache.uima</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static de.tudarmstadt.ukp.inception.recommendation.api.RecommendationService.FEATURE_NAME_IS_PREDICTION;
import static de.tudarmstadt.ukp.inception.recommendation.api.RecommendationService.FEATURE_NAME_SCORE_EXPLANATION_SUFFIX;
import static de.tudarmstadt.ukp.inception.recommendation.api.RecommendationService.FEATURE_NAME_SCORE_SUFFIX;
import static java.util.Collections.emptyList;
import static org.apache.uima.fit.util.CasUtil.getType;

import java.io.IOException;
Expand All @@ -31,6 +32,7 @@
import org.apache.uima.cas.Feature;
import org.apache.uima.cas.Type;

import de.tudarmstadt.ukp.inception.annotation.events.AnnotationEvent;
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.DataSplitter;
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.EvaluationResult;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
Expand Down Expand Up @@ -239,4 +241,21 @@ public void exportModel(RecommenderContext aContext, OutputStream aOutput) throw
{
throw new UnsupportedOperationException("Model export not supported");
}

public List<TrainingInstance> generateIncrementalTrainingInstances(AnnotationEvent aEvent)
{
return emptyList();
}

/**
* Store the given incremental training data into the given recommender context. The idea is
* that the engine then picks the training data up from the context on the next training run. It
* may then choose to use only the incremental training data during that run instead of training
* from scratch.
*/
public void putIncrementalTrainingData(RecommenderContext aRecommenderContext,
List<TrainingInstance> aIncrementalTrainingData)
{
// Nothing do to
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,33 @@
*/
package de.tudarmstadt.ukp.inception.recommendation.model;

import static java.util.Collections.unmodifiableMap;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import de.tudarmstadt.ukp.clarin.webanno.model.Project;
import de.tudarmstadt.ukp.clarin.webanno.model.SourceDocument;
import de.tudarmstadt.ukp.inception.annotation.events.AnnotationEvent;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.TrainingInstance;
import de.tudarmstadt.ukp.inception.rendering.model.Range;

public class DirtySpot
{
private final SourceDocument document;
private final String user;
private final Range affectedRange;
private final Map<Recommender, List<TrainingInstance>> incrementalTrainingData;

public DirtySpot(AnnotationEvent aEvent)
public DirtySpot(AnnotationEvent aEvent, Map<Recommender, List<TrainingInstance>> aIncrementalTrainingData)
{
document = aEvent.getDocument();
user = aEvent.getDocumentOwner();
affectedRange = aEvent.getAffectedRange();
incrementalTrainingData = unmodifiableMap(new LinkedHashMap<>(aIncrementalTrainingData));
}

public Range getAffectedRange()
Expand Down Expand Up @@ -62,6 +71,11 @@ public Project getProject()
return document.getProject();
}

public Map<Recommender, List<TrainingInstance>> getIncrementalTrainingData()
{
return incrementalTrainingData;
}

@Override
public int hashCode()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import static java.util.Optional.empty;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toMap;
import static org.apache.commons.collections4.CollectionUtils.isEmpty;

import java.io.IOException;
import java.lang.invoke.MethodHandles;
Expand Down Expand Up @@ -110,6 +111,7 @@
import de.tudarmstadt.ukp.inception.recommendation.api.model.SuggestionGroup;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngineFactory;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.TrainingInstance;
import de.tudarmstadt.ukp.inception.recommendation.config.RecommenderServiceAutoConfiguration;
import de.tudarmstadt.ukp.inception.recommendation.event.RecommenderDeletedEvent;
import de.tudarmstadt.ukp.inception.recommendation.event.RecommenderUpdatedEvent;
Expand Down Expand Up @@ -730,7 +732,38 @@ public void onAnnotation(AnnotationEvent aEvent)
requestCycle.setMetaData(DIRTIES, dirties);
}

dirties.add(new DirtySpot(aEvent));
var incrementalTrainingData = new LinkedHashMap<Recommender, List<TrainingInstance>>();
var sessionOwner = userRepository.getCurrentUser();
var recommenders = getActiveRecommenders(sessionOwner, aEvent.getProject()).stream() //
.map(EvaluatedRecommender::getRecommender) //
.filter(Recommender::isEnabled) //
.filter(rec -> rec.getLayer() != null) //
.filter(rec -> rec.getLayer().isEnabled()) //
.filter(rec -> rec.getFeature().isEnabled()) //
.filter(rec -> rec.getLayer().equals(aEvent.getLayer())) //
.toList();

for (var recommender : recommenders) {
try {
var maybeFactory = getRecommenderFactory(recommender);
if (maybeFactory.isEmpty()) {
continue;
}

var factory = maybeFactory.get();
var engine = factory.build(recommender);

incrementalTrainingData.computeIfAbsent(recommender, $ -> new ArrayList<>()) //
.addAll(engine.generateIncrementalTrainingInstances(aEvent));
}
catch (Exception e) {
LOG.warn("Unable to collect incremental training data for active recommender {}",
recommender);
continue;
}
}

dirties.add(new DirtySpot(aEvent, incrementalTrainingData));
}

/*
Expand Down Expand Up @@ -946,15 +979,17 @@ private void triggerTraining(String aSessionOwner, Project aProject, String aEve
return;
}

var user = userRepository.get(aSessionOwner);
// do not trigger training during when viewing others' work
if (user == null || !user.equals(userRepository.getCurrentUser())) {
// Do not trigger training during when viewing others' work
var sessionOwner = userRepository.get(aSessionOwner);
if (sessionOwner == null || !sessionOwner.equals(userRepository.getCurrentUser())) {
return;
}

commitIncrementalTrainingData(aSessionOwner, aDirties);

// Update the task count
var count = trainingTaskCounter.computeIfAbsent(
new RecommendationStateKey(user.getUsername(), aProject.getId()),
new RecommendationStateKey(aSessionOwner, aProject.getId()),
_key -> new AtomicInteger(0));

// If there is no active recommender at all then let's try hard to make one active by
Expand All @@ -964,42 +999,102 @@ private void triggerTraining(String aSessionOwner, Project aProject, String aEve
}

if (aForceSelection || (count.getAndIncrement() % TRAININGS_PER_SELECTION == 0)) {
// If it is time for a selection task, we just start a selection task.
// The selection task then will start the training once its finished,
// i.e. we do not start it here.
schedulingService.enqueue(SelectionTask.builder() //
.withSessionOwner(user) //
.withProject(aProject) //
.withTrigger(aEventName) //
.withCurrentDocument(aCurrentDocument) //
.withDataOwner(aDataOwner) //
.build());
triggerSelectionRun(sessionOwner, aProject, aCurrentDocument, aDataOwner, aEventName);
return;
}

var state = getState(aSessionOwner, aProject);
synchronized (state) {
state.setPredictionsUntilNextEvaluation(TRAININGS_PER_SELECTION - 1);
state.setPredictionsSinceLastEvaluation(0);
triggerTrainingRun(sessionOwner, aProject, aCurrentDocument, aDataOwner, aEventName);
}

private void commitIncrementalTrainingData(String aSessionOwner, Set<DirtySpot> aDirties)
{
if (isEmpty(aDirties)) {
return;
}

var aggregatedIncrementalTrainingData = new LinkedHashMap<Recommender, List<TrainingInstance>>();
for (var dirtySpot : aDirties) {
for (var entry : dirtySpot.getIncrementalTrainingData().entrySet()) {
var recommender = entry.getKey();
aggregatedIncrementalTrainingData
.computeIfAbsent(recommender, $ -> new ArrayList<>()) //
.addAll(entry.getValue());
}
}

for (var entry : aggregatedIncrementalTrainingData.entrySet()) {
var recommender = entry.getKey();
var maybeContext = getContext(aSessionOwner, recommender);
if (maybeContext.isEmpty()) {
continue;
}

try {
var maybeFactory = getRecommenderFactory(recommender);
if (maybeFactory.isEmpty()) {
continue;
}

var factory = maybeFactory.get();
var engine = factory.build(recommender);
engine.putIncrementalTrainingData(maybeContext.get(), entry.getValue());
}
catch (Exception e) {
LOG.warn("Unable to collect incremental training data for active recommender {}",
recommender);
continue;
}
}
}

private void triggerTrainingRun(User aSessionOwner, Project aProject, SourceDocument aDocument,
String aDataOwner, String aTrigger)
{
if (isSuspended(aSessionOwner.getUsername(), aProject)) {
return;
}

schedulingService.enqueue(TrainingTask.builder() //
.withSessionOwner(user) //
.withSessionOwner(aSessionOwner) //
.withProject(aProject) //
.withTrigger(aEventName) //
.withCurrentDocument(aCurrentDocument) //
.withTrigger(aTrigger) //
.withCurrentDocument(aDocument) //
.withDataOwner(aDataOwner) //
.build());

var state = getState(aSessionOwner, aProject);
var state = getState(aSessionOwner.getUsername(), aProject);
synchronized (state) {
int predictions = state.getPredictionsSinceLastEvaluation() + 1;
state.setPredictionsSinceLastEvaluation(predictions);
state.setPredictionsUntilNextEvaluation(TRAININGS_PER_SELECTION - predictions - 1);
}
}

private void triggerSelectionRun(User aSessionOwner, Project aProject, SourceDocument aDocument,
String aDataOwner, String aEventName)
{
if (isSuspended(aSessionOwner.getUsername(), aProject)) {
return;
}

// If it is time for a selection task, we just start a selection task.
// The selection task then will start the training once its finished,
// i.e. we do not start it here.
schedulingService.enqueue(SelectionTask.builder() //
.withSessionOwner(aSessionOwner) //
.withProject(aProject) //
.withTrigger(aEventName) //
.withCurrentDocument(aDocument) //
.withDataOwner(aDataOwner) //
.build());

var state = getState(aSessionOwner.getUsername(), aProject);
synchronized (state) {
state.setPredictionsUntilNextEvaluation(TRAININGS_PER_SELECTION - 1);
state.setPredictionsSinceLastEvaluation(0);
}
}

@Override
public List<LogMessageGroup> getLog(String aSessionOwner, Project aProject)
{
Expand Down

0 comments on commit 5066b4a

Please sign in to comment.