diff --git a/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java b/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java index 4c5c8402e..0aad42b3e 100644 --- a/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java @@ -56,10 +56,10 @@ void parametersDoNotMatch() { java( """ import java.util.Collection; - + class Test { Class> test = (Class>) get(); - + Class get() { return null; } @@ -76,7 +76,7 @@ void primitiveCast() { java( """ import java.io.DataOutputStream; - + class Test { void m(DataOutputStream out) { out.writeByte((byte) 0xff); @@ -96,31 +96,31 @@ void genericTypeVariableCast() { java( """ import java.util.Iterator; - + class GenericNumberIterable implements Iterable { - + private final Iterable wrappedIterable; - + GenericNumberIterable(Iterable wrap) { this.wrappedIterable = wrap; } - + @Override public Iterator iterator() { final Iterator iter = wrappedIterable.iterator(); - + return new Iterator() { @Override public boolean hasNext() { return iter.hasNext(); } - + @Override @SuppressWarnings("unchecked") public T next() { return (T) iter.next(); } - + @Override public void remove() { throw new UnsupportedOperationException(); @@ -141,7 +141,7 @@ void changeTypeCastInReturn() { java( """ import java.util.*; - + class Test { public > T test() { return (T) get(); @@ -162,7 +162,7 @@ void nonSamParameter() { java( """ import java.util.*; - + class Test { public boolean foo() { return Objects.equals("x", (Comparable) (s) -> 1); @@ -210,7 +210,7 @@ void wildcardGenericsInTargetType() { java( """ import java.util.List; - + class Test { Object o = null; List l = (List) o; @@ -309,7 +309,7 @@ void downCastParameterizedTypes() { java( """ import java.util.List; - + class Test { Object o = (List) method(); Object o2 = (List) method(); @@ -322,7 +322,7 @@ List method() { """, """ import java.util.List; - + class Test { Object o = method(); Object o2 = method(); @@ -402,7 +402,7 @@ void lambdaWithComplexTypeInference() { import java.util.Map; import java.util.function.Supplier; import java.util.stream.Collectors; - + class Test { void method() { Object o2 = new MapDropdownChoice( @@ -414,7 +414,7 @@ void method() { }); } } - + class MapDropdownChoice { public MapDropdownChoice(Supplier> choiceMap) { } @@ -448,7 +448,7 @@ void castWildcard() { """ import java.util.ArrayList; import java.util.List; - + class Test { void method() { List list = new ArrayList<>(); @@ -468,7 +468,7 @@ void removeImport() { """ import java.util.ArrayList; import java.util.List; - + class Test { List method(List list) { return (ArrayList) list; @@ -477,7 +477,7 @@ List method(List list) { """, """ import java.util.List; - + class Test { List method(List list) { return list; @@ -498,7 +498,7 @@ void retainCastInMarshaller() { package org.glassfish.jaxb.core.marshaller; import java.io.IOException; import java.io.Writer; - + public interface CharacterEscapeHandler { void escape( char[] ch, int start, int length, boolean isAttVal, Writer out ) throws IOException;\s } @@ -517,7 +517,7 @@ public interface Marshaller { """ import javax.xml.bind.Marshaller; import org.glassfish.jaxb.core.marshaller.CharacterEscapeHandler; - + class Foo { void bar(Marshaller marshaller) { marshaller.setProperty("org.glassfish.jaxb.characterEscapeHandler", (CharacterEscapeHandler) (ch, start, length, isAttVal, out) -> { @@ -528,4 +528,36 @@ void bar(Marshaller marshaller) { ) ); } + + @Test + void dontRemoveNecessaryDowncast() { + rewriteRun( + spec -> spec.recipe(new RemoveRedundantTypeCast()), + // language=java + java( + """ + package com.helloworld; + + import java.util.Optional; + + public class Foo { + public interface Bar {} + + public static class BarImpl implements Bar {} + + private Bar getBar() { + return new BarImpl(); + } + + private BarImpl getBarImpl() { + return new BarImpl(); + } + + public Bar baz() { + return Optional.of((Bar) getBarImpl()).orElse(getBar()); + } + } + """)); + } } +