ref: refactor before chekout
This commit is contained in:
@@ -0,0 +1,80 @@
|
||||
import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from sklearn.linear_model import RidgeCV
|
||||
from sklearn.multioutput import MultiOutputRegressor
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import mean_squared_error, r2_score
|
||||
|
||||
# Проекция дискретных классов эмоций на непрерывное пространство Рассела (Valence, Arousal)
|
||||
# Значения откалиброваны в диапазоне [1.0, 9.0]
|
||||
EMOTION_TO_VA_COORDS = {
|
||||
0: (7.5, 6.5), # amusement
|
||||
1: (2.0, 8.0), # anger
|
||||
2: (6.5, 5.0), # awe
|
||||
3: (7.0, 3.0), # contentment
|
||||
4: (3.0, 6.0), # disgust
|
||||
5: (8.0, 8.0), # excitement
|
||||
6: (2.5, 7.5), # fear
|
||||
7: (2.0, 2.0), # sadness
|
||||
}
|
||||
|
||||
def train_va_regressor():
|
||||
# Настройка путей
|
||||
base_dir = Path(__file__).resolve().parent.parent
|
||||
embeddings_path = base_dir / "emoset_test_embeddings.npy"
|
||||
labels_path = base_dir / "emoset_test_labels.npy"
|
||||
model_output_path = base_dir / "music_engine" / "va_regressor.pkl"
|
||||
|
||||
if not embeddings_path.exists() or not labels_path.exists():
|
||||
print(f"Артефакты признаков не найдены в директории: {base_dir}")
|
||||
return
|
||||
|
||||
print("Загрузка вектора признаков и меток классов...")
|
||||
x_features = np.load(embeddings_path)
|
||||
y_discrete = np.load(labels_path)
|
||||
|
||||
# Трансформация целевой переменной: классы -> непрерывные координаты V/A
|
||||
y_continuous = np.array([EMOTION_TO_VA_COORDS[label] for label in y_discrete])
|
||||
|
||||
x_train, x_test, y_train, y_test = train_test_split(
|
||||
x_features, y_continuous, test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
# Построение пайплайна: Z-масштабирование и L2-регуляризованная регрессия
|
||||
# RidgeCV автоматически подбирает оптимальный гиперпараметр alpha (силу регуляризации)
|
||||
print("Инициализация и обучение пайплайна RidgeCV...")
|
||||
regression_pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
|
||||
])
|
||||
|
||||
regression_pipeline.fit(x_train, y_train)
|
||||
|
||||
# Оценка обобщающей способности модели
|
||||
y_pred = regression_pipeline.predict(x_test)
|
||||
|
||||
mse_score = mean_squared_error(y_test, y_pred)
|
||||
r2 = r2_score(y_test, y_pred)
|
||||
|
||||
print("Обучение завершено. Метрики качества на тестовой выборке:")
|
||||
print(f" - MSE: {mse_score:.4f}")
|
||||
print(f" - R^2: {r2:.4f}")
|
||||
|
||||
# Диагностика дисперсии предсказаний
|
||||
v_min, v_max = y_pred[:, 0].min(), y_pred[:, 0].max()
|
||||
a_min, a_max = y_pred[:, 1].min(), y_pred[:, 1].max()
|
||||
print(f"Распределение Valence (прогноз): [{v_min:.2f}, {v_max:.2f}] (Эталон: 1.0 - 9.0)")
|
||||
print(f"Распределение Arousal (прогноз): [{a_min:.2f}, {a_max:.2f}] (Эталон: 1.0 - 9.0)")
|
||||
|
||||
# Экспорт обученного пайплайна
|
||||
model_output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
joblib.dump(regression_pipeline, model_output_path)
|
||||
print(f"Пайплайн сохранен: {model_output_path.name}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_va_regressor()
|
||||
Reference in New Issue
Block a user