Skip to content

Commit f62cc97

Browse files
authored
Add files via upload
1 parent c21a52a commit f62cc97

File tree

10 files changed

+2566
-0
lines changed

10 files changed

+2566
-0
lines changed

basher1.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os, random, argparse, time
2+
parser = argparse.ArgumentParser(description='DAL training procedure on the CIFAR benchmark',
3+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
4+
parser.add_argument('loss', type=str,
5+
help='npo npo2 npov2 - v5')
6+
parser.add_argument('--setting', type=str,
7+
help='forget01 forget05 forget10')
8+
parser.add_argument('--model', type=str,
9+
help='phi llama')
10+
parser.add_argument('--cuda_id', type=int,
11+
help='0~7')
12+
parser.add_argument('--hyper', type=int)
13+
args = parser.parse_args()
14+
15+
if args.setting == 'forget10':
16+
save_steps = 5
17+
epoch = 5
18+
elif args.setting == 'forget05':
19+
save_steps = 5
20+
epoch = 5
21+
elif args.setting == 'forget01':
22+
epoch = 5
23+
save_steps = 5
24+
else:
25+
raise RuntimeError()
26+
if args.model == 'phi':
27+
lr = 2e-5
28+
elif args.model == 'llama':
29+
lr = 1e-5
30+
else: raise RuntimeError()
31+
32+
for param in [args.hyper]:
33+
if args.model == 'phi':
34+
os.system(f'CUDA_VISIBLE_DEVICES={args.cuda_id} torchrun --nproc_per_node=1 --master_port={random.randint(0,60000)} forget2.py --config-name=forget.yaml split={args.setting} model_family=phi lr={lr} forget_loss={args.loss} save_steps={save_steps} hyper_param={param} num_epochs={epoch}')
35+
elif args.model == 'llama':
36+
os.system(f'CUDA_VISIBLE_DEVICES={args.cuda_id} torchrun --nproc_per_node=1 --master_port={random.randint(0,60000)} forget2.py --config-name=forget.yaml split={args.setting} model_family=llama2-7b lr={lr} forget_loss={args.loss} save_steps={save_steps} hyper_param={param} num_epochs={epoch}')
37+
time.sleep(1)
38+
39+

basher2.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import os, random, argparse, time
2+
parser = argparse.ArgumentParser(description='DAL training procedure on the CIFAR benchmark',
3+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
4+
parser.add_argument('loss', type=str,
5+
help='npo npo2 npov2 - v5')
6+
parser.add_argument('--setting', type=str,
7+
help='forget01 forget05 forget10')
8+
parser.add_argument('--model', type=str,
9+
help='phi llama')
10+
parser.add_argument('--cuda_id', type=int,
11+
help='0~7')
12+
parser.add_argument('--hyper', type=int)
13+
args = parser.parse_args()
14+
15+
if args.setting == 'forget10':
16+
save_steps = 1000
17+
epoch = 1
18+
elif args.setting == 'forget05':
19+
save_steps = 1000
20+
epoch = 1
21+
elif args.setting == 'forget01':
22+
epoch = 1
23+
save_steps = 1000
24+
else:
25+
raise RuntimeError()
26+
if args.model == 'phi':
27+
lr = 2e-5
28+
lr_str = '2e-05'
29+
model = 'phi'
30+
elif args.model == 'llama':
31+
lr = 1e-5
32+
lr_str = '1e-05'
33+
model = 'llama2-7b'
34+
else: raise RuntimeError()
35+
36+
param = args.hyper
37+
if args.model == 'phi':
38+
os.system(f'CUDA_VISIBLE_DEVICES={args.cuda_id} torchrun --nproc_per_node=1 --master_port={random.randint(0,60000)} forget2_ge.py --config-name=forget_ge.yaml split={args.setting} model_family=phi lr={lr} forget_loss={args.loss} save_steps={save_steps} hyper_param={param} num_epochs={epoch}')
39+
elif args.model == 'llama':
40+
os.system(f'CUDA_VISIBLE_DEVICES={args.cuda_id} torchrun --nproc_per_node=1 --master_port={random.randint(0,60000)} forget2_ge.py --config-name=forget_ge.yaml split={args.setting} model_family=llama2-7b lr=1e-5 forget_loss={args.loss} save_steps={save_steps} hyper_param={param} num_epochs={epoch}')
41+
time.sleep(1)
42+
cap = 62 if args.setting=='forget05' else 130
43+
for iteration in range(5,cap,5):
44+
if args.loss == 'idk':
45+
path = f'icml/{model}/{args.loss}_{lr_str}_{args.setting}_5_0.0_{param}/checkpoint-' + ('%d' % iteration)
46+
else:
47+
path = f'icml/{model}/{args.loss}_{lr_str}_{args.setting}_5_0.0_{param}/checkpoint-' + ('%d' % iteration)
48+
if args.model == 'phi':
49+
os.system(f'CUDA_VISIBLE_DEVICES={args.cuda_id} torchrun --nproc_per_node=1 --master_port={random.randint(0,60000)} forget2_ge.py --config-name=forget_ge.yaml split={args.setting} model_family=phi lr={lr} forget_loss={args.loss} save_steps={save_steps} hyper_param={param} num_epochs={epoch} model_path_cur={path}')
50+
elif args.model == 'llama':
51+
os.system(f'CUDA_VISIBLE_DEVICES={args.cuda_id} torchrun --nproc_per_node=1 --master_port={random.randint(0,60000)} forget2_ge.py --config-name=forget_ge.yaml split={args.setting} model_family=llama2-7b lr=1e-5 forget_loss={args.loss} save_steps={save_steps} hyper_param={param} num_epochs={epoch} model_path_cur={path}')
52+
time.sleep(1)
53+
54+

data_module.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
import torch
2+
import pdb
3+
from torch import nn
4+
from torch.utils.data import Dataset
5+
from torch.nn.utils.rnn import pad_sequence
6+
import datasets
7+
import pandas as pd
8+
from utils import get_model_identifiers_from_yaml, add_dataset_index
9+
10+
def convert_raw_data_to_model_format(tokenizer, max_length, question, answer, model_configs):
11+
question_start_token, question_end_token, answer_token = model_configs['question_start_tag'], model_configs['question_end_tag'], model_configs['answer_tag']
12+
new_question = question_start_token + question + question_end_token
13+
new_answer = answer_token + answer
14+
full_text = new_question + new_answer
15+
num_question_tokens = len(tokenizer.tokenize(new_question, add_special_tokens=True))
16+
17+
encoded = tokenizer(
18+
full_text,
19+
add_special_tokens=True,
20+
max_length=max_length,
21+
truncation=True,
22+
)
23+
pad_length = max_length - len(encoded.input_ids)
24+
pad_input_ids = encoded['input_ids'] + [tokenizer.eos_token_id] * pad_length
25+
pad_attention_mask = encoded['attention_mask'] + [0] * pad_length
26+
if len(encoded.input_ids) == max_length:
27+
label = encoded.input_ids
28+
else:
29+
label = encoded['input_ids'] + [tokenizer.eos_token_id] + [-100] * (pad_length-1)
30+
31+
#change label to -100 for question tokens
32+
for i in range(num_question_tokens): label[i] = -100
33+
34+
return torch.tensor(pad_input_ids),torch.tensor(label),torch.tensor(pad_attention_mask)
35+
36+
class TextDatasetQA(Dataset):
37+
def __init__(self, data_path, tokenizer, model_family, max_length=512, split = None, question_key='question', answer_key='answer'):
38+
super(TextDatasetQA, self).__init__()
39+
self.tokenizer = tokenizer
40+
self.max_length = max_length
41+
# data_len = len(datasets.load_dataset(data_path, split)["train"])
42+
# self.data = datasets.load_dataset(data_path, split)["train"].select(range(min(100, data_len)))
43+
self.data = datasets.load_dataset(data_path, split)["train"]
44+
45+
self.data = add_dataset_index(self.data)
46+
self.model_configs = get_model_identifiers_from_yaml(model_family)
47+
self.qk = question_key
48+
self.ak = answer_key
49+
50+
def __len__(self):
51+
return len(self.data)
52+
53+
def __getitem__(self, idx):
54+
question = self.data[idx][self.qk]
55+
answers = self.data[idx][self.ak]
56+
indices = self.data[idx]['index']
57+
if isinstance(answers, str):
58+
answers = [answers]
59+
60+
pad_input_ids_list = []
61+
label_list = []
62+
pad_attention_mask_list = []
63+
64+
for answer in answers:
65+
converted_data = convert_raw_data_to_model_format(self.tokenizer, self.max_length, question, answer, self.model_configs)
66+
pad_input_ids_list.append(converted_data[0])
67+
label_list.append(converted_data[1])
68+
pad_attention_mask_list.append(converted_data[2])
69+
70+
return torch.stack(pad_input_ids_list).squeeze(),\
71+
torch.stack(label_list).squeeze(),\
72+
torch.stack(pad_attention_mask_list).squeeze(),\
73+
torch.tensor(indices)
74+
75+
class TextForgetDatasetQA2(Dataset):
76+
def __init__(self, data_path, tokenizer, model_family, max_length=512, split = "forget10", loss_type="att_"):
77+
super(TextForgetDatasetQA2, self).__init__()
78+
self.tokenizer = tokenizer
79+
self.max_length = max_length
80+
81+
self.forget_data = datasets.load_dataset(data_path, split)["train"]
82+
retain_split = "retain" + str(100 - int(split.replace("forget", ""))).zfill(2)
83+
self.retain_data = datasets.load_dataset(data_path, retain_split)["train"]
84+
85+
data_f=pd.DataFrame(self.retain_data).iloc[400:].reset_index(drop=True) # seperate 400 data point for evaluations
86+
self.retain_data_train = datasets.Dataset.from_pandas(data_f)
87+
88+
self.model_configs = get_model_identifiers_from_yaml(model_family)
89+
self.loss_type = loss_type
90+
91+
if self.loss_type == "idk":
92+
self.split1, self.split2 = "idk", "retain"
93+
self.idontknowfile = "data/idontknow.jsonl"
94+
self.idk = open(self.idontknowfile, "r").readlines()
95+
96+
############### from qz
97+
elif 'att_' in self.loss_type:
98+
attention_words = torch.load('../tofu_attention/attention_idx' + split + '.pth')
99+
if len(attention_words) != len(self.forget_data):
100+
raise RuntimeError('The lengths of attention words do not match the dataset!')
101+
self.forget_data = self.forget_data.add_column('critical_word', [attention_words[_] for _ in attention_words])
102+
self.split1, self.split2 = "forget", "retain"
103+
###############
104+
else:
105+
self.split1, self.split2 = "forget", "retain"
106+
107+
def __len__(self):
108+
return len(self.forget_data)
109+
110+
def __getitem__(self, idx):
111+
rets = []
112+
for data_type in [self.split1, self.split2]:
113+
#use questions from forget set if split is idk or forget
114+
if data_type == "retain":
115+
data = self.retain_data_train
116+
idx = (idx + torch.randint(0, len(self.retain_data_train), (1,)).item()) % len(self.retain_data_train)
117+
else:
118+
data=self.forget_data
119+
idx=idx
120+
121+
question = data[idx]['question']
122+
answer = data[idx]['answer']
123+
if data_type == "idk":
124+
rand_pos = torch.randint(0, len(self.idk), (1,)).item()
125+
answer = self.idk[rand_pos].strip()
126+
127+
############### from qz , here we have a copy of convert_raw_data_to_model_format, just looking to those with if 'att_' in self.loss_type:
128+
question_start_token, question_end_token, answer_token = self.model_configs['question_start_tag'], self.model_configs['question_end_tag'], self.model_configs['answer_tag']
129+
new_question = question_start_token + question + question_end_token
130+
new_answer = answer_token + answer
131+
full_text = new_question + new_answer
132+
num_question_tokens = len(self.tokenizer.tokenize(new_question, add_special_tokens=True))
133+
#print(num_question_tokens)
134+
if data_type=="forget":
135+
if 'att_' in self.loss_type:
136+
attention_word=self.forget_data[idx]['critical_word']
137+
asciied_answer = [''.join([_ for _ in __ if _.isascii()]) for __ in self.tokenizer.tokenize(new_answer)]
138+
critical_idx_tokens = [num_question_tokens + idx for idx, _ in enumerate(asciied_answer) if _ in attention_word and _ != '' and (len(_)>=2 or _.isnumeric())]
139+
#print(len(self.tokenizer.tokenize(new_answer)))
140+
#print(len(asciied_answer))
141+
#print(critical_idx_tokens)
142+
143+
encoded = self.tokenizer(
144+
full_text,
145+
add_special_tokens=True,
146+
max_length=self.max_length,
147+
truncation=True,
148+
)
149+
150+
pad_length = self.max_length - len(encoded.input_ids)
151+
pad_input_ids = encoded['input_ids'] + [self.tokenizer.eos_token_id] * pad_length
152+
pad_attention_mask = encoded['attention_mask'] + [0] * pad_length
153+
if len(encoded.input_ids) == self.max_length:
154+
label = encoded.input_ids
155+
else:
156+
label = encoded['input_ids'] + [self.tokenizer.eos_token_id] + [-100] * (pad_length-1)
157+
158+
#change label to -100 for question tokens
159+
for i in range(num_question_tokens): label[i] = -100
160+
#print(label)
161+
if data_type=="forget":
162+
if 'att_' in self.loss_type:
163+
for idx, ele in enumerate(label):
164+
if idx not in critical_idx_tokens: label[idx] = -100
165+
#print(label)
166+
converted_data = torch.tensor(pad_input_ids),torch.tensor(label),torch.tensor(pad_attention_mask)
167+
rets.append(converted_data)
168+
return rets
169+
170+
def collate_fn(batch):
171+
input_ids, attention_masks = zip(*batch)
172+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=-100)
173+
attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)
174+
return input_ids, attention_masks
175+
176+
def custom_data_collator(samples):
177+
input_ids = [s[0] for s in samples]
178+
labels = [s[1] for s in samples]
179+
attention_mask = [s[2] for s in samples]
180+
return torch.stack(input_ids), torch.stack(labels), torch.stack(attention_mask)
181+
182+
def custom_data_collator_with_indices(samples):
183+
input_ids = [s[0] for s in samples]
184+
labels = [s[1] for s in samples]
185+
attention_mask = [s[2] for s in samples]
186+
indices = [s[3] for s in samples]
187+
return torch.stack(input_ids), torch.stack(labels), torch.stack(attention_mask), torch.stack(indices)
188+
189+
def get_batch_loss(output, labels):
190+
shifted_labels = labels[..., 1:].contiguous()
191+
output = output[..., :-1, :].contiguous()
192+
193+
loss_function = nn.CrossEntropyLoss(ignore_index=-100, reduction='none')
194+
# get the sum loss for each sequence in a batch
195+
loss = loss_function(output.transpose(-1,-2), shifted_labels).sum(dim=-1)
196+
197+
return loss
198+
199+
def model_mix(model,before,after,update_ratio):
200+
for name,parameter in model.named_parameters():
201+
parameter.data=update_ratio*before[name[:]].cuda()+(1-update_ratio)*after[name[:]].cuda()
202+
return model
203+
204+
'''
205+
import hydra, os
206+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, set_seed
207+
208+
@hydra.main(version_base=None, config_path="config", config_name="forget")
209+
def main(cfg):
210+
# ------------ DDP Pytorch 分布式训练 ----------- #
211+
212+
num_devices = int(os.environ.get('WORLD_SIZE', 1)) # os.environ 获取环境变量
213+
print(f"num_devices: {num_devices}")
214+
if os.environ.get('LOCAL_RANK') is not None:
215+
local_rank = int(os.environ.get('LOCAL_RANK', '0'))
216+
device_map = {'': local_rank}
217+
else: local_rank = 0
218+
219+
os.environ["WANDB_DISABLED"] = "true"
220+
# --------------------------------------------- #
221+
222+
model_cfg = get_model_identifiers_from_yaml(cfg.model_family)
223+
model_id = model_cfg["hf_key"] # huggingface key
224+
if cfg.model_path is None:
225+
cfg.model_path = model_cfg["ft_model_path"]
226+
227+
# save cfg in cfg.save_dir
228+
if local_rank == 0:
229+
with open(f"{cfg.save_dir}/config.yaml", "w") as file:
230+
# omegaconf.save(cfg, file)
231+
pass
232+
233+
if os.path.exists(cfg.save_dir):
234+
print("Directory already exists")
235+
if not cfg.overwrite_dir:
236+
exit()
237+
238+
tokenizer = AutoTokenizer.from_pretrained(model_id)
239+
tokenizer.pad_token = tokenizer.eos_token
240+
241+
torch_format_dataset = TextForgetDatasetQA2(cfg.data_path, tokenizer=tokenizer, model_family = cfg.model_family, max_length=500, split='forget01', loss_type='att_')
242+
#print(torch_format_dataset[1])
243+
#print(torch_format_dataset[0])
244+
245+
if __name__ == "__main__":
246+
main()
247+
248+
'''

0 commit comments

Comments
 (0)