Skip to content

Rework GoStructInitializationInspection #2826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions src/com/goide/completion/GoStructLiteralCompletion.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import com.goide.psi.*;
import com.goide.psi.impl.GoPsiImplUtil;
import com.intellij.psi.PsiElement;
import com.intellij.util.ObjectUtils;
import com.intellij.util.containers.ContainerUtil;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

Expand Down Expand Up @@ -61,8 +59,8 @@ enum Variants {

@NotNull
static Variants allowedVariants(@Nullable GoReferenceExpression structFieldReference) {
GoValue value = parent(structFieldReference, GoValue.class);
GoElement element = parent(value, GoElement.class);
GoValue value = GoPsiTreeUtil.getDirectParentOfType(structFieldReference, GoValue.class);
GoElement element = GoPsiTreeUtil.getDirectParentOfType(value, GoElement.class);
if (element != null && element.getKey() != null) {
return Variants.NONE;
}
Expand All @@ -75,7 +73,7 @@ static Variants allowedVariants(@Nullable GoReferenceExpression structFieldRefer
boolean hasValueInitializers = false;
boolean hasFieldValueInitializers = false;

GoLiteralValue literalValue = parent(element, GoLiteralValue.class);
GoLiteralValue literalValue = GoPsiTreeUtil.getDirectParentOfType(element, GoLiteralValue.class);
List<GoElement> fieldInitializers = literalValue != null ? literalValue.getElementList() : Collections.emptyList();
for (GoElement initializer : fieldInitializers) {
if (initializer == element) {
Expand Down Expand Up @@ -105,9 +103,4 @@ static Set<String> alreadyAssignedFields(@Nullable GoLiteralValue literal) {
return identifier != null ? identifier.getText() : null;
});
}

@Contract("null,_->null")
private static <T> T parent(@Nullable PsiElement of, @NotNull Class<T> parentClass) {
return ObjectUtils.tryCast(of != null ? of.getParent() : null, parentClass);
}
}
142 changes: 97 additions & 45 deletions src/com/goide/inspections/GoStructInitializationInspection.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,33 @@
import com.goide.util.GoUtil;
import com.intellij.codeInspection.*;
import com.intellij.codeInspection.ui.SingleCheckboxOptionsPanel;
import com.intellij.openapi.progress.ProgressManager;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.util.Comparing;
import com.intellij.openapi.util.InvalidDataException;
import com.intellij.openapi.util.WriteExternalException;
import com.intellij.psi.PsiElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.containers.ContainerUtil;
import com.intellij.util.ObjectUtils;
import org.jdom.Element;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import javax.swing.*;
import java.util.List;

import static com.intellij.util.containers.ContainerUtil.emptyList;
import static com.intellij.util.containers.ContainerUtil.list;
import static java.lang.Math.min;
import static java.util.stream.Collectors.toList;
import static java.util.stream.IntStream.range;

public class GoStructInitializationInspection extends GoInspectionBase {
public static final String REPLACE_WITH_NAMED_STRUCT_FIELD_FIX_NAME = "Replace with named struct field";
public static final String REPLACE_WITH_NAMED_STRUCT_FIELD_FIX_NAME = "Replace with named struct fields";
private static final GoReplaceWithNamedStructFieldQuickFix QUICK_FIX = new GoReplaceWithNamedStructFieldQuickFix();
public boolean reportLocalStructs;
/**
* @deprecated use reportLocalStructs
* @deprecated use {@link #reportLocalStructs}
*/
@SuppressWarnings("WeakerAccess") public Boolean reportImportedStructs;

Expand All @@ -49,67 +57,111 @@ public class GoStructInitializationInspection extends GoInspectionBase {
protected GoVisitor buildGoVisitor(@NotNull ProblemsHolder holder, @NotNull LocalInspectionToolSession session) {
return new GoVisitor() {
@Override
public void visitLiteralValue(@NotNull GoLiteralValue o) {
if (PsiTreeUtil.getParentOfType(o, GoReturnStatement.class, GoShortVarDeclaration.class, GoAssignmentStatement.class) == null) {
return;
}
PsiElement parent = o.getParent();
GoType refType = GoPsiImplUtil.getLiteralType(parent, false);
if (refType instanceof GoStructType) {
processStructType(holder, o, (GoStructType)refType);
public void visitLiteralValue(@NotNull GoLiteralValue literalValue) {
GoStructType structType = getLiteralStructType(literalValue);
if (structType == null || !isStructImportedOrLocalAllowed(structType, literalValue)) return;

List<GoElement> elements = literalValue.getElementList();
List<GoNamedElement> definitions = getFieldDefinitions(structType);

if (!areElementsKeysMatchesDefinitions(elements, definitions)) return;
registerProblemsForElementsWithoutKeys(elements, definitions.size());
}

private void registerProblemsForElementsWithoutKeys(@NotNull List<GoElement> elements, int definitionsCount) {
for (int i = 0; i < min(elements.size(), definitionsCount); i++) {
if (elements.get(i).getKey() != null) continue;
holder.registerProblem(elements.get(i), "Unnamed field initialization", ProblemHighlightType.WEAK_WARNING, QUICK_FIX);
}
}
};
}

@Override
public JComponent createOptionsPanel() {
return new SingleCheckboxOptionsPanel("Report for local type definitions as well", this, "reportLocalStructs");
}
@Contract("null -> null")
private static GoStructType getLiteralStructType(@Nullable GoLiteralValue literalValue) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this example is invalid

type S struct {
	r, t int
}

type B struct{
	S
}

var  _  = []B{  S: {2,  3}}

GoCompositeLit parentLit = GoPsiTreeUtil.getDirectParentOfType(literalValue, GoCompositeLit.class);
if (parentLit != null && !isStructLit(parentLit)) return null;

private void processStructType(@NotNull ProblemsHolder holder, @NotNull GoLiteralValue element, @NotNull GoStructType structType) {
if (reportLocalStructs || !GoUtil.inSamePackage(structType.getContainingFile(), element.getContainingFile())) {
processLiteralValue(holder, element, structType.getFieldDeclarationList());
}
GoStructType litType = ObjectUtils.tryCast(GoPsiImplUtil.getLiteralType(literalValue, parentLit == null), GoStructType.class);
GoNamedElement definition = getFieldDefinition(GoPsiTreeUtil.getDirectParentOfType(literalValue, GoValue.class));
return definition != null && litType != null ? getUnderlyingStructType(definition.getGoType(null)) : litType;
}

private static void processLiteralValue(@NotNull ProblemsHolder holder,
@NotNull GoLiteralValue o,
@NotNull List<GoFieldDeclaration> fields) {
List<GoElement> vals = o.getElementList();
for (int elemId = 0; elemId < vals.size(); elemId++) {
ProgressManager.checkCanceled();
GoElement element = vals.get(elemId);
if (element.getKey() == null && elemId < fields.size()) {
String structFieldName = getFieldName(fields.get(elemId));
LocalQuickFix[] fixes = structFieldName != null ? new LocalQuickFix[]{new GoReplaceWithNamedStructFieldQuickFix(structFieldName)}
: LocalQuickFix.EMPTY_ARRAY;
holder.registerProblem(element, "Unnamed field initialization", ProblemHighlightType.GENERIC_ERROR_OR_WARNING, fixes);
}
}
@Nullable
private static GoNamedElement getFieldDefinition(@Nullable GoValue value) {
GoKey key = PsiTreeUtil.getPrevSiblingOfType(value, GoKey.class);
GoFieldName fieldName = key != null ? key.getFieldName() : null;
PsiElement field = fieldName != null ? fieldName.resolve() : null;
return GoPsiImplUtil.isFieldDefinition(field) ? ObjectUtils.tryCast(field, GoNamedElement.class) : null;
}

@Nullable
private static String getFieldName(@NotNull GoFieldDeclaration declaration) {
List<GoFieldDefinition> list = declaration.getFieldDefinitionList();
GoFieldDefinition fieldDefinition = ContainerUtil.getFirstItem(list);
return fieldDefinition != null ? fieldDefinition.getIdentifier().getText() : null;
@Contract("null -> null")
private static GoStructType getUnderlyingStructType(@Nullable GoType type) {
return type != null ? ObjectUtils.tryCast(type.getUnderlyingType(), GoStructType.class) : null;
}

private static boolean isStructLit(@NotNull GoCompositeLit compositeLit) {
return getUnderlyingStructType(compositeLit.getGoType(null)) != null;
}

private boolean isStructImportedOrLocalAllowed(@NotNull GoStructType structType, @NotNull GoLiteralValue literalValue) {
return reportLocalStructs || !GoUtil.inSamePackage(structType.getContainingFile(), literalValue.getContainingFile());
}

private static boolean areElementsKeysMatchesDefinitions(@NotNull List<GoElement> elements, @NotNull List<GoNamedElement> definitions) {
return range(0, elements.size()).allMatch(i -> isNullOrNamesEqual(elements.get(i).getKey(), GoPsiImplUtil.getByIndex(definitions, i)));
}

@Contract("null, _ -> true; !null, null -> false")
private static boolean isNullOrNamesEqual(@Nullable GoKey key, @Nullable GoNamedElement elementToCompare) {
return key == null || elementToCompare != null && Comparing.equal(key.getText(), elementToCompare.getName());
}

@NotNull
private static List<GoNamedElement> getFieldDefinitions(@Nullable GoStructType type) {
return type != null ? type.getFieldDeclarationList().stream()
.flatMap(declaration -> getFieldDefinitions(declaration).stream())
.collect(toList()) : emptyList();
}

@NotNull
private static List<? extends GoNamedElement> getFieldDefinitions(@NotNull GoFieldDeclaration declaration) {
GoAnonymousFieldDefinition anonymousDefinition = declaration.getAnonymousFieldDefinition();
return anonymousDefinition != null ? list(anonymousDefinition) : declaration.getFieldDefinitionList();
}

@Override
public JComponent createOptionsPanel() {
return new SingleCheckboxOptionsPanel("Report for local type definitions as well", this, "reportLocalStructs");
}

private static class GoReplaceWithNamedStructFieldQuickFix extends LocalQuickFixBase {
private String myStructField;

public GoReplaceWithNamedStructFieldQuickFix(@NotNull String structField) {
public GoReplaceWithNamedStructFieldQuickFix() {
super(REPLACE_WITH_NAMED_STRUCT_FIELD_FIX_NAME);
myStructField = structField;
}

@Override
public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
PsiElement startElement = descriptor.getStartElement();
if (startElement instanceof GoElement) {
startElement.replace(GoElementFactory.createLiteralValueElement(project, myStructField, startElement.getText()));
}
PsiElement element = ObjectUtils.tryCast(descriptor.getStartElement(), GoElement.class);
GoLiteralValue literal = element != null && element.isValid() ? PsiTreeUtil.getParentOfType(element, GoLiteralValue.class) : null;

List<GoElement> elements = literal != null ? literal.getElementList() : emptyList();
List<GoNamedElement> definitions = getFieldDefinitions(getLiteralStructType(literal));
if (!areElementsKeysMatchesDefinitions(elements, definitions)) return;
addKeysToElements(project, elements, definitions);
}
}

private static void addKeysToElements(@NotNull Project project,
@NotNull List<GoElement> elements,
@NotNull List<GoNamedElement> definitions) {
for (int i = 0; i < min(elements.size(), definitions.size()); i++) {
GoElement element = elements.get(i);
String fieldDefinitionName = definitions.get(i).getName();
GoValue value = fieldDefinitionName != null && element.getKey() == null ? element.getValue() : null;
if (value != null) element.replace(GoElementFactory.createLiteralValueElement(project, fieldDefinitionName, value.getText()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,7 @@ private static GoReferenceExpression unwrapParensAndCast(@Nullable PsiElement e)

@Contract("null -> false")
private static boolean isFieldReferenceExpression(@Nullable PsiElement element) {
return element instanceof GoReferenceExpression && isFieldDefinition(((GoReferenceExpression)element).resolve());
}

@Contract("null -> false")
private static boolean isFieldDefinition(@Nullable PsiElement element) {
return element instanceof GoFieldDefinition || element instanceof GoAnonymousFieldDefinition;
return element instanceof GoReferenceExpression && GoPsiImplUtil.isFieldDefinition(((GoReferenceExpression)element).resolve());
}

private static boolean isAssignedInPreviousStatement(@NotNull GoExpression referenceExpression,
Expand Down Expand Up @@ -182,7 +177,7 @@ private static GoCompositeLit getStructLiteral(@NotNull GoReferenceExpression fi
@NotNull GoAssignmentStatement structAssignment) {
GoVarDefinition varDefinition = ObjectUtils.tryCast(resolveQualifier(fieldReferenceExpression), GoVarDefinition.class);
PsiElement field = fieldReferenceExpression.resolve();
if (varDefinition == null || !isFieldDefinition(field) || !hasStructTypeWithField(varDefinition, (GoNamedElement)field)) {
if (varDefinition == null || !GoPsiImplUtil.isFieldDefinition(field) || !hasStructTypeWithField(varDefinition, (GoNamedElement)field)) {
return null;
}

Expand Down Expand Up @@ -218,7 +213,7 @@ private static boolean isUninitializedFieldReferenceExpression(@Nullable GoRefer
if (fieldReferenceExpression == null) return false;
GoLiteralValue literalValue = structLiteral.getLiteralValue();
PsiElement resolve = fieldReferenceExpression.resolve();
return literalValue != null && isFieldDefinition(resolve) &&
return literalValue != null && GoPsiImplUtil.isFieldDefinition(resolve) &&
!exists(literalValue.getElementList(), element -> isFieldInitialization(element, resolve));
}

Expand Down
7 changes: 7 additions & 0 deletions src/com/goide/psi/GoPsiTreeUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import com.intellij.psi.stubs.StubElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.psi.util.PsiUtilCore;
import com.intellij.util.ObjectUtils;
import com.intellij.util.SmartList;
import com.intellij.util.containers.ContainerUtil;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

Expand Down Expand Up @@ -155,5 +157,10 @@ private static PsiElement findNotWhiteSpaceElementAtOffset(@NotNull GoFile file,
}
return element;
}

@Contract("null,_->null")
public static <T> T getDirectParentOfType(@Nullable PsiElement element, @NotNull Class<T> aClass) {
return element != null ? ObjectUtils.tryCast(element.getParent(), aClass) : null;
}
}

2 changes: 1 addition & 1 deletion src/com/goide/psi/impl/GoElementFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ public static GoType createType(@NotNull Project project, @NotNull String text)
return PsiTreeUtil.findChildOfType(file, GoType.class);
}

public static PsiElement createLiteralValueElement(@NotNull Project project, @NotNull String key, @NotNull String value) {
public static GoElement createLiteralValueElement(@NotNull Project project, @NotNull String key, @NotNull String value) {
GoFile file = createFileFromText(project, "package a; var _ = struct { a string } { " + key + ": " + value + " }");
return PsiTreeUtil.findChildOfType(file, GoElement.class);
}
Expand Down
11 changes: 9 additions & 2 deletions src/com/goide/psi/impl/GoPsiImplUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,12 @@ public static GoType getLiteralType(@Nullable PsiElement context, boolean consid
@Nullable
public static GoValue getParentGoValue(@NotNull PsiElement element) {
PsiElement place = element;
while ((place = PsiTreeUtil.getParentOfType(place, GoLiteralValue.class)) != null) {
do {
if (place.getParent() instanceof GoValue) {
return (GoValue)place.getParent();
}
}
while ((place = PsiTreeUtil.getParentOfType(place, GoLiteralValue.class)) != null);
return null;
}

Expand Down Expand Up @@ -1468,7 +1469,8 @@ public static GoExpression getValue(@NotNull GoConstDefinition definition) {
return getByIndex(((GoConstSpec)parent).getExpressionList(), index);
}

private static <T> T getByIndex(@NotNull List<T> list, int index) {
@Nullable
public static <T> T getByIndex(@NotNull List<T> list, int index) {
return 0 <= index && index < list.size() ? list.get(index) : null;
}

Expand Down Expand Up @@ -1699,4 +1701,9 @@ public static GoExpression getRightExpression(@NotNull GoAssignmentStatement ass
int fieldIndex = assignment.getLeftHandExprList().getExpressionList().indexOf(leftExpression);
return getByIndex(assignment.getExpressionList(), fieldIndex);
}

@Contract("null -> false")
public static boolean isFieldDefinition(@Nullable PsiElement element) {
return element instanceof GoFieldDefinition || element instanceof GoAnonymousFieldDefinition;
}
}
11 changes: 11 additions & 0 deletions testData/inspections/struct-initialization/anonField-after.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package foo

type S struct {
X string
string
Y int
}
func main() {
var s S
s = S{X: "X", string: "a", Y: 1}
}
11 changes: 11 additions & 0 deletions testData/inspections/struct-initialization/anonField.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package foo

type S struct {
X string
string
Y int
}
func main() {
var s S
s = S{<caret><weak_warning descr="Unnamed field initialization">"X"</weak_warning>, <weak_warning descr="Unnamed field initialization">"a"</weak_warning>, Y: 1}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package foo

type S struct {
X, Y int
}
func main() {
s := S{X: 1, Y: 0, 2}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package foo

type S struct {
X, Y int
}
func main() {
s := S{<weak_warning descr="Unnamed field initialization"><caret>1</weak_warning>, <weak_warning descr="Unnamed field initialization">0</weak_warning>, 2}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package foo

type S struct {
X, Y int
}
func main() {
s := S{<caret>1, 0, X: 2}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package foo

type S struct {
t int
}

func main() {
var _ = []S{ {t: 1} }
}
Loading