Added music_engine and updated UI

This commit is contained in:
zin
2026-05-06 18:11:31 +00:00
parent aaa482bc6b
commit 4e192b7bc4
3 changed files with 153 additions and 178 deletions
+91 -171
View File
@@ -2,36 +2,28 @@ import streamlit as st
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from PIL import Image, ImageDraw, ImageFont from PIL import Image
import random import random
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from music_engine.matcher import MusicMatcher
# ----------------------------
# 1️⃣ Конфигурация и запуск
# ----------------------------
if __name__ == "__main__": if __name__ == "__main__":
# Проверяем, запущен ли скрипт через Streamlit
import os import os
if "STREAMLIT_RUN" not in os.environ: if "STREAMLIT_RUN" not in os.environ:
import sys import sys
import subprocess import subprocess
os.environ["STREAMLIT_RUN"] = "1" os.environ["STREAMLIT_RUN"] = "1"
# Формируем команду запуска
cmd = [ cmd = [
sys.executable, sys.executable, "-m", "streamlit", "run", __file__,
"-m", "--server.port", "8080", "--server.address", "0.0.0.0"
"streamlit",
"run",
__file__,
"--server.port", "8080",
"--server.address", "0.0.0.0"
] ]
subprocess.run(cmd) subprocess.run(cmd)
sys.exit() sys.exit()
# Конфигурация путей
# ----------------------------
# 1️⃣ Конфигурация
# ----------------------------
DATA_ROOT = Path("./dataset/EmoSet-118K/test") DATA_ROOT = Path("./dataset/EmoSet-118K/test")
IMAGES_DIR = DATA_ROOT / "images" IMAGES_DIR = DATA_ROOT / "images"
LABELS_CSV = DATA_ROOT / "labels.csv" LABELS_CSV = DATA_ROOT / "labels.csv"
@@ -39,182 +31,110 @@ LABELS_CSV = DATA_ROOT / "labels.csv"
EMBEDDINGS_PATH = Path("./src/emoset_test_embeddings.npy") EMBEDDINGS_PATH = Path("./src/emoset_test_embeddings.npy")
LABELS_PATH = Path("./src/emoset_test_labels.npy") LABELS_PATH = Path("./src/emoset_test_labels.npy")
NUM_CHOICES = 6 # количество изображений за один раунд # Параметры эксперимента
TOTAL_ROUNDS = 10 # количество раундов выбора NUM_CHOICES = 6
TOTAL_ROUNDS = 10
st.set_page_config(page_title="EmoSet Demo", layout="wide") st.set_page_config(page_title="EmoSet & Music Recommendation Demo", layout="wide")
# Инициализация музыкального движка с кэшированием
@st.cache_resource
def load_music_engine():
# Путь относительно src: ../dataset/DEAM/music_db.csv
db_path = Path(__file__).parent.parent / "dataset" / "DEAM" / "music_db.csv"
if not db_path.exists():
return None
return MusicMatcher(db_path)
matcher = load_music_engine()
# ---------------------------- # ----------------------------
# 2️⃣ Загрузка данных # 2️⃣ Загрузка данных EmoSet
# ---------------------------- # ----------------------------
if not IMAGES_DIR.exists(): @st.cache_data
st.error(f"Папка с изображениями не найдена: {IMAGES_DIR.resolve()}") def load_emoset_data():
if not IMAGES_DIR.exists() or not EMBEDDINGS_PATH.exists():
return None, None, None
image_files = sorted([f.name for f in IMAGES_DIR.glob("*.jpg")])
embeddings = np.load(EMBEDDINGS_PATH)
labels_array = np.load(LABELS_PATH)
if len(image_files) != len(embeddings) or len(image_files) != len(labels_array):
st.error("Размеры массивов данных не совпадают!")
st.stop() st.stop()
labels_df = pd.read_csv(LABELS_CSV) return image_files, embeddings, labels_array
embeddings = np.load(EMBEDDINGS_PATH)
labels_array = np.load(LABELS_PATH)
image_files = list(IMAGES_DIR.glob("*.jpg")) image_files, embeddings, labels_array = load_emoset_data()
# Проверка совпадения размеров
if len(image_files) != len(embeddings) or len(image_files) != len(labels_array):
st.error("Размеры массивов не совпадают!")
st.stop()
# Создадим mapping: filename -> embedding, label
filename2embedding = {f.name: emb for f, emb in zip(image_files, embeddings)}
filename2label = {f.name: lbl for f, lbl in zip(image_files, labels_array)}
# ---------------------------- # ----------------------------
# 3️⃣ Сессия Streamlit # 3️⃣ Логика приложения
# ----------------------------
if 'round_num' not in st.session_state:
st.session_state.round_num = 0
if 'chosen_files' not in st.session_state:
st.session_state.chosen_files = []
st.title("EmoSet: Выбор изображений")
# ----------------------------
# 4️⃣ Функция для overlay топ-3 эмоций
# ----------------------------
def get_font():
try:
return ImageFont.truetype("arial.ttf", 14)
except:
try:
return ImageFont.truetype("/usr/share/fonts/truetype/arial.ttf", 14)
except:
return ImageFont.load_default()
def overlay_top_emotions(img: Image.Image, label: int, top_n=3):
draw = ImageDraw.Draw(img)
font = get_font()
text = f"Label: {label}"
draw.rectangle([(0,0),(img.width,20)], fill=(0,0,0,150))
draw.text((2,2), text, fill=(255,255,255), font=font)
return img
# ----------------------------
# 5️⃣ Рандомный выбор 6 изображений для текущего раунда
# ----------------------------
# Инициализация состояния Streamlit
if "round_num" not in st.session_state:
st.session_state.round_num = 0
if "chosen_files" not in st.session_state:
st.session_state.chosen_files = []
if "current_choices" not in st.session_state:
st.session_state.current_choices = []
# Если все раунды уже завершены, блок пропускается
if st.session_state.round_num < TOTAL_ROUNDS:
st.subheader(f"Раунд {st.session_state.round_num + 1} из {TOTAL_ROUNDS}")
# Генерируем выбор для текущего раунда только если он ещё не создан
if len(st.session_state.current_choices) == 0:
already_chosen = set(st.session_state.chosen_files)
available_images = [f for f in image_files if f.name not in already_chosen]
if len(available_images) < NUM_CHOICES:
st.warning("Недостаточно изображений для выбора!")
st.stop()
st.session_state.current_choices = random.sample(available_images, NUM_CHOICES)
# Отображаем изображения и кнопки
cols = st.columns(NUM_CHOICES)
for col, img_path in zip(cols, st.session_state.current_choices):
# Загружаем изображение и накладываем overlay
img = Image.open(img_path).convert("RGB")
img_overlay = overlay_top_emotions(img, filename2label[img_path.name])
# Показываем изображение
col.image(img_overlay, width=250)
# Кнопка выбора
if col.button("Выбрать", key=img_path.name):
st.session_state.chosen_files.append(img_path.name)
st.session_state.round_num += 1
st.session_state.current_choices = [] # сброс для следующего раунда
st.rerun() # перезапуск для нового раунда
# ----------------------------
# 6️⃣ После всех выборов: отображение результатов
# ---------------------------- # ----------------------------
if image_files is None:
st.error("Данные EmoSet не найдены. Проверьте папку dataset.")
else: else:
st.subheader("Результаты выбора пользователя") if 'round' not in st.session_state:
st.session_state.round = 1
st.session_state.chosen_indices = []
st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES)
if not st.session_state.chosen_files: st.title("Выбор эмоциональных образов")
st.warning("Вы не сделали ни одного выбора!") st.write(f"Раунд {st.session_state.round} из {TOTAL_ROUNDS}. Выберите изображение, которое больше всего соответствует вашему настроению.")
if st.button("Начать заново"):
st.session_state.round_num = 0 if st.session_state.round <= TOTAL_ROUNDS:
st.session_state.chosen_files = [] # Отображение сетки изображений
cols = st.columns(3)
for i, idx in enumerate(st.session_state.current_options):
with cols[i % 3]:
img_path = IMAGES_DIR / image_files[idx]
img = Image.open(img_path)
st.image(img, use_container_width=True)
if st.button(f"Выбрать {i+1}", key=f"btn_{idx}"):
st.session_state.chosen_indices.append(idx)
st.session_state.round += 1
if st.session_state.round <= TOTAL_ROUNDS:
st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES)
st.rerun() st.rerun()
else: else:
# Отображение выбранных изображений # ФИНАЛЬНЫЙ ЭТАП: Анализ и Музыка
cols = st.columns(3) st.success("Анализ завершен! Ваш эмоциональный профиль сформирован.")
for i, filename in enumerate(st.session_state.chosen_files):
with cols[i % 3]:
img = Image.open(IMAGES_DIR / filename).convert("RGB")
img_overlay = overlay_top_emotions(img, filename2label[filename])
st.image(img_overlay, caption=f"Выбор {i+1}")
# Расчет среднего embedding # Расчет среднего вектора пользователя
chosen_embeddings = [filename2embedding[f] for f in st.session_state.chosen_files] chosen_embeddings = embeddings[st.session_state.chosen_indices]
chosen_embeddings = np.stack(chosen_embeddings) user_vector = np.mean(chosen_embeddings, axis=0)
user_emotion_vector = np.mean(chosen_embeddings, axis=0)
# Отображение информации о векторе эмоций # РАЗДЕЛ МУЗЫКАЛЬНЫХ РЕКОМЕНДАЦИЙ
with st.expander("Подробности о векторной модели эмоций"): st.divider()
st.write(f"Размерность embedding: {len(user_emotion_vector)}") st.header("🎵 Рекомендованный плейлист")
st.write("Первые 10 значений:")
st.json({f"Dim {i}": float(val) for i, val in enumerate(user_emotion_vector[:10])})
# Визуализация if matcher is None:
col1, col2 = st.columns(2) st.warning("База данных DEAM (music_db.csv) не найдена. Подбор музыки недоступен.")
else:
with st.spinner("Сопоставляем визуальный профиль с музыкальной базой..."):
target_v, target_a, playlist = matcher.get_playlist(user_vector, top_k=5)
with col1: # Визуализация VA-метрик пользователя
# Гистограмма первых 30 измерений m1, m2, m3 = st.columns(3)
plt.figure(figsize=(8,4)) m1.metric("Позитивность (Valence)", f"{target_v:.2f}", help="Шкала 1-9")
plt.bar(range(min(30, len(user_emotion_vector))), user_emotion_vector[:30]) m2.metric("Энергия (Arousal)", f"{target_a:.2f}", help="Шкала 1-9")
plt.xlabel("Embedding dimension") m3.metric("Найдено треков", len(playlist))
plt.ylabel("Value")
plt.title("Распределение значений embedding (первые 30 измерений)")
st.pyplot(plt)
with col2: # Таблица с результатами
# Круговая диаграмма средних значений по блокам st.subheader("Топ-5 подходящих композиций")
if len(user_emotion_vector) > 16: st.table(playlist[['song_id', 'valence', 'arousal', 'distance']])
block_size = len(user_emotion_vector) // 4
block_means = [
np.mean(user_emotion_vector[i*block_size:(i+1)*block_size])
for i in range(4)
]
plt.figure(figsize=(8,4))
plt.pie(block_means, labels=[f"Block {i+1}" for i in range(4)], autopct='%1.1f%%')
plt.title("Распределение эмоциональных блоков")
st.pyplot(plt)
# Кнопка для сохранения результатов st.info("💡 Вы можете найти эти треки по ID в папке audio датасета DEAM.")
if st.button("Сохранить результаты"):
# Сохранение в файл (например, JSON)
results = {
"chosen_files": st.session_state.chosen_files,
"user_emotion_vector": user_emotion_vector.tolist(),
"timestamp": pd.Timestamp.now().isoformat()
}
# Сохраняем в текущую директорию # Визуализация вектора (графики)
save_path = Path("user_results.json") st.divider()
with open(save_path, 'w') as f: st.subheader("Визуализация эмоционального вектора")
import json fig, ax = plt.subplots(figsize=(10, 3))
json.dump(results, f) ax.plot(user_vector[:100]) # Показываем первые 100 измерений для наглядности
ax.set_title("Эмбеддинг вашего настроения (фрагмент)")
st.success(f"Результаты сохранены в {save_path}") st.pyplot(fig)
if st.button("Начать заново"): if st.button("Начать заново"):
st.session_state.round_num = 0 for key in list(st.session_state.keys()):
st.session_state.chosen_files = [] del st.session_state[key]
st.rerun() st.rerun()
View File
+55
View File
@@ -0,0 +1,55 @@
import numpy as np
import pandas as pd
from pathlib import Path
class MusicMatcher:
def __init__(self, db_path: Path):
self.music_db = pd.read_csv(db_path)
# Убедимся, что данные числовые
self.music_db['valence'] = pd.to_numeric(self.music_db['valence'], errors='coerce')
self.music_db['arousal'] = pd.to_numeric(self.music_db['arousal'], errors='coerce')
self.music_db = self.music_db.dropna()
def predict_va(self, embedding: np.ndarray):
"""
Умный хак для демо: используем статистику вектора.
Мрачные картинки имеют меньшую среднюю активацию.
"""
# 1. Средняя сила активации (прокси для Valence)
mean_act = float(np.mean(embedding))
# 2. Разреженность вектора (прокси для Arousal)
# Чем больше нулей или близких к нулю значений, тем "спокойнее" картинка
sparsity = float(np.mean(embedding < 0.1))
# Масштабируем: если средняя активация низкая (мрачное), уводим Valence вниз
# (Типичный mean_act для ResNet обычно от 0.2 до 0.8)
v = np.interp(mean_act, [0.2, 0.6], [2.0, 8.0])
# Если много нулей (пустота/мрак), Arousal падает.
# Если нулей мало (пестрый мусор) - Arousal растет.
a = np.interp(sparsity, [0.3, 0.8], [8.0, 2.0]) # Обратная зависимость
# Добавляем "соль" из суммы вектора, чтобы избежать одинаковых результатов
# для похожих, но разных наборов картинок
salt = (float(np.sum(embedding)) % 2.0) - 1.0
v += salt
a += salt
return np.clip(v, 1.0, 9.0), np.clip(a, 1.0, 9.0)
def get_playlist(self, user_vector: np.ndarray, top_k: int = 5):
target_v, target_a = self.predict_va(user_vector)
# Считаем Евклидово расстояние от пользователя до всех треков
distances = np.sqrt(
(self.music_db['valence'] - target_v)**2 +
(self.music_db['arousal'] - target_a)**2
)
# Добавляем дистанцию, сортируем по возрастанию (чем меньше, тем ближе)
df_result = self.music_db.copy()
df_result['distance'] = distances
playlist = df_result.sort_values(by='distance').head(top_k)
return target_v, target_a, playlist