Files
Thesis/src/scripts/21_train_images.ipynb
T
2026-06-08 14:49:45 +00:00

542 lines
25 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0c00b67b",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"import pandas as pd\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torchvision.transforms as T\n",
"import timm"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "84c3657f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'cuda'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Конфигурация параметров обучения и путей файловой системы\n",
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
"BATCH_SIZE = 64\n",
"EPOCHS = 15\n",
"LR = 3e-4\n",
"NUM_WORKERS = 62\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Аппаратное ускорение: {device}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f749add",
"metadata": {},
"outputs": [],
"source": [
"class EmoSetDataset(Dataset):\n",
" def __init__(self, root: Path | str, split: str):\n",
" self.root = Path(root) / split\n",
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
"\n",
" # Формирование словарей маппинга классов\n",
" self.labels = sorted(self.df[\"label\"].unique())\n",
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
"\n",
" # Базовые трансформации для валидации и теста\n",
" base_tf = [\n",
" T.ToTensor(),\n",
" T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
" ]\n",
"\n",
" # Внедрение аугментации исключительно для обучающей выборки (предотвращение переобучения)\n",
" if split == \"train\":\n",
" self.transform = T.Compose([\n",
" T.RandomResizedCrop(224),\n",
" T.RandomHorizontalFlip(),\n",
" *base_tf\n",
" ])\n",
" else:\n",
" self.transform = T.Compose([\n",
" T.Resize(256),\n",
" T.CenterCrop(224),\n",
" *base_tf\n",
" ])\n",
"\n",
" def __len__(self):\n",
" return len(self.df)\n",
"\n",
" def __getitem__(self, idx):\n",
" row = self.df.iloc[idx]\n",
" img_path = self.root / \"images\" / row[\"filename\"]\n",
"\n",
" # Обработка возможных исключений ввода-вывода (поврежденные JPEG-файлы в датасете)\n",
" try:\n",
" img = Image.open(img_path).convert(\"RGB\")\n",
" except Exception:\n",
" img = Image.new(\"RGB\", (224, 224), (0, 0, 0))\n",
"\n",
" img_tensor = self.transform(img)\n",
" label_idx = self.label2idx[row[\"label\"]]\n",
" \n",
" return img_tensor, label_idx"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c8805341",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classes: ['amusement', 'anger', 'awe', 'contentment', 'disgust', 'excitement', 'fear', 'sadness']\n"
]
}
],
"source": [
"# Подготовка объектов выборки\n",
"train_ds = EmoSetDataset(DATA_ROOT, \"train\")\n",
"val_ds = EmoSetDataset(DATA_ROOT, \"val\")\n",
"\n",
"# Инициализация итераторов с закреплением памяти (pin_memory) для ускорения передачи на GPU\n",
"train_loader = DataLoader(\n",
" train_ds,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" num_workers=NUM_WORKERS,\n",
" pin_memory=True\n",
")\n",
"\n",
"val_loader = DataLoader(\n",
" val_ds,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=False,\n",
" num_workers=NUM_WORKERS,\n",
" pin_memory=True\n",
")\n",
"\n",
"print(f\"Индексированные классы: {train_ds.labels}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dffce582",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ResNet(\n",
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (layer1): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer2): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer3): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (4): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (5): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer4): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
" (fc): Linear(in_features=2048, out_features=8, bias=True)\n",
")"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# TODO перед защитой, повторить оптимизаторы\n",
"# Загрузка предобученной архитектуры ResNet-50 с заменой классификационного слоя\n",
"model = timm.create_model(\n",
" \"resnet50\",\n",
" pretrained=True,\n",
" num_classes=len(train_ds.labels)\n",
")\n",
"model.to(device)\n",
"\n",
"# Функция потерь для многоклассовой классификации\n",
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"# Оптимизатор AdamW с L2-регуляризацией (weight_decay) для повышения обобщающей способности\n",
"optimizer = torch.optim.AdamW(\n",
" model.parameters(),\n",
" lr=LR,\n",
" weight_decay=1e-4\n",
")\n",
"\n",
"# Планировщик скорости обучения: косинусный отжиг\n",
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
" optimizer,\n",
" T_max=EPOCHS\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81a457ef",
"metadata": {},
"outputs": [],
"source": [
"def train_epoch(current_model, loader):\n",
" current_model.train()\n",
" total_loss = 0.0\n",
" correct_preds = 0\n",
" total_samples = 0\n",
"\n",
" for imgs, labels in tqdm(loader, desc=\"Тренировка\", leave=False):\n",
" imgs = imgs.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" logits = current_model(imgs)\n",
" loss = criterion(logits, labels)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" total_loss += loss.item() * imgs.size(0)\n",
" preds = logits.argmax(dim=1)\n",
" correct_preds += (preds == labels).sum().item()\n",
" total_samples += labels.size(0)\n",
"\n",
" return total_loss / total_samples, correct_preds / total_samples\n",
"\n",
"@torch.no_grad()\n",
"def val_epoch(current_model, loader):\n",
" # Перевод модели в режим инференса (отключение Dropout и фиксация BatchNorm)\n",
" current_model.eval()\n",
" total_loss = 0.0\n",
" correct_preds = 0\n",
" total_samples = 0\n",
"\n",
" for imgs, labels in tqdm(loader, desc=\"Валидация\", leave=False):\n",
" imgs = imgs.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" logits = current_model(imgs)\n",
" loss = criterion(logits, labels)\n",
"\n",
" total_loss += loss.item() * imgs.size(0)\n",
" preds = logits.argmax(dim=1)\n",
" correct_preds += (preds == labels).sum().item()\n",
" total_samples += labels.size(0)\n",
"\n",
" return total_loss / total_samples, correct_preds / total_samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "951aa9e3",
"metadata": {},
"outputs": [],
"source": [
"best_val_acc = 0.0\n",
"checkpoint_path = \"../emoset_resnet50_best.pth\"\n",
"\n",
"print(\"Старт процесса обучения...\")\n",
"\n",
"for epoch in range(1, EPOCHS + 1):\n",
" train_loss, train_acc = train_epoch(model, train_loader)\n",
" val_loss, val_acc = val_epoch(model, val_loader)\n",
"\n",
" # Обновление шага планировщика\n",
" scheduler.step()\n",
"\n",
" print(\n",
" f\"Эпоха {epoch:02d}/{EPOCHS} | \"\n",
" f\"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | \"\n",
" f\"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}\"\n",
" )\n",
"\n",
" # Экспорт весов при улучшении целевой метрики\n",
" if val_acc > best_val_acc:\n",
" best_val_acc = val_acc\n",
" torch.save(model.state_dict(), checkpoint_path)\n",
" print(f\" -> Сохранен новый лучший чекпоинт (Acc: {best_val_acc:.4f})\")\n",
"\n",
"print(\"Обучение завершено.\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "thesis-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}