feat: commiting model
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -1,5 +1,6 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Данный скрипт написан ИИ для быстрой подготовки окружения, установка драйверов и докера
|
||||||
# Остановка скрипта при возникновении любой ошибки
|
# Остановка скрипта при возникновении любой ошибки
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@
|
|||||||
"BATCH_SIZE = 64\n",
|
"BATCH_SIZE = 64\n",
|
||||||
"EPOCHS = 15\n",
|
"EPOCHS = 15\n",
|
||||||
"LR = 3e-4\n",
|
"LR = 3e-4\n",
|
||||||
"NUM_WORKERS = 40\n",
|
"NUM_WORKERS = 62\n",
|
||||||
"\n",
|
"\n",
|
||||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||||
"print(f\"Аппаратное ускорение: {device}\")"
|
"print(f\"Аппаратное ускорение: {device}\")"
|
||||||
|
|||||||
@@ -0,0 +1,184 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import torchvision.transforms as T
|
||||||
|
import timm
|
||||||
|
|
||||||
|
# Подавление предупреждений цветовых профилей
|
||||||
|
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
|
||||||
|
|
||||||
|
# Настройки окружения
|
||||||
|
DATA_ROOT = Path("/home/zin/projects/Thesis/NFS/Thesis/Emoset/EmoSet-118K")
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
EPOCHS = 30
|
||||||
|
LR = 5e-5
|
||||||
|
NUM_WORKERS = 62
|
||||||
|
PATIENCE = 7
|
||||||
|
|
||||||
|
# Маппинг классов
|
||||||
|
CLASS_MAPPING = {
|
||||||
|
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
|
||||||
|
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
|
||||||
|
}
|
||||||
|
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Устройство: {DEVICE}")
|
||||||
|
|
||||||
|
# Фиксация генераторов псевдослучайных чисел
|
||||||
|
def set_seed(seed=42):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
set_seed()
|
||||||
|
|
||||||
|
# Инициализация структур данных
|
||||||
|
class EmoSetDataset(Dataset):
|
||||||
|
def __init__(self, root: Path | str, split: str, transform=None):
|
||||||
|
self.root = Path(root) / split
|
||||||
|
self.df = pd.read_csv(self.root / "labels.csv")
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
# Фильтрация датафрейма
|
||||||
|
self.df = self.df[self.df["label"].isin(CLASS_MAPPING.keys())].reset_index(drop=True)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.df)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
row = self.df.iloc[idx]
|
||||||
|
img_path = self.root / "images" / row["filename"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
img = Image.open(img_path).convert("RGB")
|
||||||
|
except Exception:
|
||||||
|
img = Image.new("RGB", (256, 256), (0, 0, 0))
|
||||||
|
|
||||||
|
if self.transform:
|
||||||
|
img_tensor = self.transform(img)
|
||||||
|
else:
|
||||||
|
img_tensor = T.ToTensor()(img)
|
||||||
|
|
||||||
|
label_idx = CLASS_MAPPING[row["label"]]
|
||||||
|
return img_tensor, label_idx
|
||||||
|
|
||||||
|
# Трансформации
|
||||||
|
base_tf = [
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
]
|
||||||
|
|
||||||
|
train_transform = T.Compose([
|
||||||
|
T.Resize(256, antialias=True),
|
||||||
|
T.RandomCrop(224),
|
||||||
|
T.RandomHorizontalFlip(),
|
||||||
|
*base_tf
|
||||||
|
])
|
||||||
|
|
||||||
|
val_transform = T.Compose([
|
||||||
|
T.Resize(256, antialias=True),
|
||||||
|
T.CenterCrop(224),
|
||||||
|
*base_tf
|
||||||
|
])
|
||||||
|
|
||||||
|
train_ds = EmoSetDataset(DATA_ROOT, "train", transform=train_transform)
|
||||||
|
val_ds = EmoSetDataset(DATA_ROOT, "val", transform=val_transform)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
|
||||||
|
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
|
||||||
|
|
||||||
|
# Инициализация модели и оптимизатора
|
||||||
|
model = timm.create_model("resnet50", pretrained=True, num_classes=8, drop_rate=0.3)
|
||||||
|
model.to(DEVICE)
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-3)
|
||||||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
||||||
|
|
||||||
|
# Логика эпохи обучения
|
||||||
|
def train_epoch(current_model, loader):
|
||||||
|
current_model.train()
|
||||||
|
total_loss, correct_preds, total_samples = 0.0, 0, 0
|
||||||
|
|
||||||
|
for imgs, labels in tqdm(loader, desc="Тренировка", leave=False, smoothing=0):
|
||||||
|
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
|
||||||
|
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
logits = current_model(imgs)
|
||||||
|
loss = criterion(logits, labels)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item() * imgs.size(0)
|
||||||
|
preds = logits.argmax(dim=1)
|
||||||
|
correct_preds += (preds == labels).sum().item()
|
||||||
|
total_samples += labels.size(0)
|
||||||
|
|
||||||
|
return total_loss / total_samples, correct_preds / total_samples
|
||||||
|
|
||||||
|
# Логика эпохи валидации
|
||||||
|
@torch.no_grad()
|
||||||
|
def val_epoch(current_model, loader):
|
||||||
|
current_model.eval()
|
||||||
|
total_loss, correct_preds, total_samples = 0.0, 0, 0
|
||||||
|
|
||||||
|
for imgs, labels in tqdm(loader, desc="Валидация", leave=False, smoothing=0):
|
||||||
|
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
|
||||||
|
|
||||||
|
logits = current_model(imgs)
|
||||||
|
loss = criterion(logits, labels)
|
||||||
|
|
||||||
|
total_loss += loss.item() * imgs.size(0)
|
||||||
|
preds = logits.argmax(dim=1)
|
||||||
|
correct_preds += (preds == labels).sum().item()
|
||||||
|
total_samples += labels.size(0)
|
||||||
|
|
||||||
|
return total_loss / total_samples, correct_preds / total_samples
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
best_val_acc = 0.0
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
epochs_no_improve = 0
|
||||||
|
checkpoint_path = "./emosetV2_resnet50_best.pth"
|
||||||
|
|
||||||
|
print("Старт обучения.")
|
||||||
|
|
||||||
|
for epoch in range(1, EPOCHS + 1):
|
||||||
|
train_loss, train_acc = train_epoch(model, train_loader)
|
||||||
|
val_loss, val_acc = val_epoch(model, val_loader)
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
|
||||||
|
|
||||||
|
# Сохранение лучших весов по Accuracy
|
||||||
|
if val_acc > best_val_acc:
|
||||||
|
best_val_acc = val_acc
|
||||||
|
torch.save(model.state_dict(), checkpoint_path)
|
||||||
|
print(f"Сохранен чекпоинт (Acc: {best_val_acc:.4f})")
|
||||||
|
|
||||||
|
# Оценка переобучения по Loss (Early Stopping)
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
epochs_no_improve = 0
|
||||||
|
else:
|
||||||
|
epochs_no_improve += 1
|
||||||
|
if epochs_no_improve >= PATIENCE:
|
||||||
|
print(f"Ранняя остановка: метрика валидации не улучшается {PATIENCE} эпох.")
|
||||||
|
break
|
||||||
|
|
||||||
|
print("Процесс завершен.")
|
||||||
@@ -1,268 +0,0 @@
|
|||||||
import os
|
|
||||||
import gc
|
|
||||||
import pickle
|
|
||||||
import random
|
|
||||||
import ctypes
|
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
import torchvision.transforms as T
|
|
||||||
import torchvision.io as tv_io
|
|
||||||
from torch.amp import autocast, GradScaler
|
|
||||||
from tqdm import tqdm
|
|
||||||
import timm
|
|
||||||
|
|
||||||
# Подавление предупреждений PIL для корректной работы tqdm
|
|
||||||
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
|
|
||||||
|
|
||||||
# Настройка устройства
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
# Пути к файлам
|
|
||||||
DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M")
|
|
||||||
CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl")
|
|
||||||
|
|
||||||
PREVIOUS_WEIGHTS = Path("/home/zin/projects/Thesis/src/emoset_resnet50_best.pth")
|
|
||||||
RESUME_CHECKPOINT = Path("/home/zin/projects/Thesis/src/emoset_resnet50_resume.pth")
|
|
||||||
SAVE_MODEL_PATH = Path("/home/zin/projects/Thesis/src/emoset_resnet50_finetuned_2_41M.pth")
|
|
||||||
|
|
||||||
CLASS_MAPPING = {
|
|
||||||
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
|
|
||||||
"disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7
|
|
||||||
}
|
|
||||||
|
|
||||||
# Параметры обучения
|
|
||||||
BATCH_SIZE = 64
|
|
||||||
EPOCHS = 50
|
|
||||||
LR = 5e-5
|
|
||||||
NUM_TRAIN_WORKERS = 62
|
|
||||||
NUM_VAL_WORKERS = 62
|
|
||||||
PATIENCE = 5
|
|
||||||
|
|
||||||
def prepare_dataset_index():
|
|
||||||
# Загрузка или создание индекса файлов
|
|
||||||
if CACHE_PATH.exists():
|
|
||||||
print(f"Загрузка кэша: {CACHE_PATH.name}")
|
|
||||||
with open(CACHE_PATH, 'rb') as f:
|
|
||||||
cache_data = pickle.load(f)
|
|
||||||
return cache_data['image_paths'], cache_data['labels']
|
|
||||||
|
|
||||||
print(f"Сканирование директории {DATA_ROOT}...")
|
|
||||||
paths, labels = [], []
|
|
||||||
for img_path in DATA_ROOT.rglob('*.jpg'):
|
|
||||||
emotion_folder = img_path.parts[-3].lower()
|
|
||||||
if emotion_folder in CLASS_MAPPING:
|
|
||||||
paths.append(str(img_path))
|
|
||||||
labels.append(CLASS_MAPPING[emotion_folder])
|
|
||||||
|
|
||||||
with open(CACHE_PATH, 'wb') as f:
|
|
||||||
pickle.dump({'image_paths': paths, 'labels': labels}, f)
|
|
||||||
|
|
||||||
return paths, labels
|
|
||||||
|
|
||||||
class EmoSetDirectDataset(Dataset):
|
|
||||||
# Датасет с загрузкой по требованию
|
|
||||||
def __init__(self, image_paths, labels):
|
|
||||||
self.image_paths = image_paths
|
|
||||||
self.labels = labels
|
|
||||||
# Сохранение пропорций и центрирование
|
|
||||||
self.base_transform = T.Compose([
|
|
||||||
T.Resize(256, antialias=True),
|
|
||||||
T.CenterCrop(256)
|
|
||||||
])
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.image_paths)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
try:
|
|
||||||
image = tv_io.read_image(self.image_paths[idx], mode=tv_io.ImageReadMode.RGB)
|
|
||||||
image = image.to(torch.float32) / 255.0
|
|
||||||
image = self.base_transform(image)
|
|
||||||
except Exception:
|
|
||||||
# Обработка битых файлов
|
|
||||||
image = torch.zeros((3, 256, 256), dtype=torch.float32)
|
|
||||||
return image, self.labels[idx]
|
|
||||||
|
|
||||||
def build_gpu_transforms():
|
|
||||||
# Аугментации на GPU
|
|
||||||
train_tf = torch.nn.Sequential(
|
|
||||||
T.RandomCrop((224, 224)),
|
|
||||||
T.RandomHorizontalFlip(p=0.5),
|
|
||||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
||||||
).to(DEVICE)
|
|
||||||
|
|
||||||
val_tf = torch.nn.Sequential(
|
|
||||||
T.CenterCrop((224, 224)),
|
|
||||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
||||||
).to(DEVICE)
|
|
||||||
|
|
||||||
return train_tf, val_tf
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print(f"Инициализация. Устройство: {DEVICE}")
|
|
||||||
|
|
||||||
all_paths, all_labels = prepare_dataset_index()
|
|
||||||
|
|
||||||
# Разделение выборки
|
|
||||||
random.seed(42)
|
|
||||||
combined = list(zip(all_paths, all_labels))
|
|
||||||
random.shuffle(combined)
|
|
||||||
all_paths, all_labels = zip(*combined)
|
|
||||||
|
|
||||||
split_idx = int(len(all_paths) * 0.95)
|
|
||||||
|
|
||||||
train_loader = DataLoader(
|
|
||||||
EmoSetDirectDataset(all_paths[:split_idx], all_labels[:split_idx]),
|
|
||||||
batch_size=BATCH_SIZE, shuffle=True,
|
|
||||||
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
|
|
||||||
prefetch_factor=3, persistent_workers=False
|
|
||||||
)
|
|
||||||
|
|
||||||
val_loader = DataLoader(
|
|
||||||
EmoSetDirectDataset(all_paths[split_idx:], all_labels[split_idx:]),
|
|
||||||
batch_size=BATCH_SIZE, shuffle=False,
|
|
||||||
num_workers=NUM_VAL_WORKERS, pin_memory=True,
|
|
||||||
prefetch_factor=3, persistent_workers=False
|
|
||||||
)
|
|
||||||
|
|
||||||
gpu_train_tf, gpu_val_tf = build_gpu_transforms()
|
|
||||||
|
|
||||||
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
|
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
|
||||||
scaler = GradScaler()
|
|
||||||
|
|
||||||
best_val_loss = float('inf')
|
|
||||||
epochs_no_improve = 0
|
|
||||||
start_epoch = 1
|
|
||||||
|
|
||||||
# Загрузка весов
|
|
||||||
if RESUME_CHECKPOINT.exists():
|
|
||||||
print(f"Восстановление из: {RESUME_CHECKPOINT.name}")
|
|
||||||
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
|
|
||||||
model.load_state_dict(checkpoint['model_state_dict'])
|
|
||||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
||||||
|
|
||||||
try:
|
|
||||||
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if 'scaler_state_dict' in checkpoint:
|
|
||||||
scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
||||||
if 'best_val_loss' in checkpoint:
|
|
||||||
best_val_loss = checkpoint['best_val_loss']
|
|
||||||
start_epoch = checkpoint['epoch'] + 1
|
|
||||||
elif PREVIOUS_WEIGHTS.exists():
|
|
||||||
print(f"Загрузка базовых весов: {PREVIOUS_WEIGHTS.name}")
|
|
||||||
model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE))
|
|
||||||
else:
|
|
||||||
print("Веса не найдены. Инициализация с ImageNet.")
|
|
||||||
model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE)
|
|
||||||
|
|
||||||
for epoch in range(start_epoch, EPOCHS + 1):
|
|
||||||
|
|
||||||
# Обучение
|
|
||||||
model.train()
|
|
||||||
running_loss, correct, total = 0.0, 0, 0
|
|
||||||
|
|
||||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [Train]", smoothing=0)
|
|
||||||
for inputs, labels in pbar:
|
|
||||||
try:
|
|
||||||
inputs = inputs.to(DEVICE, non_blocking=True)
|
|
||||||
labels = labels.to(DEVICE, non_blocking=True)
|
|
||||||
inputs = gpu_train_tf(inputs)
|
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
|
||||||
|
|
||||||
# Смешанная точность
|
|
||||||
with autocast(device_type="cuda"):
|
|
||||||
outputs = model(inputs)
|
|
||||||
loss = criterion(outputs, labels)
|
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
|
|
||||||
running_loss += loss.item() * inputs.size(0)
|
|
||||||
_, predicted = outputs.max(1)
|
|
||||||
total += labels.size(0)
|
|
||||||
correct += predicted.eq(labels).sum().item()
|
|
||||||
|
|
||||||
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
|
||||||
|
|
||||||
except RuntimeError as memory_err:
|
|
||||||
# Очистка памяти при OOM
|
|
||||||
if "out of memory" in str(memory_err).lower():
|
|
||||||
if 'outputs' in locals(): del outputs
|
|
||||||
if 'loss' in locals(): del loss
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
|
||||||
continue
|
|
||||||
raise memory_err
|
|
||||||
|
|
||||||
train_loss = running_loss / total if total > 0 else 0
|
|
||||||
train_acc = correct / total if total > 0 else 0
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Валидация
|
|
||||||
model.eval()
|
|
||||||
val_loss, val_correct, val_total = 0.0, 0, 0
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for val_inputs, val_labels in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} [Val]", smoothing = 0):
|
|
||||||
val_inputs = val_inputs.to(DEVICE, non_blocking=True)
|
|
||||||
val_labels = val_labels.to(DEVICE, non_blocking=True)
|
|
||||||
val_inputs = gpu_val_tf(val_inputs)
|
|
||||||
|
|
||||||
with autocast(device_type="cuda"):
|
|
||||||
val_outputs = model(val_inputs)
|
|
||||||
v_loss = criterion(val_outputs, val_labels)
|
|
||||||
|
|
||||||
val_loss += v_loss.item() * val_inputs.size(0)
|
|
||||||
_, val_predicted = val_outputs.max(1)
|
|
||||||
val_total += val_labels.size(0)
|
|
||||||
val_correct += val_predicted.eq(val_labels).sum().item()
|
|
||||||
|
|
||||||
epoch_val_loss = val_loss / val_total if val_total > 0 else 0
|
|
||||||
epoch_val_acc = val_correct / val_total if val_total > 0 else 0
|
|
||||||
|
|
||||||
scheduler.step()
|
|
||||||
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}")
|
|
||||||
|
|
||||||
# Ранняя остановка и сохранение
|
|
||||||
if epoch_val_loss < best_val_loss:
|
|
||||||
best_val_loss = epoch_val_loss
|
|
||||||
epochs_no_improve = 0
|
|
||||||
torch.save(model.state_dict(), str(SAVE_MODEL_PATH).replace(".pth", "_best.pth"))
|
|
||||||
print("Сохранен новый лучший чекпоинт.")
|
|
||||||
else:
|
|
||||||
epochs_no_improve += 1
|
|
||||||
if epochs_no_improve >= PATIENCE and epoch >= 25:
|
|
||||||
print(f"Остановка: валидация не улучшается {PATIENCE} эпох.")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Сохранение состояния
|
|
||||||
checkpoint_state = {
|
|
||||||
'epoch': epoch,
|
|
||||||
'model_state_dict': model.state_dict(),
|
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
|
||||||
'scheduler_state_dict': scheduler.state_dict(),
|
|
||||||
'scaler_state_dict': scaler.state_dict(),
|
|
||||||
'best_val_loss': best_val_loss
|
|
||||||
}
|
|
||||||
torch.save(checkpoint_state, RESUME_CHECKPOINT)
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
if SAVE_MODEL_PATH.parent.exists():
|
|
||||||
torch.save(model.state_dict(), SAVE_MODEL_PATH)
|
|
||||||
print(f"Обучение завершено. Файл: {SAVE_MODEL_PATH.name}")
|
|
||||||
if RESUME_CHECKPOINT.exists():
|
|
||||||
RESUME_CHECKPOINT.unlink()
|
|
||||||
@@ -0,0 +1,319 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from torch.amp import autocast, GradScaler
|
||||||
|
import timm
|
||||||
|
|
||||||
|
# Подавление предупреждений и защита от битых "хвостов" JPEG
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Устройство: {DEVICE}")
|
||||||
|
|
||||||
|
# --- ПУТИ ---
|
||||||
|
TRAIN_ROOT = Path("./dataset/Original-2.41M")
|
||||||
|
ANCHOR_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/train") # ЯКОРЬ (Чистые данные для обучения)
|
||||||
|
VAL_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/val")
|
||||||
|
|
||||||
|
SAVE_MODEL_PATH = Path("./src/emosetV2_resnet50_finetuned_2_41M.pth")
|
||||||
|
RESUME_CHECKPOINT = Path("./src/finetuneV2_resume.pth")
|
||||||
|
PRETRAINED_PATH = Path("./src/emosetV2_resnet50_best.pth")
|
||||||
|
|
||||||
|
CLASS_MAPPING = {
|
||||||
|
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
|
||||||
|
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- НАСТРОЙКИ ---
|
||||||
|
TOTAL_BATCH_SIZE = 64
|
||||||
|
BATCH_NOISY = 48 # 75% батча - новые данные 2.41M
|
||||||
|
BATCH_ANCHOR = 16 # 25% батча - чистые якорные данные 118K
|
||||||
|
|
||||||
|
EPOCHS_PER_FOLDER = 15
|
||||||
|
PATIENCE = 5
|
||||||
|
LR = 1e-6
|
||||||
|
NUM_TRAIN_WORKERS = 32
|
||||||
|
NUM_VAL_WORKERS = 32
|
||||||
|
|
||||||
|
def worker_init_fn(worker_id):
|
||||||
|
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||||||
|
|
||||||
|
# --- 1. ТРАНСФОРМАЦИИ ---
|
||||||
|
train_transform = T.Compose([
|
||||||
|
T.Resize(256),
|
||||||
|
T.RandomResizedCrop(224, scale=(0.8, 1.0)),
|
||||||
|
T.RandomHorizontalFlip(),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
val_transform = T.Compose([
|
||||||
|
T.Resize(256),
|
||||||
|
T.CenterCrop(224),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
# --- 2. ДАТАСЕТЫ ---
|
||||||
|
class ChunkTrainDataset(Dataset):
|
||||||
|
def __init__(self, paths, transform):
|
||||||
|
self.paths = paths
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
path = self.paths[idx]
|
||||||
|
try:
|
||||||
|
img = Image.open(path).convert('RGB')
|
||||||
|
tensor = self.transform(img)
|
||||||
|
label = CLASS_MAPPING.get(path.parts[-3].lower(), 0)
|
||||||
|
return tensor, label
|
||||||
|
except Exception:
|
||||||
|
return torch.zeros((3, 224, 224)), 0
|
||||||
|
|
||||||
|
class CsvDataset(Dataset):
|
||||||
|
def __init__(self, root, transform):
|
||||||
|
self.root = Path(root)
|
||||||
|
self.df = pd.read_csv(self.root / "labels.csv")
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.df)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
row = self.df.iloc[idx]
|
||||||
|
path = self.root / "images" / row["filename"]
|
||||||
|
try:
|
||||||
|
img = Image.open(path).convert('RGB')
|
||||||
|
tensor = self.transform(img)
|
||||||
|
label = CLASS_MAPPING.get(row["label"].lower(), 0)
|
||||||
|
return tensor, label
|
||||||
|
except Exception:
|
||||||
|
return torch.zeros((3, 224, 224)), 0
|
||||||
|
|
||||||
|
# --- 3. СБОР ДАННЫХ ---
|
||||||
|
def prepare_chunks():
|
||||||
|
print("\nСканирование датасета 2.41M...")
|
||||||
|
chunk_dict = defaultdict(list)
|
||||||
|
for path in TRAIN_ROOT.rglob('*.jpg'):
|
||||||
|
emotion = path.parts[-3].lower()
|
||||||
|
if emotion not in CLASS_MAPPING:
|
||||||
|
continue
|
||||||
|
folder_str = path.parts[-2]
|
||||||
|
if folder_str.isdigit():
|
||||||
|
chunk_dict[int(folder_str)].append(path)
|
||||||
|
|
||||||
|
sorted_chunks = sorted(chunk_dict.keys())
|
||||||
|
print(f"Найдено пронумерованных папок (чанков): {len(sorted_chunks)}")
|
||||||
|
return chunk_dict, sorted_chunks
|
||||||
|
# --- 4. ОСНОВНОЙ ЦИКЛ ОБУЧЕНИЯ ---
|
||||||
|
if __name__ == "__main__":
|
||||||
|
chunk_dict, sorted_chunks = prepare_chunks()
|
||||||
|
|
||||||
|
# Валидационный датасет (только чистые данные)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
CsvDataset(VAL_118K_ROOT, val_transform),
|
||||||
|
batch_size=TOTAL_BATCH_SIZE, shuffle=False,
|
||||||
|
num_workers=NUM_VAL_WORKERS, pin_memory=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# ЯКОРНЫЙ ЗАГРУЗЧИК (Чистые данные для подмешивания)
|
||||||
|
# Используем prefetch_factor и persistent_workers для устранения рывков CPU
|
||||||
|
anchor_dataset = CsvDataset(ANCHOR_118K_ROOT, train_transform)
|
||||||
|
anchor_loader = DataLoader(
|
||||||
|
anchor_dataset, batch_size=BATCH_ANCHOR, shuffle=True,
|
||||||
|
num_workers=16, pin_memory=True, drop_last=True,
|
||||||
|
prefetch_factor=2, persistent_workers=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Инициализация модели
|
||||||
|
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
|
||||||
|
if PRETRAINED_PATH.exists():
|
||||||
|
model.load_state_dict(torch.load(PRETRAINED_PATH, map_location=DEVICE))
|
||||||
|
print(f"Базовые веса загружены из {PRETRAINED_PATH.name}")
|
||||||
|
|
||||||
|
# Размораживаем всю модель
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
|
||||||
|
# Дифференцированный оптимизатор
|
||||||
|
backbone_params = [p for n, p in model.named_parameters() if "fc" not in n]
|
||||||
|
fc_params = [p for n, p in model.named_parameters() if "fc" in n]
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW([
|
||||||
|
{'params': backbone_params, 'lr': LR}, # 1e-6: микро-шаг для основы
|
||||||
|
{'params': fc_params, 'lr': LR * 10} # 1e-5: шаг для классификатора
|
||||||
|
], weight_decay=1e-3)
|
||||||
|
|
||||||
|
# Label Smoothing помогает игнорировать мусор в разметке 2.41M
|
||||||
|
criterion = nn.CrossEntropyLoss(label_smoothing=0.15)
|
||||||
|
scaler = GradScaler()
|
||||||
|
|
||||||
|
# --- ПАРАМЕТРЫ ВОССТАНОВЛЕНИЯ ---
|
||||||
|
start_stage = 0
|
||||||
|
start_epoch = 1
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
|
||||||
|
if RESUME_CHECKPOINT.exists():
|
||||||
|
print(f"\nОбнаружен чекпоинт: {RESUME_CHECKPOINT.name}. Восстановление...")
|
||||||
|
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
|
||||||
|
try:
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Оптимизатор сброшен: {e}")
|
||||||
|
|
||||||
|
best_val_loss = checkpoint['best_val_loss']
|
||||||
|
start_stage = checkpoint['stage']
|
||||||
|
start_epoch = checkpoint['epoch'] + 1
|
||||||
|
print(f"Успешный запуск с ЭТАПА {start_stage + 1}, Эпохи {start_epoch}. Best Val Loss: {best_val_loss:.4f}\n")
|
||||||
|
else:
|
||||||
|
# --- ЗАМЕР EPOCH 0 (БАЗОВАЯ ТОЧНОСТЬ) ---
|
||||||
|
# Выполняется только если мы начинаем с нуля
|
||||||
|
print("\n[Проверка базовых весов перед обучением (Epoch 0)]")
|
||||||
|
model.eval()
|
||||||
|
val_loss, val_correct, val_total = 0.0, 0, 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, labels in tqdm(val_loader, desc="Baseline Eval", smoothing=0):
|
||||||
|
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||||
|
with autocast(device_type="cuda"):
|
||||||
|
outputs = model(inputs)
|
||||||
|
v_loss = criterion(outputs, labels)
|
||||||
|
val_loss += v_loss.item() * inputs.size(0)
|
||||||
|
_, pred = outputs.max(1)
|
||||||
|
val_total += labels.size(0)
|
||||||
|
val_correct += pred.eq(labels).sum().item()
|
||||||
|
|
||||||
|
best_val_loss = val_loss / val_total
|
||||||
|
baseline_acc = val_correct / val_total
|
||||||
|
print(f"Стартовая точка -> Val Loss: {best_val_loss:.4f} | Val Acc: {baseline_acc:.4f}\n")
|
||||||
|
|
||||||
|
# ВОССТАНОВЛЕНИЕ НАКОПЛЕННЫХ ДАННЫХ
|
||||||
|
current_train_paths = []
|
||||||
|
for s in range(start_stage):
|
||||||
|
current_train_paths.extend(chunk_dict[sorted_chunks[s]])
|
||||||
|
|
||||||
|
print("Старт Anchor Curriculum Learning (Смешивание чистых и шумных данных).")
|
||||||
|
|
||||||
|
# ГЛАВНЫЙ ЦИКЛ ПО ПАПКАМ
|
||||||
|
for stage in range(start_stage, len(sorted_chunks)):
|
||||||
|
chunk_id = sorted_chunks[stage]
|
||||||
|
print(f"\n{'='*50}")
|
||||||
|
print(f"ЭТАП {stage+1}/{len(sorted_chunks)}: Добавляем папку '{chunk_id}'")
|
||||||
|
|
||||||
|
# Накопление и перемешивание
|
||||||
|
current_train_paths.extend(chunk_dict[chunk_id])
|
||||||
|
random.shuffle(current_train_paths)
|
||||||
|
print(f"Всего файлов (грязных) в текущем пуле: {len(current_train_paths)}")
|
||||||
|
|
||||||
|
# ОСНОВНОЙ ЗАГРУЗЧИК (Грязные данные) с PREFETCH
|
||||||
|
train_loader = DataLoader(
|
||||||
|
ChunkTrainDataset(current_train_paths, train_transform),
|
||||||
|
batch_size=BATCH_NOISY, shuffle=True,
|
||||||
|
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
|
||||||
|
worker_init_fn=worker_init_fn, drop_last=True,
|
||||||
|
prefetch_factor=4, persistent_workers=True # Устраняет рывки CPU
|
||||||
|
)
|
||||||
|
|
||||||
|
epochs_no_improve = 0
|
||||||
|
first_epoch = start_epoch if stage == start_stage else 1
|
||||||
|
|
||||||
|
# Инициализация итератора якорей
|
||||||
|
anchor_iter = iter(anchor_loader)
|
||||||
|
|
||||||
|
# ЦИКЛ ЭПОХ ДЛЯ ТЕКУЩЕГО ЭТАПА
|
||||||
|
for epoch in range(first_epoch, EPOCHS_PER_FOLDER + 1):
|
||||||
|
model.train()
|
||||||
|
train_loss, train_correct, train_total = 0.0, 0, 0
|
||||||
|
|
||||||
|
for noisy_inputs, noisy_labels in tqdm(train_loader, desc=f"S{stage+1}-Ep{epoch}/{EPOCHS_PER_FOLDER} [Train]", smoothing=0):
|
||||||
|
|
||||||
|
# Достаем якорный чистый батч
|
||||||
|
try:
|
||||||
|
anc_inputs, anc_labels = next(anchor_iter)
|
||||||
|
except StopIteration:
|
||||||
|
anchor_iter = iter(anchor_loader)
|
||||||
|
anc_inputs, anc_labels = next(anchor_iter)
|
||||||
|
|
||||||
|
# СМЕШИВАЕМ БАТЧИ (Грязные + Чистые)
|
||||||
|
inputs = torch.cat([noisy_inputs, anc_inputs]).to(DEVICE)
|
||||||
|
labels = torch.cat([noisy_labels, anc_labels]).to(DEVICE)
|
||||||
|
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
with autocast(device_type="cuda"):
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
train_loss += loss.item() * inputs.size(0)
|
||||||
|
_, pred = outputs.max(1)
|
||||||
|
train_total += labels.size(0)
|
||||||
|
train_correct += pred.eq(labels).sum().item()
|
||||||
|
|
||||||
|
# ВАЛИДАЦИЯ
|
||||||
|
model.eval()
|
||||||
|
val_loss, val_correct, val_total = 0.0, 0, 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, labels in tqdm(val_loader, desc="[Val]", leave=False, smoothing=0):
|
||||||
|
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||||
|
with autocast(device_type="cuda"):
|
||||||
|
outputs = model(inputs)
|
||||||
|
v_loss = criterion(outputs, labels)
|
||||||
|
val_loss += v_loss.item() * inputs.size(0)
|
||||||
|
_, pred = outputs.max(1)
|
||||||
|
val_total += labels.size(0)
|
||||||
|
val_correct += pred.eq(labels).sum().item()
|
||||||
|
|
||||||
|
avg_train_loss = train_loss / train_total
|
||||||
|
avg_train_acc = train_correct / train_total
|
||||||
|
avg_val_loss = val_loss / val_total
|
||||||
|
avg_val_acc = val_correct / val_total
|
||||||
|
|
||||||
|
print(f"S{stage+1}-E{epoch} | Train L: {avg_train_loss:.4f}, Acc: {avg_train_acc:.4f} | Val L: {avg_val_loss:.4f}, Acc: {avg_val_acc:.4f}")
|
||||||
|
|
||||||
|
# СОХРАНЕНИЕ ЛУЧШИХ ВЕСОВ
|
||||||
|
if avg_val_loss < best_val_loss:
|
||||||
|
best_val_loss = avg_val_loss
|
||||||
|
epochs_no_improve = 0
|
||||||
|
torch.save(model.state_dict(), SAVE_MODEL_PATH)
|
||||||
|
print("--> Обновлены лучшие веса")
|
||||||
|
else:
|
||||||
|
epochs_no_improve += 1
|
||||||
|
|
||||||
|
# АВАРИЙНОЕ СОХРАНЕНИЕ В КОНЦЕ ЭПОХИ
|
||||||
|
checkpoint_state = {
|
||||||
|
'stage': stage,
|
||||||
|
'epoch': epoch,
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'best_val_loss': best_val_loss
|
||||||
|
}
|
||||||
|
torch.save(checkpoint_state, RESUME_CHECKPOINT)
|
||||||
|
os.sync() # Защита от отключения электричества
|
||||||
|
print(f"--> Чекпоинт (Этап {stage+1}, Эпоха {epoch}) зафиксирован на диске.")
|
||||||
|
|
||||||
|
# РАННЯЯ ОСТАНОВКА ДЛЯ ТЕКУЩЕГО ЭТАПА
|
||||||
|
if epochs_no_improve >= PATIENCE:
|
||||||
|
print(f"Ранняя остановка для ЭТАПА {stage+1}. Переход к следующей папке...")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Сброс счетчика стартовой эпохи после прохождения восстановительного этапа
|
||||||
|
start_epoch = 1
|
||||||
Reference in New Issue
Block a user