chore: change text output

This commit is contained in:
zin
2026-05-28 17:15:33 +00:00
parent 39a68bc3c3
commit af3c5a953e
12 changed files with 114 additions and 36 deletions
+5 -5
View File
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "0336fd0c",
"metadata": {},
"outputs": [
@@ -57,17 +57,17 @@
"\n",
"# 3. Загружаем текущую базу с V/A\n",
"if not music_db_path.exists():\n",
" print(f\"ОШИБКА: Не найден файл {music_db_path}\")\n",
" print(f\"ОШИБКА: Не найден файл {music_db_path}\")\n",
"else:\n",
" df_main = pd.read_csv(music_db_path)\n",
" print(f\"База загружена. Треков: {len(df_main)}\")\n",
" print(f\"База загружена. Треков: {len(df_main)}\")\n",
"\n",
" # Подготавливаем новые колонки\n",
" for col_name in target_columns.values():\n",
" df_main[col_name] = np.nan\n",
"\n",
" # 4. Проходимся по всем трекам и ищем их акустические CSV\n",
" print(\"🔍 Собираем акустические признаки...\")\n",
" print(\"Собираем акустические признаки...\")\n",
" found_count = 0\n",
" \n",
" for index, row in df_main.iterrows():\n",
@@ -95,7 +95,7 @@
" df_main = df_main.dropna(subset=list(target_columns.values()))\n",
" \n",
" df_main.to_csv(output_path, index=False)\n",
" print(f\"\\n🚀 ГОТОВО! Обогащенная база сохранена: {output_path}\")\n",
" print(f\"\\nГотово! Обогащенная база сохранена: {output_path}\")\n",
" print(f\"Собрано фичей для {found_count} из {len(df_main)} треков.\")\n",
" print(df_main.head())"
]
+2 -2
View File
@@ -49,7 +49,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -112,7 +112,7 @@
"print(f\"Переносим файлы в {DEAM_ROOT}...\")\n",
"shutil.copytree(kaggle_cache_path, DEAM_ROOT, dirs_exist_ok=True)\n",
"\n",
"print(\"\\n[УСПЕХ] Датасет DEAM готов к работе!\")\n"
"print(\"\\nУСПЕХ! Датасет DEAM готов к работе!\")\n"
]
}
],
+92
View File
@@ -0,0 +1,92 @@
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import joblib
import matplotlib.pyplot as plt
# 1. Алфавитный маппинг EmoSet (строго из твоего кода)
EMO_VA_MAP = {
0: (7.5, 6.5), # amusement
1: (2.0, 8.0), # anger
2: (6.5, 5.0), # awe
3: (7.0, 3.0), # contentment
4: (3.0, 6.0), # disgust
5: (8.0, 8.0), # excitement
6: (2.5, 7.5), # fear
7: (2.0, 2.0), # sadness
}
def main():
# Настраиваем пути (подразумевается, что скрипт лежит в src/music_engine)
BASE_DIR = Path(__file__).resolve().parent.parent.parent
# Если .npy лежат в корне проекта dataset/, укажи точный путь:
EMBEDDINGS_PATH = BASE_DIR / "src" / "emoset_test_embeddings.npy"
LABELS_PATH = BASE_DIR / "src" / "emoset_test_labels.npy"
MODEL_PATH = BASE_DIR / "src" / "music_engine" / "va_regressor.pkl"
print("🔍 Шаг 1: Загрузка эмбеддингов и меток...")
try:
X = np.load(EMBEDDINGS_PATH)
y_labels = np.load(LABELS_PATH)
model = joblib.load(MODEL_PATH)
except Exception as e:
print(f"Ошибка загрузки файлов: {e}")
return
# Конвертируем метки классов в координаты Рассела (V, A)
y_va = np.array([EMO_VA_MAP[label] for label in y_labels])
# Делаем сплит, как при обучении
_, X_test, _, y_test = train_test_split(X, y_va, test_size=0.2, random_state=42)
print("Шаг 2: Выполнение предсказаний через загруженный .pkl...")
y_pred = model.predict(X_test)
# Считаем метрики
mse_v = mean_squared_error(y_test[:, 0], y_pred[:, 0])
r2_v = r2_score(y_test[:, 0], y_pred[:, 0])
mse_a = mean_squared_error(y_test[:, 1], y_pred[:, 1])
r2_a = r2_score(y_test[:, 1], y_pred[:, 1])
print("\n" + "="*50)
print("Показания:")
print(f" [MSE_VALENCE] = {mse_v:.4f}")
print(f" [R2_VALENCE] = {r2_v:.4f}")
print(f" [MSE_AROUSAL] = {mse_a:.4f}")
print(f" [R2_AROUSAL] = {r2_a:.4f}")
print("="*50 + "\n")
print("Шаг 3: Генерация графика...")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
# Отрисовка плоскости Валентности
ax1.scatter(y_test[:, 0], y_pred[:, 0], alpha=0.3, color='#1f77b4', edgecolors='none', label='Предсказания регрессора')
ax1.plot([1, 9], [1, 9], 'r--', lw=2, label='Линия идеального совпадения (x=y)')
ax1.set_title('Ось Валентности (Valence)', fontsize=14, fontweight='bold')
ax1.set_xlabel('Истинные значения (центры 8 классов эмоций)', fontsize=12)
ax1.set_ylabel('Непрерывные предсказания модели', fontsize=12)
ax1.set_xlim(1, 9)
ax1.set_ylim(1, 9)
ax1.grid(True, linestyle='--', alpha=0.6)
ax1.legend(loc='upper left', fontsize=10)
ax1.text(1.2, 8.5, 'Вертикальные кластеры обусловлены\nдискретным маппингом 8 базовых\nэмоций на координатную плоскость.',
fontsize=9, bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'))
# Отрисовка плоскости Возбуждения
ax2.scatter(y_test[:, 1], y_pred[:, 1], alpha=0.3, color='#ff7f0e', edgecolors='none', label='Предсказания регрессора')
ax2.plot([1, 9], [1, 9], 'r--', lw=2, label='Линия идеального совпадения (x=y)')
ax2.set_title('Ось Возбуждения (Arousal)', fontsize=14, fontweight='bold')
ax2.set_xlabel('Истинные значения (центры 8 классов эмоций)', fontsize=12)
ax2.set_ylabel('Непрерывные предсказания модели', fontsize=12)
ax2.set_xlim(1, 9)
ax2.set_ylim(1, 9)
ax2.grid(True, linestyle='--', alpha=0.6)
ax2.legend(loc='upper left', fontsize=10)
plt.tight_layout()
plt.savefig('metrics_plot.png', dpi=300, bbox_inches='tight')
print("График сохранен как 'metrics_plot.png'.")
if __name__ == "__main__":
main()
+3 -3
View File
@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "1763c51e",
"metadata": {},
"outputs": [
@@ -40,7 +40,7 @@
"output_path = Path(\"../../dataset/DEAM/music_db.csv\")\n",
"\n",
"if not source_path.exists():\n",
" print(f\"Исходный файл не найден по пути: {source_path}\")\n",
" print(f\"Исходный файл не найден по пути: {source_path}\")\n",
"else:\n",
" # skipinitialspace=True уберет лишние пробелы в названиях колонок, если они есть\n",
" df = pd.read_csv(source_path, skipinitialspace=True)\n",
@@ -57,7 +57,7 @@
" # Сохраняем финальный файл\n",
" clean_df.to_csv(output_path, index=False)\n",
" \n",
" print(f\"✅ УСПЕХ! База создана: {output_path}\")\n",
" print(f\"Успех! База создана: {output_path}\")\n",
" print(f\"Всего треков в базе: {len(clean_df)}\")\n",
" print(\"Пример данных:\")\n",
" print(clean_df.head())"
+2 -2
View File
@@ -31,7 +31,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "84c3657f",
"metadata": {},
"outputs": [
@@ -49,7 +49,7 @@
"source": [
"# === CONFIG ===\n",
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
"BATCH_SIZE = 64 # V100 спокойно тянет\n",
"BATCH_SIZE = 64\n",
"EPOCHS = 15\n",
"LR = 3e-4\n",
"NUM_WORKERS = 24\n",
+1 -8
View File
@@ -9,7 +9,6 @@ from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import joblib
# 1. Алфавитный маппинг EmoSet
EMO_VA_MAP = {
0: (7.5, 6.5), # amusement
1: (2.0, 8.0), # anger
@@ -32,9 +31,7 @@ y_labels = np.load(LABELS_PATH)
y_va = np.array([EMO_VA_MAP[label] for label in y_labels])
X_train, X_test, y_train, y_test = train_test_split(X, y_va, test_size=0.2, random_state=42)
# 2. НОВАЯ, ПРАВИЛЬНАЯ АРХИТЕКТУРА (Pipeline)
print("Обучение масштабатора и RidgeCV регрессора...")
# Pipeline гарантирует, что при предсказании в main.py новые векторы тоже будут масштабированы
model = Pipeline([
('scaler', StandardScaler()),
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
@@ -42,23 +39,19 @@ model = Pipeline([
model.fit(X_train, y_train)
# 3. Диагностика и Оценка
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"\n[УСПЕХ] Обучение завершено!")
print(f"\nУспех! Обучение завершено!")
print(f"MSE: {mse:.4f}")
print(f"R^2 Score: {r2:.4f}")
# === ТОТ САМЫЙ ТЕСТ НА КОЛЛАПС ===
print("\n--- ДИАГНОСТИКА РАЗБРОСА ПРЕДСКАЗАНИЙ ---")
print(f"Valence: от {y_pred[:, 0].min():.2f} до {y_pred[:, 0].max():.2f} (Эталон: 2.0 - 8.0)")
print(f"Arousal: от {y_pred[:, 1].min():.2f} до {y_pred[:, 1].max():.2f} (Эталон: 2.0 - 8.0)")
# ===============================================
# 4. Сохранение (Pipeline сохраняется целиком со StandardScaler)
output_model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
output_model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(model, output_model_path)