diff --git a/base/src/main/java/org/arend/typechecking/visitor/DefinitionTypechecker.java b/base/src/main/java/org/arend/typechecking/visitor/DefinitionTypechecker.java index 8b28e05b1..89afa61cd 100644 --- a/base/src/main/java/org/arend/typechecking/visitor/DefinitionTypechecker.java +++ b/base/src/main/java/org/arend/typechecking/visitor/DefinitionTypechecker.java @@ -1709,7 +1709,26 @@ private List typecheckFunctionBody(FunctionDefinition typedDef, C } if (!def.isRecursive()) { - typedDef.setResultType(termResult.type); + Expression newType = termResult.type; + if ((typedDef.isSFunc() || kind == FunctionKind.CONS) && typedDef.getResultType() != null) { + Expression normNewType = newType.normalize(NormalizationMode.WHNF); + Expression oldType = typedDef.getResultType().normalize(NormalizationMode.WHNF); + if (oldType instanceof ClassCallExpression oldClassCall && normNewType instanceof ClassCallExpression newClassCall) { + Map impls = new LinkedHashMap<>(); + for (Map.Entry entry : newClassCall.getImplementedHere().entrySet()) { + if (oldClassCall.isImplemented(entry.getKey())) { + impls.put(entry.getKey(), entry.getValue()); + } + } + if (impls.size() != newClassCall.getImplementedHere().size()) { + newClassCall = new ClassCallExpression(newClassCall.getDefinition(), newClassCall.getLevels(), impls, newClassCall.getDefinition().getSort(), newClassCall.getDefinition().getUniverseKind()); + newClassCall.updateHasUniverses(); + typechecker.fixClassExtSort(newClassCall, def.getResultType()); + newType = newClassCall; + } + } + } + typedDef.setResultType(newType); } if (termResult.expression != null) { typedDef.setBody(termResult.expression);