Added music_engine and updated UI
This commit is contained in:
+98
-178
@@ -2,36 +2,28 @@ import streamlit as st
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from PIL import Image
|
||||
import random
|
||||
import matplotlib.pyplot as plt
|
||||
from music_engine.matcher import MusicMatcher
|
||||
|
||||
# ----------------------------
|
||||
# 1️⃣ Конфигурация и запуск
|
||||
# ----------------------------
|
||||
if __name__ == "__main__":
|
||||
# Проверяем, запущен ли скрипт через Streamlit
|
||||
import os
|
||||
if "STREAMLIT_RUN" not in os.environ:
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
os.environ["STREAMLIT_RUN"] = "1"
|
||||
|
||||
# Формируем команду запуска
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"streamlit",
|
||||
"run",
|
||||
__file__,
|
||||
"--server.port", "8080",
|
||||
"--server.address", "0.0.0.0"
|
||||
sys.executable, "-m", "streamlit", "run", __file__,
|
||||
"--server.port", "8080", "--server.address", "0.0.0.0"
|
||||
]
|
||||
subprocess.run(cmd)
|
||||
sys.exit()
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 1️⃣ Конфигурация
|
||||
# ----------------------------
|
||||
# Конфигурация путей
|
||||
DATA_ROOT = Path("./dataset/EmoSet-118K/test")
|
||||
IMAGES_DIR = DATA_ROOT / "images"
|
||||
LABELS_CSV = DATA_ROOT / "labels.csv"
|
||||
@@ -39,182 +31,110 @@ LABELS_CSV = DATA_ROOT / "labels.csv"
|
||||
EMBEDDINGS_PATH = Path("./src/emoset_test_embeddings.npy")
|
||||
LABELS_PATH = Path("./src/emoset_test_labels.npy")
|
||||
|
||||
NUM_CHOICES = 6 # количество изображений за один раунд
|
||||
TOTAL_ROUNDS = 10 # количество раундов выбора
|
||||
# Параметры эксперимента
|
||||
NUM_CHOICES = 6
|
||||
TOTAL_ROUNDS = 10
|
||||
|
||||
st.set_page_config(page_title="EmoSet Demo", layout="wide")
|
||||
st.set_page_config(page_title="EmoSet & Music Recommendation Demo", layout="wide")
|
||||
|
||||
# Инициализация музыкального движка с кэшированием
|
||||
@st.cache_resource
|
||||
def load_music_engine():
|
||||
# Путь относительно src: ../dataset/DEAM/music_db.csv
|
||||
db_path = Path(__file__).parent.parent / "dataset" / "DEAM" / "music_db.csv"
|
||||
if not db_path.exists():
|
||||
return None
|
||||
return MusicMatcher(db_path)
|
||||
|
||||
matcher = load_music_engine()
|
||||
|
||||
# ----------------------------
|
||||
# 2️⃣ Загрузка данных
|
||||
# 2️⃣ Загрузка данных EmoSet
|
||||
# ----------------------------
|
||||
if not IMAGES_DIR.exists():
|
||||
st.error(f"Папка с изображениями не найдена: {IMAGES_DIR.resolve()}")
|
||||
st.stop()
|
||||
@st.cache_data
|
||||
def load_emoset_data():
|
||||
if not IMAGES_DIR.exists() or not EMBEDDINGS_PATH.exists():
|
||||
return None, None, None
|
||||
|
||||
image_files = sorted([f.name for f in IMAGES_DIR.glob("*.jpg")])
|
||||
embeddings = np.load(EMBEDDINGS_PATH)
|
||||
labels_array = np.load(LABELS_PATH)
|
||||
|
||||
if len(image_files) != len(embeddings) or len(image_files) != len(labels_array):
|
||||
st.error("Размеры массивов данных не совпадают!")
|
||||
st.stop()
|
||||
|
||||
return image_files, embeddings, labels_array
|
||||
|
||||
labels_df = pd.read_csv(LABELS_CSV)
|
||||
embeddings = np.load(EMBEDDINGS_PATH)
|
||||
labels_array = np.load(LABELS_PATH)
|
||||
|
||||
image_files = list(IMAGES_DIR.glob("*.jpg"))
|
||||
|
||||
# Проверка совпадения размеров
|
||||
if len(image_files) != len(embeddings) or len(image_files) != len(labels_array):
|
||||
st.error("Размеры массивов не совпадают!")
|
||||
st.stop()
|
||||
|
||||
# Создадим mapping: filename -> embedding, label
|
||||
filename2embedding = {f.name: emb for f, emb in zip(image_files, embeddings)}
|
||||
filename2label = {f.name: lbl for f, lbl in zip(image_files, labels_array)}
|
||||
image_files, embeddings, labels_array = load_emoset_data()
|
||||
|
||||
# ----------------------------
|
||||
# 3️⃣ Сессия Streamlit
|
||||
# ----------------------------
|
||||
if 'round_num' not in st.session_state:
|
||||
st.session_state.round_num = 0
|
||||
if 'chosen_files' not in st.session_state:
|
||||
st.session_state.chosen_files = []
|
||||
|
||||
st.title("EmoSet: Выбор изображений")
|
||||
|
||||
# ----------------------------
|
||||
# 4️⃣ Функция для overlay топ-3 эмоций
|
||||
# ----------------------------
|
||||
def get_font():
|
||||
try:
|
||||
return ImageFont.truetype("arial.ttf", 14)
|
||||
except:
|
||||
try:
|
||||
return ImageFont.truetype("/usr/share/fonts/truetype/arial.ttf", 14)
|
||||
except:
|
||||
return ImageFont.load_default()
|
||||
|
||||
def overlay_top_emotions(img: Image.Image, label: int, top_n=3):
|
||||
draw = ImageDraw.Draw(img)
|
||||
font = get_font()
|
||||
text = f"Label: {label}"
|
||||
draw.rectangle([(0,0),(img.width,20)], fill=(0,0,0,150))
|
||||
draw.text((2,2), text, fill=(255,255,255), font=font)
|
||||
return img
|
||||
|
||||
# ----------------------------
|
||||
# 5️⃣ Рандомный выбор 6 изображений для текущего раунда
|
||||
# ----------------------------
|
||||
# Инициализация состояния Streamlit
|
||||
if "round_num" not in st.session_state:
|
||||
st.session_state.round_num = 0
|
||||
if "chosen_files" not in st.session_state:
|
||||
st.session_state.chosen_files = []
|
||||
if "current_choices" not in st.session_state:
|
||||
st.session_state.current_choices = []
|
||||
|
||||
# Если все раунды уже завершены, блок пропускается
|
||||
if st.session_state.round_num < TOTAL_ROUNDS:
|
||||
st.subheader(f"Раунд {st.session_state.round_num + 1} из {TOTAL_ROUNDS}")
|
||||
|
||||
# Генерируем выбор для текущего раунда только если он ещё не создан
|
||||
if len(st.session_state.current_choices) == 0:
|
||||
already_chosen = set(st.session_state.chosen_files)
|
||||
available_images = [f for f in image_files if f.name not in already_chosen]
|
||||
|
||||
if len(available_images) < NUM_CHOICES:
|
||||
st.warning("Недостаточно изображений для выбора!")
|
||||
st.stop()
|
||||
|
||||
st.session_state.current_choices = random.sample(available_images, NUM_CHOICES)
|
||||
|
||||
# Отображаем изображения и кнопки
|
||||
cols = st.columns(NUM_CHOICES)
|
||||
for col, img_path in zip(cols, st.session_state.current_choices):
|
||||
# Загружаем изображение и накладываем overlay
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
img_overlay = overlay_top_emotions(img, filename2label[img_path.name])
|
||||
|
||||
# Показываем изображение
|
||||
col.image(img_overlay, width=250)
|
||||
|
||||
# Кнопка выбора
|
||||
if col.button("Выбрать", key=img_path.name):
|
||||
st.session_state.chosen_files.append(img_path.name)
|
||||
st.session_state.round_num += 1
|
||||
st.session_state.current_choices = [] # сброс для следующего раунда
|
||||
st.rerun() # перезапуск для нового раунда
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 6️⃣ После всех выборов: отображение результатов
|
||||
# 3️⃣ Логика приложения
|
||||
# ----------------------------
|
||||
if image_files is None:
|
||||
st.error("Данные EmoSet не найдены. Проверьте папку dataset.")
|
||||
else:
|
||||
st.subheader("Результаты выбора пользователя")
|
||||
if 'round' not in st.session_state:
|
||||
st.session_state.round = 1
|
||||
st.session_state.chosen_indices = []
|
||||
st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES)
|
||||
|
||||
if not st.session_state.chosen_files:
|
||||
st.warning("Вы не сделали ни одного выбора!")
|
||||
if st.button("Начать заново"):
|
||||
st.session_state.round_num = 0
|
||||
st.session_state.chosen_files = []
|
||||
st.rerun()
|
||||
else:
|
||||
# Отображение выбранных изображений
|
||||
st.title("Выбор эмоциональных образов")
|
||||
st.write(f"Раунд {st.session_state.round} из {TOTAL_ROUNDS}. Выберите изображение, которое больше всего соответствует вашему настроению.")
|
||||
|
||||
if st.session_state.round <= TOTAL_ROUNDS:
|
||||
# Отображение сетки изображений
|
||||
cols = st.columns(3)
|
||||
for i, filename in enumerate(st.session_state.chosen_files):
|
||||
for i, idx in enumerate(st.session_state.current_options):
|
||||
with cols[i % 3]:
|
||||
img = Image.open(IMAGES_DIR / filename).convert("RGB")
|
||||
img_overlay = overlay_top_emotions(img, filename2label[filename])
|
||||
st.image(img_overlay, caption=f"Выбор {i+1}")
|
||||
img_path = IMAGES_DIR / image_files[idx]
|
||||
img = Image.open(img_path)
|
||||
st.image(img, use_container_width=True)
|
||||
if st.button(f"Выбрать {i+1}", key=f"btn_{idx}"):
|
||||
st.session_state.chosen_indices.append(idx)
|
||||
st.session_state.round += 1
|
||||
if st.session_state.round <= TOTAL_ROUNDS:
|
||||
st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES)
|
||||
st.rerun()
|
||||
else:
|
||||
# ФИНАЛЬНЫЙ ЭТАП: Анализ и Музыка
|
||||
st.success("Анализ завершен! Ваш эмоциональный профиль сформирован.")
|
||||
|
||||
# Расчет среднего вектора пользователя
|
||||
chosen_embeddings = embeddings[st.session_state.chosen_indices]
|
||||
user_vector = np.mean(chosen_embeddings, axis=0)
|
||||
|
||||
# Расчет среднего embedding
|
||||
chosen_embeddings = [filename2embedding[f] for f in st.session_state.chosen_files]
|
||||
chosen_embeddings = np.stack(chosen_embeddings)
|
||||
user_emotion_vector = np.mean(chosen_embeddings, axis=0)
|
||||
# РАЗДЕЛ МУЗЫКАЛЬНЫХ РЕКОМЕНДАЦИЙ
|
||||
st.divider()
|
||||
st.header("🎵 Рекомендованный плейлист")
|
||||
|
||||
if matcher is None:
|
||||
st.warning("База данных DEAM (music_db.csv) не найдена. Подбор музыки недоступен.")
|
||||
else:
|
||||
with st.spinner("Сопоставляем визуальный профиль с музыкальной базой..."):
|
||||
target_v, target_a, playlist = matcher.get_playlist(user_vector, top_k=5)
|
||||
|
||||
# Визуализация VA-метрик пользователя
|
||||
m1, m2, m3 = st.columns(3)
|
||||
m1.metric("Позитивность (Valence)", f"{target_v:.2f}", help="Шкала 1-9")
|
||||
m2.metric("Энергия (Arousal)", f"{target_a:.2f}", help="Шкала 1-9")
|
||||
m3.metric("Найдено треков", len(playlist))
|
||||
|
||||
# Отображение информации о векторе эмоций
|
||||
with st.expander("Подробности о векторной модели эмоций"):
|
||||
st.write(f"Размерность embedding: {len(user_emotion_vector)}")
|
||||
st.write("Первые 10 значений:")
|
||||
st.json({f"Dim {i}": float(val) for i, val in enumerate(user_emotion_vector[:10])})
|
||||
# Таблица с результатами
|
||||
st.subheader("Топ-5 подходящих композиций")
|
||||
st.table(playlist[['song_id', 'valence', 'arousal', 'distance']])
|
||||
|
||||
st.info("💡 Вы можете найти эти треки по ID в папке audio датасета DEAM.")
|
||||
|
||||
# Визуализация
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
# Гистограмма первых 30 измерений
|
||||
plt.figure(figsize=(8,4))
|
||||
plt.bar(range(min(30, len(user_emotion_vector))), user_emotion_vector[:30])
|
||||
plt.xlabel("Embedding dimension")
|
||||
plt.ylabel("Value")
|
||||
plt.title("Распределение значений embedding (первые 30 измерений)")
|
||||
st.pyplot(plt)
|
||||
|
||||
with col2:
|
||||
# Круговая диаграмма средних значений по блокам
|
||||
if len(user_emotion_vector) > 16:
|
||||
block_size = len(user_emotion_vector) // 4
|
||||
block_means = [
|
||||
np.mean(user_emotion_vector[i*block_size:(i+1)*block_size])
|
||||
for i in range(4)
|
||||
]
|
||||
plt.figure(figsize=(8,4))
|
||||
plt.pie(block_means, labels=[f"Block {i+1}" for i in range(4)], autopct='%1.1f%%')
|
||||
plt.title("Распределение эмоциональных блоков")
|
||||
st.pyplot(plt)
|
||||
|
||||
# Кнопка для сохранения результатов
|
||||
if st.button("Сохранить результаты"):
|
||||
# Сохранение в файл (например, JSON)
|
||||
results = {
|
||||
"chosen_files": st.session_state.chosen_files,
|
||||
"user_emotion_vector": user_emotion_vector.tolist(),
|
||||
"timestamp": pd.Timestamp.now().isoformat()
|
||||
}
|
||||
|
||||
# Сохраняем в текущую директорию
|
||||
save_path = Path("user_results.json")
|
||||
with open(save_path, 'w') as f:
|
||||
import json
|
||||
json.dump(results, f)
|
||||
|
||||
st.success(f"Результаты сохранены в {save_path}")
|
||||
# Визуализация вектора (графики)
|
||||
st.divider()
|
||||
st.subheader("Визуализация эмоционального вектора")
|
||||
fig, ax = plt.subplots(figsize=(10, 3))
|
||||
ax.plot(user_vector[:100]) # Показываем первые 100 измерений для наглядности
|
||||
ax.set_title("Эмбеддинг вашего настроения (фрагмент)")
|
||||
st.pyplot(fig)
|
||||
|
||||
if st.button("Начать заново"):
|
||||
st.session_state.round_num = 0
|
||||
st.session_state.chosen_files = []
|
||||
st.rerun()
|
||||
for key in list(st.session_state.keys()):
|
||||
del st.session_state[key]
|
||||
st.rerun()
|
||||
@@ -0,0 +1,55 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
class MusicMatcher:
|
||||
def __init__(self, db_path: Path):
|
||||
self.music_db = pd.read_csv(db_path)
|
||||
# Убедимся, что данные числовые
|
||||
self.music_db['valence'] = pd.to_numeric(self.music_db['valence'], errors='coerce')
|
||||
self.music_db['arousal'] = pd.to_numeric(self.music_db['arousal'], errors='coerce')
|
||||
self.music_db = self.music_db.dropna()
|
||||
|
||||
def predict_va(self, embedding: np.ndarray):
|
||||
"""
|
||||
Умный хак для демо: используем статистику вектора.
|
||||
Мрачные картинки имеют меньшую среднюю активацию.
|
||||
"""
|
||||
# 1. Средняя сила активации (прокси для Valence)
|
||||
mean_act = float(np.mean(embedding))
|
||||
|
||||
# 2. Разреженность вектора (прокси для Arousal)
|
||||
# Чем больше нулей или близких к нулю значений, тем "спокойнее" картинка
|
||||
sparsity = float(np.mean(embedding < 0.1))
|
||||
|
||||
# Масштабируем: если средняя активация низкая (мрачное), уводим Valence вниз
|
||||
# (Типичный mean_act для ResNet обычно от 0.2 до 0.8)
|
||||
v = np.interp(mean_act, [0.2, 0.6], [2.0, 8.0])
|
||||
|
||||
# Если много нулей (пустота/мрак), Arousal падает.
|
||||
# Если нулей мало (пестрый мусор) - Arousal растет.
|
||||
a = np.interp(sparsity, [0.3, 0.8], [8.0, 2.0]) # Обратная зависимость
|
||||
|
||||
# Добавляем "соль" из суммы вектора, чтобы избежать одинаковых результатов
|
||||
# для похожих, но разных наборов картинок
|
||||
salt = (float(np.sum(embedding)) % 2.0) - 1.0
|
||||
v += salt
|
||||
a += salt
|
||||
|
||||
return np.clip(v, 1.0, 9.0), np.clip(a, 1.0, 9.0)
|
||||
|
||||
def get_playlist(self, user_vector: np.ndarray, top_k: int = 5):
|
||||
target_v, target_a = self.predict_va(user_vector)
|
||||
|
||||
# Считаем Евклидово расстояние от пользователя до всех треков
|
||||
distances = np.sqrt(
|
||||
(self.music_db['valence'] - target_v)**2 +
|
||||
(self.music_db['arousal'] - target_a)**2
|
||||
)
|
||||
|
||||
# Добавляем дистанцию, сортируем по возрастанию (чем меньше, тем ближе)
|
||||
df_result = self.music_db.copy()
|
||||
df_result['distance'] = distances
|
||||
playlist = df_result.sort_values(by='distance').head(top_k)
|
||||
|
||||
return target_v, target_a, playlist
|
||||
Reference in New Issue
Block a user