Skip to content

Commit

Permalink
Support for #13 and version bump, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
tghamm committed Mar 5, 2024
1 parent affaa54 commit c39f451
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 148 deletions.
161 changes: 76 additions & 85 deletions Anthropic.SDK.Tests/Messages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,56 +39,49 @@ public async Task TestBasicClaude3ImageMessage()
{
string resourceName = "Anthropic.SDK.Tests.Red_Apple.jpg";

// Get the current assembly
Assembly assembly = Assembly.GetExecutingAssembly();

// Get a stream to the embedded resource
using (Stream stream = assembly.GetManifestResourceStream(resourceName))
await using Stream stream = assembly.GetManifestResourceStream(resourceName);
byte[] imageBytes;
using (var memoryStream = new MemoryStream())
{
// Read the stream into a byte array
byte[] imageBytes;
using (var memoryStream = new MemoryStream())
{
stream.CopyTo(memoryStream);
imageBytes = memoryStream.ToArray();
}

// Convert the byte array to a base64 string
string base64String = Convert.ToBase64String(imageBytes);
await stream.CopyToAsync(memoryStream);
imageBytes = memoryStream.ToArray();
}

string base64String = Convert.ToBase64String(imageBytes);

var client = new AnthropicClient();
var messages = new List<Message>();
messages.Add(new Message()
var client = new AnthropicClient();

var messages = new List<Message>();
messages.Add(new Message()
{
Role = RoleType.User,
Content = new dynamic[]
{
Role = RoleType.User,
Content = new dynamic[]
new ImageContent()
{
new ImageContent()
Source = new ImageSource()
{
Source = new ImageSource()
{
MediaType = "image/jpeg",
Data = base64String
}
},
new TextContent()
{
Text = "What is this a picture of?"
MediaType = "image/jpeg",
Data = base64String
}
},
new TextContent()
{
Text = "What is this a picture of?"
}
});
var parameters = new MessageParameters()
{
Messages = messages,
MaxTokens = 512,
Model = AnthropicModels.Claude3Sonnet,
Stream = false,
Temperature = 1.0m,
};
var res = await client.Messages.GetClaudeMessageAsync(parameters);

}

}
});
var parameters = new MessageParameters()
{
Messages = messages,
MaxTokens = 512,
Model = AnthropicModels.Claude3Opus,
Stream = false,
Temperature = 1.0m,
};
var res = await client.Messages.GetClaudeMessageAsync(parameters);
}

[TestMethod]
Expand All @@ -100,62 +93,60 @@ public async Task TestStreamingClaude3ImageMessage()
Assembly assembly = Assembly.GetExecutingAssembly();

// Get a stream to the embedded resource
using (Stream stream = assembly.GetManifestResourceStream(resourceName))
await using Stream stream = assembly.GetManifestResourceStream(resourceName);
// Read the stream into a byte array
byte[] imageBytes;
using (var memoryStream = new MemoryStream())
{
// Read the stream into a byte array
byte[] imageBytes;
using (var memoryStream = new MemoryStream())
{
stream.CopyTo(memoryStream);
imageBytes = memoryStream.ToArray();
}
await stream.CopyToAsync(memoryStream);
imageBytes = memoryStream.ToArray();
}

// Convert the byte array to a base64 string
string base64String = Convert.ToBase64String(imageBytes);
// Convert the byte array to a base64 string
string base64String = Convert.ToBase64String(imageBytes);

var client = new AnthropicClient();
var messages = new List<Message>();
messages.Add(new Message()
var client = new AnthropicClient();
var messages = new List<Message>();
messages.Add(new Message()
{
Role = RoleType.User,
Content = new dynamic[]
{
Role = RoleType.User,
Content = new dynamic[]
new ImageContent()
{
new ImageContent()
Source = new ImageSource()
{
Source = new ImageSource()
{
MediaType = "image/jpeg",
Data = base64String
}
},
new TextContent()
{
Text = "What is this a picture of?"
MediaType = "image/jpeg",
Data = base64String
}
}
});
var parameters = new MessageParameters()
{
Messages = messages,
MaxTokens = 512,
Model = AnthropicModels.Claude3Sonnet,
Stream = true,
Temperature = 1.0m,
};
var outputs = new List<MessageResponse>();
await foreach (var res in client.Messages.StreamClaudeMessageAsync(parameters))
{
if (res.Delta != null)
},
new TextContent()
{
Debug.Write(res.Delta.Text);
Text = "What is this a picture of?"
}

outputs.Add(res);
}

});
var parameters = new MessageParameters()
{
Messages = messages,
MaxTokens = 512,
Model = AnthropicModels.Claude3Opus,
Stream = true,
Temperature = 1.0m,
};
var outputs = new List<MessageResponse>();
await foreach (var res in client.Messages.StreamClaudeMessageAsync(parameters))
{
if (res.Delta != null)
{
Debug.Write(res.Delta.Text);
}

outputs.Add(res);
}

Debug.WriteLine(string.Empty);
Debug.WriteLine($@"Used Tokens - Input:{outputs.First().StreamStartMessage.Usage.InputTokens}.
Output: {outputs.Last().Usage.OutputTokens}");
}
}
}
8 changes: 4 additions & 4 deletions Anthropic.SDK/Anthropic.SDK.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
<PackageTags>Claude, AI, ML, API, Anthropic</PackageTags>
<Title>Claude API</Title>
<PackageReleaseNotes>
Adds token counter extension method helper.
Support for Messages Endpoint and Claude 3
</PackageReleaseNotes>
<PackageId>Anthropic.SDK</PackageId>
<Version>1.3.0</Version>
<AssemblyVersion>1.3.0.0</AssemblyVersion>
<FileVersion>1.3.0.0</FileVersion>
<Version>2.0.0</Version>
<AssemblyVersion>2.0.0.0</AssemblyVersion>
<FileVersion>2.0.0.0</FileVersion>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
<PackageReadmeFile>README.md</PackageReadmeFile>
<ProduceReferenceAssembly>True</ProduceReferenceAssembly>
Expand Down
12 changes: 8 additions & 4 deletions Anthropic.SDK/Completions/CompletionsEndpoint.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

namespace Anthropic.SDK.Completions
Expand All @@ -19,23 +21,25 @@ internal CompletionsEndpoint(AnthropicClient client) : base(client) { }
/// Makes a non-streaming call to the Claude completion API. Be sure to set stream to false in <param name="parameters"></param>.
/// </summary>
/// <param name="parameters"></param>
public async Task<CompletionResponse> GetClaudeCompletionAsync(SamplingParameters parameters)
/// <param name="ctx"></param>
public async Task<CompletionResponse> GetClaudeCompletionAsync(SamplingParameters parameters, CancellationToken ctx = default)
{
parameters.Stream = false;
ValidateParameters(parameters);
var response = await HttpRequest<CompletionResponse>(Url, HttpMethod.Post, parameters);
var response = await HttpRequest<CompletionResponse>(Url, HttpMethod.Post, parameters, ctx);
return response;
}

/// <summary>
/// Makes a streaming call to the Claude completion API using an IAsyncEnumerable. Be sure to set stream to true in <param name="parameters"></param>.
/// </summary>
/// <param name="parameters"></param>
public async IAsyncEnumerable<CompletionResponse> StreamClaudeCompletionAsync(SamplingParameters parameters)
/// <param name="ctx"></param>
public async IAsyncEnumerable<CompletionResponse> StreamClaudeCompletionAsync(SamplingParameters parameters, [EnumeratorCancellation] CancellationToken ctx = default)
{
parameters.Stream = true;
ValidateParameters(parameters);
await foreach (var result in HttpStreamingRequest<CompletionResponse>(Url, HttpMethod.Post, parameters))
await foreach (var result in HttpStreamingRequest<CompletionResponse>(Url, HttpMethod.Post, parameters, ctx))
{
yield return result;
}
Expand Down
56 changes: 38 additions & 18 deletions Anthropic.SDK/EndpointBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
using System.IO;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Runtime.CompilerServices;
using System.Security.Authentication;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;

namespace Anthropic.SDK
Expand Down Expand Up @@ -69,18 +71,22 @@ private string GetErrorMessage(string resultAsString, HttpResponseMessage respon
return $"Error at {name} ({description}) with HTTP status code: {response.StatusCode}. Content: {resultAsString ?? "<no content>"}";
}

protected async Task<T> HttpRequest<T>(string url = null, HttpMethod verb = null, object postData = null)
protected async Task<T> HttpRequest<T>(string url = null, HttpMethod verb = null, object postData = null, CancellationToken ctx = default)
{
var response = await HttpRequestRaw(url, verb, postData);
var response = await HttpRequestRaw(url, verb, postData, ctx: ctx);
#if NET6_0_OR_GREATER
string resultAsString = await response.Content.ReadAsStringAsync(ctx);
#else
string resultAsString = await response.Content.ReadAsStringAsync();

#endif
var res = await JsonSerializer.DeserializeAsync<T>(
new MemoryStream(Encoding.UTF8.GetBytes(resultAsString)));
new MemoryStream(Encoding.UTF8.GetBytes(resultAsString)), cancellationToken: ctx);

return res;
}

private async Task<HttpResponseMessage> HttpRequestRaw(string url = null, HttpMethod verb = null, object postData = null, bool streaming = false)
private async Task<HttpResponseMessage> HttpRequestRaw(string url = null, HttpMethod verb = null,
object postData = null, bool streaming = false, CancellationToken ctx = default)
{
if (string.IsNullOrEmpty(url))
url = this.Url;
Expand Down Expand Up @@ -110,7 +116,7 @@ private async Task<HttpResponseMessage> HttpRequestRaw(string url = null, HttpMe
}

response = await client.SendAsync(req,
streaming ? HttpCompletionOption.ResponseHeadersRead : HttpCompletionOption.ResponseContentRead);
streaming ? HttpCompletionOption.ResponseHeadersRead : HttpCompletionOption.ResponseContentRead, ctx);

if (response.IsSuccessStatusCode)
{
Expand All @@ -120,7 +126,11 @@ private async Task<HttpResponseMessage> HttpRequestRaw(string url = null, HttpMe
{
try
{
#if NET6_0_OR_GREATER
resultAsString = await response.Content.ReadAsStringAsync(ctx);
#else
resultAsString = await response.Content.ReadAsStringAsync();
#endif
}
catch (Exception e)
{
Expand Down Expand Up @@ -148,12 +158,17 @@ private async Task<HttpResponseMessage> HttpRequestRaw(string url = null, HttpMe
}
}

protected async IAsyncEnumerable<T> HttpStreamingRequest<T>(string url = null, HttpMethod verb = null, object postData = null)
protected async IAsyncEnumerable<T> HttpStreamingRequest<T>(string url = null, HttpMethod verb = null,
object postData = null, [EnumeratorCancellation] CancellationToken ctx = default)
{
var response = await HttpRequestRaw(url, verb, postData, true);

var response = await HttpRequestRaw(url, verb, postData, true, ctx);

#if NET6_0_OR_GREATER
await using var stream = await response.Content.ReadAsStreamAsync(ctx);
#else
using var stream = await response.Content.ReadAsStreamAsync();
#endif

using StreamReader reader = new StreamReader(stream);
string line;
SseEvent currentEvent = new SseEvent();
Expand All @@ -175,28 +190,33 @@ protected async IAsyncEnumerable<T> HttpStreamingRequest<T>(string url = null, H
if (currentEvent.EventType == "completion")
{
var res = await JsonSerializer.DeserializeAsync<T>(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)));
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), cancellationToken: ctx);
yield return res;
}
else if (currentEvent.EventType == "error")
{
var res = await JsonSerializer.DeserializeAsync<ErrorResponse>(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)));
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), cancellationToken: ctx);
throw new Exception(res.Error.Message);
}

// Reset the current event for the next one
currentEvent = new SseEvent();
}
}
}

protected async IAsyncEnumerable<T> HttpStreamingRequestMessages<T>(string url = null, HttpMethod verb = null, object postData = null)
protected async IAsyncEnumerable<T> HttpStreamingRequestMessages<T>(string url = null, HttpMethod verb = null,
object postData = null, [EnumeratorCancellation] CancellationToken ctx = default)
{
var response = await HttpRequestRaw(url, verb, postData, true);
var response = await HttpRequestRaw(url, verb, postData, true, ctx);


#if NET6_0_OR_GREATER
await using var stream = await response.Content.ReadAsStreamAsync(ctx);
#else
using var stream = await response.Content.ReadAsStreamAsync();
#endif
using StreamReader reader = new StreamReader(stream);
string line;
SseEvent currentEvent = new SseEvent();
Expand All @@ -215,18 +235,18 @@ protected async IAsyncEnumerable<T> HttpStreamingRequestMessages<T>(string url =
}
else // an empty line indicates the end of an event
{
if (currentEvent.EventType == "message_start" ||
currentEvent.EventType == "content_block_delta" ||
if (currentEvent.EventType == "message_start" ||
currentEvent.EventType == "content_block_delta" ||
currentEvent.EventType == "message_delta")
{
var res = await JsonSerializer.DeserializeAsync<T>(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)));
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), cancellationToken: ctx);
yield return res;
}
else if (currentEvent.EventType == "error")
{
var res = await JsonSerializer.DeserializeAsync<ErrorResponse>(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)));
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), cancellationToken: ctx);
throw new Exception(res.Error.Message);
}

Expand Down
Loading

0 comments on commit c39f451

Please sign in to comment.