import streamlit as st from pathlib import Path import pandas as pd import numpy as np from PIL import Image import random import matplotlib.pyplot as plt from music_engine.matcher import MusicMatcher # ---------------------------- # 1️⃣ Запуск Streamlit # ---------------------------- if __name__ == "__main__": import os if "STREAMLIT_RUN" not in os.environ: import sys import subprocess os.environ["STREAMLIT_RUN"] = "1" cmd = [ sys.executable, "-m", "streamlit", "run", __file__, "--server.port", "8080", "--server.address", "0.0.0.0" ] subprocess.run(cmd) sys.exit() # Словарь для отладки EMO_NAMES = {0: "amusement", 1: "anger", 2: "awe", 3: "contentment", 4: "disgust", 5: "excitement", 6: "fear", 7: "sadness"} st.set_page_config(page_title="Thesis Demo: Image-Music", layout="wide") @st.cache_resource def load_music_engine(): base_dir = Path(__file__).resolve().parent db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv" model_path = base_dir / "music_engine" / "va_regressor.pkl" if not db_path.exists(): return None return MusicMatcher(db_path=db_path, model_path=model_path) matcher = load_music_engine() @st.cache_data def load_emoset_data(): csv_path = Path("./dataset/EmoSet-118K/test/labels.csv") img_dir = Path("./dataset/EmoSet-118K/test/images") emb_path = Path("./src/emoset_test_embeddings.npy") lbl_path = Path("./src/emoset_test_labels.npy") if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]): return None, None, None, None df = pd.read_csv(csv_path) image_list = df['filename'].tolist() embs = np.load(emb_path) lbls = np.load(lbl_path) return image_list, embs, lbls, img_dir image_files, embeddings, labels_array, images_path = load_emoset_data() # ---------------------------- # 2️⃣ Основной интерфейс # ---------------------------- if image_files is None: st.error("Ошибка загрузки данных EmoSet. Проверьте пути.") else: if 'round' not in st.session_state: st.session_state.round = 1 st.session_state.chosen_indices = [] st.session_state.current_options = random.sample(range(len(image_files)), 6) st.title("🖼️ Эмоциональный подбор музыки") if st.session_state.round <= 10: st.subheader(f"Раунд {st.session_state.round} из 10") st.write("Выберите изображение, соответствующее вашему настроению:") cols = st.columns(3) for i, idx in enumerate(st.session_state.current_options): with cols[i % 3]: img_name = image_files[idx] img = Image.open(images_path / img_name) st.image(img, use_container_width=True) # Информация для отладки if matcher: v_p, a_p = matcher.predict_va(embeddings[idx]) gt_label = EMO_NAMES.get(labels_array[idx], "unknown") st.caption(f"GT: {gt_label} | Pred: V:{v_p:.1f} A:{a_p:.1f}") if st.button(f"Выбрать образ {i+1}", key=f"btn_{idx}", use_container_width=True): st.session_state.chosen_indices.append(idx) st.session_state.round += 1 if st.session_state.round <= 10: st.session_state.current_options = random.sample(range(len(image_files)), 6) st.rerun() else: # РЕЗУЛЬТАТЫ st.success("✅ Анализ завершен! Ваш эмоциональный профиль готов.") all_v, all_a = [], [] for idx in st.session_state.chosen_indices: v, a = matcher.predict_va(embeddings[idx]) all_v.append(v) all_a.append(a) target_v, target_a = np.mean(all_v), np.mean(all_a) playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5) col_left, col_right = st.columns([1, 2]) with col_left: st.header("📊 Ваш профиль") st.metric("Позитивность (Valence)", f"{target_v:.2f}") st.metric("Энергия (Arousal)", f"{target_a:.2f}") # График Рассела fig, ax = plt.subplots(figsize=(4, 4)) ax.set_xlim(1, 9); ax.set_ylim(1, 9) ax.axhline(5, color='gray', lw=1, ls='--') ax.axvline(5, color='gray', lw=1, ls='--') ax.scatter(target_v, target_a, color='red', s=150, edgecolors='white', zorder=5) ax.set_xlabel("Valence"); ax.set_ylabel("Arousal") ax.set_title("Карта эмоций (Модель Рассела)") st.pyplot(fig) with col_right: st.header("🎵 Рекомендованная музыка") for _, row in playlist.iterrows(): with st.container(border=True): c1, c2 = st.columns([1, 3]) with c1: st.write(f"**ID:** {int(row['song_id'])}") st.caption(f"L2: {row['distance']:.2f}") with c2: audio_path = matcher.get_audio_path(row['song_id']) if audio_path: st.audio(str(audio_path)) else: st.warning(f"Файл {int(row['song_id'])}.mp3 не найден") if st.button("Начать заново", type="primary"): for key in list(st.session_state.keys()): del st.session_state[key] st.rerun()