-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
68 lines (51 loc) · 1.87 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import jax
import jax.numpy as np
from jax import nn
import chex
import objax
import numpy as onp
import requests
import os
import matplotlib.pyplot as plt
from pathlib import Path
from gpt import SimpleTokenizer, Data, GPT, ADAM, progress_bar_callback
from gpt import sample
from jax import make_jaxpr
with jax.checking_leaks():
# helper functions
def download_tinyshakespeare(out_path: Path, name):
out_path.mkdir(exist_ok=True)
# see https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare_char/prepare.py
data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
with open(out_path / name , 'w') as f:
f.write(requests.get(data_url).text)
def open_txt_file(file: Path):
with open(file, 'r') as f:
data = f.read()
return data
if True:
download_tinyshakespeare(Path('data'), 'tiny_shakespeare.txt')
data = open_txt_file(Path('data') / 'tiny_shakespeare.txt')
else:
data = open_txt_file(Path('data') / 'beatles.txt')
block_size = 128
batch_size = 64
embedding_size = 192
layers = 6
num_heads = 6
max_iters = 1000
tokenizer = SimpleTokenizer()
tokenizer.train(data)
data_enc = tokenizer.encode(data)
data_obj = Data(data_enc, block_size, batch_size)
x, t = data_obj.batch()
print(x.shape, t.shape)
gpt = GPT(tokenizer.vocab_size, block_size, embedding_size, num_heads, layers, seed=0)
trainer = ADAM(gpt)
breakpoint()
print(trainer.obj_fn(x, t))
breakpoint()
lc_arr = trainer.train(data_obj, 0.0005, max_iters, callback=progress_bar_callback(max_iters))
objax.io.save_var_collection(f'model.npz', gpt.vars())
print(sample('let it be', 1000, gpt, tokenizer, seed=0))
plt.plot(lc_arr[10:]); plt.yscale('log'); plt.show()