Files
Thesis/src/main.py
T
2026-05-06 19:48:18 +00:00

145 lines
5.8 KiB
Python

import streamlit as st
from pathlib import Path
import pandas as pd
import numpy as np
from PIL import Image
import random
import matplotlib.pyplot as plt
from music_engine.matcher import MusicMatcher
# ----------------------------
# 1️⃣ Запуск Streamlit
# ----------------------------
if __name__ == "__main__":
import os
if "STREAMLIT_RUN" not in os.environ:
import sys
import subprocess
os.environ["STREAMLIT_RUN"] = "1"
cmd = [
sys.executable, "-m", "streamlit", "run", __file__,
"--server.port", "8080", "--server.address", "0.0.0.0"
]
subprocess.run(cmd)
sys.exit()
# Словарь для отладки
EMO_NAMES = {0: "amusement", 1: "anger", 2: "awe", 3: "contentment",
4: "disgust", 5: "excitement", 6: "fear", 7: "sadness"}
st.set_page_config(page_title="Thesis Demo: Image-Music", layout="wide")
@st.cache_resource
def load_music_engine():
base_dir = Path(__file__).resolve().parent
db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv"
model_path = base_dir / "music_engine" / "va_regressor.pkl"
if not db_path.exists():
return None
return MusicMatcher(db_path=db_path, model_path=model_path)
matcher = load_music_engine()
@st.cache_data
def load_emoset_data():
csv_path = Path("./dataset/EmoSet-118K/test/labels.csv")
img_dir = Path("./dataset/EmoSet-118K/test/images")
emb_path = Path("./src/emoset_test_embeddings.npy")
lbl_path = Path("./src/emoset_test_labels.npy")
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
return None, None, None, None
df = pd.read_csv(csv_path)
image_list = df['filename'].tolist()
embs = np.load(emb_path)
lbls = np.load(lbl_path)
return image_list, embs, lbls, img_dir
image_files, embeddings, labels_array, images_path = load_emoset_data()
# ----------------------------
# 2️⃣ Основной интерфейс
# ----------------------------
if image_files is None:
st.error("Ошибка загрузки данных EmoSet. Проверьте пути.")
else:
if 'round' not in st.session_state:
st.session_state.round = 1
st.session_state.chosen_indices = []
st.session_state.current_options = random.sample(range(len(image_files)), 6)
st.title("🖼️ Эмоциональный подбор музыки")
if st.session_state.round <= 10:
st.subheader(f"Раунд {st.session_state.round} из 10")
st.write("Выберите изображение, соответствующее вашему настроению:")
cols = st.columns(3)
for i, idx in enumerate(st.session_state.current_options):
with cols[i % 3]:
img_name = image_files[idx]
img = Image.open(images_path / img_name)
st.image(img, use_container_width=True)
# Информация для отладки
if matcher:
v_p, a_p = matcher.predict_va(embeddings[idx])
gt_label = EMO_NAMES.get(labels_array[idx], "unknown")
st.caption(f"GT: {gt_label} | Pred: V:{v_p:.1f} A:{a_p:.1f}")
if st.button(f"Выбрать образ {i+1}", key=f"btn_{idx}", use_container_width=True):
st.session_state.chosen_indices.append(idx)
st.session_state.round += 1
if st.session_state.round <= 10:
st.session_state.current_options = random.sample(range(len(image_files)), 6)
st.rerun()
else:
# РЕЗУЛЬТАТЫ
st.success("✅ Анализ завершен! Ваш эмоциональный профиль готов.")
all_v, all_a = [], []
for idx in st.session_state.chosen_indices:
v, a = matcher.predict_va(embeddings[idx])
all_v.append(v)
all_a.append(a)
target_v, target_a = np.mean(all_v), np.mean(all_a)
playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5)
col_left, col_right = st.columns([1, 2])
with col_left:
st.header("📊 Ваш профиль")
st.metric("Позитивность (Valence)", f"{target_v:.2f}")
st.metric("Энергия (Arousal)", f"{target_a:.2f}")
# График Рассела
fig, ax = plt.subplots(figsize=(4, 4))
ax.set_xlim(1, 9); ax.set_ylim(1, 9)
ax.axhline(5, color='gray', lw=1, ls='--')
ax.axvline(5, color='gray', lw=1, ls='--')
ax.scatter(target_v, target_a, color='red', s=150, edgecolors='white', zorder=5)
ax.set_xlabel("Valence"); ax.set_ylabel("Arousal")
ax.set_title("Карта эмоций (Модель Рассела)")
st.pyplot(fig)
with col_right:
st.header("🎵 Рекомендованная музыка")
for _, row in playlist.iterrows():
with st.container(border=True):
c1, c2 = st.columns([1, 3])
with c1:
st.write(f"**ID:** {int(row['song_id'])}")
st.caption(f"L2: {row['distance']:.2f}")
with c2:
audio_path = matcher.get_audio_path(row['song_id'])
if audio_path:
st.audio(str(audio_path))
else:
st.warning(f"Файл {int(row['song_id'])}.mp3 не найден")
if st.button("Начать заново", type="primary"):
for key in list(st.session_state.keys()): del st.session_state[key]
st.rerun()