Skip to content

Commit

Permalink
Issue #5314: Option for the assistant to check annotations
Browse files Browse the repository at this point in the history
- Added check option accessible by right-clicking on a span annotation
- Added watch mode that continually monitors new annotations and comments if it thinks something needs to be changed
- Added new queuing mode to the scheduling service which just executes tasks one after the other
- Added debug mode toggle to the assistant sidebar
  • Loading branch information
reckart committed Feb 28, 2025
1 parent 98acc2c commit 676d54f
Show file tree
Hide file tree
Showing 20 changed files with 798 additions and 63 deletions.
4 changes: 4 additions & 0 deletions inception/inception-assistant/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@
<groupId>org.wicketstuff</groupId>
<artifactId>wicketstuff-jquery-ui</artifactId>
</dependency>
<dependency>
<groupId>org.wicketstuff</groupId>
<artifactId>wicketstuff-annotationeventdispatcher</artifactId>
</dependency>

<dependency>
<groupId>org.apache.uima</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.util.List;

import de.tudarmstadt.ukp.clarin.webanno.model.Project;
import de.tudarmstadt.ukp.inception.assistant.model.MCallResponse;
import de.tudarmstadt.ukp.inception.assistant.model.MMessage;
import de.tudarmstadt.ukp.inception.assistant.model.MTextMessage;

public interface AssistantService
Expand All @@ -32,10 +34,22 @@ public interface AssistantService
void processUserMessage(String aSessionOwner, Project aProject, MTextMessage aMessage,
MTextMessage... aTransientMessage);

void processAgentMessage(String aSessionOwner, Project aProject, MTextMessage aMessage,
MTextMessage... aContextMessages);

MTextMessage processInternalMessageSync(String aSessionOwner, Project aProject,
MTextMessage aMessage)
throws IOException;

<T> MCallResponse<T> processInternalCallSync(String aSessionOwner, Project aProject,
Class<T> aType, MTextMessage aMessage)
throws IOException;

void clearConversation(String aSessionOwner, Project aProject);

void setDebugMode(String aSessionOwner, Project aProject, boolean aObject);

boolean isDebugMode(String aSessionOwner, Project aProject);

void dispatchMessage(String aSessionOwner, Project aProject, MMessage aMessage);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@
package de.tudarmstadt.ukp.inception.assistant;

import static de.tudarmstadt.ukp.inception.assistant.model.MChatRoles.SYSTEM;
import static de.tudarmstadt.ukp.inception.support.json.JSONUtil.toPrettyJsonString;
import static java.lang.Math.floorDiv;
import static java.lang.String.join;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.unmodifiableList;
import static org.apache.commons.lang3.ArrayUtils.isNotEmpty;

import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
Expand All @@ -51,6 +54,7 @@
import de.tudarmstadt.ukp.clarin.webanno.model.Project;
import de.tudarmstadt.ukp.clarin.webanno.security.model.User;
import de.tudarmstadt.ukp.inception.assistant.config.AssistantProperties;
import de.tudarmstadt.ukp.inception.assistant.model.MCallResponse;
import de.tudarmstadt.ukp.inception.assistant.model.MChatMessage;
import de.tudarmstadt.ukp.inception.assistant.model.MMessage;
import de.tudarmstadt.ukp.inception.assistant.model.MRemoveConversationCommand;
Expand Down Expand Up @@ -147,6 +151,7 @@ public List<MTextMessage> getAllChatMessages(String aSessionOwner, Project aProj
return state.getMessages().stream() //
.filter(MTextMessage.class::isInstance) //
.map(MTextMessage.class::cast) //
.filter(msg -> state.isDebugMode() || !msg.internal()) //
.toList();
}

Expand All @@ -165,18 +170,19 @@ public List<MTextMessage> getChatMessages(String aSessionOwner, Project aProject

void recordMessage(String aSessionOwner, Project aProject, MChatMessage aMessage)
{
if (!properties.isDevMode() && aMessage.ephemeral()) {
if (!isDebugMode(aSessionOwner, aProject) && aMessage.ephemeral()) {
return;
}

var state = getState(aSessionOwner, aProject);
state.upsertMessage(aMessage);
}

void dispatchMessage(String aSessionOwner, Project aProject, MMessage aMessage)
@Override
public void dispatchMessage(String aSessionOwner, Project aProject, MMessage aMessage)
{
if (aMessage instanceof MChatMessage chatMessage) {
if (!properties.isDevMode() && chatMessage.internal()) {
if (!isDebugMode(aSessionOwner, aProject) && chatMessage.internal()) {
return;
}
}
Expand All @@ -190,27 +196,109 @@ void dispatchMessage(String aSessionOwner, Project aProject, MMessage aMessage)
public void clearConversation(String aSessionOwner, Project aProject)
{
synchronized (states) {
states.keySet().removeIf(key -> aSessionOwner.equals(key.user())
&& Objects.equals(aProject.getId(), key.projectId));
states.entrySet().stream() //
.filter(e -> aSessionOwner.equals(e.getKey().user())
&& Objects.equals(aProject.getId(), e.getKey().projectId)) //
.map(Entry::getValue) //
.forEach(state -> state.clearMessages());
}

dispatchMessage(aSessionOwner, aProject, new MRemoveConversationCommand());
}

@Override
public void setDebugMode(String aSessionOwner, Project aProject, boolean aOnOff)
{
synchronized (states) {
getState(aSessionOwner, aProject).setDebugMode(aOnOff);
}
}

@Override
public boolean isDebugMode(String aSessionOwner, Project aProject)
{
synchronized (states) {
return getState(aSessionOwner, aProject).isDebugMode();
}
}

@Override
public MTextMessage processInternalMessageSync(String aSessionOwner, Project aProject,
MTextMessage aMessage)
throws IOException
{
Validate.isTrue(aMessage.internal());

if (properties.isDevMode()) {
if (isDebugMode(aSessionOwner, aProject)) {
recordMessage(aSessionOwner, aProject, aMessage);
dispatchMessage(aSessionOwner, aProject, aMessage);
}

var assistant = new ChatContext(properties, ollamaClient, aSessionOwner, aProject);
return assistant.generate(asList(aMessage));
return assistant.chat(asList(aMessage));
}

@Override
public <T> MCallResponse<T> processInternalCallSync(String aSessionOwner, Project aProject,
Class<T> aType, MTextMessage aMessage)
throws IOException
{
Validate.isTrue(aMessage.internal());

if (isDebugMode(aSessionOwner, aProject)) {
recordMessage(aSessionOwner, aProject, aMessage);
dispatchMessage(aSessionOwner, aProject, aMessage);
}

var assistant = new ChatContext(properties, ollamaClient, aSessionOwner, aProject);
var result = assistant.call(aType, asList(aMessage));

if (isDebugMode(aSessionOwner, aProject)) {
var resultMessage = MTextMessage.builder() //
.withRole(SYSTEM).internal().ephemeral() //
.withActor(aMessage.actor()) //
.withMessage("```json\n" + toPrettyJsonString(result.payload()) + "\n```") //
.withPerformance(result.performance()) //
.build();
recordMessage(aSessionOwner, aProject, resultMessage);
dispatchMessage(aSessionOwner, aProject, resultMessage);
}

return result;
}

@Override
public void processAgentMessage(String aSessionOwner, Project aProject, MTextMessage aMessage,
MTextMessage... aContextMessages)
{
var assistant = new ChatContext(properties, ollamaClient, aSessionOwner, aProject);

// Dispatch message early so the front-end can enter waiting state
dispatchMessage(aSessionOwner, aProject, aMessage);

try {
var systemMessages = generateSystemMessages();

recordMessage(aSessionOwner, aProject, aMessage);

var recentConversation = limitConversationToContextLength(systemMessages, emptyList(),
emptyList(), aMessage, properties.getChat().getContextLength());

var responseMessage = assistant.chat(recentConversation,
(id, r) -> handleStreamedMessageFragment(aSessionOwner, aProject, id, r));

recordMessage(aSessionOwner, aProject, responseMessage);

dispatchMessage(aSessionOwner, aProject, responseMessage.withoutContent());
}
catch (IOException e) {
var errorMessage = MTextMessage.builder() //
.withActor("Error").withRole(SYSTEM).internal().ephemeral() //
.withMessage("Error: " + e.getMessage()) //
.build();
recordMessage(aSessionOwner, aProject, errorMessage);
dispatchMessage(aSessionOwner, aProject, errorMessage);
}
}

@Override
Expand Down Expand Up @@ -238,7 +326,7 @@ public void processUserMessage(String aSessionOwner, Project aProject, MTextMess
// We record the message only now to ensure it is not included in the listMessages above
recordMessage(aSessionOwner, aProject, aMessage);

if (properties.isDevMode()) {
if (isDebugMode(aSessionOwner, aProject)) {
for (var msg : ephemeralMessages) {
recordMessage(aSessionOwner, aProject, msg);
dispatchMessage(aSessionOwner, aProject, msg);
Expand All @@ -249,7 +337,7 @@ public void processUserMessage(String aSessionOwner, Project aProject, MTextMess
ephemeralMessages, conversationMessages, aMessage,
properties.getChat().getContextLength());

var responseMessage = assistant.generate(recentConversation,
var responseMessage = assistant.chat(recentConversation,
(id, r) -> handleStreamedMessageFragment(aSessionOwner, aProject, id, r));

recordMessage(aSessionOwner, aProject, responseMessage);
Expand Down Expand Up @@ -448,10 +536,28 @@ private void clearState(String aSessionOwner)
private static class AssistentState
{
private LinkedList<MMessage> messages = new LinkedList<>();
private boolean debugMode;

public List<MMessage> getMessages()
{
return new ArrayList<>(messages);
return unmodifiableList(new ArrayList<>(messages));
}

public void clearMessages()
{
synchronized (messages) {
messages.clear();
}
}

public void setDebugMode(boolean aOnOff)
{
debugMode = aOnOff;
}

public boolean isDebugMode()
{
return debugMode;
}

public void upsertMessage(MMessage aMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package de.tudarmstadt.ukp.inception.assistant;

import static com.github.victools.jsonschema.generator.OptionPreset.PLAIN_JSON;
import static com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12;
import static de.tudarmstadt.ukp.inception.assistant.model.MChatRoles.ASSISTANT;
import static java.lang.System.currentTimeMillis;

Expand All @@ -26,8 +28,15 @@
import java.util.UUID;
import java.util.function.BiConsumer;

import com.github.victools.jsonschema.generator.Option;
import com.github.victools.jsonschema.generator.SchemaGenerator;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder;
import com.github.victools.jsonschema.module.jackson.JacksonModule;
import com.github.victools.jsonschema.module.jackson.JacksonOption;

import de.tudarmstadt.ukp.clarin.webanno.model.Project;
import de.tudarmstadt.ukp.inception.assistant.config.AssistantProperties;
import de.tudarmstadt.ukp.inception.assistant.model.MCallResponse;
import de.tudarmstadt.ukp.inception.assistant.model.MPerformanceMetrics;
import de.tudarmstadt.ukp.inception.assistant.model.MReference;
import de.tudarmstadt.ukp.inception.assistant.model.MTextMessage;
Expand All @@ -37,13 +46,15 @@
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ollama.client.OllamaChatResponse;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ollama.client.OllamaClient;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ollama.client.OllamaOptions;
import de.tudarmstadt.ukp.inception.support.json.JSONUtil;

public class ChatContext
{
private final AssistantProperties properties;
private final OllamaClient ollamaClient;
private final String sessionOwner;
private final Project project;
private SchemaGenerator generator;

public ChatContext(AssistantProperties aProperties, OllamaClient aOllamaClient,
String aSessionOwner, Project aProject)
Expand All @@ -52,6 +63,10 @@ public ChatContext(AssistantProperties aProperties, OllamaClient aOllamaClient,
ollamaClient = aOllamaClient;
sessionOwner = aSessionOwner;
project = aProject;
generator = new SchemaGenerator(new SchemaGeneratorConfigBuilder(DRAFT_2020_12, PLAIN_JSON) //
.with(Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT) //
.with(new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED)) //
.build());
}

public Project getProject()
Expand All @@ -64,12 +79,12 @@ public String getSessionOwner()
return sessionOwner;
}

public MTextMessage generate(List<MTextMessage> aMessasges) throws IOException
public MTextMessage chat(List<MTextMessage> aMessasges) throws IOException
{
return generate(aMessasges, null);
return chat(aMessasges, null);
}

public MTextMessage generate(List<MTextMessage> aMessasges,
public MTextMessage chat(List<MTextMessage> aMessasges,
BiConsumer<UUID, MTextMessage> aCallback)
throws IOException
{
Expand Down Expand Up @@ -121,6 +136,55 @@ public MTextMessage generate(List<MTextMessage> aMessasges,
.build();
}

public <T> MCallResponse<T> call(Class<T> aResult, List<MTextMessage> aMessasges)
throws IOException
{
var schema = generator.generateSchema(aResult);

var responseId = UUID.randomUUID();
var chatProperties = properties.getChat();
var request = OllamaChatRequest.builder() //
.withModel(chatProperties.getModel()) //
.withStream(true) //
.withMessages(aMessasges.stream() //
.map(msg -> new OllamaChatMessage(msg.role(), msg.message())) //
.toList()) //
.withFormat(schema) //
.withOption(OllamaOptions.NUM_CTX, chatProperties.getContextLength()) //
.withOption(OllamaOptions.TOP_P, chatProperties.getTopP()) //
.withOption(OllamaOptions.TOP_K, chatProperties.getTopK()) //
.withOption(OllamaOptions.REPEAT_PENALTY, chatProperties.getRepeatPenalty()) //
.withOption(OllamaOptions.TEMPERATURE, chatProperties.getTemperature()) //
.build();

var references = new LinkedHashMap<String, MReference>();
aMessasges.stream() //
.flatMap(msg -> msg.references().stream()) //
.forEach(r -> references.put(r.id(), r));

// Generate the actual response
var startTime = currentTimeMillis();
var response = ollamaClient.chat(properties.getUrl(), request, null);
var tokens = response.getEvalCount();
var endTime = currentTimeMillis();

var payload = JSONUtil.fromJsonString(aResult, response.getMessage().content());

// Send a final and complete message also including final metrics
return MCallResponse.builder(aResult) //
.withId(responseId) //
.withActor(properties.getNickname()) //
.withRole(ASSISTANT) //
.withPayload(payload) //
.withPerformance(MPerformanceMetrics.builder() //
.withDuration(endTime - startTime) //
.withTokens(tokens) //
.build()) //
// Include all refs in the final message again just to be sure
.withReferences(references.values()) //
.build();
}

private void streamMessage(BiConsumer<UUID, MTextMessage> aCallback, UUID responseId,
OllamaChatResponse msg)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ public interface AssistantProperties

AssistantEmbeddingProperties getEmbedding();

boolean isDevMode();

String getNickname();

AssitantUserGuideProperties getUserGuide();
Expand Down
Loading

0 comments on commit 676d54f

Please sign in to comment.