Skip to content

Commit bc23670

Browse files
committed
Add coder support
1 parent c9c4a88 commit bc23670

19 files changed

+492
-88
lines changed

internal/server/path.go

+12
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ func handlePath(json any, paths *[]string, fn func(string, *[]string) (string, e
3232
return nil, err
3333
}
3434
xs[i] = o
35+
} else {
36+
o, err := handlePath(xs[i], paths, fn)
37+
if err != nil {
38+
return nil, err
39+
}
40+
xs[i] = o
3541
}
3642
}
3743
return xs, nil
@@ -43,6 +49,12 @@ func handlePath(json any, paths *[]string, fn func(string, *[]string) (string, e
4349
return nil, err
4450
}
4551
m[key] = o
52+
} else {
53+
o, err := handlePath(m[key], paths, fn)
54+
if err != nil {
55+
return nil, err
56+
}
57+
m[key] = o
4658
}
4759
}
4860
return m, nil

internal/tests/coder_test.go

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package tests
2+
3+
import (
4+
"testing"
5+
6+
"github.com/replicate/cog-runtime/internal/server"
7+
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestPredictionDataclassCoderSucceeded(t *testing.T) {
12+
if *legacyCog {
13+
// Compat: legacy Cog does not support custom coder
14+
t.SkipNow()
15+
}
16+
ct := NewCogTest(t, "dataclass")
17+
assert.NoError(t, ct.Start())
18+
19+
hc := ct.WaitForSetup()
20+
assert.Equal(t, server.StatusReady.String(), hc.Status)
21+
assert.Equal(t, server.SetupSucceeded, hc.Setup.Status)
22+
23+
resp := ct.Prediction(map[string]any{
24+
"account": map[string]any{
25+
"id": 0,
26+
"name": "John",
27+
"address": map[string]any{"street": "Smith", "zip": 12345},
28+
"credentials": map[string]any{"password": "foo", "pubkey": b64encode("bar")},
29+
},
30+
})
31+
32+
output := map[string]any{
33+
"account": map[string]any{
34+
"id": 100.0,
35+
"name": "JOHN",
36+
"address": map[string]any{"street": "SMITH", "zip": 22345.0},
37+
"credentials": map[string]any{"password": "**********", "pubkey": b64encode("*bar*")},
38+
},
39+
}
40+
ct.AssertResponse(resp, server.PredictionSucceeded, output, "")
41+
42+
ct.Shutdown()
43+
assert.NoError(t, ct.Cleanup())
44+
}
45+
46+
func TestPredictionChatCoderSucceeded(t *testing.T) {
47+
if *legacyCog {
48+
// Compat: legacy Cog does not support custom coder
49+
t.SkipNow()
50+
}
51+
ct := NewCogTest(t, "chat")
52+
assert.NoError(t, ct.Start())
53+
54+
hc := ct.WaitForSetup()
55+
assert.Equal(t, server.StatusReady.String(), hc.Status)
56+
assert.Equal(t, server.SetupSucceeded, hc.Setup.Status)
57+
58+
resp := ct.Prediction(map[string]any{"msg": map[string]any{"role": "assistant", "content": "bar"}})
59+
output := map[string]any{"role": "assistant", "content": "*bar*"}
60+
ct.AssertResponse(resp, server.PredictionSucceeded, output, "")
61+
62+
ct.Shutdown()
63+
assert.NoError(t, ct.Cleanup())
64+
}

pyproject.toml

+5
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,17 @@ dev = [
2222
]
2323

2424
test = [
25+
'openai',
2526
'pytest-cov',
2627
'pytest',
2728
'pytest-asyncio',
2829
'tqdm',
2930
]
3031

32+
provided = [
33+
'pydantic',
34+
]
35+
3136
[build-system]
3237
requires = ['setuptools', 'setuptools-scm']
3338
build-backend = 'setuptools.build_meta'

python/cog/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
BaseModel,
44
BasePredictor,
55
CancelationException,
6+
Coder,
67
ConcatenateIterator,
78
Input,
89
Path,
@@ -16,6 +17,7 @@
1617
'BaseModel',
1718
'BasePredictor',
1819
'CancelationException',
20+
'Coder',
1921
'ConcatenateIterator',
2022
'Input',
2123
'Path',

python/cog/coder/__init__.py

Whitespace-only changes.

python/cog/coder/dataclass_coder.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import dataclasses
2+
from typing import Any, Optional, Type
3+
4+
from coglet import api
5+
6+
7+
class DataclassCoder(api.Coder):
8+
@staticmethod
9+
def factory(cls: Type) -> Optional[api.Coder]:
10+
if dataclasses.is_dataclass(cls):
11+
return DataclassCoder(cls)
12+
else:
13+
return None
14+
15+
def __init__(self, cls: Type):
16+
assert dataclasses.is_dataclass(cls)
17+
self.cls = cls
18+
19+
def encode(self, x: Any) -> dict[str, Any]:
20+
# Secret is a dataclass and dataclasses.asdict recursively converts its internals
21+
return self._to_dict(self.cls, x)
22+
23+
def _to_dict(self, cls: Type, x: Any) -> dict[str, Any]:
24+
r: dict[str, Any] = {}
25+
for f in dataclasses.fields(cls):
26+
v = getattr(x, f.name)
27+
# Keep Path and Secret as is and let json.dumps(default=fn) handle them
28+
if f.type is api.Path:
29+
v = api.Path(v)
30+
elif f.type is api.Secret:
31+
v = api.Secret(v)
32+
elif dataclasses.is_dataclass(v):
33+
v = self._to_dict(f.type, v) # type: ignore
34+
r[f.name] = v
35+
return r
36+
37+
def decode(self, x: dict[str, Any]) -> Any:
38+
kwargs = self._from_dict(self.cls, x)
39+
return self.cls(**kwargs) # type: ignore
40+
41+
def _from_dict(self, cls: Type, x: dict[str, Any]) -> Any:
42+
r: dict[str, Any] = {}
43+
for f in dataclasses.fields(cls):
44+
if f.name not in x:
45+
continue
46+
elif f.type is api.Path:
47+
r[f.name] = api.Path(x[f.name])
48+
# Secret is a dataclass and must be handled before other dataclasses
49+
elif f.type is api.Secret:
50+
r[f.name] = api.Secret(x[f.name])
51+
elif dataclasses.is_dataclass(f.type):
52+
kwargs = self._from_dict(f.type, x[f.name]) # type: ignore
53+
r[f.name] = f.type(**kwargs) # type: ignore
54+
else:
55+
r[f.name] = x[f.name]
56+
return r

python/cog/coder/json_coder.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import typing
2+
from typing import Any, Optional, Type
3+
4+
from coglet import api
5+
6+
7+
class JsonCoder(api.Coder):
8+
@staticmethod
9+
def factory(cls: Type) -> Optional[api.Coder]:
10+
if typing.get_origin(cls) is dict is dict and typing.get_args(cls)[0] is str:
11+
return JsonCoder()
12+
else:
13+
return None
14+
15+
def encode(self, x: Any) -> dict[str, Any]:
16+
return x
17+
18+
def decode(self, x: dict[str, Any]) -> Any:
19+
return x

python/cog/coder/pydantic_coder.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import inspect
2+
from typing import Any, Type
3+
4+
from pydantic import BaseModel
5+
6+
from coglet import api
7+
8+
9+
class BaseModelCoder(api.Coder):
10+
@staticmethod
11+
def factory(cls: Type):
12+
if cls is not BaseModel and any(c is BaseModel for c in inspect.getmro(cls)):
13+
return BaseModelCoder(cls)
14+
else:
15+
return None
16+
17+
def __init__(self, cls: Type[BaseModel]):
18+
self.cls = cls
19+
20+
def encode(self, x: BaseModel) -> dict[str, Any]:
21+
return x.model_dump(exclude_unset=True)
22+
23+
def decode(self, x: dict[str, Any]) -> BaseModel:
24+
return self.cls.model_construct(**x)

0 commit comments

Comments
 (0)