Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
508 changes: 508 additions & 0 deletions api/fine-tuning/cmd/e2e-serving-test/main.go

Large diffs are not rendered by default.

142 changes: 95 additions & 47 deletions api/fine-tuning/cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (
"os"
"os/signal"
"path/filepath"
"sync"
"syscall"
"time"

image "github.com/0glabs/0g-serving-broker/common/docker"
"github.com/0glabs/0g-serving-broker/common/log"
Expand All @@ -20,10 +22,13 @@ import (
"github.com/0glabs/0g-serving-broker/fine-tuning/internal/db"
"github.com/0glabs/0g-serving-broker/fine-tuning/internal/handler"
"github.com/0glabs/0g-serving-broker/fine-tuning/internal/services"
"github.com/0glabs/0g-serving-broker/fine-tuning/internal/serving"
"github.com/0glabs/0g-serving-broker/fine-tuning/internal/storage"
"github.com/0glabs/0g-serving-broker/fine-tuning/internal/utils"
"github.com/0glabs/0g-serving-broker/fine-tuning/monitor"
"github.com/docker/docker/client"
"github.com/gin-gonic/gin"
"github.com/prometheus/client_golang/prometheus/promhttp"
)

//go:generate swag fmt
Expand All @@ -42,8 +47,6 @@ func Main() {
panic(err)
}

// Initialize data directory for task storage
// Uses configured dataDir or falls back to os.TempDir()
utils.SetDataDir(cfg.Service.DataDir)
logger.Infof("Data directory set to: %s", utils.GetDataDir())

Expand All @@ -55,13 +58,13 @@ func Main() {
defer cancel()
imageChan := buildImageIfNeeded(ctx, cfg, logger)

services, err := initializeServices(ctx, cfg, logger)
appServices, err := initializeServices(ctx, cfg, logger)
if err != nil {
panic(err)
}
defer services.contract.Close()
defer appServices.contract.Close()

if err := runApplication(ctx, cfg, services, logger, imageChan); err != nil {
if err := runApplication(ctx, cfg, appServices, logger, imageChan); err != nil {
panic(err)
}
}
Expand Down Expand Up @@ -121,14 +124,11 @@ func buildImageIfNeeded(ctx context.Context, config *config.Config, logger log.L
if buildImage {
logger.Debugf("build image %s", imageName)

// Check if transformer files exist in the embedded location
embeddedPath := "/fine-tuning/execution/transformer"

// Prepare bridge directory for Docker daemon access
if _, err := os.Stat(embeddedPath); err == nil {
logger.Infof("Found embedded transformer files at %s", embeddedPath)

// Clean bridge directory contents but don't remove the directory itself (it may be mounted)
bridgeDir := constant.FineTuningDockerfilePath
if entries, err := os.ReadDir(bridgeDir); err == nil {
for _, entry := range entries {
Expand All @@ -139,13 +139,11 @@ func buildImageIfNeeded(ctx context.Context, config *config.Config, logger log.L
}
}

// Ensure bridge directory exists
if err := os.MkdirAll(bridgeDir, 0755); err != nil {
logger.Errorf("failed to create bridge directory: %v", err)
return
}

// Copy transformer files to bridge directory
logger.Infof("Copying transformer files to bridge directory: %s", bridgeDir)
if err := copyDirectory(embeddedPath, bridgeDir); err != nil {
logger.Errorf("failed to copy transformer files: %v", err)
Expand All @@ -156,7 +154,6 @@ func buildImageIfNeeded(ctx context.Context, config *config.Config, logger log.L
logger.Warnf("Embedded transformer files not found at %s, checking bridge directory", embeddedPath)
}

// Build image using the bridge directory (constant.FineTuningDockerfilePath now points to /tmp/transformer-bridge)
logger.Infof("Building image from: %s", constant.FineTuningDockerfilePath)
err := image.ImageBuild(ctx, cli, constant.FineTuningDockerfilePath, imageName, logger)
if err != nil {
Expand All @@ -173,11 +170,11 @@ func buildImageIfNeeded(ctx context.Context, config *config.Config, logger log.L
}

func initializeServices(ctx context.Context, cfg *config.Config, logger log.Logger) (*ApplicationServices, error) {
db, err := db.NewDB(cfg, logger)
database, err := db.NewDB(cfg, logger)
if err != nil {
return nil, err
}
if err := db.Migrate(); err != nil {
if err := database.Migrate(); err != nil {
return nil, err
}

Expand All @@ -199,7 +196,6 @@ func initializeServices(ctx context.Context, cfg *config.Config, logger log.Logg
return nil, err
}

// Sync TEE quote to initialize Address before creating contract
logger.Info("syncing TEE quote during service initialization")
if err := teeService.SyncQuote(ctx, os.Getenv("NETWORK") != "hardhat"); err != nil {
return nil, err
Expand All @@ -210,78 +206,126 @@ func initializeServices(ctx context.Context, cfg *config.Config, logger log.Logg
return nil, err
}

ctrl := ctrl.New(db, cfg, contract, teeService, logger)
ctrlInst := ctrl.New(database, cfg, contract, teeService, logger)

setup, err := services.NewSetup(db, cfg, contract, logger, storageClient, teeService)
setup, err := services.NewSetup(database, cfg, contract, logger, storageClient, teeService)
if err != nil {
return nil, err
}

executor, err := services.NewExecutor(db, cfg, contract, logger)
executor, err := services.NewExecutor(database, cfg, contract, logger)
if err != nil {
return nil, err
}

finalizer, err := services.NewFinalizer(db, cfg, contract, logger, storageClient, teeService)
finalizer, err := services.NewFinalizer(database, cfg, contract, logger, storageClient, teeService)
if err != nil {
return nil, err
}

settlement, err := services.NewSettlement(db, contract, cfg, teeService, logger)
settlement, err := services.NewSettlement(database, contract, cfg, teeService, logger)
if err != nil {
return nil, err
}

return &ApplicationServices{
db: db,
db: database,
storageClient: storageClient,
contract: contract,
teeService: teeService,
ctrl: ctrl,
ctrl: ctrlInst,
setup: setup,
executor: executor,
finalizer: finalizer,
settlement: settlement,
}, nil
}

func runApplication(ctx context.Context, cfg *config.Config, services *ApplicationServices, logger log.Logger, imageChan <-chan bool) error {
if err := services.db.MarkInProgressTasksAsFailed(); err != nil {
func runApplication(ctx context.Context, cfg *config.Config, svc *ApplicationServices, logger log.Logger, imageChan <-chan bool) error {
if err := svc.db.MarkInProgressTasksAsFailed(); err != nil {
return err
}

if err := services.ctrl.SyncServices(ctx); err != nil {
if err := svc.ctrl.SyncServices(ctx); err != nil {
return err
}

if err := services.finalizer.Start(ctx); err != nil {
if err := svc.finalizer.Start(ctx); err != nil {
return err
}

if err := services.executor.Start(ctx); err != nil {
if err := svc.executor.Start(ctx); err != nil {
return err
}

if err := services.setup.Start(ctx); err != nil {
if err := svc.setup.Start(ctx); err != nil {
return err
}

engine := gin.New()
h := handler.New(services.ctrl, logger, cfg.RateLimitRPS, cfg.RateLimitBurst)

var wg sync.WaitGroup

if cfg.Monitor.Enable {
monitor.Init(cfg.Service.ServingUrl, ctx)
engine.GET("/metrics", gin.WrapH(promhttp.Handler()))
engine.Use(monitor.TrackMetrics())
wg.Add(1)
go func() {
defer wg.Done()
startTaskStatePoller(ctx, svc.db, logger)
}()
}

var servingProxy *serving.Proxy
if cfg.Serving.Enable {
servingMgr := serving.NewManager(svc.db, serving.ServingConfig{
Enable: cfg.Serving.Enable,
BaseModelPath: cfg.Serving.BaseModelPath,
InferenceGPUIDs: cfg.Serving.InferenceGPUIDs,
VLLMPort: cfg.Serving.VLLMPort,
MaxLoraRank: cfg.Serving.MaxLoraRank,
MaxLoraModules: cfg.Serving.MaxLoraModules,
MaxCpuLoras: cfg.Serving.MaxCpuLoras,
LoraModulesDir: cfg.Serving.LoraModulesDir,
OffloadAfterMinutes: cfg.Serving.OffloadAfterMinutes,
EnableColdStorage: cfg.Serving.EnableColdStorage,
ModelLoadTimeoutSeconds: cfg.Serving.ModelLoadTimeoutSeconds,
GpuMemoryUtilization: cfg.Serving.GpuMemoryUtilization,
}, logger, svc.storageClient)
if err := servingMgr.Start(ctx); err != nil {
return err
}
defer func() {
if err := servingMgr.Stop(); err != nil {
logger.Warnf("failed to stop vLLM: %v", err)
}
}()

registry := serving.NewRegistry(svc.contract, servingMgr, serving.RegistryConfig{
InputPrice: cfg.Serving.InputPrice,
OutputPrice: cfg.Serving.OutputPrice,
}, logger)
registry.Start(ctx)

servingProxy = serving.NewProxy(servingMgr, logger)
logger.Info("LoRA serving module initialized")
}

h := handler.New(svc.ctrl, logger, cfg.RateLimitRPS, cfg.RateLimitBurst, servingProxy)
h.Register(engine)

if _, ok := <-imageChan; !ok {
return errors.New("image build failed")
}

if err := services.settlement.Start(ctx); err != nil {
if err := svc.settlement.Start(ctx); err != nil {
return err
}

stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)

// Listen and Serve, config port with PORT=X
go func() {
logger.Info("starting http server...")
if err := engine.Run(); err != nil {
Expand All @@ -292,40 +336,53 @@ func runApplication(ctx context.Context, cfg *config.Config, services *Applicati

<-stop
logger.Info("shutting down server...")
wg.Wait()
return nil
}

// copyDirectory recursively copies a directory from src to dst
func startTaskStatePoller(ctx context.Context, database *db.DB, logger log.Logger) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
return
case <-ticker.C:
counts, err := database.CountTasksByState()
if err != nil {
logger.Warnf("failed to count tasks by state for metrics: %v", err)
continue
}
monitor.UpdateTaskStateGauge(counts)
}
}
}

func copyDirectory(src, dst string) error {
// Get file info of source
srcInfo, err := os.Stat(src)
if err != nil {
return err
}

// Create destination directory
if err := os.MkdirAll(dst, srcInfo.Mode()); err != nil {
return err
}

// Read source directory
entries, err := os.ReadDir(src)
if err != nil {
return err
}

// Copy each entry
for _, entry := range entries {
srcPath := filepath.Join(src, entry.Name())
dstPath := filepath.Join(dst, entry.Name())

if entry.IsDir() {
// Recursively copy subdirectory
if err := copyDirectory(srcPath, dstPath); err != nil {
return err
}
} else {
// Copy file
if err := copyFile(srcPath, dstPath); err != nil {
return err
}
Expand All @@ -335,33 +392,24 @@ func copyDirectory(src, dst string) error {
return nil
}

// copyFile copies a single file from src to dst
func copyFile(src, dst string) error {
// Open source file
srcFile, err := os.Open(src)
if err != nil {
return err
}
defer srcFile.Close()

// Get source file info
srcInfo, err := srcFile.Stat()
if err != nil {
return err
}

// Create destination file
dstFile, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, srcInfo.Mode())
if err != nil {
return err
}
defer dstFile.Close()

// Copy contents
_, err = io.Copy(dstFile, srcFile)
if err != nil {
return err
}

return nil
return err
}
Loading