Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TypeParameterAccessors"
uuid = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.1"
version = "0.3.2"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
15 changes: 7 additions & 8 deletions src/base/abstractarray.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
struct Self end
position(a, ::Self) = Position(0)

Check warning on line 2 in src/base/abstractarray.jl

View check run for this annotation

Codecov / codecov/patch

src/base/abstractarray.jl#L2

Added line #L2 was not covered by tests
position(::Type, ::Self) = Position(0)
function set_type_parameters(type::Type, ::Self, param)
return error("Can't set the parent type of an unwrapped array type.")

Check warning on line 5 in src/base/abstractarray.jl

View check run for this annotation

Codecov / codecov/patch

src/base/abstractarray.jl#L4-L5

Added lines #L4 - L5 were not covered by tests
end

position(::Type{AbstractArray}, ::typeof(eltype)) = Position(1)
position(::Type{AbstractArray}, ::typeof(ndims)) = Position(2)
default_type_parameters(::Type{AbstractArray}) = (Float64, 1)
Expand All @@ -9,14 +16,6 @@
position(::Type{<:BitArray}, ::typeof(ndims)) = Position(1)
default_type_parameters(::Type{<:BitArray}) = (1,)

struct Self end
position(a, ::Self) = Position(0)
position(::Type{T}, ::Self) where {T} = Position(0)

function set_type_parameters(type::Type, ::Self, param)
return error("Can't set the parent type of an unwrapped array type.")
end

function set_eltype(array::AbstractArray, param)
return convert(set_eltype(typeof(array), param), array)
end
Expand Down
58 changes: 55 additions & 3 deletions src/type_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,62 @@
position(object, name) = position(typeof(object), name)
position(::Type, pos::Int) = Position(pos)
position(::Type, pos::Position) = pos

function position(type::Type, name)
base_type = unspecify_type_parameters(type)
base_type === type && error("`position` not defined for $type and $name.")
return position(base_type, name)
type′ = unspecify_type_parameters(type)
if type === type′
# Fallback definition that determines the
# position automatically from the supertype of
# the type.
return position_from_supertype(type′, name)
end
return position(type′, name)
end

# Automatically determine the position of a type parameter of a type given
# a supertype and the name of the parameter.
function position_from_supertype(type::Type, name)
type′ = unspecify_type_parameters(type)
supertype_pos = position(supertype(type′), name)
return position_from_supertype_position(type′, supertype_pos)
end

# Automatically determine the position of a type parameter of a type given
# the supertype and the position of the corresponding parameter in the supertype.
@generated function position_from_supertype_position(
::Type{T}, supertype_pos::Position
) where {T}
T′ = unspecify_type_parameters(T)
# The type parameters of the type as TypeVars.
# TODO: Ideally we would use `get_type_parameters`
# but that sometimes loses TypeVar names:
# https://github.com/ITensor/TypeParameterAccessors.jl/issues/30
type_params = Base.unwrap_unionall(T′).parameters
# The type parameters of the immediate supertype as TypeVars.
# This has TypeVars with names that correspond to the names of
# the TypeVars of the type parameters of `T`, for example:
# ```julia
# julia> struct MyArray{B,A} <: AbstractArray{A,B} end
#
# julia> Base.unwrap_unionall(MyArray).parameters
# svec(B, A)
#
# julia> Base.unwrap_unionall(supertype(MyArray)).parameters
# svec(A, B)
# ```
supertype_params = Base.unwrap_unionall(supertype(T)).parameters
supertype_param = supertype_params[Int(supertype_pos)]
pos = findfirst(param -> (param.name == supertype_param.name), type_params)
if isnothing(pos)
return error("Position not found.")
end
return :(@inline; $(Position(pos)))
end

# Automatically determine the position of a type parameter of a type given
# a supertype and the name of the parameter.
function position_from_supertype(type::Type, supertype_target::Type, name)
return position_from_supertype(type, supertype_target, position(supertype_target, name))

Check warning on line 85 in src/type_parameters.jl

View check run for this annotation

Codecov / codecov/patch

src/type_parameters.jl#L84-L85

Added lines #L84 - L85 were not covered by tests
end

function positions(::Type{T}, pos::Tuple) where {T}
Expand Down
34 changes: 33 additions & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@ using JLArrays: JLArray, JLMatrix, JLVector
using Test: @test, @test_throws, @test_broken, @testset
using TestExtras: @constinferred
using TypeParameterAccessors:
set_type_parameters, specify_type_parameters, type_parameters, unspecify_type_parameters
TypeParameterAccessors,
Position,
set_type_parameters,
specify_type_parameters,
type_parameters,
unspecify_type_parameters

const anyarrayts = (
(arrayt=Array, matrixt=Matrix, vectort=Vector),
Expand Down Expand Up @@ -78,3 +83,30 @@ const anyarrayts = (
(3, Float32)
end
end

@testset "Automatic fallback for position" begin
struct MyArray{B,A} <: AbstractArray{A,B} end
@test @constinferred(TypeParameterAccessors.position(MyArray, eltype)) == Position(2)
@test @constinferred(TypeParameterAccessors.position(MyArray{3,Float32}, eltype)) ==
Position(2)
@test @constinferred(TypeParameterAccessors.position(MyArray, ndims)) == Position(1)
@test @constinferred(TypeParameterAccessors.position(MyArray{3,Float32}, ndims)) ==
Position(1)

struct MyVector{X,Y,A<:Real} <: AbstractArray{A,1} end
@test @constinferred(TypeParameterAccessors.position(MyVector, eltype)) == Position(3)
@test @constinferred(
TypeParameterAccessors.position(MyVector{Int,(1, 2),Float32}, eltype)
) == Position(3)
@test_throws ErrorException TypeParameterAccessors.position(MyVector, ndims)
@test_throws ErrorException TypeParameterAccessors.position(
MyVector{Int,(1, 2),Float32}, ndims
)

struct MyBoolArray{X,Y,Z,B} <: AbstractArray{Bool,B} end
@test_throws ErrorException TypeParameterAccessors.position(MyBoolArray, eltype)
@test_throws ErrorException TypeParameterAccessors.position(MyBoolArray{1,2,3,4}, eltype)
@test @constinferred(TypeParameterAccessors.position(MyBoolArray, ndims)) == Position(4)
@test @constinferred(TypeParameterAccessors.position(MyBoolArray{1,2,3,4}, ndims)) ==
Position(4)
end
Loading