Added music_engine and updated UI

This commit is contained in:
zin
2026-05-06 18:11:31 +00:00
parent aaa482bc6b
commit 4e192b7bc4
3 changed files with 153 additions and 178 deletions
+90 -170
View File
@@ -2,36 +2,28 @@ import streamlit as st
from pathlib import Path
import pandas as pd
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from PIL import Image
import random
import matplotlib.pyplot as plt
from music_engine.matcher import MusicMatcher
# ----------------------------
# 1️⃣ Конфигурация и запуск
# ----------------------------
if __name__ == "__main__":
# Проверяем, запущен ли скрипт через Streamlit
import os
if "STREAMLIT_RUN" not in os.environ:
import sys
import subprocess
os.environ["STREAMLIT_RUN"] = "1"
# Формируем команду запуска
cmd = [
sys.executable,
"-m",
"streamlit",
"run",
__file__,
"--server.port", "8080",
"--server.address", "0.0.0.0"
sys.executable, "-m", "streamlit", "run", __file__,
"--server.port", "8080", "--server.address", "0.0.0.0"
]
subprocess.run(cmd)
sys.exit()
# ----------------------------
# 1️⃣ Конфигурация
# ----------------------------
# Конфигурация путей
DATA_ROOT = Path("./dataset/EmoSet-118K/test")
IMAGES_DIR = DATA_ROOT / "images"
LABELS_CSV = DATA_ROOT / "labels.csv"
@@ -39,182 +31,110 @@ LABELS_CSV = DATA_ROOT / "labels.csv"
EMBEDDINGS_PATH = Path("./src/emoset_test_embeddings.npy")
LABELS_PATH = Path("./src/emoset_test_labels.npy")
NUM_CHOICES = 6 # количество изображений за один раунд
TOTAL_ROUNDS = 10 # количество раундов выбора
# Параметры эксперимента
NUM_CHOICES = 6
TOTAL_ROUNDS = 10
st.set_page_config(page_title="EmoSet Demo", layout="wide")
st.set_page_config(page_title="EmoSet & Music Recommendation Demo", layout="wide")
# Инициализация музыкального движка с кэшированием
@st.cache_resource
def load_music_engine():
# Путь относительно src: ../dataset/DEAM/music_db.csv
db_path = Path(__file__).parent.parent / "dataset" / "DEAM" / "music_db.csv"
if not db_path.exists():
return None
return MusicMatcher(db_path)
matcher = load_music_engine()
# ----------------------------
# 2️⃣ Загрузка данных
# 2️⃣ Загрузка данных EmoSet
# ----------------------------
if not IMAGES_DIR.exists():
st.error(f"Папка с изображениями не найдена: {IMAGES_DIR.resolve()}")
st.stop()
@st.cache_data
def load_emoset_data():
if not IMAGES_DIR.exists() or not EMBEDDINGS_PATH.exists():
return None, None, None
labels_df = pd.read_csv(LABELS_CSV)
embeddings = np.load(EMBEDDINGS_PATH)
labels_array = np.load(LABELS_PATH)
image_files = sorted([f.name for f in IMAGES_DIR.glob("*.jpg")])
embeddings = np.load(EMBEDDINGS_PATH)
labels_array = np.load(LABELS_PATH)
image_files = list(IMAGES_DIR.glob("*.jpg"))
if len(image_files) != len(embeddings) or len(image_files) != len(labels_array):
st.error("Размеры массивов данных не совпадают!")
st.stop()
# Проверка совпадения размеров
if len(image_files) != len(embeddings) or len(image_files) != len(labels_array):
st.error("Размеры массивов не совпадают!")
st.stop()
return image_files, embeddings, labels_array
# Создадим mapping: filename -> embedding, label
filename2embedding = {f.name: emb for f, emb in zip(image_files, embeddings)}
filename2label = {f.name: lbl for f, lbl in zip(image_files, labels_array)}
image_files, embeddings, labels_array = load_emoset_data()
# ----------------------------
# 3️⃣ Сессия Streamlit
# ----------------------------
if 'round_num' not in st.session_state:
st.session_state.round_num = 0
if 'chosen_files' not in st.session_state:
st.session_state.chosen_files = []
st.title("EmoSet: Выбор изображений")
# ----------------------------
# 4️⃣ Функция для overlay топ-3 эмоций
# ----------------------------
def get_font():
try:
return ImageFont.truetype("arial.ttf", 14)
except:
try:
return ImageFont.truetype("/usr/share/fonts/truetype/arial.ttf", 14)
except:
return ImageFont.load_default()
def overlay_top_emotions(img: Image.Image, label: int, top_n=3):
draw = ImageDraw.Draw(img)
font = get_font()
text = f"Label: {label}"
draw.rectangle([(0,0),(img.width,20)], fill=(0,0,0,150))
draw.text((2,2), text, fill=(255,255,255), font=font)
return img
# ----------------------------
# 5️⃣ Рандомный выбор 6 изображений для текущего раунда
# ----------------------------
# Инициализация состояния Streamlit
if "round_num" not in st.session_state:
st.session_state.round_num = 0
if "chosen_files" not in st.session_state:
st.session_state.chosen_files = []
if "current_choices" not in st.session_state:
st.session_state.current_choices = []
# Если все раунды уже завершены, блок пропускается
if st.session_state.round_num < TOTAL_ROUNDS:
st.subheader(f"Раунд {st.session_state.round_num + 1} из {TOTAL_ROUNDS}")
# Генерируем выбор для текущего раунда только если он ещё не создан
if len(st.session_state.current_choices) == 0:
already_chosen = set(st.session_state.chosen_files)
available_images = [f for f in image_files if f.name not in already_chosen]
if len(available_images) < NUM_CHOICES:
st.warning("Недостаточно изображений для выбора!")
st.stop()
st.session_state.current_choices = random.sample(available_images, NUM_CHOICES)
# Отображаем изображения и кнопки
cols = st.columns(NUM_CHOICES)
for col, img_path in zip(cols, st.session_state.current_choices):
# Загружаем изображение и накладываем overlay
img = Image.open(img_path).convert("RGB")
img_overlay = overlay_top_emotions(img, filename2label[img_path.name])
# Показываем изображение
col.image(img_overlay, width=250)
# Кнопка выбора
if col.button("Выбрать", key=img_path.name):
st.session_state.chosen_files.append(img_path.name)
st.session_state.round_num += 1
st.session_state.current_choices = [] # сброс для следующего раунда
st.rerun() # перезапуск для нового раунда
# ----------------------------
# 6️⃣ После всех выборов: отображение результатов
# 3️⃣ Логика приложения
# ----------------------------
if image_files is None:
st.error("Данные EmoSet не найдены. Проверьте папку dataset.")
else:
st.subheader("Результаты выбора пользователя")
if 'round' not in st.session_state:
st.session_state.round = 1
st.session_state.chosen_indices = []
st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES)
if not st.session_state.chosen_files:
st.warning("Вы не сделали ни одного выбора!")
if st.button("Начать заново"):
st.session_state.round_num = 0
st.session_state.chosen_files = []
st.rerun()
else:
# Отображение выбранных изображений
st.title("Выбор эмоциональных образов")
st.write(f"Раунд {st.session_state.round} из {TOTAL_ROUNDS}. Выберите изображение, которое больше всего соответствует вашему настроению.")
if st.session_state.round <= TOTAL_ROUNDS:
# Отображение сетки изображений
cols = st.columns(3)
for i, filename in enumerate(st.session_state.chosen_files):
for i, idx in enumerate(st.session_state.current_options):
with cols[i % 3]:
img = Image.open(IMAGES_DIR / filename).convert("RGB")
img_overlay = overlay_top_emotions(img, filename2label[filename])
st.image(img_overlay, caption=f"Выбор {i+1}")
img_path = IMAGES_DIR / image_files[idx]
img = Image.open(img_path)
st.image(img, use_container_width=True)
if st.button(f"Выбрать {i+1}", key=f"btn_{idx}"):
st.session_state.chosen_indices.append(idx)
st.session_state.round += 1
if st.session_state.round <= TOTAL_ROUNDS:
st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES)
st.rerun()
else:
# ФИНАЛЬНЫЙ ЭТАП: Анализ и Музыка
st.success("Анализ завершен! Ваш эмоциональный профиль сформирован.")
# Расчет среднего embedding
chosen_embeddings = [filename2embedding[f] for f in st.session_state.chosen_files]
chosen_embeddings = np.stack(chosen_embeddings)
user_emotion_vector = np.mean(chosen_embeddings, axis=0)
# Расчет среднего вектора пользователя
chosen_embeddings = embeddings[st.session_state.chosen_indices]
user_vector = np.mean(chosen_embeddings, axis=0)
# Отображение информации о векторе эмоций
with st.expander("Подробности о векторной модели эмоций"):
st.write(f"Размерность embedding: {len(user_emotion_vector)}")
st.write("Первые 10 значений:")
st.json({f"Dim {i}": float(val) for i, val in enumerate(user_emotion_vector[:10])})
# РАЗДЕЛ МУЗЫКАЛЬНЫХ РЕКОМЕНДАЦИЙ
st.divider()
st.header("🎵 Рекомендованный плейлист")
# Визуализация
col1, col2 = st.columns(2)
if matcher is None:
st.warning("База данных DEAM (music_db.csv) не найдена. Подбор музыки недоступен.")
else:
with st.spinner("Сопоставляем визуальный профиль с музыкальной базой..."):
target_v, target_a, playlist = matcher.get_playlist(user_vector, top_k=5)
with col1:
# Гистограмма первых 30 измерений
plt.figure(figsize=(8,4))
plt.bar(range(min(30, len(user_emotion_vector))), user_emotion_vector[:30])
plt.xlabel("Embedding dimension")
plt.ylabel("Value")
plt.title("Распределение значений embedding (первые 30 измерений)")
st.pyplot(plt)
# Визуализация VA-метрик пользователя
m1, m2, m3 = st.columns(3)
m1.metric("Позитивность (Valence)", f"{target_v:.2f}", help="Шкала 1-9")
m2.metric("Энергия (Arousal)", f"{target_a:.2f}", help="Шкала 1-9")
m3.metric("Найдено треков", len(playlist))
with col2:
# Круговая диаграмма средних значений по блокам
if len(user_emotion_vector) > 16:
block_size = len(user_emotion_vector) // 4
block_means = [
np.mean(user_emotion_vector[i*block_size:(i+1)*block_size])
for i in range(4)
]
plt.figure(figsize=(8,4))
plt.pie(block_means, labels=[f"Block {i+1}" for i in range(4)], autopct='%1.1f%%')
plt.title("Распределение эмоциональных блоков")
st.pyplot(plt)
# Таблица с результатами
st.subheader("Топ-5 подходящих композиций")
st.table(playlist[['song_id', 'valence', 'arousal', 'distance']])
# Кнопка для сохранения результатов
if st.button("Сохранить результаты"):
# Сохранение в файл (например, JSON)
results = {
"chosen_files": st.session_state.chosen_files,
"user_emotion_vector": user_emotion_vector.tolist(),
"timestamp": pd.Timestamp.now().isoformat()
}
st.info("💡 Вы можете найти эти треки по ID в папке audio датасета DEAM.")
# Сохраняем в текущую директорию
save_path = Path("user_results.json")
with open(save_path, 'w') as f:
import json
json.dump(results, f)
st.success(f"Результаты сохранены в {save_path}")
# Визуализация вектора (графики)
st.divider()
st.subheader("Визуализация эмоционального вектора")
fig, ax = plt.subplots(figsize=(10, 3))
ax.plot(user_vector[:100]) # Показываем первые 100 измерений для наглядности
ax.set_title("Эмбеддинг вашего настроения (фрагмент)")
st.pyplot(fig)
if st.button("Начать заново"):
st.session_state.round_num = 0
st.session_state.chosen_files = []
for key in list(st.session_state.keys()):
del st.session_state[key]
st.rerun()
View File
+55
View File
@@ -0,0 +1,55 @@
import numpy as np
import pandas as pd
from pathlib import Path
class MusicMatcher:
def __init__(self, db_path: Path):
self.music_db = pd.read_csv(db_path)
# Убедимся, что данные числовые
self.music_db['valence'] = pd.to_numeric(self.music_db['valence'], errors='coerce')
self.music_db['arousal'] = pd.to_numeric(self.music_db['arousal'], errors='coerce')
self.music_db = self.music_db.dropna()
def predict_va(self, embedding: np.ndarray):
"""
Умный хак для демо: используем статистику вектора.
Мрачные картинки имеют меньшую среднюю активацию.
"""
# 1. Средняя сила активации (прокси для Valence)
mean_act = float(np.mean(embedding))
# 2. Разреженность вектора (прокси для Arousal)
# Чем больше нулей или близких к нулю значений, тем "спокойнее" картинка
sparsity = float(np.mean(embedding < 0.1))
# Масштабируем: если средняя активация низкая (мрачное), уводим Valence вниз
# (Типичный mean_act для ResNet обычно от 0.2 до 0.8)
v = np.interp(mean_act, [0.2, 0.6], [2.0, 8.0])
# Если много нулей (пустота/мрак), Arousal падает.
# Если нулей мало (пестрый мусор) - Arousal растет.
a = np.interp(sparsity, [0.3, 0.8], [8.0, 2.0]) # Обратная зависимость
# Добавляем "соль" из суммы вектора, чтобы избежать одинаковых результатов
# для похожих, но разных наборов картинок
salt = (float(np.sum(embedding)) % 2.0) - 1.0
v += salt
a += salt
return np.clip(v, 1.0, 9.0), np.clip(a, 1.0, 9.0)
def get_playlist(self, user_vector: np.ndarray, top_k: int = 5):
target_v, target_a = self.predict_va(user_vector)
# Считаем Евклидово расстояние от пользователя до всех треков
distances = np.sqrt(
(self.music_db['valence'] - target_v)**2 +
(self.music_db['arousal'] - target_a)**2
)
# Добавляем дистанцию, сортируем по возрастанию (чем меньше, тем ближе)
df_result = self.music_db.copy()
df_result['distance'] = distances
playlist = df_result.sort_values(by='distance').head(top_k)
return target_v, target_a, playlist