Files
Thesis/src/scripts/31_finetune_2.41M.py
T
2026-06-06 21:06:21 +00:00

268 lines
10 KiB
Python

import os
import gc
import pickle
import random
import ctypes
import warnings
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.io as tv_io
from torch.amp import autocast, GradScaler
from tqdm import tqdm
import timm
# Подавление предупреждений PIL для корректной работы tqdm
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
# Настройка устройства
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Пути к файлам
DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M")
CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl")
PREVIOUS_WEIGHTS = Path("/home/zin/projects/Thesis/src/emoset_resnet50_best.pth")
RESUME_CHECKPOINT = Path("/home/zin/projects/Thesis/src/emoset_resnet50_resume.pth")
SAVE_MODEL_PATH = Path("/home/zin/projects/Thesis/src/emoset_resnet50_finetuned_2_41M.pth")
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7
}
# Параметры обучения
BATCH_SIZE = 64
EPOCHS = 50
LR = 5e-5
NUM_TRAIN_WORKERS = 62
NUM_VAL_WORKERS = 62
PATIENCE = 5
def prepare_dataset_index():
# Загрузка или создание индекса файлов
if CACHE_PATH.exists():
print(f"Загрузка кэша: {CACHE_PATH.name}")
with open(CACHE_PATH, 'rb') as f:
cache_data = pickle.load(f)
return cache_data['image_paths'], cache_data['labels']
print(f"Сканирование директории {DATA_ROOT}...")
paths, labels = [], []
for img_path in DATA_ROOT.rglob('*.jpg'):
emotion_folder = img_path.parts[-3].lower()
if emotion_folder in CLASS_MAPPING:
paths.append(str(img_path))
labels.append(CLASS_MAPPING[emotion_folder])
with open(CACHE_PATH, 'wb') as f:
pickle.dump({'image_paths': paths, 'labels': labels}, f)
return paths, labels
class EmoSetDirectDataset(Dataset):
# Датасет с загрузкой по требованию
def __init__(self, image_paths, labels):
self.image_paths = image_paths
self.labels = labels
# Сохранение пропорций и центрирование
self.base_transform = T.Compose([
T.Resize(256, antialias=True),
T.CenterCrop(256)
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
try:
image = tv_io.read_image(self.image_paths[idx], mode=tv_io.ImageReadMode.RGB)
image = image.to(torch.float32) / 255.0
image = self.base_transform(image)
except Exception:
# Обработка битых файлов
image = torch.zeros((3, 256, 256), dtype=torch.float32)
return image, self.labels[idx]
def build_gpu_transforms():
# Аугментации на GPU
train_tf = torch.nn.Sequential(
T.RandomCrop((224, 224)),
T.RandomHorizontalFlip(p=0.5),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(DEVICE)
val_tf = torch.nn.Sequential(
T.CenterCrop((224, 224)),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(DEVICE)
return train_tf, val_tf
if __name__ == "__main__":
print(f"Инициализация. Устройство: {DEVICE}")
all_paths, all_labels = prepare_dataset_index()
# Разделение выборки
random.seed(42)
combined = list(zip(all_paths, all_labels))
random.shuffle(combined)
all_paths, all_labels = zip(*combined)
split_idx = int(len(all_paths) * 0.95)
train_loader = DataLoader(
EmoSetDirectDataset(all_paths[:split_idx], all_labels[:split_idx]),
batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
prefetch_factor=3, persistent_workers=False
)
val_loader = DataLoader(
EmoSetDirectDataset(all_paths[split_idx:], all_labels[split_idx:]),
batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_VAL_WORKERS, pin_memory=True,
prefetch_factor=3, persistent_workers=False
)
gpu_train_tf, gpu_val_tf = build_gpu_transforms()
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = GradScaler()
best_val_loss = float('inf')
epochs_no_improve = 0
start_epoch = 1
# Загрузка весов
if RESUME_CHECKPOINT.exists():
print(f"Восстановление из: {RESUME_CHECKPOINT.name}")
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
try:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
except Exception:
pass
if 'scaler_state_dict' in checkpoint:
scaler.load_state_dict(checkpoint['scaler_state_dict'])
if 'best_val_loss' in checkpoint:
best_val_loss = checkpoint['best_val_loss']
start_epoch = checkpoint['epoch'] + 1
elif PREVIOUS_WEIGHTS.exists():
print(f"Загрузка базовых весов: {PREVIOUS_WEIGHTS.name}")
model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE))
else:
print("Веса не найдены. Инициализация с ImageNet.")
model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE)
for epoch in range(start_epoch, EPOCHS + 1):
# Обучение
model.train()
running_loss, correct, total = 0.0, 0, 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [Train]", smoothing=0)
for inputs, labels in pbar:
try:
inputs = inputs.to(DEVICE, non_blocking=True)
labels = labels.to(DEVICE, non_blocking=True)
inputs = gpu_train_tf(inputs)
optimizer.zero_grad(set_to_none=True)
# Смешанная точность
with autocast(device_type="cuda"):
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
except RuntimeError as memory_err:
# Очистка памяти при OOM
if "out of memory" in str(memory_err).lower():
if 'outputs' in locals(): del outputs
if 'loss' in locals(): del loss
torch.cuda.empty_cache()
optimizer.zero_grad(set_to_none=True)
continue
raise memory_err
train_loss = running_loss / total if total > 0 else 0
train_acc = correct / total if total > 0 else 0
gc.collect()
torch.cuda.empty_cache()
# Валидация
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for val_inputs, val_labels in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} [Val]", smoothing = 0):
val_inputs = val_inputs.to(DEVICE, non_blocking=True)
val_labels = val_labels.to(DEVICE, non_blocking=True)
val_inputs = gpu_val_tf(val_inputs)
with autocast(device_type="cuda"):
val_outputs = model(val_inputs)
v_loss = criterion(val_outputs, val_labels)
val_loss += v_loss.item() * val_inputs.size(0)
_, val_predicted = val_outputs.max(1)
val_total += val_labels.size(0)
val_correct += val_predicted.eq(val_labels).sum().item()
epoch_val_loss = val_loss / val_total if val_total > 0 else 0
epoch_val_acc = val_correct / val_total if val_total > 0 else 0
scheduler.step()
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}")
# Ранняя остановка и сохранение
if epoch_val_loss < best_val_loss:
best_val_loss = epoch_val_loss
epochs_no_improve = 0
torch.save(model.state_dict(), str(SAVE_MODEL_PATH).replace(".pth", "_best.pth"))
print("Сохранен новый лучший чекпоинт.")
else:
epochs_no_improve += 1
if epochs_no_improve >= PATIENCE and epoch >= 25:
print(f"Остановка: валидация не улучшается {PATIENCE} эпох.")
break
# Сохранение состояния
checkpoint_state = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'best_val_loss': best_val_loss
}
torch.save(checkpoint_state, RESUME_CHECKPOINT)
gc.collect()
if SAVE_MODEL_PATH.parent.exists():
torch.save(model.state_dict(), SAVE_MODEL_PATH)
print(f"Обучение завершено. Файл: {SAVE_MODEL_PATH.name}")
if RESUME_CHECKPOINT.exists():
RESUME_CHECKPOINT.unlink()