feat: commiting model

This commit is contained in:
zin
2026-06-08 14:49:45 +00:00
parent daba573b2c
commit 14968dd4d4
7 changed files with 505 additions and 269 deletions
Binary file not shown.
Binary file not shown.
+1
View File
@@ -1,5 +1,6 @@
#!/bin/bash #!/bin/bash
# Данный скрипт написан ИИ для быстрой подготовки окружения, установка драйверов и докера
# Остановка скрипта при возникновении любой ошибки # Остановка скрипта при возникновении любой ошибки
set -e set -e
+1 -1
View File
@@ -44,7 +44,7 @@
"BATCH_SIZE = 64\n", "BATCH_SIZE = 64\n",
"EPOCHS = 15\n", "EPOCHS = 15\n",
"LR = 3e-4\n", "LR = 3e-4\n",
"NUM_WORKERS = 40\n", "NUM_WORKERS = 62\n",
"\n", "\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Аппаратное ускорение: {device}\")" "print(f\"Аппаратное ускорение: {device}\")"
+184
View File
@@ -0,0 +1,184 @@
import os
import random
import warnings
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm
# Подавление предупреждений цветовых профилей
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
# Настройки окружения
DATA_ROOT = Path("/home/zin/projects/Thesis/NFS/Thesis/Emoset/EmoSet-118K")
BATCH_SIZE = 64
EPOCHS = 30
LR = 5e-5
NUM_WORKERS = 62
PATIENCE = 7
# Маппинг классов
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {DEVICE}")
# Фиксация генераторов псевдослучайных чисел
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed()
# Инициализация структур данных
class EmoSetDataset(Dataset):
def __init__(self, root: Path | str, split: str, transform=None):
self.root = Path(root) / split
self.df = pd.read_csv(self.root / "labels.csv")
self.transform = transform
# Фильтрация датафрейма
self.df = self.df[self.df["label"].isin(CLASS_MAPPING.keys())].reset_index(drop=True)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_path = self.root / "images" / row["filename"]
try:
img = Image.open(img_path).convert("RGB")
except Exception:
img = Image.new("RGB", (256, 256), (0, 0, 0))
if self.transform:
img_tensor = self.transform(img)
else:
img_tensor = T.ToTensor()(img)
label_idx = CLASS_MAPPING[row["label"]]
return img_tensor, label_idx
# Трансформации
base_tf = [
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
train_transform = T.Compose([
T.Resize(256, antialias=True),
T.RandomCrop(224),
T.RandomHorizontalFlip(),
*base_tf
])
val_transform = T.Compose([
T.Resize(256, antialias=True),
T.CenterCrop(224),
*base_tf
])
train_ds = EmoSetDataset(DATA_ROOT, "train", transform=train_transform)
val_ds = EmoSetDataset(DATA_ROOT, "val", transform=val_transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
# Инициализация модели и оптимизатора
model = timm.create_model("resnet50", pretrained=True, num_classes=8, drop_rate=0.3)
model.to(DEVICE)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
# Логика эпохи обучения
def train_epoch(current_model, loader):
current_model.train()
total_loss, correct_preds, total_samples = 0.0, 0, 0
for imgs, labels in tqdm(loader, desc="Тренировка", leave=False, smoothing=0):
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad(set_to_none=True)
logits = current_model(imgs)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * imgs.size(0)
preds = logits.argmax(dim=1)
correct_preds += (preds == labels).sum().item()
total_samples += labels.size(0)
return total_loss / total_samples, correct_preds / total_samples
# Логика эпохи валидации
@torch.no_grad()
def val_epoch(current_model, loader):
current_model.eval()
total_loss, correct_preds, total_samples = 0.0, 0, 0
for imgs, labels in tqdm(loader, desc="Валидация", leave=False, smoothing=0):
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
logits = current_model(imgs)
loss = criterion(logits, labels)
total_loss += loss.item() * imgs.size(0)
preds = logits.argmax(dim=1)
correct_preds += (preds == labels).sum().item()
total_samples += labels.size(0)
return total_loss / total_samples, correct_preds / total_samples
if __name__ == "__main__":
best_val_acc = 0.0
best_val_loss = float('inf')
epochs_no_improve = 0
checkpoint_path = "./emosetV2_resnet50_best.pth"
print("Старт обучения.")
for epoch in range(1, EPOCHS + 1):
train_loss, train_acc = train_epoch(model, train_loader)
val_loss, val_acc = val_epoch(model, val_loader)
scheduler.step()
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
# Сохранение лучших весов по Accuracy
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), checkpoint_path)
print(f"Сохранен чекпоинт (Acc: {best_val_acc:.4f})")
# Оценка переобучения по Loss (Early Stopping)
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= PATIENCE:
print(f"Ранняя остановка: метрика валидации не улучшается {PATIENCE} эпох.")
break
print("Процесс завершен.")
-268
View File
@@ -1,268 +0,0 @@
import os
import gc
import pickle
import random
import ctypes
import warnings
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.io as tv_io
from torch.amp import autocast, GradScaler
from tqdm import tqdm
import timm
# Подавление предупреждений PIL для корректной работы tqdm
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
# Настройка устройства
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Пути к файлам
DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M")
CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl")
PREVIOUS_WEIGHTS = Path("/home/zin/projects/Thesis/src/emoset_resnet50_best.pth")
RESUME_CHECKPOINT = Path("/home/zin/projects/Thesis/src/emoset_resnet50_resume.pth")
SAVE_MODEL_PATH = Path("/home/zin/projects/Thesis/src/emoset_resnet50_finetuned_2_41M.pth")
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7
}
# Параметры обучения
BATCH_SIZE = 64
EPOCHS = 50
LR = 5e-5
NUM_TRAIN_WORKERS = 62
NUM_VAL_WORKERS = 62
PATIENCE = 5
def prepare_dataset_index():
# Загрузка или создание индекса файлов
if CACHE_PATH.exists():
print(f"Загрузка кэша: {CACHE_PATH.name}")
with open(CACHE_PATH, 'rb') as f:
cache_data = pickle.load(f)
return cache_data['image_paths'], cache_data['labels']
print(f"Сканирование директории {DATA_ROOT}...")
paths, labels = [], []
for img_path in DATA_ROOT.rglob('*.jpg'):
emotion_folder = img_path.parts[-3].lower()
if emotion_folder in CLASS_MAPPING:
paths.append(str(img_path))
labels.append(CLASS_MAPPING[emotion_folder])
with open(CACHE_PATH, 'wb') as f:
pickle.dump({'image_paths': paths, 'labels': labels}, f)
return paths, labels
class EmoSetDirectDataset(Dataset):
# Датасет с загрузкой по требованию
def __init__(self, image_paths, labels):
self.image_paths = image_paths
self.labels = labels
# Сохранение пропорций и центрирование
self.base_transform = T.Compose([
T.Resize(256, antialias=True),
T.CenterCrop(256)
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
try:
image = tv_io.read_image(self.image_paths[idx], mode=tv_io.ImageReadMode.RGB)
image = image.to(torch.float32) / 255.0
image = self.base_transform(image)
except Exception:
# Обработка битых файлов
image = torch.zeros((3, 256, 256), dtype=torch.float32)
return image, self.labels[idx]
def build_gpu_transforms():
# Аугментации на GPU
train_tf = torch.nn.Sequential(
T.RandomCrop((224, 224)),
T.RandomHorizontalFlip(p=0.5),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(DEVICE)
val_tf = torch.nn.Sequential(
T.CenterCrop((224, 224)),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(DEVICE)
return train_tf, val_tf
if __name__ == "__main__":
print(f"Инициализация. Устройство: {DEVICE}")
all_paths, all_labels = prepare_dataset_index()
# Разделение выборки
random.seed(42)
combined = list(zip(all_paths, all_labels))
random.shuffle(combined)
all_paths, all_labels = zip(*combined)
split_idx = int(len(all_paths) * 0.95)
train_loader = DataLoader(
EmoSetDirectDataset(all_paths[:split_idx], all_labels[:split_idx]),
batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
prefetch_factor=3, persistent_workers=False
)
val_loader = DataLoader(
EmoSetDirectDataset(all_paths[split_idx:], all_labels[split_idx:]),
batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_VAL_WORKERS, pin_memory=True,
prefetch_factor=3, persistent_workers=False
)
gpu_train_tf, gpu_val_tf = build_gpu_transforms()
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = GradScaler()
best_val_loss = float('inf')
epochs_no_improve = 0
start_epoch = 1
# Загрузка весов
if RESUME_CHECKPOINT.exists():
print(f"Восстановление из: {RESUME_CHECKPOINT.name}")
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
try:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
except Exception:
pass
if 'scaler_state_dict' in checkpoint:
scaler.load_state_dict(checkpoint['scaler_state_dict'])
if 'best_val_loss' in checkpoint:
best_val_loss = checkpoint['best_val_loss']
start_epoch = checkpoint['epoch'] + 1
elif PREVIOUS_WEIGHTS.exists():
print(f"Загрузка базовых весов: {PREVIOUS_WEIGHTS.name}")
model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE))
else:
print("Веса не найдены. Инициализация с ImageNet.")
model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE)
for epoch in range(start_epoch, EPOCHS + 1):
# Обучение
model.train()
running_loss, correct, total = 0.0, 0, 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [Train]", smoothing=0)
for inputs, labels in pbar:
try:
inputs = inputs.to(DEVICE, non_blocking=True)
labels = labels.to(DEVICE, non_blocking=True)
inputs = gpu_train_tf(inputs)
optimizer.zero_grad(set_to_none=True)
# Смешанная точность
with autocast(device_type="cuda"):
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
except RuntimeError as memory_err:
# Очистка памяти при OOM
if "out of memory" in str(memory_err).lower():
if 'outputs' in locals(): del outputs
if 'loss' in locals(): del loss
torch.cuda.empty_cache()
optimizer.zero_grad(set_to_none=True)
continue
raise memory_err
train_loss = running_loss / total if total > 0 else 0
train_acc = correct / total if total > 0 else 0
gc.collect()
torch.cuda.empty_cache()
# Валидация
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for val_inputs, val_labels in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} [Val]", smoothing = 0):
val_inputs = val_inputs.to(DEVICE, non_blocking=True)
val_labels = val_labels.to(DEVICE, non_blocking=True)
val_inputs = gpu_val_tf(val_inputs)
with autocast(device_type="cuda"):
val_outputs = model(val_inputs)
v_loss = criterion(val_outputs, val_labels)
val_loss += v_loss.item() * val_inputs.size(0)
_, val_predicted = val_outputs.max(1)
val_total += val_labels.size(0)
val_correct += val_predicted.eq(val_labels).sum().item()
epoch_val_loss = val_loss / val_total if val_total > 0 else 0
epoch_val_acc = val_correct / val_total if val_total > 0 else 0
scheduler.step()
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}")
# Ранняя остановка и сохранение
if epoch_val_loss < best_val_loss:
best_val_loss = epoch_val_loss
epochs_no_improve = 0
torch.save(model.state_dict(), str(SAVE_MODEL_PATH).replace(".pth", "_best.pth"))
print("Сохранен новый лучший чекпоинт.")
else:
epochs_no_improve += 1
if epochs_no_improve >= PATIENCE and epoch >= 25:
print(f"Остановка: валидация не улучшается {PATIENCE} эпох.")
break
# Сохранение состояния
checkpoint_state = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'best_val_loss': best_val_loss
}
torch.save(checkpoint_state, RESUME_CHECKPOINT)
gc.collect()
if SAVE_MODEL_PATH.parent.exists():
torch.save(model.state_dict(), SAVE_MODEL_PATH)
print(f"Обучение завершено. Файл: {SAVE_MODEL_PATH.name}")
if RESUME_CHECKPOINT.exists():
RESUME_CHECKPOINT.unlink()
+319
View File
@@ -0,0 +1,319 @@
import os
import random
import warnings
from collections import defaultdict
from pathlib import Path
from PIL import Image, ImageFile
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torch.amp import autocast, GradScaler
import timm
# Подавление предупреждений и защита от битых "хвостов" JPEG
warnings.filterwarnings("ignore")
ImageFile.LOAD_TRUNCATED_IMAGES = True
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {DEVICE}")
# --- ПУТИ ---
TRAIN_ROOT = Path("./dataset/Original-2.41M")
ANCHOR_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/train") # ЯКОРЬ (Чистые данные для обучения)
VAL_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/val")
SAVE_MODEL_PATH = Path("./src/emosetV2_resnet50_finetuned_2_41M.pth")
RESUME_CHECKPOINT = Path("./src/finetuneV2_resume.pth")
PRETRAINED_PATH = Path("./src/emosetV2_resnet50_best.pth")
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
}
# --- НАСТРОЙКИ ---
TOTAL_BATCH_SIZE = 64
BATCH_NOISY = 48 # 75% батча - новые данные 2.41M
BATCH_ANCHOR = 16 # 25% батча - чистые якорные данные 118K
EPOCHS_PER_FOLDER = 15
PATIENCE = 5
LR = 1e-6
NUM_TRAIN_WORKERS = 32
NUM_VAL_WORKERS = 32
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
# --- 1. ТРАНСФОРМАЦИИ ---
train_transform = T.Compose([
T.Resize(256),
T.RandomResizedCrop(224, scale=(0.8, 1.0)),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# --- 2. ДАТАСЕТЫ ---
class ChunkTrainDataset(Dataset):
def __init__(self, paths, transform):
self.paths = paths
self.transform = transform
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
path = self.paths[idx]
try:
img = Image.open(path).convert('RGB')
tensor = self.transform(img)
label = CLASS_MAPPING.get(path.parts[-3].lower(), 0)
return tensor, label
except Exception:
return torch.zeros((3, 224, 224)), 0
class CsvDataset(Dataset):
def __init__(self, root, transform):
self.root = Path(root)
self.df = pd.read_csv(self.root / "labels.csv")
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
path = self.root / "images" / row["filename"]
try:
img = Image.open(path).convert('RGB')
tensor = self.transform(img)
label = CLASS_MAPPING.get(row["label"].lower(), 0)
return tensor, label
except Exception:
return torch.zeros((3, 224, 224)), 0
# --- 3. СБОР ДАННЫХ ---
def prepare_chunks():
print("\nСканирование датасета 2.41M...")
chunk_dict = defaultdict(list)
for path in TRAIN_ROOT.rglob('*.jpg'):
emotion = path.parts[-3].lower()
if emotion not in CLASS_MAPPING:
continue
folder_str = path.parts[-2]
if folder_str.isdigit():
chunk_dict[int(folder_str)].append(path)
sorted_chunks = sorted(chunk_dict.keys())
print(f"Найдено пронумерованных папок (чанков): {len(sorted_chunks)}")
return chunk_dict, sorted_chunks
# --- 4. ОСНОВНОЙ ЦИКЛ ОБУЧЕНИЯ ---
if __name__ == "__main__":
chunk_dict, sorted_chunks = prepare_chunks()
# Валидационный датасет (только чистые данные)
val_loader = DataLoader(
CsvDataset(VAL_118K_ROOT, val_transform),
batch_size=TOTAL_BATCH_SIZE, shuffle=False,
num_workers=NUM_VAL_WORKERS, pin_memory=True
)
# ЯКОРНЫЙ ЗАГРУЗЧИК (Чистые данные для подмешивания)
# Используем prefetch_factor и persistent_workers для устранения рывков CPU
anchor_dataset = CsvDataset(ANCHOR_118K_ROOT, train_transform)
anchor_loader = DataLoader(
anchor_dataset, batch_size=BATCH_ANCHOR, shuffle=True,
num_workers=16, pin_memory=True, drop_last=True,
prefetch_factor=2, persistent_workers=False
)
# Инициализация модели
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
if PRETRAINED_PATH.exists():
model.load_state_dict(torch.load(PRETRAINED_PATH, map_location=DEVICE))
print(f"Базовые веса загружены из {PRETRAINED_PATH.name}")
# Размораживаем всю модель
for param in model.parameters():
param.requires_grad = True
# Дифференцированный оптимизатор
backbone_params = [p for n, p in model.named_parameters() if "fc" not in n]
fc_params = [p for n, p in model.named_parameters() if "fc" in n]
optimizer = torch.optim.AdamW([
{'params': backbone_params, 'lr': LR}, # 1e-6: микро-шаг для основы
{'params': fc_params, 'lr': LR * 10} # 1e-5: шаг для классификатора
], weight_decay=1e-3)
# Label Smoothing помогает игнорировать мусор в разметке 2.41M
criterion = nn.CrossEntropyLoss(label_smoothing=0.15)
scaler = GradScaler()
# --- ПАРАМЕТРЫ ВОССТАНОВЛЕНИЯ ---
start_stage = 0
start_epoch = 1
best_val_loss = float('inf')
if RESUME_CHECKPOINT.exists():
print(f"\nОбнаружен чекпоинт: {RESUME_CHECKPOINT.name}. Восстановление...")
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
try:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
except Exception as e:
print(f"Оптимизатор сброшен: {e}")
best_val_loss = checkpoint['best_val_loss']
start_stage = checkpoint['stage']
start_epoch = checkpoint['epoch'] + 1
print(f"Успешный запуск с ЭТАПА {start_stage + 1}, Эпохи {start_epoch}. Best Val Loss: {best_val_loss:.4f}\n")
else:
# --- ЗАМЕР EPOCH 0 (БАЗОВАЯ ТОЧНОСТЬ) ---
# Выполняется только если мы начинаем с нуля
print("\n[Проверка базовых весов перед обучением (Epoch 0)]")
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc="Baseline Eval", smoothing=0):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
with autocast(device_type="cuda"):
outputs = model(inputs)
v_loss = criterion(outputs, labels)
val_loss += v_loss.item() * inputs.size(0)
_, pred = outputs.max(1)
val_total += labels.size(0)
val_correct += pred.eq(labels).sum().item()
best_val_loss = val_loss / val_total
baseline_acc = val_correct / val_total
print(f"Стартовая точка -> Val Loss: {best_val_loss:.4f} | Val Acc: {baseline_acc:.4f}\n")
# ВОССТАНОВЛЕНИЕ НАКОПЛЕННЫХ ДАННЫХ
current_train_paths = []
for s in range(start_stage):
current_train_paths.extend(chunk_dict[sorted_chunks[s]])
print("Старт Anchor Curriculum Learning (Смешивание чистых и шумных данных).")
# ГЛАВНЫЙ ЦИКЛ ПО ПАПКАМ
for stage in range(start_stage, len(sorted_chunks)):
chunk_id = sorted_chunks[stage]
print(f"\n{'='*50}")
print(f"ЭТАП {stage+1}/{len(sorted_chunks)}: Добавляем папку '{chunk_id}'")
# Накопление и перемешивание
current_train_paths.extend(chunk_dict[chunk_id])
random.shuffle(current_train_paths)
print(f"Всего файлов (грязных) в текущем пуле: {len(current_train_paths)}")
# ОСНОВНОЙ ЗАГРУЗЧИК (Грязные данные) с PREFETCH
train_loader = DataLoader(
ChunkTrainDataset(current_train_paths, train_transform),
batch_size=BATCH_NOISY, shuffle=True,
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
worker_init_fn=worker_init_fn, drop_last=True,
prefetch_factor=4, persistent_workers=True # Устраняет рывки CPU
)
epochs_no_improve = 0
first_epoch = start_epoch if stage == start_stage else 1
# Инициализация итератора якорей
anchor_iter = iter(anchor_loader)
# ЦИКЛ ЭПОХ ДЛЯ ТЕКУЩЕГО ЭТАПА
for epoch in range(first_epoch, EPOCHS_PER_FOLDER + 1):
model.train()
train_loss, train_correct, train_total = 0.0, 0, 0
for noisy_inputs, noisy_labels in tqdm(train_loader, desc=f"S{stage+1}-Ep{epoch}/{EPOCHS_PER_FOLDER} [Train]", smoothing=0):
# Достаем якорный чистый батч
try:
anc_inputs, anc_labels = next(anchor_iter)
except StopIteration:
anchor_iter = iter(anchor_loader)
anc_inputs, anc_labels = next(anchor_iter)
# СМЕШИВАЕМ БАТЧИ (Грязные + Чистые)
inputs = torch.cat([noisy_inputs, anc_inputs]).to(DEVICE)
labels = torch.cat([noisy_labels, anc_labels]).to(DEVICE)
optimizer.zero_grad(set_to_none=True)
with autocast(device_type="cuda"):
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item() * inputs.size(0)
_, pred = outputs.max(1)
train_total += labels.size(0)
train_correct += pred.eq(labels).sum().item()
# ВАЛИДАЦИЯ
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc="[Val]", leave=False, smoothing=0):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
with autocast(device_type="cuda"):
outputs = model(inputs)
v_loss = criterion(outputs, labels)
val_loss += v_loss.item() * inputs.size(0)
_, pred = outputs.max(1)
val_total += labels.size(0)
val_correct += pred.eq(labels).sum().item()
avg_train_loss = train_loss / train_total
avg_train_acc = train_correct / train_total
avg_val_loss = val_loss / val_total
avg_val_acc = val_correct / val_total
print(f"S{stage+1}-E{epoch} | Train L: {avg_train_loss:.4f}, Acc: {avg_train_acc:.4f} | Val L: {avg_val_loss:.4f}, Acc: {avg_val_acc:.4f}")
# СОХРАНЕНИЕ ЛУЧШИХ ВЕСОВ
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
epochs_no_improve = 0
torch.save(model.state_dict(), SAVE_MODEL_PATH)
print("--> Обновлены лучшие веса")
else:
epochs_no_improve += 1
# АВАРИЙНОЕ СОХРАНЕНИЕ В КОНЦЕ ЭПОХИ
checkpoint_state = {
'stage': stage,
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss
}
torch.save(checkpoint_state, RESUME_CHECKPOINT)
os.sync() # Защита от отключения электричества
print(f"--> Чекпоинт (Этап {stage+1}, Эпоха {epoch}) зафиксирован на диске.")
# РАННЯЯ ОСТАНОВКА ДЛЯ ТЕКУЩЕГО ЭТАПА
if epochs_no_improve >= PATIENCE:
print(f"Ранняя остановка для ЭТАПА {stage+1}. Переход к следующей папке...")
break
# Сброс счетчика стартовой эпохи после прохождения восстановительного этапа
start_epoch = 1