Skip to content

Commit

Permalink
Merge pull request #5315 from inception-project/feature/5314-Option-f…
Browse files Browse the repository at this point in the history
…or-the-assistant-to-check-annotations

Issue #5314: Option for the assistant to check annotations
  • Loading branch information
reckart authored Feb 28, 2025
2 parents 98acc2c + 5360ce7 commit d92f02f
Show file tree
Hide file tree
Showing 20 changed files with 814 additions and 67 deletions.
24 changes: 20 additions & 4 deletions inception/inception-assistant/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,6 @@
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-imls-llm-support</artifactId>
</dependency>
<dependency>
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-scheduling</artifactId>
</dependency>
<dependency>
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-schema-api</artifactId>
Expand All @@ -101,6 +97,10 @@
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-api-editor</artifactId>
</dependency>
<dependency>
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-support-bootstrap</artifactId>
</dependency>

<dependency>
<groupId>com.knuddels</groupId>
Expand Down Expand Up @@ -137,6 +137,14 @@
<groupId>com.networknt</groupId>
<artifactId>json-schema-validator</artifactId>
</dependency>
<dependency>
<groupId>com.github.victools</groupId>
<artifactId>jsonschema-module-jackson</artifactId>
</dependency>
<dependency>
<groupId>com.github.victools</groupId>
<artifactId>jsonschema-generator</artifactId>
</dependency>

<dependency>
<groupId>org.apache.wicket</groupId>
Expand Down Expand Up @@ -166,6 +174,14 @@
<groupId>org.wicketstuff</groupId>
<artifactId>wicketstuff-jquery-ui</artifactId>
</dependency>
<dependency>
<groupId>org.wicketstuff</groupId>
<artifactId>wicketstuff-annotationeventdispatcher</artifactId>
</dependency>
<dependency>
<groupId>org.danekja</groupId>
<artifactId>jdk-serializable-functional</artifactId>
</dependency>

<dependency>
<groupId>org.apache.uima</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.util.List;

import de.tudarmstadt.ukp.clarin.webanno.model.Project;
import de.tudarmstadt.ukp.inception.assistant.model.MCallResponse;
import de.tudarmstadt.ukp.inception.assistant.model.MMessage;
import de.tudarmstadt.ukp.inception.assistant.model.MTextMessage;

public interface AssistantService
Expand All @@ -32,10 +34,22 @@ public interface AssistantService
void processUserMessage(String aSessionOwner, Project aProject, MTextMessage aMessage,
MTextMessage... aTransientMessage);

void processAgentMessage(String aSessionOwner, Project aProject, MTextMessage aMessage,
MTextMessage... aContextMessages);

MTextMessage processInternalMessageSync(String aSessionOwner, Project aProject,
MTextMessage aMessage)
throws IOException;

<T> MCallResponse<T> processInternalCallSync(String aSessionOwner, Project aProject,
Class<T> aType, MTextMessage aMessage)
throws IOException;

void clearConversation(String aSessionOwner, Project aProject);

void setDebugMode(String aSessionOwner, Project aProject, boolean aObject);

boolean isDebugMode(String aSessionOwner, Project aProject);

void dispatchMessage(String aSessionOwner, Project aProject, MMessage aMessage);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@
package de.tudarmstadt.ukp.inception.assistant;

import static de.tudarmstadt.ukp.inception.assistant.model.MChatRoles.SYSTEM;
import static de.tudarmstadt.ukp.inception.support.json.JSONUtil.toPrettyJsonString;
import static java.lang.Math.floorDiv;
import static java.lang.String.join;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.unmodifiableList;
import static org.apache.commons.lang3.ArrayUtils.isNotEmpty;

import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
Expand All @@ -51,6 +54,7 @@
import de.tudarmstadt.ukp.clarin.webanno.model.Project;
import de.tudarmstadt.ukp.clarin.webanno.security.model.User;
import de.tudarmstadt.ukp.inception.assistant.config.AssistantProperties;
import de.tudarmstadt.ukp.inception.assistant.model.MCallResponse;
import de.tudarmstadt.ukp.inception.assistant.model.MChatMessage;
import de.tudarmstadt.ukp.inception.assistant.model.MMessage;
import de.tudarmstadt.ukp.inception.assistant.model.MRemoveConversationCommand;
Expand Down Expand Up @@ -147,6 +151,7 @@ public List<MTextMessage> getAllChatMessages(String aSessionOwner, Project aProj
return state.getMessages().stream() //
.filter(MTextMessage.class::isInstance) //
.map(MTextMessage.class::cast) //
.filter(msg -> state.isDebugMode() || !msg.internal()) //
.toList();
}

Expand All @@ -165,18 +170,19 @@ public List<MTextMessage> getChatMessages(String aSessionOwner, Project aProject

void recordMessage(String aSessionOwner, Project aProject, MChatMessage aMessage)
{
if (!properties.isDevMode() && aMessage.ephemeral()) {
if (!isDebugMode(aSessionOwner, aProject) && aMessage.ephemeral()) {
return;
}

var state = getState(aSessionOwner, aProject);
state.upsertMessage(aMessage);
}

void dispatchMessage(String aSessionOwner, Project aProject, MMessage aMessage)
@Override
public void dispatchMessage(String aSessionOwner, Project aProject, MMessage aMessage)
{
if (aMessage instanceof MChatMessage chatMessage) {
if (!properties.isDevMode() && chatMessage.internal()) {
if (!isDebugMode(aSessionOwner, aProject) && chatMessage.internal()) {
return;
}
}
Expand All @@ -190,27 +196,109 @@ void dispatchMessage(String aSessionOwner, Project aProject, MMessage aMessage)
public void clearConversation(String aSessionOwner, Project aProject)
{
synchronized (states) {
states.keySet().removeIf(key -> aSessionOwner.equals(key.user())
&& Objects.equals(aProject.getId(), key.projectId));
states.entrySet().stream() //
.filter(e -> aSessionOwner.equals(e.getKey().user())
&& Objects.equals(aProject.getId(), e.getKey().projectId)) //
.map(Entry::getValue) //
.forEach(state -> state.clearMessages());
}

dispatchMessage(aSessionOwner, aProject, new MRemoveConversationCommand());
}

@Override
public void setDebugMode(String aSessionOwner, Project aProject, boolean aOnOff)
{
synchronized (states) {
getState(aSessionOwner, aProject).setDebugMode(aOnOff);
}
}

@Override
public boolean isDebugMode(String aSessionOwner, Project aProject)
{
synchronized (states) {
return getState(aSessionOwner, aProject).isDebugMode();
}
}

@Override
public MTextMessage processInternalMessageSync(String aSessionOwner, Project aProject,
MTextMessage aMessage)
throws IOException
{
Validate.isTrue(aMessage.internal());

if (properties.isDevMode()) {
if (isDebugMode(aSessionOwner, aProject)) {
recordMessage(aSessionOwner, aProject, aMessage);
dispatchMessage(aSessionOwner, aProject, aMessage);
}

var assistant = new ChatContext(properties, ollamaClient, aSessionOwner, aProject);
return assistant.generate(asList(aMessage));
return assistant.chat(asList(aMessage));
}

@Override
public <T> MCallResponse<T> processInternalCallSync(String aSessionOwner, Project aProject,
Class<T> aType, MTextMessage aMessage)
throws IOException
{
Validate.isTrue(aMessage.internal());

if (isDebugMode(aSessionOwner, aProject)) {
recordMessage(aSessionOwner, aProject, aMessage);
dispatchMessage(aSessionOwner, aProject, aMessage);
}

var assistant = new ChatContext(properties, ollamaClient, aSessionOwner, aProject);
var result = assistant.call(aType, asList(aMessage));

if (isDebugMode(aSessionOwner, aProject)) {
var resultMessage = MTextMessage.builder() //
.withRole(SYSTEM).internal().ephemeral() //
.withActor(aMessage.actor()) //
.withMessage("```json\n" + toPrettyJsonString(result.payload()) + "\n```") //
.withPerformance(result.performance()) //
.build();
recordMessage(aSessionOwner, aProject, resultMessage);
dispatchMessage(aSessionOwner, aProject, resultMessage);
}

return result;
}

@Override
public void processAgentMessage(String aSessionOwner, Project aProject, MTextMessage aMessage,
MTextMessage... aContextMessages)
{
var assistant = new ChatContext(properties, ollamaClient, aSessionOwner, aProject);

// Dispatch message early so the front-end can enter waiting state
dispatchMessage(aSessionOwner, aProject, aMessage);

try {
var systemMessages = generateSystemMessages();

recordMessage(aSessionOwner, aProject, aMessage);

var recentConversation = limitConversationToContextLength(systemMessages, emptyList(),
emptyList(), aMessage, properties.getChat().getContextLength());

var responseMessage = assistant.chat(recentConversation,
(id, r) -> handleStreamedMessageFragment(aSessionOwner, aProject, id, r));

recordMessage(aSessionOwner, aProject, responseMessage);

dispatchMessage(aSessionOwner, aProject, responseMessage.withoutContent());
}
catch (IOException e) {
var errorMessage = MTextMessage.builder() //
.withActor("Error").withRole(SYSTEM).internal().ephemeral() //
.withMessage("Error: " + e.getMessage()) //
.build();
recordMessage(aSessionOwner, aProject, errorMessage);
dispatchMessage(aSessionOwner, aProject, errorMessage);
}
}

@Override
Expand Down Expand Up @@ -238,7 +326,7 @@ public void processUserMessage(String aSessionOwner, Project aProject, MTextMess
// We record the message only now to ensure it is not included in the listMessages above
recordMessage(aSessionOwner, aProject, aMessage);

if (properties.isDevMode()) {
if (isDebugMode(aSessionOwner, aProject)) {
for (var msg : ephemeralMessages) {
recordMessage(aSessionOwner, aProject, msg);
dispatchMessage(aSessionOwner, aProject, msg);
Expand All @@ -249,7 +337,7 @@ public void processUserMessage(String aSessionOwner, Project aProject, MTextMess
ephemeralMessages, conversationMessages, aMessage,
properties.getChat().getContextLength());

var responseMessage = assistant.generate(recentConversation,
var responseMessage = assistant.chat(recentConversation,
(id, r) -> handleStreamedMessageFragment(aSessionOwner, aProject, id, r));

recordMessage(aSessionOwner, aProject, responseMessage);
Expand Down Expand Up @@ -448,10 +536,28 @@ private void clearState(String aSessionOwner)
private static class AssistentState
{
private LinkedList<MMessage> messages = new LinkedList<>();
private boolean debugMode;

public List<MMessage> getMessages()
{
return new ArrayList<>(messages);
return unmodifiableList(new ArrayList<>(messages));
}

public void clearMessages()
{
synchronized (messages) {
messages.clear();
}
}

public void setDebugMode(boolean aOnOff)
{
debugMode = aOnOff;
}

public boolean isDebugMode()
{
return debugMode;
}

public void upsertMessage(MMessage aMessage)
Expand Down
Loading

0 comments on commit d92f02f

Please sign in to comment.