Add EmoSet training and evaluation
This commit is contained in:
537
src/acc_images_model.ipynb
Normal file
537
src/acc_images_model.ipynb
Normal file
File diff suppressed because one or more lines are too long
614
src/download emo_dataset.ipynb
Normal file
614
src/download emo_dataset.ipynb
Normal file
@@ -0,0 +1,614 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "09f9237a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already satisfied: datasets in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (4.4.2)\n",
|
||||
"Requirement already satisfied: tqdm in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (4.67.1)\n",
|
||||
"Requirement already satisfied: pillow in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (12.1.0)\n",
|
||||
"Requirement already satisfied: requests in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (2.32.5)\n",
|
||||
"Requirement already satisfied: filelock in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (3.20.3)\n",
|
||||
"Requirement already satisfied: numpy>=1.17 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (2.4.1)\n",
|
||||
"Requirement already satisfied: pyarrow>=21.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (22.0.0)\n",
|
||||
"Requirement already satisfied: dill<0.4.1,>=0.3.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.4.0)\n",
|
||||
"Requirement already satisfied: pandas in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (2.3.3)\n",
|
||||
"Requirement already satisfied: httpx<1.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.28.1)\n",
|
||||
"Requirement already satisfied: xxhash in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (3.6.0)\n",
|
||||
"Requirement already satisfied: multiprocess<0.70.19 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.70.18)\n",
|
||||
"Requirement already satisfied: fsspec<=2025.10.0,>=2023.1.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2025.10.0)\n",
|
||||
"Requirement already satisfied: huggingface-hub<2.0,>=0.25.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (1.3.1)\n",
|
||||
"Requirement already satisfied: packaging in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (25.0)\n",
|
||||
"Requirement already satisfied: pyyaml>=5.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (6.0.3)\n",
|
||||
"Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (3.13.3)\n",
|
||||
"Requirement already satisfied: anyio in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (4.12.1)\n",
|
||||
"Requirement already satisfied: certifi in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (2026.1.4)\n",
|
||||
"Requirement already satisfied: httpcore==1.* in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (1.0.9)\n",
|
||||
"Requirement already satisfied: idna in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (3.11)\n",
|
||||
"Requirement already satisfied: h11>=0.16 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx<1.0.0->datasets) (0.16.0)\n",
|
||||
"Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.2.0)\n",
|
||||
"Requirement already satisfied: shellingham in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.5.4)\n",
|
||||
"Requirement already satisfied: typer-slim in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (0.21.1)\n",
|
||||
"Requirement already satisfied: typing-extensions>=4.1.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (4.15.0)\n",
|
||||
"Requirement already satisfied: charset_normalizer<4,>=2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests) (3.4.4)\n",
|
||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests) (2.6.3)\n",
|
||||
"Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2.6.1)\n",
|
||||
"Requirement already satisfied: aiosignal>=1.4.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.4.0)\n",
|
||||
"Requirement already satisfied: attrs>=17.3.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (25.4.0)\n",
|
||||
"Requirement already satisfied: frozenlist>=1.1.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.8.0)\n",
|
||||
"Requirement already satisfied: multidict<7.0,>=4.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (6.7.0)\n",
|
||||
"Requirement already satisfied: propcache>=0.2.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (0.4.1)\n",
|
||||
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.22.0)\n",
|
||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2.9.0.post0)\n",
|
||||
"Requirement already satisfied: pytz>=2020.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2025.2)\n",
|
||||
"Requirement already satisfied: tzdata>=2022.7 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2025.3)\n",
|
||||
"Requirement already satisfied: six>=1.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n",
|
||||
"Requirement already satisfied: click>=8.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from typer-slim->huggingface-hub<2.0,>=0.25.0->datasets) (8.3.1)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!pip install datasets tqdm pillow requests\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "6f0b2e2c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "95f07577d20642b09f2cda6f0b2cca14",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "868d872a109d49f9966f2f19985e7048",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "06741794289540849ad179c5966dcab8",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Downloading data: 0%| | 0/18 [00:00<?, ?files/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "e47aad5270144913996cb5b226213ab9",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00000-of-00018.parquet: 0%| | 0.00/509M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "30d1492a948245e3b6b58e92218cd760",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00001-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "931823b458cb4696b459e9011537cf1e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00002-of-00018.parquet: 0%| | 0.00/489M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "846f4245b16d4cc096a43c940590ad11",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00003-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "71df201ff1a24811af67458c3fe3f2f4",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00004-of-00018.parquet: 0%| | 0.00/495M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "404dce6c69fc413dbe4aa84c289a0ab6",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00005-of-00018.parquet: 0%| | 0.00/501M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "e52b0bbbfdd14c599f44f02a48542317",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00006-of-00018.parquet: 0%| | 0.00/510M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "172981d77fc941cfa32c05f5a34bf742",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00007-of-00018.parquet: 0%| | 0.00/497M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "cc9d886ff22f4165bf696c8b4d758931",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00008-of-00018.parquet: 0%| | 0.00/512M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "5f118a9923c64ee2aa2001a1414927a3",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00009-of-00018.parquet: 0%| | 0.00/502M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "db61d8d556dc4574adbd8f916f790fa7",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00010-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "75414190b19c4affbe190f6dd4f7bc4f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00011-of-00018.parquet: 0%| | 0.00/500M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "172aa22ed0c44a289e0ac68b240c13c4",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00012-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "2baa935ed3524a73883909752cb15907",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00013-of-00018.parquet: 0%| | 0.00/491M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "5e716611b29b44788e0bf2e7ad05be5b",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00014-of-00018.parquet: 0%| | 0.00/502M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "d9c0baac101b449794155392f07b49c3",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00015-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "b31cdc7f17ac4ac8a04593e8a01a300a",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00016-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "ed6766f750c54b4194957bfe3db78ed6",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/train-00017-of-00018.parquet: 0%| | 0.00/494M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "5454d2ecded64b82a12823f02a7ab12d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/val-00000-of-00002.parquet: 0%| | 0.00/282M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "62dd1439e0514c98b0c24cc8f600c57e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/val-00001-of-00002.parquet: 0%| | 0.00/283M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "3a5b966f79314e069251462bff82395f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/test-00000-of-00004.parquet: 0%| | 0.00/422M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "422974f938924910a0712b30a9c2bd84",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/test-00001-of-00004.parquet: 0%| | 0.00/430M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "f155a08427094de7ad1a5884e623db2b",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/test-00002-of-00004.parquet: 0%| | 0.00/420M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "a94a4621d19f45f690e0064fee83767b",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"data/test-00003-of-00004.parquet: 0%| | 0.00/422M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "50f55b00a27b4213b573b398e5b0d708",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Generating train split: 0%| | 0/94481 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "8658b8414f604f0ca2fd248a214ad4aa",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Generating val split: 0%| | 0/5905 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "d59b7dea75f84b64bb8b262b43730e51",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Generating test split: 0%| | 0/17716 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "0c5815040f0a4a31903348a8327811a5",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading dataset shards: 0%| | 0/18 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"DatasetDict({\n",
|
||||
" train: Dataset({\n",
|
||||
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
|
||||
" num_rows: 94481\n",
|
||||
" })\n",
|
||||
" val: Dataset({\n",
|
||||
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
|
||||
" num_rows: 5905\n",
|
||||
" })\n",
|
||||
" test: Dataset({\n",
|
||||
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
|
||||
" num_rows: 17716\n",
|
||||
" })\n",
|
||||
"})\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"from pathlib import Path\n",
|
||||
"from PIL import Image\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"# куда сохраняем датасет\n",
|
||||
"DATA_DIR = Path(\"../dataset/EmoSet-118K\")\n",
|
||||
"DATA_DIR.mkdir(exist_ok=True, parents=True)\n",
|
||||
"\n",
|
||||
"# загружаем через Hugging Face\n",
|
||||
"ds = load_dataset(\"Woleek/EmoSet-118K\")\n",
|
||||
"\n",
|
||||
"print(ds)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "052ab073",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tqdm import tqdm\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"def save_split(split):\n",
|
||||
" split_dir = DATA_DIR / split\n",
|
||||
" img_dir = split_dir / \"images\"\n",
|
||||
" img_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
" labels_path = split_dir / \"labels.csv\"\n",
|
||||
"\n",
|
||||
" # перезаписываем labels.csv\n",
|
||||
" with open(labels_path, \"w\") as f:\n",
|
||||
" f.write(\"filename,label\\n\")\n",
|
||||
"\n",
|
||||
" for example in tqdm(ds[split]):\n",
|
||||
" img = example[\"image\"] # уже PIL.Image\n",
|
||||
" label = example[\"emotion\"]\n",
|
||||
" image_id = example[\"image_id\"]\n",
|
||||
"\n",
|
||||
" fname = f\"{image_id}.jpg\"\n",
|
||||
" img.save(img_dir / fname)\n",
|
||||
"\n",
|
||||
" with open(labels_path, \"a\") as f:\n",
|
||||
" f.write(f\"{fname},{label}\\n\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "a74ceedf",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 94481/94481 [18:43<00:00, 84.10it/s] \n",
|
||||
"100%|██████████| 5905/5905 [01:08<00:00, 86.57it/s] \n",
|
||||
"100%|██████████| 17716/17716 [02:57<00:00, 100.01it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"save_split(\"train\")\n",
|
||||
"save_split(\"val\")\n",
|
||||
"save_split(\"test\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "thesis-py3.11",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
BIN
src/emoset_resnet50_best.pth
Normal file
BIN
src/emoset_resnet50_best.pth
Normal file
Binary file not shown.
0
src/extract_embeddings.ipynb
Normal file
0
src/extract_embeddings.ipynb
Normal file
114
src/prepare_dataset_to_train.ipynb
Normal file
114
src/prepare_dataset_to_train.ipynb
Normal file
@@ -0,0 +1,114 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "d70d8e32",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from concurrent.futures import ProcessPoolExecutor\n",
|
||||
"import pandas as pd\n",
|
||||
"from pathlib import Path\n",
|
||||
"from PIL import Image\n",
|
||||
"import torch\n",
|
||||
"from torchvision import transforms\n",
|
||||
"from tqdm import tqdm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "31b0fa82",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
|
||||
"TRANSFORM = transforms.Compose([\n",
|
||||
" transforms.Resize((224,224)),\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])\n",
|
||||
"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "1a17ecf5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/94481 [00:00<?, ?it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "PicklingError",
|
||||
"evalue": "Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||||
"\u001b[31m_RemoteTraceback\u001b[39m Traceback (most recent call last)",
|
||||
"\u001b[31m_RemoteTraceback\u001b[39m: \n\"\"\"\nTraceback (most recent call last):\n File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 244, in _feed\n obj = _ForkingPickler.dumps(obj)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/reduction.py\", line 51, in dumps\n cls(buf, protocol).dump(obj)\n_pickle.PicklingError: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed\n\"\"\"",
|
||||
"\nThe above exception was the direct cause of the following exception:\n",
|
||||
"\u001b[31mPicklingError\u001b[39m Traceback (most recent call last)",
|
||||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 20\u001b[39m\n\u001b[32m 18\u001b[39m futures = [executor.submit(process_row, row, split_dir, tensor_dir) \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m df.itertuples()]\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m tqdm(futures):\n\u001b[32m---> \u001b[39m\u001b[32m20\u001b[39m results.append(\u001b[43mf\u001b[49m\u001b[43m.\u001b[49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 22\u001b[39m new_df = pd.DataFrame(results)\n\u001b[32m 23\u001b[39m new_df.to_csv(DATA_ROOT / split / \u001b[33m\"\u001b[39m\u001b[33mlabels_tensor.csv\u001b[39m\u001b[33m\"\u001b[39m, index=\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
|
||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:449\u001b[39m, in \u001b[36mFuture.result\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 447\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m CancelledError()\n\u001b[32m 448\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state == FINISHED:\n\u001b[32m--> \u001b[39m\u001b[32m449\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m__get_result\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 451\u001b[39m \u001b[38;5;28mself\u001b[39m._condition.wait(timeout)\n\u001b[32m 453\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state \u001b[38;5;129;01min\u001b[39;00m [CANCELLED, CANCELLED_AND_NOTIFIED]:\n",
|
||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:401\u001b[39m, in \u001b[36mFuture.__get_result\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 399\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception:\n\u001b[32m 400\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m401\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception\n\u001b[32m 402\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 403\u001b[39m \u001b[38;5;66;03m# Break a reference cycle with the exception in self._exception\u001b[39;00m\n\u001b[32m 404\u001b[39m \u001b[38;5;28mself\u001b[39m = \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py:244\u001b[39m, in \u001b[36mQueue._feed\u001b[39m\u001b[34m(buffer, notempty, send_bytes, writelock, reader_close, writer_close, ignore_epipe, onerror, queue_sem)\u001b[39m\n\u001b[32m 241\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[32m 243\u001b[39m \u001b[38;5;66;03m# serialize the data before acquiring the lock\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m244\u001b[39m obj = \u001b[43m_ForkingPickler\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdumps\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 245\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m wacquire \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 246\u001b[39m send_bytes(obj)\n",
|
||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/reduction.py:51\u001b[39m, in \u001b[36mForkingPickler.dumps\u001b[39m\u001b[34m(cls, obj, protocol)\u001b[39m\n\u001b[32m 48\u001b[39m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[32m 49\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdumps\u001b[39m(\u001b[38;5;28mcls\u001b[39m, obj, protocol=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m 50\u001b[39m buf = io.BytesIO()\n\u001b[32m---> \u001b[39m\u001b[32m51\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mbuf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprotocol\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdump\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 52\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m buf.getbuffer()\n",
|
||||
"\u001b[31mPicklingError\u001b[39m: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def process_row(row, split_dir, tensor_dir):\n",
|
||||
" img_path = split_dir / row.filename\n",
|
||||
" img = Image.open(img_path).convert(\"RGB\")\n",
|
||||
" tensor = TRANSFORM(img)\n",
|
||||
" tensor_path = tensor_dir / f\"{row.filename}.pt\"\n",
|
||||
" torch.save(tensor, tensor_path)\n",
|
||||
" return {\"tensor_path\": str(tensor_path), \"label\": row.label}\n",
|
||||
"\n",
|
||||
"for split in [\"train\",\"val\",\"test\"]:\n",
|
||||
" split_dir = DATA_ROOT / split / \"images\"\n",
|
||||
" tensor_dir = DATA_ROOT / split / \"tensors\"\n",
|
||||
" tensor_dir.mkdir(exist_ok=True, parents=True)\n",
|
||||
"\n",
|
||||
" df = pd.read_csv(DATA_ROOT / split / \"labels.csv\")\n",
|
||||
"\n",
|
||||
" results = []\n",
|
||||
" with ProcessPoolExecutor(max_workers=12) as executor:\n",
|
||||
" futures = [executor.submit(process_row, row, split_dir, tensor_dir) for row in df.itertuples()]\n",
|
||||
" for f in tqdm(futures):\n",
|
||||
" results.append(f.result())\n",
|
||||
"\n",
|
||||
" new_df = pd.DataFrame(results)\n",
|
||||
" new_df.to_csv(DATA_ROOT / split / \"labels_tensor.csv\", index=False)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "thesis-py3.11",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
192
src/test.ipynb
Normal file
192
src/test.ipynb
Normal file
@@ -0,0 +1,192 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ca08df84",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using device: cuda\n",
|
||||
"Step 0/1000, Loss: 1.0042\n",
|
||||
"Step 10/1000, Loss: 1.0012\n",
|
||||
"Step 20/1000, Loss: 0.9928\n",
|
||||
"Step 30/1000, Loss: 0.9793\n",
|
||||
"Step 40/1000, Loss: 0.9584\n",
|
||||
"Step 50/1000, Loss: 0.9277\n",
|
||||
"Step 60/1000, Loss: 0.8837\n",
|
||||
"Step 70/1000, Loss: 0.8089\n",
|
||||
"Step 80/1000, Loss: 0.7500\n",
|
||||
"Step 90/1000, Loss: 0.6431\n",
|
||||
"Step 100/1000, Loss: 0.5583\n",
|
||||
"Step 110/1000, Loss: 0.4861\n",
|
||||
"Step 120/1000, Loss: 0.4373\n",
|
||||
"Step 130/1000, Loss: 0.3796\n",
|
||||
"Step 140/1000, Loss: 0.3441\n",
|
||||
"Step 150/1000, Loss: 0.3014\n",
|
||||
"Step 160/1000, Loss: 0.2682\n",
|
||||
"Step 170/1000, Loss: 0.2513\n",
|
||||
"Step 180/1000, Loss: 0.2077\n",
|
||||
"Step 190/1000, Loss: 0.1861\n",
|
||||
"Step 200/1000, Loss: 0.1674\n",
|
||||
"Step 210/1000, Loss: 0.1552\n",
|
||||
"Step 220/1000, Loss: 0.1284\n",
|
||||
"Step 230/1000, Loss: 0.1103\n",
|
||||
"Step 240/1000, Loss: 0.1114\n",
|
||||
"Step 250/1000, Loss: 0.0993\n",
|
||||
"Step 260/1000, Loss: 0.0870\n",
|
||||
"Step 270/1000, Loss: 0.1048\n",
|
||||
"Step 280/1000, Loss: 0.0787\n",
|
||||
"Step 290/1000, Loss: 0.0666\n",
|
||||
"Step 300/1000, Loss: 0.0603\n",
|
||||
"Step 310/1000, Loss: 0.0668\n",
|
||||
"Step 320/1000, Loss: 0.0466\n",
|
||||
"Step 330/1000, Loss: 0.0398\n",
|
||||
"Step 340/1000, Loss: 0.0421\n",
|
||||
"Step 350/1000, Loss: 0.0494\n",
|
||||
"Step 360/1000, Loss: 0.0381\n",
|
||||
"Step 370/1000, Loss: 0.0329\n",
|
||||
"Step 380/1000, Loss: 0.0529\n",
|
||||
"Step 390/1000, Loss: 0.0306\n",
|
||||
"Step 400/1000, Loss: 0.0257\n",
|
||||
"Step 410/1000, Loss: 0.0292\n",
|
||||
"Step 420/1000, Loss: 0.0260\n",
|
||||
"Step 430/1000, Loss: 0.0225\n",
|
||||
"Step 440/1000, Loss: 0.0262\n",
|
||||
"Step 450/1000, Loss: 0.0220\n",
|
||||
"Step 460/1000, Loss: 0.0206\n",
|
||||
"Step 470/1000, Loss: 0.0354\n",
|
||||
"Step 480/1000, Loss: 0.0271\n",
|
||||
"Step 490/1000, Loss: 0.0214\n",
|
||||
"Step 500/1000, Loss: 0.0186\n",
|
||||
"Step 510/1000, Loss: 0.0153\n",
|
||||
"Step 520/1000, Loss: 0.0141\n",
|
||||
"Step 530/1000, Loss: 0.0349\n",
|
||||
"Step 540/1000, Loss: 0.0186\n",
|
||||
"Step 550/1000, Loss: 0.0157\n",
|
||||
"Step 560/1000, Loss: 0.0142\n",
|
||||
"Step 570/1000, Loss: 0.0345\n",
|
||||
"Step 580/1000, Loss: 0.0166\n",
|
||||
"Step 590/1000, Loss: 0.0156\n",
|
||||
"Step 600/1000, Loss: 0.0136\n",
|
||||
"Step 610/1000, Loss: 0.0434\n",
|
||||
"Step 620/1000, Loss: 0.0210\n",
|
||||
"Step 630/1000, Loss: 0.0170\n",
|
||||
"Step 640/1000, Loss: 0.0177\n",
|
||||
"Step 650/1000, Loss: 0.0312\n",
|
||||
"Step 660/1000, Loss: 0.0219\n",
|
||||
"Step 670/1000, Loss: 0.0169\n",
|
||||
"Step 680/1000, Loss: 0.0149\n",
|
||||
"Step 690/1000, Loss: 0.0172\n",
|
||||
"Step 700/1000, Loss: 0.0131\n",
|
||||
"Step 710/1000, Loss: 0.0257\n",
|
||||
"Step 720/1000, Loss: 0.0158\n",
|
||||
"Step 730/1000, Loss: 0.0152\n",
|
||||
"Step 740/1000, Loss: 0.0166\n",
|
||||
"Step 750/1000, Loss: 0.0301\n",
|
||||
"Step 760/1000, Loss: 0.0187\n",
|
||||
"Step 770/1000, Loss: 0.0179\n",
|
||||
"Step 780/1000, Loss: 0.0177\n",
|
||||
"Step 790/1000, Loss: 0.0168\n",
|
||||
"Step 800/1000, Loss: 0.0297\n",
|
||||
"Step 810/1000, Loss: 0.0182\n",
|
||||
"Step 820/1000, Loss: 0.0158\n",
|
||||
"Step 830/1000, Loss: 0.0143\n",
|
||||
"Step 840/1000, Loss: 0.0145\n",
|
||||
"Step 850/1000, Loss: 0.0175\n",
|
||||
"Step 860/1000, Loss: 0.0200\n",
|
||||
"Step 870/1000, Loss: 0.0157\n",
|
||||
"Step 880/1000, Loss: 0.0133\n",
|
||||
"Step 890/1000, Loss: 0.0159\n",
|
||||
"Step 900/1000, Loss: 0.0190\n",
|
||||
"Step 910/1000, Loss: 0.0146\n",
|
||||
"Step 920/1000, Loss: 0.0132\n",
|
||||
"Step 930/1000, Loss: 0.0161\n",
|
||||
"Step 940/1000, Loss: 0.0134\n",
|
||||
"Step 950/1000, Loss: 0.0249\n",
|
||||
"Step 960/1000, Loss: 0.0146\n",
|
||||
"Step 970/1000, Loss: 0.0116\n",
|
||||
"Step 980/1000, Loss: 0.0136\n",
|
||||
"Step 990/1000, Loss: 0.0133\n",
|
||||
"Total time: 66.18 s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"print(\"Using device:\", device)\n",
|
||||
"\n",
|
||||
"# Огромные параметры\n",
|
||||
"N, D_in, H1, H2, H3, D_out = 300_000, 4096, 2048, 1024, 512, 10\n",
|
||||
"batch_size = 16_384 # большой батч\n",
|
||||
"steps = 1000 # много итераций для длительной нагрузки\n",
|
||||
"\n",
|
||||
"# Случайные данные на GPU\n",
|
||||
"x = torch.randn(N, D_in, device=device, dtype=torch.float32)\n",
|
||||
"y = torch.randn(N, D_out, device=device, dtype=torch.float32)\n",
|
||||
"\n",
|
||||
"model = nn.Sequential(\n",
|
||||
" nn.Linear(D_in, H1),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Linear(H1, H2),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Linear(H2, H3),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Linear(H3, D_out)\n",
|
||||
").to(device)\n",
|
||||
"\n",
|
||||
"loss_fn = nn.MSELoss()\n",
|
||||
"optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
|
||||
"\n",
|
||||
"start = time.time()\n",
|
||||
"for t in range(steps):\n",
|
||||
" idx = torch.randint(0, N, (batch_size,), device=device)\n",
|
||||
" x_batch = x[idx]\n",
|
||||
" y_batch = y[idx]\n",
|
||||
"\n",
|
||||
" y_pred = model(x_batch)\n",
|
||||
" loss = loss_fn(y_pred, y_batch)\n",
|
||||
"\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" if t % 10 == 0:\n",
|
||||
" # замедляем вывод, чтобы можно было наблюдать\n",
|
||||
" print(f\"Step {t}/{steps}, Loss: {loss.item():.4f}\")\n",
|
||||
"\n",
|
||||
"end = time.time()\n",
|
||||
"print(f\"Total time: {end-start:.2f} s\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python (my-python-project)",
|
||||
"language": "python",
|
||||
"name": "my-python-project"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
759
src/train_images.ipynb
Normal file
759
src/train_images.ipynb
Normal file
@@ -0,0 +1,759 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9336560f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "0c00b67b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"import torchvision.transforms as T\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"from pathlib import Path\n",
|
||||
"from PIL import Image\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"import timm\n",
|
||||
"import numpy as np\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "84c3657f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'cuda'"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# === CONFIG ===\n",
|
||||
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
|
||||
"BATCH_SIZE = 64 # V100 спокойно тянет\n",
|
||||
"EPOCHS = 15\n",
|
||||
"LR = 3e-4\n",
|
||||
"NUM_WORKERS = 24\n",
|
||||
"\n",
|
||||
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||
"DEVICE\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "9f749add",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class EmoSetDataset(Dataset):\n",
|
||||
" def __init__(self, root, split):\n",
|
||||
" self.root = Path(root) / split\n",
|
||||
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
|
||||
"\n",
|
||||
" self.labels = sorted(self.df[\"label\"].unique())\n",
|
||||
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
|
||||
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
|
||||
"\n",
|
||||
" self.transform = T.Compose([\n",
|
||||
" T.Resize((224, 224)),\n",
|
||||
" T.ToTensor(),\n",
|
||||
" T.Normalize(\n",
|
||||
" mean=[0.485, 0.456, 0.406],\n",
|
||||
" std=[0.229, 0.224, 0.225]\n",
|
||||
" )\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.df)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" row = self.df.iloc[idx]\n",
|
||||
" img_path = self.root / \"images\" / row[\"filename\"]\n",
|
||||
"\n",
|
||||
" img = Image.open(img_path).convert(\"RGB\")\n",
|
||||
" img = self.transform(img)\n",
|
||||
"\n",
|
||||
" label = self.label2idx[row[\"label\"]]\n",
|
||||
" return img, label\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "c8805341",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Classes: ['amusement', 'anger', 'awe', 'contentment', 'disgust', 'excitement', 'fear', 'sadness']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"train_ds = EmoSetDataset(DATA_ROOT, \"train\")\n",
|
||||
"val_ds = EmoSetDataset(DATA_ROOT, \"val\")\n",
|
||||
"\n",
|
||||
"train_loader = DataLoader(\n",
|
||||
" train_ds,\n",
|
||||
" batch_size=BATCH_SIZE,\n",
|
||||
" shuffle=True,\n",
|
||||
" num_workers=NUM_WORKERS,\n",
|
||||
" pin_memory=True\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"val_loader = DataLoader(\n",
|
||||
" val_ds,\n",
|
||||
" batch_size=BATCH_SIZE,\n",
|
||||
" shuffle=False,\n",
|
||||
" num_workers=NUM_WORKERS,\n",
|
||||
" pin_memory=True\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Classes:\", train_ds.labels)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "dffce582",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ResNet(\n",
|
||||
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
|
||||
" (layer1): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (layer2): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (3): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (layer3): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (3): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (4): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (5): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (layer4): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
|
||||
" (fc): Linear(in_features=2048, out_features=8, bias=True)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = timm.create_model(\n",
|
||||
" \"resnet50\",\n",
|
||||
" pretrained=True,\n",
|
||||
" num_classes=len(train_ds.labels)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"model.to(DEVICE)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "81a457ef",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"\n",
|
||||
"optimizer = torch.optim.AdamW(\n",
|
||||
" model.parameters(),\n",
|
||||
" lr=LR,\n",
|
||||
" weight_decay=1e-4\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
|
||||
" optimizer,\n",
|
||||
" T_max=EPOCHS\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "951aa9e3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train_epoch(model, loader):\n",
|
||||
" model.train()\n",
|
||||
" total_loss = 0\n",
|
||||
" correct = 0\n",
|
||||
" total = 0\n",
|
||||
"\n",
|
||||
" for imgs, labels in tqdm(loader, leave=False):\n",
|
||||
" imgs = imgs.to(DEVICE)\n",
|
||||
" labels = labels.to(DEVICE)\n",
|
||||
"\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" logits = model(imgs)\n",
|
||||
" loss = criterion(logits, labels)\n",
|
||||
"\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" total_loss += loss.item() * imgs.size(0)\n",
|
||||
" preds = logits.argmax(dim=1)\n",
|
||||
" correct += (preds == labels).sum().item()\n",
|
||||
" total += labels.size(0)\n",
|
||||
"\n",
|
||||
" return total_loss / total, correct / total\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "fb7e9398",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@torch.no_grad()\n",
|
||||
"def val_epoch(model, loader):\n",
|
||||
" model.eval()\n",
|
||||
" total_loss = 0\n",
|
||||
" correct = 0\n",
|
||||
" total = 0\n",
|
||||
"\n",
|
||||
" for imgs, labels in loader:\n",
|
||||
" imgs = imgs.to(DEVICE)\n",
|
||||
" labels = labels.to(DEVICE)\n",
|
||||
"\n",
|
||||
" logits = model(imgs)\n",
|
||||
" loss = criterion(logits, labels)\n",
|
||||
"\n",
|
||||
" total_loss += loss.item() * imgs.size(0)\n",
|
||||
" preds = logits.argmax(dim=1)\n",
|
||||
" correct += (preds == labels).sum().item()\n",
|
||||
" total += labels.size(0)\n",
|
||||
"\n",
|
||||
" return total_loss / total, correct / total\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "9e870e5d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/1477 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 01 | Train loss: 0.8383, acc: 0.6954 | Val loss: 0.6694, acc: 0.7563\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 02 | Train loss: 0.5462, acc: 0.7972 | Val loss: 0.6592, acc: 0.7594\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 03 | Train loss: 0.3654, acc: 0.8632 | Val loss: 0.7263, acc: 0.7600\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 04 | Train loss: 0.2111, acc: 0.9230 | Val loss: 0.8572, acc: 0.7472\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 05 | Train loss: 0.1187, acc: 0.9585 | Val loss: 1.0372, acc: 0.7453\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 06 | Train loss: 0.0690, acc: 0.9768 | Val loss: 1.1982, acc: 0.7529\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 07 | Train loss: 0.0466, acc: 0.9843 | Val loss: 1.3178, acc: 0.7492\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 08 | Train loss: 0.0295, acc: 0.9905 | Val loss: 1.3926, acc: 0.7551\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 09 | Train loss: 0.0204, acc: 0.9938 | Val loss: 1.4682, acc: 0.7497\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 10 | Train loss: 0.0146, acc: 0.9955 | Val loss: 1.4784, acc: 0.7604\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 11 | Train loss: 0.0087, acc: 0.9975 | Val loss: 1.5263, acc: 0.7580\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 12 | Train loss: 0.0057, acc: 0.9987 | Val loss: 1.5689, acc: 0.7558\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 13 | Train loss: 0.0044, acc: 0.9990 | Val loss: 1.5952, acc: 0.7566\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 14 | Train loss: 0.0030, acc: 0.9993 | Val loss: 1.6130, acc: 0.7600\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 15 | Train loss: 0.0025, acc: 0.9995 | Val loss: 1.5921, acc: 0.7627\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"best_val_acc = 0.0\n",
|
||||
"\n",
|
||||
"for epoch in range(1, EPOCHS + 1):\n",
|
||||
" train_loss, train_acc = train_epoch(model, train_loader)\n",
|
||||
" val_loss, val_acc = val_epoch(model, val_loader)\n",
|
||||
"\n",
|
||||
" scheduler.step()\n",
|
||||
"\n",
|
||||
" print(\n",
|
||||
" f\"Epoch {epoch:02d} | \"\n",
|
||||
" f\"Train loss: {train_loss:.4f}, acc: {train_acc:.4f} | \"\n",
|
||||
" f\"Val loss: {val_loss:.4f}, acc: {val_acc:.4f}\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" if val_acc > best_val_acc:\n",
|
||||
" best_val_acc = val_acc\n",
|
||||
" torch.save(model.state_dict(), \"emoset_resnet50_best.pth\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7796ef11",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "thesis-py3.11",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user