Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ITensorNetworksNext"
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
version = "0.4.4"
version = "0.5.0"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
Expand All @@ -13,14 +13,13 @@ AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d"
BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
Expand All @@ -43,19 +42,18 @@ AlgorithmsInterface = "0.1"
BackendSelection = "0.1.6"
Combinatorics = "1"
DataGraphs = "0.4"
DiagonalArrays = "0.3.31"
Dictionaries = "0.4.5"
FunctionImplementations = "0.4.1"
Graphs = "1.13.1"
ITensorBase = "0.6.2"
LinearAlgebra = "1.10"
MacroTools = "0.5.16"
MatrixAlgebraKit = "0.6"
NamedDimsArrays = "0.15.8"
NamedGraphs = "0.11"
Random = "1.10"
SimpleTraits = "0.9.5"
SplitApplyCombine = "1.2.3"
TensorAlgebra = "0.9.5"
TensorAlgebra = "0.9.7"
TensorOperations = "5.3.1"
TermInterface = "2"
TypeParameterAccessors = "0.4.4"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ path = ".."
[compat]
Documenter = "1"
ITensorFormatter = "0.2.27"
ITensorNetworksNext = "0.4"
ITensorNetworksNext = "0.5"
Literate = "2"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
path = ".."

[compat]
ITensorNetworksNext = "0.4"
ITensorNetworksNext = "0.5"
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
module ITensorNetworksNextTensorOperationsExt

using BackendSelection: @Algorithm_str, Algorithm
using ITensorNetworksNext.LazyNamedDimsArrays.TermInterface: arguments
using ITensorNetworksNext.LazyNamedDimsArrays:
LazyNamedDimsArrays, ismul, substitute, symnameddims
using NamedDimsArrays: inds
using ITensorBase: denamed, inds
using ITensorNetworksNext.LazyITensors.TermInterface: arguments
using ITensorNetworksNext.LazyITensors: LazyITensors, ismul, substitute, symnameddims
using TensorOperations: TensorOperations, optimaltree

function contraction_tree_to_expr(f, tree)
Expand All @@ -15,12 +14,12 @@ function contraction_tree_to_expr(f, tree)
end
end

function LazyNamedDimsArrays.optimize_contraction_order(alg::Algorithm"optimal", a)
function LazyITensors.optimize_contraction_order(alg::Algorithm"optimal", a)
@assert ismul(a)
ts = arguments(a)
inds_network = collect.(inds.(ts))
# Converting dims to Float64 to minimize overflow issues
inds_to_dims = Dict(i => Float64(length(i)) for i in reduce(∪, inds_network))
inds_to_dims = Dict(i => Float64(length(denamed(i))) for i in reduce(∪, inds_network))
tree, _ = optimaltree(inds_network, inds_to_dims)
return contraction_tree_to_expr(i -> ts[i], tree)
end
Expand Down
2 changes: 1 addition & 1 deletion src/ITensorNetworksNext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ end

include("select_algorithm.jl")
include("AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl")
include("LazyNamedDimsArrays/LazyNamedDimsArrays.jl")
include("LazyITensors/LazyITensors.jl")
include("abstracttensornetwork.jl")
include("tensornetwork.jl")
include("TensorNetworkGenerators/TensorNetworkGenerators.jl")
Expand Down
12 changes: 12 additions & 0 deletions src/LazyITensors/LazyITensors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module LazyITensors

include("baseextensions.jl")
include("itensorbaseextensions.jl")
include("applied.jl")
include("lazyinterface.jl")
include("lazybroadcast.jl")
include("lazyitensor.jl")
include("symbolicitensor.jl")
include("evaluation_order.jl")

end
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using AbstractTrees: AbstractTrees
using TermInterface: TermInterface, arguments, iscall, operation
using TypeParameterAccessors: unspecify_type_parameters

# Generic functionality for Applied types, like `Mul`, `Add`, etc.
ismul(a) = iscall(a) && operation(a) ≡ *
Expand All @@ -20,9 +19,9 @@ function maketerm_applied(type, head, args, metadata)
@assert head ≡ operation(term)
return term
end
map_arguments_applied(f, a) = unspecify_type_parameters(typeof(a))(map(f, arguments(a)))
map_arguments_applied(f, a) = Base.typename(typeof(a)).wrapper(map(f, arguments(a)))
function hash_applied(a, h::UInt64)
h = hash(Symbol(unspecify_type_parameters(typeof(a))), h)
h = hash(Symbol(Base.typename(typeof(a)).wrapper), h)
for arg in arguments(a)
h = hash(arg, h)
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using NamedDimsArrays: denamed, dimnames, inds
using ITensorBase: denamed, dimnames, inds
using TermInterface: arguments, arity, operation

# The time complexity of evaluating `f(args...)`.
Expand All @@ -14,22 +14,22 @@ function input_space_complexity(f, args...)
return error("Not implemented.")
end

using NamedDimsArrays: AbstractNamedDimsArray
using ITensorBase: AbstractITensor
function time_complexity(
::typeof(*), t1::AbstractNamedDimsArray, t2::AbstractNamedDimsArray
::typeof(*), t1::AbstractITensor, t2::AbstractITensor
)
return prod(length ∘ denamed, (inds(t1) ∪ inds(t2)))
end
function time_complexity(
::typeof(+), t1::AbstractNamedDimsArray, t2::AbstractNamedDimsArray
::typeof(+), t1::AbstractITensor, t2::AbstractITensor
)
@assert issetequal(dimnames(t1), dimnames(t2))
return prod(denamed, size(t1))
end
function time_complexity(::typeof(*), c::Number, t::AbstractNamedDimsArray)
function time_complexity(::typeof(*), c::Number, t::AbstractITensor)
return prod(denamed, size(t))
end
function time_complexity(::typeof(*), t::AbstractNamedDimsArray, c::Number)
function time_complexity(::typeof(*), t::AbstractITensor, c::Number)
return time_complexity(*, c, t)
end

Expand Down
23 changes: 23 additions & 0 deletions src/LazyITensors/itensorbaseextensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using ITensorBase: ITensor, denamed, inds
# Defined to avoid type piracy.
# TODO: Define a proper hash function
# in ITensorBase.jl, maybe one that is
# independent of the order of dimensions.
function _hash(a::ITensor, h::UInt64)
h = hash(:ITensor, h)
h = hash(denamed(a), h)
for i in inds(a)
h = hash(i, h)
end
return h
end
function _hash(x, h::UInt64)
return hash(x, h)
end

using AbstractTrees: AbstractTrees
# Custom version of `AbstractTrees.printnode` to
# avoid type piracy when overloading on `AbstractITensor`.
# Method specializations (`LazyITensor`, `SymbolicITensor`) live in
# `lazyitensor.jl` and `symbolicitensor.jl`.
printnode_nameddims(io::IO, x) = AbstractTrees.printnode(io, x)
13 changes: 13 additions & 0 deletions src/LazyITensors/lazybroadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Lazy broadcasting.
struct LazyITensorStyle <: Base.Broadcast.AbstractArrayStyle{Any} end
function Broadcast.broadcasted(::LazyITensorStyle, f, as...)
return error("Arbitrary broadcasting not supported for LazyITensor.")
end
# Linear operations.
Broadcast.broadcasted(::LazyITensorStyle, ::typeof(+), a1, a2) = a1 + a2
Broadcast.broadcasted(::LazyITensorStyle, ::typeof(-), a1, a2) = a1 - a2
Broadcast.broadcasted(::LazyITensorStyle, ::typeof(*), c::Number, a) = c * a
Broadcast.broadcasted(::LazyITensorStyle, ::typeof(*), a, c::Number) = a * c
Broadcast.broadcasted(::LazyITensorStyle, ::typeof(*), a::Number, b::Number) = a * b
Broadcast.broadcasted(::LazyITensorStyle, ::typeof(/), a, c::Number) = a / c
Broadcast.broadcasted(::LazyITensorStyle, ::typeof(-), a) = -a
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using NamedDimsArrays: denamed, dimnames, inds
using ITensorBase: denamed, dimnames, inds
using TermInterface: iscall, maketerm, operation, sorted_arguments
using WrappedUnions: unwrap

Expand All @@ -23,19 +23,13 @@ opwalk(opmap, a) = walk(opmap, identity, a)
argwalk(argmap, a) = walk(identity, argmap, a)

# Generic lazy functionality.
using FunctionImplementations: AbstractArrayImplementationStyle
struct LazyNamedDimsArrayImplementationStyle <: AbstractArrayImplementationStyle end
const lazy_style = LazyNamedDimsArrayImplementationStyle()

const maketerm_lazy = lazy_style(maketerm)
function maketerm_lazy(type::Type, head, args, metadata)
if head ≡ *
return type(maketerm(Mul, head, args, metadata))
else
return error("Only mul supported right now.")
end
end
const getindex_lazy = lazy_style(getindex)
function getindex_lazy(a::AbstractArray, I...)
u = unwrap(a)
if !iscall(u)
Expand All @@ -44,7 +38,6 @@ function getindex_lazy(a::AbstractArray, I...)
return error("Indexing into expression not supported.")
end
end
const arguments_lazy = lazy_style(arguments)
function arguments_lazy(a)
u = unwrap(a)
if !iscall(u)
Expand All @@ -56,17 +49,12 @@ function arguments_lazy(a)
end
end
using TermInterface: children
const children_lazy = lazy_style(children)
children_lazy(a) = arguments(a)
using TermInterface: head
const head_lazy = lazy_style(head)
head_lazy(a) = operation(a)
const iscall_lazy = lazy_style(iscall)
iscall_lazy(a) = iscall(unwrap(a))
using TermInterface: isexpr
const isexpr_lazy = lazy_style(isexpr)
isexpr_lazy(a) = iscall(a)
const operation_lazy = lazy_style(operation)
function operation_lazy(a)
u = unwrap(a)
if !iscall(u)
Expand All @@ -77,7 +65,6 @@ function operation_lazy(a)
return error("Variant not supported.")
end
end
const sorted_arguments_lazy = lazy_style(sorted_arguments)
function sorted_arguments_lazy(a)
u = unwrap(a)
if !iscall(u)
Expand All @@ -89,12 +76,9 @@ function sorted_arguments_lazy(a)
end
end
using TermInterface: sorted_children
const sorted_children_lazy = lazy_style(sorted_children)
sorted_children_lazy(a) = sorted_arguments(a)
const ismul_lazy = lazy_style(ismul)
ismul_lazy(a) = ismul(unwrap(a))
using AbstractTrees: AbstractTrees
const abstracttrees_children_lazy = lazy_style(AbstractTrees.children)
function abstracttrees_children_lazy(a)
if !iscall(a)
return ()
Expand All @@ -103,7 +87,6 @@ function abstracttrees_children_lazy(a)
end
end
using AbstractTrees: nodevalue
const nodevalue_lazy = lazy_style(nodevalue)
function nodevalue_lazy(a)
if !iscall(a)
return unwrap(a)
Expand All @@ -112,11 +95,8 @@ function nodevalue_lazy(a)
end
end
using Base.Broadcast: materialize
const materialize_lazy = lazy_style(materialize)
materialize_lazy(a) = argwalk(unwrap, a)
const copy_lazy = lazy_style(copy)
copy_lazy(a) = materialize(a)
const equals_lazy = lazy_style(==)
function equals_lazy(a1, a2)
u1, u2 = unwrap.((a1, a2))
if !iscall(u1) && !iscall(u2)
Expand All @@ -127,7 +107,6 @@ function equals_lazy(a1, a2)
return false
end
end
const isequal_lazy = lazy_style(isequal)
function isequal_lazy(a1, a2)
u1, u2 = unwrap.((a1, a2))
if !iscall(u1) && !iscall(u2)
Expand All @@ -138,13 +117,11 @@ function isequal_lazy(a1, a2)
return false
end
end
const hash_lazy = lazy_style(hash)
function hash_lazy(a, h::UInt64)
h = hash(Symbol(unspecify_type_parameters(typeof(a))), h)
# Use `_hash`, which defines a custom hash for NamedDimsArray.
h = hash(Symbol(Base.typename(typeof(a)).wrapper), h)
# Use `_hash`, which defines a custom hash for ITensor.
return _hash(unwrap(a), h)
end
const map_arguments_lazy = lazy_style(map_arguments)
function map_arguments_lazy(f, a)
u = unwrap(a)
if !iscall(u)
Expand All @@ -156,21 +133,18 @@ function map_arguments_lazy(f, a)
end
end
function substitute end
const substitute_lazy = lazy_style(substitute)
function substitute_lazy(a, substitutions::AbstractDict)
haskey(substitutions, a) && return substitutions[a]
!iscall(a) && return a
return map_arguments(arg -> substitute(arg, substitutions), a)
end
substitute_lazy(a, substitutions) = substitute(a, Dict(substitutions))
using AbstractTrees: printnode
const printnode_lazy = lazy_style(printnode)
function printnode_lazy(io, a)
# Use `printnode_nameddims` to avoid type piracy,
# since it overloads on `AbstractNamedDimsArray`.
# since it overloads on `AbstractITensor`.
return printnode_nameddims(io, unwrap(a))
end
const show_lazy = lazy_style(show)
function show_lazy(io::IO, a)
if !iscall(a)
return show(io, unwrap(a))
Expand All @@ -184,12 +158,9 @@ function show_lazy(io::IO, mime::MIME"text/plain", a)
!iscall(a) ? show(io, mime, unwrap(a)) : show(io, a)
return nothing
end
const add_lazy = lazy_style(+)
add_lazy(a1, a2) = error("Not implemented.")
const sub_lazy = lazy_style(-)
sub_lazy(a) = error("Not implemented.")
sub_lazy(a1, a2) = error("Not implemented.")
const mul_lazy = lazy_style(*)
function mul_lazy(a)
u = unwrap(a)
if !iscall(u)
Expand All @@ -212,8 +183,7 @@ mul_lazy(a1, a2::Number) = error("Not implemented.")
mul_lazy(a1::Number, a2::Number) = a1 * a2
div_lazy(a1, a2::Number) = error("Not implemented.")

# NamedDimsArrays.jl interface.
const dimnames_lazy = lazy_style(dimnames)
# ITensorBase.jl named-tensor interface.
function dimnames_lazy(a)
u = unwrap(a)
if !iscall(u)
Expand All @@ -224,7 +194,6 @@ function dimnames_lazy(a)
return error("Variant not supported.")
end
end
const inds_lazy = lazy_style(inds)
function inds_lazy(a)
u = unwrap(a)
if !iscall(u)
Expand All @@ -235,7 +204,6 @@ function inds_lazy(a)
return error("Variant not supported.")
end
end
const denamed_lazy = lazy_style(denamed)
function denamed_lazy(a)
u = unwrap(a)
if !iscall(u)
Expand Down
Loading