From 5290554d70653d7d8c8797898516892822a29c14 Mon Sep 17 00:00:00 2001 From: zin Date: Wed, 6 May 2026 20:12:14 +0000 Subject: [PATCH] Refactored main.py --- src/data_loader.py | 33 +++++++++ src/main.py | 147 ++++++---------------------------------- src/tabs/tab_dataset.py | 89 ++++++++++++++++++++++++ 3 files changed, 143 insertions(+), 126 deletions(-) create mode 100644 src/data_loader.py create mode 100644 src/tabs/tab_dataset.py diff --git a/src/data_loader.py b/src/data_loader.py new file mode 100644 index 0000000..917b162 --- /dev/null +++ b/src/data_loader.py @@ -0,0 +1,33 @@ +import streamlit as st +from pathlib import Path +import pandas as pd +import numpy as np +from music_engine.matcher import MusicMatcher + +@st.cache_resource +def load_music_engine(): + """Загрузка базы данных и модели регрессора.""" + base_dir = Path(__file__).resolve().parent + db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv" + model_path = base_dir / "music_engine" / "va_regressor.pkl" + if not db_path.exists(): + return None + return MusicMatcher(db_path=db_path, model_path=model_path) + +@st.cache_data +def load_emoset_data(): + """Загрузка тестовой выборки EmoSet для первой вкладки.""" + csv_path = Path("./dataset/EmoSet-118K/test/labels.csv") + img_dir = Path("./dataset/EmoSet-118K/test/images") + emb_path = Path("./src/emoset_test_embeddings.npy") + lbl_path = Path("./src/emoset_test_labels.npy") + + if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]): + return None, None, None, None + + df = pd.read_csv(csv_path) + image_list = df['filename'].tolist() + embs = np.load(emb_path) + lbls = np.load(lbl_path) + + return image_list, embs, lbls, img_dir \ No newline at end of file diff --git a/src/main.py b/src/main.py index 2940e5a..5c38129 100644 --- a/src/main.py +++ b/src/main.py @@ -1,145 +1,40 @@ import streamlit as st -from pathlib import Path -import pandas as pd -import numpy as np -from PIL import Image -import random -import matplotlib.pyplot as plt -from music_engine.matcher import MusicMatcher +import sys +import os +import subprocess +from data_loader import load_music_engine, load_emoset_data +from tabs.tab_dataset import render_dataset_tab # ---------------------------- -# 1️⃣ Запуск Streamlit +# 1️⃣ Запуск приложения # ---------------------------- if __name__ == "__main__": - import os if "STREAMLIT_RUN" not in os.environ: - import sys - import subprocess os.environ["STREAMLIT_RUN"] = "1" - cmd = [ - sys.executable, "-m", "streamlit", "run", __file__, - "--server.port", "8080", "--server.address", "0.0.0.0" - ] + cmd = [sys.executable, "-m", "streamlit", "run", __file__, "--server.port", "8080", "--server.address", "0.0.0.0"] subprocess.run(cmd) sys.exit() -# Словарь для отладки -EMO_NAMES = {0: "amusement", 1: "anger", 2: "awe", 3: "contentment", - 4: "disgust", 5: "excitement", 6: "fear", 7: "sadness"} - -st.set_page_config(page_title="Thesis Demo: Image-Music", layout="wide") - -@st.cache_resource -def load_music_engine(): - base_dir = Path(__file__).resolve().parent - db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv" - model_path = base_dir / "music_engine" / "va_regressor.pkl" - if not db_path.exists(): - return None - return MusicMatcher(db_path=db_path, model_path=model_path) +st.set_page_config(page_title="Thesis Demo", layout="wide") +# ---------------------------- +# 2️⃣ Инициализация движка и данных +# ---------------------------- matcher = load_music_engine() - -@st.cache_data -def load_emoset_data(): - csv_path = Path("./dataset/EmoSet-118K/test/labels.csv") - img_dir = Path("./dataset/EmoSet-118K/test/images") - emb_path = Path("./src/emoset_test_embeddings.npy") - lbl_path = Path("./src/emoset_test_labels.npy") - - if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]): - return None, None, None, None - - df = pd.read_csv(csv_path) - image_list = df['filename'].tolist() - embs = np.load(emb_path) - lbls = np.load(lbl_path) - - return image_list, embs, lbls, img_dir - image_files, embeddings, labels_array, images_path = load_emoset_data() # ---------------------------- -# 2️⃣ Основной интерфейс +# 3️⃣ Интерфейс и Вкладки # ---------------------------- -if image_files is None: - st.error("Ошибка загрузки данных EmoSet. Проверьте пути.") -else: - if 'round' not in st.session_state: - st.session_state.round = 1 - st.session_state.chosen_indices = [] - st.session_state.current_options = random.sample(range(len(image_files)), 6) +st.title("🖼️ Эмоциональный генератор плейлистов") - st.title("🖼️ Эмоциональный подбор музыки") +# Создаем две вкладки +tab1, tab2 = st.tabs(["📊 Анализ EmoSet (Отладка)", "📸 Анализ своих фото (Live)"]) - if st.session_state.round <= 10: - st.subheader(f"Раунд {st.session_state.round} из 10") - st.write("Выберите изображение, соответствующее вашему настроению:") - - cols = st.columns(3) - for i, idx in enumerate(st.session_state.current_options): - with cols[i % 3]: - img_name = image_files[idx] - img = Image.open(images_path / img_name) - st.image(img, use_container_width=True) - - # Информация для отладки - if matcher: - v_p, a_p = matcher.predict_va(embeddings[idx]) - gt_label = EMO_NAMES.get(labels_array[idx], "unknown") - st.caption(f"GT: {gt_label} | Pred: V:{v_p:.1f} A:{a_p:.1f}") - - if st.button(f"Выбрать образ {i+1}", key=f"btn_{idx}", use_container_width=True): - st.session_state.chosen_indices.append(idx) - st.session_state.round += 1 - if st.session_state.round <= 10: - st.session_state.current_options = random.sample(range(len(image_files)), 6) - st.rerun() - else: - # РЕЗУЛЬТАТЫ - st.success("✅ Анализ завершен! Ваш эмоциональный профиль готов.") - - all_v, all_a = [], [] - for idx in st.session_state.chosen_indices: - v, a = matcher.predict_va(embeddings[idx]) - all_v.append(v) - all_a.append(a) - - target_v, target_a = np.mean(all_v), np.mean(all_a) - playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5) +with tab1: + render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path) - col_left, col_right = st.columns([1, 2]) - - with col_left: - st.header("📊 Ваш профиль") - st.metric("Позитивность (Valence)", f"{target_v:.2f}") - st.metric("Энергия (Arousal)", f"{target_a:.2f}") - - # График Рассела - fig, ax = plt.subplots(figsize=(4, 4)) - ax.set_xlim(1, 9); ax.set_ylim(1, 9) - ax.axhline(5, color='gray', lw=1, ls='--') - ax.axvline(5, color='gray', lw=1, ls='--') - ax.scatter(target_v, target_a, color='red', s=150, edgecolors='white', zorder=5) - ax.set_xlabel("Valence"); ax.set_ylabel("Arousal") - ax.set_title("Карта эмоций (Модель Рассела)") - st.pyplot(fig) - - with col_right: - st.header("🎵 Рекомендованная музыка") - for _, row in playlist.iterrows(): - with st.container(border=True): - c1, c2 = st.columns([1, 3]) - with c1: - st.write(f"**ID:** {int(row['song_id'])}") - st.caption(f"L2: {row['distance']:.2f}") - with c2: - audio_path = matcher.get_audio_path(row['song_id']) - if audio_path: - st.audio(str(audio_path)) - else: - st.warning(f"Файл {int(row['song_id'])}.mp3 не найден") - - if st.button("Начать заново", type="primary"): - for key in list(st.session_state.keys()): del st.session_state[key] - st.rerun() \ No newline at end of file +with tab2: + st.info("🚀 Модуль загрузки пользовательских фотографий и извлечения признаков 'на лету'.") + st.write("Скоро здесь появится drag-and-drop интерфейс для тестирования ваших собственных изображений.") + # TODO: render_live_tab(matcher) \ No newline at end of file diff --git a/src/tabs/tab_dataset.py b/src/tabs/tab_dataset.py new file mode 100644 index 0000000..42b93d5 --- /dev/null +++ b/src/tabs/tab_dataset.py @@ -0,0 +1,89 @@ +import streamlit as st +import random +import numpy as np +from PIL import Image +import matplotlib.pyplot as plt + +EMO_NAMES = {0: "amusement", 1: "anger", 2: "awe", 3: "contentment", + 4: "disgust", 5: "excitement", 6: "fear", 7: "sadness"} + +def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path): + if image_files is None: + st.error("Ошибка загрузки данных EmoSet. Проверьте пути.") + return + + # Инициализация состояния именно для этой вкладки + if 'ds_round' not in st.session_state: + st.session_state.ds_round = 1 + st.session_state.ds_chosen_indices = [] + st.session_state.ds_current_options = random.sample(range(len(image_files)), 6) + + st.write("Выберите изображение, соответствующее вашему настроению:") + + if st.session_state.ds_round <= 10: + st.subheader(f"Раунд {st.session_state.ds_round} из 10") + + cols = st.columns(3) + for i, idx in enumerate(st.session_state.ds_current_options): + with cols[i % 3]: + img_name = image_files[idx] + img = Image.open(images_path / img_name) + st.image(img, use_container_width=True) + + if matcher: + v_p, a_p = matcher.predict_va(embeddings[idx]) + gt_label = EMO_NAMES.get(labels_array[idx], "unknown") + st.caption(f"GT: {gt_label} | Pred: V:{v_p:.1f} A:{a_p:.1f}") + + if st.button(f"Выбрать образ {i+1}", key=f"btn_ds_{idx}", use_container_width=True): + st.session_state.ds_chosen_indices.append(idx) + st.session_state.ds_round += 1 + if st.session_state.ds_round <= 10: + st.session_state.ds_current_options = random.sample(range(len(image_files)), 6) + st.rerun() + else: + st.success("✅ Анализ завершен! Ваш эмоциональный профиль готов.") + + all_v, all_a = [], [] + for idx in st.session_state.ds_chosen_indices: + v, a = matcher.predict_va(embeddings[idx]) + all_v.append(v) + all_a.append(a) + + target_v, target_a = np.mean(all_v), np.mean(all_a) + playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5) + + col_left, col_right = st.columns([1, 2]) + + with col_left: + st.header("📊 Ваш профиль") + st.metric("Позитивность (Valence)", f"{target_v:.2f}") + st.metric("Энергия (Arousal)", f"{target_a:.2f}") + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.set_xlim(1, 9); ax.set_ylim(1, 9) + ax.axhline(5, color='gray', lw=1, ls='--'); ax.axvline(5, color='gray', lw=1, ls='--') + ax.scatter(target_v, target_a, color='red', s=150, edgecolors='white', zorder=5) + ax.set_xlabel("Valence"); ax.set_ylabel("Arousal") + st.pyplot(fig) + + with col_right: + st.header("🎵 Рекомендованная музыка") + for _, row in playlist.iterrows(): + with st.container(border=True): + c1, c2 = st.columns([1, 3]) + with c1: + st.write(f"**ID:** {int(row['song_id'])}") + st.caption(f"L2 Dist: {row['distance']:.2f}") + with c2: + audio_path = matcher.get_audio_path(row['song_id']) + if audio_path: + st.audio(str(audio_path)) + else: + st.warning(f"Файл {int(row['song_id'])}.mp3 не найден") + + if st.button("Начать заново", type="primary"): + st.session_state.pop('ds_round', None) + st.session_state.pop('ds_chosen_indices', None) + st.session_state.pop('ds_current_options', None) + st.rerun() \ No newline at end of file