Skip to content

Commit

Permalink
fix tail calls, comformence and add a helper (#14)
Browse files Browse the repository at this point in the history
* resolve references

(cherry picked from commit 66737d9443e4e9bb168d95d74d1a5f74c0dd6927)

* fix tail call

* fix conformence
  • Loading branch information
safchain authored Nov 19, 2024
1 parent cc302cb commit adf57ad
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 36 deletions.
1 change: 1 addition & 0 deletions ebpf/include/baloum.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ static int (*baloum_call)(struct baloum_ctx *ctx, const char *section) = (void *
static int (*baloum_strcmp)(const char *s1, const char *s2) = (void *)0xfffd;
static int (*baloum_memcmp)(const void *b1, const void *b2, __u32 size) = (void *)0xfffc;
static int (*baloum_sleep)(__u64 ns) = (void *)0xfffb;
static int (*baloum_memcpy)(const void *b1, const void *b2, __u32 size) = (void *)0xfffa;

#define assert_memcmp(b1, b2, s, msg) \
if (baloum_memcmp(b1, b2, s) != 0) \
Expand Down
35 changes: 32 additions & 3 deletions pkg/baloum/fncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ const (

// static int (*baloum_sleep)(__u64 ns) = (void *)0xfffb;
FnSleep = asm.BuiltinFunc(0xfffb)

// static int (*baloum_memcpy)(const void *b1, const void *b2, __u32 size) = (void *)0xfffa;
FnMemCpy = asm.BuiltinFunc(0xfffa)
)

var (
Expand All @@ -51,6 +54,7 @@ var (
FnStrCmp: FnStrCmpImpl,
FnMemCmp: FnMemCmpImpl,
FnSleep: FnSleepImpl,
FnMemCpy: FnMemCpyImpl,

// bpf helpers
asm.FnTracePrintk: FnTracePrintkImpl,
Expand Down Expand Up @@ -147,6 +151,20 @@ func FnMemCmpImpl(vm *VM, inst *asm.Instruction) error {
return nil
}

func FnMemCpyImpl(vm *VM, inst *asm.Instruction) error {
code := ErrorCode
vm.regs[asm.R0] = uint64(code)

size := vm.regs[asm.R3]

srcBytes, err := vm.getBytes(vm.regs[asm.R2], size)
if err != nil {
return err
}

return vm.setBytes(vm.regs[asm.R1], srcBytes, size)
}

var (
reFmt = regexp.MustCompile("(%[^%])")
)
Expand Down Expand Up @@ -374,7 +392,10 @@ func FnGetSmpProcessorIdImpl(vm *VM, inst *asm.Instruction) error {
}

func FnTailCallImpl(vm *VM, inst *asm.Instruction) error {
vm.regs[asm.R0] = 0
if vm.tailCails >= 32 {
return errors.New("maximum tail calls reach")
}
vm.tailCails++

_map := vm.maps.GetMapById(int(vm.regs[asm.R2]))
if _map == nil {
Expand Down Expand Up @@ -407,15 +428,23 @@ func FnTailCallImpl(vm *VM, inst *asm.Instruction) error {
return errors.New("value size not supported")
}

if int(fd) >= len(vm.programs) {
if fd == 0 {
return errors.New("program not found")
}

progIndex := fd - 1

if progIndex > len(vm.programs) {
return errors.New("out of bound")
}

program := vm.programs[fd]
program := vm.programs[progIndex]
if program.Type != vm.progType {
return errors.New("program types differ")
}

vm.regs[asm.R0] = 0

ret, err := vm.RunInstructions(vm.ctx, program.Instructions)
vm.regs[asm.R0] = uint64(ret)

Expand Down
4 changes: 2 additions & 2 deletions pkg/baloum/map_prog_array_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestTailCall(t *testing.T) {
}

vm := NewVM(spec, Opts{Fncs: fncs, Logger: suggar})
err = vm.LoadMaps("test/tail_call")
err = vm.LoadMapsUsedBy("test/tail_call")
if err != nil {
log.Fatal(err)
}
Expand All @@ -72,7 +72,7 @@ func TestTailCall(t *testing.T) {

var ctx StdContext
code, err := vm.RunProgram(&ctx, "test/tail_call")
assert.Equal(t, 72, code)
assert.Equal(t, int64(72), code)
assert.Nil(t, err)

data, err := vm.Map("data").Lookup(uint64(0))
Expand Down
14 changes: 8 additions & 6 deletions pkg/baloum/opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"runtime"
"time"

"github.com/cilium/ebpf"
"github.com/cilium/ebpf/asm"
)

Expand All @@ -36,12 +37,13 @@ type Fncs struct {
}

type Opts struct {
StackSize int
Fncs Fncs
RawFncs map[asm.BuiltinFunc]func(*VM, *asm.Instruction) error
Logger Logger
CPUs int
Observer Observer
StackSize int
Fncs Fncs
RawFncs map[asm.BuiltinFunc]func(*VM, *asm.Instruction) error
Logger Logger
CPUs int
Observer Observer
ProgramType ebpf.ProgramType
}

func defaultKtimeGetNS(vm *VM) (uint64, error) {
Expand Down
136 changes: 111 additions & 25 deletions pkg/baloum/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
)

const (
ErrorCode = int(-1)
ErrorCode = int64(-1)

// bitmasks
fetchBit = 0x01
Expand All @@ -41,17 +41,18 @@ type vmState struct {
type program []asm.Instruction

type VM struct {
Spec *ebpf.CollectionSpec
Opts Opts
stack []byte
heap *Heap
regs Regs
fncs map[asm.BuiltinFunc]func(*VM, *asm.Instruction) error
strs map[string]uint64
maps *MapCollection
programs []*ebpf.ProgramSpec
progType ebpf.ProgramType
ctx Context
Spec *ebpf.CollectionSpec
Opts Opts
stack []byte
heap *Heap
regs Regs
fncs map[asm.BuiltinFunc]func(*VM, *asm.Instruction) error
strs map[string]uint64
maps *MapCollection
programs []*ebpf.ProgramSpec
progType ebpf.ProgramType
ctx Context
tailCails int
}

func NewVM(spec *ebpf.CollectionSpec, opts Opts) *VM {
Expand Down Expand Up @@ -210,6 +211,21 @@ func (vm *VM) setUint8(addr uint64, value uint8) error {
return nil
}

func (vm *VM) setBytes(addr uint64, value []byte, size uint64) error {
bytes, err := vm.getBytes(addr, size)
if err != nil {
return err
}

if int(size) > len(bytes) {
return errors.New("not enough space")
}

copy(bytes, value[:size])

return nil
}

func (vm *VM) atomicUint64(addr uint64, inc uint64, imm int64) (uint64, bool, error) {
value, err := vm.getUint64(addr)
if err != nil {
Expand Down Expand Up @@ -329,6 +345,50 @@ func (vm *VM) initFncs() {
}
}

func resolveSymbolReferences(insts asm.Instructions) asm.Instructions {
var resolved asm.Instructions

symbols := make(map[string]int)

for offset, ins := range insts {
if symbol := ins.Symbol(); symbol != "" {
symbols[symbol] = offset
}
}

for i, inst := range insts {
resolved = append(resolved, inst)

if ref := inst.Reference(); ref != "" {
offset, exists := symbols[ref]
if exists {
var inc int

// correct with size of instruction size
delta := offset - i - 1
if delta > 0 {
for j := 0; j != delta; j++ {
if insts[i+j].Size() > 8 {
inc++
}
}
} else {
for j := 0; j != delta; j-- {
if insts[i+j].Size() > 8 {
inc--
}
}
}

inst.Offset = int16(delta + inc)
resolved[i] = inst
}
}
}

return resolved
}

func normalizeInsts(insts []asm.Instruction) []asm.Instruction {
var normInsts []asm.Instruction

Expand All @@ -343,7 +403,9 @@ func normalizeInsts(insts []asm.Instruction) []asm.Instruction {
return normInsts
}

func (vm *VM) RunInstructions(ctx Context, insts []asm.Instruction) (int, error) {
func (vm *VM) RunInstructions(ctx Context, insts []asm.Instruction) (int64, error) {
// prepare the instruction
insts = resolveSymbolReferences(insts)
insts = normalizeInsts(insts)

state := vm.saveState()
Expand Down Expand Up @@ -451,9 +513,9 @@ func (vm *VM) RunInstructions(ctx Context, insts []asm.Instruction) (int, error)
case asm.ArSh.Op(asm.RegSource):
vm.regs[inst.Dst] = uint64(int64(vm.regs[inst.Dst]) >> uint64(vm.regs[inst.Src]%64))
case asm.ArSh.Op32(asm.ImmSource):
vm.regs[inst.Dst] = uint64(int32(vm.regs[inst.Dst]) >> uint32(uint32(inst.Constant)%32))
vm.regs[inst.Dst] = uint64(uint32(int32(vm.regs[inst.Dst]) >> (uint32(inst.Constant) % 32)))
case asm.ArSh.Op32(asm.RegSource):
vm.regs[inst.Dst] = uint64(int32(vm.regs[inst.Dst]) >> uint32(vm.regs[inst.Src]%32))
vm.regs[inst.Dst] = uint64(uint32(int32(vm.regs[inst.Dst]) >> (vm.regs[inst.Src] % 32)))

//
case asm.StoreMemOp(asm.DWord):
Expand Down Expand Up @@ -837,7 +899,7 @@ func (vm *VM) RunInstructions(ctx Context, insts []asm.Instruction) (int, error)

// if tail call endup here
if builtin == asm.FnTailCall {
return int(int32(vm.regs[asm.R0])), nil
return int64(int32(vm.regs[asm.R0])), nil
}
} else {
return ErrorCode, fmt.Errorf("unknown function: `%v`", inst.Src)
Expand All @@ -857,7 +919,7 @@ func (vm *VM) RunInstructions(ctx Context, insts []asm.Instruction) (int, error)

//
case asm.Exit.Op(asm.ImmSource):
return int(int32(vm.regs[asm.R0])), nil
return int64(vm.regs[asm.R0]), nil
default:
if opcode.Class().IsALU() && opcode.ALUOp() == asm.Swap {
buff := make([]byte, 8)
Expand Down Expand Up @@ -885,18 +947,29 @@ func (vm *VM) RunInstructions(ctx Context, insts []asm.Instruction) (int, error)
} else {
return ErrorCode, fmt.Errorf("unknown op: %v", inst)
}

}
}

return ErrorCode, errors.New("unexpected error")
}

func (vm *VM) LoadMap(name string) error {
return vm.maps.LoadMap(vm.Spec, name)
func (vm *VM) LoadMap(name string) (*Map, error) {
if err := vm.maps.LoadMap(vm.Spec, name); err != nil {
return nil, err
}
return vm.maps.mapByName[name], nil
}

func (vm *VM) LoadMaps(section ...string) error {
func (vm *VM) LoadMaps(names ...string) error {
for _, name := range names {
if _, err := vm.LoadMap(name); err != nil {
return err
}
}
return nil
}

func (vm *VM) LoadMapsUsedBy(section ...string) error {
return vm.maps.LoadMaps(vm.Spec, section...)
}

Expand Down Expand Up @@ -925,9 +998,18 @@ func (vm *VM) loadSection(section string) (*ebpf.ProgramSpec, error) {
return program, nil
}

func (vm *VM) Program(name string) (*ebpf.ProgramSpec, uint32) {
for i, programSpec := range vm.programs {
if programSpec.Name == name {
return programSpec, uint32(i)
}
}
return nil, 0
}

func (vm *VM) AddProgram(program *ebpf.ProgramSpec) uint32 {
// FD is the index in the map of programs
fd := uint32(len(vm.programs))
// FD is the index in the map of programs + 1
fd := uint32(len(vm.programs)) + 1
vm.programs = append(vm.programs, program)

return fd
Expand All @@ -943,14 +1025,18 @@ func (vm *VM) LoadProgram(section string) (uint32, error) {
return fd, nil
}

func (vm *VM) RunProgram(ctx Context, section string) (int, error) {
func (vm *VM) RunProgram(ctx Context, section string, programType ...ebpf.ProgramType) (int64, error) {
program, err := vm.loadSection(section)
if err != nil {
return ErrorCode, err
}

// keep current type and context
vm.progType = program.Type
if len(programType) > 0 {
vm.progType = programType[0]
} else {
vm.progType = program.Type
}
vm.ctx = ctx

return vm.RunInstructions(ctx, program.Instructions)
Expand Down

0 comments on commit adf57ad

Please sign in to comment.