Refactored paths

This commit is contained in:
zin
2026-05-06 18:22:54 +00:00
parent 4e192b7bc4
commit dd22ee09a4
8 changed files with 61 additions and 0 deletions
+61
View File
@@ -0,0 +1,61 @@
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.linear_model import Ridge
from sklearn.multioutput import MultiOutputRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import joblib
# 1. Эталонный маппинг EmoSet -> Valence/Arousal (шкала 1-9)
# Убедись, что индексы ключей (0-7) совпадают с тем, как они размечены в твоем labels.npy
# Стандартный порядок EmoSet:
EMO_VA_MAP = {
0: (7.5, 6.5), # amusement (радость/веселье) - позитивно, средне-активно
1: (6.5, 5.0), # awe (трепет/восхищение) - позитивно, спокойно
2: (7.0, 3.0), # contentment (удовлетворение) - позитивно, очень спокойно
3: (8.0, 8.0), # excitement (возбуждение) - очень позитивно, очень активно
4: (2.0, 8.0), # anger (гнев) - негативно, очень активно
5: (3.0, 6.0), # disgust (отвращение) - негативно, средне-активно
6: (2.5, 7.5), # fear (страх) - негативно, очень активно
7: (2.0, 2.0), # sadness (грусть) - негативно, пассивно
}
# 2. Загрузка данных
# Укажи пути к твоим эмбеддингам и меткам (можно взять train или test, для демо не так критично)
EMBEDDINGS_PATH = Path("../emoset_test_embeddings.npy")
LABELS_PATH = Path("../emoset_test_labels.npy")
print("Загрузка данных...")
X = np.load(EMBEDDINGS_PATH)
y_labels = np.load(LABELS_PATH)
# Преобразуем дискретные метки в целевые координаты V-A
print("Формирование целевых координат (Valence, Arousal)...")
y_va = np.array([EMO_VA_MAP[label] for label in y_labels])
# Разделение на train/val
X_train, X_test, y_train, y_test = train_test_split(X, y_va, test_size=0.2, random_state=42)
# 3. Обучение модели
print("Обучение Ridge регрессора...")
# Ridge отлично справляется с многомерными эмбеддингами, избегая переобучения
base_estimator = Ridge(alpha=1.0)
model = MultiOutputRegressor(base_estimator)
model.fit(X_train, y_train)
# 4. Оценка
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"Обучение завершено!")
print(f"MSE (Среднеквадратичная ошибка): {mse:.4f}")
print(f"R^2 Score (Коэффициент детерминации): {r2:.4f}")
# 5. Сохранение
output_model_path = Path("../src/music_engine/va_regressor.pkl")
output_model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(model, output_model_path)
print(f"Модель сохранена в: {output_model_path}")