Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Nov 20, 2023
1 parent 59899f3 commit 0372e72
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
2 changes: 2 additions & 0 deletions alr_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from alr_transformer.model import ALRTransformer
from alr_transformer.alr_block import ALRBlock
4 changes: 0 additions & 4 deletions alr_transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,3 @@ def forward(self, x):
x = self.transformer(x)
return self.to_logits(x)

x = torch.randn(1, 1024, 512)
model = ALRTransformer(512, 6, 10000, 64, 8, 4)
model(x).shape
# torch.Size([1, 1024, 10000])
17 changes: 17 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch
from alr_transformer import ALRTransformer

x = torch.randint(0, 100000, (1, 2048))

model = ALRTransformer(
dim = 512,
depth = 6,
num_tokens = 100000,
dim_head = 64,
heads = 8,
ff_mult = 4
)

out = model(x)
print(out)
print(out.shape)

0 comments on commit 0372e72

Please sign in to comment.