Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ in the `/examples/` directory.
### DMR (Docker Model Runner) provider options

When using the `dmr` provider, you can use the `provider_opts` key for DMR
runtime-specific (e.g. llama.cpp) options:
runtime-specific (e.g. llama.cpp/vllm) options and speculative decoding:

```yaml
models:
Expand All @@ -273,7 +273,12 @@ models:
model: ai/qwen3
max_tokens: 8192
provider_opts:
# general flags passed to the underlying model runtime
runtime_flags: ["--ngl=33", "--repeat-penalty=1.2", ...] # or comma/space-separated string
# speculative decoding for faster inference
speculative_draft_model: ai/qwen3:1B
speculative_num_tokens: 5
speculative_acceptance_rate: 0.8
```

The default base_url `cagent` will use for DMR providers is
Expand All @@ -283,6 +288,8 @@ settings](https://docs.docker.com/ai/model-runner/get-started/#enable-dmr-in-doc
on MacOS and Windows, and via command line on [Docker CE on
Linux](https://docs.docker.com/ai/model-runner/get-started/#enable-dmr-in-docker-engine).

See the [DMR Provider documentation](docs/USAGE.md#dmr-docker-model-runner-provider-usage) for more details on runtime flags and speculative decoding options.

## Quickly generate agents and agent teams with `cagent new`

Using the command `cagent new` you can quickly generate agents or multi-agent
Expand Down
22 changes: 21 additions & 1 deletion docs/PROVIDERS.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ var ProviderAliases = map[string]Alias{

## Add custom config if needed (optional)

If your provider requires custom config, like Azure's `api_version`
If your provider requires custom config, like Azure's `api_version` or DMR's speculative decoding options

```yaml
models:
Expand All @@ -41,6 +41,14 @@ models:
model: gpt-4o
provider_opts:
your_custom_option: your_custom_value
# DMR with speculative decoding
dmr_model:
provider: dmr
model: ai/qwen3:14B
provider_opts:
speculative_draft_model: ai/qwen3:1B
speculative_num_tokens: 5
speculative_acceptance_rate: 0.8
```

edit [`pkg/model/provider/openai/client.go`](https://github.com/docker/cagent/blob/main/pkg/model/provider/openai/client.go)
Expand All @@ -63,3 +71,15 @@ switch cfg.Provider { //nolint:gocritic
}
}
```

## DMR Provider Specific Options

The DMR provider supports speculative decoding for faster inference. Configure it using `provider_opts`:

- `speculative_draft_model` (string): Model to use for draft predictions
- `speculative_num_tokens` (int): Number of tokens to generate speculatively
- `speculative_acceptance_rate` (float): Acceptance rate threshold for speculative tokens

All three options are passed to `docker model configure` as command-line flags.

You can also pass any flag of the underlying model runtime (llama.cpp or vllm) using the `runtime_flags` option
25 changes: 24 additions & 1 deletion docs/USAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,30 @@ models:
runtime_flags: "--ngl=33 --repeat-penalty=1.2" # string accepted as well
```

Troubleshooting:
##### Speculative Decoding

DMR supports speculative decoding for faster inference by using a smaller draft model to predict tokens ahead. Configure speculative decoding using `provider_opts`:

```yaml
models:
qwen-with-speculative:
provider: dmr
model: ai/qwen3:14B
max_tokens: 8192
provider_opts:
speculative_draft_model: ai/qwen3:0.6B-F16 # Draft model for predictions
speculative_num_tokens: 16 # Number of tokens to generate speculatively
speculative_acceptance_rate: 0.8 # Acceptance rate threshold
```

All three speculative decoding options are passed to `docker model configure` as flags:
- `speculative_draft_model` → `--speculative-draft-model`
- `speculative_num_tokens` → `--speculative-num-tokens`
- `speculative_acceptance_rate` → `--speculative-acceptance-rate`

These options work alongside `max_tokens` (which sets `--context-size`) and `runtime_flags`.

##### Troubleshooting:

- Plugin not found: cagent will log a debug message and use the default base URL
- Endpoint empty in status: ensure the Model Runner is running, or set `base_url` manually
Expand Down
13 changes: 13 additions & 0 deletions examples/dmr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,25 @@
agents:
root:
model: qwen
# model: qwen_speculative
description: "Pirate-themed AI assistant"
instruction: Talk like a pirate
commands:
demo: "Hey tell me a story about docker containers"

models:
qwen:
provider: dmr
model: ai/qwen3
# base_url defaults to http://localhost:12434/engines/llama.cpp/v1
# use http://model-runner.docker.internal/engines/v1 if you run cagent from a container

# try this model for faster inference if you have enough memory
qwen_speculative:
provider: dmr
model: ai/qwen3
# The draft model should be a smaller, faster variant of the main model with low latency
provider_opts:
speculative_draft_model: ai/qwen3:0.6B-Q4_K_M
speculative_num_tokens: 16 # (this is the llama.cpp default if omitted)
speculative_acceptance_rate: 0.8 # (this is the llama.cpp default if omitted)
102 changes: 92 additions & 10 deletions pkg/model/provider/dmr/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt
clientOptions = append(clientOptions, option.WithBaseURL(baseURL), option.WithAPIKey("")) // DMR doesn't need auth

// Build runtime flags from ModelConfig and engine
contextSize, providerRuntimeFlags := parseDMRProviderOpts(cfg)
contextSize, providerRuntimeFlags, specOpts := parseDMRProviderOpts(cfg)
configFlags := buildRuntimeFlagsFromModelConfig(engine, cfg)
finalFlags, warnings := mergeRuntimeFlagsPreferUser(configFlags, providerRuntimeFlags)
for _, w := range warnings {
slog.Warn(w)
}
slog.Debug("DMR provider_opts parsed", "model", cfg.Model, "context_size", contextSize, "runtime_flags", finalFlags, "engine", engine)
if err := configureDockerModel(ctx, cfg.Model, contextSize, finalFlags); err != nil {
slog.Debug("DMR provider_opts parsed", "model", cfg.Model, "context_size", contextSize, "runtime_flags", finalFlags, "speculative_opts", specOpts, "engine", engine)
if err := configureDockerModel(ctx, cfg.Model, contextSize, finalFlags, specOpts); err != nil {
slog.Debug("docker model configure skipped or failed", "error", err)
}

Expand Down Expand Up @@ -533,14 +533,22 @@ func ConvertParametersToSchema(params any) (any, error) {
return m, nil
}

func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize int, runtimeFlags []string) {
type speculativeDecodingOpts struct {
draftModel string
numTokens int
acceptanceRate float64
}

func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize int, runtimeFlags []string, specOpts *speculativeDecodingOpts) {
if cfg == nil {
return 0, nil
return 0, nil, nil
}

// Context length is now sourced from the standard max_tokens field
contextSize = cfg.MaxTokens

slog.Debug("DMR provider opts", "provider_opts", cfg.ProviderOpts)

if len(cfg.ProviderOpts) > 0 {
if v, ok := cfg.ProviderOpts["runtime_flags"]; ok {
switch t := v.(type) {
Expand All @@ -555,9 +563,72 @@ func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize int, runtimeFlag
runtimeFlags = append(runtimeFlags, parts...)
}
}

// Parse speculative decoding options
var hasDraftModel, hasNumTokens, hasAcceptanceRate bool
var draftModel string
var numTokens int
var acceptanceRate float64

if v, ok := cfg.ProviderOpts["speculative_draft_model"]; ok {
if s, ok := v.(string); ok && s != "" {
draftModel = s
hasDraftModel = true
}
}

if v, ok := cfg.ProviderOpts["speculative_num_tokens"]; ok {
switch t := v.(type) {
case float64:
numTokens = int(t)
hasNumTokens = true
case uint64:
numTokens = int(t)
hasNumTokens = true
case string:
s := strings.TrimSpace(t)
if s != "" {
if n, err := strconv.Atoi(s); err == nil {
numTokens = n
hasNumTokens = true
} else if f, err := strconv.ParseFloat(s, 64); err == nil {
numTokens = int(f)
hasNumTokens = true
}
}
}
}

if v, ok := cfg.ProviderOpts["speculative_acceptance_rate"]; ok {
switch t := v.(type) {
case float64:
acceptanceRate = t
hasAcceptanceRate = true
case uint64:
acceptanceRate = float64(t)
hasAcceptanceRate = true
case string:
s := strings.TrimSpace(t)
if s != "" {
if f, err := strconv.ParseFloat(s, 64); err == nil {
acceptanceRate = f
hasAcceptanceRate = true
}
}
}
}

// Only create specOpts if at least one field is set
if hasDraftModel || hasNumTokens || hasAcceptanceRate {
specOpts = &speculativeDecodingOpts{
draftModel: draftModel,
numTokens: numTokens,
acceptanceRate: acceptanceRate,
}
}
}

return contextSize, runtimeFlags
return contextSize, runtimeFlags, specOpts
}

func pullDockerModelIfNeeded(ctx context.Context, model string) error {
Expand Down Expand Up @@ -615,8 +686,8 @@ func modelExists(ctx context.Context, model string) bool {
return true
}

func configureDockerModel(ctx context.Context, model string, contextSize int, runtimeFlags []string) error {
args := buildDockerModelConfigureArgs(model, contextSize, runtimeFlags)
func configureDockerModel(ctx context.Context, model string, contextSize int, runtimeFlags []string, specOpts *speculativeDecodingOpts) error {
args := buildDockerModelConfigureArgs(model, contextSize, runtimeFlags, specOpts)

cmd := exec.CommandContext(ctx, "docker", args...)
slog.Debug("Running docker model configure", "model", model, "args", args)
Expand All @@ -631,12 +702,23 @@ func configureDockerModel(ctx context.Context, model string, contextSize int, ru
}

// buildDockerModelConfigureArgs returns the argument vector passed to `docker` for model configuration.
// It formats context size and runtime flags consistently with the CLI contract.
func buildDockerModelConfigureArgs(model string, contextSize int, runtimeFlags []string) []string {
// It formats context size, speculative decoding options, and runtime flags consistently with the CLI contract.
func buildDockerModelConfigureArgs(model string, contextSize int, runtimeFlags []string, specOpts *speculativeDecodingOpts) []string {
args := []string{"model", "configure"}
if contextSize > 0 {
args = append(args, "--context-size="+strconv.Itoa(contextSize))
}
if specOpts != nil {
if specOpts.draftModel != "" {
args = append(args, "--speculative-draft-model="+specOpts.draftModel)
}
if specOpts.numTokens > 0 {
args = append(args, "--speculative-num-tokens="+strconv.Itoa(specOpts.numTokens))
}
if specOpts.acceptanceRate > 0 {
args = append(args, "--speculative-min-acceptance-rate="+strconv.FormatFloat(specOpts.acceptanceRate, 'f', -1, 64))
}
}
args = append(args, model)
if len(runtimeFlags) > 0 {
args = append(args, "--")
Expand Down
76 changes: 74 additions & 2 deletions pkg/model/provider/dmr/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestNewClientWithWrongType(t *testing.T) {
}

func TestBuildDockerConfigureArgs(t *testing.T) {
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", 8192, []string{"--temp", "0.7", "--top-p", "0.9"})
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", 8192, []string{"--temp", "0.7", "--top-p", "0.9"}, nil)

assert.Equal(t, []string{"model", "configure", "--context-size=8192", "ai/qwen3:14B-Q6_K", "--", "--temp", "0.7", "--top-p", "0.9"}, args)
}
Expand Down Expand Up @@ -62,7 +62,7 @@ func TestIntegrateFlagsWithProviderOptsOrder(t *testing.T) {
// provider opts should be appended after derived flags so they can override by order
merged := append(derived, []string{"--threads", "6"}...)

args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", cfg.MaxTokens, merged)
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", cfg.MaxTokens, merged, nil)
assert.Equal(t, []string{"model", "configure", "--context-size=4096", "ai/qwen3:14B-Q6_K", "--", "--temp", "0.6", "--top-p", "0.9", "--threads", "6"}, args)
}

Expand All @@ -83,3 +83,75 @@ func TestMergeRuntimeFlagsPreferUser_WarnsAndPrefersUser(t *testing.T) {
func floatPtr(f float64) *float64 {
return &f
}

func TestBuildDockerConfigureArgsWithSpeculativeDecoding(t *testing.T) {
specOpts := &speculativeDecodingOpts{
draftModel: "ai/qwen3:1B",
numTokens: 5,
acceptanceRate: 0.8,
}
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", 8192, []string{"--temp", "0.7"}, specOpts)

assert.Equal(t, []string{
"model", "configure",
"--context-size=8192",
"--speculative-draft-model=ai/qwen3:1B",
"--speculative-num-tokens=5",
"--speculative-min-acceptance-rate=0.8",
"ai/qwen3:14B-Q6_K",
"--",
"--temp", "0.7",
}, args)
}

func TestBuildDockerConfigureArgsWithPartialSpeculativeDecoding(t *testing.T) {
specOpts := &speculativeDecodingOpts{
draftModel: "ai/qwen3:1B",
numTokens: 5,
// acceptanceRate not set (0 value)
}
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", 0, nil, specOpts)

assert.Equal(t, []string{
"model", "configure",
"--speculative-draft-model=ai/qwen3:1B",
"--speculative-num-tokens=5",
"ai/qwen3:14B-Q6_K",
}, args)
}

func TestParseDMRProviderOptsWithSpeculativeDecoding(t *testing.T) {
cfg := &latest.ModelConfig{
MaxTokens: 4096,
ProviderOpts: map[string]any{
"speculative_draft_model": "ai/qwen3:1B",
"speculative_num_tokens": "5",
"speculative_acceptance_rate": "0.75",
"runtime_flags": []string{"--threads", "8"},
},
}

contextSize, runtimeFlags, specOpts := parseDMRProviderOpts(cfg)

assert.Equal(t, 4096, contextSize)
assert.Equal(t, []string{"--threads", "8"}, runtimeFlags)
require.NotNil(t, specOpts)
assert.Equal(t, "ai/qwen3:1B", specOpts.draftModel)
assert.Equal(t, 5, specOpts.numTokens)
assert.InEpsilon(t, 0.75, specOpts.acceptanceRate, 0.001)
}

func TestParseDMRProviderOptsWithoutSpeculativeDecoding(t *testing.T) {
cfg := &latest.ModelConfig{
MaxTokens: 4096,
ProviderOpts: map[string]any{
"runtime_flags": []string{"--threads", "8"},
},
}

contextSize, runtimeFlags, specOpts := parseDMRProviderOpts(cfg)

assert.Equal(t, 4096, contextSize)
assert.Equal(t, []string{"--threads", "8"}, runtimeFlags)
assert.Nil(t, specOpts)
}
Loading