import os import json from transformers import AutoTokenizer, AutoModel import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel import torch import torch.nn.functional as F def encode_number_to_char(number): mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D'} return mapping.get(number, None) #Mean Pooling - Take attention mask into account for correct averaging def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] #First element of model_output contains all token embeddings input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) txt_dir = './dev_output' json_dir = './dev_data' show_detail_path = './show_detail' txt_files = {os.path.splitext(file)[0]: os.path.join(txt_dir, file) for file in os.listdir(txt_dir) if file.endswith('.txt')} json_files = {os.path.splitext(file)[0]: os.path.join(json_dir, file) for file in os.listdir(json_dir) if file.endswith('.json')} tokenizer = AutoTokenizer.from_pretrained('/disk6/hyhong/QWen/Sbert_pretrain') # init model = AutoModel.from_pretrained('/disk6/hyhong/QWen/Sbert_pretrain') question_num = 0 tau_num = 0 l3d_num = 0 neu_num = 0 huck_num = 0 india_num = 0 correct_num = 0 correct_tau_num = 0 correct_l3d_num = 0 correct_neu_num = 0 correct_huck_num = 0 correct_india_num = 0 score = [] for txt_id, txt_path in txt_files.items(): if txt_id in json_files: json_path = json_files[txt_id] with open(json_path, 'r', encoding='utf-8') as json_file: json_data = json.load(json_file) question_num += 1 if "-a-" in txt_path: l3d_num += 1 choices = [choice[3:] for choice in json_data.get("choice", [])] if "-b-" in txt_path: tau_num += 1 choices = [choice[3:] for choice in json_data.get("choice", [])] if "-c-" in txt_path: neu_num += 1 choices = [choice[3:] for choice in json_data.get("choice", [])] if "-d-" in txt_path: huck_num += 1 choices = [choice[3:] for choice in json_data.get("choice", [])] if "-e-" in txt_path: india_num += 1 choices = [choice[3:] for choice in json_data.get("choice", [])] with open(txt_path, 'r', encoding='utf-8') as txt_file: txt_sentence = txt_file.read().strip() # choices = [choice[3:] for choice in json_data.get("choice", [])] #choices = [choice for choice in json_data.get("choice", [])] question = json_data["question"] answer_all = json_data["answer"] answer = answer_all[0] txt_name = os.path.basename(txt_path) show_detail_txt = os.path.join(show_detail_path, txt_name) detail_txt = open(show_detail_txt, 'w') detail_txt.write(json_data["audio_url"]) detail_txt.write('\n') detail_txt.write("Question:" + json_data["question"]) detail_txt.write('\n') detail_txt.write("Choice:") detail_txt.write('\n') detail_txt.write(json_data["choice"][0]) detail_txt.write('\n') if len(json_data["choice"]) == 2: detail_txt.write(json_data["choice"][1]) detail_txt.write('\n') if len(json_data["choice"]) == 3: detail_txt.write(json_data["choice"][2]) detail_txt.write('\n') if len(json_data["choice"]) == 4: detail_txt.write(json_data["choice"][3]) detail_txt.write('\n') detail_txt.write("Correct answer:" + answer) detail_txt.write('\n') new_lists = [[txt_sentence, choice] for choice in choices] score = [] for i, new_list in enumerate(new_lists, start=1): qwen_response = new_list[0] new_list[0] = question + " Answer:" + new_list[0] new_list[1] = question + " Answer:" + new_list[1] encoded_input = tokenizer(new_list, padding=True, truncation=True, return_tensors='pt') with torch.no_grad(): model_output = model(**encoded_input) sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) sentence_embeddings_1 = sentence_embeddings[0,:] sentence_embeddings_2 = sentence_embeddings[1,:] dot_product = torch.dot(sentence_embeddings_1, sentence_embeddings_2) score.append(dot_product) #print(f"File: {txt_id}") #print(f"List {i}: {new_list}") max_score = max(score) max_index = score.index(max_score) answer_qwen = encode_number_to_char(max_index) detail_txt.write("Model respond:" + qwen_response) detail_txt.write('\n') detail_txt.write("Model answer:" + answer_qwen) detail_txt.close() if answer_qwen == answer: correct_num += 1 if "-a-" in txt_path: correct_l3d_num += 1 if "-b-" in txt_path: correct_tau_num += 1 if "-c-" in txt_path: correct_neu_num += 1 if "-d-" in txt_path: correct_huck_num += 1 if "-e-" in txt_path: correct_india_num += 1 correct_rate_all = correct_num / question_num correct_rate_part2 = (correct_l3d_num + correct_tau_num + correct_neu_num) / (l3d_num + tau_num + neu_num) #correct_rate_tau = correct_tau_num / tau_num #correct_rate_neu = correct_neu_num / neu_num correct_rate_part1 = correct_huck_num / huck_num correct_rate_part3 = correct_india_num / india_num print('Overall accuracy rate:') print(correct_rate_all) print('PART 1 accuracy rate:') print(correct_rate_part1) print('PART 2 accuracy rate:') print(correct_rate_part2) print('PART 3 accuracy rate:') print(correct_rate_part3)