Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment with replacing textual representation in classfile with method that builds the model #274

Open
wants to merge 13 commits into
base: code-reflection
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,19 @@ private static <O extends Op & Op.Invokable> void generateMethod(MethodHandles.L
iop.body().blocks(), cob, lambdaSink, quotable).generate()));
}

public static byte[] addOpByteCodeToClassFile(MethodHandles.Lookup lookup, ClassModel cm, String methodName, FuncOp builderOp) {
var bytes = generateClassData(lookup, methodName, builderOp);
var builderMethod = ClassFile.of().parse(bytes).methods().stream()
.filter(mm -> mm.methodName().equalsString(methodName)).findFirst().get();
var newBytes = ClassFile.of().build(cm.thisClass().asSymbol(), cb -> {
for (var ce : cm) {
cb.with(ce);
}
cb.with(builderMethod);
});
return newBytes;
}

private record Slot(int slot, TypeKind typeKind) {}
private record ExceptionRegionWithBlocks(ExceptionRegionEnter ere, BitSet blocks) {}

Expand Down Expand Up @@ -848,7 +861,7 @@ private void generate() {
.dup();
processOperands(op);
cob.invokespecial(
((JavaType) op.resultType()).toNominalDescriptor(),
jt.toNominalDescriptor(),
ConstantDescs.INIT_NAME,
MethodRef.toNominalDescriptor(op.constructorType())
.changeReturnType(ConstantDescs.CD_void));
Expand All @@ -860,8 +873,27 @@ private void generate() {
push(op.result());
}
case InvokeOp op -> {
// @@@ var args
processOperands(op);
if (op.isVarArgs()) {
processOperands(op.operands().subList(0, op.operands().size() - op.varArgOperands().size()));
var varArgOperands = op.varArgOperands();
cob.loadConstant(varArgOperands.size());
var compType = ((ArrayType) op.invokeDescriptor().type().parameterTypes().getLast()).componentType();
var typeKind = TypeKind.fromDescriptor(compType.toNominalDescriptor().descriptorString());
if (TypeKind.REFERENCE.equals(typeKind)) {
var cd = ClassDesc.ofDescriptor(compType.toNominalDescriptor().descriptorString());
cob.anewarray(cd);
} else {
cob.newarray(typeKind);
}
for (int j = 0; j < varArgOperands.size(); j++) {
cob.dup();
cob.loadConstant(j);
load(varArgOperands.get(j));
cob.arrayStore(typeKind);
}
} else {
processOperands(op);
}
// Resolve referenced class to determine if interface
MethodRef md = op.invokeDescriptor();
JavaType refType = (JavaType)md.refType();
Expand All @@ -873,9 +905,6 @@ private void generate() {
}
// Determine invoke opcode
final boolean isInterface = refClass.isInterface();
if (op.isVarArgs()) {
throw new UnsupportedOperationException("invoke varargs unsupported: " + op.invokeDescriptor());
}
Opcode invokeOpcode = switch (op.invokeKind()) {
case STATIC ->
Opcode.INVOKESTATIC;
Expand Down
89 changes: 89 additions & 0 deletions test/jdk/java/lang/reflect/code/bytecode/TestVarArg.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import org.testng.Assert;
import org.testng.annotations.Test;

import java.lang.classfile.ClassFile;
import java.lang.classfile.components.ClassPrinter;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.reflect.Method;
import java.lang.reflect.code.OpTransformer;
import java.lang.reflect.code.bytecode.BytecodeGenerator;
import java.lang.reflect.code.op.CoreOp;
import java.lang.runtime.CodeReflection;
import java.util.Arrays;
import java.util.Optional;
import java.util.stream.Stream;

/*
* @test
* @enablePreview
* @run testng TestVarArg
*
*/
public class TestVarArg {

@Test
void test() throws Throwable {
var f = getFuncOp("f");
f.writeTo(System.out);

var lf = f.transform(OpTransformer.LOWERING_TRANSFORMER);
lf.writeTo(System.out);

var bytes = BytecodeGenerator.generateClassData(MethodHandles.lookup(), f);
var classModel = ClassFile.of().parse(bytes);
ClassPrinter.toYaml(classModel, ClassPrinter.Verbosity.TRACE_ALL, System.out::print);

MethodHandle mh = BytecodeGenerator.generate(MethodHandles.lookup(), lf);
Assert.assertEquals(mh.invoke(), f());
}

@CodeReflection
static String f() {
String r = "";
String ls = System.lineSeparator();

r += ls + h(1);
r += ls + h(2, 3);
r += ls + h(4, (byte) 5);

r += ls + k(Byte.MIN_VALUE, Byte.MAX_VALUE);

r += ls + j("s1", "s2", "s3");

r += ls + w(8, 9);

r += k();

r += w(11L, 12L);

r += w(21.0, 22.0);

return r;
}

static String h(int i, int... s) {
return i + ", " + Arrays.toString(s);
}

static String k(byte... s) {
return Arrays.toString(s);
}

static String j(String i, String... s) {
return i + ", " + Arrays.toString(s);
}

static <T extends Number> String w(T... ts) {
return Arrays.toString(ts);
}

private CoreOp.FuncOp getFuncOp(String name) {
Optional<Method> om = Stream.of(this.getClass().getDeclaredMethods())
.filter(m -> m.getName().equals(name))
.findFirst();

Method m = om.get();
return m.getCodeModel().get();
}
}
9 changes: 9 additions & 0 deletions test/jdk/java/lang/reflect/code/writer/IR.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import java.lang.runtime.CodeReflection;

public class IR {

@CodeReflection
static String add(String a, int b) {
return a + b;
}
}
139 changes: 139 additions & 0 deletions test/jdk/java/lang/reflect/code/writer/TestOpMethod.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import java.io.IOException;
import java.lang.classfile.*;
import java.lang.classfile.components.ClassPrinter;
import java.lang.classfile.constantpool.FieldRefEntry;
import java.lang.classfile.constantpool.StringEntry;
import java.lang.classfile.instruction.*;
import java.lang.constant.ConstantDescs;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.reflect.code.Op;
import java.lang.reflect.code.bytecode.BytecodeGenerator;
import java.lang.reflect.code.interpreter.Interpreter;
import java.lang.reflect.code.op.ExtendedOp;
import java.lang.reflect.code.op.OpFactory;
import java.lang.reflect.code.parser.OpParser;
import java.lang.reflect.code.type.*;
import java.lang.reflect.code.writer.OpBuilder;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;

import static java.lang.reflect.code.op.CoreOp.FuncOp;

public class TestOpMethod {

public static void main(String[] args) throws IOException, ClassNotFoundException {
for (var arg : args) {
var path = Path.of(arg);
var originalBytes = Files.readAllBytes(path);
var newBytes = TestOpMethod.replaceOpFieldWithBuilderMethod(originalBytes);
System.out.printf("%s %d %d%n", arg, originalBytes.length, newBytes.length);
// TODO a script that runs the tool for many classes
// TODO reduce size if possible
}
}

static byte[] replaceOpFieldWithBuilderMethod(byte[] classData) {
return replaceOpFieldWithBuilderMethod(ClassFile.of().parse(classData));
}

record OpFieldAndIR(FieldRefEntry opField, String ir) {
}

static byte[] replaceOpFieldWithBuilderMethod(ClassModel classModel) {
var opFieldsAndIRs = new ArrayList<OpFieldAndIR>();
var classTransform = ClassTransform.dropping(e -> e instanceof FieldModel fm && fm.fieldName().stringValue().endsWith("$op")).andThen(
ClassTransform.transformingMethods(mm -> mm.methodName().equalsString(ConstantDescs.CLASS_INIT_NAME), (mb, me) -> {
if (!(me instanceof CodeModel codeModel)) {
mb.with(me);
return;
}
mb.withCode(cob -> {
ConstantInstruction.LoadConstantInstruction ldc = null;
for (CodeElement e : codeModel) {
if (ldc != null && e instanceof FieldInstruction fi && fi.opcode() == Opcode.PUTSTATIC && fi.owner().equals(classModel.thisClass()) && fi.name().stringValue().endsWith("$op")) {
opFieldsAndIRs.add(new OpFieldAndIR(fi.field(), ((StringEntry) ldc.constantEntry()).stringValue()));
ldc = null;
} else {
if (ldc != null) {
cob.with(ldc);
ldc = null;
}
switch (e) {
case ConstantInstruction.LoadConstantInstruction lci when lci.constantEntry() instanceof StringEntry ->
ldc = lci;
case LineNumber _, CharacterRange _, LocalVariable _, LocalVariableType _ -> {
}
default -> cob.with(e);
}
}
}
});
})).andThen(ClassTransform.endHandler(clb -> {
for (var opFieldAndIR : opFieldsAndIRs) {
var funcOp = ((FuncOp) OpParser.fromStringOfFuncOp(opFieldAndIR.ir()));
var builderOp = OpBuilder.createBuilderFunction(funcOp);
testBuilderOp(builderOp, opFieldAndIR.ir());
var opFieldName = opFieldAndIR.opField().name().stringValue();
var methodName = builderMethodName(opFieldName);
byte[] bytes = BytecodeGenerator.generateClassData(MethodHandles.lookup(), methodName, builderOp);
var builderMethod = ClassFile.of().parse(bytes).methods().stream()
.filter(mm -> mm.methodName().equalsString(methodName)).findFirst().orElseThrow();
clb.with(builderMethod);
}
}));
var newBytes = ClassFile.of(ClassFile.ConstantPoolSharingOption.NEW_POOL).transformClass(classModel, classTransform);
testBuilderMethods(newBytes, opFieldsAndIRs);
return newBytes;
}

static void testBuilderOp(FuncOp builderOp, String expectedIR) {
var op = (Op) Interpreter.invoke(builderOp, ExtendedOp.FACTORY, CoreTypeFactory.CORE_TYPE_FACTORY);
assert expectedIR.equals(op.toText());
}

static void testBuilderMethods(byte[] classData, List<OpFieldAndIR> opFieldsAndIRs) {
MethodHandles.Lookup lookup = null;
try {
lookup = MethodHandles.lookup().defineHiddenClass(classData, true);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
for (var opFieldAndIR : opFieldsAndIRs) {
var opFieldName = opFieldAndIR.opField().name().stringValue();
var methodName = builderMethodName(opFieldName);
var functionType = FunctionType.functionType(JavaType.type(Op.class), JavaType.type(OpFactory.class),
JavaType.type(TypeElementFactory.class));
MethodHandle mh = null;
try {
mh = lookup.findStatic(lookup.lookupClass(),
methodName,
MethodRef.toNominalDescriptor(functionType).resolveConstantDesc(lookup));
} catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
Op builtOp = null;
try {
builtOp = ((Op) mh.invoke(ExtendedOp.FACTORY, CoreTypeFactory.CORE_TYPE_FACTORY));
} catch (Throwable e) {
throw new RuntimeException(e);
}
assert builtOp.toText().equals(opFieldAndIR.ir());
}
}

static String builderMethodName(String opFieldName) {
// e.g. A::add(int, int)int$op ---> add(int, int)int$op
return opFieldName.substring(opFieldName.indexOf(':') + 2);
}

static void print(byte[] bytes) {
print(ClassFile.of().parse(bytes));
}

static void print(ClassModel cm) {
ClassPrinter.toYaml(cm, ClassPrinter.Verbosity.TRACE_ALL, System.out::print);
}
}