| import os
|
| import json
|
| from transformers import AutoTokenizer, AutoModel
|
| import torch
|
| import torch.nn.functional as F
|
| from transformers import AutoTokenizer, AutoModel
|
| import torch
|
| import torch.nn.functional as F
|
|
|
| def encode_number_to_char(number):
|
| mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
|
| return mapping.get(number, None)
|
|
|
| def mean_pooling(model_output, attention_mask):
|
| token_embeddings = model_output[0]
|
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
|
|
|
|
| txt_dir = './dev_output'
|
| json_dir = './dev_data'
|
| show_detail_path = './show_detail'
|
|
|
|
|
| txt_files = {os.path.splitext(file)[0]: os.path.join(txt_dir, file) for file in os.listdir(txt_dir) if file.endswith('.txt')}
|
| json_files = {os.path.splitext(file)[0]: os.path.join(json_dir, file) for file in os.listdir(json_dir) if file.endswith('.json')}
|
| tokenizer = AutoTokenizer.from_pretrained('/disk6/hyhong/QWen/Sbert_pretrain')
|
| model = AutoModel.from_pretrained('/disk6/hyhong/QWen/Sbert_pretrain')
|
| question_num = 0
|
|
|
| tau_num = 0
|
| l3d_num = 0
|
| neu_num = 0
|
| huck_num = 0
|
| india_num = 0
|
|
|
|
|
| correct_num = 0
|
| correct_tau_num = 0
|
| correct_l3d_num = 0
|
| correct_neu_num = 0
|
| correct_huck_num = 0
|
| correct_india_num = 0
|
|
|
| score = []
|
|
|
| for txt_id, txt_path in txt_files.items():
|
| if txt_id in json_files:
|
| json_path = json_files[txt_id]
|
| with open(json_path, 'r', encoding='utf-8') as json_file:
|
| json_data = json.load(json_file)
|
| question_num += 1
|
| if "-a-" in txt_path:
|
| l3d_num += 1
|
| choices = [choice[3:] for choice in json_data.get("choice", [])]
|
| if "-b-" in txt_path:
|
| tau_num += 1
|
| choices = [choice[3:] for choice in json_data.get("choice", [])]
|
| if "-c-" in txt_path:
|
| neu_num += 1
|
| choices = [choice[3:] for choice in json_data.get("choice", [])]
|
| if "-d-" in txt_path:
|
| huck_num += 1
|
| choices = [choice[3:] for choice in json_data.get("choice", [])]
|
| if "-e-" in txt_path:
|
| india_num += 1
|
| choices = [choice[3:] for choice in json_data.get("choice", [])]
|
|
|
| with open(txt_path, 'r', encoding='utf-8') as txt_file:
|
| txt_sentence = txt_file.read().strip()
|
|
|
|
|
|
|
| question = json_data["question"]
|
| answer_all = json_data["answer"]
|
| answer = answer_all[0]
|
|
|
| txt_name = os.path.basename(txt_path)
|
| show_detail_txt = os.path.join(show_detail_path, txt_name)
|
|
|
| detail_txt = open(show_detail_txt, 'w')
|
| detail_txt.write(json_data["audio_url"])
|
| detail_txt.write('\n')
|
| detail_txt.write("Question:" + json_data["question"])
|
| detail_txt.write('\n')
|
| detail_txt.write("Choice:")
|
| detail_txt.write('\n')
|
| detail_txt.write(json_data["choice"][0])
|
| detail_txt.write('\n')
|
| if len(json_data["choice"]) == 2:
|
| detail_txt.write(json_data["choice"][1])
|
| detail_txt.write('\n')
|
| if len(json_data["choice"]) == 3:
|
| detail_txt.write(json_data["choice"][2])
|
| detail_txt.write('\n')
|
| if len(json_data["choice"]) == 4:
|
| detail_txt.write(json_data["choice"][3])
|
| detail_txt.write('\n')
|
| detail_txt.write("Correct answer:" + answer)
|
| detail_txt.write('\n')
|
|
|
| new_lists = [[txt_sentence, choice] for choice in choices]
|
| score = []
|
| for i, new_list in enumerate(new_lists, start=1):
|
| qwen_response = new_list[0]
|
| new_list[0] = question + " Answer:" + new_list[0]
|
| new_list[1] = question + " Answer:" + new_list[1]
|
| encoded_input = tokenizer(new_list, padding=True, truncation=True, return_tensors='pt')
|
| with torch.no_grad():
|
| model_output = model(**encoded_input)
|
| sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
| sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
| sentence_embeddings_1 = sentence_embeddings[0,:]
|
| sentence_embeddings_2 = sentence_embeddings[1,:]
|
| dot_product = torch.dot(sentence_embeddings_1, sentence_embeddings_2)
|
| score.append(dot_product)
|
|
|
|
|
| max_score = max(score)
|
| max_index = score.index(max_score)
|
| answer_qwen = encode_number_to_char(max_index)
|
| detail_txt.write("Model respond:" + qwen_response)
|
| detail_txt.write('\n')
|
| detail_txt.write("Model answer:" + answer_qwen)
|
| detail_txt.close()
|
|
|
| if answer_qwen == answer:
|
| correct_num += 1
|
| if "-a-" in txt_path:
|
| correct_l3d_num += 1
|
| if "-b-" in txt_path:
|
| correct_tau_num += 1
|
| if "-c-" in txt_path:
|
| correct_neu_num += 1
|
| if "-d-" in txt_path:
|
| correct_huck_num += 1
|
| if "-e-" in txt_path:
|
| correct_india_num += 1
|
|
|
| correct_rate_all = correct_num / question_num
|
| correct_rate_part2 = (correct_l3d_num + correct_tau_num + correct_neu_num) / (l3d_num + tau_num + neu_num)
|
|
|
|
|
| correct_rate_part1 = correct_huck_num / huck_num
|
|
|
| correct_rate_part3 = correct_india_num / india_num
|
|
|
| print('Overall accuracy rate:')
|
| print(correct_rate_all)
|
|
|
| print('PART 1 accuracy rate:')
|
| print(correct_rate_part1)
|
| print('PART 2 accuracy rate:')
|
| print(correct_rate_part2)
|
| print('PART 3 accuracy rate:')
|
| print(correct_rate_part3)
|
|
|
|
|
|
|
|
|