Skip to content

Commit 53c3cca

Browse files
committed
Add CTOR_HEAD injection point
1 parent 855d67c commit 53c3cca

File tree

5 files changed

+216
-3
lines changed

5 files changed

+216
-3
lines changed
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
/*
2+
* Minecraft Development for IntelliJ
3+
*
4+
* https://mcdev.io/
5+
*
6+
* Copyright (C) 2024 minecraft-dev
7+
*
8+
* This program is free software: you can redistribute it and/or modify
9+
* it under the terms of the GNU Lesser General Public License as published
10+
* by the Free Software Foundation, version 3.0 only.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU Lesser General Public License
18+
* along with this program. If not, see <https://www.gnu.org/licenses/>.
19+
*/
20+
21+
package com.demonwav.mcdev.platform.mixin.handlers.injectionPoint
22+
23+
import com.demonwav.mcdev.platform.mixin.reference.MixinSelector
24+
import com.demonwav.mcdev.platform.mixin.util.findOrConstructSourceMethod
25+
import com.demonwav.mcdev.platform.mixin.util.findSuperConstructorCall
26+
import com.demonwav.mcdev.platform.mixin.util.isConstructor
27+
import com.demonwav.mcdev.util.createLiteralExpression
28+
import com.demonwav.mcdev.util.enumValueOfOrNull
29+
import com.demonwav.mcdev.util.findContainingClass
30+
import com.intellij.codeInsight.lookup.LookupElementBuilder
31+
import com.intellij.openapi.editor.Editor
32+
import com.intellij.openapi.project.Project
33+
import com.intellij.psi.JavaPsiFacade
34+
import com.intellij.psi.PsiAnnotation
35+
import com.intellij.psi.PsiClass
36+
import com.intellij.psi.PsiElement
37+
import com.intellij.psi.PsiExpression
38+
import com.intellij.psi.PsiField
39+
import com.intellij.psi.PsiLiteral
40+
import com.intellij.psi.PsiMethod
41+
import com.intellij.psi.PsiMethodCallExpression
42+
import com.intellij.psi.PsiMethodReferenceExpression
43+
import com.intellij.psi.PsiReferenceExpression
44+
import com.intellij.psi.PsiStatement
45+
import com.intellij.psi.codeStyle.CodeStyleManager
46+
import com.intellij.psi.util.PsiUtil
47+
import com.intellij.psi.util.parentOfType
48+
import com.intellij.util.JavaPsiConstructorUtil
49+
import org.objectweb.asm.Opcodes
50+
import org.objectweb.asm.tree.ClassNode
51+
import org.objectweb.asm.tree.FieldInsnNode
52+
import org.objectweb.asm.tree.MethodNode
53+
54+
class CtorHeadInjectionPoint : InjectionPoint<PsiElement>() {
55+
override fun onCompleted(editor: Editor, reference: PsiLiteral) {
56+
val at = reference.parentOfType<PsiAnnotation>() ?: return
57+
val project = reference.project
58+
at.setDeclaredAttributeValue(
59+
"unsafe",
60+
JavaPsiFacade.getElementFactory(project).createLiteralExpression(true)
61+
)
62+
CodeStyleManager.getInstance(project).reformat(at)
63+
}
64+
65+
override fun createNavigationVisitor(
66+
at: PsiAnnotation,
67+
target: MixinSelector?,
68+
targetClass: PsiClass
69+
): NavigationVisitor {
70+
val args = AtResolver.getArgs(at)
71+
val enforce = args["enforce"]?.let { enumValueOfOrNull<EnforceMode>(it) } ?: EnforceMode.DEFAULT
72+
return MyNavigationVisitor(enforce)
73+
}
74+
75+
override fun doCreateCollectVisitor(
76+
at: PsiAnnotation,
77+
target: MixinSelector?,
78+
targetClass: ClassNode,
79+
mode: CollectVisitor.Mode
80+
): CollectVisitor<PsiElement> {
81+
val args = AtResolver.getArgs(at)
82+
val enforce = args["enforce"]?.let { enumValueOfOrNull<EnforceMode>(it) } ?: EnforceMode.DEFAULT
83+
return MyCollectVisitor(at.project, targetClass, mode, enforce)
84+
}
85+
86+
override fun createLookup(
87+
targetClass: ClassNode,
88+
result: CollectVisitor.Result<PsiElement>
89+
): LookupElementBuilder? {
90+
return null
91+
}
92+
93+
private enum class EnforceMode {
94+
DEFAULT, POST_DELEGATE, POST_INIT
95+
}
96+
97+
private class MyCollectVisitor(
98+
project: Project,
99+
clazz: ClassNode,
100+
mode: Mode,
101+
private val enforce: EnforceMode,
102+
) : HeadInjectionPoint.MyCollectVisitor(project, clazz, mode) {
103+
override fun accept(methodNode: MethodNode) {
104+
val insns = methodNode.instructions ?: return
105+
106+
if (!methodNode.isConstructor) {
107+
super.accept(methodNode)
108+
return
109+
}
110+
111+
val superCtorCall = methodNode.findSuperConstructorCall() ?: run {
112+
super.accept(methodNode)
113+
return
114+
}
115+
116+
if (enforce == EnforceMode.POST_DELEGATE) {
117+
val insn = superCtorCall.next ?: return
118+
addResult(insn, methodNode.findOrConstructSourceMethod(clazz, project))
119+
return
120+
}
121+
122+
// Although Mumfrey's original intention was to target the last *unique* field store,
123+
// i.e. ignore duplicate field stores that occur later, due to a bug in the implementation
124+
// it simply finds the last PUTFIELD whose owner is the target class. Mumfrey now says he
125+
// doesn't want to change the implementation in case of breaking mixins that rely on this
126+
// behavior, so it is now effectively intended, so it's what we'll use here.
127+
val lastFieldStore = generateSequence(insns.last) { it.previous }
128+
.takeWhile { it !== superCtorCall }
129+
.firstOrNull { insn ->
130+
insn.opcode == Opcodes.PUTFIELD &&
131+
(insn as FieldInsnNode).owner == clazz.name
132+
} ?: superCtorCall
133+
134+
val lastFieldStoreNext = lastFieldStore.next ?: return
135+
addResult(lastFieldStoreNext, methodNode.findOrConstructSourceMethod(clazz, project))
136+
}
137+
}
138+
139+
private class MyNavigationVisitor(private val enforce: EnforceMode) : NavigationVisitor() {
140+
private var isConstructor = true
141+
private var firstStatement = true
142+
private lateinit var elementToReturn: PsiElement
143+
144+
override fun visitStart(executableElement: PsiElement) {
145+
isConstructor = executableElement is PsiMethod && executableElement.isConstructor
146+
elementToReturn = executableElement
147+
}
148+
149+
override fun visitExpression(expression: PsiExpression) {
150+
if (firstStatement) {
151+
elementToReturn = expression
152+
firstStatement = false
153+
}
154+
super.visitExpression(expression)
155+
}
156+
157+
override fun visitStatement(statement: PsiStatement) {
158+
if (firstStatement) {
159+
elementToReturn = statement
160+
firstStatement = false
161+
}
162+
super.visitStatement(statement)
163+
}
164+
165+
override fun visitMethodCallExpression(expression: PsiMethodCallExpression) {
166+
super.visitMethodCallExpression(expression)
167+
if (isConstructor) {
168+
if (JavaPsiConstructorUtil.isChainedConstructorCall(expression) ||
169+
JavaPsiConstructorUtil.isSuperConstructorCall(expression)
170+
) {
171+
elementToReturn = expression
172+
}
173+
}
174+
}
175+
176+
override fun visitReferenceExpression(expression: PsiReferenceExpression) {
177+
super.visitReferenceExpression(expression)
178+
if (isConstructor &&
179+
enforce != EnforceMode.POST_DELEGATE &&
180+
expression !is PsiMethodReferenceExpression &&
181+
PsiUtil.isAccessedForWriting(expression)
182+
) {
183+
val resolvedField = expression.resolve()
184+
if (resolvedField is PsiField && resolvedField.containingClass == expression.findContainingClass()) {
185+
elementToReturn = expression
186+
}
187+
}
188+
}
189+
190+
override fun visitEnd(executableElement: PsiElement) {
191+
addResult(elementToReturn)
192+
}
193+
}
194+
}

src/main/kotlin/platform/mixin/handlers/injectionPoint/HeadInjectionPoint.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ class HeadInjectionPoint : InjectionPoint<PsiElement>() {
5757
return null
5858
}
5959

60-
private class MyCollectVisitor(
61-
private val project: Project,
62-
private val clazz: ClassNode,
60+
internal open class MyCollectVisitor(
61+
protected val project: Project,
62+
protected val clazz: ClassNode,
6363
mode: Mode,
6464
) : CollectVisitor<PsiElement>(mode) {
6565
override fun accept(methodNode: MethodNode) {

src/main/kotlin/platform/mixin/handlers/injectionPoint/InjectionPoint.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,9 @@ abstract class NavigationVisitor : JavaRecursiveElementVisitor() {
294294
result += element
295295
}
296296

297+
open fun visitStart(executableElement: PsiElement) {
298+
}
299+
297300
open fun visitEnd(executableElement: PsiElement) {
298301
}
299302

@@ -304,6 +307,7 @@ abstract class NavigationVisitor : JavaRecursiveElementVisitor() {
304307

305308
override fun visitMethod(method: PsiMethod) {
306309
if (!hasVisitedAnything) {
310+
visitStart(method)
307311
super.visitMethod(method)
308312
visitEnd(method)
309313
}
@@ -312,6 +316,7 @@ abstract class NavigationVisitor : JavaRecursiveElementVisitor() {
312316
override fun visitAnonymousClass(aClass: PsiAnonymousClass) {
313317
// do not recurse into anonymous classes
314318
if (!hasVisitedAnything) {
319+
visitStart(aClass)
315320
super.visitAnonymousClass(aClass)
316321
visitEnd(aClass)
317322
}
@@ -320,13 +325,17 @@ abstract class NavigationVisitor : JavaRecursiveElementVisitor() {
320325
override fun visitClass(aClass: PsiClass) {
321326
// do not recurse into inner classes
322327
if (!hasVisitedAnything) {
328+
visitStart(aClass)
323329
super.visitClass(aClass)
324330
visitEnd(aClass)
325331
}
326332
}
327333

328334
override fun visitMethodReferenceExpression(expression: PsiMethodReferenceExpression) {
329335
val hadVisitedAnything = hasVisitedAnything
336+
if (!hadVisitedAnything) {
337+
visitStart(expression)
338+
}
330339
super.visitMethodReferenceExpression(expression)
331340
if (!hadVisitedAnything) {
332341
visitEnd(expression)
@@ -336,6 +345,7 @@ abstract class NavigationVisitor : JavaRecursiveElementVisitor() {
336345
override fun visitLambdaExpression(expression: PsiLambdaExpression) {
337346
// do not recurse into lambda expressions
338347
if (!hasVisitedAnything) {
348+
visitStart(expression)
339349
super.visitLambdaExpression(expression)
340350
visitEnd(expression)
341351
}

src/main/kotlin/util/utils.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,11 @@ fun <S : CharSequence, R> S.ifNotBlank(block: (S) -> R): R? {
388388

389389
return null
390390
}
391+
392+
inline fun <reified T : Enum<T>> enumValueOfOrNull(str: String): T? {
393+
return try {
394+
enumValueOf<T>(str)
395+
} catch (e: IllegalArgumentException) {
396+
null
397+
}
398+
}

src/main/resources/META-INF/plugin.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@
151151
<injectionPoint atCode="HEAD" implementation="com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.HeadInjectionPoint" />
152152
<injectionPoint atCode="RETURN" implementation="com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.ReturnInjectionPoint" />
153153
<injectionPoint atCode="TAIL" implementation="com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.TailInjectionPoint" />
154+
<injectionPoint atCode="CTOR_HEAD" implementation="com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.CtorHeadInjectionPoint" />
154155
<injectionPoint atCode="LOAD" implementation="com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.LoadInjectionPoint" />
155156
<injectionPoint atCode="STORE" implementation="com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.StoreInjectionPoint" />
156157
<injectionPoint atCode="CONSTANT" implementation="com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.ConstantInjectionPoint" />

0 commit comments

Comments
 (0)