Refactored main.py
This commit is contained in:
@@ -0,0 +1,89 @@
|
||||
import streamlit as st
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
EMO_NAMES = {0: "amusement", 1: "anger", 2: "awe", 3: "contentment",
|
||||
4: "disgust", 5: "excitement", 6: "fear", 7: "sadness"}
|
||||
|
||||
def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path):
|
||||
if image_files is None:
|
||||
st.error("Ошибка загрузки данных EmoSet. Проверьте пути.")
|
||||
return
|
||||
|
||||
# Инициализация состояния именно для этой вкладки
|
||||
if 'ds_round' not in st.session_state:
|
||||
st.session_state.ds_round = 1
|
||||
st.session_state.ds_chosen_indices = []
|
||||
st.session_state.ds_current_options = random.sample(range(len(image_files)), 6)
|
||||
|
||||
st.write("Выберите изображение, соответствующее вашему настроению:")
|
||||
|
||||
if st.session_state.ds_round <= 10:
|
||||
st.subheader(f"Раунд {st.session_state.ds_round} из 10")
|
||||
|
||||
cols = st.columns(3)
|
||||
for i, idx in enumerate(st.session_state.ds_current_options):
|
||||
with cols[i % 3]:
|
||||
img_name = image_files[idx]
|
||||
img = Image.open(images_path / img_name)
|
||||
st.image(img, use_container_width=True)
|
||||
|
||||
if matcher:
|
||||
v_p, a_p = matcher.predict_va(embeddings[idx])
|
||||
gt_label = EMO_NAMES.get(labels_array[idx], "unknown")
|
||||
st.caption(f"GT: {gt_label} | Pred: V:{v_p:.1f} A:{a_p:.1f}")
|
||||
|
||||
if st.button(f"Выбрать образ {i+1}", key=f"btn_ds_{idx}", use_container_width=True):
|
||||
st.session_state.ds_chosen_indices.append(idx)
|
||||
st.session_state.ds_round += 1
|
||||
if st.session_state.ds_round <= 10:
|
||||
st.session_state.ds_current_options = random.sample(range(len(image_files)), 6)
|
||||
st.rerun()
|
||||
else:
|
||||
st.success("✅ Анализ завершен! Ваш эмоциональный профиль готов.")
|
||||
|
||||
all_v, all_a = [], []
|
||||
for idx in st.session_state.ds_chosen_indices:
|
||||
v, a = matcher.predict_va(embeddings[idx])
|
||||
all_v.append(v)
|
||||
all_a.append(a)
|
||||
|
||||
target_v, target_a = np.mean(all_v), np.mean(all_a)
|
||||
playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5)
|
||||
|
||||
col_left, col_right = st.columns([1, 2])
|
||||
|
||||
with col_left:
|
||||
st.header("📊 Ваш профиль")
|
||||
st.metric("Позитивность (Valence)", f"{target_v:.2f}")
|
||||
st.metric("Энергия (Arousal)", f"{target_a:.2f}")
|
||||
|
||||
fig, ax = plt.subplots(figsize=(4, 4))
|
||||
ax.set_xlim(1, 9); ax.set_ylim(1, 9)
|
||||
ax.axhline(5, color='gray', lw=1, ls='--'); ax.axvline(5, color='gray', lw=1, ls='--')
|
||||
ax.scatter(target_v, target_a, color='red', s=150, edgecolors='white', zorder=5)
|
||||
ax.set_xlabel("Valence"); ax.set_ylabel("Arousal")
|
||||
st.pyplot(fig)
|
||||
|
||||
with col_right:
|
||||
st.header("🎵 Рекомендованная музыка")
|
||||
for _, row in playlist.iterrows():
|
||||
with st.container(border=True):
|
||||
c1, c2 = st.columns([1, 3])
|
||||
with c1:
|
||||
st.write(f"**ID:** {int(row['song_id'])}")
|
||||
st.caption(f"L2 Dist: {row['distance']:.2f}")
|
||||
with c2:
|
||||
audio_path = matcher.get_audio_path(row['song_id'])
|
||||
if audio_path:
|
||||
st.audio(str(audio_path))
|
||||
else:
|
||||
st.warning(f"Файл {int(row['song_id'])}.mp3 не найден")
|
||||
|
||||
if st.button("Начать заново", type="primary"):
|
||||
st.session_state.pop('ds_round', None)
|
||||
st.session_state.pop('ds_chosen_indices', None)
|
||||
st.session_state.pop('ds_current_options', None)
|
||||
st.rerun()
|
||||
Reference in New Issue
Block a user