feat: improving finetuning

This commit is contained in:
zin
2026-06-06 21:06:21 +00:00
parent 8648e52106
commit daba573b2c
2 changed files with 131 additions and 127 deletions
+118 -114
View File
@@ -2,6 +2,8 @@ import os
import gc import gc
import pickle import pickle
import random import random
import ctypes
import warnings
from pathlib import Path from pathlib import Path
import torch import torch
@@ -13,9 +15,13 @@ from torch.amp import autocast, GradScaler
from tqdm import tqdm from tqdm import tqdm
import timm import timm
# Конфигурация стенда и путей файловой системы # Подавление предупреждений PIL для корректной работы tqdm
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
# Настройка устройства
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Пути к файлам
DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M") DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M")
CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl") CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl")
@@ -28,23 +34,23 @@ CLASS_MAPPING = {
"disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7 "disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7
} }
# Гиперпараметры конвейера обучения # Параметры обучения
BATCH_SIZE = 82 BATCH_SIZE = 64
EPOCHS = 15 EPOCHS = 50
LR = 5e-5 LR = 5e-5
NUM_TRAIN_WORKERS = 48 NUM_TRAIN_WORKERS = 62
NUM_VAL_WORKERS = 18 NUM_VAL_WORKERS = 62
PATIENCE = 4 PATIENCE = 5
def prepare_dataset_index(): def prepare_dataset_index():
# Построение или загрузка индекса файлов для минимизации I/O операций по сети (NFS) # Загрузка или создание индекса файлов
if CACHE_PATH.exists(): if CACHE_PATH.exists():
print(f"Загрузка карты файловой системы из кэша: {CACHE_PATH.name}") print(f"Загрузка кэша: {CACHE_PATH.name}")
with open(CACHE_PATH, 'rb') as f: with open(CACHE_PATH, 'rb') as f:
cache_data = pickle.load(f) cache_data = pickle.load(f)
return cache_data['image_paths'], cache_data['labels'] return cache_data['image_paths'], cache_data['labels']
print(f"Сканирование сетевой директории {DATA_ROOT} (первичная индексация)...") print(f"Сканирование директории {DATA_ROOT}...")
paths, labels = [], [] paths, labels = [], []
for img_path in DATA_ROOT.rglob('*.jpg'): for img_path in DATA_ROOT.rglob('*.jpg'):
emotion_folder = img_path.parts[-3].lower() emotion_folder = img_path.parts[-3].lower()
@@ -58,11 +64,15 @@ def prepare_dataset_index():
return paths, labels return paths, labels
class EmoSetDirectDataset(Dataset): class EmoSetDirectDataset(Dataset):
# Датасет с отложенной аугментацией: декодирование на CPU, трансформации на GPU # Датасет с загрузкой по требованию
def __init__(self, image_paths, labels): def __init__(self, image_paths, labels):
self.image_paths = image_paths self.image_paths = image_paths
self.labels = labels self.labels = labels
self.base_transform = T.Resize((256, 256), antialias=True) # Сохранение пропорций и центрирование
self.base_transform = T.Compose([
T.Resize(256, antialias=True),
T.CenterCrop(256)
])
def __len__(self): def __len__(self):
return len(self.image_paths) return len(self.image_paths)
@@ -73,16 +83,15 @@ class EmoSetDirectDataset(Dataset):
image = image.to(torch.float32) / 255.0 image = image.to(torch.float32) / 255.0
image = self.base_transform(image) image = self.base_transform(image)
except Exception: except Exception:
# Изолирование сбоев ввода-вывода (поврежденные файлы на сетевом диске) # Обработка битых файлов
image = torch.zeros((3, 256, 256), dtype=torch.float32) image = torch.zeros((3, 256, 256), dtype=torch.float32)
return image, self.labels[idx] return image, self.labels[idx]
def build_gpu_transforms(): def build_gpu_transforms():
# Перенос матричных операций аугментации на тензорные ядра видеокарты # Аугментации на GPU
train_tf = torch.nn.Sequential( train_tf = torch.nn.Sequential(
T.RandomCrop((224, 224)), T.RandomCrop((224, 224)),
T.RandomHorizontalFlip(p=0.5), T.RandomHorizontalFlip(p=0.5),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(DEVICE) ).to(DEVICE)
@@ -94,11 +103,11 @@ def build_gpu_transforms():
return train_tf, val_tf return train_tf, val_tf
if __name__ == "__main__": if __name__ == "__main__":
print(f"Инициализация конвейера обучения. Устройство: {DEVICE}") print(f"Инициализация. Устройство: {DEVICE}")
all_paths, all_labels = prepare_dataset_index() all_paths, all_labels = prepare_dataset_index()
# Фиксация сида для детерминированного разделения выборок при перезапусках скрипта # Разделение выборки
random.seed(42) random.seed(42)
combined = list(zip(all_paths, all_labels)) combined = list(zip(all_paths, all_labels))
random.shuffle(combined) random.shuffle(combined)
@@ -110,14 +119,14 @@ if __name__ == "__main__":
EmoSetDirectDataset(all_paths[:split_idx], all_labels[:split_idx]), EmoSetDirectDataset(all_paths[:split_idx], all_labels[:split_idx]),
batch_size=BATCH_SIZE, shuffle=True, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_TRAIN_WORKERS, pin_memory=True, num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
prefetch_factor=2, persistent_workers=False prefetch_factor=3, persistent_workers=False
) )
val_loader = DataLoader( val_loader = DataLoader(
EmoSetDirectDataset(all_paths[split_idx:], all_labels[split_idx:]), EmoSetDirectDataset(all_paths[split_idx:], all_labels[split_idx:]),
batch_size=BATCH_SIZE, shuffle=False, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_VAL_WORKERS, pin_memory=True, num_workers=NUM_VAL_WORKERS, pin_memory=True,
prefetch_factor=2, persistent_workers=False prefetch_factor=3, persistent_workers=False
) )
gpu_train_tf, gpu_val_tf = build_gpu_transforms() gpu_train_tf, gpu_val_tf = build_gpu_transforms()
@@ -132,133 +141,128 @@ if __name__ == "__main__":
epochs_no_improve = 0 epochs_no_improve = 0
start_epoch = 1 start_epoch = 1
# Инициализация механизма отказоустойчивости и интеграция весов # Загрузка весов
if RESUME_CHECKPOINT.exists(): if RESUME_CHECKPOINT.exists():
print(f"Восстановление контекста выполнения из: {RESUME_CHECKPOINT.name}") print(f"Восстановление из: {RESUME_CHECKPOINT.name}")
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE) checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict']) model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if 'scaler_state_dict' in checkpoint: scaler.load_state_dict(checkpoint['scaler_state_dict']) try:
if 'best_val_loss' in checkpoint: best_val_loss = checkpoint['best_val_loss'] scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
except Exception:
pass
if 'scaler_state_dict' in checkpoint:
scaler.load_state_dict(checkpoint['scaler_state_dict'])
if 'best_val_loss' in checkpoint:
best_val_loss = checkpoint['best_val_loss']
start_epoch = checkpoint['epoch'] + 1 start_epoch = checkpoint['epoch'] + 1
elif PREVIOUS_WEIGHTS.exists(): elif PREVIOUS_WEIGHTS.exists():
print(f"Интеграция претренированных весов: {PREVIOUS_WEIGHTS.name}") print(f"Загрузка базовых весов: {PREVIOUS_WEIGHTS.name}")
model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE)) model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE))
else: else:
print("Веса не найдены. Инициализация с ImageNet.") print("Веса не найдены. Инициализация с ImageNet.")
model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE) model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE)
try: for epoch in range(start_epoch, EPOCHS + 1):
for epoch in range(start_epoch, EPOCHS + 1):
# Проход по обучающей выборке # Обучение
model.train() model.train()
running_loss, correct, total = 0.0, 0, 0 running_loss, correct, total = 0.0, 0, 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [Train]") pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [Train]", smoothing=0)
for inputs, labels in pbar: for inputs, labels in pbar:
try: try:
inputs = inputs.to(DEVICE, non_blocking=True) inputs = inputs.to(DEVICE, non_blocking=True)
labels = labels.to(DEVICE, non_blocking=True) labels = labels.to(DEVICE, non_blocking=True)
inputs = gpu_train_tf(inputs) inputs = gpu_train_tf(inputs)
optimizer.zero_grad() optimizer.zero_grad(set_to_none=True)
# Смешанная точность для экономии VRAM # Смешанная точность
with autocast(device_type="cuda"): with autocast(device_type="cuda"):
outputs = model(inputs) outputs = model(inputs)
loss = criterion(outputs, labels) loss = criterion(outputs, labels)
scaler.scale(loss).backward() scaler.scale(loss).backward()
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
running_loss += loss.item() * inputs.size(0) running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1) _, predicted = outputs.max(1)
total += labels.size(0) total += labels.size(0)
correct += predicted.eq(labels).sum().item() correct += predicted.eq(labels).sum().item()
pbar.set_postfix({'loss': f"{loss.item():.4f}"}) pbar.set_postfix({'loss': f"{loss.item():.4f}"})
except RuntimeError as memory_err: except RuntimeError as memory_err:
# Подавление пиковых скачков потребления VRAM # Очистка памяти при OOM
if "out of memory" in str(memory_err).lower(): if "out of memory" in str(memory_err).lower():
if 'outputs' in locals(): del outputs if 'outputs' in locals(): del outputs
if 'loss' in locals(): del loss if 'loss' in locals(): del loss
torch.cuda.empty_cache() torch.cuda.empty_cache()
optimizer.zero_grad() optimizer.zero_grad(set_to_none=True)
continue continue
raise memory_err raise memory_err
train_loss = running_loss / total if total > 0 else 0 train_loss = running_loss / total if total > 0 else 0
train_acc = correct / total if total > 0 else 0 train_acc = correct / total if total > 0 else 0
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Проход по валидационной выборке # Валидация
model.eval() model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0 val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad(): with torch.no_grad():
for val_inputs, val_labels in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} [Val]", leave=False): for val_inputs, val_labels in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} [Val]", smoothing = 0):
val_inputs = val_inputs.to(DEVICE, non_blocking=True) val_inputs = val_inputs.to(DEVICE, non_blocking=True)
val_labels = val_labels.to(DEVICE, non_blocking=True) val_labels = val_labels.to(DEVICE, non_blocking=True)
val_inputs = gpu_val_tf(val_inputs) val_inputs = gpu_val_tf(val_inputs)
with autocast(device_type="cuda"): with autocast(device_type="cuda"):
val_outputs = model(val_inputs) val_outputs = model(val_inputs)
v_loss = criterion(val_outputs, val_labels) v_loss = criterion(val_outputs, val_labels)
val_loss += v_loss.item() * val_inputs.size(0) val_loss += v_loss.item() * val_inputs.size(0)
_, val_predicted = val_outputs.max(1) _, val_predicted = val_outputs.max(1)
val_total += val_labels.size(0) val_total += val_labels.size(0)
val_correct += val_predicted.eq(val_labels).sum().item() val_correct += val_predicted.eq(val_labels).sum().item()
epoch_val_loss = val_loss / val_total if val_total > 0 else 0 epoch_val_loss = val_loss / val_total if val_total > 0 else 0
epoch_val_acc = val_correct / val_total if val_total > 0 else 0 epoch_val_acc = val_correct / val_total if val_total > 0 else 0
scheduler.step() scheduler.step()
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}") print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}")
# Оценка критериев ранней остановки и сохранение состояния сессии # Ранняя остановка и сохранение
if epoch_val_loss < best_val_loss: if epoch_val_loss < best_val_loss:
best_val_loss = epoch_val_loss best_val_loss = epoch_val_loss
epochs_no_improve = 0 epochs_no_improve = 0
torch.save(model.state_dict(), str(SAVE_MODEL_PATH).replace(".pth", "_best.pth")) torch.save(model.state_dict(), str(SAVE_MODEL_PATH).replace(".pth", "_best.pth"))
else: print("Сохранен новый лучший чекпоинт.")
epochs_no_improve += 1 else:
if epochs_no_improve >= PATIENCE and epoch >= 15: epochs_no_improve += 1
print(f"Сработал механизм Early Stopping. Валидация не улучшается {PATIENCE} эпох.") if epochs_no_improve >= PATIENCE and epoch >= 25:
break print(f"Остановка: валидация не улучшается {PATIENCE} эпох.")
break
# Атомарное сохранение контекста # Сохранение состояния
checkpoint_state = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'best_val_loss': best_val_loss
}
torch.save(checkpoint_state, RESUME_CHECKPOINT)
gc.collect()
except KeyboardInterrupt:
print("\nВыполнение прервано пользователем (SIGINT).")
print(f"Дамп памяти конвейера зафиксирован на эпохе {epoch}.")
checkpoint_state = { checkpoint_state = {
'epoch': epoch, 'model_state_dict': model.state_dict(), 'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': scaler.state_dict(), 'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'best_val_loss': best_val_loss 'best_val_loss': best_val_loss
} }
torch.save(checkpoint_state, RESUME_CHECKPOINT) torch.save(checkpoint_state, RESUME_CHECKPOINT)
gc.collect()
else: if SAVE_MODEL_PATH.parent.exists():
if SAVE_MODEL_PATH.parent.exists(): torch.save(model.state_dict(), SAVE_MODEL_PATH)
torch.save(model.state_dict(), SAVE_MODEL_PATH) print(f"Обучение завершено. Файл: {SAVE_MODEL_PATH.name}")
print(f"Процесс Fine-Tuning завершен. Артефакт сохранен: {SAVE_MODEL_PATH.name}") if RESUME_CHECKPOINT.exists():
if RESUME_CHECKPOINT.exists(): RESUME_CHECKPOINT.unlink()
RESUME_CHECKPOINT.unlink()
Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.8 MiB