Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions docs/auth/byok.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,116 @@ provider: {

> **Note:** The `bearerToken` option accepts a **static token string** only. The SDK does not refresh this token automatically. If your token expires, requests will fail and you'll need to create a new session with a fresh token.

## Custom Model Listing

When using BYOK, the CLI server may not know which models your provider supports. You can supply a custom `onListModels` handler at the client level so that `client.listModels()` returns your provider's models in the standard `ModelInfo` format. This lets downstream consumers discover available models without querying the CLI.

<details open>
<summary><strong>Node.js / TypeScript</strong></summary>

```typescript
import { CopilotClient } from "@github/copilot-sdk";
import type { ModelInfo } from "@github/copilot-sdk";

const client = new CopilotClient({
onListModels: () => [
{
id: "my-custom-model",
name: "My Custom Model",
capabilities: {
supports: { vision: false, reasoningEffort: false },
limits: { max_context_window_tokens: 128000 },
},
},
],
});
```

</details>

<details>
<summary><strong>Python</strong></summary>

```python
from copilot import CopilotClient
from copilot.types import ModelInfo, ModelCapabilities, ModelSupports, ModelLimits

client = CopilotClient({
"on_list_models": lambda: [
ModelInfo(
id="my-custom-model",
name="My Custom Model",
capabilities=ModelCapabilities(
supports=ModelSupports(vision=False, reasoning_effort=False),
limits=ModelLimits(max_context_window_tokens=128000),
),
)
],
})
```

</details>

<details>
<summary><strong>Go</strong></summary>

```go
package main

import (
"context"
copilot "github.com/github/copilot-sdk/go"
)

func main() {
client := copilot.NewClient(&copilot.ClientOptions{
OnListModels: func(ctx context.Context) ([]copilot.ModelInfo, error) {
return []copilot.ModelInfo{
{
ID: "my-custom-model",
Name: "My Custom Model",
Capabilities: copilot.ModelCapabilities{
Supports: copilot.ModelSupports{Vision: false, ReasoningEffort: false},
Limits: copilot.ModelLimits{MaxContextWindowTokens: 128000},
},
},
}, nil
},
})
_ = client
}
```

</details>

<details>
<summary><strong>.NET</strong></summary>

```csharp
using GitHub.Copilot.SDK;

var client = new CopilotClient(new CopilotClientOptions
{
OnListModels = (ct) => Task.FromResult(new List<ModelInfo>
{
new()
{
Id = "my-custom-model",
Name = "My Custom Model",
Capabilities = new ModelCapabilities
{
Supports = new ModelSupports { Vision = false, ReasoningEffort = false },
Limits = new ModelLimits { MaxContextWindowTokens = 128000 }
}
}
})
});
```

</details>

Results are cached after the first call, just like the default behavior. The handler completely replaces the CLI's `models.list` RPC — no fallback to the server occurs.

## Limitations

When using BYOK, be aware of these limitations:
Expand Down
29 changes: 20 additions & 9 deletions dotnet/src/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable
private int? _negotiatedProtocolVersion;
private List<ModelInfo>? _modelsCache;
private readonly SemaphoreSlim _modelsCacheLock = new(1, 1);
private readonly Func<CancellationToken, Task<List<ModelInfo>>>? _onListModels;
private readonly List<Action<SessionLifecycleEvent>> _lifecycleHandlers = [];
private readonly Dictionary<string, List<Action<SessionLifecycleEvent>>> _typedLifecycleHandlers = [];
private readonly object _lifecycleHandlersLock = new();
Expand Down Expand Up @@ -136,6 +137,7 @@ public CopilotClient(CopilotClientOptions? options = null)
}

_logger = _options.Logger ?? NullLogger.Instance;
_onListModels = _options.OnListModels;

// Parse CliUrl if provided
if (!string.IsNullOrEmpty(_options.CliUrl))
Expand Down Expand Up @@ -624,9 +626,6 @@ public async Task<GetAuthStatusResponse> GetAuthStatusAsync(CancellationToken ca
/// <exception cref="InvalidOperationException">Thrown when the client is not connected or not authenticated.</exception>
public async Task<List<ModelInfo>> ListModelsAsync(CancellationToken cancellationToken = default)
{
var connection = await EnsureConnectedAsync(cancellationToken);

// Use semaphore for async locking to prevent race condition with concurrent calls
await _modelsCacheLock.WaitAsync(cancellationToken);
try
{
Expand All @@ -636,14 +635,26 @@ public async Task<List<ModelInfo>> ListModelsAsync(CancellationToken cancellatio
return [.. _modelsCache]; // Return a copy to prevent cache mutation
}

// Cache miss - fetch from backend while holding lock
var response = await InvokeRpcAsync<GetModelsResponse>(
connection.Rpc, "models.list", [], cancellationToken);
List<ModelInfo> models;
if (_onListModels is not null)
{
// Use custom handler instead of CLI RPC
models = await _onListModels(cancellationToken);
}
else
{
var connection = await EnsureConnectedAsync(cancellationToken);

// Cache miss - fetch from backend while holding lock
var response = await InvokeRpcAsync<GetModelsResponse>(
connection.Rpc, "models.list", [], cancellationToken);
models = response.Models;
}

// Update cache before releasing lock
_modelsCache = response.Models;
// Update cache before releasing lock (copy to prevent external mutation)
_modelsCache = [.. models];

return [.. response.Models]; // Return a copy to prevent cache mutation
return [.. models]; // Return a copy to prevent cache mutation
}
finally
{
Expand Down
9 changes: 9 additions & 0 deletions dotnet/src/Types.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ protected CopilotClientOptions(CopilotClientOptions? other)
Port = other.Port;
UseLoggedInUser = other.UseLoggedInUser;
UseStdio = other.UseStdio;
OnListModels = other.OnListModels;
}

/// <summary>
Expand Down Expand Up @@ -136,6 +137,14 @@ public string? GithubToken
/// </summary>
public bool? UseLoggedInUser { get; set; }

/// <summary>
/// Custom handler for listing available models.
/// When provided, <c>ListModelsAsync()</c> calls this handler instead of
/// querying the CLI server. Useful in BYOK mode to return models
/// available from your custom provider.
/// </summary>
public Func<CancellationToken, Task<List<ModelInfo>>>? OnListModels { get; set; }

/// <summary>
/// Creates a shallow clone of this <see cref="CopilotClientOptions"/> instance.
/// </summary>
Expand Down
100 changes: 100 additions & 0 deletions dotnet/test/ClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,104 @@ public async Task Should_Throw_When_ResumeSession_Called_Without_PermissionHandl
Assert.Contains("OnPermissionRequest", ex.Message);
Assert.Contains("is required", ex.Message);
}

[Fact]
public async Task ListModels_WithCustomHandler_CallsHandler()
{
var customModels = new List<ModelInfo>
{
new()
{
Id = "my-custom-model",
Name = "My Custom Model",
Capabilities = new ModelCapabilities
{
Supports = new ModelSupports { Vision = false, ReasoningEffort = false },
Limits = new ModelLimits { MaxContextWindowTokens = 128000 }
}
}
};

var callCount = 0;
await using var client = new CopilotClient(new CopilotClientOptions
{
OnListModels = (ct) =>
{
callCount++;
return Task.FromResult(customModels);
}
});
await client.StartAsync();

var models = await client.ListModelsAsync();
Assert.Equal(1, callCount);
Assert.Single(models);
Assert.Equal("my-custom-model", models[0].Id);
}
Comment on lines +278 to +310
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new tests exercise OnListModels, but they still call StartAsync(). Since the intended behavior is that ListModelsAsync() works without a CLI connection when OnListModels is provided, add a test that calls ListModelsAsync() before StartAsync() to lock in that guarantee.

Copilot uses AI. Check for mistakes.

[Fact]
public async Task ListModels_WithCustomHandler_CachesResults()
{
var customModels = new List<ModelInfo>
{
new()
{
Id = "cached-model",
Name = "Cached Model",
Capabilities = new ModelCapabilities
{
Supports = new ModelSupports { Vision = false, ReasoningEffort = false },
Limits = new ModelLimits { MaxContextWindowTokens = 128000 }
}
}
};

var callCount = 0;
await using var client = new CopilotClient(new CopilotClientOptions
{
OnListModels = (ct) =>
{
callCount++;
return Task.FromResult(customModels);
}
});
await client.StartAsync();

await client.ListModelsAsync();
await client.ListModelsAsync();
Assert.Equal(1, callCount); // Only called once due to caching
}

[Fact]
public async Task ListModels_WithCustomHandler_WorksWithoutStart()
{
var customModels = new List<ModelInfo>
{
new()
{
Id = "no-start-model",
Name = "No Start Model",
Capabilities = new ModelCapabilities
{
Supports = new ModelSupports { Vision = false, ReasoningEffort = false },
Limits = new ModelLimits { MaxContextWindowTokens = 128000 }
}
}
};

var callCount = 0;
await using var client = new CopilotClient(new CopilotClientOptions
{
OnListModels = (ct) =>
{
callCount++;
return Task.FromResult(customModels);
}
});

var models = await client.ListModelsAsync();
Assert.Equal(1, callCount);
Assert.Single(models);
Assert.Equal("no-start-model", models[0].Id);
}
}
51 changes: 33 additions & 18 deletions go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ type Client struct {
processErrorPtr *error
osProcess atomic.Pointer[os.Process]
negotiatedProtocolVersion int
onListModels func(ctx context.Context) ([]ModelInfo, error)

// RPC provides typed server-scoped RPC methods.
// This field is nil until the client is connected via Start().
Expand Down Expand Up @@ -188,6 +189,9 @@ func NewClient(options *ClientOptions) *Client {
if options.UseLoggedInUser != nil {
opts.UseLoggedInUser = options.UseLoggedInUser
}
if options.OnListModels != nil {
client.onListModels = options.OnListModels
}
}

// Default Env to current environment if not set
Expand Down Expand Up @@ -1035,40 +1039,51 @@ func (c *Client) GetAuthStatus(ctx context.Context) (*GetAuthStatusResponse, err
// Results are cached after the first successful call to avoid rate limiting.
// The cache is cleared when the client disconnects.
func (c *Client) ListModels(ctx context.Context) ([]ModelInfo, error) {
if c.client == nil {
return nil, fmt.Errorf("client not connected")
}

// Use mutex for locking to prevent race condition with concurrent calls
c.modelsCacheMux.Lock()
defer c.modelsCacheMux.Unlock()

// Check cache (already inside lock)
if c.modelsCache != nil {
// Return a copy to prevent cache mutation
result := make([]ModelInfo, len(c.modelsCache))
copy(result, c.modelsCache)
return result, nil
}

// Cache miss - fetch from backend while holding lock
result, err := c.client.Request("models.list", listModelsRequest{})
if err != nil {
return nil, err
}
var models []ModelInfo
if c.onListModels != nil {
// Use custom handler instead of CLI RPC
var err error
models, err = c.onListModels(ctx)
if err != nil {
return nil, err
}
} else {
if c.client == nil {
return nil, fmt.Errorf("client not connected")
}
// Cache miss - fetch from backend while holding lock
result, err := c.client.Request("models.list", listModelsRequest{})
if err != nil {
return nil, err
}

var response listModelsResponse
if err := json.Unmarshal(result, &response); err != nil {
return nil, fmt.Errorf("failed to unmarshal models response: %w", err)
var response listModelsResponse
if err := json.Unmarshal(result, &response); err != nil {
return nil, fmt.Errorf("failed to unmarshal models response: %w", err)
}
models = response.Models
}

// Update cache before releasing lock
c.modelsCache = response.Models
// Update cache before releasing lock (copy to prevent external mutation)
cache := make([]ModelInfo, len(models))
copy(cache, models)
c.modelsCache = cache

// Return a copy to prevent cache mutation
models := make([]ModelInfo, len(response.Models))
copy(models, response.Models)
return models, nil
result := make([]ModelInfo, len(models))
copy(result, models)
return result, nil
}

// minProtocolVersion is the minimum protocol version this SDK can communicate with.
Expand Down
Loading
Loading