diff --git a/src/norecompile.jl b/src/norecompile.jl
index 8446c8bcc..d9288547d 100644
--- a/src/norecompile.jl
+++ b/src/norecompile.jl
@@ -57,14 +57,14 @@ function wrapfun_iip(ff,
     dualT = dualgen(T)
     dualT1 = ArrayInterface.promote_eltype(T1, dualT)
     dualT2 = ArrayInterface.promote_eltype(T2, dualT)
-    dualT4 = dualgen(promote_type(T, T4))
+    dualT4 = promote_dual(dualgen(T4), dualT)
 
-    iip_arglists = (Tuple{T1, T2, T3, T4},
-        Tuple{dualT1, dualT2, T3, T4},
-        Tuple{dualT1, T2, T3, dualT4},
-        Tuple{dualT1, dualT2, T3, dualT4})
+    iip_arglists = (Tuple{T1, T2, T3, T4},    # primal
+        Tuple{dualT1, dualT2, T3, T4},        # vjp
+        Tuple{dualT1, T2, T3, dualT4},        # tgrad
+        )
 
-    iip_returnlists = ntuple(x -> Nothing, 4)
+    iip_returnlists = ntuple(x -> Nothing, length(iip_arglists))
 
     fwt = map(iip_arglists, iip_returnlists) do A, R
         FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff))