Integrated Music_engine and update Debug UI
This commit is contained in:
+73
-59
@@ -31,39 +31,41 @@ LABELS_CSV = DATA_ROOT / "labels.csv"
|
||||
EMBEDDINGS_PATH = Path("./src/emoset_test_embeddings.npy")
|
||||
LABELS_PATH = Path("./src/emoset_test_labels.npy")
|
||||
|
||||
# Параметры эксперимента
|
||||
NUM_CHOICES = 6
|
||||
TOTAL_ROUNDS = 10
|
||||
|
||||
st.set_page_config(page_title="EmoSet & Music Recommendation Demo", layout="wide")
|
||||
# Словарь для расшифровки меток (алфавитный порядок EmoSet)
|
||||
EMO_NAMES = {
|
||||
0: "amusement (веселье)",
|
||||
1: "anger (гнев)",
|
||||
2: "awe (трепет)",
|
||||
3: "contentment (удовлетворение)",
|
||||
4: "disgust (отвращение)",
|
||||
5: "excitement (возбуждение)",
|
||||
6: "fear (страх)",
|
||||
7: "sadness (грусть)"
|
||||
}
|
||||
|
||||
st.set_page_config(page_title="Debug Mode: EmoSet & Music", layout="wide")
|
||||
|
||||
# Инициализация музыкального движка с кэшированием
|
||||
@st.cache_resource
|
||||
def load_music_engine():
|
||||
# Путь относительно src: ../dataset/DEAM/music_db.csv
|
||||
db_path = Path(__file__).parent.parent / "dataset" / "DEAM" / "music_db.csv"
|
||||
if not db_path.exists():
|
||||
return None
|
||||
return MusicMatcher(db_path)
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv"
|
||||
model_path = base_dir / "music_engine" / "va_regressor.pkl"
|
||||
if not db_path.exists(): return None
|
||||
return MusicMatcher(db_path=db_path, model_path=model_path)
|
||||
|
||||
matcher = load_music_engine()
|
||||
|
||||
# ----------------------------
|
||||
# 2️⃣ Загрузка данных EmoSet
|
||||
# ----------------------------
|
||||
@st.cache_data
|
||||
def load_emoset_data():
|
||||
if not IMAGES_DIR.exists() or not EMBEDDINGS_PATH.exists():
|
||||
if not IMAGES_DIR.exists() or not EMBEDDINGS_PATH.exists() or not LABELS_CSV.exists():
|
||||
return None, None, None
|
||||
|
||||
image_files = sorted([f.name for f in IMAGES_DIR.glob("*.jpg")])
|
||||
df = pd.read_csv(LABELS_CSV)
|
||||
image_files = df['filename'].tolist()
|
||||
embeddings = np.load(EMBEDDINGS_PATH)
|
||||
labels_array = np.load(LABELS_PATH)
|
||||
|
||||
if len(image_files) != len(embeddings) or len(image_files) != len(labels_array):
|
||||
st.error("Размеры массивов данных не совпадают!")
|
||||
st.stop()
|
||||
|
||||
return image_files, embeddings, labels_array
|
||||
|
||||
image_files, embeddings, labels_array = load_emoset_data()
|
||||
@@ -72,69 +74,81 @@ image_files, embeddings, labels_array = load_emoset_data()
|
||||
# 3️⃣ Логика приложения
|
||||
# ----------------------------
|
||||
if image_files is None:
|
||||
st.error("Данные EmoSet не найдены. Проверьте папку dataset.")
|
||||
st.error("Данные не найдены.")
|
||||
else:
|
||||
if 'round' not in st.session_state:
|
||||
st.session_state.round = 1
|
||||
st.session_state.chosen_indices = []
|
||||
st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES)
|
||||
|
||||
st.title("Выбор эмоциональных образов")
|
||||
st.write(f"Раунд {st.session_state.round} из {TOTAL_ROUNDS}. Выберите изображение, которое больше всего соответствует вашему настроению.")
|
||||
|
||||
st.title("🧪 Отладка эмоционального маппинга")
|
||||
|
||||
if st.session_state.round <= TOTAL_ROUNDS:
|
||||
# Отображение сетки изображений
|
||||
st.write(f"**Раунд {st.session_state.round} из {TOTAL_ROUNDS}**")
|
||||
|
||||
# Сетка выбора с отладочной информацией
|
||||
cols = st.columns(3)
|
||||
for i, idx in enumerate(st.session_state.current_options):
|
||||
with cols[i % 3]:
|
||||
img_path = IMAGES_DIR / image_files[idx]
|
||||
img = Image.open(img_path)
|
||||
st.image(img, use_container_width=True)
|
||||
if st.button(f"Выбрать {i+1}", key=f"btn_{idx}"):
|
||||
|
||||
# --- ИНФОРМАЦИЯ ДЛЯ ОТЛАДКИ ---
|
||||
if matcher:
|
||||
v, a = matcher.predict_va(embeddings[idx])
|
||||
label_id = labels_array[idx]
|
||||
label_name = EMO_NAMES.get(label_id, "Unknown")
|
||||
|
||||
st.caption(f"**Класс:** {label_name}")
|
||||
st.caption(f"**Прогноз:** V: {v:.2f} | A: {a:.2f}")
|
||||
# ------------------------------
|
||||
|
||||
if st.button(f"Выбрать образ {i+1}", key=f"btn_{idx}", use_container_width=True):
|
||||
st.session_state.chosen_indices.append(idx)
|
||||
st.session_state.round += 1
|
||||
if st.session_state.round <= TOTAL_ROUNDS:
|
||||
st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES)
|
||||
st.rerun()
|
||||
else:
|
||||
# ФИНАЛЬНЫЙ ЭТАП: Анализ и Музыка
|
||||
st.success("Анализ завершен! Ваш эмоциональный профиль сформирован.")
|
||||
# ФИНАЛЬНЫЙ ЭТАП
|
||||
st.success("✅ Профиль сформирован!")
|
||||
|
||||
# Расчет среднего вектора пользователя
|
||||
chosen_embeddings = embeddings[st.session_state.chosen_indices]
|
||||
user_vector = np.mean(chosen_embeddings, axis=0)
|
||||
|
||||
# РАЗДЕЛ МУЗЫКАЛЬНЫХ РЕКОМЕНДАЦИЙ
|
||||
st.divider()
|
||||
st.header("🎵 Рекомендованный плейлист")
|
||||
# Считаем итог
|
||||
chosen_embs = embeddings[st.session_state.chosen_indices]
|
||||
all_v, all_a = [], []
|
||||
for emb in chosen_embs:
|
||||
v, a = matcher.predict_va(emb)
|
||||
all_v.append(v)
|
||||
all_a.append(a)
|
||||
|
||||
if matcher is None:
|
||||
st.warning("База данных DEAM (music_db.csv) не найдена. Подбор музыки недоступен.")
|
||||
else:
|
||||
with st.spinner("Сопоставляем визуальный профиль с музыкальной базой..."):
|
||||
target_v, target_a, playlist = matcher.get_playlist(user_vector, top_k=5)
|
||||
|
||||
# Визуализация VA-метрик пользователя
|
||||
m1, m2, m3 = st.columns(3)
|
||||
m1.metric("Позитивность (Valence)", f"{target_v:.2f}", help="Шкала 1-9")
|
||||
m2.metric("Энергия (Arousal)", f"{target_a:.2f}", help="Шкала 1-9")
|
||||
m3.metric("Найдено треков", len(playlist))
|
||||
target_v, target_a = np.mean(all_v), np.mean(all_a)
|
||||
|
||||
# Вывод результатов
|
||||
col1, col2 = st.columns([2, 1])
|
||||
|
||||
with col1:
|
||||
st.header("🎵 Ваш плейлист")
|
||||
distances = np.sqrt((matcher.music_db['valence'] - target_v)**2 + (matcher.music_db['arousal'] - target_a)**2)
|
||||
playlist = matcher.music_db.copy()
|
||||
playlist['distance'] = distances
|
||||
st.table(playlist.sort_values(by='distance').head(5)[['song_id', 'valence', 'arousal', 'distance']])
|
||||
|
||||
# Таблица с результатами
|
||||
st.subheader("Топ-5 подходящих композиций")
|
||||
st.table(playlist[['song_id', 'valence', 'arousal', 'distance']])
|
||||
|
||||
st.info("💡 Вы можете найти эти треки по ID в папке audio датасета DEAM.")
|
||||
|
||||
# Визуализация вектора (графики)
|
||||
with col2:
|
||||
st.header("📊 Профиль")
|
||||
st.metric("Valence (Итог)", f"{target_v:.2f}")
|
||||
st.metric("Arousal (Итог)", f"{target_a:.2f}")
|
||||
|
||||
# Показываем, что именно выбрал пользователь
|
||||
st.divider()
|
||||
st.subheader("Визуализация эмоционального вектора")
|
||||
fig, ax = plt.subplots(figsize=(10, 3))
|
||||
ax.plot(user_vector[:100]) # Показываем первые 100 измерений для наглядности
|
||||
ax.set_title("Эмбеддинг вашего настроения (фрагмент)")
|
||||
st.pyplot(fig)
|
||||
st.subheader("Выбранные вами образы и их веса:")
|
||||
sum_cols = st.columns(5)
|
||||
for i, idx in enumerate(st.session_state.chosen_indices):
|
||||
with sum_cols[i % 5]:
|
||||
v_i, a_i = matcher.predict_va(embeddings[idx])
|
||||
st.image(Image.open(IMAGES_DIR / image_files[idx]), use_container_width=True)
|
||||
st.write(f"V:{v_i:.1f} A:{a_i:.1f}")
|
||||
|
||||
if st.button("Начать заново"):
|
||||
for key in list(st.session_state.keys()):
|
||||
del st.session_state[key]
|
||||
for key in list(st.session_state.keys()): del st.session_state[key]
|
||||
st.rerun()
|
||||
Reference in New Issue
Block a user