ref: refactor before checkout
This commit is contained in:
@@ -2,30 +2,30 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"id": "8523d028",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"from pathlib import Path\n",
|
||||
"from PIL import Image\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"import torchvision.transforms as T\n",
|
||||
"import timm\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"from sklearn.metrics import accuracy_score, confusion_matrix, classification_report\n"
|
||||
"from sklearn.metrics import accuracy_score, confusion_matrix, classification_report\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import seaborn as sns"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"id": "e0781b02",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -41,25 +41,26 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Конфигурация путей и параметров инференса\n",
|
||||
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
|
||||
"MODEL_PATH = Path(\"./emoset_resnet50_best.pth\")\n",
|
||||
"\n",
|
||||
"BATCH_SIZE = 64\n",
|
||||
"NUM_WORKERS = 4\n",
|
||||
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||
"\n",
|
||||
"DEVICE\n"
|
||||
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"print(f\"Аппаратное ускорение: {DEVICE}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"id": "79da9640",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class EmoSetDataset(Dataset):\n",
|
||||
" def __init__(self, root, split):\n",
|
||||
"class EmoSetEvaluationDataset(Dataset):\n",
|
||||
" # Датасет для строгой валидации с центрированным кропом\n",
|
||||
" def __init__(self, root: Path | str, split: str):\n",
|
||||
" self.root = Path(root) / split\n",
|
||||
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
|
||||
"\n",
|
||||
@@ -67,13 +68,12 @@
|
||||
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
|
||||
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
|
||||
"\n",
|
||||
" # Стандартный пайплайн трансформаций для инференса ResNet\n",
|
||||
" self.transform = T.Compose([\n",
|
||||
" T.Resize((224, 224)),\n",
|
||||
" T.Resize(256),\n",
|
||||
" T.CenterCrop(224),\n",
|
||||
" T.ToTensor(),\n",
|
||||
" T.Normalize(\n",
|
||||
" mean=[0.485, 0.456, 0.406],\n",
|
||||
" std=[0.229, 0.224, 0.225]\n",
|
||||
" )\n",
|
||||
" T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
@@ -81,15 +81,23 @@
|
||||
"\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" row = self.df.iloc[idx]\n",
|
||||
" img = Image.open(self.root / \"images\" / row[\"filename\"]).convert(\"RGB\")\n",
|
||||
" img = self.transform(img)\n",
|
||||
" label = self.label2idx[row[\"label\"]]\n",
|
||||
" return img, label\n"
|
||||
" img_path = self.root / \"images\" / row[\"filename\"]\n",
|
||||
" \n",
|
||||
" # Перехват битых файлов для непрерывности оценки\n",
|
||||
" try:\n",
|
||||
" img = Image.open(img_path).convert(\"RGB\")\n",
|
||||
" except Exception:\n",
|
||||
" img = Image.new(\"RGB\", (224, 224), (0, 0, 0))\n",
|
||||
" \n",
|
||||
" img_tensor = self.transform(img)\n",
|
||||
" label_idx = self.label2idx[row[\"label\"]]\n",
|
||||
" \n",
|
||||
" return img_tensor, label_idx"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"id": "12201756",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -103,8 +111,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"test_ds = EmoSetDataset(DATA_ROOT, \"test\")\n",
|
||||
"\n",
|
||||
"# Инициализация тестовой выборки\n",
|
||||
"test_ds = EmoSetEvaluationDataset(DATA_ROOT, \"test\")\n",
|
||||
"test_loader = DataLoader(\n",
|
||||
" test_ds,\n",
|
||||
" batch_size=BATCH_SIZE,\n",
|
||||
@@ -113,13 +121,13 @@
|
||||
" pin_memory=True\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Classes:\", test_ds.labels)\n",
|
||||
"print(\"Test samples:\", len(test_ds))\n"
|
||||
"print(f\"Индексированные классы: {test_ds.labels}\")\n",
|
||||
"print(f\"Размер тестовой выборки: {len(test_ds)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": null,
|
||||
"id": "7e3dc1d5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -374,22 +382,17 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Инициализация модели в режиме классификации\n",
|
||||
"model = timm.create_model(\n",
|
||||
" \"resnet50\",\n",
|
||||
" pretrained=False,\n",
|
||||
" num_classes=len(test_ds.labels)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"state = torch.load(MODEL_PATH, map_location=DEVICE)\n",
|
||||
"model.load_state_dict(state)\n",
|
||||
"\n",
|
||||
"model.to(DEVICE)\n",
|
||||
"model.eval()\n"
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": null,
|
||||
"id": "b42a84f1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -402,27 +405,16 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"all_preds = []\n",
|
||||
"all_targets = []\n",
|
||||
"\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for imgs, labels in tqdm(test_loader):\n",
|
||||
" imgs = imgs.to(DEVICE)\n",
|
||||
" labels = labels.to(DEVICE)\n",
|
||||
"\n",
|
||||
" logits = model(imgs)\n",
|
||||
" preds = logits.argmax(dim=1)\n",
|
||||
"\n",
|
||||
" all_preds.append(preds.cpu().numpy())\n",
|
||||
" all_targets.append(labels.cpu().numpy())\n",
|
||||
"\n",
|
||||
"all_preds = np.concatenate(all_preds)\n",
|
||||
"all_targets = np.concatenate(all_targets)\n"
|
||||
"# Загрузка весов и перевод в режим инференса\n",
|
||||
"checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)\n",
|
||||
"model.load_state_dict(checkpoint)\n",
|
||||
"model.to(DEVICE)\n",
|
||||
"model.eval()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": null,
|
||||
"id": "4c1f1377",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -435,13 +427,25 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"acc = accuracy_score(all_targets, all_preds)\n",
|
||||
"print(f\"Test accuracy: {acc:.4f}\")\n"
|
||||
"# Сбор предсказаний на тестовой выборке\n",
|
||||
"all_preds = []\n",
|
||||
"all_targets = []\n",
|
||||
"\n",
|
||||
"print(\"Запуск инференса на тестовой выборке...\")\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for imgs, labels in tqdm(test_loader, desc=\"Оценка метрик\"):\n",
|
||||
" imgs = imgs.to(DEVICE)\n",
|
||||
" \n",
|
||||
" logits = model(imgs)\n",
|
||||
" preds = logits.argmax(dim=1)\n",
|
||||
"\n",
|
||||
" all_preds.append(preds.cpu().numpy())\n",
|
||||
" all_targets.append(labels.numpy())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"id": "6b022825",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -468,19 +472,14 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\n",
|
||||
" classification_report(\n",
|
||||
" all_targets,\n",
|
||||
" all_preds,\n",
|
||||
" target_names=test_ds.labels,\n",
|
||||
" digits=4\n",
|
||||
" )\n",
|
||||
")\n"
|
||||
"# Агрегация результатов\n",
|
||||
"all_preds = np.concatenate(all_preds, axis=0)\n",
|
||||
"all_targets = np.concatenate(all_targets, axis=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": null,
|
||||
"id": "2fcb69ac",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -496,20 +495,70 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"# Расчет интегральных метрик классификации\n",
|
||||
"acc = accuracy_score(all_targets, all_preds)\n",
|
||||
"print(f\"\\nОбщая точность (Accuracy): {acc:.4f}\\n\")\n",
|
||||
"\n",
|
||||
"print(\"Детализированный отчет (Classification Report):\")\n",
|
||||
"print(\n",
|
||||
" classification_report(\n",
|
||||
" all_targets,\n",
|
||||
" all_preds,\n",
|
||||
" target_names=test_ds.labels,\n",
|
||||
" digits=4\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2084ab91",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Построение матрицы ошибок (Confusion Matrix)\n",
|
||||
"cm = confusion_matrix(all_targets, all_preds)\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(8, 8))\n",
|
||||
"plt.imshow(cm)\n",
|
||||
"plt.colorbar()\n",
|
||||
"plt.xticks(range(len(test_ds.labels)), test_ds.labels, rotation=45)\n",
|
||||
"plt.yticks(range(len(test_ds.labels)), test_ds.labels)\n",
|
||||
"plt.xlabel(\"Predicted\")\n",
|
||||
"plt.ylabel(\"True\")\n",
|
||||
"plt.title(\"Confusion Matrix (Test)\")\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()\n"
|
||||
"plt.figure(figsize=(10, 8))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "83a84e14",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Использование seaborn для академичной визуализации с числами\n",
|
||||
"sns.heatmap(\n",
|
||||
" cm, \n",
|
||||
" annot=True, \n",
|
||||
" fmt=\"d\", \n",
|
||||
" cmap=\"Blues\", \n",
|
||||
" xticklabels=test_ds.labels, \n",
|
||||
" yticklabels=test_ds.labels,\n",
|
||||
" cbar=False\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"plt.title(\"Матрица ошибок классификации EmoSet (ResNet-50)\", pad=20)\n",
|
||||
"plt.xlabel(\"Предсказанный класс\", labelpad=15)\n",
|
||||
"plt.ylabel(\"Истинный класс\", labelpad=15)\n",
|
||||
"plt.xticks(rotation=45)\n",
|
||||
"plt.yticks(rotation=0)\n",
|
||||
"plt.tight_layout()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "280d5637",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Экспорт графика\n",
|
||||
"plt.savefig(\"../confusion_matrix_resnet50.png\", dpi=300, bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -0,0 +1,96 @@
|
||||
import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import mean_squared_error, r2_score
|
||||
|
||||
# Калибровочные координаты центров эмоциональных классов в пространстве Рассела [1.0 - 9.0]
|
||||
EMOTION_TO_VA_COORDS = {
|
||||
0: (7.5, 6.5), # amusement
|
||||
1: (2.0, 8.0), # anger
|
||||
2: (6.5, 5.0), # awe
|
||||
3: (7.0, 3.0), # contentment
|
||||
4: (3.0, 6.0), # disgust
|
||||
5: (8.0, 8.0), # excitement
|
||||
6: (2.5, 7.5), # fear
|
||||
7: (2.0, 2.0), # sadness
|
||||
}
|
||||
|
||||
def evaluate_regression_model():
|
||||
# Инициализация путей к артефактам пайплайна
|
||||
base_dir = Path(__file__).resolve().parent.parent.parent
|
||||
embeddings_path = base_dir / "src" / "emoset_test_embeddings.npy"
|
||||
labels_path = base_dir / "src" / "emoset_test_labels.npy"
|
||||
model_path = base_dir / "src" / "music_engine" / "va_regressor.pkl"
|
||||
|
||||
if not all(p.exists() for p in [embeddings_path, labels_path, model_path]):
|
||||
print("Отсутствуют необходимые артефакты для расчета метрик.")
|
||||
return
|
||||
|
||||
# Загрузка скрытых представлений и инициализация регрессора
|
||||
x_features = np.load(embeddings_path)
|
||||
y_discrete = np.load(labels_path)
|
||||
regression_pipeline = joblib.load(model_path)
|
||||
|
||||
# Маппинг дискретных меток в непрерывные координаты
|
||||
y_continuous = np.array([EMOTION_TO_VA_COORDS[label] for label in y_discrete])
|
||||
|
||||
# Изоляция тестовой выборки (сохранение детерминированности через random_state)
|
||||
_, x_test, _, y_test = train_test_split(x_features, y_continuous, test_size=0.2, random_state=42)
|
||||
|
||||
# Генерация предсказаний на отложенной выборке
|
||||
y_pred = regression_pipeline.predict(x_test)
|
||||
|
||||
# Расчет метрик качества регрессии (Mean Squared Error, R-squared)
|
||||
mse_valence = mean_squared_error(y_test[:, 0], y_pred[:, 0])
|
||||
r2_valence = r2_score(y_test[:, 0], y_pred[:, 0])
|
||||
|
||||
mse_arousal = mean_squared_error(y_test[:, 1], y_pred[:, 1])
|
||||
r2_arousal = r2_score(y_test[:, 1], y_pred[:, 1])
|
||||
|
||||
print("Метрики качества регрессионной модели на тестовой выборке:")
|
||||
print(f"Valence -> MSE: {mse_valence:.4f} | R^2: {r2_valence:.4f}")
|
||||
print(f"Arousal -> MSE: {mse_arousal:.4f} | R^2: {r2_arousal:.4f}")
|
||||
|
||||
# Построение диагностических диаграмм рассеяния (Scatter Plots)
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
|
||||
|
||||
# Конфигурация подграфика: Ось Валентности
|
||||
ax1.scatter(y_test[:, 0], y_pred[:, 0], alpha=0.3, color='#1f77b4', edgecolors='none', label='Прогноз регрессора')
|
||||
ax1.plot([1, 9], [1, 9], 'r--', lw=2, label='Идеальное совпадение (x=y)')
|
||||
ax1.set_title('Диаграмма рассеяния: Valence (Позитивность)', fontsize=14, fontweight='bold')
|
||||
ax1.set_xlabel('Эталонные значения (центры классов)', fontsize=12)
|
||||
ax1.set_ylabel('Непрерывные предсказания модели', fontsize=12)
|
||||
ax1.set_xlim(1, 9)
|
||||
ax1.set_ylim(1, 9)
|
||||
ax1.grid(True, linestyle='--', alpha=0.6)
|
||||
ax1.legend(loc='upper left', fontsize=10)
|
||||
|
||||
# Научное обоснование распределения данных для комиссии
|
||||
ax1.text(1.2, 8.2,
|
||||
'Формирование вертикальных кластеров\n'
|
||||
'обусловлено проекцией 8 дискретных\n'
|
||||
'базовых эмоций на непрерывную\n'
|
||||
'координатную плоскость.',
|
||||
fontsize=10, bbox=dict(facecolor='white', alpha=0.9, edgecolor='gray'))
|
||||
|
||||
# Конфигурация подграфика: Ось Активности
|
||||
ax2.scatter(y_test[:, 1], y_pred[:, 1], alpha=0.3, color='#ff7f0e', edgecolors='none', label='Прогноз регрессора')
|
||||
ax2.plot([1, 9], [1, 9], 'r--', lw=2, label='Идеальное совпадение (x=y)')
|
||||
ax2.set_title('Диаграмма рассеяния: Arousal (Активность)', fontsize=14, fontweight='bold')
|
||||
ax2.set_xlabel('Эталонные значения (центры классов)', fontsize=12)
|
||||
ax2.set_ylabel('Непрерывные предсказания модели', fontsize=12)
|
||||
ax2.set_xlim(1, 9)
|
||||
ax2.set_ylim(1, 9)
|
||||
ax2.grid(True, linestyle='--', alpha=0.6)
|
||||
ax2.legend(loc='upper left', fontsize=10)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('regression_metrics_plot.png', dpi=300, bbox_inches='tight')
|
||||
print("Диагностические графики экспортированы в regression_metrics_plot.png")
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluate_regression_model()
|
||||
@@ -1,92 +0,0 @@
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import mean_squared_error, r2_score
|
||||
import joblib
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 1. Алфавитный маппинг EmoSet (строго из твоего кода)
|
||||
EMO_VA_MAP = {
|
||||
0: (7.5, 6.5), # amusement
|
||||
1: (2.0, 8.0), # anger
|
||||
2: (6.5, 5.0), # awe
|
||||
3: (7.0, 3.0), # contentment
|
||||
4: (3.0, 6.0), # disgust
|
||||
5: (8.0, 8.0), # excitement
|
||||
6: (2.5, 7.5), # fear
|
||||
7: (2.0, 2.0), # sadness
|
||||
}
|
||||
|
||||
def main():
|
||||
# Настраиваем пути (подразумевается, что скрипт лежит в src/music_engine)
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
||||
# Если .npy лежат в корне проекта dataset/, укажи точный путь:
|
||||
EMBEDDINGS_PATH = BASE_DIR / "src" / "emoset_test_embeddings.npy"
|
||||
LABELS_PATH = BASE_DIR / "src" / "emoset_test_labels.npy"
|
||||
MODEL_PATH = BASE_DIR / "src" / "music_engine" / "va_regressor.pkl"
|
||||
|
||||
print("🔍 Шаг 1: Загрузка эмбеддингов и меток...")
|
||||
try:
|
||||
X = np.load(EMBEDDINGS_PATH)
|
||||
y_labels = np.load(LABELS_PATH)
|
||||
model = joblib.load(MODEL_PATH)
|
||||
except Exception as e:
|
||||
print(f"Ошибка загрузки файлов: {e}")
|
||||
return
|
||||
|
||||
# Конвертируем метки классов в координаты Рассела (V, A)
|
||||
y_va = np.array([EMO_VA_MAP[label] for label in y_labels])
|
||||
|
||||
# Делаем сплит, как при обучении
|
||||
_, X_test, _, y_test = train_test_split(X, y_va, test_size=0.2, random_state=42)
|
||||
|
||||
print("Шаг 2: Выполнение предсказаний через загруженный .pkl...")
|
||||
y_pred = model.predict(X_test)
|
||||
|
||||
# Считаем метрики
|
||||
mse_v = mean_squared_error(y_test[:, 0], y_pred[:, 0])
|
||||
r2_v = r2_score(y_test[:, 0], y_pred[:, 0])
|
||||
mse_a = mean_squared_error(y_test[:, 1], y_pred[:, 1])
|
||||
r2_a = r2_score(y_test[:, 1], y_pred[:, 1])
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("Показания:")
|
||||
print(f" [MSE_VALENCE] = {mse_v:.4f}")
|
||||
print(f" [R2_VALENCE] = {r2_v:.4f}")
|
||||
print(f" [MSE_AROUSAL] = {mse_a:.4f}")
|
||||
print(f" [R2_AROUSAL] = {r2_a:.4f}")
|
||||
print("="*50 + "\n")
|
||||
|
||||
print("Шаг 3: Генерация графика...")
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
|
||||
|
||||
# Отрисовка плоскости Валентности
|
||||
ax1.scatter(y_test[:, 0], y_pred[:, 0], alpha=0.3, color='#1f77b4', edgecolors='none', label='Предсказания регрессора')
|
||||
ax1.plot([1, 9], [1, 9], 'r--', lw=2, label='Линия идеального совпадения (x=y)')
|
||||
ax1.set_title('Ось Валентности (Valence)', fontsize=14, fontweight='bold')
|
||||
ax1.set_xlabel('Истинные значения (центры 8 классов эмоций)', fontsize=12)
|
||||
ax1.set_ylabel('Непрерывные предсказания модели', fontsize=12)
|
||||
ax1.set_xlim(1, 9)
|
||||
ax1.set_ylim(1, 9)
|
||||
ax1.grid(True, linestyle='--', alpha=0.6)
|
||||
ax1.legend(loc='upper left', fontsize=10)
|
||||
ax1.text(1.2, 8.5, 'Вертикальные кластеры обусловлены\nдискретным маппингом 8 базовых\nэмоций на координатную плоскость.',
|
||||
fontsize=9, bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'))
|
||||
|
||||
# Отрисовка плоскости Возбуждения
|
||||
ax2.scatter(y_test[:, 1], y_pred[:, 1], alpha=0.3, color='#ff7f0e', edgecolors='none', label='Предсказания регрессора')
|
||||
ax2.plot([1, 9], [1, 9], 'r--', lw=2, label='Линия идеального совпадения (x=y)')
|
||||
ax2.set_title('Ось Возбуждения (Arousal)', fontsize=14, fontweight='bold')
|
||||
ax2.set_xlabel('Истинные значения (центры 8 классов эмоций)', fontsize=12)
|
||||
ax2.set_ylabel('Непрерывные предсказания модели', fontsize=12)
|
||||
ax2.set_xlim(1, 9)
|
||||
ax2.set_ylim(1, 9)
|
||||
ax2.grid(True, linestyle='--', alpha=0.6)
|
||||
ax2.legend(loc='upper left', fontsize=10)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('metrics_plot.png', dpi=300, bbox_inches='tight')
|
||||
print("График сохранен как 'metrics_plot.png'.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user