80 lines
3.6 KiB
Python
80 lines
3.6 KiB
Python
import joblib
|
|
import numpy as np
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
|
|
from sklearn.linear_model import RidgeCV
|
|
from sklearn.multioutput import MultiOutputRegressor
|
|
from sklearn.preprocessing import StandardScaler
|
|
from sklearn.pipeline import Pipeline
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.metrics import mean_squared_error, r2_score
|
|
|
|
# Проекция дискретных классов эмоций на непрерывное пространство Рассела (Valence, Arousal)
|
|
# Значения откалиброваны в диапазоне [1.0, 9.0]
|
|
EMOTION_TO_VA_COORDS = {
|
|
0: (7.5, 6.5), # amusement
|
|
1: (2.0, 8.0), # anger
|
|
2: (6.5, 5.0), # awe
|
|
3: (7.0, 3.0), # contentment
|
|
4: (3.0, 6.0), # disgust
|
|
5: (8.0, 8.0), # excitement
|
|
6: (2.5, 7.5), # fear
|
|
7: (2.0, 2.0), # sadness
|
|
}
|
|
|
|
def train_va_regressor():
|
|
# Настройка путей
|
|
base_dir = Path(__file__).resolve().parent.parent
|
|
embeddings_path = base_dir / "emoset_test_embeddings.npy"
|
|
labels_path = base_dir / "emoset_test_labels.npy"
|
|
model_output_path = base_dir / "music_engine" / "va_regressor.pkl"
|
|
|
|
if not embeddings_path.exists() or not labels_path.exists():
|
|
print(f"Артефакты признаков не найдены в директории: {base_dir}")
|
|
return
|
|
|
|
print("Загрузка вектора признаков и меток классов...")
|
|
x_features = np.load(embeddings_path)
|
|
y_discrete = np.load(labels_path)
|
|
|
|
# Трансформация целевой переменной: классы -> непрерывные координаты V/A
|
|
y_continuous = np.array([EMOTION_TO_VA_COORDS[label] for label in y_discrete])
|
|
|
|
x_train, x_test, y_train, y_test = train_test_split(
|
|
x_features, y_continuous, test_size=0.2, random_state=42
|
|
)
|
|
|
|
# Построение пайплайна: Z-масштабирование и L2-регуляризованная регрессия
|
|
# RidgeCV автоматически подбирает оптимальный гиперпараметр alpha (силу регуляризации)
|
|
print("Инициализация и обучение пайплайна RidgeCV...")
|
|
regression_pipeline = Pipeline([
|
|
('scaler', StandardScaler()),
|
|
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
|
|
])
|
|
|
|
regression_pipeline.fit(x_train, y_train)
|
|
|
|
# Оценка обобщающей способности модели
|
|
y_pred = regression_pipeline.predict(x_test)
|
|
|
|
mse_score = mean_squared_error(y_test, y_pred)
|
|
r2 = r2_score(y_test, y_pred)
|
|
|
|
print("Обучение завершено. Метрики качества на тестовой выборке:")
|
|
print(f" - MSE: {mse_score:.4f}")
|
|
print(f" - R^2: {r2:.4f}")
|
|
|
|
# Диагностика дисперсии предсказаний
|
|
v_min, v_max = y_pred[:, 0].min(), y_pred[:, 0].max()
|
|
a_min, a_max = y_pred[:, 1].min(), y_pred[:, 1].max()
|
|
print(f"Распределение Valence (прогноз): [{v_min:.2f}, {v_max:.2f}] (Эталон: 1.0 - 9.0)")
|
|
print(f"Распределение Arousal (прогноз): [{a_min:.2f}, {a_max:.2f}] (Эталон: 1.0 - 9.0)")
|
|
|
|
# Экспорт обученного пайплайна
|
|
model_output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
joblib.dump(regression_pipeline, model_output_path)
|
|
print(f"Пайплайн сохранен: {model_output_path.name}")
|
|
|
|
if __name__ == "__main__":
|
|
train_va_regressor() |