diff --git a/src/dataset_paths_cache.pkl b/src/dataset_paths_cache.pkl deleted file mode 100644 index 010b852..0000000 Binary files a/src/dataset_paths_cache.pkl and /dev/null differ diff --git a/src/emoset_resnet50_best.pth b/src/emoset_resnet50_best.pth index 2098582..450ee8d 100644 Binary files a/src/emoset_resnet50_best.pth and b/src/emoset_resnet50_best.pth differ diff --git a/src/scripts/00_setup_env.sh b/src/scripts/00_setup_env.sh index 77b4591..24c8d01 100644 --- a/src/scripts/00_setup_env.sh +++ b/src/scripts/00_setup_env.sh @@ -1,5 +1,6 @@ #!/bin/bash +# Данный скрипт написан ИИ для быстрой подготовки окружения, установка драйверов и докера # Остановка скрипта при возникновении любой ошибки set -e diff --git a/src/scripts/21_train_images.ipynb b/src/scripts/21_train_images.ipynb index 83c0a4a..98cf8a0 100644 --- a/src/scripts/21_train_images.ipynb +++ b/src/scripts/21_train_images.ipynb @@ -44,7 +44,7 @@ "BATCH_SIZE = 64\n", "EPOCHS = 15\n", "LR = 3e-4\n", - "NUM_WORKERS = 40\n", + "NUM_WORKERS = 62\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Аппаратное ускорение: {device}\")" diff --git a/src/scripts/21_train_images.py b/src/scripts/21_train_images.py new file mode 100644 index 0000000..06548cb --- /dev/null +++ b/src/scripts/21_train_images.py @@ -0,0 +1,184 @@ +import os +import random +import warnings +from pathlib import Path +from PIL import Image +import pandas as pd +import numpy as np +from tqdm import tqdm + +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +import torchvision.transforms as T +import timm + +# Подавление предупреждений цветовых профилей +warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*") + +# Настройки окружения +DATA_ROOT = Path("/home/zin/projects/Thesis/NFS/Thesis/Emoset/EmoSet-118K") +BATCH_SIZE = 64 +EPOCHS = 30 +LR = 5e-5 +NUM_WORKERS = 62 +PATIENCE = 7 + +# Маппинг классов +CLASS_MAPPING = { + "amusement": 0, "anger": 1, "awe": 2, "contentment": 3, + "disgust": 4, "excitement": 5, "fear": 6, "sadness": 7 +} + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Устройство: {DEVICE}") + +# Фиксация генераторов псевдослучайных чисел +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +set_seed() + +# Инициализация структур данных +class EmoSetDataset(Dataset): + def __init__(self, root: Path | str, split: str, transform=None): + self.root = Path(root) / split + self.df = pd.read_csv(self.root / "labels.csv") + self.transform = transform + + # Фильтрация датафрейма + self.df = self.df[self.df["label"].isin(CLASS_MAPPING.keys())].reset_index(drop=True) + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + img_path = self.root / "images" / row["filename"] + + try: + img = Image.open(img_path).convert("RGB") + except Exception: + img = Image.new("RGB", (256, 256), (0, 0, 0)) + + if self.transform: + img_tensor = self.transform(img) + else: + img_tensor = T.ToTensor()(img) + + label_idx = CLASS_MAPPING[row["label"]] + return img_tensor, label_idx + +# Трансформации +base_tf = [ + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +] + +train_transform = T.Compose([ + T.Resize(256, antialias=True), + T.RandomCrop(224), + T.RandomHorizontalFlip(), + *base_tf +]) + +val_transform = T.Compose([ + T.Resize(256, antialias=True), + T.CenterCrop(224), + *base_tf +]) + +train_ds = EmoSetDataset(DATA_ROOT, "train", transform=train_transform) +val_ds = EmoSetDataset(DATA_ROOT, "val", transform=val_transform) + +train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True) +val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) + +# Инициализация модели и оптимизатора +model = timm.create_model("resnet50", pretrained=True, num_classes=8, drop_rate=0.3) +model.to(DEVICE) + +criterion = nn.CrossEntropyLoss(label_smoothing=0.1) + +optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-3) +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) + +# Логика эпохи обучения +def train_epoch(current_model, loader): + current_model.train() + total_loss, correct_preds, total_samples = 0.0, 0, 0 + + for imgs, labels in tqdm(loader, desc="Тренировка", leave=False, smoothing=0): + imgs, labels = imgs.to(DEVICE), labels.to(DEVICE) + + optimizer.zero_grad(set_to_none=True) + logits = current_model(imgs) + loss = criterion(logits, labels) + + loss.backward() + optimizer.step() + + total_loss += loss.item() * imgs.size(0) + preds = logits.argmax(dim=1) + correct_preds += (preds == labels).sum().item() + total_samples += labels.size(0) + + return total_loss / total_samples, correct_preds / total_samples + +# Логика эпохи валидации +@torch.no_grad() +def val_epoch(current_model, loader): + current_model.eval() + total_loss, correct_preds, total_samples = 0.0, 0, 0 + + for imgs, labels in tqdm(loader, desc="Валидация", leave=False, smoothing=0): + imgs, labels = imgs.to(DEVICE), labels.to(DEVICE) + + logits = current_model(imgs) + loss = criterion(logits, labels) + + total_loss += loss.item() * imgs.size(0) + preds = logits.argmax(dim=1) + correct_preds += (preds == labels).sum().item() + total_samples += labels.size(0) + + return total_loss / total_samples, correct_preds / total_samples + +if __name__ == "__main__": + best_val_acc = 0.0 + best_val_loss = float('inf') + epochs_no_improve = 0 + checkpoint_path = "./emosetV2_resnet50_best.pth" + + print("Старт обучения.") + + for epoch in range(1, EPOCHS + 1): + train_loss, train_acc = train_epoch(model, train_loader) + val_loss, val_acc = val_epoch(model, val_loader) + + scheduler.step() + + print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}") + + # Сохранение лучших весов по Accuracy + if val_acc > best_val_acc: + best_val_acc = val_acc + torch.save(model.state_dict(), checkpoint_path) + print(f"Сохранен чекпоинт (Acc: {best_val_acc:.4f})") + + # Оценка переобучения по Loss (Early Stopping) + if val_loss < best_val_loss: + best_val_loss = val_loss + epochs_no_improve = 0 + else: + epochs_no_improve += 1 + if epochs_no_improve >= PATIENCE: + print(f"Ранняя остановка: метрика валидации не улучшается {PATIENCE} эпох.") + break + + print("Процесс завершен.") \ No newline at end of file diff --git a/src/scripts/31_finetune_2.41M.py b/src/scripts/31_finetune_2.41M.py deleted file mode 100644 index 9ae6b18..0000000 --- a/src/scripts/31_finetune_2.41M.py +++ /dev/null @@ -1,268 +0,0 @@ -import os -import gc -import pickle -import random -import ctypes -import warnings -from pathlib import Path - -import torch -import torch.nn as nn -from torch.utils.data import Dataset, DataLoader -import torchvision.transforms as T -import torchvision.io as tv_io -from torch.amp import autocast, GradScaler -from tqdm import tqdm -import timm - -# Подавление предупреждений PIL для корректной работы tqdm -warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*") - -# Настройка устройства -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -# Пути к файлам -DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M") -CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl") - -PREVIOUS_WEIGHTS = Path("/home/zin/projects/Thesis/src/emoset_resnet50_best.pth") -RESUME_CHECKPOINT = Path("/home/zin/projects/Thesis/src/emoset_resnet50_resume.pth") -SAVE_MODEL_PATH = Path("/home/zin/projects/Thesis/src/emoset_resnet50_finetuned_2_41M.pth") - -CLASS_MAPPING = { - "amusement": 0, "anger": 1, "awe": 2, "contentment": 3, - "disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7 -} - -# Параметры обучения -BATCH_SIZE = 64 -EPOCHS = 50 -LR = 5e-5 -NUM_TRAIN_WORKERS = 62 -NUM_VAL_WORKERS = 62 -PATIENCE = 5 - -def prepare_dataset_index(): - # Загрузка или создание индекса файлов - if CACHE_PATH.exists(): - print(f"Загрузка кэша: {CACHE_PATH.name}") - with open(CACHE_PATH, 'rb') as f: - cache_data = pickle.load(f) - return cache_data['image_paths'], cache_data['labels'] - - print(f"Сканирование директории {DATA_ROOT}...") - paths, labels = [], [] - for img_path in DATA_ROOT.rglob('*.jpg'): - emotion_folder = img_path.parts[-3].lower() - if emotion_folder in CLASS_MAPPING: - paths.append(str(img_path)) - labels.append(CLASS_MAPPING[emotion_folder]) - - with open(CACHE_PATH, 'wb') as f: - pickle.dump({'image_paths': paths, 'labels': labels}, f) - - return paths, labels - -class EmoSetDirectDataset(Dataset): - # Датасет с загрузкой по требованию - def __init__(self, image_paths, labels): - self.image_paths = image_paths - self.labels = labels - # Сохранение пропорций и центрирование - self.base_transform = T.Compose([ - T.Resize(256, antialias=True), - T.CenterCrop(256) - ]) - - def __len__(self): - return len(self.image_paths) - - def __getitem__(self, idx): - try: - image = tv_io.read_image(self.image_paths[idx], mode=tv_io.ImageReadMode.RGB) - image = image.to(torch.float32) / 255.0 - image = self.base_transform(image) - except Exception: - # Обработка битых файлов - image = torch.zeros((3, 256, 256), dtype=torch.float32) - return image, self.labels[idx] - -def build_gpu_transforms(): - # Аугментации на GPU - train_tf = torch.nn.Sequential( - T.RandomCrop((224, 224)), - T.RandomHorizontalFlip(p=0.5), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ).to(DEVICE) - - val_tf = torch.nn.Sequential( - T.CenterCrop((224, 224)), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ).to(DEVICE) - - return train_tf, val_tf - -if __name__ == "__main__": - print(f"Инициализация. Устройство: {DEVICE}") - - all_paths, all_labels = prepare_dataset_index() - - # Разделение выборки - random.seed(42) - combined = list(zip(all_paths, all_labels)) - random.shuffle(combined) - all_paths, all_labels = zip(*combined) - - split_idx = int(len(all_paths) * 0.95) - - train_loader = DataLoader( - EmoSetDirectDataset(all_paths[:split_idx], all_labels[:split_idx]), - batch_size=BATCH_SIZE, shuffle=True, - num_workers=NUM_TRAIN_WORKERS, pin_memory=True, - prefetch_factor=3, persistent_workers=False - ) - - val_loader = DataLoader( - EmoSetDirectDataset(all_paths[split_idx:], all_labels[split_idx:]), - batch_size=BATCH_SIZE, shuffle=False, - num_workers=NUM_VAL_WORKERS, pin_memory=True, - prefetch_factor=3, persistent_workers=False - ) - - gpu_train_tf, gpu_val_tf = build_gpu_transforms() - - model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE) - criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) - scaler = GradScaler() - - best_val_loss = float('inf') - epochs_no_improve = 0 - start_epoch = 1 - - # Загрузка весов - if RESUME_CHECKPOINT.exists(): - print(f"Восстановление из: {RESUME_CHECKPOINT.name}") - checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE) - model.load_state_dict(checkpoint['model_state_dict']) - optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - - try: - scheduler.load_state_dict(checkpoint['scheduler_state_dict']) - except Exception: - pass - - if 'scaler_state_dict' in checkpoint: - scaler.load_state_dict(checkpoint['scaler_state_dict']) - if 'best_val_loss' in checkpoint: - best_val_loss = checkpoint['best_val_loss'] - start_epoch = checkpoint['epoch'] + 1 - elif PREVIOUS_WEIGHTS.exists(): - print(f"Загрузка базовых весов: {PREVIOUS_WEIGHTS.name}") - model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE)) - else: - print("Веса не найдены. Инициализация с ImageNet.") - model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE) - - for epoch in range(start_epoch, EPOCHS + 1): - - # Обучение - model.train() - running_loss, correct, total = 0.0, 0, 0 - - pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [Train]", smoothing=0) - for inputs, labels in pbar: - try: - inputs = inputs.to(DEVICE, non_blocking=True) - labels = labels.to(DEVICE, non_blocking=True) - inputs = gpu_train_tf(inputs) - - optimizer.zero_grad(set_to_none=True) - - # Смешанная точность - with autocast(device_type="cuda"): - outputs = model(inputs) - loss = criterion(outputs, labels) - - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - - running_loss += loss.item() * inputs.size(0) - _, predicted = outputs.max(1) - total += labels.size(0) - correct += predicted.eq(labels).sum().item() - - pbar.set_postfix({'loss': f"{loss.item():.4f}"}) - - except RuntimeError as memory_err: - # Очистка памяти при OOM - if "out of memory" in str(memory_err).lower(): - if 'outputs' in locals(): del outputs - if 'loss' in locals(): del loss - torch.cuda.empty_cache() - optimizer.zero_grad(set_to_none=True) - continue - raise memory_err - - train_loss = running_loss / total if total > 0 else 0 - train_acc = correct / total if total > 0 else 0 - - gc.collect() - torch.cuda.empty_cache() - - # Валидация - model.eval() - val_loss, val_correct, val_total = 0.0, 0, 0 - - with torch.no_grad(): - for val_inputs, val_labels in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} [Val]", smoothing = 0): - val_inputs = val_inputs.to(DEVICE, non_blocking=True) - val_labels = val_labels.to(DEVICE, non_blocking=True) - val_inputs = gpu_val_tf(val_inputs) - - with autocast(device_type="cuda"): - val_outputs = model(val_inputs) - v_loss = criterion(val_outputs, val_labels) - - val_loss += v_loss.item() * val_inputs.size(0) - _, val_predicted = val_outputs.max(1) - val_total += val_labels.size(0) - val_correct += val_predicted.eq(val_labels).sum().item() - - epoch_val_loss = val_loss / val_total if val_total > 0 else 0 - epoch_val_acc = val_correct / val_total if val_total > 0 else 0 - - scheduler.step() - print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}") - - # Ранняя остановка и сохранение - if epoch_val_loss < best_val_loss: - best_val_loss = epoch_val_loss - epochs_no_improve = 0 - torch.save(model.state_dict(), str(SAVE_MODEL_PATH).replace(".pth", "_best.pth")) - print("Сохранен новый лучший чекпоинт.") - else: - epochs_no_improve += 1 - if epochs_no_improve >= PATIENCE and epoch >= 25: - print(f"Остановка: валидация не улучшается {PATIENCE} эпох.") - break - - # Сохранение состояния - checkpoint_state = { - 'epoch': epoch, - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'scheduler_state_dict': scheduler.state_dict(), - 'scaler_state_dict': scaler.state_dict(), - 'best_val_loss': best_val_loss - } - torch.save(checkpoint_state, RESUME_CHECKPOINT) - gc.collect() - - if SAVE_MODEL_PATH.parent.exists(): - torch.save(model.state_dict(), SAVE_MODEL_PATH) - print(f"Обучение завершено. Файл: {SAVE_MODEL_PATH.name}") - if RESUME_CHECKPOINT.exists(): - RESUME_CHECKPOINT.unlink() \ No newline at end of file diff --git a/src/scripts/removed_31_finetune_2.41M.py b/src/scripts/removed_31_finetune_2.41M.py new file mode 100644 index 0000000..34c0f0e --- /dev/null +++ b/src/scripts/removed_31_finetune_2.41M.py @@ -0,0 +1,319 @@ +import os +import random +import warnings +from collections import defaultdict +from pathlib import Path +from PIL import Image, ImageFile +import pandas as pd +import numpy as np +from tqdm import tqdm + +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +import torchvision.transforms as T +from torch.amp import autocast, GradScaler +import timm + +# Подавление предупреждений и защита от битых "хвостов" JPEG +warnings.filterwarnings("ignore") +ImageFile.LOAD_TRUNCATED_IMAGES = True + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Устройство: {DEVICE}") + +# --- ПУТИ --- +TRAIN_ROOT = Path("./dataset/Original-2.41M") +ANCHOR_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/train") # ЯКОРЬ (Чистые данные для обучения) +VAL_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/val") + +SAVE_MODEL_PATH = Path("./src/emosetV2_resnet50_finetuned_2_41M.pth") +RESUME_CHECKPOINT = Path("./src/finetuneV2_resume.pth") +PRETRAINED_PATH = Path("./src/emosetV2_resnet50_best.pth") + +CLASS_MAPPING = { + "amusement": 0, "anger": 1, "awe": 2, "contentment": 3, + "disgust": 4, "excitement": 5, "fear": 6, "sadness": 7 +} + +# --- НАСТРОЙКИ --- +TOTAL_BATCH_SIZE = 64 +BATCH_NOISY = 48 # 75% батча - новые данные 2.41M +BATCH_ANCHOR = 16 # 25% батча - чистые якорные данные 118K + +EPOCHS_PER_FOLDER = 15 +PATIENCE = 5 +LR = 1e-6 +NUM_TRAIN_WORKERS = 32 +NUM_VAL_WORKERS = 32 + +def worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + +# --- 1. ТРАНСФОРМАЦИИ --- +train_transform = T.Compose([ + T.Resize(256), + T.RandomResizedCrop(224, scale=(0.8, 1.0)), + T.RandomHorizontalFlip(), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +]) + +val_transform = T.Compose([ + T.Resize(256), + T.CenterCrop(224), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +]) + +# --- 2. ДАТАСЕТЫ --- +class ChunkTrainDataset(Dataset): + def __init__(self, paths, transform): + self.paths = paths + self.transform = transform + + def __len__(self): + return len(self.paths) + + def __getitem__(self, idx): + path = self.paths[idx] + try: + img = Image.open(path).convert('RGB') + tensor = self.transform(img) + label = CLASS_MAPPING.get(path.parts[-3].lower(), 0) + return tensor, label + except Exception: + return torch.zeros((3, 224, 224)), 0 + +class CsvDataset(Dataset): + def __init__(self, root, transform): + self.root = Path(root) + self.df = pd.read_csv(self.root / "labels.csv") + self.transform = transform + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + path = self.root / "images" / row["filename"] + try: + img = Image.open(path).convert('RGB') + tensor = self.transform(img) + label = CLASS_MAPPING.get(row["label"].lower(), 0) + return tensor, label + except Exception: + return torch.zeros((3, 224, 224)), 0 + +# --- 3. СБОР ДАННЫХ --- +def prepare_chunks(): + print("\nСканирование датасета 2.41M...") + chunk_dict = defaultdict(list) + for path in TRAIN_ROOT.rglob('*.jpg'): + emotion = path.parts[-3].lower() + if emotion not in CLASS_MAPPING: + continue + folder_str = path.parts[-2] + if folder_str.isdigit(): + chunk_dict[int(folder_str)].append(path) + + sorted_chunks = sorted(chunk_dict.keys()) + print(f"Найдено пронумерованных папок (чанков): {len(sorted_chunks)}") + return chunk_dict, sorted_chunks + # --- 4. ОСНОВНОЙ ЦИКЛ ОБУЧЕНИЯ --- +if __name__ == "__main__": + chunk_dict, sorted_chunks = prepare_chunks() + + # Валидационный датасет (только чистые данные) + val_loader = DataLoader( + CsvDataset(VAL_118K_ROOT, val_transform), + batch_size=TOTAL_BATCH_SIZE, shuffle=False, + num_workers=NUM_VAL_WORKERS, pin_memory=True + ) + + # ЯКОРНЫЙ ЗАГРУЗЧИК (Чистые данные для подмешивания) + # Используем prefetch_factor и persistent_workers для устранения рывков CPU + anchor_dataset = CsvDataset(ANCHOR_118K_ROOT, train_transform) + anchor_loader = DataLoader( + anchor_dataset, batch_size=BATCH_ANCHOR, shuffle=True, + num_workers=16, pin_memory=True, drop_last=True, + prefetch_factor=2, persistent_workers=False + ) + + # Инициализация модели + model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE) + if PRETRAINED_PATH.exists(): + model.load_state_dict(torch.load(PRETRAINED_PATH, map_location=DEVICE)) + print(f"Базовые веса загружены из {PRETRAINED_PATH.name}") + + # Размораживаем всю модель + for param in model.parameters(): + param.requires_grad = True + + # Дифференцированный оптимизатор + backbone_params = [p for n, p in model.named_parameters() if "fc" not in n] + fc_params = [p for n, p in model.named_parameters() if "fc" in n] + + optimizer = torch.optim.AdamW([ + {'params': backbone_params, 'lr': LR}, # 1e-6: микро-шаг для основы + {'params': fc_params, 'lr': LR * 10} # 1e-5: шаг для классификатора + ], weight_decay=1e-3) + + # Label Smoothing помогает игнорировать мусор в разметке 2.41M + criterion = nn.CrossEntropyLoss(label_smoothing=0.15) + scaler = GradScaler() + + # --- ПАРАМЕТРЫ ВОССТАНОВЛЕНИЯ --- + start_stage = 0 + start_epoch = 1 + best_val_loss = float('inf') + + if RESUME_CHECKPOINT.exists(): + print(f"\nОбнаружен чекпоинт: {RESUME_CHECKPOINT.name}. Восстановление...") + checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE) + model.load_state_dict(checkpoint['model_state_dict']) + + try: + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + except Exception as e: + print(f"Оптимизатор сброшен: {e}") + + best_val_loss = checkpoint['best_val_loss'] + start_stage = checkpoint['stage'] + start_epoch = checkpoint['epoch'] + 1 + print(f"Успешный запуск с ЭТАПА {start_stage + 1}, Эпохи {start_epoch}. Best Val Loss: {best_val_loss:.4f}\n") + else: + # --- ЗАМЕР EPOCH 0 (БАЗОВАЯ ТОЧНОСТЬ) --- + # Выполняется только если мы начинаем с нуля + print("\n[Проверка базовых весов перед обучением (Epoch 0)]") + model.eval() + val_loss, val_correct, val_total = 0.0, 0, 0 + with torch.no_grad(): + for inputs, labels in tqdm(val_loader, desc="Baseline Eval", smoothing=0): + inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) + with autocast(device_type="cuda"): + outputs = model(inputs) + v_loss = criterion(outputs, labels) + val_loss += v_loss.item() * inputs.size(0) + _, pred = outputs.max(1) + val_total += labels.size(0) + val_correct += pred.eq(labels).sum().item() + + best_val_loss = val_loss / val_total + baseline_acc = val_correct / val_total + print(f"Стартовая точка -> Val Loss: {best_val_loss:.4f} | Val Acc: {baseline_acc:.4f}\n") + + # ВОССТАНОВЛЕНИЕ НАКОПЛЕННЫХ ДАННЫХ + current_train_paths = [] + for s in range(start_stage): + current_train_paths.extend(chunk_dict[sorted_chunks[s]]) + + print("Старт Anchor Curriculum Learning (Смешивание чистых и шумных данных).") + + # ГЛАВНЫЙ ЦИКЛ ПО ПАПКАМ + for stage in range(start_stage, len(sorted_chunks)): + chunk_id = sorted_chunks[stage] + print(f"\n{'='*50}") + print(f"ЭТАП {stage+1}/{len(sorted_chunks)}: Добавляем папку '{chunk_id}'") + + # Накопление и перемешивание + current_train_paths.extend(chunk_dict[chunk_id]) + random.shuffle(current_train_paths) + print(f"Всего файлов (грязных) в текущем пуле: {len(current_train_paths)}") + + # ОСНОВНОЙ ЗАГРУЗЧИК (Грязные данные) с PREFETCH + train_loader = DataLoader( + ChunkTrainDataset(current_train_paths, train_transform), + batch_size=BATCH_NOISY, shuffle=True, + num_workers=NUM_TRAIN_WORKERS, pin_memory=True, + worker_init_fn=worker_init_fn, drop_last=True, + prefetch_factor=4, persistent_workers=True # Устраняет рывки CPU + ) + + epochs_no_improve = 0 + first_epoch = start_epoch if stage == start_stage else 1 + + # Инициализация итератора якорей + anchor_iter = iter(anchor_loader) + + # ЦИКЛ ЭПОХ ДЛЯ ТЕКУЩЕГО ЭТАПА + for epoch in range(first_epoch, EPOCHS_PER_FOLDER + 1): + model.train() + train_loss, train_correct, train_total = 0.0, 0, 0 + + for noisy_inputs, noisy_labels in tqdm(train_loader, desc=f"S{stage+1}-Ep{epoch}/{EPOCHS_PER_FOLDER} [Train]", smoothing=0): + + # Достаем якорный чистый батч + try: + anc_inputs, anc_labels = next(anchor_iter) + except StopIteration: + anchor_iter = iter(anchor_loader) + anc_inputs, anc_labels = next(anchor_iter) + + # СМЕШИВАЕМ БАТЧИ (Грязные + Чистые) + inputs = torch.cat([noisy_inputs, anc_inputs]).to(DEVICE) + labels = torch.cat([noisy_labels, anc_labels]).to(DEVICE) + + optimizer.zero_grad(set_to_none=True) + with autocast(device_type="cuda"): + outputs = model(inputs) + loss = criterion(outputs, labels) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + train_loss += loss.item() * inputs.size(0) + _, pred = outputs.max(1) + train_total += labels.size(0) + train_correct += pred.eq(labels).sum().item() + + # ВАЛИДАЦИЯ + model.eval() + val_loss, val_correct, val_total = 0.0, 0, 0 + with torch.no_grad(): + for inputs, labels in tqdm(val_loader, desc="[Val]", leave=False, smoothing=0): + inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) + with autocast(device_type="cuda"): + outputs = model(inputs) + v_loss = criterion(outputs, labels) + val_loss += v_loss.item() * inputs.size(0) + _, pred = outputs.max(1) + val_total += labels.size(0) + val_correct += pred.eq(labels).sum().item() + + avg_train_loss = train_loss / train_total + avg_train_acc = train_correct / train_total + avg_val_loss = val_loss / val_total + avg_val_acc = val_correct / val_total + + print(f"S{stage+1}-E{epoch} | Train L: {avg_train_loss:.4f}, Acc: {avg_train_acc:.4f} | Val L: {avg_val_loss:.4f}, Acc: {avg_val_acc:.4f}") + + # СОХРАНЕНИЕ ЛУЧШИХ ВЕСОВ + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + epochs_no_improve = 0 + torch.save(model.state_dict(), SAVE_MODEL_PATH) + print("--> Обновлены лучшие веса") + else: + epochs_no_improve += 1 + + # АВАРИЙНОЕ СОХРАНЕНИЕ В КОНЦЕ ЭПОХИ + checkpoint_state = { + 'stage': stage, + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'best_val_loss': best_val_loss + } + torch.save(checkpoint_state, RESUME_CHECKPOINT) + os.sync() # Защита от отключения электричества + print(f"--> Чекпоинт (Этап {stage+1}, Эпоха {epoch}) зафиксирован на диске.") + + # РАННЯЯ ОСТАНОВКА ДЛЯ ТЕКУЩЕГО ЭТАПА + if epochs_no_improve >= PATIENCE: + print(f"Ранняя остановка для ЭТАПА {stage+1}. Переход к следующей папке...") + break + + # Сброс счетчика стартовой эпохи после прохождения восстановительного этапа + start_epoch = 1 \ No newline at end of file