diff --git a/Project.toml b/Project.toml index 8296a16..e0f891e 100644 --- a/Project.toml +++ b/Project.toml @@ -13,11 +13,9 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" -Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" [extensions] DeviceSparseArraysJLArraysExt = "JLArrays" -DeviceSparseArraysReactantExt = "Reactant" [compat] AcceleratedKernels = "0.4" @@ -26,6 +24,5 @@ ArrayInterface = "7" JLArrays = "0.3" KernelAbstractions = "0.9" LinearAlgebra = "1" -Reactant = "0.2.164" SparseArrays = "1" julia = "1.10" diff --git a/ext/DeviceSparseArraysReactantExt.jl b/ext/DeviceSparseArraysReactantExt.jl deleted file mode 100644 index 443ba00..0000000 --- a/ext/DeviceSparseArraysReactantExt.jl +++ /dev/null @@ -1,12 +0,0 @@ -module DeviceSparseArraysReactantExt - -import DeviceSparseArrays -import Reactant - -DeviceSparseArrays._check_type( - ::Type{T}, - ::Reactant.RArray{Reactant.TracedRNumber{T}}, -) where {T} = true -DeviceSparseArrays._get_eltype(::Reactant.RArray{Reactant.TracedRNumber{T}}) where {T} = T - -end diff --git a/src/helpers.jl b/src/helpers.jl index 91aea6f..025b005 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -1,11 +1,3 @@ -#= -A method to check that an AbstractArray is of a given element type. -This is needed because we can implement new methods for different arrays (e.g., Reactant.jl) -=# -_check_type(::Type{T}, v::AbstractArray{T}) where {T} = true -_check_type(::Type{T}, v::AbstractArray) where {T} = false - -_get_eltype(::AbstractArray{T}) where {T} = T - +# Helper functions to call AcceleratedKernels methods _sortperm_AK(x) = AcceleratedKernels.sortperm(x) _cumsum_AK(x) = AcceleratedKernels.cumsum(x) diff --git a/src/matrix_coo/matrix_coo.jl b/src/matrix_coo/matrix_coo.jl index 82f22a8..3c6240b 100644 --- a/src/matrix_coo/matrix_coo.jl +++ b/src/matrix_coo/matrix_coo.jl @@ -1,7 +1,7 @@ # DeviceSparseMatrixCOO implementation """ - DeviceSparseMatrixCOO{Tv,Ti,RowIndT<:AbstractVector,ColIndT<:AbstractVector,NzValT<:AbstractVector} <: AbstractDeviceSparseMatrix{Tv,Ti} + DeviceSparseMatrixCOO{Tv,Ti,RowIndT<:AbstractVector{Ti},ColIndT<:AbstractVector{Ti},NzValT<:AbstractVector{Tv}} <: AbstractDeviceSparseMatrix{Tv,Ti} Coordinate (COO) sparse matrix with generic storage vectors for row indices, column indices, and nonzero values. Buffer types (e.g. `Vector`, GPU array @@ -16,17 +16,18 @@ types) enable dispatch on device characteristics. """ struct DeviceSparseMatrixCOO{ Tv, - Ti<:Integer, - RowIndT<:AbstractVector, - ColIndT<:AbstractVector, - NzValT<:AbstractVector, + Ti, + RowIndT<:AbstractVector{Ti}, + ColIndT<:AbstractVector{Ti}, + NzValT<:AbstractVector{Tv}, } <: AbstractDeviceSparseMatrix{Tv,Ti} m::Int n::Int rowind::RowIndT colind::ColIndT nzval::NzValT - function DeviceSparseMatrixCOO{Tv,Ti,RowIndT,ColIndT,NzValT}( + + function DeviceSparseMatrixCOO( m::Integer, n::Integer, rowind::RowIndT, @@ -34,10 +35,10 @@ struct DeviceSparseMatrixCOO{ nzval::NzValT, ) where { Tv, - Ti<:Integer, - RowIndT<:AbstractVector, - ColIndT<:AbstractVector, - NzValT<:AbstractVector, + Ti, + RowIndT<:AbstractVector{Ti}, + ColIndT<:AbstractVector{Ti}, + NzValT<:AbstractVector{Tv}, } get_backend(rowind) == get_backend(colind) == get_backend(nzval) || throw(ArgumentError("All storage vectors must be on the same device/backend.")) @@ -46,39 +47,19 @@ struct DeviceSparseMatrixCOO{ n >= 0 || throw(ArgumentError("n must be non-negative")) SparseArrays.sparse_check_Ti(m, n, Ti) - _check_type(Ti, rowind) || throw(ArgumentError("rowind must be of type $Ti")) - _check_type(Ti, colind) || throw(ArgumentError("colind must be of type $Ti")) - _check_type(Tv, nzval) || throw(ArgumentError("nzval must be of type $Tv")) - length(rowind) == length(colind) == length(nzval) || throw(ArgumentError("rowind, colind, and nzval must have same length")) - return new(Int(m), Int(n), rowind, colind, nzval) + return new{Tv,Ti,RowIndT,ColIndT,NzValT}( + Int(m), + Int(n), + copy(rowind), + copy(colind), + copy(nzval), + ) end end -function DeviceSparseMatrixCOO( - m::Integer, - n::Integer, - rowind::RowIndT, - colind::ColIndT, - nzval::NzValT, -) where { - RowIndT<:AbstractVector{Ti}, - ColIndT<:AbstractVector{Ti}, - NzValT<:AbstractVector{Tv}, -} where {Ti<:Integer,Tv} - Ti2 = _get_eltype(rowind) - Tv2 = _get_eltype(nzval) - DeviceSparseMatrixCOO{Tv2,Ti2,RowIndT,ColIndT,NzValT}( - m, - n, - copy(rowind), - copy(colind), - copy(nzval), - ) -end - # Conversion from SparseMatrixCSC to COO function DeviceSparseMatrixCOO(A::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} m, n = size(A) diff --git a/src/matrix_csc/matrix_csc.jl b/src/matrix_csc/matrix_csc.jl index 68091a0..a399b11 100644 --- a/src/matrix_csc/matrix_csc.jl +++ b/src/matrix_csc/matrix_csc.jl @@ -1,7 +1,7 @@ # DeviceSparseMatrixCSC implementation """ - DeviceSparseMatrixCSC{Tv,Ti,ColPtrT= 0 || throw(ArgumentError("The number of elements must be non-negative.")) length(nzind) == length(nzval) || throw(ArgumentError("index and value vectors must be the same length")) - return new(Int(n), copy(nzind), copy(nzval)) - end -end -# Param inference constructor -function DeviceSparseVector( - n::Integer, - nzind::IndT, - nzval::ValT, -) where {IndT<:AbstractVector{Ti},ValT<:AbstractVector{Tv}} where {Ti<:Integer,Tv} - DeviceSparseVector{Tv,Ti,IndT,ValT}(n, nzind, nzval) + return new{Tv,Ti,IndT,ValT}(Int(n), copy(nzind), copy(nzval)) + end end # Conversions