import streamlit as st import random import numpy as np from PIL import Image import matplotlib.pyplot as plt EMO_NAMES = {0: "amusement", 1: "anger", 2: "awe", 3: "contentment", 4: "disgust", 5: "excitement", 6: "fear", 7: "sadness"} def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path): if image_files is None: st.error("Ошибка загрузки данных EmoSet. Проверьте пути.") return # Инициализация состояния именно для этой вкладки if 'ds_round' not in st.session_state: st.session_state.ds_round = 1 st.session_state.ds_chosen_indices = [] st.session_state.ds_current_options = random.sample(range(len(image_files)), 6) st.write("Выберите изображение, соответствующее вашему настроению:") if st.session_state.ds_round <= 10: st.subheader(f"Раунд {st.session_state.ds_round} из 10") cols = st.columns(3) for i, idx in enumerate(st.session_state.ds_current_options): with cols[i % 3]: img_name = image_files[idx] img = Image.open(images_path / img_name) st.image(img, use_container_width=True) if matcher: v_p, a_p = matcher.predict_va(embeddings[idx]) gt_label = EMO_NAMES.get(labels_array[idx], "unknown") st.caption(f"GT: {gt_label} | Pred: V:{v_p:.1f} A:{a_p:.1f}") if st.button(f"Выбрать образ {i+1}", key=f"btn_ds_{idx}", use_container_width=True): st.session_state.ds_chosen_indices.append(idx) st.session_state.ds_round += 1 if st.session_state.ds_round <= 10: st.session_state.ds_current_options = random.sample(range(len(image_files)), 6) st.rerun() else: st.success("✅ Анализ завершен! Ваш эмоциональный профиль готов.") all_v, all_a = [], [] for idx in st.session_state.ds_chosen_indices: v, a = matcher.predict_va(embeddings[idx]) all_v.append(v) all_a.append(a) target_v, target_a = np.mean(all_v), np.mean(all_a) playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5) col_left, col_right = st.columns([1, 2]) with col_left: st.header("📊 Ваш профиль") st.metric("Позитивность (Valence)", f"{target_v:.2f}") st.metric("Энергия (Arousal)", f"{target_a:.2f}") fig, ax = plt.subplots(figsize=(4, 4)) ax.set_xlim(1, 9); ax.set_ylim(1, 9) ax.axhline(5, color='gray', lw=1, ls='--'); ax.axvline(5, color='gray', lw=1, ls='--') ax.scatter(target_v, target_a, color='red', s=150, edgecolors='white', zorder=5) ax.set_xlabel("Valence"); ax.set_ylabel("Arousal") st.pyplot(fig) with col_right: st.header("🎵 Рекомендованная музыка") for _, row in playlist.iterrows(): with st.container(border=True): c1, c2 = st.columns([1, 3]) with c1: st.write(f"**ID:** {int(row['song_id'])}") score_val = row.get('final_score', row.get('emo_distance', 0)) st.caption(f"Dist Score: {score_val:.2f}") with c2: audio_path = matcher.get_audio_path(row['song_id']) if audio_path: st.audio(str(audio_path)) else: st.warning(f"Файл {int(row['song_id'])}.mp3 не найден") if st.button("Начать заново", type="primary"): st.session_state.pop('ds_round', None) st.session_state.pop('ds_chosen_indices', None) st.session_state.pop('ds_current_options', None) st.rerun()