chore: change text output
This commit is contained in:
+1
-1
@@ -28,7 +28,7 @@ def load_image_processor():
|
||||
model_path = BASE_DIR / "emoset_resnet50_best.pth"
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"❌ КРИТИЧЕСКАЯ ОШИБКА: Веса не найдены по пути: {model_path}")
|
||||
print(f"Ошибка: Веса не найдены по пути: {model_path}")
|
||||
# Если не нашли в src, попробуем поискать в корне проекта на всякий случай
|
||||
model_path = BASE_DIR.parent / "emoset_resnet50_best.pth"
|
||||
|
||||
|
||||
@@ -4,15 +4,12 @@ from PIL import Image
|
||||
import timm
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
# ТЕПЕРЬ BLIP-2
|
||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||
|
||||
class ImageProcessor:
|
||||
def __init__(self, model_path: Path | str):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# --- ПОТОК 1: ТВОЙ АВТОРСКИЙ ЭМБЕДДИНГ (Core) ---
|
||||
self.emo_model = timm.create_model('resnet50', pretrained=False, num_classes=8)
|
||||
if Path(model_path).exists():
|
||||
self.emo_model.load_state_dict(torch.load(model_path, map_location=self.device))
|
||||
@@ -26,15 +23,13 @@ class ImageProcessor:
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# --- ПОТОК 2: СЕМАНТИЧЕСКИЙ ЭКСПЕРТ BLIP-2 ---
|
||||
print("⏳ Загрузка тяжелой артиллерии: BLIP-2...")
|
||||
# Используем версию opt-2.7b — она идеально сбалансирована для V100
|
||||
print("Загрузка BLIP-2...")
|
||||
self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/blip2-opt-2.7b",
|
||||
torch_dtype=torch.float16 # Обязательно для скорости на V100
|
||||
torch_dtype=torch.float16
|
||||
).to(self.device)
|
||||
print("✅ BLIP-2 и ResNet-50 готовы.")
|
||||
print("BLIP-2 и ResNet-50 готовы.")
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_embedding(self, image: Image.Image) -> np.ndarray:
|
||||
@@ -47,10 +42,8 @@ class ImageProcessor:
|
||||
"""Генерирует описание через BLIP-2."""
|
||||
img_rgb = image.convert('RGB')
|
||||
|
||||
# Инференс BLIP-2 требует float16 для V100
|
||||
inputs = self.blip_processor(images=img_rgb, return_tensors="pt").to(self.device, torch.float16)
|
||||
|
||||
# Генерируем описание
|
||||
generated_ids = self.blip_model.generate(**inputs, max_new_tokens=40)
|
||||
caption = self.blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
||||
return caption
|
||||
@@ -56,5 +56,5 @@ Example: {{"energy": 0.5, "flux": 0.2, "centroid": 0.4, "pitch": 0.3, "hnr": 0.8
|
||||
return profile
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"⚠️ Ошибка связи с локальной LLM: {e}")
|
||||
print(f"Ошибка связи с локальной LLM: {e}")
|
||||
return None
|
||||
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"id": "0336fd0c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -57,17 +57,17 @@
|
||||
"\n",
|
||||
"# 3. Загружаем текущую базу с V/A\n",
|
||||
"if not music_db_path.exists():\n",
|
||||
" print(f\"❌ ОШИБКА: Не найден файл {music_db_path}\")\n",
|
||||
" print(f\"ОШИБКА: Не найден файл {music_db_path}\")\n",
|
||||
"else:\n",
|
||||
" df_main = pd.read_csv(music_db_path)\n",
|
||||
" print(f\"✅ База загружена. Треков: {len(df_main)}\")\n",
|
||||
" print(f\"База загружена. Треков: {len(df_main)}\")\n",
|
||||
"\n",
|
||||
" # Подготавливаем новые колонки\n",
|
||||
" for col_name in target_columns.values():\n",
|
||||
" df_main[col_name] = np.nan\n",
|
||||
"\n",
|
||||
" # 4. Проходимся по всем трекам и ищем их акустические CSV\n",
|
||||
" print(\"🔍 Собираем акустические признаки...\")\n",
|
||||
" print(\"Собираем акустические признаки...\")\n",
|
||||
" found_count = 0\n",
|
||||
" \n",
|
||||
" for index, row in df_main.iterrows():\n",
|
||||
@@ -95,7 +95,7 @@
|
||||
" df_main = df_main.dropna(subset=list(target_columns.values()))\n",
|
||||
" \n",
|
||||
" df_main.to_csv(output_path, index=False)\n",
|
||||
" print(f\"\\n🚀 ГОТОВО! Обогащенная база сохранена: {output_path}\")\n",
|
||||
" print(f\"\\nГотово! Обогащенная база сохранена: {output_path}\")\n",
|
||||
" print(f\"Собрано фичей для {found_count} из {len(df_main)} треков.\")\n",
|
||||
" print(df_main.head())"
|
||||
]
|
||||
|
||||
@@ -49,7 +49,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -112,7 +112,7 @@
|
||||
"print(f\"Переносим файлы в {DEAM_ROOT}...\")\n",
|
||||
"shutil.copytree(kaggle_cache_path, DEAM_ROOT, dirs_exist_ok=True)\n",
|
||||
"\n",
|
||||
"print(\"\\n[УСПЕХ] Датасет DEAM готов к работе!\")\n"
|
||||
"print(\"\\nУСПЕХ! Датасет DEAM готов к работе!\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import mean_squared_error, r2_score
|
||||
import joblib
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 1. Алфавитный маппинг EmoSet (строго из твоего кода)
|
||||
EMO_VA_MAP = {
|
||||
0: (7.5, 6.5), # amusement
|
||||
1: (2.0, 8.0), # anger
|
||||
2: (6.5, 5.0), # awe
|
||||
3: (7.0, 3.0), # contentment
|
||||
4: (3.0, 6.0), # disgust
|
||||
5: (8.0, 8.0), # excitement
|
||||
6: (2.5, 7.5), # fear
|
||||
7: (2.0, 2.0), # sadness
|
||||
}
|
||||
|
||||
def main():
|
||||
# Настраиваем пути (подразумевается, что скрипт лежит в src/music_engine)
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
||||
# Если .npy лежат в корне проекта dataset/, укажи точный путь:
|
||||
EMBEDDINGS_PATH = BASE_DIR / "src" / "emoset_test_embeddings.npy"
|
||||
LABELS_PATH = BASE_DIR / "src" / "emoset_test_labels.npy"
|
||||
MODEL_PATH = BASE_DIR / "src" / "music_engine" / "va_regressor.pkl"
|
||||
|
||||
print("🔍 Шаг 1: Загрузка эмбеддингов и меток...")
|
||||
try:
|
||||
X = np.load(EMBEDDINGS_PATH)
|
||||
y_labels = np.load(LABELS_PATH)
|
||||
model = joblib.load(MODEL_PATH)
|
||||
except Exception as e:
|
||||
print(f"Ошибка загрузки файлов: {e}")
|
||||
return
|
||||
|
||||
# Конвертируем метки классов в координаты Рассела (V, A)
|
||||
y_va = np.array([EMO_VA_MAP[label] for label in y_labels])
|
||||
|
||||
# Делаем сплит, как при обучении
|
||||
_, X_test, _, y_test = train_test_split(X, y_va, test_size=0.2, random_state=42)
|
||||
|
||||
print("Шаг 2: Выполнение предсказаний через загруженный .pkl...")
|
||||
y_pred = model.predict(X_test)
|
||||
|
||||
# Считаем метрики
|
||||
mse_v = mean_squared_error(y_test[:, 0], y_pred[:, 0])
|
||||
r2_v = r2_score(y_test[:, 0], y_pred[:, 0])
|
||||
mse_a = mean_squared_error(y_test[:, 1], y_pred[:, 1])
|
||||
r2_a = r2_score(y_test[:, 1], y_pred[:, 1])
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("Показания:")
|
||||
print(f" [MSE_VALENCE] = {mse_v:.4f}")
|
||||
print(f" [R2_VALENCE] = {r2_v:.4f}")
|
||||
print(f" [MSE_AROUSAL] = {mse_a:.4f}")
|
||||
print(f" [R2_AROUSAL] = {r2_a:.4f}")
|
||||
print("="*50 + "\n")
|
||||
|
||||
print("Шаг 3: Генерация графика...")
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
|
||||
|
||||
# Отрисовка плоскости Валентности
|
||||
ax1.scatter(y_test[:, 0], y_pred[:, 0], alpha=0.3, color='#1f77b4', edgecolors='none', label='Предсказания регрессора')
|
||||
ax1.plot([1, 9], [1, 9], 'r--', lw=2, label='Линия идеального совпадения (x=y)')
|
||||
ax1.set_title('Ось Валентности (Valence)', fontsize=14, fontweight='bold')
|
||||
ax1.set_xlabel('Истинные значения (центры 8 классов эмоций)', fontsize=12)
|
||||
ax1.set_ylabel('Непрерывные предсказания модели', fontsize=12)
|
||||
ax1.set_xlim(1, 9)
|
||||
ax1.set_ylim(1, 9)
|
||||
ax1.grid(True, linestyle='--', alpha=0.6)
|
||||
ax1.legend(loc='upper left', fontsize=10)
|
||||
ax1.text(1.2, 8.5, 'Вертикальные кластеры обусловлены\nдискретным маппингом 8 базовых\nэмоций на координатную плоскость.',
|
||||
fontsize=9, bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'))
|
||||
|
||||
# Отрисовка плоскости Возбуждения
|
||||
ax2.scatter(y_test[:, 1], y_pred[:, 1], alpha=0.3, color='#ff7f0e', edgecolors='none', label='Предсказания регрессора')
|
||||
ax2.plot([1, 9], [1, 9], 'r--', lw=2, label='Линия идеального совпадения (x=y)')
|
||||
ax2.set_title('Ось Возбуждения (Arousal)', fontsize=14, fontweight='bold')
|
||||
ax2.set_xlabel('Истинные значения (центры 8 классов эмоций)', fontsize=12)
|
||||
ax2.set_ylabel('Непрерывные предсказания модели', fontsize=12)
|
||||
ax2.set_xlim(1, 9)
|
||||
ax2.set_ylim(1, 9)
|
||||
ax2.grid(True, linestyle='--', alpha=0.6)
|
||||
ax2.legend(loc='upper left', fontsize=10)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('metrics_plot.png', dpi=300, bbox_inches='tight')
|
||||
print("График сохранен как 'metrics_plot.png'.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -13,7 +13,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"id": "1763c51e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -40,7 +40,7 @@
|
||||
"output_path = Path(\"../../dataset/DEAM/music_db.csv\")\n",
|
||||
"\n",
|
||||
"if not source_path.exists():\n",
|
||||
" print(f\"❌ Исходный файл не найден по пути: {source_path}\")\n",
|
||||
" print(f\"Исходный файл не найден по пути: {source_path}\")\n",
|
||||
"else:\n",
|
||||
" # skipinitialspace=True уберет лишние пробелы в названиях колонок, если они есть\n",
|
||||
" df = pd.read_csv(source_path, skipinitialspace=True)\n",
|
||||
@@ -57,7 +57,7 @@
|
||||
" # Сохраняем финальный файл\n",
|
||||
" clean_df.to_csv(output_path, index=False)\n",
|
||||
" \n",
|
||||
" print(f\"✅ УСПЕХ! База создана: {output_path}\")\n",
|
||||
" print(f\"Успех! База создана: {output_path}\")\n",
|
||||
" print(f\"Всего треков в базе: {len(clean_df)}\")\n",
|
||||
" print(\"Пример данных:\")\n",
|
||||
" print(clean_df.head())"
|
||||
|
||||
@@ -31,7 +31,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"id": "84c3657f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -49,7 +49,7 @@
|
||||
"source": [
|
||||
"# === CONFIG ===\n",
|
||||
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
|
||||
"BATCH_SIZE = 64 # V100 спокойно тянет\n",
|
||||
"BATCH_SIZE = 64\n",
|
||||
"EPOCHS = 15\n",
|
||||
"LR = 3e-4\n",
|
||||
"NUM_WORKERS = 24\n",
|
||||
|
||||
@@ -9,7 +9,6 @@ from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import mean_squared_error, r2_score
|
||||
import joblib
|
||||
|
||||
# 1. Алфавитный маппинг EmoSet
|
||||
EMO_VA_MAP = {
|
||||
0: (7.5, 6.5), # amusement
|
||||
1: (2.0, 8.0), # anger
|
||||
@@ -32,9 +31,7 @@ y_labels = np.load(LABELS_PATH)
|
||||
y_va = np.array([EMO_VA_MAP[label] for label in y_labels])
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y_va, test_size=0.2, random_state=42)
|
||||
|
||||
# 2. НОВАЯ, ПРАВИЛЬНАЯ АРХИТЕКТУРА (Pipeline)
|
||||
print("Обучение масштабатора и RidgeCV регрессора...")
|
||||
# Pipeline гарантирует, что при предсказании в main.py новые векторы тоже будут масштабированы
|
||||
model = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
|
||||
@@ -42,23 +39,19 @@ model = Pipeline([
|
||||
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
# 3. Диагностика и Оценка
|
||||
y_pred = model.predict(X_test)
|
||||
|
||||
mse = mean_squared_error(y_test, y_pred)
|
||||
r2 = r2_score(y_test, y_pred)
|
||||
|
||||
print(f"\n[УСПЕХ] Обучение завершено!")
|
||||
print(f"\nУспех! Обучение завершено!")
|
||||
print(f"MSE: {mse:.4f}")
|
||||
print(f"R^2 Score: {r2:.4f}")
|
||||
|
||||
# === ТОТ САМЫЙ ТЕСТ НА КОЛЛАПС ===
|
||||
print("\n--- ДИАГНОСТИКА РАЗБРОСА ПРЕДСКАЗАНИЙ ---")
|
||||
print(f"Valence: от {y_pred[:, 0].min():.2f} до {y_pred[:, 0].max():.2f} (Эталон: 2.0 - 8.0)")
|
||||
print(f"Arousal: от {y_pred[:, 1].min():.2f} до {y_pred[:, 1].max():.2f} (Эталон: 2.0 - 8.0)")
|
||||
# ===============================================
|
||||
|
||||
# 4. Сохранение (Pipeline сохраняется целиком со StandardScaler)
|
||||
output_model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
|
||||
output_model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
joblib.dump(model, output_model_path)
|
||||
|
||||
@@ -42,7 +42,7 @@ def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_pa
|
||||
st.session_state.ds_current_options = random.sample(range(len(image_files)), 6)
|
||||
st.rerun()
|
||||
else:
|
||||
st.success("✅ Анализ завершен! Ваш эмоциональный профиль готов.")
|
||||
st.success("Анализ завершен! Ваш эмоциональный профиль готов.")
|
||||
|
||||
all_v, all_a = [], []
|
||||
for idx in st.session_state.ds_chosen_indices:
|
||||
@@ -56,7 +56,7 @@ def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_pa
|
||||
col_left, col_right = st.columns([1, 2])
|
||||
|
||||
with col_left:
|
||||
st.header("📊 Ваш профиль")
|
||||
st.header("Ваш профиль")
|
||||
st.metric("Позитивность (Valence)", f"{target_v:.2f}")
|
||||
st.metric("Энергия (Arousal)", f"{target_a:.2f}")
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ def render_live_tab(matcher, image_processor):
|
||||
st.image(img, use_container_width=True)
|
||||
with st.spinner("VLM Анализ..."):
|
||||
caption = image_processor.describe_scene(img)
|
||||
st.caption(f"👁️ *{caption.capitalize()}*")
|
||||
st.caption(f"*{caption.capitalize()}*")
|
||||
all_objects.append(caption)
|
||||
|
||||
if st.button("🎵 Сгенерировать саундтрек", type="primary", use_container_width=True):
|
||||
@@ -51,7 +51,7 @@ def render_live_tab(matcher, image_processor):
|
||||
with st.spinner("Поиск треков в базе DEAM..."):
|
||||
playlist = matcher.find_nearest_tracks(target_v, target_a, llm_profile=llm_profile, top_k=5)
|
||||
|
||||
st.success("✅ Кросс-модальный анализ завершен!")
|
||||
st.success("Кросс-модальный анализ завершен!")
|
||||
|
||||
# ВЫВОД РЕЗУЛЬТАТОВ
|
||||
col_left, col_right = st.columns([1, 2])
|
||||
|
||||
Reference in New Issue
Block a user