From 9cea5f5042ce298ed35eb97d70e59402fc2e5c70 Mon Sep 17 00:00:00 2001 From: Andrew Tolmach Date: Tue, 20 Dec 2022 14:15:59 -0800 Subject: [PATCH] improve x86 emulation so it can handle final programs up through lambda --- interp_x86/convert_x86.py | 6 +- interp_x86/eval_x86.py | 122 +++++++++++++++++++++++++++++++------- utils.py | 12 ++++ 3 files changed, 118 insertions(+), 22 deletions(-) diff --git a/interp_x86/convert_x86.py b/interp_x86/convert_x86.py index 336451e..76cab0e 100644 --- a/interp_x86/convert_x86.py +++ b/interp_x86/convert_x86.py @@ -27,7 +27,7 @@ def convert_arg(arg): return Tree('mem_a', [convert_int(offset), reg]) case ByteReg(id): return Tree('reg_a', [id]) - case GlobalValue(id): + case Global(id): return Tree('global_val_a', [id, 'rip']) case _: raise Exception('convert_arg: unhandled ' + repr(arg)) @@ -38,10 +38,14 @@ def convert_instr(instr): return Tree(instr, [convert_arg(arg) for arg in args]) case Callq(func, args): return Tree('callq', [func]) + case IndirectCallq(func, args): + return Tree('indirect_callq', [convert_arg(func)]) case Jump(label): return Tree('jmp', [label]) case JumpIf(cc, label): return Tree('j' + cc, [label]) + case IndirectJump(func): + return Tree('indirect_jmp', [convert_arg(func)]) case _: raise Exception('error in convert_instr, unhandled ' + repr(instr)) diff --git a/interp_x86/eval_x86.py b/interp_x86/eval_x86.py index a1a9b6a..85d55f2 100644 --- a/interp_x86/eval_x86.py +++ b/interp_x86/eval_x86.py @@ -1,5 +1,8 @@ # Author: Joe Near # License: GPLv3 +# Notice of modifications as per license: +# Modified 9/22 Andrew Tolmach to use 64-bit machine arithmetic +# Modified 10/22 Andrew Tolmach to support test that callee-save registers are preserved by top-level function from collections import defaultdict from dataclasses import dataclass @@ -10,9 +13,9 @@ from parser_x86 import x86_parser, x86_parser_instrs -def interp_x86(program): +def interp_x86(program, check_regs = False): x86_program = convert_program(program) - emu = X86Emulator(logging=False) + emu = X86Emulator(logging=False, check_regs=check_regs) x86_output = emu.eval_program(x86_program) for s in x86_output: print(s, end='') @@ -21,21 +24,34 @@ def interp_x86(program): class FunPointer: fun_name: str +@dataclass +class Bogus: + pass + class X86Emulator: - def __init__(self, logging=True): + def __init__(self, logging=True, check_regs=False): self.registers = defaultdict(lambda: None) self.memory = defaultdict(lambda: None) self.variables = defaultdict(lambda: None) self.logging = logging - self.registers['rbp'] = 1000 - self.registers['rsp'] = 1000 - + self.check_regs = check_regs + self.callee_saves = ['rsp','rbp','rbx','r12','r13','r14'] # omit 'r15' + bogus = Bogus() + self.initial_callee_save_values = [1000,1000,bogus,bogus,bogus,bogus,bogus] + for r,v in zip(self.callee_saves, self.initial_callee_save_values): + self.registers[r] = v self.global_vals = {} def log(self, s): if self.logging: print(s) + def trash_caller_saves(self): + caller_saves = ['rax', 'rcx', 'rdx', 'rsi', 'rdi', 'r8', 'r9', 'r10'] # omit 'r11' + if self.check_regs: + for r in caller_saves: + self.registers[r] = Bogus() + def parse_and_eval_program(self, s): p = x86_parser.parse(s) @@ -70,6 +86,11 @@ def eval_program(self, p): self.log(f'OUTPUT: {output}') self.log('========== FINISHED EXECUTION ==============================') + if self.check_regs: + for r,v in zip(self.callee_saves, self.initial_callee_save_values): + if self.registers[r] != v: + print(f'OOPS: initial callee save value overwritten: {r} {self.registers[r]}') # will show as a diff with .golden + return output def eval_instructions(self, s): @@ -248,12 +269,42 @@ def eval_instrs(self, instrs, blocks, output): v2 = self.eval_arg(a2) self.store_arg(a2, sub64(v2, v1)) + elif instr.data == 'imulq': + a1, a2 = instr.children + v1 = self.eval_arg(a1) + v2 = self.eval_arg(a2) + self.store_arg(a2, mul64(v1, v2)) + elif instr.data == 'xorq': a1, a2 = instr.children v1 = self.eval_arg(a1) v2 = self.eval_arg(a2) self.store_arg(a2, xor64(v1, v2)) + elif instr.data == 'andq': + a1, a2 = instr.children + v1 = self.eval_arg(a1) + v2 = self.eval_arg(a2) + self.store_arg(a2, and64(v1, v2)) + + elif instr.data == 'orq': + a1, a2 = instr.children + v1 = self.eval_arg(a1) + v2 = self.eval_arg(a2) + self.store_arg(a2, or64(v1, v2)) + + elif instr.data == 'sarq': + a1, a2 = instr.children + v1 = self.eval_arg(a1) + v2 = self.eval_arg(a2) + self.store_arg(a2, shiftra64(v2, v1)) + + elif instr.data == 'salq': + a1, a2 = instr.children + v1 = self.eval_arg(a1) + v2 = self.eval_arg(a2) + self.store_arg(a2, shiftl64(v2, v1)) + elif instr.data == 'negq': a1 = instr.children[0] v1 = self.eval_arg(a1) @@ -318,14 +369,16 @@ def eval_instrs(self, instrs, blocks, output): output.append(self.registers['rdi']) if self.logging: print(self.print_state()) + self.trash_caller_saves() elif target == label_name('read_int'): + self.trash_caller_saves() self.registers['rax'] = input_int() self.log(f'CALL TO read_int: {self.registers["rax"]}') if self.logging: print(self.print_state()) - elif target == 'initialize': + elif target == label_name('initialize'): self.log(f'CALL TO initialize: {self.registers["rdi"]}, {self.registers["rsi"]}') rootstack_size = self.registers['rdi'] heap_size = self.registers['rsi'] @@ -337,23 +390,24 @@ def eval_instrs(self, instrs, blocks, output): fromspace_end = fromspace_begin + heap_size self.global_vals = { **self.global_vals, - 'rootstack_begin': rs_begin, - 'rootstack_end': rs_end, - 'free_ptr': fromspace_begin, - 'fromspace_begin': fromspace_begin, - 'fromspace_end': fromspace_end + label_name('rootstack_begin'): rs_begin, + label_name('rootstack_end'): rs_end, + label_name('free_ptr'): fromspace_begin, + label_name('fromspace_begin'): fromspace_begin, + label_name('fromspace_end'): fromspace_end } if self.logging: print(self.print_state()) + self.trash_caller_saves() - elif target == 'collect': + elif target == label_name('collect'): self.log(f'CALL TO collect: need {self.registers["rsi"]} bytes') needed = self.registers["rsi"] - fsb = self.global_vals['fromspace_begin'] - fse = self.global_vals['fromspace_end'] + fsb = self.global_vals[label_name('fromspace_begin')] + fse = self.global_vals[label_name('fromspace_end')] current_space = fse - fsb @@ -362,13 +416,37 @@ def eval_instrs(self, instrs, blocks, output): new_space = new_space * 2 new_fse = fsb + new_space - self.global_vals['fromspace_end'] = new_fse + self.global_vals[label_name('fromspace_end')] = new_fse if self.logging: print(self.print_state()) + self.trash_caller_saves() + + # this seems to be buggy + # elif target == label_name('get_vec_length'): + # tag = self.registers["rsi"] + # self.log(f'CALL TO get_vec_length: tag {tag}') + + # TAG_VECOF_RSHIFT = 62 + # TAG_VECOF_LENGTH_RSHIFT = 2 + # TAG_VEC_LENGTH_MASK = 126 # 1111110 + # TAG_VEC_LENGTH_RSHIFT = 1 + # # is_vecof + # if and64(1, shiftra64(tag,TAG_VECOF_RSHIFT)) != 0: + # # get_vecof_length + # res = shiftra64(shiftra64(shiftl64(tag,2),2),TAG_VECOF_LENGTH_RSHIFT) + # else: + # # get_vector_length + # res = shiftra64(and64(tag, TAG_VEC_LENGTH_MASK),TAG_VEC_LENGTH_RSHIFT) + # self.trash_caller_saves() + # self.registers['rax'] = res + # if self.logging: + # print(self.print_state()) else: + self.registers['rsp'] = self.registers['rsp'] - 8 # simulate push of RA self.eval_instrs(blocks[target], blocks, output) + self.registers['rsp'] = self.registers['rsp'] + 8 # and pop of RA elif instr.data == 'retq': return @@ -380,24 +458,26 @@ def eval_instrs(self, instrs, blocks, output): if v1 == v2: self.registers['EFLAGS'] = 'e' - elif v2 < v1: + elif type(v1) == int and type(v2) == int and v2 < v1: self.registers['EFLAGS'] = 'l' - elif v2 > v1: + elif type(v1) == int and type(v2) == int and v2 > v1: self.registers['EFLAGS'] = 'g' - else: - raise RuntimeError(f'failed comparison: {instr}') +# else: +# raise RuntimeError(f'failed comparison: {instr} {v1} {v2}') elif instr.data == 'leaq': a1, a2 = instr.children v1 = self.eval_arg(a1) - assert isinstance(v1, FunPointer) +# assert isinstance(v1, FunPointer) self.store_arg(a2, v1) elif instr.data == 'indirect_callq': v = self.eval_arg(instr.children[0]) assert isinstance(v, FunPointer) target = v.fun_name + self.registers['rsp'] = self.registers['rsp'] - 8 self.eval_instrs(blocks[target], blocks, output) + self.registers['rsp'] = self.registers['rsp'] + 8 elif instr.data == 'indirect_jmp': v = self.eval_arg(instr.children[0]) diff --git a/utils.py b/utils.py index a6f629c..13cb17a 100644 --- a/utils.py +++ b/utils.py @@ -1075,6 +1075,18 @@ def mul64(x,y): def neg64(x): return to_signed(-x) +def shiftra64(x,y): + return to_signed(x>>y) + +def shiftl64(x,y): + return to_signed(x<