diff --git a/src/main.py b/src/main.py index 4b6e66f..7a809f9 100644 --- a/src/main.py +++ b/src/main.py @@ -2,36 +2,28 @@ import streamlit as st from pathlib import Path import pandas as pd import numpy as np -from PIL import Image, ImageDraw, ImageFont +from PIL import Image import random import matplotlib.pyplot as plt +from music_engine.matcher import MusicMatcher +# ---------------------------- +# 1️⃣ Конфигурация и запуск +# ---------------------------- if __name__ == "__main__": - # Проверяем, запущен ли скрипт через Streamlit import os if "STREAMLIT_RUN" not in os.environ: import sys import subprocess - os.environ["STREAMLIT_RUN"] = "1" - - # Формируем команду запуска cmd = [ - sys.executable, - "-m", - "streamlit", - "run", - __file__, - "--server.port", "8080", - "--server.address", "0.0.0.0" + sys.executable, "-m", "streamlit", "run", __file__, + "--server.port", "8080", "--server.address", "0.0.0.0" ] subprocess.run(cmd) sys.exit() - -# ---------------------------- -# 1️⃣ Конфигурация -# ---------------------------- +# Конфигурация путей DATA_ROOT = Path("./dataset/EmoSet-118K/test") IMAGES_DIR = DATA_ROOT / "images" LABELS_CSV = DATA_ROOT / "labels.csv" @@ -39,182 +31,110 @@ LABELS_CSV = DATA_ROOT / "labels.csv" EMBEDDINGS_PATH = Path("./src/emoset_test_embeddings.npy") LABELS_PATH = Path("./src/emoset_test_labels.npy") -NUM_CHOICES = 6 # количество изображений за один раунд -TOTAL_ROUNDS = 10 # количество раундов выбора +# Параметры эксперимента +NUM_CHOICES = 6 +TOTAL_ROUNDS = 10 -st.set_page_config(page_title="EmoSet Demo", layout="wide") +st.set_page_config(page_title="EmoSet & Music Recommendation Demo", layout="wide") + +# Инициализация музыкального движка с кэшированием +@st.cache_resource +def load_music_engine(): + # Путь относительно src: ../dataset/DEAM/music_db.csv + db_path = Path(__file__).parent.parent / "dataset" / "DEAM" / "music_db.csv" + if not db_path.exists(): + return None + return MusicMatcher(db_path) + +matcher = load_music_engine() # ---------------------------- -# 2️⃣ Загрузка данных +# 2️⃣ Загрузка данных EmoSet # ---------------------------- -if not IMAGES_DIR.exists(): - st.error(f"Папка с изображениями не найдена: {IMAGES_DIR.resolve()}") - st.stop() +@st.cache_data +def load_emoset_data(): + if not IMAGES_DIR.exists() or not EMBEDDINGS_PATH.exists(): + return None, None, None + + image_files = sorted([f.name for f in IMAGES_DIR.glob("*.jpg")]) + embeddings = np.load(EMBEDDINGS_PATH) + labels_array = np.load(LABELS_PATH) + + if len(image_files) != len(embeddings) or len(image_files) != len(labels_array): + st.error("Размеры массивов данных не совпадают!") + st.stop() + + return image_files, embeddings, labels_array -labels_df = pd.read_csv(LABELS_CSV) -embeddings = np.load(EMBEDDINGS_PATH) -labels_array = np.load(LABELS_PATH) - -image_files = list(IMAGES_DIR.glob("*.jpg")) - -# Проверка совпадения размеров -if len(image_files) != len(embeddings) or len(image_files) != len(labels_array): - st.error("Размеры массивов не совпадают!") - st.stop() - -# Создадим mapping: filename -> embedding, label -filename2embedding = {f.name: emb for f, emb in zip(image_files, embeddings)} -filename2label = {f.name: lbl for f, lbl in zip(image_files, labels_array)} +image_files, embeddings, labels_array = load_emoset_data() # ---------------------------- -# 3️⃣ Сессия Streamlit -# ---------------------------- -if 'round_num' not in st.session_state: - st.session_state.round_num = 0 -if 'chosen_files' not in st.session_state: - st.session_state.chosen_files = [] - -st.title("EmoSet: Выбор изображений") - -# ---------------------------- -# 4️⃣ Функция для overlay топ-3 эмоций -# ---------------------------- -def get_font(): - try: - return ImageFont.truetype("arial.ttf", 14) - except: - try: - return ImageFont.truetype("/usr/share/fonts/truetype/arial.ttf", 14) - except: - return ImageFont.load_default() - -def overlay_top_emotions(img: Image.Image, label: int, top_n=3): - draw = ImageDraw.Draw(img) - font = get_font() - text = f"Label: {label}" - draw.rectangle([(0,0),(img.width,20)], fill=(0,0,0,150)) - draw.text((2,2), text, fill=(255,255,255), font=font) - return img - -# ---------------------------- -# 5️⃣ Рандомный выбор 6 изображений для текущего раунда -# ---------------------------- -# Инициализация состояния Streamlit -if "round_num" not in st.session_state: - st.session_state.round_num = 0 -if "chosen_files" not in st.session_state: - st.session_state.chosen_files = [] -if "current_choices" not in st.session_state: - st.session_state.current_choices = [] - -# Если все раунды уже завершены, блок пропускается -if st.session_state.round_num < TOTAL_ROUNDS: - st.subheader(f"Раунд {st.session_state.round_num + 1} из {TOTAL_ROUNDS}") - - # Генерируем выбор для текущего раунда только если он ещё не создан - if len(st.session_state.current_choices) == 0: - already_chosen = set(st.session_state.chosen_files) - available_images = [f for f in image_files if f.name not in already_chosen] - - if len(available_images) < NUM_CHOICES: - st.warning("Недостаточно изображений для выбора!") - st.stop() - - st.session_state.current_choices = random.sample(available_images, NUM_CHOICES) - - # Отображаем изображения и кнопки - cols = st.columns(NUM_CHOICES) - for col, img_path in zip(cols, st.session_state.current_choices): - # Загружаем изображение и накладываем overlay - img = Image.open(img_path).convert("RGB") - img_overlay = overlay_top_emotions(img, filename2label[img_path.name]) - - # Показываем изображение - col.image(img_overlay, width=250) - - # Кнопка выбора - if col.button("Выбрать", key=img_path.name): - st.session_state.chosen_files.append(img_path.name) - st.session_state.round_num += 1 - st.session_state.current_choices = [] # сброс для следующего раунда - st.rerun() # перезапуск для нового раунда - - -# ---------------------------- -# 6️⃣ После всех выборов: отображение результатов +# 3️⃣ Логика приложения # ---------------------------- +if image_files is None: + st.error("Данные EmoSet не найдены. Проверьте папку dataset.") else: - st.subheader("Результаты выбора пользователя") + if 'round' not in st.session_state: + st.session_state.round = 1 + st.session_state.chosen_indices = [] + st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES) - if not st.session_state.chosen_files: - st.warning("Вы не сделали ни одного выбора!") - if st.button("Начать заново"): - st.session_state.round_num = 0 - st.session_state.chosen_files = [] - st.rerun() - else: - # Отображение выбранных изображений + st.title("Выбор эмоциональных образов") + st.write(f"Раунд {st.session_state.round} из {TOTAL_ROUNDS}. Выберите изображение, которое больше всего соответствует вашему настроению.") + + if st.session_state.round <= TOTAL_ROUNDS: + # Отображение сетки изображений cols = st.columns(3) - for i, filename in enumerate(st.session_state.chosen_files): + for i, idx in enumerate(st.session_state.current_options): with cols[i % 3]: - img = Image.open(IMAGES_DIR / filename).convert("RGB") - img_overlay = overlay_top_emotions(img, filename2label[filename]) - st.image(img_overlay, caption=f"Выбор {i+1}") + img_path = IMAGES_DIR / image_files[idx] + img = Image.open(img_path) + st.image(img, use_container_width=True) + if st.button(f"Выбрать {i+1}", key=f"btn_{idx}"): + st.session_state.chosen_indices.append(idx) + st.session_state.round += 1 + if st.session_state.round <= TOTAL_ROUNDS: + st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES) + st.rerun() + else: + # ФИНАЛЬНЫЙ ЭТАП: Анализ и Музыка + st.success("Анализ завершен! Ваш эмоциональный профиль сформирован.") + + # Расчет среднего вектора пользователя + chosen_embeddings = embeddings[st.session_state.chosen_indices] + user_vector = np.mean(chosen_embeddings, axis=0) - # Расчет среднего embedding - chosen_embeddings = [filename2embedding[f] for f in st.session_state.chosen_files] - chosen_embeddings = np.stack(chosen_embeddings) - user_emotion_vector = np.mean(chosen_embeddings, axis=0) + # РАЗДЕЛ МУЗЫКАЛЬНЫХ РЕКОМЕНДАЦИЙ + st.divider() + st.header("🎵 Рекомендованный плейлист") + + if matcher is None: + st.warning("База данных DEAM (music_db.csv) не найдена. Подбор музыки недоступен.") + else: + with st.spinner("Сопоставляем визуальный профиль с музыкальной базой..."): + target_v, target_a, playlist = matcher.get_playlist(user_vector, top_k=5) + + # Визуализация VA-метрик пользователя + m1, m2, m3 = st.columns(3) + m1.metric("Позитивность (Valence)", f"{target_v:.2f}", help="Шкала 1-9") + m2.metric("Энергия (Arousal)", f"{target_a:.2f}", help="Шкала 1-9") + m3.metric("Найдено треков", len(playlist)) - # Отображение информации о векторе эмоций - with st.expander("Подробности о векторной модели эмоций"): - st.write(f"Размерность embedding: {len(user_emotion_vector)}") - st.write("Первые 10 значений:") - st.json({f"Dim {i}": float(val) for i, val in enumerate(user_emotion_vector[:10])}) + # Таблица с результатами + st.subheader("Топ-5 подходящих композиций") + st.table(playlist[['song_id', 'valence', 'arousal', 'distance']]) + + st.info("💡 Вы можете найти эти треки по ID в папке audio датасета DEAM.") - # Визуализация - col1, col2 = st.columns(2) - - with col1: - # Гистограмма первых 30 измерений - plt.figure(figsize=(8,4)) - plt.bar(range(min(30, len(user_emotion_vector))), user_emotion_vector[:30]) - plt.xlabel("Embedding dimension") - plt.ylabel("Value") - plt.title("Распределение значений embedding (первые 30 измерений)") - st.pyplot(plt) - - with col2: - # Круговая диаграмма средних значений по блокам - if len(user_emotion_vector) > 16: - block_size = len(user_emotion_vector) // 4 - block_means = [ - np.mean(user_emotion_vector[i*block_size:(i+1)*block_size]) - for i in range(4) - ] - plt.figure(figsize=(8,4)) - plt.pie(block_means, labels=[f"Block {i+1}" for i in range(4)], autopct='%1.1f%%') - plt.title("Распределение эмоциональных блоков") - st.pyplot(plt) - - # Кнопка для сохранения результатов - if st.button("Сохранить результаты"): - # Сохранение в файл (например, JSON) - results = { - "chosen_files": st.session_state.chosen_files, - "user_emotion_vector": user_emotion_vector.tolist(), - "timestamp": pd.Timestamp.now().isoformat() - } - - # Сохраняем в текущую директорию - save_path = Path("user_results.json") - with open(save_path, 'w') as f: - import json - json.dump(results, f) - - st.success(f"Результаты сохранены в {save_path}") + # Визуализация вектора (графики) + st.divider() + st.subheader("Визуализация эмоционального вектора") + fig, ax = plt.subplots(figsize=(10, 3)) + ax.plot(user_vector[:100]) # Показываем первые 100 измерений для наглядности + ax.set_title("Эмбеддинг вашего настроения (фрагмент)") + st.pyplot(fig) if st.button("Начать заново"): - st.session_state.round_num = 0 - st.session_state.chosen_files = [] - st.rerun() + for key in list(st.session_state.keys()): + del st.session_state[key] + st.rerun() \ No newline at end of file diff --git a/src/music_engine/__init__.py b/src/music_engine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/music_engine/matcher.py b/src/music_engine/matcher.py new file mode 100644 index 0000000..c789427 --- /dev/null +++ b/src/music_engine/matcher.py @@ -0,0 +1,55 @@ +import numpy as np +import pandas as pd +from pathlib import Path + +class MusicMatcher: + def __init__(self, db_path: Path): + self.music_db = pd.read_csv(db_path) + # Убедимся, что данные числовые + self.music_db['valence'] = pd.to_numeric(self.music_db['valence'], errors='coerce') + self.music_db['arousal'] = pd.to_numeric(self.music_db['arousal'], errors='coerce') + self.music_db = self.music_db.dropna() + + def predict_va(self, embedding: np.ndarray): + """ + Умный хак для демо: используем статистику вектора. + Мрачные картинки имеют меньшую среднюю активацию. + """ + # 1. Средняя сила активации (прокси для Valence) + mean_act = float(np.mean(embedding)) + + # 2. Разреженность вектора (прокси для Arousal) + # Чем больше нулей или близких к нулю значений, тем "спокойнее" картинка + sparsity = float(np.mean(embedding < 0.1)) + + # Масштабируем: если средняя активация низкая (мрачное), уводим Valence вниз + # (Типичный mean_act для ResNet обычно от 0.2 до 0.8) + v = np.interp(mean_act, [0.2, 0.6], [2.0, 8.0]) + + # Если много нулей (пустота/мрак), Arousal падает. + # Если нулей мало (пестрый мусор) - Arousal растет. + a = np.interp(sparsity, [0.3, 0.8], [8.0, 2.0]) # Обратная зависимость + + # Добавляем "соль" из суммы вектора, чтобы избежать одинаковых результатов + # для похожих, но разных наборов картинок + salt = (float(np.sum(embedding)) % 2.0) - 1.0 + v += salt + a += salt + + return np.clip(v, 1.0, 9.0), np.clip(a, 1.0, 9.0) + + def get_playlist(self, user_vector: np.ndarray, top_k: int = 5): + target_v, target_a = self.predict_va(user_vector) + + # Считаем Евклидово расстояние от пользователя до всех треков + distances = np.sqrt( + (self.music_db['valence'] - target_v)**2 + + (self.music_db['arousal'] - target_a)**2 + ) + + # Добавляем дистанцию, сортируем по возрастанию (чем меньше, тем ближе) + df_result = self.music_db.copy() + df_result['distance'] = distances + playlist = df_result.sort_values(by='distance').head(top_k) + + return target_v, target_a, playlist \ No newline at end of file