Added music_engine and updated UI

This commit is contained in:
zin
2026-05-06 18:11:31 +00:00
parent aaa482bc6b
commit 4e192b7bc4
3 changed files with 153 additions and 178 deletions
+98 -178
View File
@@ -2,36 +2,28 @@ import streamlit as st
from pathlib import Path
import pandas as pd
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from PIL import Image
import random
import matplotlib.pyplot as plt
from music_engine.matcher import MusicMatcher
# ----------------------------
# 1️⃣ Конфигурация и запуск
# ----------------------------
if __name__ == "__main__":
# Проверяем, запущен ли скрипт через Streamlit
import os
if "STREAMLIT_RUN" not in os.environ:
import sys
import subprocess
os.environ["STREAMLIT_RUN"] = "1"
# Формируем команду запуска
cmd = [
sys.executable,
"-m",
"streamlit",
"run",
__file__,
"--server.port", "8080",
"--server.address", "0.0.0.0"
sys.executable, "-m", "streamlit", "run", __file__,
"--server.port", "8080", "--server.address", "0.0.0.0"
]
subprocess.run(cmd)
sys.exit()
# ----------------------------
# 1️⃣ Конфигурация
# ----------------------------
# Конфигурация путей
DATA_ROOT = Path("./dataset/EmoSet-118K/test")
IMAGES_DIR = DATA_ROOT / "images"
LABELS_CSV = DATA_ROOT / "labels.csv"
@@ -39,182 +31,110 @@ LABELS_CSV = DATA_ROOT / "labels.csv"
EMBEDDINGS_PATH = Path("./src/emoset_test_embeddings.npy")
LABELS_PATH = Path("./src/emoset_test_labels.npy")
NUM_CHOICES = 6 # количество изображений за один раунд
TOTAL_ROUNDS = 10 # количество раундов выбора
# Параметры эксперимента
NUM_CHOICES = 6
TOTAL_ROUNDS = 10
st.set_page_config(page_title="EmoSet Demo", layout="wide")
st.set_page_config(page_title="EmoSet & Music Recommendation Demo", layout="wide")
# Инициализация музыкального движка с кэшированием
@st.cache_resource
def load_music_engine():
# Путь относительно src: ../dataset/DEAM/music_db.csv
db_path = Path(__file__).parent.parent / "dataset" / "DEAM" / "music_db.csv"
if not db_path.exists():
return None
return MusicMatcher(db_path)
matcher = load_music_engine()
# ----------------------------
# 2️⃣ Загрузка данных
# 2️⃣ Загрузка данных EmoSet
# ----------------------------
if not IMAGES_DIR.exists():
st.error(f"Папка с изображениями не найдена: {IMAGES_DIR.resolve()}")
st.stop()
@st.cache_data
def load_emoset_data():
if not IMAGES_DIR.exists() or not EMBEDDINGS_PATH.exists():
return None, None, None
image_files = sorted([f.name for f in IMAGES_DIR.glob("*.jpg")])
embeddings = np.load(EMBEDDINGS_PATH)
labels_array = np.load(LABELS_PATH)
if len(image_files) != len(embeddings) or len(image_files) != len(labels_array):
st.error("Размеры массивов данных не совпадают!")
st.stop()
return image_files, embeddings, labels_array
labels_df = pd.read_csv(LABELS_CSV)
embeddings = np.load(EMBEDDINGS_PATH)
labels_array = np.load(LABELS_PATH)
image_files = list(IMAGES_DIR.glob("*.jpg"))
# Проверка совпадения размеров
if len(image_files) != len(embeddings) or len(image_files) != len(labels_array):
st.error("Размеры массивов не совпадают!")
st.stop()
# Создадим mapping: filename -> embedding, label
filename2embedding = {f.name: emb for f, emb in zip(image_files, embeddings)}
filename2label = {f.name: lbl for f, lbl in zip(image_files, labels_array)}
image_files, embeddings, labels_array = load_emoset_data()
# ----------------------------
# 3️⃣ Сессия Streamlit
# ----------------------------
if 'round_num' not in st.session_state:
st.session_state.round_num = 0
if 'chosen_files' not in st.session_state:
st.session_state.chosen_files = []
st.title("EmoSet: Выбор изображений")
# ----------------------------
# 4️⃣ Функция для overlay топ-3 эмоций
# ----------------------------
def get_font():
try:
return ImageFont.truetype("arial.ttf", 14)
except:
try:
return ImageFont.truetype("/usr/share/fonts/truetype/arial.ttf", 14)
except:
return ImageFont.load_default()
def overlay_top_emotions(img: Image.Image, label: int, top_n=3):
draw = ImageDraw.Draw(img)
font = get_font()
text = f"Label: {label}"
draw.rectangle([(0,0),(img.width,20)], fill=(0,0,0,150))
draw.text((2,2), text, fill=(255,255,255), font=font)
return img
# ----------------------------
# 5️⃣ Рандомный выбор 6 изображений для текущего раунда
# ----------------------------
# Инициализация состояния Streamlit
if "round_num" not in st.session_state:
st.session_state.round_num = 0
if "chosen_files" not in st.session_state:
st.session_state.chosen_files = []
if "current_choices" not in st.session_state:
st.session_state.current_choices = []
# Если все раунды уже завершены, блок пропускается
if st.session_state.round_num < TOTAL_ROUNDS:
st.subheader(f"Раунд {st.session_state.round_num + 1} из {TOTAL_ROUNDS}")
# Генерируем выбор для текущего раунда только если он ещё не создан
if len(st.session_state.current_choices) == 0:
already_chosen = set(st.session_state.chosen_files)
available_images = [f for f in image_files if f.name not in already_chosen]
if len(available_images) < NUM_CHOICES:
st.warning("Недостаточно изображений для выбора!")
st.stop()
st.session_state.current_choices = random.sample(available_images, NUM_CHOICES)
# Отображаем изображения и кнопки
cols = st.columns(NUM_CHOICES)
for col, img_path in zip(cols, st.session_state.current_choices):
# Загружаем изображение и накладываем overlay
img = Image.open(img_path).convert("RGB")
img_overlay = overlay_top_emotions(img, filename2label[img_path.name])
# Показываем изображение
col.image(img_overlay, width=250)
# Кнопка выбора
if col.button("Выбрать", key=img_path.name):
st.session_state.chosen_files.append(img_path.name)
st.session_state.round_num += 1
st.session_state.current_choices = [] # сброс для следующего раунда
st.rerun() # перезапуск для нового раунда
# ----------------------------
# 6️⃣ После всех выборов: отображение результатов
# 3️⃣ Логика приложения
# ----------------------------
if image_files is None:
st.error("Данные EmoSet не найдены. Проверьте папку dataset.")
else:
st.subheader("Результаты выбора пользователя")
if 'round' not in st.session_state:
st.session_state.round = 1
st.session_state.chosen_indices = []
st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES)
if not st.session_state.chosen_files:
st.warning("Вы не сделали ни одного выбора!")
if st.button("Начать заново"):
st.session_state.round_num = 0
st.session_state.chosen_files = []
st.rerun()
else:
# Отображение выбранных изображений
st.title("Выбор эмоциональных образов")
st.write(f"Раунд {st.session_state.round} из {TOTAL_ROUNDS}. Выберите изображение, которое больше всего соответствует вашему настроению.")
if st.session_state.round <= TOTAL_ROUNDS:
# Отображение сетки изображений
cols = st.columns(3)
for i, filename in enumerate(st.session_state.chosen_files):
for i, idx in enumerate(st.session_state.current_options):
with cols[i % 3]:
img = Image.open(IMAGES_DIR / filename).convert("RGB")
img_overlay = overlay_top_emotions(img, filename2label[filename])
st.image(img_overlay, caption=f"Выбор {i+1}")
img_path = IMAGES_DIR / image_files[idx]
img = Image.open(img_path)
st.image(img, use_container_width=True)
if st.button(f"Выбрать {i+1}", key=f"btn_{idx}"):
st.session_state.chosen_indices.append(idx)
st.session_state.round += 1
if st.session_state.round <= TOTAL_ROUNDS:
st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES)
st.rerun()
else:
# ФИНАЛЬНЫЙ ЭТАП: Анализ и Музыка
st.success("Анализ завершен! Ваш эмоциональный профиль сформирован.")
# Расчет среднего вектора пользователя
chosen_embeddings = embeddings[st.session_state.chosen_indices]
user_vector = np.mean(chosen_embeddings, axis=0)
# Расчет среднего embedding
chosen_embeddings = [filename2embedding[f] for f in st.session_state.chosen_files]
chosen_embeddings = np.stack(chosen_embeddings)
user_emotion_vector = np.mean(chosen_embeddings, axis=0)
# РАЗДЕЛ МУЗЫКАЛЬНЫХ РЕКОМЕНДАЦИЙ
st.divider()
st.header("🎵 Рекомендованный плейлист")
if matcher is None:
st.warning("База данных DEAM (music_db.csv) не найдена. Подбор музыки недоступен.")
else:
with st.spinner("Сопоставляем визуальный профиль с музыкальной базой..."):
target_v, target_a, playlist = matcher.get_playlist(user_vector, top_k=5)
# Визуализация VA-метрик пользователя
m1, m2, m3 = st.columns(3)
m1.metric("Позитивность (Valence)", f"{target_v:.2f}", help="Шкала 1-9")
m2.metric("Энергия (Arousal)", f"{target_a:.2f}", help="Шкала 1-9")
m3.metric("Найдено треков", len(playlist))
# Отображение информации о векторе эмоций
with st.expander("Подробности о векторной модели эмоций"):
st.write(f"Размерность embedding: {len(user_emotion_vector)}")
st.write("Первые 10 значений:")
st.json({f"Dim {i}": float(val) for i, val in enumerate(user_emotion_vector[:10])})
# Таблица с результатами
st.subheader("Топ-5 подходящих композиций")
st.table(playlist[['song_id', 'valence', 'arousal', 'distance']])
st.info("💡 Вы можете найти эти треки по ID в папке audio датасета DEAM.")
# Визуализация
col1, col2 = st.columns(2)
with col1:
# Гистограмма первых 30 измерений
plt.figure(figsize=(8,4))
plt.bar(range(min(30, len(user_emotion_vector))), user_emotion_vector[:30])
plt.xlabel("Embedding dimension")
plt.ylabel("Value")
plt.title("Распределение значений embedding (первые 30 измерений)")
st.pyplot(plt)
with col2:
# Круговая диаграмма средних значений по блокам
if len(user_emotion_vector) > 16:
block_size = len(user_emotion_vector) // 4
block_means = [
np.mean(user_emotion_vector[i*block_size:(i+1)*block_size])
for i in range(4)
]
plt.figure(figsize=(8,4))
plt.pie(block_means, labels=[f"Block {i+1}" for i in range(4)], autopct='%1.1f%%')
plt.title("Распределение эмоциональных блоков")
st.pyplot(plt)
# Кнопка для сохранения результатов
if st.button("Сохранить результаты"):
# Сохранение в файл (например, JSON)
results = {
"chosen_files": st.session_state.chosen_files,
"user_emotion_vector": user_emotion_vector.tolist(),
"timestamp": pd.Timestamp.now().isoformat()
}
# Сохраняем в текущую директорию
save_path = Path("user_results.json")
with open(save_path, 'w') as f:
import json
json.dump(results, f)
st.success(f"Результаты сохранены в {save_path}")
# Визуализация вектора (графики)
st.divider()
st.subheader("Визуализация эмоционального вектора")
fig, ax = plt.subplots(figsize=(10, 3))
ax.plot(user_vector[:100]) # Показываем первые 100 измерений для наглядности
ax.set_title("Эмбеддинг вашего настроения (фрагмент)")
st.pyplot(fig)
if st.button("Начать заново"):
st.session_state.round_num = 0
st.session_state.chosen_files = []
st.rerun()
for key in list(st.session_state.keys()):
del st.session_state[key]
st.rerun()