Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions run_clip_RN50_Imagesonly.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/bin/bash

#SBATCH --output=logs/out_%A_%j.log
#SBATCH --nodes=1
#SBATCH -p GPU-shared
#SBATCH --ntasks-per-node=4
#SBATCH --gpus=v100-32:4
#SBATCH --cpus-per-task=4
#SBATCH --time=20:00:00
#SBATCH --mem=128GB
#SBATCH --job-name=finish_train_network
#SBATCH --mail-type=BEGIN,END
#SBATCH --mail-user=mnk2978@nyu.edu

module purge;
set -x

#debug flags
echo $SLURM_JOB_NAME

#env vars
export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
export MASTER_PORT=$(shuf -i 10000-65500 -n 1)
export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_NTASKS_PER_NODE))
echo "WORLD_SIZE="$WORLD_SIZE
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR
export CUDA_VISIBLE_DEVICES=0,1,2,3
#comments

#trained on iNaturalist with captions
#run command
srun --accel-bind=v \
/bin/bash /ocean/projects/cis220069p/mnk2989/setup_singularity.bash \
/bin/bash -c \
'export WANDB_API_KEY=983121974971d75792b8a5ffdd8058944431e1f2; export PYTHONPATH="$PYTHONPATH:/ocean/projects/cis220069p/mnk2989/vlhub-forked/src"; python /ocean/projects/cis220069p/mnk2989/vlhub-forked/src/training/main.py \
--train-data="/ocean/projects/cis220069p/mnk2989/vlhub-forked/metadata/inat2021_with_labels.csv" \
--report-to wandb \
--csv-separator "," \
--inat2021 "/ocean/projects/cis220069p/mnk2989/data/datasets/" \
--zeroshot-frequency=1 \
--integer-labels \
--csv-caption-key title \
--save-frequency 4 \
--warmup 2000 \
--batch-size=128 \
--precision=fp32 \
--epochs=32 \
--resume "/ocean/projects/cis220069p/mnk2989/logs/2022_12_03-02_02_56-model_RN50-imagesonly-lr_0.0005-b_256-j_4-p_fp32/checkpoints/epoch_28.pt" \
--workers=4 --model=RN50-imagesonly --local-loss --gather-with-grad --ds-filter inat_classnames'
54 changes: 54 additions & 0 deletions run_clip_ResNet50.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/bin/bash

#SBATCH --output=logs/out_%A_%j.log
#SBATCH --nodes=1
#SBATCH -p GPU-shared
#SBATCH --ntasks-per-node=4
#SBATCH --gpus=v100-32:4
#SBATCH --cpus-per-task=4
#SBATCH --time=40:00:00
#SBATCH --mem=128GB
#SBATCH --job-name=finish_train_network
#SBATCH --mail-type=BEGIN,END
#SBATCH --mail-user=mnk2978@nyu.edu

module purge;
set -x

#debug flags
echo $SLURM_JOB_NAME

#env vars
export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
export MASTER_PORT=$(shuf -i 10000-65500 -n 1)
export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_NTASKS_PER_NODE))
echo "WORLD_SIZE="$WORLD_SIZE
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR
export CUDA_VISIBLE_DEVICES=0,1,2,3
#comments

#trained on iNaturalist with captions
#run command
srun --accel-bind=v \
/bin/bash /ocean/projects/cis220069p/mnk2989/setup_singularity.bash \
/bin/bash -c \
'export WANDB_API_KEY=983121974971d75792b8a5ffdd8058944431e1f2; export PYTHONPATH="$PYTHONPATH:/ocean/projects/cis220069p/mnk2989/vlhub/src"; python /ocean/projects/cis220069p/mnk2989/vlhub/src/training/main.py \
--save-frequency 4 \
--report-to wandb \
--dataset-type webdataset \
--train-data "/ocean/projects/cis220069p/mnk2989/data/datasets/inat_captions/{000000..000362}.tar" \
--train-num-samples 2600000 \
--inat2021 "/ocean/projects/cis220069p/mnk2989/data/datasets/" \
--zeroshot-frequency=4 \
--warmup 2000 \
--batch-size=256 \
--wd=0.1 \
--epochs=32 \
--workers=4 \
--model=RN50 \
--seed 0 \
--local-loss \
--lr 5e-4 \
--gather-with-grad'
56 changes: 56 additions & 0 deletions run_clip_timm_RN50.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/bin/bash

#SBATCH --output=logs/out_%A_%j.log
#SBATCH --nodes=1
#SBATCH -p GPU-shared
#SBATCH --ntasks-per-node=4
#SBATCH --gpus=v100-32:4
#SBATCH --cpus-per-task=4
#SBATCH --time=40:00:00
#SBATCH --mem=128GB
#SBATCH --job-name=finish_train_network
#SBATCH --mail-type=BEGIN,END
#SBATCH --mail-user=mnk2978@nyu.edu

module purge;
set -x

#debug flags
echo $SLURM_JOB_NAME

#env vars
export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
export MASTER_PORT=$(shuf -i 10000-65500 -n 1)
export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_NTASKS_PER_NODE))
echo "WORLD_SIZE="$WORLD_SIZE
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR
export CUDA_VISIBLE_DEVICES=0,1,2,3
#comments

#trained on iNaturalist with captions
#run command
srun --accel-bind=v \
/bin/bash /ocean/projects/cis220069p/mnk2989/setup_singularity.bash \
/bin/bash -c \
'export WANDB_API_KEY=983121974971d75792b8a5ffdd8058944431e1f2; export PYTHONPATH="$PYTHONPATH:/ocean/projects/cis220069p/mnk2989/vlhub/src"; python /ocean/projects/cis220069p/mnk2989/vlhub/src/training/main.py \
--save-frequency 4 \
--report-to wandb \
--dataset-type webdataset \
--train-data "/ocean/projects/cis220069p/mnk2989/data/datasets/inat_captions/{000000..000362}.tar" \
--train-num-samples 2600000 \
--inat2021 "/ocean/projects/cis220069p/mnk2989/data/datasets/" \
--zeroshot-frequency=4 \
--warmup 2000 \
--batch-size=256 \
--wd=0.1 \
--epochs=32 \
--workers=4 \
--model=timm-resnet50 \
--pretrained-image \
--lock-image \
--seed 0 \
--local-loss \
--lr 5e-4 \
--gather-with-grad'
2 changes: 1 addition & 1 deletion src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def create_model(
assert False, 'pretrained image towers currently only supported for timm models'

model = CLIP(**model_cfg)

pretrained_cfg = {}
if pretrained:
checkpoint_path = ''
Expand Down
Empty file.
21 changes: 21 additions & 0 deletions src/open_clip/model_configs/RN50-imagesonly.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"embed_dim": 10000,
"vision_cfg": {
"image_size": 224,
"layers": [
3,
4,
6,
3
],
"width": 64,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1000,
"heads": 8,
"layers": 12
}
}
38 changes: 20 additions & 18 deletions src/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,13 @@ def token_strip_func(texts):
texts = torch.tensor(tlist)
return texts

def clean_integer_label(label, singleclass, strict):
def clean_integer_label(label, singleclass, strict, ds):
if ds is None:
ds = [0]*1000
if isinstance(label, float):
label = int(label)
if isinstance(label, int):
if label < 0 or label > 999:
if label < 0 or label > len(ds) - 1:
logging.info("Integer label {} out of acceptable range, mapping to 0".format(label))
label = 0
if singleclass:
Expand All @@ -153,7 +155,7 @@ def clean_integer_label(label, singleclass, strict):
label_updated = []
for l in label:
intl = int(l)
if intl < 0 or intl > 999:
if intl < 0 or intl > len(ds) - 1:
logging.info("Integer label {} out of acceptable range, mapping to 0".format(intl))
label_updated.append(0)
else:
Expand Down Expand Up @@ -209,13 +211,13 @@ def __init__(self, input_filename, transforms, img_key, caption_key, csvfilter,
df = df[df['title'].str.len() > 0]
logging.debug("Done. Length is now {}".format(len(df)))
logging.debug(df.head())
elif csvfilter != "":
logging.debug('Filtering captions. Original dataset size is {}'.format(len(df)))
df['is_synset'] = df[caption_key].progress_apply(synset_ds, ngram=3, ds=csvfilter, cipher=False, simplecaptions=False, strict=strict, shift=shift, metacaptions=metacaptions)
logging.debug(df['is_synset'].head())
df = df[df['is_synset']].drop(columns=['is_synset'])
logging.debug("Done. Length is now {}".format(len(df)))
logging.debug(df.head())
# elif csvfilter != "":
# logging.debug('Filtering captions. Original dataset size is {}'.format(len(df)))
# df['is_synset'] = df[caption_key].progress_apply(synset_ds, ngram=10, ds=csvfilter, cipher=False, simplecaptions=False, strict=strict, shift=shift, metacaptions=metacaptions)
# logging.debug(df['is_synset'].head())
# df = df[df['is_synset']].drop(columns=['is_synset'])
# logging.debug("Done. Length is now {}".format(len(df)))
# logging.debug(df.head())
self.images = df[img_key].tolist()
self.captions = df[caption_key].tolist()
self.transforms = transforms
Expand Down Expand Up @@ -255,7 +257,7 @@ def __getitem__(self, idx):
#if isinstance(texts, str) and not texts.is_numeric():
#assert(False, "Integer labels cannot be computed on the fly for a CSV dataset")
#texts = [synset_ds(clean_captions(str(texts)), 3, self.csvfilter, False, False, self.strict, False, True, None) for t in texts]
texts = clean_integer_label(self.captions[idx], not self.multiclass, self.strict)
texts = clean_integer_label(self.captions[idx], not self.multiclass, self.strict, self.csvfilter)
return images, texts
if self.scrambled:
texts = scramble_txt(texts)
Expand Down Expand Up @@ -346,10 +348,10 @@ def preprocess_txt(text, token_scrambled, token_strip):
def filter_preprocess_txt(text, ds, scrambled, dscipher, simplecaptions, strict, shift, integer_labels, multiclass, metacaptions):
if bool(ds):
if integer_labels:
text = clean_captions(str(text))
text = synset_ds(text, 3, ds, False, False, strict, False, integer_labels, metacaptions)
text = clean_captions(str(text)) # just does lower case
text = synset_ds(text, 3, ds, False, False, strict, False, integer_labels, metacaptions) # It should return an integer
if text:
text = clean_integer_label(text, not multiclass, strict)
text = clean_integer_label(text, not multiclass, strict, ds)
else:
text = ""
else:
Expand Down Expand Up @@ -396,7 +398,7 @@ def shift_cipher(s, shift):
WARNING: can return string or bool, depending on arguments provided
"""

def synset_ds(s, ngram=3, ds=None, cipher=False, simplecaptions=False, strict=False, shift=None, integer_labels=False, metacaptions=None):
def synset_ds(s, ngram=5, ds=None, cipher=False, simplecaptions=False, strict=False, shift=None, integer_labels=False, metacaptions=None):
flag = False
s = list(lemmatizer.lemmatize(t) for t in s.split(" "))
str_s = " ".join(w for w in s)
Expand Down Expand Up @@ -460,7 +462,7 @@ def synset_ds(s, ngram=3, ds=None, cipher=False, simplecaptions=False, strict=Fa
elif d and str_s.find(gram) == -1:
str_s += " {}".format(gram)
flag=True

if len(str_s) > 76:
str_s = str_s[:75]

Expand All @@ -472,7 +474,7 @@ def synset_ds(s, ngram=3, ds=None, cipher=False, simplecaptions=False, strict=Fa
elif shift:
str_s = shift_cipher(str_s, shift)
return str_s

return flag

def get_dataset_size(shards):
Expand Down Expand Up @@ -909,7 +911,7 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, total=
dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this
else:
# last batches are partial, eval is done on single (master) node
num_batches = math.ceil(num_samples / args.batch_size)
num_batches = math.ceil(num_samples / args.batch_size)
dataloader = wds.WebLoader(
dataset,
batch_size=None,
Expand Down
6 changes: 5 additions & 1 deletion src/training/inat_zeroshot_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import pandas as pd
import ast
from pathlib import Path
import nltk
from clean_filter_captions import ds_val_getter, clean_captions
lemmatizer = nltk.stem.WordNetLemmatizer()

try:
METPATH = "./metadata/inat2021-categories.csv"
Expand All @@ -8,7 +13,6 @@
df = pd.read_csv(METPATH)

inat_classnames = df['label'].tolist()

inat_template = [
lambda c: f'a bad photo of a {c}.',
lambda c: f'a photo of many {c}.',
Expand Down
3 changes: 0 additions & 3 deletions src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w

#MAIN TRAINING LOOP
for i, batch in enumerate(dataloader):

#HOUSEKEEPING
# if args.ds_filter and args.debug:
# for b in batch[1].tolist():
# if b not in batchset:
# batchset.append(b)
Expand All @@ -159,7 +157,6 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w
step = num_batches_per_epoch * epoch + i
scheduler(step)
images, texts = batch
#logging.info(texts)
texts = texts.to(device=device, non_blocking=True)
images = images.to(device=device, non_blocking=True)
data_time_m.update(time.time() - end)
Expand Down
24 changes: 15 additions & 9 deletions src/training/zero_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def l2norm(t):
return F.normalize(t, dim = -1, p = 2)

def zero_shot_classifier(model, classnames, templates, args):
logging.debug("In zero-shot-classifer, classnames are {}".format(classnames))
# logging.debug("In zero-shot-classifer, classnames are {}".format(classnames))
with torch.no_grad():
zeroshot_weights = []
for classname in tqdm(classnames):
Expand All @@ -41,7 +41,7 @@ def zero_shot_classifier(model, classnames, templates, args):
random.shuffle(l)
res.append(" ".join(l).strip())
texts = tokenize(texts).to(args.device) # tokenize
logging.debug("In zero-shot-classifer, tokens are {}".format(classnames))
# logging.debug("In zero-shot-classifer, tokens are {}".format(texts))
if args.distributed and not args.horovod:
if args.model in ["coca"]:
images = torch.rand(len(texts), 3, 224, 224).to(args.device)
Expand Down Expand Up @@ -182,7 +182,7 @@ def run(model, classifier, dataloader, args, idx=None, split=None):
image_features = model.encode_image(images)
image_features = F.normalize(image_features, dim=-1)
logits = 100. * image_features @ classifier

# measure accuracy with objectnet adjustments
if split == "objectnet" and args.integer_labels:
with open("./metadata/imagenet_to_objectnet.json","r") as f:
Expand Down Expand Up @@ -285,7 +285,6 @@ def imageNetIDToObjectNetID(prediction_class):
prediction_class[i] = -1

def zero_shot_eval(model, data, epoch, args):
#logging.debug(data)

results = {}
classifier = None
Expand All @@ -302,11 +301,18 @@ def zero_shot_eval(model, data, epoch, args):
# inat_classnames = to_upper(inat_classnames)
# elif args.zs_lower:
# inat_classnames = to_lower(inat_classnames)
logging.info("Starting zero-shot inat2021.")
logging.info('Building zero-shot classifier')
classifier = zero_shot_classifier(model, inat_classnames, inat_template, args)

logging.info('Using classifier')

isint = (args.integer_labels or args.linear_probe)
# usecaps = args.caption_subset and not isint
if isint:
args.classnames = inat_classnames
classifier = None
# return classifier
else:
logging.info('Building zero-shot classifier')
classifier = zero_shot_classifier(model, inat_classnames, inat_template, args)
# classifier = None
logging.info('Using classifier')
top1, top5 = run(model, classifier, data['inat2021'].dataloader, args)
results['inat2021-top1'] = top1
results['inat2021-top5'] = top5
Expand Down