Skip to content

add example for LocalAi #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions local-ai-example/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-examples</artifactId>
<version>0.30.0</version>
</parent>

<artifactId>local-ai-example</artifactId>
<packaging>jar</packaging>

<name>local-ai-example</name>
<url>http://maven.apache.org</url>

<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>

<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>3.8.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>0.30.0</version>
<scope>compile</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-local-ai</artifactId>
<version>0.29.1</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>0.30.0</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
<version>0.30.0</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.github.docker-java</groupId>
<artifactId>docker-java-api</artifactId>
<version>3.3.6</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers</artifactId>
<version>1.19.7</version>
<scope>compile</scope>
</dependency>
</dependencies>
</project>
99 changes: 99 additions & 0 deletions local-ai-example/src/main/java/AbstractLocalAiInfrastructure.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import com.github.dockerjava.api.DockerClient;
import com.github.dockerjava.api.command.InspectContainerResponse;
import com.github.dockerjava.api.model.Image;
import org.testcontainers.DockerClientFactory;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.utility.DockerImageName;
import org.testcontainers.utility.MountableFile;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class AbstractLocalAiInfrastructure {

private static final String LOCAL_AI_IMAGE = "localai/localai:latest";

private static final String LOCAL_IMAGE_NAME = "tc-local-ai";

private static final String LOCAL_LOCAL_AI_IMAGE = String.format("%s:%s", LOCAL_IMAGE_NAME, DockerImageName.parse(LOCAL_AI_IMAGE).getVersionPart());

private static final List<String[]> CMDS = Arrays.asList(
new String[]{"curl", "-o", "/build/models/ggml-gpt4all-j", "https://gpt4all.io/models/ggml-gpt4all-j.bin"},
new String[]{"curl", "-Lo", "/build/models/ggml-model-q4_0", "https://huggingface.co/LangChain4j/localai-embeddings/resolve/main/ggml-model-q4_0"});

static final LocalAiContainer localAi;

static {
localAi = new LocalAiContainer(new LocalAi(LOCAL_AI_IMAGE, LOCAL_LOCAL_AI_IMAGE).resolve());
localAi.start();
createImage(localAi, LOCAL_LOCAL_AI_IMAGE);
}

static void createImage(GenericContainer<?> container, String localImageName) {
DockerImageName dockerImageName = DockerImageName.parse(container.getDockerImageName());
if (!dockerImageName.equals(DockerImageName.parse(localImageName))) {
DockerClient dockerClient = DockerClientFactory.instance().client();
List<Image> images = dockerClient.listImagesCmd().withReferenceFilter(localImageName).exec();
if (images.isEmpty()) {
DockerImageName imageModel = DockerImageName.parse(localImageName);
dockerClient.commitCmd(container.getContainerId())
.withRepository(imageModel.getUnversionedPart())
.withLabels(Collections.singletonMap("org.testcontainers.sessionId", ""))
.withTag(imageModel.getVersionPart())
.exec();
}
}
}

static class LocalAiContainer extends GenericContainer<LocalAiContainer> {

public LocalAiContainer(DockerImageName image) {
super(image);
withExposedPorts(8080);
withImagePullPolicy(dockerImageName -> !dockerImageName.getUnversionedPart().startsWith(LOCAL_IMAGE_NAME));
}
@Override
protected void containerIsStarted(InspectContainerResponse containerInfo) {
if (!DockerImageName.parse(getDockerImageName()).equals(DockerImageName.parse(LOCAL_LOCAL_AI_IMAGE))) {
try {
for (String[] cmd : CMDS) {
execInContainer(cmd);
}
copyFileToContainer(MountableFile.forClasspathResource("ggml-model-q4_0.yaml"), "/build/models/ggml-model-q4_0.yaml");
} catch (IOException | InterruptedException e) {
throw new RuntimeException("Error downloading the model", e);
}
}
}

public String getBaseUrl() {
return "http://" + getHost() + ":" + getMappedPort(8080);
}
}

static class LocalAi {

private final String baseImage;

private final String localImageName;

LocalAi(String baseImage, String localImageName) {
this.baseImage = baseImage;
this.localImageName = localImageName;
}

protected DockerImageName resolve() {
DockerImageName dockerImageName = DockerImageName.parse(this.baseImage);
DockerClient dockerClient = DockerClientFactory.instance().client();
List<Image> images = dockerClient.listImagesCmd().withReferenceFilter(this.localImageName).exec();
if (images.isEmpty()) {
return dockerImageName;
}
return DockerImageName.parse(this.localImageName);
}

}

}
37 changes: 37 additions & 0 deletions local-ai-example/src/main/java/LocalAiChatModelExamples.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import dev.langchain4j.model.output.Response;

import java.util.Collections;
import java.util.List;

public class LocalAiChatModelExamples extends AbstractLocalAiInfrastructure {
static ChatLanguageModel model = LocalAiChatModel.builder()
.baseUrl(localAi.getBaseUrl())
.modelName("ggml-gpt4all-j")
.maxTokens(3)
.logRequests(true)
.logResponses(true)
.build();

static class Simple_Prompt {
public static void main(String[] args) {
String answer = model.generate("better go home and weave a net than to stand by the pond longing for fish.");

System.out.println(answer);
}
}

static class Simple_Message_Prompt {
public static void main(String[] args) {
UserMessage userMessage = UserMessage.from("better go home and weave a net than to stand by the pond longing for fish.");
List<ChatMessage> messages = Collections.singletonList(userMessage);
Response<AiMessage> response = model.generate(messages);

System.out.println(response);
}
}
}
37 changes: 37 additions & 0 deletions local-ai-example/src/main/java/LocalAiEmbeddingModelExamples.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.localai.LocalAiEmbeddingModel;
import dev.langchain4j.model.output.Response;
import org.testcontainers.shaded.com.google.common.collect.Lists;

import java.util.List;

public class LocalAiEmbeddingModelExamples extends AbstractLocalAiInfrastructure {

static EmbeddingModel embeddingModel = LocalAiEmbeddingModel.builder()
.baseUrl(localAi.getBaseUrl())
.modelName("ggml-model-q4_0")
.logRequests(true)
.logResponses(true)
.build();

static class Simple_Embed {
public static void main(String[] args) {
Response<Embedding> response = embeddingModel.embed("better go home and weave a net than to stand by the pond longing for fish.");

System.out.println(response.content());
}
}

static class List_Embed {
public static void main(String[] args) {
TextSegment textSegment1 = TextSegment.from("better go home and weave a net than ");
TextSegment textSegment2 = TextSegment.from("to stand by the pond longing for fish.");
Response<List<Embedding>> listResponse = embeddingModel.embedAll(Lists.newArrayList(textSegment1, textSegment2));

listResponse.content().stream().map(Embedding::dimension).forEach(System.out::println);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiStreamingChatModel;
import dev.langchain4j.model.output.Response;

public class LocalAiStreamingChatModelExamples extends AbstractLocalAiInfrastructure {


static StreamingChatLanguageModel model = LocalAiStreamingChatModel.builder()
.baseUrl(localAi.getBaseUrl())
.modelName("ggml-gpt4all-j")
.maxTokens(50)
.logRequests(true)
.logResponses(true)
.build();

static class Simple_Prompt {
public static void main(String[] args) {

model.generate("Tell me a poem by Li Bai", new StreamingResponseHandler<AiMessage>() {

@Override
public void onNext(String token) {
System.out.println("onNext(): " + token);
}

@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("onComplete(): " + response);
}

@Override
public void onError(Throwable error) {
error.printStackTrace();
}
});
}
}
}