Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions configs/code_clippy_6B.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@
"seq": 2048,
"cores_per_replica": 8,
"per_replica_batch": 1,
"gradient_accumulation_steps": 16,
"gradient_accumulation_steps": 32,

"warmup_steps": 3000,
"anneal_steps": 300000,
"lr": 1.2e-4,
"end_lr": 1.2e-5,
"warmup_steps": 6200,
"anneal_steps": 613800,
"lr": 1e-5,
"end_lr": 1e-6,
"weight_decay": 0.1,
"total_steps": 350000,
"total_steps": 620000,

"tpu_size": 8,

"bucket": "code-clippy-bucket",
"model_dir": "code_clippy_6B",
"model_dir": "code_clippy_6B_v2",

"train_set": "code_clippy.train.index",
"val_set": {
Expand All @@ -36,7 +36,7 @@
"ckpt_every": 500,
"keep_every": 10000,

"name": "code_clippy_6B",
"name": "code_clippy_6B_v2",
"wandb_project": "mesh-transformer-jax",
"comment": ""
"comment": "Decreased learning rate and increased gradient steps"
}
14 changes: 11 additions & 3 deletions device_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def save(network, step, bucket, path, mp, aux=None, keep_n=3, delete_old=True):


def train_step(network, data):

inputs = {
"obs": data[:, :, :-1],
"target": data[:, :, 1:],
Expand Down Expand Up @@ -324,9 +325,14 @@ def eval_step(network, data):
exit()

start = time.time()
loss, last_loss, grad_norm, grad_norm_micro = train_step(
network, train_dataset.get_samples()
)
try:
loss, last_loss, grad_norm, grad_norm_micro = train_step(
network, train_dataset.get_samples()
)
except:
print(f'Skipped this batch bc of faulty sample.\n File name:{train_dataset.get_state()}')
wandb.log(train_dataset.get_state(),step)
continue
step += 1

steps_per_sec = 1 / (time.time() - start)
Expand Down Expand Up @@ -393,6 +399,8 @@ def eval_step(network, data):
"train/learning_rate": float(scheduler(network.state["opt_state"][-1].count[0].item())),
"sequences_processed": sequences_processed,
"tokens_processed": tokens_processed,
#visualize clipped gradients
# "clip_global_gradient_norm": clip_by_global_norm(1)[1]
}
wandb_stats.update(noise_scale_stats)

Expand Down
38 changes: 25 additions & 13 deletions generate_indexes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse

from pathlib import Path

import os
parser = argparse.ArgumentParser()
parser.add_argument(
"--gs_project_id",
Expand All @@ -15,11 +15,13 @@
"--output_dir", type=str, help="Where the indexes will be stored locally."
)

parser.add_argument ("--load_bucket", type=bool, help="Whether to load the folder structure from the Bucket directly", default = False, nargs = '?')

args = parser.parse_args()
input_dir = Path(args.input_dir)
root_dir = input_dir.name
output_dir = Path(args.output_dir)

load_bucket = args.load_bucket
# get the list of tfrecords
# train_tfrecords = [
# str(f)
Expand All @@ -32,20 +34,30 @@
# if "valid" in str(f)
# ]
# construct index file paths in the format of a gs bucket
train_indexes = [
f"gs://{args.gs_project_id}/{root_dir}/{str(f).split(f'/{root_dir}/')[-1]}"
for f in input_dir.glob("**/*.tfrecords")
if "train" in str(f)
]
print(train_indexes[:5])
val_indexes = [
f"gs://{args.gs_project_id}/{root_dir}/{str(f).split(f'/{root_dir}/')[-1]}"
for f in input_dir.glob("**/*.tfrecords")
if "valid" in str(f)
]
if load_bucket == False:
train_indexes = [
f"gs://{args.gs_project_id}/{root_dir}/{str(f).split(f'/{root_dir}/')[-1]}"
for f in input_dir.glob("**/*.tfrecords")
if "train" in str(f)
]
print(train_indexes[:5])
val_indexes = [
f"gs://{args.gs_project_id}/{root_dir}/{str(f).split(f'/{root_dir}/')[-1]}"
for f in input_dir.glob("**/*.tfrecords")
if "valid" in str(f)
]

else:
list_files = os.popen(f'gsutil ls -r gs://{args.gs_project_id}/{root_dir}').read().split('\n')
train_indexes = [f for f in list_files if 'train' in str(f) and '.tfrecords' in str(f)]
val_indexes = [f for f in list_files if 'valid' in str(f) and '.tfrecords' in str(f)]

with open(output_dir / "code_clippy.train.index", "w") as f:
f.write("\n".join(train_indexes))

with open(output_dir / "code_clippy.val.index", "w") as f:
f.write("\n".join(val_indexes))




78 changes: 78 additions & 0 deletions load_data_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
This script loads the data from Google Cloud Storage to a Huggingface dataset repository.
"""

import argparse

from pathlib import Path
import os
import time
import shutil
# from smart_open import open
# from google.cloud import storage
# from google.cloud.exceptions import NotFound

parser = argparse.ArgumentParser()
parser.add_argument(
"--gs_project_id",
type=str,
help="Google Cloud Storage project ID",
default = 'code-clippy-bucket'
)

parser.add_argument(
"--output_dir", type=str, help="Where the files will be temprorarily stored locally.", default = "../code_clippy_github/"
)


parser.add_argument(
"--input_dir", type=str, help="Where the dataset is stored on the GCS.", default = "code-clippy-dataset"
)


args = parser.parse_args()
root_dir = args.input_dir
output_dir = Path(args.output_dir)
# load_bucket = args.load_bucket

os.chdir(output_dir)

list_files = os.popen(f'gsutil ls -r gs://{args.gs_project_id}/{root_dir}').read().split('\n')

uploaded_files = os.popen(f'ls -r {output_dir}').read().split('\n')

#print(uploaded_files)
json_files_list = [f for f in list_files if '.json.gz' in str(f) and str(f.split('/')[-1]) not in uploaded_files]

print(json_files_list)

commited_files = []

for commit_num , file_path in enumerate(json_files_list,1):

os.system( f'gsutil cp {file_path} {file_path.split("/")[-1]}')
time.sleep(0.5)
os.system( f'git add {file_path.split("/")[-1]}')

commited_files.append(file_path.split("/")[-1])

if commit_num % 20 == 0:
time.sleep(1)
os.system(f'git commit -m \" adding dataset from GCS {commit_num}\"')
time.sleep(1)
os.system(f'git push https://USERNAME:PASSWORD@huggingface.co/datasets/repo.git')

time.sleep(1)

while len(commited_files) > 0:
os.remove(f'{commited_files.pop(0)}')

print('Done Deleting')

if commit_num % 200 == 0:
os.chdir('..')
shutil.rmtree(f'{str(output_dir).split("/")[-1]}')
# os.remove(f'{str(output_dir).split("/")[-1]}')
os.system(f'GIT_LFS_SKIP_SMUDGE=1 git clone https://USERNAME:PASSWORD@huggingface.co/datasets/CodedotAI/code_clippy_github.git')
os.chdir(f'{str(output_dir).split("/")[-1]}')
print(f'Completion: {commit_num/len(json_files_list) * 100} %')
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy~=1.19.5
tqdm~=4.45.0
tqdm~=4.62.3
wandb>=0.11.2
einops~=0.3.0
requests~=2.25.1
Expand All @@ -17,6 +17,6 @@ transformers
smart_open[gcs]
func_timeout
ftfy
fastapi
uvicorn
pathy
fastapi~=0.74.1
uvicorn~=0.2.2
pathy~=0.6.1