Refactored main.py
This commit is contained in:
@@ -0,0 +1,33 @@
|
|||||||
|
import streamlit as st
|
||||||
|
from pathlib import Path
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from music_engine.matcher import MusicMatcher
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def load_music_engine():
|
||||||
|
"""Загрузка базы данных и модели регрессора."""
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv"
|
||||||
|
model_path = base_dir / "music_engine" / "va_regressor.pkl"
|
||||||
|
if not db_path.exists():
|
||||||
|
return None
|
||||||
|
return MusicMatcher(db_path=db_path, model_path=model_path)
|
||||||
|
|
||||||
|
@st.cache_data
|
||||||
|
def load_emoset_data():
|
||||||
|
"""Загрузка тестовой выборки EmoSet для первой вкладки."""
|
||||||
|
csv_path = Path("./dataset/EmoSet-118K/test/labels.csv")
|
||||||
|
img_dir = Path("./dataset/EmoSet-118K/test/images")
|
||||||
|
emb_path = Path("./src/emoset_test_embeddings.npy")
|
||||||
|
lbl_path = Path("./src/emoset_test_labels.npy")
|
||||||
|
|
||||||
|
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
image_list = df['filename'].tolist()
|
||||||
|
embs = np.load(emb_path)
|
||||||
|
lbls = np.load(lbl_path)
|
||||||
|
|
||||||
|
return image_list, embs, lbls, img_dir
|
||||||
+21
-126
@@ -1,145 +1,40 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
from pathlib import Path
|
import sys
|
||||||
import pandas as pd
|
import os
|
||||||
import numpy as np
|
import subprocess
|
||||||
from PIL import Image
|
from data_loader import load_music_engine, load_emoset_data
|
||||||
import random
|
from tabs.tab_dataset import render_dataset_tab
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from music_engine.matcher import MusicMatcher
|
|
||||||
|
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
# 1️⃣ Запуск Streamlit
|
# 1️⃣ Запуск приложения
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import os
|
|
||||||
if "STREAMLIT_RUN" not in os.environ:
|
if "STREAMLIT_RUN" not in os.environ:
|
||||||
import sys
|
|
||||||
import subprocess
|
|
||||||
os.environ["STREAMLIT_RUN"] = "1"
|
os.environ["STREAMLIT_RUN"] = "1"
|
||||||
cmd = [
|
cmd = [sys.executable, "-m", "streamlit", "run", __file__, "--server.port", "8080", "--server.address", "0.0.0.0"]
|
||||||
sys.executable, "-m", "streamlit", "run", __file__,
|
|
||||||
"--server.port", "8080", "--server.address", "0.0.0.0"
|
|
||||||
]
|
|
||||||
subprocess.run(cmd)
|
subprocess.run(cmd)
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
# Словарь для отладки
|
st.set_page_config(page_title="Thesis Demo", layout="wide")
|
||||||
EMO_NAMES = {0: "amusement", 1: "anger", 2: "awe", 3: "contentment",
|
|
||||||
4: "disgust", 5: "excitement", 6: "fear", 7: "sadness"}
|
|
||||||
|
|
||||||
st.set_page_config(page_title="Thesis Demo: Image-Music", layout="wide")
|
|
||||||
|
|
||||||
@st.cache_resource
|
|
||||||
def load_music_engine():
|
|
||||||
base_dir = Path(__file__).resolve().parent
|
|
||||||
db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv"
|
|
||||||
model_path = base_dir / "music_engine" / "va_regressor.pkl"
|
|
||||||
if not db_path.exists():
|
|
||||||
return None
|
|
||||||
return MusicMatcher(db_path=db_path, model_path=model_path)
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# 2️⃣ Инициализация движка и данных
|
||||||
|
# ----------------------------
|
||||||
matcher = load_music_engine()
|
matcher = load_music_engine()
|
||||||
|
|
||||||
@st.cache_data
|
|
||||||
def load_emoset_data():
|
|
||||||
csv_path = Path("./dataset/EmoSet-118K/test/labels.csv")
|
|
||||||
img_dir = Path("./dataset/EmoSet-118K/test/images")
|
|
||||||
emb_path = Path("./src/emoset_test_embeddings.npy")
|
|
||||||
lbl_path = Path("./src/emoset_test_labels.npy")
|
|
||||||
|
|
||||||
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
|
|
||||||
return None, None, None, None
|
|
||||||
|
|
||||||
df = pd.read_csv(csv_path)
|
|
||||||
image_list = df['filename'].tolist()
|
|
||||||
embs = np.load(emb_path)
|
|
||||||
lbls = np.load(lbl_path)
|
|
||||||
|
|
||||||
return image_list, embs, lbls, img_dir
|
|
||||||
|
|
||||||
image_files, embeddings, labels_array, images_path = load_emoset_data()
|
image_files, embeddings, labels_array, images_path = load_emoset_data()
|
||||||
|
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
# 2️⃣ Основной интерфейс
|
# 3️⃣ Интерфейс и Вкладки
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
if image_files is None:
|
st.title("🖼️ Эмоциональный генератор плейлистов")
|
||||||
st.error("Ошибка загрузки данных EmoSet. Проверьте пути.")
|
|
||||||
else:
|
|
||||||
if 'round' not in st.session_state:
|
|
||||||
st.session_state.round = 1
|
|
||||||
st.session_state.chosen_indices = []
|
|
||||||
st.session_state.current_options = random.sample(range(len(image_files)), 6)
|
|
||||||
|
|
||||||
st.title("🖼️ Эмоциональный подбор музыки")
|
# Создаем две вкладки
|
||||||
|
tab1, tab2 = st.tabs(["📊 Анализ EmoSet (Отладка)", "📸 Анализ своих фото (Live)"])
|
||||||
|
|
||||||
if st.session_state.round <= 10:
|
with tab1:
|
||||||
st.subheader(f"Раунд {st.session_state.round} из 10")
|
render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path)
|
||||||
st.write("Выберите изображение, соответствующее вашему настроению:")
|
|
||||||
|
|
||||||
cols = st.columns(3)
|
|
||||||
for i, idx in enumerate(st.session_state.current_options):
|
|
||||||
with cols[i % 3]:
|
|
||||||
img_name = image_files[idx]
|
|
||||||
img = Image.open(images_path / img_name)
|
|
||||||
st.image(img, use_container_width=True)
|
|
||||||
|
|
||||||
# Информация для отладки
|
|
||||||
if matcher:
|
|
||||||
v_p, a_p = matcher.predict_va(embeddings[idx])
|
|
||||||
gt_label = EMO_NAMES.get(labels_array[idx], "unknown")
|
|
||||||
st.caption(f"GT: {gt_label} | Pred: V:{v_p:.1f} A:{a_p:.1f}")
|
|
||||||
|
|
||||||
if st.button(f"Выбрать образ {i+1}", key=f"btn_{idx}", use_container_width=True):
|
|
||||||
st.session_state.chosen_indices.append(idx)
|
|
||||||
st.session_state.round += 1
|
|
||||||
if st.session_state.round <= 10:
|
|
||||||
st.session_state.current_options = random.sample(range(len(image_files)), 6)
|
|
||||||
st.rerun()
|
|
||||||
else:
|
|
||||||
# РЕЗУЛЬТАТЫ
|
|
||||||
st.success("✅ Анализ завершен! Ваш эмоциональный профиль готов.")
|
|
||||||
|
|
||||||
all_v, all_a = [], []
|
|
||||||
for idx in st.session_state.chosen_indices:
|
|
||||||
v, a = matcher.predict_va(embeddings[idx])
|
|
||||||
all_v.append(v)
|
|
||||||
all_a.append(a)
|
|
||||||
|
|
||||||
target_v, target_a = np.mean(all_v), np.mean(all_a)
|
|
||||||
playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5)
|
|
||||||
|
|
||||||
col_left, col_right = st.columns([1, 2])
|
with tab2:
|
||||||
|
st.info("🚀 Модуль загрузки пользовательских фотографий и извлечения признаков 'на лету'.")
|
||||||
with col_left:
|
st.write("Скоро здесь появится drag-and-drop интерфейс для тестирования ваших собственных изображений.")
|
||||||
st.header("📊 Ваш профиль")
|
# TODO: render_live_tab(matcher)
|
||||||
st.metric("Позитивность (Valence)", f"{target_v:.2f}")
|
|
||||||
st.metric("Энергия (Arousal)", f"{target_a:.2f}")
|
|
||||||
|
|
||||||
# График Рассела
|
|
||||||
fig, ax = plt.subplots(figsize=(4, 4))
|
|
||||||
ax.set_xlim(1, 9); ax.set_ylim(1, 9)
|
|
||||||
ax.axhline(5, color='gray', lw=1, ls='--')
|
|
||||||
ax.axvline(5, color='gray', lw=1, ls='--')
|
|
||||||
ax.scatter(target_v, target_a, color='red', s=150, edgecolors='white', zorder=5)
|
|
||||||
ax.set_xlabel("Valence"); ax.set_ylabel("Arousal")
|
|
||||||
ax.set_title("Карта эмоций (Модель Рассела)")
|
|
||||||
st.pyplot(fig)
|
|
||||||
|
|
||||||
with col_right:
|
|
||||||
st.header("🎵 Рекомендованная музыка")
|
|
||||||
for _, row in playlist.iterrows():
|
|
||||||
with st.container(border=True):
|
|
||||||
c1, c2 = st.columns([1, 3])
|
|
||||||
with c1:
|
|
||||||
st.write(f"**ID:** {int(row['song_id'])}")
|
|
||||||
st.caption(f"L2: {row['distance']:.2f}")
|
|
||||||
with c2:
|
|
||||||
audio_path = matcher.get_audio_path(row['song_id'])
|
|
||||||
if audio_path:
|
|
||||||
st.audio(str(audio_path))
|
|
||||||
else:
|
|
||||||
st.warning(f"Файл {int(row['song_id'])}.mp3 не найден")
|
|
||||||
|
|
||||||
if st.button("Начать заново", type="primary"):
|
|
||||||
for key in list(st.session_state.keys()): del st.session_state[key]
|
|
||||||
st.rerun()
|
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
import streamlit as st
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
EMO_NAMES = {0: "amusement", 1: "anger", 2: "awe", 3: "contentment",
|
||||||
|
4: "disgust", 5: "excitement", 6: "fear", 7: "sadness"}
|
||||||
|
|
||||||
|
def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path):
|
||||||
|
if image_files is None:
|
||||||
|
st.error("Ошибка загрузки данных EmoSet. Проверьте пути.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Инициализация состояния именно для этой вкладки
|
||||||
|
if 'ds_round' not in st.session_state:
|
||||||
|
st.session_state.ds_round = 1
|
||||||
|
st.session_state.ds_chosen_indices = []
|
||||||
|
st.session_state.ds_current_options = random.sample(range(len(image_files)), 6)
|
||||||
|
|
||||||
|
st.write("Выберите изображение, соответствующее вашему настроению:")
|
||||||
|
|
||||||
|
if st.session_state.ds_round <= 10:
|
||||||
|
st.subheader(f"Раунд {st.session_state.ds_round} из 10")
|
||||||
|
|
||||||
|
cols = st.columns(3)
|
||||||
|
for i, idx in enumerate(st.session_state.ds_current_options):
|
||||||
|
with cols[i % 3]:
|
||||||
|
img_name = image_files[idx]
|
||||||
|
img = Image.open(images_path / img_name)
|
||||||
|
st.image(img, use_container_width=True)
|
||||||
|
|
||||||
|
if matcher:
|
||||||
|
v_p, a_p = matcher.predict_va(embeddings[idx])
|
||||||
|
gt_label = EMO_NAMES.get(labels_array[idx], "unknown")
|
||||||
|
st.caption(f"GT: {gt_label} | Pred: V:{v_p:.1f} A:{a_p:.1f}")
|
||||||
|
|
||||||
|
if st.button(f"Выбрать образ {i+1}", key=f"btn_ds_{idx}", use_container_width=True):
|
||||||
|
st.session_state.ds_chosen_indices.append(idx)
|
||||||
|
st.session_state.ds_round += 1
|
||||||
|
if st.session_state.ds_round <= 10:
|
||||||
|
st.session_state.ds_current_options = random.sample(range(len(image_files)), 6)
|
||||||
|
st.rerun()
|
||||||
|
else:
|
||||||
|
st.success("✅ Анализ завершен! Ваш эмоциональный профиль готов.")
|
||||||
|
|
||||||
|
all_v, all_a = [], []
|
||||||
|
for idx in st.session_state.ds_chosen_indices:
|
||||||
|
v, a = matcher.predict_va(embeddings[idx])
|
||||||
|
all_v.append(v)
|
||||||
|
all_a.append(a)
|
||||||
|
|
||||||
|
target_v, target_a = np.mean(all_v), np.mean(all_a)
|
||||||
|
playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5)
|
||||||
|
|
||||||
|
col_left, col_right = st.columns([1, 2])
|
||||||
|
|
||||||
|
with col_left:
|
||||||
|
st.header("📊 Ваш профиль")
|
||||||
|
st.metric("Позитивность (Valence)", f"{target_v:.2f}")
|
||||||
|
st.metric("Энергия (Arousal)", f"{target_a:.2f}")
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(4, 4))
|
||||||
|
ax.set_xlim(1, 9); ax.set_ylim(1, 9)
|
||||||
|
ax.axhline(5, color='gray', lw=1, ls='--'); ax.axvline(5, color='gray', lw=1, ls='--')
|
||||||
|
ax.scatter(target_v, target_a, color='red', s=150, edgecolors='white', zorder=5)
|
||||||
|
ax.set_xlabel("Valence"); ax.set_ylabel("Arousal")
|
||||||
|
st.pyplot(fig)
|
||||||
|
|
||||||
|
with col_right:
|
||||||
|
st.header("🎵 Рекомендованная музыка")
|
||||||
|
for _, row in playlist.iterrows():
|
||||||
|
with st.container(border=True):
|
||||||
|
c1, c2 = st.columns([1, 3])
|
||||||
|
with c1:
|
||||||
|
st.write(f"**ID:** {int(row['song_id'])}")
|
||||||
|
st.caption(f"L2 Dist: {row['distance']:.2f}")
|
||||||
|
with c2:
|
||||||
|
audio_path = matcher.get_audio_path(row['song_id'])
|
||||||
|
if audio_path:
|
||||||
|
st.audio(str(audio_path))
|
||||||
|
else:
|
||||||
|
st.warning(f"Файл {int(row['song_id'])}.mp3 не найден")
|
||||||
|
|
||||||
|
if st.button("Начать заново", type="primary"):
|
||||||
|
st.session_state.pop('ds_round', None)
|
||||||
|
st.session_state.pop('ds_chosen_indices', None)
|
||||||
|
st.session_state.pop('ds_current_options', None)
|
||||||
|
st.rerun()
|
||||||
Reference in New Issue
Block a user