import os import gc import pickle import random import ctypes import warnings from pathlib import Path import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import torchvision.transforms as T import torchvision.io as tv_io from torch.amp import autocast, GradScaler from tqdm import tqdm import timm # Подавление предупреждений PIL для корректной работы tqdm warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*") # Настройка устройства DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Пути к файлам DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M") CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl") PREVIOUS_WEIGHTS = Path("/home/zin/projects/Thesis/src/emoset_resnet50_best.pth") RESUME_CHECKPOINT = Path("/home/zin/projects/Thesis/src/emoset_resnet50_resume.pth") SAVE_MODEL_PATH = Path("/home/zin/projects/Thesis/src/emoset_resnet50_finetuned_2_41M.pth") CLASS_MAPPING = { "amusement": 0, "anger": 1, "awe": 2, "contentment": 3, "disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7 } # Параметры обучения BATCH_SIZE = 64 EPOCHS = 50 LR = 5e-5 NUM_TRAIN_WORKERS = 62 NUM_VAL_WORKERS = 62 PATIENCE = 5 def prepare_dataset_index(): # Загрузка или создание индекса файлов if CACHE_PATH.exists(): print(f"Загрузка кэша: {CACHE_PATH.name}") with open(CACHE_PATH, 'rb') as f: cache_data = pickle.load(f) return cache_data['image_paths'], cache_data['labels'] print(f"Сканирование директории {DATA_ROOT}...") paths, labels = [], [] for img_path in DATA_ROOT.rglob('*.jpg'): emotion_folder = img_path.parts[-3].lower() if emotion_folder in CLASS_MAPPING: paths.append(str(img_path)) labels.append(CLASS_MAPPING[emotion_folder]) with open(CACHE_PATH, 'wb') as f: pickle.dump({'image_paths': paths, 'labels': labels}, f) return paths, labels class EmoSetDirectDataset(Dataset): # Датасет с загрузкой по требованию def __init__(self, image_paths, labels): self.image_paths = image_paths self.labels = labels # Сохранение пропорций и центрирование self.base_transform = T.Compose([ T.Resize(256, antialias=True), T.CenterCrop(256) ]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): try: image = tv_io.read_image(self.image_paths[idx], mode=tv_io.ImageReadMode.RGB) image = image.to(torch.float32) / 255.0 image = self.base_transform(image) except Exception: # Обработка битых файлов image = torch.zeros((3, 256, 256), dtype=torch.float32) return image, self.labels[idx] def build_gpu_transforms(): # Аугментации на GPU train_tf = torch.nn.Sequential( T.RandomCrop((224, 224)), T.RandomHorizontalFlip(p=0.5), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ).to(DEVICE) val_tf = torch.nn.Sequential( T.CenterCrop((224, 224)), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ).to(DEVICE) return train_tf, val_tf if __name__ == "__main__": print(f"Инициализация. Устройство: {DEVICE}") all_paths, all_labels = prepare_dataset_index() # Разделение выборки random.seed(42) combined = list(zip(all_paths, all_labels)) random.shuffle(combined) all_paths, all_labels = zip(*combined) split_idx = int(len(all_paths) * 0.95) train_loader = DataLoader( EmoSetDirectDataset(all_paths[:split_idx], all_labels[:split_idx]), batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_TRAIN_WORKERS, pin_memory=True, prefetch_factor=3, persistent_workers=False ) val_loader = DataLoader( EmoSetDirectDataset(all_paths[split_idx:], all_labels[split_idx:]), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_VAL_WORKERS, pin_memory=True, prefetch_factor=3, persistent_workers=False ) gpu_train_tf, gpu_val_tf = build_gpu_transforms() model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) scaler = GradScaler() best_val_loss = float('inf') epochs_no_improve = 0 start_epoch = 1 # Загрузка весов if RESUME_CHECKPOINT.exists(): print(f"Восстановление из: {RESUME_CHECKPOINT.name}") checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) try: scheduler.load_state_dict(checkpoint['scheduler_state_dict']) except Exception: pass if 'scaler_state_dict' in checkpoint: scaler.load_state_dict(checkpoint['scaler_state_dict']) if 'best_val_loss' in checkpoint: best_val_loss = checkpoint['best_val_loss'] start_epoch = checkpoint['epoch'] + 1 elif PREVIOUS_WEIGHTS.exists(): print(f"Загрузка базовых весов: {PREVIOUS_WEIGHTS.name}") model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE)) else: print("Веса не найдены. Инициализация с ImageNet.") model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE) for epoch in range(start_epoch, EPOCHS + 1): # Обучение model.train() running_loss, correct, total = 0.0, 0, 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [Train]", smoothing=0) for inputs, labels in pbar: try: inputs = inputs.to(DEVICE, non_blocking=True) labels = labels.to(DEVICE, non_blocking=True) inputs = gpu_train_tf(inputs) optimizer.zero_grad(set_to_none=True) # Смешанная точность with autocast(device_type="cuda"): outputs = model(inputs) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() running_loss += loss.item() * inputs.size(0) _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() pbar.set_postfix({'loss': f"{loss.item():.4f}"}) except RuntimeError as memory_err: # Очистка памяти при OOM if "out of memory" in str(memory_err).lower(): if 'outputs' in locals(): del outputs if 'loss' in locals(): del loss torch.cuda.empty_cache() optimizer.zero_grad(set_to_none=True) continue raise memory_err train_loss = running_loss / total if total > 0 else 0 train_acc = correct / total if total > 0 else 0 gc.collect() torch.cuda.empty_cache() # Валидация model.eval() val_loss, val_correct, val_total = 0.0, 0, 0 with torch.no_grad(): for val_inputs, val_labels in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} [Val]", smoothing = 0): val_inputs = val_inputs.to(DEVICE, non_blocking=True) val_labels = val_labels.to(DEVICE, non_blocking=True) val_inputs = gpu_val_tf(val_inputs) with autocast(device_type="cuda"): val_outputs = model(val_inputs) v_loss = criterion(val_outputs, val_labels) val_loss += v_loss.item() * val_inputs.size(0) _, val_predicted = val_outputs.max(1) val_total += val_labels.size(0) val_correct += val_predicted.eq(val_labels).sum().item() epoch_val_loss = val_loss / val_total if val_total > 0 else 0 epoch_val_acc = val_correct / val_total if val_total > 0 else 0 scheduler.step() print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}") # Ранняя остановка и сохранение if epoch_val_loss < best_val_loss: best_val_loss = epoch_val_loss epochs_no_improve = 0 torch.save(model.state_dict(), str(SAVE_MODEL_PATH).replace(".pth", "_best.pth")) print("Сохранен новый лучший чекпоинт.") else: epochs_no_improve += 1 if epochs_no_improve >= PATIENCE and epoch >= 25: print(f"Остановка: валидация не улучшается {PATIENCE} эпох.") break # Сохранение состояния checkpoint_state = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': scaler.state_dict(), 'best_val_loss': best_val_loss } torch.save(checkpoint_state, RESUME_CHECKPOINT) gc.collect() if SAVE_MODEL_PATH.parent.exists(): torch.save(model.state_dict(), SAVE_MODEL_PATH) print(f"Обучение завершено. Файл: {SAVE_MODEL_PATH.name}") if RESUME_CHECKPOINT.exists(): RESUME_CHECKPOINT.unlink()