diff --git a/ebpf/include/baloum.h b/ebpf/include/baloum.h index b74e24c..c423876 100644 --- a/ebpf/include/baloum.h +++ b/ebpf/include/baloum.h @@ -16,6 +16,7 @@ static int (*baloum_call)(struct baloum_ctx *ctx, const char *section) = (void * static int (*baloum_strcmp)(const char *s1, const char *s2) = (void *)0xfffd; static int (*baloum_memcmp)(const void *b1, const void *b2, __u32 size) = (void *)0xfffc; static int (*baloum_sleep)(__u64 ns) = (void *)0xfffb; +static int (*baloum_memcpy)(const void *b1, const void *b2, __u32 size) = (void *)0xfffa; #define assert_memcmp(b1, b2, s, msg) \ if (baloum_memcmp(b1, b2, s) != 0) \ diff --git a/pkg/baloum/fncs.go b/pkg/baloum/fncs.go index 805b476..3525d9f 100644 --- a/pkg/baloum/fncs.go +++ b/pkg/baloum/fncs.go @@ -41,6 +41,9 @@ const ( // static int (*baloum_sleep)(__u64 ns) = (void *)0xfffb; FnSleep = asm.BuiltinFunc(0xfffb) + + // static int (*baloum_memcpy)(const void *b1, const void *b2, __u32 size) = (void *)0xfffa; + FnMemCpy = asm.BuiltinFunc(0xfffa) ) var ( @@ -51,6 +54,7 @@ var ( FnStrCmp: FnStrCmpImpl, FnMemCmp: FnMemCmpImpl, FnSleep: FnSleepImpl, + FnMemCpy: FnMemCpyImpl, // bpf helpers asm.FnTracePrintk: FnTracePrintkImpl, @@ -147,6 +151,20 @@ func FnMemCmpImpl(vm *VM, inst *asm.Instruction) error { return nil } +func FnMemCpyImpl(vm *VM, inst *asm.Instruction) error { + code := ErrorCode + vm.regs[asm.R0] = uint64(code) + + size := vm.regs[asm.R3] + + srcBytes, err := vm.getBytes(vm.regs[asm.R2], size) + if err != nil { + return err + } + + return vm.setBytes(vm.regs[asm.R1], srcBytes, size) +} + var ( reFmt = regexp.MustCompile("(%[^%])") ) @@ -374,7 +392,10 @@ func FnGetSmpProcessorIdImpl(vm *VM, inst *asm.Instruction) error { } func FnTailCallImpl(vm *VM, inst *asm.Instruction) error { - vm.regs[asm.R0] = 0 + if vm.tailCails >= 32 { + return errors.New("maximum tail calls reach") + } + vm.tailCails++ _map := vm.maps.GetMapById(int(vm.regs[asm.R2])) if _map == nil { @@ -407,15 +428,23 @@ func FnTailCallImpl(vm *VM, inst *asm.Instruction) error { return errors.New("value size not supported") } - if int(fd) >= len(vm.programs) { + if fd == 0 { + return errors.New("program not found") + } + + progIndex := fd - 1 + + if progIndex > len(vm.programs) { return errors.New("out of bound") } - program := vm.programs[fd] + program := vm.programs[progIndex] if program.Type != vm.progType { return errors.New("program types differ") } + vm.regs[asm.R0] = 0 + ret, err := vm.RunInstructions(vm.ctx, program.Instructions) vm.regs[asm.R0] = uint64(ret) diff --git a/pkg/baloum/map_prog_array_test.go b/pkg/baloum/map_prog_array_test.go index 10e84cc..8a0a5d2 100644 --- a/pkg/baloum/map_prog_array_test.go +++ b/pkg/baloum/map_prog_array_test.go @@ -51,7 +51,7 @@ func TestTailCall(t *testing.T) { } vm := NewVM(spec, Opts{Fncs: fncs, Logger: suggar}) - err = vm.LoadMaps("test/tail_call") + err = vm.LoadMapsUsedBy("test/tail_call") if err != nil { log.Fatal(err) } @@ -72,7 +72,7 @@ func TestTailCall(t *testing.T) { var ctx StdContext code, err := vm.RunProgram(&ctx, "test/tail_call") - assert.Equal(t, 72, code) + assert.Equal(t, int64(72), code) assert.Nil(t, err) data, err := vm.Map("data").Lookup(uint64(0)) diff --git a/pkg/baloum/opts.go b/pkg/baloum/opts.go index 0a51af4..1a474ba 100644 --- a/pkg/baloum/opts.go +++ b/pkg/baloum/opts.go @@ -20,6 +20,7 @@ import ( "runtime" "time" + "github.com/cilium/ebpf" "github.com/cilium/ebpf/asm" ) @@ -36,12 +37,13 @@ type Fncs struct { } type Opts struct { - StackSize int - Fncs Fncs - RawFncs map[asm.BuiltinFunc]func(*VM, *asm.Instruction) error - Logger Logger - CPUs int - Observer Observer + StackSize int + Fncs Fncs + RawFncs map[asm.BuiltinFunc]func(*VM, *asm.Instruction) error + Logger Logger + CPUs int + Observer Observer + ProgramType ebpf.ProgramType } func defaultKtimeGetNS(vm *VM) (uint64, error) { diff --git a/pkg/baloum/vm.go b/pkg/baloum/vm.go index e3bfdd7..702bfc9 100644 --- a/pkg/baloum/vm.go +++ b/pkg/baloum/vm.go @@ -27,7 +27,7 @@ import ( ) const ( - ErrorCode = int(-1) + ErrorCode = int64(-1) // bitmasks fetchBit = 0x01 @@ -41,17 +41,18 @@ type vmState struct { type program []asm.Instruction type VM struct { - Spec *ebpf.CollectionSpec - Opts Opts - stack []byte - heap *Heap - regs Regs - fncs map[asm.BuiltinFunc]func(*VM, *asm.Instruction) error - strs map[string]uint64 - maps *MapCollection - programs []*ebpf.ProgramSpec - progType ebpf.ProgramType - ctx Context + Spec *ebpf.CollectionSpec + Opts Opts + stack []byte + heap *Heap + regs Regs + fncs map[asm.BuiltinFunc]func(*VM, *asm.Instruction) error + strs map[string]uint64 + maps *MapCollection + programs []*ebpf.ProgramSpec + progType ebpf.ProgramType + ctx Context + tailCails int } func NewVM(spec *ebpf.CollectionSpec, opts Opts) *VM { @@ -210,6 +211,21 @@ func (vm *VM) setUint8(addr uint64, value uint8) error { return nil } +func (vm *VM) setBytes(addr uint64, value []byte, size uint64) error { + bytes, err := vm.getBytes(addr, size) + if err != nil { + return err + } + + if int(size) > len(bytes) { + return errors.New("not enough space") + } + + copy(bytes, value[:size]) + + return nil +} + func (vm *VM) atomicUint64(addr uint64, inc uint64, imm int64) (uint64, bool, error) { value, err := vm.getUint64(addr) if err != nil { @@ -329,6 +345,50 @@ func (vm *VM) initFncs() { } } +func resolveSymbolReferences(insts asm.Instructions) asm.Instructions { + var resolved asm.Instructions + + symbols := make(map[string]int) + + for offset, ins := range insts { + if symbol := ins.Symbol(); symbol != "" { + symbols[symbol] = offset + } + } + + for i, inst := range insts { + resolved = append(resolved, inst) + + if ref := inst.Reference(); ref != "" { + offset, exists := symbols[ref] + if exists { + var inc int + + // correct with size of instruction size + delta := offset - i - 1 + if delta > 0 { + for j := 0; j != delta; j++ { + if insts[i+j].Size() > 8 { + inc++ + } + } + } else { + for j := 0; j != delta; j-- { + if insts[i+j].Size() > 8 { + inc-- + } + } + } + + inst.Offset = int16(delta + inc) + resolved[i] = inst + } + } + } + + return resolved +} + func normalizeInsts(insts []asm.Instruction) []asm.Instruction { var normInsts []asm.Instruction @@ -343,7 +403,9 @@ func normalizeInsts(insts []asm.Instruction) []asm.Instruction { return normInsts } -func (vm *VM) RunInstructions(ctx Context, insts []asm.Instruction) (int, error) { +func (vm *VM) RunInstructions(ctx Context, insts []asm.Instruction) (int64, error) { + // prepare the instruction + insts = resolveSymbolReferences(insts) insts = normalizeInsts(insts) state := vm.saveState() @@ -451,9 +513,9 @@ func (vm *VM) RunInstructions(ctx Context, insts []asm.Instruction) (int, error) case asm.ArSh.Op(asm.RegSource): vm.regs[inst.Dst] = uint64(int64(vm.regs[inst.Dst]) >> uint64(vm.regs[inst.Src]%64)) case asm.ArSh.Op32(asm.ImmSource): - vm.regs[inst.Dst] = uint64(int32(vm.regs[inst.Dst]) >> uint32(uint32(inst.Constant)%32)) + vm.regs[inst.Dst] = uint64(uint32(int32(vm.regs[inst.Dst]) >> (uint32(inst.Constant) % 32))) case asm.ArSh.Op32(asm.RegSource): - vm.regs[inst.Dst] = uint64(int32(vm.regs[inst.Dst]) >> uint32(vm.regs[inst.Src]%32)) + vm.regs[inst.Dst] = uint64(uint32(int32(vm.regs[inst.Dst]) >> (vm.regs[inst.Src] % 32))) // case asm.StoreMemOp(asm.DWord): @@ -837,7 +899,7 @@ func (vm *VM) RunInstructions(ctx Context, insts []asm.Instruction) (int, error) // if tail call endup here if builtin == asm.FnTailCall { - return int(int32(vm.regs[asm.R0])), nil + return int64(int32(vm.regs[asm.R0])), nil } } else { return ErrorCode, fmt.Errorf("unknown function: `%v`", inst.Src) @@ -857,7 +919,7 @@ func (vm *VM) RunInstructions(ctx Context, insts []asm.Instruction) (int, error) // case asm.Exit.Op(asm.ImmSource): - return int(int32(vm.regs[asm.R0])), nil + return int64(vm.regs[asm.R0]), nil default: if opcode.Class().IsALU() && opcode.ALUOp() == asm.Swap { buff := make([]byte, 8) @@ -885,18 +947,29 @@ func (vm *VM) RunInstructions(ctx Context, insts []asm.Instruction) (int, error) } else { return ErrorCode, fmt.Errorf("unknown op: %v", inst) } - } } return ErrorCode, errors.New("unexpected error") } -func (vm *VM) LoadMap(name string) error { - return vm.maps.LoadMap(vm.Spec, name) +func (vm *VM) LoadMap(name string) (*Map, error) { + if err := vm.maps.LoadMap(vm.Spec, name); err != nil { + return nil, err + } + return vm.maps.mapByName[name], nil } -func (vm *VM) LoadMaps(section ...string) error { +func (vm *VM) LoadMaps(names ...string) error { + for _, name := range names { + if _, err := vm.LoadMap(name); err != nil { + return err + } + } + return nil +} + +func (vm *VM) LoadMapsUsedBy(section ...string) error { return vm.maps.LoadMaps(vm.Spec, section...) } @@ -925,9 +998,18 @@ func (vm *VM) loadSection(section string) (*ebpf.ProgramSpec, error) { return program, nil } +func (vm *VM) Program(name string) (*ebpf.ProgramSpec, uint32) { + for i, programSpec := range vm.programs { + if programSpec.Name == name { + return programSpec, uint32(i) + } + } + return nil, 0 +} + func (vm *VM) AddProgram(program *ebpf.ProgramSpec) uint32 { - // FD is the index in the map of programs - fd := uint32(len(vm.programs)) + // FD is the index in the map of programs + 1 + fd := uint32(len(vm.programs)) + 1 vm.programs = append(vm.programs, program) return fd @@ -943,14 +1025,18 @@ func (vm *VM) LoadProgram(section string) (uint32, error) { return fd, nil } -func (vm *VM) RunProgram(ctx Context, section string) (int, error) { +func (vm *VM) RunProgram(ctx Context, section string, programType ...ebpf.ProgramType) (int64, error) { program, err := vm.loadSection(section) if err != nil { return ErrorCode, err } // keep current type and context - vm.progType = program.Type + if len(programType) > 0 { + vm.progType = programType[0] + } else { + vm.progType = program.Type + } vm.ctx = ctx return vm.RunInstructions(ctx, program.Instructions)