Skip to content

Commit

Permalink
There is a model. It might even train.
Browse files Browse the repository at this point in the history
  • Loading branch information
corbt committed Nov 12, 2024
1 parent e68d74b commit 676e806
Show file tree
Hide file tree
Showing 17 changed files with 2,073 additions and 324 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,6 @@ last_run_prepared/
data/
wandb/
reward_model_output/
models/
models/
remote/
rclone.conf
95 changes: 95 additions & 0 deletions cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import asyncio
from typing import Optional, Callable, Tuple, Any
import os
import aioboto3
import io

BUST_CACHE = os.getenv("BUST_CACHE", "")
CACHE_ONLY = os.getenv("CACHE_ONLY", "false").lower() != "false"
Expand Down Expand Up @@ -157,6 +159,79 @@ async def delete_by_fn_id(self, fn_id: str) -> None:
conn.commit()


class S3Backend(CacheBackend):
def __init__(
self, bucket_name: str, prefix: str = "cache/", region_name: str = "us-east-1"
):
self.bucket_name = bucket_name
self.prefix = prefix.rstrip("/") + "/"
self.session = aioboto3.Session()
self.region_name = region_name

async def setup(self) -> None:
# Ensure bucket exists
async with self.session.client("s3", region_name=self.region_name) as s3:
try:
await s3.head_bucket(Bucket=self.bucket_name)
except:
raise Exception(
f"Bucket {self.bucket_name} does not exist or is not accessible"
)

async def get(self, fn_id: str, arg_hash: str) -> Tuple[bool, Any]:
key = f"{self.prefix}{fn_id}/{arg_hash}"

async with self.session.client("s3", region_name=self.region_name) as s3:
try:
response = await s3.get_object(Bucket=self.bucket_name, Key=key)
async with response["Body"] as stream:
data = await stream.read()
return True, pickle.loads(data)
except:
return False, None

async def set(self, fn_id: str, arg_hash: str, result: Any) -> None:
key = f"{self.prefix}{fn_id}/{arg_hash}"
pickled_data = pickle.dumps(result)

async with self.session.client("s3", region_name=self.region_name) as s3:
await s3.put_object(Bucket=self.bucket_name, Key=key, Body=pickled_data)

async def delete(self, fn_id: str, arg_hash: str) -> None:
key = f"{self.prefix}{fn_id}/{arg_hash}"

async with self.session.client("s3", region_name=self.region_name) as s3:
await s3.delete_object(Bucket=self.bucket_name, Key=key)

async def delete_all(self) -> None:
async with self.session.client("s3", region_name=self.region_name) as s3:
paginator = s3.get_paginator("list_objects_v2")
async for page in paginator.paginate(
Bucket=self.bucket_name, Prefix=self.prefix
):
if "Contents" in page:
objects = [{"Key": obj["Key"]} for obj in page["Contents"]]
if objects:
await s3.delete_objects(
Bucket=self.bucket_name, Delete={"Objects": objects}
)

async def delete_by_fn_id(self, fn_id: str) -> None:
prefix = f"{self.prefix}{fn_id}/"

async with self.session.client("s3", region_name=self.region_name) as s3:
paginator = s3.get_paginator("list_objects_v2")
async for page in paginator.paginate(
Bucket=self.bucket_name, Prefix=prefix
):
if "Contents" in page:
objects = [{"Key": obj["Key"]} for obj in page["Contents"]]
if objects:
await s3.delete_objects(
Bucket=self.bucket_name, Delete={"Objects": objects}
)


class Cache:
def __init__(self, backend: CacheBackend):
self.backend = backend
Expand Down Expand Up @@ -279,3 +354,23 @@ async def bust_all(self) -> None:
"""Busts the entire cache"""
await self.ensure_setup()
await self.backend.delete_all()

async def set(self, key: str, value: Any) -> None:
"""Directly set a value in the cache using a string key"""
await self.ensure_setup()
fn_id = "__direct"
arg_hash = hashlib.sha256(key.encode()).hexdigest()
await self.backend.set(fn_id, arg_hash, value)

async def get(self, key: str) -> Any:
"""
Directly get a value from the cache using a string key
Raises KeyError if the key hasn't been set
"""
await self.ensure_setup()
fn_id = "__direct"
arg_hash = hashlib.sha256(key.encode()).hexdigest()
cache_hit, result = await self.backend.get(fn_id, arg_hash)
if not cache_hit:
raise KeyError(f"No cache entry found for key: {key}")
return result
72 changes: 60 additions & 12 deletions cache/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,49 @@
import pytest
import asyncio
import os
from cache import Cache, SQLiteBackend
from cache import Cache, SQLiteBackend, S3Backend
from moto import mock_aws
import boto3


# Test fixtures
@pytest.fixture
def cache():
# Use temporary test database
db_path = "test_cache.db"
backend = SQLiteBackend(db_path)
cache = Cache(backend)
@pytest.fixture(
params=[
# ("sqlite", lambda: SQLiteBackend("test_cache.db")),
(
"s3",
lambda: S3Backend("test-bucket", prefix="test/", region_name="us-east-1"),
),
]
)
def cache(request):
backend_type, backend_factory = request.param

yield cache
if backend_type == "sqlite":
db_path = "test_cache.db"
backend = backend_factory()
cache = Cache(backend)

# Cleanup
asyncio.run(cache.bust_all())
if os.path.exists(db_path):
os.remove(db_path)
yield cache

# Cleanup
asyncio.run(cache.bust_all())
if os.path.exists(db_path):
os.remove(db_path)
elif backend_type == "s3":
mock = mock_aws()
mock.start()

print("CREATING BUCKET")
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="test-bucket")

backend = backend_factory()
cache = Cache(backend)

yield cache

mock.stop()


# Test functions to be cached
Expand Down Expand Up @@ -134,3 +160,25 @@ async def test_bust_entire_cache(cache):
add2_hit, _ = await cached_add2.read_cache(4, 5)
assert not add1_hit
assert not add2_hit


@pytest.mark.asyncio
async def test_direct_cache_operations(cache):
# Test setting and getting a value
await cache.set("test_key", "test_value")
result = await cache.get("test_key")
assert result == "test_value"

# Test getting a non-existent key
with pytest.raises(KeyError):
await cache.get("nonexistent_key")

# Test overwriting an existing key
await cache.set("test_key", "new_value")
result = await cache.get("test_key")
assert result == "new_value"

# Test that bust_all clears direct cache entries
await cache.bust_all()
with pytest.raises(KeyError):
await cache.get("test_key")
52 changes: 52 additions & 0 deletions mount-remote.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/bin/bash

# Source environment variables
source .env

# Check if REMOTE_BUCKET is set
if [ -z "$REMOTE_BUCKET" ]; then
echo "Error: REMOTE_BUCKET environment variable is not set"
exit 1
fi

# Create remote directory if it doesn't exist
mkdir -p ./remote

# Configure rclone
cat > rclone.conf << EOF
[s3]
type = s3
provider = AWS
env_auth = false
access_key_id = $AWS_ACCESS_KEY_ID
secret_access_key = $AWS_SECRET_ACCESS_KEY
region = us-west-2
endpoint = s3.us-west-2.amazonaws.com
location_constraint = us-west-2
EOF

# Start rclone NFS server in the background
# Using port 12345 (you can change this), enabling full cache mode for write support
rclone serve nfs s3:$REMOTE_BUCKET --addr :12345 --vfs-cache-mode=full --config rclone.conf &
RCLONE_PID=$!

# Wait a moment for the server to start
sleep 2

# Mount the NFS share
mount -t nfs -o port=12345,mountport=12345,tcp localhost:/ ./remote

# Trap script exit to cleanup
cleanup() {
echo "Cleaning up..."
umount ./remote
kill $RCLONE_PID
}
trap cleanup EXIT

echo "NFS mount is ready at ./remote"
echo "Press Ctrl+C to unmount and exit"

# Keep the script running
wait $RCLONE_PID

5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
# "bitsandbytes>=0.44.1",
"aioboto3>=13.2.0",
# "bitsandbytes>=0.44.1",
"datasets>=3.0.1",
"dicttoxml>=1.7.16",
Expand All @@ -15,13 +17,16 @@ dependencies = [
"jupyter>=1.1.1",
# "liger-kernel>=0.3.1",
"matplotlib>=3.9.2",
"modal>=0.65.48",
"moto[s3]>=5.0.20",
"numpy>=1.26.4",
"peft>=0.13.2",
"plotly>=5.24.1",
"polars>=1.9.0",
"pytest-asyncio>=0.24.0",
"pytest>=8.3.3",
"python-dotenv>=1.0.1",
"s3fs>=2024.6.1",
"schedulefree>=1.2.7",
"scikit-learn>=1.5.2",
"seaborn>=0.13.2",
Expand Down
Loading

0 comments on commit 676e806

Please sign in to comment.