Refactored main.py

This commit is contained in:
zin
2026-05-06 20:12:14 +00:00
parent 95595a5a5e
commit 5290554d70
3 changed files with 143 additions and 126 deletions
+33
View File
@@ -0,0 +1,33 @@
import streamlit as st
from pathlib import Path
import pandas as pd
import numpy as np
from music_engine.matcher import MusicMatcher
@st.cache_resource
def load_music_engine():
"""Загрузка базы данных и модели регрессора."""
base_dir = Path(__file__).resolve().parent
db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv"
model_path = base_dir / "music_engine" / "va_regressor.pkl"
if not db_path.exists():
return None
return MusicMatcher(db_path=db_path, model_path=model_path)
@st.cache_data
def load_emoset_data():
"""Загрузка тестовой выборки EmoSet для первой вкладки."""
csv_path = Path("./dataset/EmoSet-118K/test/labels.csv")
img_dir = Path("./dataset/EmoSet-118K/test/images")
emb_path = Path("./src/emoset_test_embeddings.npy")
lbl_path = Path("./src/emoset_test_labels.npy")
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
return None, None, None, None
df = pd.read_csv(csv_path)
image_list = df['filename'].tolist()
embs = np.load(emb_path)
lbls = np.load(lbl_path)
return image_list, embs, lbls, img_dir
+21 -126
View File
@@ -1,145 +1,40 @@
import streamlit as st import streamlit as st
from pathlib import Path import sys
import pandas as pd import os
import numpy as np import subprocess
from PIL import Image from data_loader import load_music_engine, load_emoset_data
import random from tabs.tab_dataset import render_dataset_tab
import matplotlib.pyplot as plt
from music_engine.matcher import MusicMatcher
# ---------------------------- # ----------------------------
# 1️⃣ Запуск Streamlit # 1️⃣ Запуск приложения
# ---------------------------- # ----------------------------
if __name__ == "__main__": if __name__ == "__main__":
import os
if "STREAMLIT_RUN" not in os.environ: if "STREAMLIT_RUN" not in os.environ:
import sys
import subprocess
os.environ["STREAMLIT_RUN"] = "1" os.environ["STREAMLIT_RUN"] = "1"
cmd = [ cmd = [sys.executable, "-m", "streamlit", "run", __file__, "--server.port", "8080", "--server.address", "0.0.0.0"]
sys.executable, "-m", "streamlit", "run", __file__,
"--server.port", "8080", "--server.address", "0.0.0.0"
]
subprocess.run(cmd) subprocess.run(cmd)
sys.exit() sys.exit()
# Словарь для отладки st.set_page_config(page_title="Thesis Demo", layout="wide")
EMO_NAMES = {0: "amusement", 1: "anger", 2: "awe", 3: "contentment",
4: "disgust", 5: "excitement", 6: "fear", 7: "sadness"}
st.set_page_config(page_title="Thesis Demo: Image-Music", layout="wide")
@st.cache_resource
def load_music_engine():
base_dir = Path(__file__).resolve().parent
db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv"
model_path = base_dir / "music_engine" / "va_regressor.pkl"
if not db_path.exists():
return None
return MusicMatcher(db_path=db_path, model_path=model_path)
# ----------------------------
# 2️⃣ Инициализация движка и данных
# ----------------------------
matcher = load_music_engine() matcher = load_music_engine()
@st.cache_data
def load_emoset_data():
csv_path = Path("./dataset/EmoSet-118K/test/labels.csv")
img_dir = Path("./dataset/EmoSet-118K/test/images")
emb_path = Path("./src/emoset_test_embeddings.npy")
lbl_path = Path("./src/emoset_test_labels.npy")
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
return None, None, None, None
df = pd.read_csv(csv_path)
image_list = df['filename'].tolist()
embs = np.load(emb_path)
lbls = np.load(lbl_path)
return image_list, embs, lbls, img_dir
image_files, embeddings, labels_array, images_path = load_emoset_data() image_files, embeddings, labels_array, images_path = load_emoset_data()
# ---------------------------- # ----------------------------
# 2️⃣ Основной интерфейс # 3️⃣ Интерфейс и Вкладки
# ---------------------------- # ----------------------------
if image_files is None: st.title("🖼️ Эмоциональный генератор плейлистов")
st.error("Ошибка загрузки данных EmoSet. Проверьте пути.")
else:
if 'round' not in st.session_state:
st.session_state.round = 1
st.session_state.chosen_indices = []
st.session_state.current_options = random.sample(range(len(image_files)), 6)
st.title("🖼️ Эмоциональный подбор музыки") # Создаем две вкладки
tab1, tab2 = st.tabs(["📊 Анализ EmoSet (Отладка)", "📸 Анализ своих фото (Live)"])
if st.session_state.round <= 10: with tab1:
st.subheader(f"Раунд {st.session_state.round} из 10") render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path)
st.write("Выберите изображение, соответствующее вашему настроению:")
cols = st.columns(3)
for i, idx in enumerate(st.session_state.current_options):
with cols[i % 3]:
img_name = image_files[idx]
img = Image.open(images_path / img_name)
st.image(img, use_container_width=True)
# Информация для отладки
if matcher:
v_p, a_p = matcher.predict_va(embeddings[idx])
gt_label = EMO_NAMES.get(labels_array[idx], "unknown")
st.caption(f"GT: {gt_label} | Pred: V:{v_p:.1f} A:{a_p:.1f}")
if st.button(f"Выбрать образ {i+1}", key=f"btn_{idx}", use_container_width=True):
st.session_state.chosen_indices.append(idx)
st.session_state.round += 1
if st.session_state.round <= 10:
st.session_state.current_options = random.sample(range(len(image_files)), 6)
st.rerun()
else:
# РЕЗУЛЬТАТЫ
st.success("✅ Анализ завершен! Ваш эмоциональный профиль готов.")
all_v, all_a = [], []
for idx in st.session_state.chosen_indices:
v, a = matcher.predict_va(embeddings[idx])
all_v.append(v)
all_a.append(a)
target_v, target_a = np.mean(all_v), np.mean(all_a)
playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5)
col_left, col_right = st.columns([1, 2]) with tab2:
st.info("🚀 Модуль загрузки пользовательских фотографий и извлечения признаков 'на лету'.")
with col_left: st.write("Скоро здесь появится drag-and-drop интерфейс для тестирования ваших собственных изображений.")
st.header("📊 Ваш профиль") # TODO: render_live_tab(matcher)
st.metric("Позитивность (Valence)", f"{target_v:.2f}")
st.metric("Энергия (Arousal)", f"{target_a:.2f}")
# График Рассела
fig, ax = plt.subplots(figsize=(4, 4))
ax.set_xlim(1, 9); ax.set_ylim(1, 9)
ax.axhline(5, color='gray', lw=1, ls='--')
ax.axvline(5, color='gray', lw=1, ls='--')
ax.scatter(target_v, target_a, color='red', s=150, edgecolors='white', zorder=5)
ax.set_xlabel("Valence"); ax.set_ylabel("Arousal")
ax.set_title("Карта эмоций (Модель Рассела)")
st.pyplot(fig)
with col_right:
st.header("🎵 Рекомендованная музыка")
for _, row in playlist.iterrows():
with st.container(border=True):
c1, c2 = st.columns([1, 3])
with c1:
st.write(f"**ID:** {int(row['song_id'])}")
st.caption(f"L2: {row['distance']:.2f}")
with c2:
audio_path = matcher.get_audio_path(row['song_id'])
if audio_path:
st.audio(str(audio_path))
else:
st.warning(f"Файл {int(row['song_id'])}.mp3 не найден")
if st.button("Начать заново", type="primary"):
for key in list(st.session_state.keys()): del st.session_state[key]
st.rerun()
+89
View File
@@ -0,0 +1,89 @@
import streamlit as st
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
EMO_NAMES = {0: "amusement", 1: "anger", 2: "awe", 3: "contentment",
4: "disgust", 5: "excitement", 6: "fear", 7: "sadness"}
def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path):
if image_files is None:
st.error("Ошибка загрузки данных EmoSet. Проверьте пути.")
return
# Инициализация состояния именно для этой вкладки
if 'ds_round' not in st.session_state:
st.session_state.ds_round = 1
st.session_state.ds_chosen_indices = []
st.session_state.ds_current_options = random.sample(range(len(image_files)), 6)
st.write("Выберите изображение, соответствующее вашему настроению:")
if st.session_state.ds_round <= 10:
st.subheader(f"Раунд {st.session_state.ds_round} из 10")
cols = st.columns(3)
for i, idx in enumerate(st.session_state.ds_current_options):
with cols[i % 3]:
img_name = image_files[idx]
img = Image.open(images_path / img_name)
st.image(img, use_container_width=True)
if matcher:
v_p, a_p = matcher.predict_va(embeddings[idx])
gt_label = EMO_NAMES.get(labels_array[idx], "unknown")
st.caption(f"GT: {gt_label} | Pred: V:{v_p:.1f} A:{a_p:.1f}")
if st.button(f"Выбрать образ {i+1}", key=f"btn_ds_{idx}", use_container_width=True):
st.session_state.ds_chosen_indices.append(idx)
st.session_state.ds_round += 1
if st.session_state.ds_round <= 10:
st.session_state.ds_current_options = random.sample(range(len(image_files)), 6)
st.rerun()
else:
st.success("✅ Анализ завершен! Ваш эмоциональный профиль готов.")
all_v, all_a = [], []
for idx in st.session_state.ds_chosen_indices:
v, a = matcher.predict_va(embeddings[idx])
all_v.append(v)
all_a.append(a)
target_v, target_a = np.mean(all_v), np.mean(all_a)
playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5)
col_left, col_right = st.columns([1, 2])
with col_left:
st.header("📊 Ваш профиль")
st.metric("Позитивность (Valence)", f"{target_v:.2f}")
st.metric("Энергия (Arousal)", f"{target_a:.2f}")
fig, ax = plt.subplots(figsize=(4, 4))
ax.set_xlim(1, 9); ax.set_ylim(1, 9)
ax.axhline(5, color='gray', lw=1, ls='--'); ax.axvline(5, color='gray', lw=1, ls='--')
ax.scatter(target_v, target_a, color='red', s=150, edgecolors='white', zorder=5)
ax.set_xlabel("Valence"); ax.set_ylabel("Arousal")
st.pyplot(fig)
with col_right:
st.header("🎵 Рекомендованная музыка")
for _, row in playlist.iterrows():
with st.container(border=True):
c1, c2 = st.columns([1, 3])
with c1:
st.write(f"**ID:** {int(row['song_id'])}")
st.caption(f"L2 Dist: {row['distance']:.2f}")
with c2:
audio_path = matcher.get_audio_path(row['song_id'])
if audio_path:
st.audio(str(audio_path))
else:
st.warning(f"Файл {int(row['song_id'])}.mp3 не найден")
if st.button("Начать заново", type="primary"):
st.session_state.pop('ds_round', None)
st.session_state.pop('ds_chosen_indices', None)
st.session_state.pop('ds_current_options', None)
st.rerun()