Skip to content

Commit

Permalink
Change onehot to oneelement (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Feb 3, 2025
1 parent e9f1be2 commit 7a1ed53
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 14 deletions.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorBase"
uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.13"
version = "0.1.14"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -16,9 +16,11 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"

[weakdeps]
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"

[extensions]
ITensorBaseDiagonalArraysExt = "DiagonalArrays"
ITensorBaseSparseArraysBaseExt = ["NamedDimsArrays", "SparseArraysBase"]

[compat]
Accessors = "0.1.39"
Expand All @@ -28,6 +30,7 @@ FillArrays = "1.13.0"
LinearAlgebra = "1.10"
MapBroadcast = "0.1.5"
NamedDimsArrays = "0.4"
SparseArraysBase = "0.2.11"
UnallocatedArrays = "0.1.1"
UnspecifiedTypes = "0.1.1"
VectorInterface = "0.5.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module ITensorBaseSparseArraysBaseExt

using ITensorBase: ITensor, Index
using NamedDimsArrays: dename
using SparseArraysBase: SparseArraysBase, oneelement

function SparseArraysBase.oneelement(
value, index::NTuple{N,Int}, ax::NTuple{N,Index}
) where {N}
return ITensor(oneelement(value, index, only.(axes.(dename.(ax)))), ax)
end

end
25 changes: 12 additions & 13 deletions src/quirks.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
# TODO: Define this properly.
# TODO: Rename this to `dual`.
dag(i::Index) = i
# TODO: Define this properly.
# TODO: Rename this to `dual`.
dag(a::ITensor) = a
# TODO: Deprecate.

# TODO: Deprecate, just use `Int(length(i))` or
# `unname(length(i))` directly.
# Conversion to `Int` is used in case the output is named.
dim(i::Index) = Int(length(i))
# TODO: Deprecate.
# Conversion to `Int` is used in case the output is named.
# TODO: Deprecate, just use `Int(length(i))` or
# `unname(length(i))` directly.
dim(a::AbstractITensor) = Int(length(a))

# TODO: Define this properly.
hasqns(i::Index) = false
# TODO: Maybe rename to `isgraded(i::Index) = isgraded(dename(i))`.
hasqns(::Index) = false
# TODO: Define this properly.
hasqns(i::AbstractITensor) = false
# TODO: Maybe rename to `isgraded(a) = all(isgraded, axes(a))`.
hasqns(::AbstractITensor) = false

# This seems to be needed to get broadcasting working.
# TODO: Investigate this and see if we can get rid of it.
Base.Broadcast.extrude(a::AbstractITensor) = a

# TODO: Generalize this.
# Maybe define it as `oneelement`, and base it on
# `FillArrays.OneElement` (https://juliaarrays.github.io/FillArrays.jl/stable/#FillArrays.OneElement).
function onehot(elt::Type{<:Number}, iv::Pair{<:Index,<:Int})
a = ITensor(first(iv))
a[last(iv)] = one(elt)
return a
end
onehot(iv::Pair{<:Index,<:Int}) = onehot(Bool, iv)

# TODO: This is just a stand-in for truncated SVD
# that only makes use of `maxdim`, just to get some
# functionality running in `ITensorMPS.jl`.
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down
22 changes: 22 additions & 0 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using ITensorBase:
ITensorBase, ITensor, Index, gettag, hastag, inds, plev, prime, settag, tags, unsettag
using DiagonalArrays: δ, delta, diagview
using NamedDimsArrays: dename, name, named
using SparseArraysBase: oneelement
using Test: @test, @test_broken, @testset

@testset "ITensorBase" begin
Expand Down Expand Up @@ -53,4 +54,25 @@ using Test: @test, @test_broken, @testset
@test diagview(dename(a)) == ones(2)
end
end
@testset "oneelement" begin
i = Index(3)
a = oneelement(i => 2)
@test a isa ITensor
@test ndims(a) == 1
@test issetequal(inds(a), (i,))
@test eltype(a) === Bool
@test a[1] == 0
@test a[2] == 1
@test a[3] == 0

i = Index(3)
a = oneelement(Float32, i => 2)
@test a isa ITensor
@test ndims(a) == 1
@test issetequal(inds(a), (i,))
@test eltype(a) === Float32
@test a[1] == 0
@test a[2] == 1
@test a[3] == 0
end
end

0 comments on commit 7a1ed53

Please sign in to comment.