BERT的中文问答系统20

import os import json import jsonlines import torch import torch.optim as optim from torch.utils.data import Dataset, DataLoader, DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from transformers import BertModel, BertTokenizer import tkinter as tk from tkinter import filedialog, messagebox, scrolledtext, ttk import logging from difflib import SequenceMatcher from datetime import datetime from torch.cuda.amp import GradScaler, autocast import torch.multiprocessing as mp import psutil # 获取项目根目录 PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) # 配置日志 LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs') os.makedirs(LOGS_DIR, exist_ok=True) def setup_logging(): log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d/%H-%M-%S/羲和.txt')) os.makedirs(os.path.dirname(log_file), exist_ok=True) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file), logging.StreamHandler() ] ) setup_logging() # 数据集类 class XihuaDataset(Dataset): def __init__(self, file_path, tokenizer, max_length=128): self.tokenizer = tokenizer self.max_length = max_length self.data = self.load_data(file_path) def load_data(self, file_path): data = [] if file_path.endswith('.jsonl'): with jsonlines.open(file_path) as reader: for i, item in enumerate(reader): try: data.append(item) except jsonlines.jsonlines.InvalidLineError as e: logging.warning(f"跳过无效行 { i + 1}: { e}") elif file_path.endswith('.json'): with open(file_path, 'r') as f: try: data = json.load(f) except json.JSONDecodeError as e: logging.warning(f"跳过无效文件 { file_path}: { e}") return data def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] question = item['question'] human_answer = item['human_answers'][0] chatgpt_answer = item['chatgpt_answers'][0] try: inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length) human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length) chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length) except Exception as e: logging.warning(f"跳过无效项 { idx}: { e}") return self.__getitem__((idx + 1) % len(self.data)) return { 'input_ids': inputs['input_ids'].squeeze(), 'attention_mask': inputs['attention_mask'].squeeze(), 'human_input_ids': human_inputs['input_ids'].squeeze(), 'human_attention_mask': human_inputs['attention_mask'].squeeze(), 'chatgpt_input_ids': chatgpt_inputs[
上一篇:基于SpringBoot足球场在线预约系统的设计与实现


下一篇:LeetCode:2747. 统计没有收到请求的服务器数目(滑动窗口 Java)