Add EmoSet training and evaluation

This commit is contained in:
zin
2026-01-13 10:00:04 +00:00
parent ea505b4f9c
commit 00a8472c8d
547 changed files with 62318 additions and 0 deletions

537
src/acc_images_model.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,614 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "09f9237a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (4.4.2)\n",
"Requirement already satisfied: tqdm in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (4.67.1)\n",
"Requirement already satisfied: pillow in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (12.1.0)\n",
"Requirement already satisfied: requests in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (2.32.5)\n",
"Requirement already satisfied: filelock in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (3.20.3)\n",
"Requirement already satisfied: numpy>=1.17 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (2.4.1)\n",
"Requirement already satisfied: pyarrow>=21.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (22.0.0)\n",
"Requirement already satisfied: dill<0.4.1,>=0.3.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.4.0)\n",
"Requirement already satisfied: pandas in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (2.3.3)\n",
"Requirement already satisfied: httpx<1.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.28.1)\n",
"Requirement already satisfied: xxhash in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (3.6.0)\n",
"Requirement already satisfied: multiprocess<0.70.19 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.70.18)\n",
"Requirement already satisfied: fsspec<=2025.10.0,>=2023.1.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2025.10.0)\n",
"Requirement already satisfied: huggingface-hub<2.0,>=0.25.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (1.3.1)\n",
"Requirement already satisfied: packaging in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (25.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (6.0.3)\n",
"Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (3.13.3)\n",
"Requirement already satisfied: anyio in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (4.12.1)\n",
"Requirement already satisfied: certifi in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (2026.1.4)\n",
"Requirement already satisfied: httpcore==1.* in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (1.0.9)\n",
"Requirement already satisfied: idna in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (3.11)\n",
"Requirement already satisfied: h11>=0.16 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx<1.0.0->datasets) (0.16.0)\n",
"Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.2.0)\n",
"Requirement already satisfied: shellingham in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.5.4)\n",
"Requirement already satisfied: typer-slim in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (0.21.1)\n",
"Requirement already satisfied: typing-extensions>=4.1.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (4.15.0)\n",
"Requirement already satisfied: charset_normalizer<4,>=2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests) (3.4.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests) (2.6.3)\n",
"Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2.6.1)\n",
"Requirement already satisfied: aiosignal>=1.4.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.4.0)\n",
"Requirement already satisfied: attrs>=17.3.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (25.4.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.8.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (6.7.0)\n",
"Requirement already satisfied: propcache>=0.2.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (0.4.1)\n",
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.22.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2025.2)\n",
"Requirement already satisfied: tzdata>=2022.7 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2025.3)\n",
"Requirement already satisfied: six>=1.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n",
"Requirement already satisfied: click>=8.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from typer-slim->huggingface-hub<2.0,>=0.25.0->datasets) (8.3.1)\n"
]
}
],
"source": [
"!pip install datasets tqdm pillow requests\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6f0b2e2c",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "95f07577d20642b09f2cda6f0b2cca14",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "868d872a109d49f9966f2f19985e7048",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "06741794289540849ad179c5966dcab8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0/18 [00:00<?, ?files/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e47aad5270144913996cb5b226213ab9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00000-of-00018.parquet: 0%| | 0.00/509M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "30d1492a948245e3b6b58e92218cd760",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00001-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "931823b458cb4696b459e9011537cf1e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00002-of-00018.parquet: 0%| | 0.00/489M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "846f4245b16d4cc096a43c940590ad11",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00003-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "71df201ff1a24811af67458c3fe3f2f4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00004-of-00018.parquet: 0%| | 0.00/495M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "404dce6c69fc413dbe4aa84c289a0ab6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00005-of-00018.parquet: 0%| | 0.00/501M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e52b0bbbfdd14c599f44f02a48542317",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00006-of-00018.parquet: 0%| | 0.00/510M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "172981d77fc941cfa32c05f5a34bf742",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00007-of-00018.parquet: 0%| | 0.00/497M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cc9d886ff22f4165bf696c8b4d758931",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00008-of-00018.parquet: 0%| | 0.00/512M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5f118a9923c64ee2aa2001a1414927a3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00009-of-00018.parquet: 0%| | 0.00/502M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "db61d8d556dc4574adbd8f916f790fa7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00010-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "75414190b19c4affbe190f6dd4f7bc4f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00011-of-00018.parquet: 0%| | 0.00/500M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "172aa22ed0c44a289e0ac68b240c13c4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00012-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2baa935ed3524a73883909752cb15907",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00013-of-00018.parquet: 0%| | 0.00/491M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5e716611b29b44788e0bf2e7ad05be5b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00014-of-00018.parquet: 0%| | 0.00/502M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d9c0baac101b449794155392f07b49c3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00015-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b31cdc7f17ac4ac8a04593e8a01a300a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00016-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ed6766f750c54b4194957bfe3db78ed6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00017-of-00018.parquet: 0%| | 0.00/494M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5454d2ecded64b82a12823f02a7ab12d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/val-00000-of-00002.parquet: 0%| | 0.00/282M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "62dd1439e0514c98b0c24cc8f600c57e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/val-00001-of-00002.parquet: 0%| | 0.00/283M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3a5b966f79314e069251462bff82395f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00000-of-00004.parquet: 0%| | 0.00/422M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "422974f938924910a0712b30a9c2bd84",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00001-of-00004.parquet: 0%| | 0.00/430M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f155a08427094de7ad1a5884e623db2b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00002-of-00004.parquet: 0%| | 0.00/420M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a94a4621d19f45f690e0064fee83767b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00003-of-00004.parquet: 0%| | 0.00/422M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "50f55b00a27b4213b573b398e5b0d708",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0%| | 0/94481 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8658b8414f604f0ca2fd248a214ad4aa",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating val split: 0%| | 0/5905 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d59b7dea75f84b64bb8b262b43730e51",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating test split: 0%| | 0/17716 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0c5815040f0a4a31903348a8327811a5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading dataset shards: 0%| | 0/18 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
" num_rows: 94481\n",
" })\n",
" val: Dataset({\n",
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
" num_rows: 5905\n",
" })\n",
" test: Dataset({\n",
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
" num_rows: 17716\n",
" })\n",
"})\n"
]
}
],
"source": [
"from datasets import load_dataset\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"import requests\n",
"\n",
"# куда сохраняем датасет\n",
"DATA_DIR = Path(\"../dataset/EmoSet-118K\")\n",
"DATA_DIR.mkdir(exist_ok=True, parents=True)\n",
"\n",
"# загружаем через Hugging Face\n",
"ds = load_dataset(\"Woleek/EmoSet-118K\")\n",
"\n",
"print(ds)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "052ab073",
"metadata": {},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"from pathlib import Path\n",
"\n",
"def save_split(split):\n",
" split_dir = DATA_DIR / split\n",
" img_dir = split_dir / \"images\"\n",
" img_dir.mkdir(parents=True, exist_ok=True)\n",
"\n",
" labels_path = split_dir / \"labels.csv\"\n",
"\n",
" # перезаписываем labels.csv\n",
" with open(labels_path, \"w\") as f:\n",
" f.write(\"filename,label\\n\")\n",
"\n",
" for example in tqdm(ds[split]):\n",
" img = example[\"image\"] # уже PIL.Image\n",
" label = example[\"emotion\"]\n",
" image_id = example[\"image_id\"]\n",
"\n",
" fname = f\"{image_id}.jpg\"\n",
" img.save(img_dir / fname)\n",
"\n",
" with open(labels_path, \"a\") as f:\n",
" f.write(f\"{fname},{label}\\n\")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a74ceedf",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 94481/94481 [18:43<00:00, 84.10it/s] \n",
"100%|██████████| 5905/5905 [01:08<00:00, 86.57it/s] \n",
"100%|██████████| 17716/17716 [02:57<00:00, 100.01it/s]\n"
]
}
],
"source": [
"save_split(\"train\")\n",
"save_split(\"val\")\n",
"save_split(\"test\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "thesis-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

Binary file not shown.

View File

View File

@@ -0,0 +1,114 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d70d8e32",
"metadata": {},
"outputs": [],
"source": [
"from concurrent.futures import ProcessPoolExecutor\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"import torch\n",
"from torchvision import transforms\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "31b0fa82",
"metadata": {},
"outputs": [],
"source": [
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
"TRANSFORM = transforms.Compose([\n",
" transforms.Resize((224,224)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1a17ecf5",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/94481 [00:00<?, ?it/s]\n"
]
},
{
"ename": "PicklingError",
"evalue": "Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31m_RemoteTraceback\u001b[39m Traceback (most recent call last)",
"\u001b[31m_RemoteTraceback\u001b[39m: \n\"\"\"\nTraceback (most recent call last):\n File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 244, in _feed\n obj = _ForkingPickler.dumps(obj)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/reduction.py\", line 51, in dumps\n cls(buf, protocol).dump(obj)\n_pickle.PicklingError: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed\n\"\"\"",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[31mPicklingError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 20\u001b[39m\n\u001b[32m 18\u001b[39m futures = [executor.submit(process_row, row, split_dir, tensor_dir) \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m df.itertuples()]\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m tqdm(futures):\n\u001b[32m---> \u001b[39m\u001b[32m20\u001b[39m results.append(\u001b[43mf\u001b[49m\u001b[43m.\u001b[49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 22\u001b[39m new_df = pd.DataFrame(results)\n\u001b[32m 23\u001b[39m new_df.to_csv(DATA_ROOT / split / \u001b[33m\"\u001b[39m\u001b[33mlabels_tensor.csv\u001b[39m\u001b[33m\"\u001b[39m, index=\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:449\u001b[39m, in \u001b[36mFuture.result\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 447\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m CancelledError()\n\u001b[32m 448\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state == FINISHED:\n\u001b[32m--> \u001b[39m\u001b[32m449\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m__get_result\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 451\u001b[39m \u001b[38;5;28mself\u001b[39m._condition.wait(timeout)\n\u001b[32m 453\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state \u001b[38;5;129;01min\u001b[39;00m [CANCELLED, CANCELLED_AND_NOTIFIED]:\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:401\u001b[39m, in \u001b[36mFuture.__get_result\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 399\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception:\n\u001b[32m 400\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m401\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception\n\u001b[32m 402\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 403\u001b[39m \u001b[38;5;66;03m# Break a reference cycle with the exception in self._exception\u001b[39;00m\n\u001b[32m 404\u001b[39m \u001b[38;5;28mself\u001b[39m = \u001b[38;5;28;01mNone\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py:244\u001b[39m, in \u001b[36mQueue._feed\u001b[39m\u001b[34m(buffer, notempty, send_bytes, writelock, reader_close, writer_close, ignore_epipe, onerror, queue_sem)\u001b[39m\n\u001b[32m 241\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[32m 243\u001b[39m \u001b[38;5;66;03m# serialize the data before acquiring the lock\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m244\u001b[39m obj = \u001b[43m_ForkingPickler\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdumps\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 245\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m wacquire \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 246\u001b[39m send_bytes(obj)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/reduction.py:51\u001b[39m, in \u001b[36mForkingPickler.dumps\u001b[39m\u001b[34m(cls, obj, protocol)\u001b[39m\n\u001b[32m 48\u001b[39m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[32m 49\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdumps\u001b[39m(\u001b[38;5;28mcls\u001b[39m, obj, protocol=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m 50\u001b[39m buf = io.BytesIO()\n\u001b[32m---> \u001b[39m\u001b[32m51\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mbuf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprotocol\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdump\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 52\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m buf.getbuffer()\n",
"\u001b[31mPicklingError\u001b[39m: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed"
]
}
],
"source": [
"def process_row(row, split_dir, tensor_dir):\n",
" img_path = split_dir / row.filename\n",
" img = Image.open(img_path).convert(\"RGB\")\n",
" tensor = TRANSFORM(img)\n",
" tensor_path = tensor_dir / f\"{row.filename}.pt\"\n",
" torch.save(tensor, tensor_path)\n",
" return {\"tensor_path\": str(tensor_path), \"label\": row.label}\n",
"\n",
"for split in [\"train\",\"val\",\"test\"]:\n",
" split_dir = DATA_ROOT / split / \"images\"\n",
" tensor_dir = DATA_ROOT / split / \"tensors\"\n",
" tensor_dir.mkdir(exist_ok=True, parents=True)\n",
"\n",
" df = pd.read_csv(DATA_ROOT / split / \"labels.csv\")\n",
"\n",
" results = []\n",
" with ProcessPoolExecutor(max_workers=12) as executor:\n",
" futures = [executor.submit(process_row, row, split_dir, tensor_dir) for row in df.itertuples()]\n",
" for f in tqdm(futures):\n",
" results.append(f.result())\n",
"\n",
" new_df = pd.DataFrame(results)\n",
" new_df.to_csv(DATA_ROOT / split / \"labels_tensor.csv\", index=False)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "thesis-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

192
src/test.ipynb Normal file
View File

@@ -0,0 +1,192 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "ca08df84",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n",
"Step 0/1000, Loss: 1.0042\n",
"Step 10/1000, Loss: 1.0012\n",
"Step 20/1000, Loss: 0.9928\n",
"Step 30/1000, Loss: 0.9793\n",
"Step 40/1000, Loss: 0.9584\n",
"Step 50/1000, Loss: 0.9277\n",
"Step 60/1000, Loss: 0.8837\n",
"Step 70/1000, Loss: 0.8089\n",
"Step 80/1000, Loss: 0.7500\n",
"Step 90/1000, Loss: 0.6431\n",
"Step 100/1000, Loss: 0.5583\n",
"Step 110/1000, Loss: 0.4861\n",
"Step 120/1000, Loss: 0.4373\n",
"Step 130/1000, Loss: 0.3796\n",
"Step 140/1000, Loss: 0.3441\n",
"Step 150/1000, Loss: 0.3014\n",
"Step 160/1000, Loss: 0.2682\n",
"Step 170/1000, Loss: 0.2513\n",
"Step 180/1000, Loss: 0.2077\n",
"Step 190/1000, Loss: 0.1861\n",
"Step 200/1000, Loss: 0.1674\n",
"Step 210/1000, Loss: 0.1552\n",
"Step 220/1000, Loss: 0.1284\n",
"Step 230/1000, Loss: 0.1103\n",
"Step 240/1000, Loss: 0.1114\n",
"Step 250/1000, Loss: 0.0993\n",
"Step 260/1000, Loss: 0.0870\n",
"Step 270/1000, Loss: 0.1048\n",
"Step 280/1000, Loss: 0.0787\n",
"Step 290/1000, Loss: 0.0666\n",
"Step 300/1000, Loss: 0.0603\n",
"Step 310/1000, Loss: 0.0668\n",
"Step 320/1000, Loss: 0.0466\n",
"Step 330/1000, Loss: 0.0398\n",
"Step 340/1000, Loss: 0.0421\n",
"Step 350/1000, Loss: 0.0494\n",
"Step 360/1000, Loss: 0.0381\n",
"Step 370/1000, Loss: 0.0329\n",
"Step 380/1000, Loss: 0.0529\n",
"Step 390/1000, Loss: 0.0306\n",
"Step 400/1000, Loss: 0.0257\n",
"Step 410/1000, Loss: 0.0292\n",
"Step 420/1000, Loss: 0.0260\n",
"Step 430/1000, Loss: 0.0225\n",
"Step 440/1000, Loss: 0.0262\n",
"Step 450/1000, Loss: 0.0220\n",
"Step 460/1000, Loss: 0.0206\n",
"Step 470/1000, Loss: 0.0354\n",
"Step 480/1000, Loss: 0.0271\n",
"Step 490/1000, Loss: 0.0214\n",
"Step 500/1000, Loss: 0.0186\n",
"Step 510/1000, Loss: 0.0153\n",
"Step 520/1000, Loss: 0.0141\n",
"Step 530/1000, Loss: 0.0349\n",
"Step 540/1000, Loss: 0.0186\n",
"Step 550/1000, Loss: 0.0157\n",
"Step 560/1000, Loss: 0.0142\n",
"Step 570/1000, Loss: 0.0345\n",
"Step 580/1000, Loss: 0.0166\n",
"Step 590/1000, Loss: 0.0156\n",
"Step 600/1000, Loss: 0.0136\n",
"Step 610/1000, Loss: 0.0434\n",
"Step 620/1000, Loss: 0.0210\n",
"Step 630/1000, Loss: 0.0170\n",
"Step 640/1000, Loss: 0.0177\n",
"Step 650/1000, Loss: 0.0312\n",
"Step 660/1000, Loss: 0.0219\n",
"Step 670/1000, Loss: 0.0169\n",
"Step 680/1000, Loss: 0.0149\n",
"Step 690/1000, Loss: 0.0172\n",
"Step 700/1000, Loss: 0.0131\n",
"Step 710/1000, Loss: 0.0257\n",
"Step 720/1000, Loss: 0.0158\n",
"Step 730/1000, Loss: 0.0152\n",
"Step 740/1000, Loss: 0.0166\n",
"Step 750/1000, Loss: 0.0301\n",
"Step 760/1000, Loss: 0.0187\n",
"Step 770/1000, Loss: 0.0179\n",
"Step 780/1000, Loss: 0.0177\n",
"Step 790/1000, Loss: 0.0168\n",
"Step 800/1000, Loss: 0.0297\n",
"Step 810/1000, Loss: 0.0182\n",
"Step 820/1000, Loss: 0.0158\n",
"Step 830/1000, Loss: 0.0143\n",
"Step 840/1000, Loss: 0.0145\n",
"Step 850/1000, Loss: 0.0175\n",
"Step 860/1000, Loss: 0.0200\n",
"Step 870/1000, Loss: 0.0157\n",
"Step 880/1000, Loss: 0.0133\n",
"Step 890/1000, Loss: 0.0159\n",
"Step 900/1000, Loss: 0.0190\n",
"Step 910/1000, Loss: 0.0146\n",
"Step 920/1000, Loss: 0.0132\n",
"Step 930/1000, Loss: 0.0161\n",
"Step 940/1000, Loss: 0.0134\n",
"Step 950/1000, Loss: 0.0249\n",
"Step 960/1000, Loss: 0.0146\n",
"Step 970/1000, Loss: 0.0116\n",
"Step 980/1000, Loss: 0.0136\n",
"Step 990/1000, Loss: 0.0133\n",
"Total time: 66.18 s\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import time\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(\"Using device:\", device)\n",
"\n",
"# Огромные параметры\n",
"N, D_in, H1, H2, H3, D_out = 300_000, 4096, 2048, 1024, 512, 10\n",
"batch_size = 16_384 # большой батч\n",
"steps = 1000 # много итераций для длительной нагрузки\n",
"\n",
"# Случайные данные на GPU\n",
"x = torch.randn(N, D_in, device=device, dtype=torch.float32)\n",
"y = torch.randn(N, D_out, device=device, dtype=torch.float32)\n",
"\n",
"model = nn.Sequential(\n",
" nn.Linear(D_in, H1),\n",
" nn.ReLU(),\n",
" nn.Linear(H1, H2),\n",
" nn.ReLU(),\n",
" nn.Linear(H2, H3),\n",
" nn.ReLU(),\n",
" nn.Linear(H3, D_out)\n",
").to(device)\n",
"\n",
"loss_fn = nn.MSELoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
"\n",
"start = time.time()\n",
"for t in range(steps):\n",
" idx = torch.randint(0, N, (batch_size,), device=device)\n",
" x_batch = x[idx]\n",
" y_batch = y[idx]\n",
"\n",
" y_pred = model(x_batch)\n",
" loss = loss_fn(y_pred, y_batch)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if t % 10 == 0:\n",
" # замедляем вывод, чтобы можно было наблюдать\n",
" print(f\"Step {t}/{steps}, Loss: {loss.item():.4f}\")\n",
"\n",
"end = time.time()\n",
"print(f\"Total time: {end-start:.2f} s\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (my-python-project)",
"language": "python",
"name": "my-python-project"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

759
src/train_images.ipynb Normal file
View File

@@ -0,0 +1,759 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "9336560f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0c00b67b",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torchvision.transforms as T\n",
"\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"from tqdm import tqdm\n",
"\n",
"import timm\n",
"import numpy as np\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "84c3657f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'cuda'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# === CONFIG ===\n",
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
"BATCH_SIZE = 64 # V100 спокойно тянет\n",
"EPOCHS = 15\n",
"LR = 3e-4\n",
"NUM_WORKERS = 24\n",
"\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"DEVICE\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9f749add",
"metadata": {},
"outputs": [],
"source": [
"class EmoSetDataset(Dataset):\n",
" def __init__(self, root, split):\n",
" self.root = Path(root) / split\n",
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
"\n",
" self.labels = sorted(self.df[\"label\"].unique())\n",
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
"\n",
" self.transform = T.Compose([\n",
" T.Resize((224, 224)),\n",
" T.ToTensor(),\n",
" T.Normalize(\n",
" mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225]\n",
" )\n",
" ])\n",
"\n",
" def __len__(self):\n",
" return len(self.df)\n",
"\n",
" def __getitem__(self, idx):\n",
" row = self.df.iloc[idx]\n",
" img_path = self.root / \"images\" / row[\"filename\"]\n",
"\n",
" img = Image.open(img_path).convert(\"RGB\")\n",
" img = self.transform(img)\n",
"\n",
" label = self.label2idx[row[\"label\"]]\n",
" return img, label\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c8805341",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classes: ['amusement', 'anger', 'awe', 'contentment', 'disgust', 'excitement', 'fear', 'sadness']\n"
]
}
],
"source": [
"train_ds = EmoSetDataset(DATA_ROOT, \"train\")\n",
"val_ds = EmoSetDataset(DATA_ROOT, \"val\")\n",
"\n",
"train_loader = DataLoader(\n",
" train_ds,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" num_workers=NUM_WORKERS,\n",
" pin_memory=True\n",
")\n",
"\n",
"val_loader = DataLoader(\n",
" val_ds,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=False,\n",
" num_workers=NUM_WORKERS,\n",
" pin_memory=True\n",
")\n",
"\n",
"print(\"Classes:\", train_ds.labels)\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dffce582",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ResNet(\n",
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (layer1): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer2): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer3): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (4): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (5): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer4): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
" (fc): Linear(in_features=2048, out_features=8, bias=True)\n",
")"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = timm.create_model(\n",
" \"resnet50\",\n",
" pretrained=True,\n",
" num_classes=len(train_ds.labels)\n",
")\n",
"\n",
"model.to(DEVICE)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "81a457ef",
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"optimizer = torch.optim.AdamW(\n",
" model.parameters(),\n",
" lr=LR,\n",
" weight_decay=1e-4\n",
")\n",
"\n",
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
" optimizer,\n",
" T_max=EPOCHS\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "951aa9e3",
"metadata": {},
"outputs": [],
"source": [
"def train_epoch(model, loader):\n",
" model.train()\n",
" total_loss = 0\n",
" correct = 0\n",
" total = 0\n",
"\n",
" for imgs, labels in tqdm(loader, leave=False):\n",
" imgs = imgs.to(DEVICE)\n",
" labels = labels.to(DEVICE)\n",
"\n",
" optimizer.zero_grad()\n",
" logits = model(imgs)\n",
" loss = criterion(logits, labels)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" total_loss += loss.item() * imgs.size(0)\n",
" preds = logits.argmax(dim=1)\n",
" correct += (preds == labels).sum().item()\n",
" total += labels.size(0)\n",
"\n",
" return total_loss / total, correct / total\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "fb7e9398",
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def val_epoch(model, loader):\n",
" model.eval()\n",
" total_loss = 0\n",
" correct = 0\n",
" total = 0\n",
"\n",
" for imgs, labels in loader:\n",
" imgs = imgs.to(DEVICE)\n",
" labels = labels.to(DEVICE)\n",
"\n",
" logits = model(imgs)\n",
" loss = criterion(logits, labels)\n",
"\n",
" total_loss += loss.item() * imgs.size(0)\n",
" preds = logits.argmax(dim=1)\n",
" correct += (preds == labels).sum().item()\n",
" total += labels.size(0)\n",
"\n",
" return total_loss / total, correct / total\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9e870e5d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/1477 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 01 | Train loss: 0.8383, acc: 0.6954 | Val loss: 0.6694, acc: 0.7563\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 02 | Train loss: 0.5462, acc: 0.7972 | Val loss: 0.6592, acc: 0.7594\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 03 | Train loss: 0.3654, acc: 0.8632 | Val loss: 0.7263, acc: 0.7600\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 04 | Train loss: 0.2111, acc: 0.9230 | Val loss: 0.8572, acc: 0.7472\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 05 | Train loss: 0.1187, acc: 0.9585 | Val loss: 1.0372, acc: 0.7453\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 06 | Train loss: 0.0690, acc: 0.9768 | Val loss: 1.1982, acc: 0.7529\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 07 | Train loss: 0.0466, acc: 0.9843 | Val loss: 1.3178, acc: 0.7492\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 08 | Train loss: 0.0295, acc: 0.9905 | Val loss: 1.3926, acc: 0.7551\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 09 | Train loss: 0.0204, acc: 0.9938 | Val loss: 1.4682, acc: 0.7497\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10 | Train loss: 0.0146, acc: 0.9955 | Val loss: 1.4784, acc: 0.7604\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 11 | Train loss: 0.0087, acc: 0.9975 | Val loss: 1.5263, acc: 0.7580\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 12 | Train loss: 0.0057, acc: 0.9987 | Val loss: 1.5689, acc: 0.7558\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 13 | Train loss: 0.0044, acc: 0.9990 | Val loss: 1.5952, acc: 0.7566\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 14 | Train loss: 0.0030, acc: 0.9993 | Val loss: 1.6130, acc: 0.7600\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 15 | Train loss: 0.0025, acc: 0.9995 | Val loss: 1.5921, acc: 0.7627\n"
]
}
],
"source": [
"best_val_acc = 0.0\n",
"\n",
"for epoch in range(1, EPOCHS + 1):\n",
" train_loss, train_acc = train_epoch(model, train_loader)\n",
" val_loss, val_acc = val_epoch(model, val_loader)\n",
"\n",
" scheduler.step()\n",
"\n",
" print(\n",
" f\"Epoch {epoch:02d} | \"\n",
" f\"Train loss: {train_loss:.4f}, acc: {train_acc:.4f} | \"\n",
" f\"Val loss: {val_loss:.4f}, acc: {val_acc:.4f}\"\n",
" )\n",
"\n",
" if val_acc > best_val_acc:\n",
" best_val_acc = val_acc\n",
" torch.save(model.state_dict(), \"emoset_resnet50_best.pth\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7796ef11",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "thesis-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}