Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pkg/cli/init-templates/pipeline/cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@

build:
gpu: false
cog_runtime: true
python_requirements: requirements.txt
predict: "main.py:run"
12 changes: 10 additions & 2 deletions pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
buildFast = cfg.Build.Fast
}

// TODO[md]: this is a temporary hack to propagate a procedure flag through the build system without
// touching every function signature with another param. The cogpacks refactor addresses this.
if pipelinesImage {
cfg.Build.ProcedureMode = true
}

client := registry.NewRegistryClient()
if buildFast || pipelinesImage {
imageName = config.DockerImageName(projectDir)
Expand Down Expand Up @@ -258,6 +264,8 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
console.Info("")
console.Infof("Starting Docker image %s and running setup()...", imageName)

logHandler := predict.NewLogHandler()

predictor, err := predict.NewPredictor(ctx, command.RunOptions{
GPUs: gpus,
Image: imageName,
Expand All @@ -281,7 +289,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
}()

timeout := time.Duration(setupTimeout) * time.Second
if err := predictor.Start(ctx, os.Stderr, timeout); err != nil {
if err := predictor.Start(ctx, logHandler, timeout); err != nil {
// Only retry if we're using a GPU but but the user didn't explicitly select a GPU with --gpus
// If the user specified the wrong GPU, they are explicitly selecting a GPU and they'll want to hear about it
if gpus == "all" && errors.Is(err, docker.ErrMissingDeviceDriver) {
Expand All @@ -297,7 +305,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
return err
}

if err := predictor.Start(ctx, os.Stderr, timeout); err != nil {
if err := predictor.Start(ctx, logHandler, timeout); err != nil {
return err
}
} else {
Expand Down
4 changes: 3 additions & 1 deletion pkg/cli/train.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ func cmdTrain(cmd *cobra.Command, args []string) error {
console.Info("")
console.Infof("Starting Docker image %s...", imageName)

logHandler := predict.NewLogHandler()

predictor, err := predict.NewPredictor(ctx, command.RunOptions{
GPUs: gpus,
Image: imageName,
Expand All @@ -140,7 +142,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error {
}
}()

if err := predictor.Start(ctx, os.Stderr, time.Duration(setupTimeout)*time.Second); err != nil {
if err := predictor.Start(ctx, logHandler, time.Duration(setupTimeout)*time.Second); err != nil {
return err
}

Expand Down
2 changes: 2 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ type Build struct {
CogRuntime *bool `json:"cog_runtime,omitempty" yaml:"cog_runtime,omitempty"`
PythonOverrides string `json:"python_overrides,omitempty" yaml:"python_overrides,omitempty"`

ProcedureMode bool `json:"-" yaml:"-"`

pythonRequirementsContent []string
}

Expand Down
67 changes: 67 additions & 0 deletions pkg/dockerfile/coglet.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package dockerfile

import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
)

// GitHubRelease represents a GitHub release response
type GitHubRelease struct {
TagName string `json:"tag_name"`
Assets []struct {
Name string `json:"name"`
BrowserDownloadURL string `json:"browser_download_url"`
} `json:"assets"`
}

// GetLatestCogletWheelURL fetches the latest coglet wheel URL from GitHub releases
func GetLatestCogletWheelURL(ctx context.Context) (string, error) {
// Create HTTP client with timeout
client := &http.Client{
Timeout: 30 * time.Second,
}

// GitHub API URL for latest release
apiURL := "https://api.github.com/repos/replicate/cog-runtime/releases/latest"

// Create request with context
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}

// Add headers
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "cog-cli")

// Make the request
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to fetch release data: %w", err)
}
defer resp.Body.Close()

// Check status code
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("GitHub API returned status %d", resp.StatusCode)
}

// Parse JSON response
var release GitHubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return "", fmt.Errorf("failed to parse release data: %w", err)
}

// Find coglet wheel in assets
for _, asset := range release.Assets {
if strings.HasSuffix(asset.Name, ".whl") && strings.Contains(asset.Name, "coglet") {
return asset.BrowserDownloadURL, nil
}
}

return "", fmt.Errorf("no coglet wheel found in latest release %s", release.TagName)
}
13 changes: 11 additions & 2 deletions pkg/dockerfile/standard_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ func (g *StandardGenerator) installCog() (string, error) {
return "", nil
}

if g.Config.Build.CogRuntime != nil && *g.Config.Build.CogRuntime {
if g.Config.Build.CogRuntime != nil && *g.Config.Build.CogRuntime || g.Config.Build.ProcedureMode {
return g.installCogRuntime()
}

Expand All @@ -485,10 +485,19 @@ func (g *StandardGenerator) installCogRuntime() (string, error) {
if !CheckMajorMinorOnly(g.Config.Build.PythonVersion) {
return "", fmt.Errorf("Python version must be <major>.<minor>")
}

cogletURL := PinnedCogletURL
if g.Config.Build.ProcedureMode {
// if we're building a procedure, use the latest coglet release
if latestURL, err := GetLatestCogletWheelURL(context.TODO()); err == nil {
cogletURL = latestURL
}
}

cmds := []string{
"ENV R8_COG_VERSION=coglet",
"ENV R8_PYTHON_VERSION=" + g.Config.Build.PythonVersion,
"RUN pip install " + PinnedCogletURL,
"RUN pip install " + cogletURL,
}
return strings.Join(cmds, "\n"), nil
}
Expand Down
115 changes: 115 additions & 0 deletions pkg/predict/log_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package predict

import (
"bufio"
"encoding/json"
"os"
"strings"
"sync"

"github.com/replicate/cog/pkg/util/console"
)

// LogEntry represents a structured log entry from the container
type LogEntry struct {
Severity string `json:"severity"`
Timestamp string `json:"timestamp"`
Logger string `json:"logger"`
Caller string `json:"caller"`
Message string `json:"message"`
// Additional fields are ignored but preserved
}

// LogHandler implements io.Writer and processes container stderr output
// It parses JSON logs and routes them to appropriate console levels,
// while handling unstructured logs gracefully.
type LogHandler struct {
mu sync.Mutex
}

// NewLogHandler creates a new LogHandler
func NewLogHandler() *LogHandler {
return &LogHandler{}
}

// Write implements io.Writer interface
func (lh *LogHandler) Write(p []byte) (n int, err error) {
lh.mu.Lock()
defer lh.mu.Unlock()

// TEMPORARY: Tee raw output to stderr for debugging
os.Stderr.WriteString("RAW: " + string(p))

// Process each line
scanner := bufio.NewScanner(strings.NewReader(string(p)))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}

lh.processLine(line)
}

return len(p), nil
}

// processLine processes a single log line
func (lh *LogHandler) processLine(line string) {
// Try to parse as JSON first
if lh.TryParseJSONLog(line) {
return
}

// Handle unstructured logs
lh.handleUnstructuredLog(line)
}

// TryParseJSONLog attempts to parse the line as a JSON log entry
// Returns true if successfully parsed and handled
// This is exported for testing purposes
func (lh *LogHandler) TryParseJSONLog(line string) bool {
var entry LogEntry
if err := json.Unmarshal([]byte(line), &entry); err != nil {
return false
}

// Route based on severity level
switch strings.ToLower(entry.Severity) {
case "debug":
console.Debug(entry.Message)
case "info":
console.Debug(entry.Message) // Info logs from container go to debug level
case "warn", "warning":
console.Warn(entry.Message)
case "error":
console.Error(entry.Message)
case "fatal":
console.Error(entry.Message) // Fatal logs from container go to error level
default:
// Unknown severity, treat as info
console.Debug(entry.Message)
}

return true
}

// handleUnstructuredLog handles non-JSON log lines
func (lh *LogHandler) handleUnstructuredLog(line string) {
// Check for common error patterns
lowerLine := strings.ToLower(line)

// Route based on content patterns
switch {
case strings.Contains(lowerLine, "error") || strings.Contains(lowerLine, "failed") || strings.Contains(lowerLine, "exception"):
console.Error(line)
case strings.Contains(lowerLine, "warning") || strings.Contains(lowerLine, "warn"):
console.Warn(line)
case strings.Contains(lowerLine, "debug"):
console.Debug(line)
default:
// Default to debug level for unstructured logs
// This prevents cluttering the user's output with container internals
console.Debug(line)
}
}
Loading
Loading