{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "0c00b67b", "metadata": {}, "outputs": [], "source": [ "import os\n", "from pathlib import Path\n", "from PIL import Image\n", "import pandas as pd\n", "import numpy as np\n", "from tqdm import tqdm\n", "\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import Dataset, DataLoader\n", "import torchvision.transforms as T\n", "import timm" ] }, { "cell_type": "code", "execution_count": null, "id": "84c3657f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'cuda'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Конфигурация параметров обучения и путей файловой системы\n", "DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n", "BATCH_SIZE = 64\n", "EPOCHS = 15\n", "LR = 3e-4\n", "NUM_WORKERS = 62\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Аппаратное ускорение: {device}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "9f749add", "metadata": {}, "outputs": [], "source": [ "class EmoSetDataset(Dataset):\n", " def __init__(self, root: Path | str, split: str):\n", " self.root = Path(root) / split\n", " self.df = pd.read_csv(self.root / \"labels.csv\")\n", "\n", " # Формирование словарей маппинга классов\n", " self.labels = sorted(self.df[\"label\"].unique())\n", " self.label2idx = {l: i for i, l in enumerate(self.labels)}\n", " self.idx2label = {i: l for l, i in self.label2idx.items()}\n", "\n", " # Базовые трансформации для валидации и теста\n", " base_tf = [\n", " T.ToTensor(),\n", " T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", " ]\n", "\n", " # Внедрение аугментации исключительно для обучающей выборки (предотвращение переобучения)\n", " if split == \"train\":\n", " self.transform = T.Compose([\n", " T.RandomResizedCrop(224),\n", " T.RandomHorizontalFlip(),\n", " *base_tf\n", " ])\n", " else:\n", " self.transform = T.Compose([\n", " T.Resize(256),\n", " T.CenterCrop(224),\n", " *base_tf\n", " ])\n", "\n", " def __len__(self):\n", " return len(self.df)\n", "\n", " def __getitem__(self, idx):\n", " row = self.df.iloc[idx]\n", " img_path = self.root / \"images\" / row[\"filename\"]\n", "\n", " # Обработка возможных исключений ввода-вывода (поврежденные JPEG-файлы в датасете)\n", " try:\n", " img = Image.open(img_path).convert(\"RGB\")\n", " except Exception:\n", " img = Image.new(\"RGB\", (224, 224), (0, 0, 0))\n", "\n", " img_tensor = self.transform(img)\n", " label_idx = self.label2idx[row[\"label\"]]\n", " \n", " return img_tensor, label_idx" ] }, { "cell_type": "code", "execution_count": null, "id": "c8805341", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Classes: ['amusement', 'anger', 'awe', 'contentment', 'disgust', 'excitement', 'fear', 'sadness']\n" ] } ], "source": [ "# Подготовка объектов выборки\n", "train_ds = EmoSetDataset(DATA_ROOT, \"train\")\n", "val_ds = EmoSetDataset(DATA_ROOT, \"val\")\n", "\n", "# Инициализация итераторов с закреплением памяти (pin_memory) для ускорения передачи на GPU\n", "train_loader = DataLoader(\n", " train_ds,\n", " batch_size=BATCH_SIZE,\n", " shuffle=True,\n", " num_workers=NUM_WORKERS,\n", " pin_memory=True\n", ")\n", "\n", "val_loader = DataLoader(\n", " val_ds,\n", " batch_size=BATCH_SIZE,\n", " shuffle=False,\n", " num_workers=NUM_WORKERS,\n", " pin_memory=True\n", ")\n", "\n", "print(f\"Индексированные классы: {train_ds.labels}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "dffce582", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ResNet(\n", " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (layer1): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " )\n", " (layer2): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (3): Bottleneck(\n", " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " )\n", " (layer3): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (3): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (4): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (5): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " )\n", " (layer4): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " )\n", " (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n", " (fc): Linear(in_features=2048, out_features=8, bias=True)\n", ")" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# TODO перед защитой, повторить оптимизаторы\n", "# Загрузка предобученной архитектуры ResNet-50 с заменой классификационного слоя\n", "model = timm.create_model(\n", " \"resnet50\",\n", " pretrained=True,\n", " num_classes=len(train_ds.labels)\n", ")\n", "model.to(device)\n", "\n", "# Функция потерь для многоклассовой классификации\n", "criterion = nn.CrossEntropyLoss()\n", "\n", "# Оптимизатор AdamW с L2-регуляризацией (weight_decay) для повышения обобщающей способности\n", "optimizer = torch.optim.AdamW(\n", " model.parameters(),\n", " lr=LR,\n", " weight_decay=1e-4\n", ")\n", "\n", "# Планировщик скорости обучения: косинусный отжиг\n", "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n", " optimizer,\n", " T_max=EPOCHS\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "81a457ef", "metadata": {}, "outputs": [], "source": [ "def train_epoch(current_model, loader):\n", " current_model.train()\n", " total_loss = 0.0\n", " correct_preds = 0\n", " total_samples = 0\n", "\n", " for imgs, labels in tqdm(loader, desc=\"Тренировка\", leave=False):\n", " imgs = imgs.to(device)\n", " labels = labels.to(device)\n", "\n", " optimizer.zero_grad()\n", " logits = current_model(imgs)\n", " loss = criterion(logits, labels)\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " total_loss += loss.item() * imgs.size(0)\n", " preds = logits.argmax(dim=1)\n", " correct_preds += (preds == labels).sum().item()\n", " total_samples += labels.size(0)\n", "\n", " return total_loss / total_samples, correct_preds / total_samples\n", "\n", "@torch.no_grad()\n", "def val_epoch(current_model, loader):\n", " # Перевод модели в режим инференса (отключение Dropout и фиксация BatchNorm)\n", " current_model.eval()\n", " total_loss = 0.0\n", " correct_preds = 0\n", " total_samples = 0\n", "\n", " for imgs, labels in tqdm(loader, desc=\"Валидация\", leave=False):\n", " imgs = imgs.to(device)\n", " labels = labels.to(device)\n", "\n", " logits = current_model(imgs)\n", " loss = criterion(logits, labels)\n", "\n", " total_loss += loss.item() * imgs.size(0)\n", " preds = logits.argmax(dim=1)\n", " correct_preds += (preds == labels).sum().item()\n", " total_samples += labels.size(0)\n", "\n", " return total_loss / total_samples, correct_preds / total_samples" ] }, { "cell_type": "code", "execution_count": null, "id": "951aa9e3", "metadata": {}, "outputs": [], "source": [ "best_val_acc = 0.0\n", "checkpoint_path = \"../emoset_resnet50_best.pth\"\n", "\n", "print(\"Старт процесса обучения...\")\n", "\n", "for epoch in range(1, EPOCHS + 1):\n", " train_loss, train_acc = train_epoch(model, train_loader)\n", " val_loss, val_acc = val_epoch(model, val_loader)\n", "\n", " # Обновление шага планировщика\n", " scheduler.step()\n", "\n", " print(\n", " f\"Эпоха {epoch:02d}/{EPOCHS} | \"\n", " f\"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | \"\n", " f\"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}\"\n", " )\n", "\n", " # Экспорт весов при улучшении целевой метрики\n", " if val_acc > best_val_acc:\n", " best_val_acc = val_acc\n", " torch.save(model.state_dict(), checkpoint_path)\n", " print(f\" -> Сохранен новый лучший чекпоинт (Acc: {best_val_acc:.4f})\")\n", "\n", "print(\"Обучение завершено.\")" ] } ], "metadata": { "kernelspec": { "display_name": "thesis-py3.11", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 5 }