diff --git a/Project.toml b/Project.toml index ff53f3f..a33f6ea 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Eric P. Hanson"] version = "1.1.0" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" ConcurrentUtilities = "f0e56b4a-5159-44fe-b623-3e5288b988bb" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -21,12 +22,6 @@ Test = "1" Flux = "0.16" julia = "1.10" -[weakdeps] -Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" - -[extensions] -AdaptExt = "Adapt" - [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -34,4 +29,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Adapt", "Aqua", "Flux", "Random", "Test"] +test = ["Aqua", "Flux", "Random", "Test"] diff --git a/ext/AdaptExt.jl b/ext/AdaptExt.jl deleted file mode 100644 index fa4ae00..0000000 --- a/ext/AdaptExt.jl +++ /dev/null @@ -1,9 +0,0 @@ -module AdaptExt -using Adapt, AllocArrays - -# AllocArrays should be seen as "storage" like CuArray, not a wrapper like Transpose, -# since we want to recurse through models and convert arrays to AllocArrays, so that -# in the forward pass we can use our bump allocator -Adapt.adapt_storage(::Type{<:AllocArray}, xs::AbstractArray) = AllocArray(xs) - -end # AdaptExt diff --git a/src/AllocArrays.jl b/src/AllocArrays.jl index 6151a56..7aed488 100644 --- a/src/AllocArrays.jl +++ b/src/AllocArrays.jl @@ -4,6 +4,7 @@ using ScopedValues: ScopedValue, with using Bumper using ConcurrentUtilities using PrecompileTools +using Adapt # Two array types export AllocArray @@ -25,6 +26,7 @@ include("AllocArray.jl") include("CheckedAllocArray.jl") include("alloc_interface.jl") include("autoscaling_alloc_buffer.jl") +include("adapt.jl") const CURRENT_ALLOCATOR = ScopedValue{Allocator}(DEFAULT_ALLOCATOR) diff --git a/src/adapt.jl b/src/adapt.jl new file mode 100644 index 0000000..b3e7e13 --- /dev/null +++ b/src/adapt.jl @@ -0,0 +1,59 @@ +# AllocArrays should be seen as "storage" like CuArray, not a wrapper like Transpose, +# since we want to recurse through models and convert arrays to AllocArrays, so that +# in the forward pass we can use our bump allocator +Adapt.adapt_storage(::Type{<:AllocArray}, xs::AbstractArray) = copy_to_alloc(xs) + +function copy_to_alloc(xs::AbstractArray{T}) where {T} + arr = similar(AllocArray{T}, size(xs)) + copyto!(arr, xs) + return arr +end + +struct CachedUpdatelessFunction{T,F,A} + f::F + allocator::A +end +function CachedUpdatelessFunction{T}(f; allocator=BumperAllocator()) where {T} + return CachedUpdatelessFunction{T,typeof(f),typeof(allocator)}(f, allocator) +end +function CachedUpdatelessFunction(f; allocator=BumperAllocator()) + return CachedUpdatelessFunction{AllocArray}(f; allocator) +end + +function (c::CachedUpdatelessFunction{T})(args...; kw...) where {T} + with_allocator(c.allocator) do + try + tmp_f = Adapt.adapt(T, c.f) + args = Adapt.adapt(T, args) + kw = Adapt.adapt(T, kw) + return Adapt.adapt(Array, tmp_f(args...; kw...)) + finally + reset!(c.allocator) + end + end +end + +function Base.show(io::IO, mime::MIME"text/plain", c::CachedUpdatelessFunction{T}) where {T} + print(io, CachedUpdatelessFunction) + if T != AllocArray + print(io, "{$T}") + end + print(io, " with function\n") + show(io, mime, c.f) + print(io, "\nand allocator ") + show(io, mime, c.allocator) + return nothing +end + +function Base.show(io::IO, c::CachedUpdatelessFunction{T}) where {T} + print(io, CachedUpdatelessFunction) + if T != AllocArray + print(io, "{$T}") + end + print(io, "(") + show(io, c.f) + print(io, ", ") + show(io, c.allocator) + print(io, ")") + return nothing +end