Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Eric P. Hanson"]
version = "1.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
ConcurrentUtilities = "f0e56b4a-5159-44fe-b623-3e5288b988bb"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand All @@ -21,17 +22,11 @@ Test = "1"
Flux = "0.16"
julia = "1.10"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

[extensions]
AdaptExt = "Adapt"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Adapt", "Aqua", "Flux", "Random", "Test"]
test = ["Aqua", "Flux", "Random", "Test"]
9 changes: 0 additions & 9 deletions ext/AdaptExt.jl

This file was deleted.

2 changes: 2 additions & 0 deletions src/AllocArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ScopedValues: ScopedValue, with
using Bumper
using ConcurrentUtilities
using PrecompileTools
using Adapt

# Two array types
export AllocArray
Expand All @@ -25,6 +26,7 @@ include("AllocArray.jl")
include("CheckedAllocArray.jl")
include("alloc_interface.jl")
include("autoscaling_alloc_buffer.jl")
include("adapt.jl")

const CURRENT_ALLOCATOR = ScopedValue{Allocator}(DEFAULT_ALLOCATOR)

Expand Down
59 changes: 59 additions & 0 deletions src/adapt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# AllocArrays should be seen as "storage" like CuArray, not a wrapper like Transpose,
# since we want to recurse through models and convert arrays to AllocArrays, so that
# in the forward pass we can use our bump allocator
Adapt.adapt_storage(::Type{<:AllocArray}, xs::AbstractArray) = copy_to_alloc(xs)

function copy_to_alloc(xs::AbstractArray{T}) where {T}
arr = similar(AllocArray{T}, size(xs))
copyto!(arr, xs)
return arr
end

struct CachedUpdatelessFunction{T,F,A}
f::F
allocator::A
end
function CachedUpdatelessFunction{T}(f; allocator=BumperAllocator()) where {T}
return CachedUpdatelessFunction{T,typeof(f),typeof(allocator)}(f, allocator)
end
function CachedUpdatelessFunction(f; allocator=BumperAllocator())
return CachedUpdatelessFunction{AllocArray}(f; allocator)
end

function (c::CachedUpdatelessFunction{T})(args...; kw...) where {T}
with_allocator(c.allocator) do
try
tmp_f = Adapt.adapt(T, c.f)
args = Adapt.adapt(T, args)
kw = Adapt.adapt(T, kw)
return Adapt.adapt(Array, tmp_f(args...; kw...))
finally
reset!(c.allocator)
end
end
end

function Base.show(io::IO, mime::MIME"text/plain", c::CachedUpdatelessFunction{T}) where {T}
print(io, CachedUpdatelessFunction)
if T != AllocArray
print(io, "{$T}")
end
print(io, " with function\n")
show(io, mime, c.f)
print(io, "\nand allocator ")
show(io, mime, c.allocator)
return nothing
end

function Base.show(io::IO, c::CachedUpdatelessFunction{T}) where {T}
print(io, CachedUpdatelessFunction)
if T != AllocArray
print(io, "{$T}")
end
print(io, "(")
show(io, c.f)
print(io, ", ")
show(io, c.allocator)
print(io, ")")
return nothing
end
Loading