542 lines
25 KiB
Plaintext
542 lines
25 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0c00b67b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"from pathlib import Path\n",
|
|
"from PIL import Image\n",
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"from torch.utils.data import Dataset, DataLoader\n",
|
|
"import torchvision.transforms as T\n",
|
|
"import timm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "84c3657f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'cuda'"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Конфигурация параметров обучения и путей файловой системы\n",
|
|
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
|
|
"BATCH_SIZE = 64\n",
|
|
"EPOCHS = 15\n",
|
|
"LR = 3e-4\n",
|
|
"NUM_WORKERS = 62\n",
|
|
"\n",
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
"print(f\"Аппаратное ускорение: {device}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9f749add",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class EmoSetDataset(Dataset):\n",
|
|
" def __init__(self, root: Path | str, split: str):\n",
|
|
" self.root = Path(root) / split\n",
|
|
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
|
|
"\n",
|
|
" # Формирование словарей маппинга классов\n",
|
|
" self.labels = sorted(self.df[\"label\"].unique())\n",
|
|
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
|
|
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
|
|
"\n",
|
|
" # Базовые трансформации для валидации и теста\n",
|
|
" base_tf = [\n",
|
|
" T.ToTensor(),\n",
|
|
" T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
|
|
" ]\n",
|
|
"\n",
|
|
" # Внедрение аугментации исключительно для обучающей выборки (предотвращение переобучения)\n",
|
|
" if split == \"train\":\n",
|
|
" self.transform = T.Compose([\n",
|
|
" T.RandomResizedCrop(224),\n",
|
|
" T.RandomHorizontalFlip(),\n",
|
|
" *base_tf\n",
|
|
" ])\n",
|
|
" else:\n",
|
|
" self.transform = T.Compose([\n",
|
|
" T.Resize(256),\n",
|
|
" T.CenterCrop(224),\n",
|
|
" *base_tf\n",
|
|
" ])\n",
|
|
"\n",
|
|
" def __len__(self):\n",
|
|
" return len(self.df)\n",
|
|
"\n",
|
|
" def __getitem__(self, idx):\n",
|
|
" row = self.df.iloc[idx]\n",
|
|
" img_path = self.root / \"images\" / row[\"filename\"]\n",
|
|
"\n",
|
|
" # Обработка возможных исключений ввода-вывода (поврежденные JPEG-файлы в датасете)\n",
|
|
" try:\n",
|
|
" img = Image.open(img_path).convert(\"RGB\")\n",
|
|
" except Exception:\n",
|
|
" img = Image.new(\"RGB\", (224, 224), (0, 0, 0))\n",
|
|
"\n",
|
|
" img_tensor = self.transform(img)\n",
|
|
" label_idx = self.label2idx[row[\"label\"]]\n",
|
|
" \n",
|
|
" return img_tensor, label_idx"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c8805341",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Classes: ['amusement', 'anger', 'awe', 'contentment', 'disgust', 'excitement', 'fear', 'sadness']\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Подготовка объектов выборки\n",
|
|
"train_ds = EmoSetDataset(DATA_ROOT, \"train\")\n",
|
|
"val_ds = EmoSetDataset(DATA_ROOT, \"val\")\n",
|
|
"\n",
|
|
"# Инициализация итераторов с закреплением памяти (pin_memory) для ускорения передачи на GPU\n",
|
|
"train_loader = DataLoader(\n",
|
|
" train_ds,\n",
|
|
" batch_size=BATCH_SIZE,\n",
|
|
" shuffle=True,\n",
|
|
" num_workers=NUM_WORKERS,\n",
|
|
" pin_memory=True\n",
|
|
")\n",
|
|
"\n",
|
|
"val_loader = DataLoader(\n",
|
|
" val_ds,\n",
|
|
" batch_size=BATCH_SIZE,\n",
|
|
" shuffle=False,\n",
|
|
" num_workers=NUM_WORKERS,\n",
|
|
" pin_memory=True\n",
|
|
")\n",
|
|
"\n",
|
|
"print(f\"Индексированные классы: {train_ds.labels}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "dffce582",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"ResNet(\n",
|
|
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
|
|
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
|
|
" (layer1): Sequential(\n",
|
|
" (0): Bottleneck(\n",
|
|
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" (downsample): Sequential(\n",
|
|
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (1): Bottleneck(\n",
|
|
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" )\n",
|
|
" (2): Bottleneck(\n",
|
|
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (layer2): Sequential(\n",
|
|
" (0): Bottleneck(\n",
|
|
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" (downsample): Sequential(\n",
|
|
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
|
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (1): Bottleneck(\n",
|
|
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" )\n",
|
|
" (2): Bottleneck(\n",
|
|
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" )\n",
|
|
" (3): Bottleneck(\n",
|
|
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (layer3): Sequential(\n",
|
|
" (0): Bottleneck(\n",
|
|
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" (downsample): Sequential(\n",
|
|
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
|
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (1): Bottleneck(\n",
|
|
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" )\n",
|
|
" (2): Bottleneck(\n",
|
|
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" )\n",
|
|
" (3): Bottleneck(\n",
|
|
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" )\n",
|
|
" (4): Bottleneck(\n",
|
|
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" )\n",
|
|
" (5): Bottleneck(\n",
|
|
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (layer4): Sequential(\n",
|
|
" (0): Bottleneck(\n",
|
|
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" (downsample): Sequential(\n",
|
|
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
|
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (1): Bottleneck(\n",
|
|
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" )\n",
|
|
" (2): Bottleneck(\n",
|
|
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act1): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (drop_block): Identity()\n",
|
|
" (act2): ReLU(inplace=True)\n",
|
|
" (aa): Identity()\n",
|
|
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
|
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (act3): ReLU(inplace=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
|
|
" (fc): Linear(in_features=2048, out_features=8, bias=True)\n",
|
|
")"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# TODO перед защитой, повторить оптимизаторы\n",
|
|
"# Загрузка предобученной архитектуры ResNet-50 с заменой классификационного слоя\n",
|
|
"model = timm.create_model(\n",
|
|
" \"resnet50\",\n",
|
|
" pretrained=True,\n",
|
|
" num_classes=len(train_ds.labels)\n",
|
|
")\n",
|
|
"model.to(device)\n",
|
|
"\n",
|
|
"# Функция потерь для многоклассовой классификации\n",
|
|
"criterion = nn.CrossEntropyLoss()\n",
|
|
"\n",
|
|
"# Оптимизатор AdamW с L2-регуляризацией (weight_decay) для повышения обобщающей способности\n",
|
|
"optimizer = torch.optim.AdamW(\n",
|
|
" model.parameters(),\n",
|
|
" lr=LR,\n",
|
|
" weight_decay=1e-4\n",
|
|
")\n",
|
|
"\n",
|
|
"# Планировщик скорости обучения: косинусный отжиг\n",
|
|
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
|
|
" optimizer,\n",
|
|
" T_max=EPOCHS\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "81a457ef",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train_epoch(current_model, loader):\n",
|
|
" current_model.train()\n",
|
|
" total_loss = 0.0\n",
|
|
" correct_preds = 0\n",
|
|
" total_samples = 0\n",
|
|
"\n",
|
|
" for imgs, labels in tqdm(loader, desc=\"Тренировка\", leave=False):\n",
|
|
" imgs = imgs.to(device)\n",
|
|
" labels = labels.to(device)\n",
|
|
"\n",
|
|
" optimizer.zero_grad()\n",
|
|
" logits = current_model(imgs)\n",
|
|
" loss = criterion(logits, labels)\n",
|
|
"\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
"\n",
|
|
" total_loss += loss.item() * imgs.size(0)\n",
|
|
" preds = logits.argmax(dim=1)\n",
|
|
" correct_preds += (preds == labels).sum().item()\n",
|
|
" total_samples += labels.size(0)\n",
|
|
"\n",
|
|
" return total_loss / total_samples, correct_preds / total_samples\n",
|
|
"\n",
|
|
"@torch.no_grad()\n",
|
|
"def val_epoch(current_model, loader):\n",
|
|
" # Перевод модели в режим инференса (отключение Dropout и фиксация BatchNorm)\n",
|
|
" current_model.eval()\n",
|
|
" total_loss = 0.0\n",
|
|
" correct_preds = 0\n",
|
|
" total_samples = 0\n",
|
|
"\n",
|
|
" for imgs, labels in tqdm(loader, desc=\"Валидация\", leave=False):\n",
|
|
" imgs = imgs.to(device)\n",
|
|
" labels = labels.to(device)\n",
|
|
"\n",
|
|
" logits = current_model(imgs)\n",
|
|
" loss = criterion(logits, labels)\n",
|
|
"\n",
|
|
" total_loss += loss.item() * imgs.size(0)\n",
|
|
" preds = logits.argmax(dim=1)\n",
|
|
" correct_preds += (preds == labels).sum().item()\n",
|
|
" total_samples += labels.size(0)\n",
|
|
"\n",
|
|
" return total_loss / total_samples, correct_preds / total_samples"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "951aa9e3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"best_val_acc = 0.0\n",
|
|
"checkpoint_path = \"../emoset_resnet50_best.pth\"\n",
|
|
"\n",
|
|
"print(\"Старт процесса обучения...\")\n",
|
|
"\n",
|
|
"for epoch in range(1, EPOCHS + 1):\n",
|
|
" train_loss, train_acc = train_epoch(model, train_loader)\n",
|
|
" val_loss, val_acc = val_epoch(model, val_loader)\n",
|
|
"\n",
|
|
" # Обновление шага планировщика\n",
|
|
" scheduler.step()\n",
|
|
"\n",
|
|
" print(\n",
|
|
" f\"Эпоха {epoch:02d}/{EPOCHS} | \"\n",
|
|
" f\"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | \"\n",
|
|
" f\"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Экспорт весов при улучшении целевой метрики\n",
|
|
" if val_acc > best_val_acc:\n",
|
|
" best_val_acc = val_acc\n",
|
|
" torch.save(model.state_dict(), checkpoint_path)\n",
|
|
" print(f\" -> Сохранен новый лучший чекпоинт (Acc: {best_val_acc:.4f})\")\n",
|
|
"\n",
|
|
"print(\"Обучение завершено.\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "thesis-py3.11",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|