Integrated Music_engine and update Debug UI
This commit is contained in:
@@ -1,61 +1,65 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from sklearn.linear_model import Ridge
|
||||
from sklearn.linear_model import RidgeCV
|
||||
from sklearn.multioutput import MultiOutputRegressor
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import mean_squared_error, r2_score
|
||||
import joblib
|
||||
|
||||
# 1. Эталонный маппинг EmoSet -> Valence/Arousal (шкала 1-9)
|
||||
# Убедись, что индексы ключей (0-7) совпадают с тем, как они размечены в твоем labels.npy
|
||||
# Стандартный порядок EmoSet:
|
||||
# 1. Алфавитный маппинг EmoSet
|
||||
EMO_VA_MAP = {
|
||||
0: (7.5, 6.5), # amusement (радость/веселье) - позитивно, средне-активно
|
||||
1: (6.5, 5.0), # awe (трепет/восхищение) - позитивно, спокойно
|
||||
2: (7.0, 3.0), # contentment (удовлетворение) - позитивно, очень спокойно
|
||||
3: (8.0, 8.0), # excitement (возбуждение) - очень позитивно, очень активно
|
||||
4: (2.0, 8.0), # anger (гнев) - негативно, очень активно
|
||||
5: (3.0, 6.0), # disgust (отвращение) - негативно, средне-активно
|
||||
6: (2.5, 7.5), # fear (страх) - негативно, очень активно
|
||||
7: (2.0, 2.0), # sadness (грусть) - негативно, пассивно
|
||||
0: (7.5, 6.5), # amusement
|
||||
1: (2.0, 8.0), # anger
|
||||
2: (6.5, 5.0), # awe
|
||||
3: (7.0, 3.0), # contentment
|
||||
4: (3.0, 6.0), # disgust
|
||||
5: (8.0, 8.0), # excitement
|
||||
6: (2.5, 7.5), # fear
|
||||
7: (2.0, 2.0), # sadness
|
||||
}
|
||||
|
||||
# 2. Загрузка данных
|
||||
# Укажи пути к твоим эмбеддингам и меткам (можно взять train или test, для демо не так критично)
|
||||
EMBEDDINGS_PATH = Path("../emoset_test_embeddings.npy")
|
||||
LABELS_PATH = Path("../emoset_test_labels.npy")
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
EMBEDDINGS_PATH = BASE_DIR / "emoset_test_embeddings.npy"
|
||||
LABELS_PATH = BASE_DIR / "emoset_test_labels.npy"
|
||||
|
||||
print("Загрузка данных...")
|
||||
X = np.load(EMBEDDINGS_PATH)
|
||||
y_labels = np.load(LABELS_PATH)
|
||||
|
||||
# Преобразуем дискретные метки в целевые координаты V-A
|
||||
print("Формирование целевых координат (Valence, Arousal)...")
|
||||
y_va = np.array([EMO_VA_MAP[label] for label in y_labels])
|
||||
|
||||
# Разделение на train/val
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y_va, test_size=0.2, random_state=42)
|
||||
|
||||
# 3. Обучение модели
|
||||
print("Обучение Ridge регрессора...")
|
||||
# Ridge отлично справляется с многомерными эмбеддингами, избегая переобучения
|
||||
base_estimator = Ridge(alpha=1.0)
|
||||
model = MultiOutputRegressor(base_estimator)
|
||||
# 2. НОВАЯ, ПРАВИЛЬНАЯ АРХИТЕКТУРА (Pipeline)
|
||||
print("Обучение масштабатора и RidgeCV регрессора...")
|
||||
# Pipeline гарантирует, что при предсказании в main.py новые векторы тоже будут масштабированы
|
||||
model = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
|
||||
])
|
||||
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
# 4. Оценка
|
||||
# 3. Диагностика и Оценка
|
||||
y_pred = model.predict(X_test)
|
||||
|
||||
mse = mean_squared_error(y_test, y_pred)
|
||||
r2 = r2_score(y_test, y_pred)
|
||||
|
||||
print(f"Обучение завершено!")
|
||||
print(f"MSE (Среднеквадратичная ошибка): {mse:.4f}")
|
||||
print(f"R^2 Score (Коэффициент детерминации): {r2:.4f}")
|
||||
print(f"\n[УСПЕХ] Обучение завершено!")
|
||||
print(f"MSE: {mse:.4f}")
|
||||
print(f"R^2 Score: {r2:.4f}")
|
||||
|
||||
# 5. Сохранение
|
||||
output_model_path = Path("../src/music_engine/va_regressor.pkl")
|
||||
# === ТОТ САМЫЙ ТЕСТ НА КОЛЛАПС ===
|
||||
print("\n--- ДИАГНОСТИКА РАЗБРОСА ПРЕДСКАЗАНИЙ ---")
|
||||
print(f"Valence: от {y_pred[:, 0].min():.2f} до {y_pred[:, 0].max():.2f} (Эталон: 2.0 - 8.0)")
|
||||
print(f"Arousal: от {y_pred[:, 1].min():.2f} до {y_pred[:, 1].max():.2f} (Эталон: 2.0 - 8.0)")
|
||||
# ===============================================
|
||||
|
||||
# 4. Сохранение (Pipeline сохраняется целиком со StandardScaler)
|
||||
output_model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
|
||||
output_model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
joblib.dump(model, output_model_path)
|
||||
print(f"Модель сохранена в: {output_model_path}")
|
||||
print(f"\nМодель сохранена в: {output_model_path}")
|
||||
Reference in New Issue
Block a user