1
+ import torch
2
+ import pdb
3
+ from torch import nn
4
+ from torch .utils .data import Dataset
5
+ from torch .nn .utils .rnn import pad_sequence
6
+ import datasets
7
+ import pandas as pd
8
+ from utils import get_model_identifiers_from_yaml , add_dataset_index
9
+
10
+ def convert_raw_data_to_model_format (tokenizer , max_length , question , answer , model_configs ):
11
+ question_start_token , question_end_token , answer_token = model_configs ['question_start_tag' ], model_configs ['question_end_tag' ], model_configs ['answer_tag' ]
12
+ new_question = question_start_token + question + question_end_token
13
+ new_answer = answer_token + answer
14
+ full_text = new_question + new_answer
15
+ num_question_tokens = len (tokenizer .tokenize (new_question , add_special_tokens = True ))
16
+
17
+ encoded = tokenizer (
18
+ full_text ,
19
+ add_special_tokens = True ,
20
+ max_length = max_length ,
21
+ truncation = True ,
22
+ )
23
+ pad_length = max_length - len (encoded .input_ids )
24
+ pad_input_ids = encoded ['input_ids' ] + [tokenizer .eos_token_id ] * pad_length
25
+ pad_attention_mask = encoded ['attention_mask' ] + [0 ] * pad_length
26
+ if len (encoded .input_ids ) == max_length :
27
+ label = encoded .input_ids
28
+ else :
29
+ label = encoded ['input_ids' ] + [tokenizer .eos_token_id ] + [- 100 ] * (pad_length - 1 )
30
+
31
+ #change label to -100 for question tokens
32
+ for i in range (num_question_tokens ): label [i ] = - 100
33
+
34
+ return torch .tensor (pad_input_ids ),torch .tensor (label ),torch .tensor (pad_attention_mask )
35
+
36
+ class TextDatasetQA (Dataset ):
37
+ def __init__ (self , data_path , tokenizer , model_family , max_length = 512 , split = None , question_key = 'question' , answer_key = 'answer' ):
38
+ super (TextDatasetQA , self ).__init__ ()
39
+ self .tokenizer = tokenizer
40
+ self .max_length = max_length
41
+ # data_len = len(datasets.load_dataset(data_path, split)["train"])
42
+ # self.data = datasets.load_dataset(data_path, split)["train"].select(range(min(100, data_len)))
43
+ self .data = datasets .load_dataset (data_path , split )["train" ]
44
+
45
+ self .data = add_dataset_index (self .data )
46
+ self .model_configs = get_model_identifiers_from_yaml (model_family )
47
+ self .qk = question_key
48
+ self .ak = answer_key
49
+
50
+ def __len__ (self ):
51
+ return len (self .data )
52
+
53
+ def __getitem__ (self , idx ):
54
+ question = self .data [idx ][self .qk ]
55
+ answers = self .data [idx ][self .ak ]
56
+ indices = self .data [idx ]['index' ]
57
+ if isinstance (answers , str ):
58
+ answers = [answers ]
59
+
60
+ pad_input_ids_list = []
61
+ label_list = []
62
+ pad_attention_mask_list = []
63
+
64
+ for answer in answers :
65
+ converted_data = convert_raw_data_to_model_format (self .tokenizer , self .max_length , question , answer , self .model_configs )
66
+ pad_input_ids_list .append (converted_data [0 ])
67
+ label_list .append (converted_data [1 ])
68
+ pad_attention_mask_list .append (converted_data [2 ])
69
+
70
+ return torch .stack (pad_input_ids_list ).squeeze (),\
71
+ torch .stack (label_list ).squeeze (),\
72
+ torch .stack (pad_attention_mask_list ).squeeze (),\
73
+ torch .tensor (indices )
74
+
75
+ class TextForgetDatasetQA2 (Dataset ):
76
+ def __init__ (self , data_path , tokenizer , model_family , max_length = 512 , split = "forget10" , loss_type = "att_" ):
77
+ super (TextForgetDatasetQA2 , self ).__init__ ()
78
+ self .tokenizer = tokenizer
79
+ self .max_length = max_length
80
+
81
+ self .forget_data = datasets .load_dataset (data_path , split )["train" ]
82
+ retain_split = "retain" + str (100 - int (split .replace ("forget" , "" ))).zfill (2 )
83
+ self .retain_data = datasets .load_dataset (data_path , retain_split )["train" ]
84
+
85
+ data_f = pd .DataFrame (self .retain_data ).iloc [400 :].reset_index (drop = True ) # seperate 400 data point for evaluations
86
+ self .retain_data_train = datasets .Dataset .from_pandas (data_f )
87
+
88
+ self .model_configs = get_model_identifiers_from_yaml (model_family )
89
+ self .loss_type = loss_type
90
+
91
+ if self .loss_type == "idk" :
92
+ self .split1 , self .split2 = "idk" , "retain"
93
+ self .idontknowfile = "data/idontknow.jsonl"
94
+ self .idk = open (self .idontknowfile , "r" ).readlines ()
95
+
96
+ ############### from qz
97
+ elif 'att_' in self .loss_type :
98
+ attention_words = torch .load ('../tofu_attention/attention_idx' + split + '.pth' )
99
+ if len (attention_words ) != len (self .forget_data ):
100
+ raise RuntimeError ('The lengths of attention words do not match the dataset!' )
101
+ self .forget_data = self .forget_data .add_column ('critical_word' , [attention_words [_ ] for _ in attention_words ])
102
+ self .split1 , self .split2 = "forget" , "retain"
103
+ ###############
104
+ else :
105
+ self .split1 , self .split2 = "forget" , "retain"
106
+
107
+ def __len__ (self ):
108
+ return len (self .forget_data )
109
+
110
+ def __getitem__ (self , idx ):
111
+ rets = []
112
+ for data_type in [self .split1 , self .split2 ]:
113
+ #use questions from forget set if split is idk or forget
114
+ if data_type == "retain" :
115
+ data = self .retain_data_train
116
+ idx = (idx + torch .randint (0 , len (self .retain_data_train ), (1 ,)).item ()) % len (self .retain_data_train )
117
+ else :
118
+ data = self .forget_data
119
+ idx = idx
120
+
121
+ question = data [idx ]['question' ]
122
+ answer = data [idx ]['answer' ]
123
+ if data_type == "idk" :
124
+ rand_pos = torch .randint (0 , len (self .idk ), (1 ,)).item ()
125
+ answer = self .idk [rand_pos ].strip ()
126
+
127
+ ############### from qz , here we have a copy of convert_raw_data_to_model_format, just looking to those with if 'att_' in self.loss_type:
128
+ question_start_token , question_end_token , answer_token = self .model_configs ['question_start_tag' ], self .model_configs ['question_end_tag' ], self .model_configs ['answer_tag' ]
129
+ new_question = question_start_token + question + question_end_token
130
+ new_answer = answer_token + answer
131
+ full_text = new_question + new_answer
132
+ num_question_tokens = len (self .tokenizer .tokenize (new_question , add_special_tokens = True ))
133
+ #print(num_question_tokens)
134
+ if data_type == "forget" :
135
+ if 'att_' in self .loss_type :
136
+ attention_word = self .forget_data [idx ]['critical_word' ]
137
+ asciied_answer = ['' .join ([_ for _ in __ if _ .isascii ()]) for __ in self .tokenizer .tokenize (new_answer )]
138
+ critical_idx_tokens = [num_question_tokens + idx for idx , _ in enumerate (asciied_answer ) if _ in attention_word and _ != '' and (len (_ )>= 2 or _ .isnumeric ())]
139
+ #print(len(self.tokenizer.tokenize(new_answer)))
140
+ #print(len(asciied_answer))
141
+ #print(critical_idx_tokens)
142
+
143
+ encoded = self .tokenizer (
144
+ full_text ,
145
+ add_special_tokens = True ,
146
+ max_length = self .max_length ,
147
+ truncation = True ,
148
+ )
149
+
150
+ pad_length = self .max_length - len (encoded .input_ids )
151
+ pad_input_ids = encoded ['input_ids' ] + [self .tokenizer .eos_token_id ] * pad_length
152
+ pad_attention_mask = encoded ['attention_mask' ] + [0 ] * pad_length
153
+ if len (encoded .input_ids ) == self .max_length :
154
+ label = encoded .input_ids
155
+ else :
156
+ label = encoded ['input_ids' ] + [self .tokenizer .eos_token_id ] + [- 100 ] * (pad_length - 1 )
157
+
158
+ #change label to -100 for question tokens
159
+ for i in range (num_question_tokens ): label [i ] = - 100
160
+ #print(label)
161
+ if data_type == "forget" :
162
+ if 'att_' in self .loss_type :
163
+ for idx , ele in enumerate (label ):
164
+ if idx not in critical_idx_tokens : label [idx ] = - 100
165
+ #print(label)
166
+ converted_data = torch .tensor (pad_input_ids ),torch .tensor (label ),torch .tensor (pad_attention_mask )
167
+ rets .append (converted_data )
168
+ return rets
169
+
170
+ def collate_fn (batch ):
171
+ input_ids , attention_masks = zip (* batch )
172
+ input_ids = pad_sequence (input_ids , batch_first = True , padding_value = - 100 )
173
+ attention_masks = pad_sequence (attention_masks , batch_first = True , padding_value = 0 )
174
+ return input_ids , attention_masks
175
+
176
+ def custom_data_collator (samples ):
177
+ input_ids = [s [0 ] for s in samples ]
178
+ labels = [s [1 ] for s in samples ]
179
+ attention_mask = [s [2 ] for s in samples ]
180
+ return torch .stack (input_ids ), torch .stack (labels ), torch .stack (attention_mask )
181
+
182
+ def custom_data_collator_with_indices (samples ):
183
+ input_ids = [s [0 ] for s in samples ]
184
+ labels = [s [1 ] for s in samples ]
185
+ attention_mask = [s [2 ] for s in samples ]
186
+ indices = [s [3 ] for s in samples ]
187
+ return torch .stack (input_ids ), torch .stack (labels ), torch .stack (attention_mask ), torch .stack (indices )
188
+
189
+ def get_batch_loss (output , labels ):
190
+ shifted_labels = labels [..., 1 :].contiguous ()
191
+ output = output [..., :- 1 , :].contiguous ()
192
+
193
+ loss_function = nn .CrossEntropyLoss (ignore_index = - 100 , reduction = 'none' )
194
+ # get the sum loss for each sequence in a batch
195
+ loss = loss_function (output .transpose (- 1 ,- 2 ), shifted_labels ).sum (dim = - 1 )
196
+
197
+ return loss
198
+
199
+ def model_mix (model ,before ,after ,update_ratio ):
200
+ for name ,parameter in model .named_parameters ():
201
+ parameter .data = update_ratio * before [name [:]].cuda ()+ (1 - update_ratio )* after [name [:]].cuda ()
202
+ return model
203
+
204
+ '''
205
+ import hydra, os
206
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, set_seed
207
+
208
+ @hydra.main(version_base=None, config_path="config", config_name="forget")
209
+ def main(cfg):
210
+ # ------------ DDP Pytorch 分布式训练 ----------- #
211
+
212
+ num_devices = int(os.environ.get('WORLD_SIZE', 1)) # os.environ 获取环境变量
213
+ print(f"num_devices: {num_devices}")
214
+ if os.environ.get('LOCAL_RANK') is not None:
215
+ local_rank = int(os.environ.get('LOCAL_RANK', '0'))
216
+ device_map = {'': local_rank}
217
+ else: local_rank = 0
218
+
219
+ os.environ["WANDB_DISABLED"] = "true"
220
+ # --------------------------------------------- #
221
+
222
+ model_cfg = get_model_identifiers_from_yaml(cfg.model_family)
223
+ model_id = model_cfg["hf_key"] # huggingface key
224
+ if cfg.model_path is None:
225
+ cfg.model_path = model_cfg["ft_model_path"]
226
+
227
+ # save cfg in cfg.save_dir
228
+ if local_rank == 0:
229
+ with open(f"{cfg.save_dir}/config.yaml", "w") as file:
230
+ # omegaconf.save(cfg, file)
231
+ pass
232
+
233
+ if os.path.exists(cfg.save_dir):
234
+ print("Directory already exists")
235
+ if not cfg.overwrite_dir:
236
+ exit()
237
+
238
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
239
+ tokenizer.pad_token = tokenizer.eos_token
240
+
241
+ torch_format_dataset = TextForgetDatasetQA2(cfg.data_path, tokenizer=tokenizer, model_family = cfg.model_family, max_length=500, split='forget01', loss_type='att_')
242
+ #print(torch_format_dataset[1])
243
+ #print(torch_format_dataset[0])
244
+
245
+ if __name__ == "__main__":
246
+ main()
247
+
248
+ '''
0 commit comments