diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl
index 7e69dd39..d908924e 100644
--- a/src/ITensorNetworks.jl
+++ b/src/ITensorNetworks.jl
@@ -36,7 +36,7 @@ include("caches/beliefpropagationcache.jl")
 include("contraction_tree_to_graph.jl")
 include("gauging.jl")
 include("utils.jl")
-include("ITensorsExt/itensorutils.jl")
+include("ITensorsExtensions/ITensorsExtensions.jl")
 include("solvers/local_solvers/eigsolve.jl")
 include("solvers/local_solvers/exponentiate.jl")
 include("solvers/local_solvers/dmrg_x.jl")
diff --git a/src/ITensorsExt/itensorutils.jl b/src/ITensorsExt/itensorutils.jl
deleted file mode 100644
index e8ce2e26..00000000
--- a/src/ITensorsExt/itensorutils.jl
+++ /dev/null
@@ -1,90 +0,0 @@
-using LinearAlgebra: pinv
-using ITensors.NDTensors:
-  Block,
-  Tensor,
-  blockdim,
-  blockoffsets,
-  diaglength,
-  getdiagindex,
-  nzblocks,
-  setdiagindex!,
-  tensor,
-  DiagBlockSparseTensor,
-  DenseTensor,
-  BlockOffsets
-
-function NDTensors.blockoffsets(dense::DenseTensor)
-  return BlockOffsets{ndims(dense)}([Block(ntuple(Returns(1), ndims(dense)))], [0])
-end
-function NDTensors.nzblocks(dense::DenseTensor)
-  return nzblocks(blockoffsets(dense))
-end
-NDTensors.blockdim(ind::Int, ::Block{1}) = ind
-NDTensors.blockdim(i::Index{Int}, b::Integer) = blockdim(i, Block(b))
-NDTensors.blockdim(i::Index{Int}, b::Block) = blockdim(space(i), b)
-
-LinearAlgebra.isdiag(it::ITensor) = isdiag(tensor(it))
-
-function map_diag!(f::Function, it_destination::ITensor, it_source::ITensor)
-  return itensor(map_diag!(f, tensor(it_destination), tensor(it_source)))
-end
-map_diag(f::Function, it::ITensor) = map_diag!(f, copy(it), it)
-
-function map_diag!(f::Function, t_destination::Tensor, t_source::Tensor)
-  for i in 1:diaglength(t_destination)
-    setdiagindex!(t_destination, f(getdiagindex(t_source, i)), i)
-  end
-  return t_destination
-end
-map_diag(f::Function, t::Tensor) = map_diag!(f, copy(t), t)
-
-# Convenience functions
-sqrt_diag(it::ITensor) = map_diag(sqrt, it)
-inv_diag(it::ITensor) = map_diag(inv, it)
-invsqrt_diag(it::ITensor) = map_diag(inv ∘ sqrt, it)
-pinv_diag(it::ITensor) = map_diag(pinv, it)
-pinvsqrt_diag(it::ITensor) = map_diag(pinv ∘ sqrt, it)
-
-# Analagous to `denseblocks`.
-# Extract the diagonal entries into a diagonal tensor.
-function diagblocks(D::Tensor)
-  nzblocksD = nzblocks(D)
-  T = DiagBlockSparseTensor(eltype(D), nzblocksD, inds(D))
-  for b in nzblocksD
-    for n in 1:diaglength(D)
-      setdiagindex!(T, getdiagindex(D, n), n)
-    end
-  end
-  return T
-end
-
-diagblocks(it::ITensor) = itensor(diagblocks(tensor(it)))
-
-"""Given a vector of ITensors, separate them into groups of commuting itensors (i.e. itensors in the same group do not share any common indices)"""
-function group_commuting_itensors(its::Vector{ITensor})
-  remaining_its = copy(its)
-  it_groups = Vector{ITensor}[]
-
-  while !isempty(remaining_its)
-    cur_group = ITensor[]
-    cur_indices = Index[]
-    inds_to_remove = []
-    for i in 1:length(remaining_its)
-      it = remaining_its[i]
-      it_inds = inds(it)
-
-      if all([i ∉ cur_indices for i in it_inds])
-        push!(cur_group, it)
-        push!(cur_indices, it_inds...)
-        push!(inds_to_remove, i)
-      end
-    end
-    remaining_its = ITensor[
-      remaining_its[i] for
-      i in setdiff([i for i in 1:length(remaining_its)], inds_to_remove)
-    ]
-    push!(it_groups, cur_group)
-  end
-
-  return it_groups
-end
diff --git a/src/ITensorsExtensions/ITensorsExtensions.jl b/src/ITensorsExtensions/ITensorsExtensions.jl
index 66350c8f..5b58e663 100644
--- a/src/ITensorsExtensions/ITensorsExtensions.jl
+++ b/src/ITensorsExtensions/ITensorsExtensions.jl
@@ -12,7 +12,9 @@ using ITensors:
   map_diag,
   noncommonind,
   noprime,
+  replaceind,
   replaceinds,
+  sim,
   space,
   sqrt_decomp
 using ITensors.NDTensors:
@@ -52,16 +54,24 @@ invsqrt_diag(it::ITensor) = map_diag(inv ∘ sqrt, it)
 pinv_diag(it::ITensor) = map_diag(pinv, it)
 pinvsqrt_diag(it::ITensor) = map_diag(pinv ∘ sqrt, it)
 
-function map_itensor(
-  f::Function, A::ITensor, lind=first(inds(A)); regularization=nothing, kwargs...
-)
-  USV = svd(A, lind; kwargs...)
-  U, S, V, spec, u, v = USV
-  S = map_diag(s -> f(s + regularization), S)
-  sqrtDL, δᵤᵥ, sqrtDR = sqrt_decomp(S, u, v)
-  sqrtDR = denseblocks(sqrtDR) * denseblocks(δᵤᵥ)
-  L, R = U * sqrtDL, V * sqrtDR
-  return L * R
+#TODO: Make this work for non-hermitian A
+function eigendecomp(A::ITensor, linds, rinds; ishermitian=false, kwargs...)
+  @assert ishermitian
+  D, U = eigen(A, linds, rinds; ishermitian, kwargs...)
+  ul, ur = noncommonind(D, U), commonind(D, U)
+  Ul = replaceinds(U, vcat(rinds, ur), vcat(linds, ul))
+
+  return Ul, D, dag(U)
+end
+
+function map_eigvals(f::Function, A::ITensor, inds...; ishermitian=false, kwargs...)
+  if isdiag(A)
+    return map_diag(f, A)
+  end
+
+  Ul, D, Ur = eigendecomp(A, inds...; ishermitian, kwargs...)
+
+  return Ul * map_diag(f, D) * Ur
 end
 
 # Analagous to `denseblocks`.
diff --git a/src/apply.jl b/src/apply.jl
index 948ccb7c..9405e550 100644
--- a/src/apply.jl
+++ b/src/apply.jl
@@ -11,6 +11,8 @@ using ITensors:
   contract,
   dag,
   denseblocks,
+  factorize,
+  factorize_svd,
   hasqns,
   isdiag,
   noprime,
@@ -23,61 +25,9 @@ using ITensors.ContractionSequenceOptimization: optimal_contraction_sequence
 using ITensors.ITensorMPS: siteinds
 using KrylovKit: linsolve
 using LinearAlgebra: eigen, norm, svd
-using NamedGraphs: NamedEdge
+using NamedGraphs: NamedEdge, has_edge
 using Observers: Observers
 
-function sqrt_and_inv_sqrt(
-  A::ITensor; ishermitian=false, cutoff=nothing, regularization=nothing
-)
-  if isdiag(A)
-    A = map_diag(x -> x + regularization, A)
-    sqrtA = sqrt_diag(A)
-    inv_sqrtA = inv_diag(sqrtA)
-    return sqrtA, inv_sqrtA
-  end
-  @assert ishermitian
-  D, U = eigen(A; ishermitian, cutoff)
-  D = map_diag(x -> x + regularization, D)
-  sqrtD = sqrt_diag(D)
-  # sqrtA = U * sqrtD * prime(dag(U))
-  sqrtA = noprime(sqrtD * U)
-  inv_sqrtD = inv_diag(sqrtD)
-  # inv_sqrtA = U * inv_sqrtD * prime(dag(U))
-  inv_sqrtA = noprime(inv_sqrtD * dag(U))
-  return sqrtA, inv_sqrtA
-end
-
-function symmetric_factorize(
-  A::ITensor, inds...; (observer!)=nothing, tags="", svd_kwargs...
-)
-  if !isnothing(observer!)
-    Observers.insert_function!(
-      observer!, "singular_values" => (; singular_values) -> singular_values
-    )
-  end
-  U, S, V = svd(A, inds...; lefttags=tags, righttags=tags, svd_kwargs...)
-  u = commonind(S, U)
-  v = commonind(S, V)
-  sqrtS = sqrt_diag(S)
-  Fu = U * sqrtS
-  Fv = V * sqrtS
-  if hasqns(A)
-    # Hack to make a generalized (non-diagonal) `δ` tensor.
-    # TODO: Make this easier, `ITensors.δ` doesn't work here.
-    δᵤᵥ = copy(S)
-    ITensors.data(δᵤᵥ) .= true
-    Fu *= dag(δᵤᵥ)
-    S = denseblocks(S)
-    S *= prime(dag(δᵤᵥ), u)
-    S = diagblocks(S)
-  else
-    Fu = replaceinds(Fu, v => u)
-    S = replaceinds(S, v => u')
-  end
-  Observers.update!(observer!; singular_values=S)
-  return Fu, Fv
-end
-
 function full_update_bp(
   o,
   ψ,
@@ -86,7 +36,7 @@ function full_update_bp(
   nfullupdatesweeps=10,
   print_fidelity_loss=false,
   envisposdef=false,
-  (observer!)=nothing,
+  (singular_values!)=nothing,
   symmetrize=false,
   apply_kwargs...,
 )
@@ -117,8 +67,13 @@ function full_update_bp(
     apply_kwargs...,
   )
   if symmetrize
-    Rᵥ₁, Rᵥ₂ = symmetric_factorize(
-      Rᵥ₁ * Rᵥ₂, inds(Rᵥ₁); tags=edge_tag(v⃗[1] => v⃗[2]), observer!, apply_kwargs...
+    Rᵥ₁, Rᵥ₂ = factorize_svd(
+      Rᵥ₁ * Rᵥ₂,
+      inds(Rᵥ₁);
+      ortho="none",
+      tags=edge_tag(v⃗[1] => v⃗[2]),
+      singular_values!,
+      apply_kwargs...,
     )
   end
   ψᵥ₁ = Qᵥ₁ * Rᵥ₁
@@ -126,19 +81,31 @@ function full_update_bp(
   return ψᵥ₁, ψᵥ₂
 end
 
-function simple_update_bp_full(o, ψ, v⃗; envs, (observer!)=nothing, apply_kwargs...)
+function simple_update_bp_full(o, ψ, v⃗; envs, (singular_values!)=nothing, apply_kwargs...)
   cutoff = 10 * eps(real(scalartype(ψ)))
-  regularization = 10 * eps(real(scalartype(ψ)))
   envs_v1 = filter(env -> hascommoninds(env, ψ[v⃗[1]]), envs)
   envs_v2 = filter(env -> hascommoninds(env, ψ[v⃗[2]]), envs)
-  sqrt_and_inv_sqrt_envs_v1 =
-    sqrt_and_inv_sqrt.(envs_v1; ishermitian=true, cutoff, regularization)
-  sqrt_and_inv_sqrt_envs_v2 =
-    sqrt_and_inv_sqrt.(envs_v2; ishermitian=true, cutoff, regularization)
-  sqrt_envs_v1 = first.(sqrt_and_inv_sqrt_envs_v1)
-  inv_sqrt_envs_v1 = last.(sqrt_and_inv_sqrt_envs_v1)
-  sqrt_envs_v2 = first.(sqrt_and_inv_sqrt_envs_v2)
-  inv_sqrt_envs_v2 = last.(sqrt_and_inv_sqrt_envs_v2)
+  @assert all(ndims(env) == 2 for env in vcat(envs_v1, envs_v2))
+  sqrt_envs_v1 = [
+    ITensorsExtensions.map_eigvals(
+      sqrt, env, inds(env)[1], inds(env)[2]; cutoff, ishermitian=true
+    ) for env in envs_v1
+  ]
+  sqrt_envs_v2 = [
+    ITensorsExtensions.map_eigvals(
+      sqrt, env, inds(env)[1], inds(env)[2]; cutoff, ishermitian=true
+    ) for env in envs_v2
+  ]
+  inv_sqrt_envs_v1 = [
+    ITensorsExtensions.map_eigvals(
+      inv ∘ sqrt, env, inds(env)[1], inds(env)[2]; cutoff, ishermitian=true
+    ) for env in envs_v1
+  ]
+  inv_sqrt_envs_v2 = [
+    ITensorsExtensions.map_eigvals(
+      inv ∘ sqrt, env, inds(env)[1], inds(env)[2]; cutoff, ishermitian=true
+    ) for env in envs_v2
+  ]
   ψᵥ₁ᵥ₂_tn = [ψ[v⃗[1]]; ψ[v⃗[2]]; sqrt_envs_v1; sqrt_envs_v2]
   ψᵥ₁ᵥ₂ = contract(ψᵥ₁ᵥ₂_tn; sequence=contraction_sequence(ψᵥ₁ᵥ₂_tn; alg="optimal"))
   oψ = apply(o, ψᵥ₁ᵥ₂)
@@ -151,32 +118,44 @@ function simple_update_bp_full(o, ψ, v⃗; envs, (observer!)=nothing, apply_kwa
   v1_inds = [v1_inds; siteinds(ψ, v⃗[1])]
   v2_inds = [v2_inds; siteinds(ψ, v⃗[2])]
   e = v⃗[1] => v⃗[2]
-  ψᵥ₁, ψᵥ₂ = symmetric_factorize(oψ, v1_inds; tags=edge_tag(e), observer!, apply_kwargs...)
+  ψᵥ₁, ψᵥ₂ = factorize_svd(
+    oψ, v1_inds; ortho="none", tags=edge_tag(e), singular_values!, apply_kwargs...
+  )
   for inv_sqrt_env_v1 in inv_sqrt_envs_v1
-    # TODO: `dag` here?
-    ψᵥ₁ *= inv_sqrt_env_v1
+    ψᵥ₁ *= dag(inv_sqrt_env_v1)
   end
   for inv_sqrt_env_v2 in inv_sqrt_envs_v2
-    # TODO: `dag` here?
-    ψᵥ₂ *= inv_sqrt_env_v2
+    ψᵥ₂ *= dag(inv_sqrt_env_v2)
   end
   return ψᵥ₁, ψᵥ₂
 end
 
 # Reduced version
-function simple_update_bp(o, ψ, v⃗; envs, (observer!)=nothing, apply_kwargs...)
+function simple_update_bp(o, ψ, v⃗; envs, (singular_values!)=nothing, apply_kwargs...)
   cutoff = 10 * eps(real(scalartype(ψ)))
-  regularization = 10 * eps(real(scalartype(ψ)))
   envs_v1 = filter(env -> hascommoninds(env, ψ[v⃗[1]]), envs)
   envs_v2 = filter(env -> hascommoninds(env, ψ[v⃗[2]]), envs)
-  sqrt_and_inv_sqrt_envs_v1 =
-    sqrt_and_inv_sqrt.(envs_v1; ishermitian=true, cutoff, regularization)
-  sqrt_and_inv_sqrt_envs_v2 =
-    sqrt_and_inv_sqrt.(envs_v2; ishermitian=true, cutoff, regularization)
-  sqrt_envs_v1 = first.(sqrt_and_inv_sqrt_envs_v1)
-  inv_sqrt_envs_v1 = last.(sqrt_and_inv_sqrt_envs_v1)
-  sqrt_envs_v2 = first.(sqrt_and_inv_sqrt_envs_v2)
-  inv_sqrt_envs_v2 = last.(sqrt_and_inv_sqrt_envs_v2)
+  @assert all(ndims(env) == 2 for env in vcat(envs_v1, envs_v2))
+  sqrt_envs_v1 = [
+    ITensorsExtensions.map_eigvals(
+      sqrt, env, inds(env)[1], inds(env)[2]; cutoff, ishermitian=true
+    ) for env in envs_v1
+  ]
+  sqrt_envs_v2 = [
+    ITensorsExtensions.map_eigvals(
+      sqrt, env, inds(env)[1], inds(env)[2]; cutoff, ishermitian=true
+    ) for env in envs_v2
+  ]
+  inv_sqrt_envs_v1 = [
+    ITensorsExtensions.map_eigvals(
+      inv ∘ sqrt, env, inds(env)[1], inds(env)[2]; cutoff, ishermitian=true
+    ) for env in envs_v1
+  ]
+  inv_sqrt_envs_v2 = [
+    ITensorsExtensions.map_eigvals(
+      inv ∘ sqrt, env, inds(env)[1], inds(env)[2]; cutoff, ishermitian=true
+    ) for env in envs_v2
+  ]
   ψᵥ₁ = contract([ψ[v⃗[1]]; sqrt_envs_v1])
   ψᵥ₂ = contract([ψ[v⃗[2]]; sqrt_envs_v2])
   sᵥ₁ = siteinds(ψ, v⃗[1])
@@ -187,12 +166,16 @@ function simple_update_bp(o, ψ, v⃗; envs, (observer!)=nothing, apply_kwargs..
   rᵥ₂ = commoninds(Qᵥ₂, Rᵥ₂)
   oR = apply(o, Rᵥ₁ * Rᵥ₂)
   e = v⃗[1] => v⃗[2]
-  Rᵥ₁, Rᵥ₂ = symmetric_factorize(
-    oR, unioninds(rᵥ₁, sᵥ₁); tags=edge_tag(e), observer!, apply_kwargs...
+  Rᵥ₁, Rᵥ₂ = factorize_svd(
+    oR,
+    unioninds(rᵥ₁, sᵥ₁);
+    ortho="none",
+    tags=edge_tag(e),
+    singular_values!,
+    apply_kwargs...,
   )
-  # TODO: `dag` here?
-  Qᵥ₁ = contract([Qᵥ₁; inv_sqrt_envs_v1])
-  Qᵥ₂ = contract([Qᵥ₂; inv_sqrt_envs_v2])
+  Qᵥ₁ = contract([Qᵥ₁; dag.(inv_sqrt_envs_v1)])
+  Qᵥ₂ = contract([Qᵥ₂; dag.(inv_sqrt_envs_v2)])
   ψᵥ₁ = Qᵥ₁ * Rᵥ₁
   ψᵥ₂ = Qᵥ₂ * Rᵥ₂
   return ψᵥ₁, ψᵥ₂
@@ -207,7 +190,7 @@ function ITensors.apply(
   nfullupdatesweeps=10,
   print_fidelity_loss=false,
   envisposdef=false,
-  (observer!)=nothing,
+  (singular_values!)=nothing,
   variational_optimization_only=false,
   symmetrize=false,
   reduced=true,
@@ -243,15 +226,15 @@ function ITensors.apply(
         nfullupdatesweeps,
         print_fidelity_loss,
         envisposdef,
-        observer!,
+        singular_values!,
         symmetrize,
         apply_kwargs...,
       )
     else
       if reduced
-        ψᵥ₁, ψᵥ₂ = simple_update_bp(o, ψ, v⃗; envs, observer!, apply_kwargs...)
+        ψᵥ₁, ψᵥ₂ = simple_update_bp(o, ψ, v⃗; envs, singular_values!, apply_kwargs...)
       else
-        ψᵥ₁, ψᵥ₂ = simple_update_bp_full(o, ψ, v⃗; envs, observer!, apply_kwargs...)
+        ψᵥ₁, ψᵥ₂ = simple_update_bp_full(o, ψ, v⃗; envs, singular_values!, apply_kwargs...)
       end
     end
     if normalize
@@ -367,13 +350,13 @@ function ITensors.apply(o, ψ::VidalITensorNetwork; normalize=false, apply_kwarg
 
     for vn in neighbors(ψ, src(e))
       if (vn != dst(e))
-        ψv1 = noprime(ψv1 * inv_diag(bond_tensor(ψ, vn => src(e))))
+        ψv1 = noprime(ψv1 * ITensorsExtensions.inv_diag(bond_tensor(ψ, vn => src(e))))
       end
     end
 
     for vn in neighbors(ψ, dst(e))
       if (vn != src(e))
-        ψv2 = noprime(ψv2 * inv_diag(bond_tensor(ψ, vn => dst(e))))
+        ψv2 = noprime(ψv2 * ITensorsExtensions.inv_diag(bond_tensor(ψ, vn => dst(e))))
       end
     end
 
diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl
index 35769012..282b9ee8 100644
--- a/src/caches/beliefpropagationcache.jl
+++ b/src/caches/beliefpropagationcache.jl
@@ -8,7 +8,7 @@ using ITensors: dir
 using ITensors.ITensorMPS: ITensorMPS
 using NamedGraphs: boundary_partitionedges, partitionvertices, partitionedges
 
-default_message(inds_e) = ITensor[denseblocks(delta(inds_e))]
+default_message(inds_e) = ITensor[denseblocks(delta(i)) for i in inds_e]
 default_messages(ptn::PartitionedGraph) = Dictionary()
 default_message_norm(m::ITensor) = norm(m)
 function default_message_update(contract_list::Vector{ITensor}; kwargs...)
diff --git a/src/gauging.jl b/src/gauging.jl
index 4d2c4f6a..eb82c277 100644
--- a/src/gauging.jl
+++ b/src/gauging.jl
@@ -40,7 +40,7 @@ function ITensorNetwork(
 
   for e in edges(ψ)
     vsrc, vdst = src(e), dst(e)
-    root_S = sqrt_diag(bond_tensor(ψ_vidal, e))
+    root_S = ITensorsExtensions.sqrt_diag(bond_tensor(ψ_vidal, e))
     setindex_preserve_graph!(ψ, noprime(root_S * ψ[vsrc]), vsrc)
     setindex_preserve_graph!(ψ, noprime(root_S * ψ[vdst]), vdst)
   end
@@ -88,11 +88,12 @@ function vidalitensornetwork_preserve_cache(
     Y_D, Y_U = eigen(
       only(message(cache, reverse(pe))); ishermitian=true, cutoff=message_cutoff
     )
-    X_D, Y_D = map_diag(x -> x + regularization, X_D),
-    map_diag(x -> x + regularization, Y_D)
+    X_D, Y_D = ITensorsExtensions.map_diag(x -> x + regularization, X_D),
+    ITensorsExtensions.map_diag(x -> x + regularization, Y_D)
 
-    rootX_D, rootY_D = sqrt_diag(X_D), sqrt_diag(Y_D)
-    inv_rootX_D, inv_rootY_D = invsqrt_diag(X_D), invsqrt_diag(Y_D)
+    rootX_D, rootY_D = ITensorsExtensions.sqrt_diag(X_D), ITensorsExtensions.sqrt_diag(Y_D)
+    inv_rootX_D, inv_rootY_D = ITensorsExtensions.invsqrt_diag(X_D),
+    ITensorsExtensions.invsqrt_diag(Y_D)
     rootX = X_U * rootX_D * prime(dag(X_U))
     rootY = Y_U * rootY_D * prime(dag(Y_U))
     inv_rootX = X_U * inv_rootX_D * prime(dag(X_U))
diff --git a/src/treetensornetworks/opsum_to_ttn.jl b/src/treetensornetworks/opsum_to_ttn.jl
index 9b555547..51d8da07 100644
--- a/src/treetensornetworks/opsum_to_ttn.jl
+++ b/src/treetensornetworks/opsum_to_ttn.jl
@@ -2,7 +2,7 @@ using Graphs: degree, is_tree
 using ITensors: flux, has_fermion_string, itensor, ops, removeqns, space, val
 using ITensors.ITensorMPS: ITensorMPS, cutoff, linkdims, truncate!
 using ITensors.LazyApply: Prod, Sum, coefficient
-using ITensors.NDTensors: Block, maxdim, nblocks, nnzblocks
+using ITensors.NDTensors: Block, blockdim, maxdim, nblocks, nnzblocks
 using ITensors.Ops: Op, OpSum
 using NamedGraphs: degrees, is_leaf, vertex_path
 using StaticArrays: MVector
diff --git a/test/test_apply.jl b/test/test_apply.jl
index d4c408e0..fab04ceb 100644
--- a/test/test_apply.jl
+++ b/test/test_apply.jl
@@ -19,7 +19,7 @@ using Test: @test, @testset
 
 @testset "apply" begin
   Random.seed!(5623)
-  g_dims = (2, 3)
+  g_dims = (2, 2)
   n = prod(g_dims)
   g = named_grid(g_dims)
   s = siteinds("S=1/2", g)
diff --git a/test/test_itensornetwork.jl b/test/test_itensornetwork.jl
index cc5c6217..7012a1b5 100644
--- a/test/test_itensornetwork.jl
+++ b/test/test_itensornetwork.jl
@@ -34,6 +34,7 @@ using ITensors:
   scalartype,
   sim,
   uniqueinds
+using ITensors.NDTensors: NDTensors, dim
 using ITensorNetworks:
   ITensorNetworks,
   ⊗,
@@ -53,7 +54,6 @@ using ITensorNetworks:
   ttn
 using LinearAlgebra: factorize
 using NamedGraphs: NamedEdge, incident_edges, named_comb_tree, named_grid
-using NDTensors: NDTensors, dim
 using Random: Random, randn!
 using Test: @test, @test_broken, @testset
 
diff --git a/test/test_itensorsextensions.jl b/test/test_itensorsextensions.jl
new file mode 100644
index 00000000..b2438780
--- /dev/null
+++ b/test/test_itensorsextensions.jl
@@ -0,0 +1,76 @@
+@eval module $(gensym())
+using ITensors:
+  ITensors,
+  ITensor,
+  Index,
+  QN,
+  dag,
+  delta,
+  inds,
+  noprime,
+  op,
+  prime,
+  randomITensor,
+  replaceind,
+  replaceinds,
+  sim
+using ITensorNetworks.ITensorsExtensions: map_eigvals
+using Random: Random
+using Test: @test, @testset
+
+Random.seed!(1234)
+@testset "ITensorsExtensions" begin
+  @testset "Test map eigvals without QNS (eltype=$elt, dim=$n)" for elt in (
+      Float32, Float64, Complex{Float32}, Complex{Float64}
+    ),
+    n in (2, 3, 5, 10)
+
+    i, j = Index(n, "i"), Index(n, "j")
+    linds, rinds = Index[i], Index[j]
+    A = randn(elt, (n, n))
+    A = A * A'
+    P = ITensor(A, i, j)
+    sqrtP = map_eigvals(sqrt, P, linds, rinds; ishermitian=true)
+    inv_P = dag(map_eigvals(inv, P, linds, rinds; ishermitian=true))
+    inv_sqrtP = dag(map_eigvals(inv ∘ sqrt, P, linds, rinds; ishermitian=true))
+
+    sqrtPdag = replaceind(dag(sqrtP), i, i')
+    P2 = replaceind(sqrtP * sqrtPdag, i', j)
+    @test P2 ≈ P
+
+    invP = replaceind(inv_P, i, i')
+    I = invP * P
+    @test I ≈ delta(elt, inds(I))
+
+    inv_sqrtP = replaceind(inv_sqrtP, i, i')
+    I = inv_sqrtP * sqrtP
+    @test I ≈ delta(elt, inds(I))
+  end
+
+  @testset "Test map eigvals with QNS (eltype=$elt, dim=$n)" for elt in (
+      Float32, Float64, Complex{Float32}, Complex{Float64}
+    ),
+    n in (2, 3, 5, 10)
+
+    i, j = Index.(([QN() => n], [QN() => n]))
+    A = randomITensor(elt, i, j)
+    P = A * prime(dag(A), i)
+    sqrtP = map_eigvals(sqrt, P, i, i'; ishermitian=true)
+    inv_P = dag(map_eigvals(inv, P, i, i'; ishermitian=true))
+    inv_sqrtP = dag(map_eigvals(inv ∘ sqrt, P, i, i'; ishermitian=true))
+
+    new_ind = noprime(sim(i'))
+    sqrtPdag = replaceind(dag(sqrtP), i', new_ind)
+    P2 = replaceind(sqrtP * sqrtPdag, new_ind, i)
+    @test P2 ≈ P
+
+    inv_P = replaceind(inv_P, i', new_ind)
+    I = replaceind(inv_P * P, new_ind, i)
+    @test I ≈ op("I", i)
+
+    inv_sqrtP = replaceind(inv_sqrtP, i', new_ind)
+    I = replaceind(inv_sqrtP * sqrtP, new_ind, i)
+    @test I ≈ op("I", i)
+  end
+end
+end