Beta v.1.0

This commit is contained in:
zin
2026-05-06 19:48:18 +00:00
parent e6cd11b615
commit 95595a5a5e
5 changed files with 155 additions and 208 deletions
+74 -83
View File
@@ -8,7 +8,7 @@ import matplotlib.pyplot as plt
from music_engine.matcher import MusicMatcher
# ----------------------------
# 1️⃣ Конфигурация и запуск
# 1️⃣ Запуск Streamlit
# ----------------------------
if __name__ == "__main__":
import os
@@ -23,132 +23,123 @@ if __name__ == "__main__":
subprocess.run(cmd)
sys.exit()
# Конфигурация путей
DATA_ROOT = Path("./dataset/EmoSet-118K/test")
IMAGES_DIR = DATA_ROOT / "images"
LABELS_CSV = DATA_ROOT / "labels.csv"
# Словарь для отладки
EMO_NAMES = {0: "amusement", 1: "anger", 2: "awe", 3: "contentment",
4: "disgust", 5: "excitement", 6: "fear", 7: "sadness"}
EMBEDDINGS_PATH = Path("./src/emoset_test_embeddings.npy")
LABELS_PATH = Path("./src/emoset_test_labels.npy")
NUM_CHOICES = 6
TOTAL_ROUNDS = 10
# Словарь для расшифровки меток (алфавитный порядок EmoSet)
EMO_NAMES = {
0: "amusement (веселье)",
1: "anger (гнев)",
2: "awe (трепет)",
3: "contentment (удовлетворение)",
4: "disgust (отвращение)",
5: "excitement (возбуждение)",
6: "fear (страх)",
7: "sadness (грусть)"
}
st.set_page_config(page_title="Debug Mode: EmoSet & Music", layout="wide")
st.set_page_config(page_title="Thesis Demo: Image-Music", layout="wide")
@st.cache_resource
def load_music_engine():
base_dir = Path(__file__).resolve().parent
db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv"
model_path = base_dir / "music_engine" / "va_regressor.pkl"
if not db_path.exists(): return None
if not db_path.exists():
return None
return MusicMatcher(db_path=db_path, model_path=model_path)
matcher = load_music_engine()
@st.cache_data
def load_emoset_data():
if not IMAGES_DIR.exists() or not EMBEDDINGS_PATH.exists() or not LABELS_CSV.exists():
return None, None, None
df = pd.read_csv(LABELS_CSV)
image_files = df['filename'].tolist()
embeddings = np.load(EMBEDDINGS_PATH)
labels_array = np.load(LABELS_PATH)
return image_files, embeddings, labels_array
csv_path = Path("./dataset/EmoSet-118K/test/labels.csv")
img_dir = Path("./dataset/EmoSet-118K/test/images")
emb_path = Path("./src/emoset_test_embeddings.npy")
lbl_path = Path("./src/emoset_test_labels.npy")
image_files, embeddings, labels_array = load_emoset_data()
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
return None, None, None, None
df = pd.read_csv(csv_path)
image_list = df['filename'].tolist()
embs = np.load(emb_path)
lbls = np.load(lbl_path)
return image_list, embs, lbls, img_dir
image_files, embeddings, labels_array, images_path = load_emoset_data()
# ----------------------------
# 3️⃣ Логика приложения
# 2️⃣ Основной интерфейс
# ----------------------------
if image_files is None:
st.error("Данные не найдены.")
st.error("Ошибка загрузки данных EmoSet. Проверьте пути.")
else:
if 'round' not in st.session_state:
st.session_state.round = 1
st.session_state.chosen_indices = []
st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES)
st.session_state.current_options = random.sample(range(len(image_files)), 6)
st.title("🧪 Отладка эмоционального маппинга")
if st.session_state.round <= TOTAL_ROUNDS:
st.write(f"**Раунд {st.session_state.round} из {TOTAL_ROUNDS}**")
st.title("🖼️ Эмоциональный подбор музыки")
if st.session_state.round <= 10:
st.subheader(f"Раунд {st.session_state.round} из 10")
st.write("Выберите изображение, соответствующее вашему настроению:")
# Сетка выбора с отладочной информацией
cols = st.columns(3)
for i, idx in enumerate(st.session_state.current_options):
with cols[i % 3]:
img_path = IMAGES_DIR / image_files[idx]
img = Image.open(img_path)
img_name = image_files[idx]
img = Image.open(images_path / img_name)
st.image(img, use_container_width=True)
# --- ИНФОРМАЦИЯ ДЛЯ ОТЛАДКИ ---
# Информация для отладки
if matcher:
v, a = matcher.predict_va(embeddings[idx])
label_id = labels_array[idx]
label_name = EMO_NAMES.get(label_id, "Unknown")
st.caption(f"**Класс:** {label_name}")
st.caption(f"**Прогноз:** V: {v:.2f} | A: {a:.2f}")
# ------------------------------
v_p, a_p = matcher.predict_va(embeddings[idx])
gt_label = EMO_NAMES.get(labels_array[idx], "unknown")
st.caption(f"GT: {gt_label} | Pred: V:{v_p:.1f} A:{a_p:.1f}")
if st.button(f"Выбрать образ {i+1}", key=f"btn_{idx}", use_container_width=True):
st.session_state.chosen_indices.append(idx)
st.session_state.round += 1
if st.session_state.round <= TOTAL_ROUNDS:
st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES)
if st.session_state.round <= 10:
st.session_state.current_options = random.sample(range(len(image_files)), 6)
st.rerun()
else:
# ФИНАЛЬНЫЙ ЭТАП
st.success("Профиль сформирован!")
# РЕЗУЛЬТАТЫ
st.success("Анализ завершен! Ваш эмоциональный профиль готов.")
# Считаем итог
chosen_embs = embeddings[st.session_state.chosen_indices]
all_v, all_a = [], []
for emb in chosen_embs:
v, a = matcher.predict_va(emb)
for idx in st.session_state.chosen_indices:
v, a = matcher.predict_va(embeddings[idx])
all_v.append(v)
all_a.append(a)
target_v, target_a = np.mean(all_v), np.mean(all_a)
# Вывод результатов
col1, col2 = st.columns([2, 1])
with col1:
st.header("🎵 Ваш плейлист")
distances = np.sqrt((matcher.music_db['valence'] - target_v)**2 + (matcher.music_db['arousal'] - target_a)**2)
playlist = matcher.music_db.copy()
playlist['distance'] = distances
st.table(playlist.sort_values(by='distance').head(5)[['song_id', 'valence', 'arousal', 'distance']])
playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5)
with col2:
st.header("📊 Профиль")
st.metric("Valence (Итог)", f"{target_v:.2f}")
st.metric("Arousal (Итог)", f"{target_a:.2f}")
col_left, col_right = st.columns([1, 2])
with col_left:
st.header("📊 Ваш профиль")
st.metric("Позитивность (Valence)", f"{target_v:.2f}")
st.metric("Энергия (Arousal)", f"{target_a:.2f}")
# Показываем, что именно выбрал пользователь
st.divider()
st.subheader("Выбранные вами образы и их веса:")
sum_cols = st.columns(5)
for i, idx in enumerate(st.session_state.chosen_indices):
with sum_cols[i % 5]:
v_i, a_i = matcher.predict_va(embeddings[idx])
st.image(Image.open(IMAGES_DIR / image_files[idx]), use_container_width=True)
st.write(f"V:{v_i:.1f} A:{a_i:.1f}")
# График Рассела
fig, ax = plt.subplots(figsize=(4, 4))
ax.set_xlim(1, 9); ax.set_ylim(1, 9)
ax.axhline(5, color='gray', lw=1, ls='--')
ax.axvline(5, color='gray', lw=1, ls='--')
ax.scatter(target_v, target_a, color='red', s=150, edgecolors='white', zorder=5)
ax.set_xlabel("Valence"); ax.set_ylabel("Arousal")
ax.set_title("Карта эмоций (Модель Рассела)")
st.pyplot(fig)
if st.button("Начать заново"):
with col_right:
st.header("🎵 Рекомендованная музыка")
for _, row in playlist.iterrows():
with st.container(border=True):
c1, c2 = st.columns([1, 3])
with c1:
st.write(f"**ID:** {int(row['song_id'])}")
st.caption(f"L2: {row['distance']:.2f}")
with c2:
audio_path = matcher.get_audio_path(row['song_id'])
if audio_path:
st.audio(str(audio_path))
else:
st.warning(f"Файл {int(row['song_id'])}.mp3 не найден")
if st.button("Начать заново", type="primary"):
for key in list(st.session_state.keys()): del st.session_state[key]
st.rerun()