feat: improving finetuning
This commit is contained in:
+118
-114
@@ -2,6 +2,8 @@ import os
|
|||||||
import gc
|
import gc
|
||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
|
import ctypes
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -13,9 +15,13 @@ from torch.amp import autocast, GradScaler
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import timm
|
import timm
|
||||||
|
|
||||||
# Конфигурация стенда и путей файловой системы
|
# Подавление предупреждений PIL для корректной работы tqdm
|
||||||
|
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
|
||||||
|
|
||||||
|
# Настройка устройства
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# Пути к файлам
|
||||||
DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M")
|
DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M")
|
||||||
CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl")
|
CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl")
|
||||||
|
|
||||||
@@ -28,23 +34,23 @@ CLASS_MAPPING = {
|
|||||||
"disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7
|
"disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7
|
||||||
}
|
}
|
||||||
|
|
||||||
# Гиперпараметры конвейера обучения
|
# Параметры обучения
|
||||||
BATCH_SIZE = 82
|
BATCH_SIZE = 64
|
||||||
EPOCHS = 15
|
EPOCHS = 50
|
||||||
LR = 5e-5
|
LR = 5e-5
|
||||||
NUM_TRAIN_WORKERS = 48
|
NUM_TRAIN_WORKERS = 62
|
||||||
NUM_VAL_WORKERS = 18
|
NUM_VAL_WORKERS = 62
|
||||||
PATIENCE = 4
|
PATIENCE = 5
|
||||||
|
|
||||||
def prepare_dataset_index():
|
def prepare_dataset_index():
|
||||||
# Построение или загрузка индекса файлов для минимизации I/O операций по сети (NFS)
|
# Загрузка или создание индекса файлов
|
||||||
if CACHE_PATH.exists():
|
if CACHE_PATH.exists():
|
||||||
print(f"Загрузка карты файловой системы из кэша: {CACHE_PATH.name}")
|
print(f"Загрузка кэша: {CACHE_PATH.name}")
|
||||||
with open(CACHE_PATH, 'rb') as f:
|
with open(CACHE_PATH, 'rb') as f:
|
||||||
cache_data = pickle.load(f)
|
cache_data = pickle.load(f)
|
||||||
return cache_data['image_paths'], cache_data['labels']
|
return cache_data['image_paths'], cache_data['labels']
|
||||||
|
|
||||||
print(f"Сканирование сетевой директории {DATA_ROOT} (первичная индексация)...")
|
print(f"Сканирование директории {DATA_ROOT}...")
|
||||||
paths, labels = [], []
|
paths, labels = [], []
|
||||||
for img_path in DATA_ROOT.rglob('*.jpg'):
|
for img_path in DATA_ROOT.rglob('*.jpg'):
|
||||||
emotion_folder = img_path.parts[-3].lower()
|
emotion_folder = img_path.parts[-3].lower()
|
||||||
@@ -58,11 +64,15 @@ def prepare_dataset_index():
|
|||||||
return paths, labels
|
return paths, labels
|
||||||
|
|
||||||
class EmoSetDirectDataset(Dataset):
|
class EmoSetDirectDataset(Dataset):
|
||||||
# Датасет с отложенной аугментацией: декодирование на CPU, трансформации на GPU
|
# Датасет с загрузкой по требованию
|
||||||
def __init__(self, image_paths, labels):
|
def __init__(self, image_paths, labels):
|
||||||
self.image_paths = image_paths
|
self.image_paths = image_paths
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
self.base_transform = T.Resize((256, 256), antialias=True)
|
# Сохранение пропорций и центрирование
|
||||||
|
self.base_transform = T.Compose([
|
||||||
|
T.Resize(256, antialias=True),
|
||||||
|
T.CenterCrop(256)
|
||||||
|
])
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.image_paths)
|
return len(self.image_paths)
|
||||||
@@ -73,16 +83,15 @@ class EmoSetDirectDataset(Dataset):
|
|||||||
image = image.to(torch.float32) / 255.0
|
image = image.to(torch.float32) / 255.0
|
||||||
image = self.base_transform(image)
|
image = self.base_transform(image)
|
||||||
except Exception:
|
except Exception:
|
||||||
# Изолирование сбоев ввода-вывода (поврежденные файлы на сетевом диске)
|
# Обработка битых файлов
|
||||||
image = torch.zeros((3, 256, 256), dtype=torch.float32)
|
image = torch.zeros((3, 256, 256), dtype=torch.float32)
|
||||||
return image, self.labels[idx]
|
return image, self.labels[idx]
|
||||||
|
|
||||||
def build_gpu_transforms():
|
def build_gpu_transforms():
|
||||||
# Перенос матричных операций аугментации на тензорные ядра видеокарты
|
# Аугментации на GPU
|
||||||
train_tf = torch.nn.Sequential(
|
train_tf = torch.nn.Sequential(
|
||||||
T.RandomCrop((224, 224)),
|
T.RandomCrop((224, 224)),
|
||||||
T.RandomHorizontalFlip(p=0.5),
|
T.RandomHorizontalFlip(p=0.5),
|
||||||
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
|
|
||||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
).to(DEVICE)
|
).to(DEVICE)
|
||||||
|
|
||||||
@@ -94,11 +103,11 @@ def build_gpu_transforms():
|
|||||||
return train_tf, val_tf
|
return train_tf, val_tf
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(f"Инициализация конвейера обучения. Устройство: {DEVICE}")
|
print(f"Инициализация. Устройство: {DEVICE}")
|
||||||
|
|
||||||
all_paths, all_labels = prepare_dataset_index()
|
all_paths, all_labels = prepare_dataset_index()
|
||||||
|
|
||||||
# Фиксация сида для детерминированного разделения выборок при перезапусках скрипта
|
# Разделение выборки
|
||||||
random.seed(42)
|
random.seed(42)
|
||||||
combined = list(zip(all_paths, all_labels))
|
combined = list(zip(all_paths, all_labels))
|
||||||
random.shuffle(combined)
|
random.shuffle(combined)
|
||||||
@@ -110,14 +119,14 @@ if __name__ == "__main__":
|
|||||||
EmoSetDirectDataset(all_paths[:split_idx], all_labels[:split_idx]),
|
EmoSetDirectDataset(all_paths[:split_idx], all_labels[:split_idx]),
|
||||||
batch_size=BATCH_SIZE, shuffle=True,
|
batch_size=BATCH_SIZE, shuffle=True,
|
||||||
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
|
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
|
||||||
prefetch_factor=2, persistent_workers=False
|
prefetch_factor=3, persistent_workers=False
|
||||||
)
|
)
|
||||||
|
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
EmoSetDirectDataset(all_paths[split_idx:], all_labels[split_idx:]),
|
EmoSetDirectDataset(all_paths[split_idx:], all_labels[split_idx:]),
|
||||||
batch_size=BATCH_SIZE, shuffle=False,
|
batch_size=BATCH_SIZE, shuffle=False,
|
||||||
num_workers=NUM_VAL_WORKERS, pin_memory=True,
|
num_workers=NUM_VAL_WORKERS, pin_memory=True,
|
||||||
prefetch_factor=2, persistent_workers=False
|
prefetch_factor=3, persistent_workers=False
|
||||||
)
|
)
|
||||||
|
|
||||||
gpu_train_tf, gpu_val_tf = build_gpu_transforms()
|
gpu_train_tf, gpu_val_tf = build_gpu_transforms()
|
||||||
@@ -132,133 +141,128 @@ if __name__ == "__main__":
|
|||||||
epochs_no_improve = 0
|
epochs_no_improve = 0
|
||||||
start_epoch = 1
|
start_epoch = 1
|
||||||
|
|
||||||
# Инициализация механизма отказоустойчивости и интеграция весов
|
# Загрузка весов
|
||||||
if RESUME_CHECKPOINT.exists():
|
if RESUME_CHECKPOINT.exists():
|
||||||
print(f"Восстановление контекста выполнения из: {RESUME_CHECKPOINT.name}")
|
print(f"Восстановление из: {RESUME_CHECKPOINT.name}")
|
||||||
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
|
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
|
||||||
model.load_state_dict(checkpoint['model_state_dict'])
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
||||||
if 'scaler_state_dict' in checkpoint: scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
try:
|
||||||
if 'best_val_loss' in checkpoint: best_val_loss = checkpoint['best_val_loss']
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if 'scaler_state_dict' in checkpoint:
|
||||||
|
scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
||||||
|
if 'best_val_loss' in checkpoint:
|
||||||
|
best_val_loss = checkpoint['best_val_loss']
|
||||||
start_epoch = checkpoint['epoch'] + 1
|
start_epoch = checkpoint['epoch'] + 1
|
||||||
elif PREVIOUS_WEIGHTS.exists():
|
elif PREVIOUS_WEIGHTS.exists():
|
||||||
print(f"Интеграция претренированных весов: {PREVIOUS_WEIGHTS.name}")
|
print(f"Загрузка базовых весов: {PREVIOUS_WEIGHTS.name}")
|
||||||
model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE))
|
model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE))
|
||||||
else:
|
else:
|
||||||
print("Веса не найдены. Инициализация с ImageNet.")
|
print("Веса не найдены. Инициализация с ImageNet.")
|
||||||
model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE)
|
model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE)
|
||||||
|
|
||||||
try:
|
for epoch in range(start_epoch, EPOCHS + 1):
|
||||||
for epoch in range(start_epoch, EPOCHS + 1):
|
|
||||||
|
|
||||||
# Проход по обучающей выборке
|
# Обучение
|
||||||
model.train()
|
model.train()
|
||||||
running_loss, correct, total = 0.0, 0, 0
|
running_loss, correct, total = 0.0, 0, 0
|
||||||
|
|
||||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [Train]")
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [Train]", smoothing=0)
|
||||||
for inputs, labels in pbar:
|
for inputs, labels in pbar:
|
||||||
try:
|
try:
|
||||||
inputs = inputs.to(DEVICE, non_blocking=True)
|
inputs = inputs.to(DEVICE, non_blocking=True)
|
||||||
labels = labels.to(DEVICE, non_blocking=True)
|
labels = labels.to(DEVICE, non_blocking=True)
|
||||||
inputs = gpu_train_tf(inputs)
|
inputs = gpu_train_tf(inputs)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# Смешанная точность для экономии VRAM
|
# Смешанная точность
|
||||||
with autocast(device_type="cuda"):
|
with autocast(device_type="cuda"):
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
loss = criterion(outputs, labels)
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
|
||||||
running_loss += loss.item() * inputs.size(0)
|
running_loss += loss.item() * inputs.size(0)
|
||||||
_, predicted = outputs.max(1)
|
_, predicted = outputs.max(1)
|
||||||
total += labels.size(0)
|
total += labels.size(0)
|
||||||
correct += predicted.eq(labels).sum().item()
|
correct += predicted.eq(labels).sum().item()
|
||||||
|
|
||||||
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
||||||
|
|
||||||
except RuntimeError as memory_err:
|
except RuntimeError as memory_err:
|
||||||
# Подавление пиковых скачков потребления VRAM
|
# Очистка памяти при OOM
|
||||||
if "out of memory" in str(memory_err).lower():
|
if "out of memory" in str(memory_err).lower():
|
||||||
if 'outputs' in locals(): del outputs
|
if 'outputs' in locals(): del outputs
|
||||||
if 'loss' in locals(): del loss
|
if 'loss' in locals(): del loss
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad(set_to_none=True)
|
||||||
continue
|
continue
|
||||||
raise memory_err
|
raise memory_err
|
||||||
|
|
||||||
train_loss = running_loss / total if total > 0 else 0
|
train_loss = running_loss / total if total > 0 else 0
|
||||||
train_acc = correct / total if total > 0 else 0
|
train_acc = correct / total if total > 0 else 0
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Проход по валидационной выборке
|
# Валидация
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loss, val_correct, val_total = 0.0, 0, 0
|
val_loss, val_correct, val_total = 0.0, 0, 0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for val_inputs, val_labels in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} [Val]", leave=False):
|
for val_inputs, val_labels in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} [Val]", smoothing = 0):
|
||||||
val_inputs = val_inputs.to(DEVICE, non_blocking=True)
|
val_inputs = val_inputs.to(DEVICE, non_blocking=True)
|
||||||
val_labels = val_labels.to(DEVICE, non_blocking=True)
|
val_labels = val_labels.to(DEVICE, non_blocking=True)
|
||||||
val_inputs = gpu_val_tf(val_inputs)
|
val_inputs = gpu_val_tf(val_inputs)
|
||||||
|
|
||||||
with autocast(device_type="cuda"):
|
with autocast(device_type="cuda"):
|
||||||
val_outputs = model(val_inputs)
|
val_outputs = model(val_inputs)
|
||||||
v_loss = criterion(val_outputs, val_labels)
|
v_loss = criterion(val_outputs, val_labels)
|
||||||
|
|
||||||
val_loss += v_loss.item() * val_inputs.size(0)
|
val_loss += v_loss.item() * val_inputs.size(0)
|
||||||
_, val_predicted = val_outputs.max(1)
|
_, val_predicted = val_outputs.max(1)
|
||||||
val_total += val_labels.size(0)
|
val_total += val_labels.size(0)
|
||||||
val_correct += val_predicted.eq(val_labels).sum().item()
|
val_correct += val_predicted.eq(val_labels).sum().item()
|
||||||
|
|
||||||
epoch_val_loss = val_loss / val_total if val_total > 0 else 0
|
epoch_val_loss = val_loss / val_total if val_total > 0 else 0
|
||||||
epoch_val_acc = val_correct / val_total if val_total > 0 else 0
|
epoch_val_acc = val_correct / val_total if val_total > 0 else 0
|
||||||
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}")
|
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}")
|
||||||
|
|
||||||
# Оценка критериев ранней остановки и сохранение состояния сессии
|
# Ранняя остановка и сохранение
|
||||||
if epoch_val_loss < best_val_loss:
|
if epoch_val_loss < best_val_loss:
|
||||||
best_val_loss = epoch_val_loss
|
best_val_loss = epoch_val_loss
|
||||||
epochs_no_improve = 0
|
epochs_no_improve = 0
|
||||||
torch.save(model.state_dict(), str(SAVE_MODEL_PATH).replace(".pth", "_best.pth"))
|
torch.save(model.state_dict(), str(SAVE_MODEL_PATH).replace(".pth", "_best.pth"))
|
||||||
else:
|
print("Сохранен новый лучший чекпоинт.")
|
||||||
epochs_no_improve += 1
|
else:
|
||||||
if epochs_no_improve >= PATIENCE and epoch >= 15:
|
epochs_no_improve += 1
|
||||||
print(f"Сработал механизм Early Stopping. Валидация не улучшается {PATIENCE} эпох.")
|
if epochs_no_improve >= PATIENCE and epoch >= 25:
|
||||||
break
|
print(f"Остановка: валидация не улучшается {PATIENCE} эпох.")
|
||||||
|
break
|
||||||
|
|
||||||
# Атомарное сохранение контекста
|
# Сохранение состояния
|
||||||
checkpoint_state = {
|
|
||||||
'epoch': epoch,
|
|
||||||
'model_state_dict': model.state_dict(),
|
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
|
||||||
'scheduler_state_dict': scheduler.state_dict(),
|
|
||||||
'scaler_state_dict': scaler.state_dict(),
|
|
||||||
'best_val_loss': best_val_loss
|
|
||||||
}
|
|
||||||
torch.save(checkpoint_state, RESUME_CHECKPOINT)
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\nВыполнение прервано пользователем (SIGINT).")
|
|
||||||
print(f"Дамп памяти конвейера зафиксирован на эпохе {epoch}.")
|
|
||||||
checkpoint_state = {
|
checkpoint_state = {
|
||||||
'epoch': epoch, 'model_state_dict': model.state_dict(),
|
'epoch': epoch,
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': scaler.state_dict(),
|
'scheduler_state_dict': scheduler.state_dict(),
|
||||||
|
'scaler_state_dict': scaler.state_dict(),
|
||||||
'best_val_loss': best_val_loss
|
'best_val_loss': best_val_loss
|
||||||
}
|
}
|
||||||
torch.save(checkpoint_state, RESUME_CHECKPOINT)
|
torch.save(checkpoint_state, RESUME_CHECKPOINT)
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
else:
|
if SAVE_MODEL_PATH.parent.exists():
|
||||||
if SAVE_MODEL_PATH.parent.exists():
|
torch.save(model.state_dict(), SAVE_MODEL_PATH)
|
||||||
torch.save(model.state_dict(), SAVE_MODEL_PATH)
|
print(f"Обучение завершено. Файл: {SAVE_MODEL_PATH.name}")
|
||||||
print(f"Процесс Fine-Tuning завершен. Артефакт сохранен: {SAVE_MODEL_PATH.name}")
|
if RESUME_CHECKPOINT.exists():
|
||||||
if RESUME_CHECKPOINT.exists():
|
RESUME_CHECKPOINT.unlink()
|
||||||
RESUME_CHECKPOINT.unlink()
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.8 MiB |
Reference in New Issue
Block a user