Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,13 @@ def process_batch(
images = []

conv[0]['image'] = Image.open(conv[len(conv)-2]['image']).convert("RGB")
conv[len(conv)-2]['image'] = None
# conv[len(conv)-2]['image'] = None

for message in conv:
loss_mask_val = (
False if message["role"] in ("system", "user", "observation") else True
)

print(message)
new_input_ids_all = tokenizer.apply_chat_template(
[message], tokenize=True, return_dict=True, padding=True
)
Expand Down Expand Up @@ -342,16 +342,19 @@ def process_batch_eval(
batched_images = []

for conv in batched_conv:

print(conv)
print("----------------divide--------------------")
idx = len(conv) - 2
conv[0]['image'] = Image.open(conv[idx]["image"]).convert("RGB")
conv[idx]['image'] = None
# conv[idx]['image'] = None

new_input_ids_all = tokenizer.apply_chat_template(
conv, tokenize=True, return_dict=True, padding=True
)

input_ids = new_input_ids_all["input_ids"][0]
print("input_ids")
print(input_ids)
attention_mask = new_input_ids_all["attention_mask"][0]
position_ids = list(range(len(input_ids)))

Expand All @@ -371,14 +374,14 @@ def process_batch_eval(
output_segment = input_ids[
dialogue_parts[end_idx - 1] : dialogue_parts[end_idx]
]

output_segment.append(151336) # Add EOS token

batched_input_ids.append(input_segment[:max_input_length])
batched_attention_mask.append(attention_segment[:max_input_length])
batched_position_ids.append(position_segment[:max_input_length])
batched_output_ids.append(output_segment[:max_output_length])
batched_images.append(new_input_ids_all["images"][0])


del (
batched_conv,
Expand All @@ -389,7 +392,6 @@ def process_batch_eval(
output_segment,
)
torch.cuda.empty_cache()

return {
"input_ids": batched_input_ids,
"attention_mask": batched_attention_mask,
Expand Down Expand Up @@ -430,12 +432,24 @@ def compute_metrics(eval_preds: EvalPrediction, tokenizer):
batched_pred_ids[batched_pred_ids == -100] = tokenizer.pad_token_id
batched_label_ids[batched_label_ids == -100] = tokenizer.pad_token_id
metrics_dct = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
# 选择 batched_pred_ids 中索引为偶数的元素和 batched_label_ids 中索引为奇数的元素
even_pred_ids = [batched_pred_ids[i] for i in range(len(batched_pred_ids)) if i % 2 == 0]
odd_label_ids = [batched_label_ids[i] for i in range(len(batched_label_ids)) if i % 2 == 1]
for pred_ids, label_ids in zip(even_pred_ids, odd_label_ids):
if (hasattr(pred_ids, 'size') and pred_ids.size == 0) or (not hasattr(pred_ids, 'size') and len(pred_ids) == 0):
print("Error: Empty prediction!")
continue
pred_txt = tokenizer.decode(pred_ids, skip_special_tokens=True).strip()
label_txt = tokenizer.decode(label_ids, skip_special_tokens=True).strip()
pred_tokens = list(jieba.cut(pred_txt))
label_tokens = list(jieba.cut(label_txt))
rouge = Rouge()
if not pred_tokens or not label_tokens:
print(f"Warning: Empty tokens. Prediction: {pred_tokens}, Label: {label_tokens}")
# 如果为空,可能是分词失败或数据问题,采取合适的措施(例如跳过,填充默认值等)
pred_tokens = ["[EMPTY]"] # 或者填充一个默认值
label_tokens = ["[EMPTY]"]
# continue # 跳过当前样本或处理为空的情况
scores = rouge.get_scores(" ".join(pred_tokens), " ".join(label_tokens))
for k, v in scores[0].items():
metrics_dct[k].append(round(v["f"] * 100, 4))
Expand Down