Initial Demo

This commit is contained in:
zin
2026-01-15 02:05:34 +00:00
parent 2acb29a598
commit 025cc061f7
10 changed files with 4012 additions and 1 deletions

220
src/main.py Normal file
View File

@@ -0,0 +1,220 @@
import streamlit as st
from pathlib import Path
import pandas as pd
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import random
import matplotlib.pyplot as plt
if __name__ == "__main__":
# Проверяем, запущен ли скрипт через Streamlit
import os
if "STREAMLIT_RUN" not in os.environ:
import sys
import subprocess
os.environ["STREAMLIT_RUN"] = "1"
# Формируем команду запуска
cmd = [
sys.executable,
"-m",
"streamlit",
"run",
__file__,
"--server.port", "8080",
"--server.address", "0.0.0.0"
]
subprocess.run(cmd)
sys.exit()
# ----------------------------
# 1⃣ Конфигурация
# ----------------------------
DATA_ROOT = Path("./dataset/EmoSet-118K/test")
IMAGES_DIR = DATA_ROOT / "images"
LABELS_CSV = DATA_ROOT / "labels.csv"
EMBEDDINGS_PATH = Path("./src/emoset_test_embeddings.npy")
LABELS_PATH = Path("./src/emoset_test_labels.npy")
NUM_CHOICES = 6 # количество изображений за один раунд
TOTAL_ROUNDS = 10 # количество раундов выбора
st.set_page_config(page_title="EmoSet Demo", layout="wide")
# ----------------------------
# 2⃣ Загрузка данных
# ----------------------------
if not IMAGES_DIR.exists():
st.error(f"Папка с изображениями не найдена: {IMAGES_DIR.resolve()}")
st.stop()
labels_df = pd.read_csv(LABELS_CSV)
embeddings = np.load(EMBEDDINGS_PATH)
labels_array = np.load(LABELS_PATH)
image_files = list(IMAGES_DIR.glob("*.jpg"))
# Проверка совпадения размеров
if len(image_files) != len(embeddings) or len(image_files) != len(labels_array):
st.error("Размеры массивов не совпадают!")
st.stop()
# Создадим mapping: filename -> embedding, label
filename2embedding = {f.name: emb for f, emb in zip(image_files, embeddings)}
filename2label = {f.name: lbl for f, lbl in zip(image_files, labels_array)}
# ----------------------------
# 3⃣ Сессия Streamlit
# ----------------------------
if 'round_num' not in st.session_state:
st.session_state.round_num = 0
if 'chosen_files' not in st.session_state:
st.session_state.chosen_files = []
st.title("EmoSet: Выбор изображений")
# ----------------------------
# 4⃣ Функция для overlay топ-3 эмоций
# ----------------------------
def get_font():
try:
return ImageFont.truetype("arial.ttf", 14)
except:
try:
return ImageFont.truetype("/usr/share/fonts/truetype/arial.ttf", 14)
except:
return ImageFont.load_default()
def overlay_top_emotions(img: Image.Image, label: int, top_n=3):
draw = ImageDraw.Draw(img)
font = get_font()
text = f"Label: {label}"
draw.rectangle([(0,0),(img.width,20)], fill=(0,0,0,150))
draw.text((2,2), text, fill=(255,255,255), font=font)
return img
# ----------------------------
# 5⃣ Рандомный выбор 6 изображений для текущего раунда
# ----------------------------
# Инициализация состояния Streamlit
if "round_num" not in st.session_state:
st.session_state.round_num = 0
if "chosen_files" not in st.session_state:
st.session_state.chosen_files = []
if "current_choices" not in st.session_state:
st.session_state.current_choices = []
# Если все раунды уже завершены, блок пропускается
if st.session_state.round_num < TOTAL_ROUNDS:
st.subheader(f"Раунд {st.session_state.round_num + 1} из {TOTAL_ROUNDS}")
# Генерируем выбор для текущего раунда только если он ещё не создан
if len(st.session_state.current_choices) == 0:
already_chosen = set(st.session_state.chosen_files)
available_images = [f for f in image_files if f.name not in already_chosen]
if len(available_images) < NUM_CHOICES:
st.warning("Недостаточно изображений для выбора!")
st.stop()
st.session_state.current_choices = random.sample(available_images, NUM_CHOICES)
# Отображаем изображения и кнопки
cols = st.columns(NUM_CHOICES)
for col, img_path in zip(cols, st.session_state.current_choices):
# Загружаем изображение и накладываем overlay
img = Image.open(img_path).convert("RGB")
img_overlay = overlay_top_emotions(img, filename2label[img_path.name])
# Показываем изображение
col.image(img_overlay, width=250)
# Кнопка выбора
if col.button("Выбрать", key=img_path.name):
st.session_state.chosen_files.append(img_path.name)
st.session_state.round_num += 1
st.session_state.current_choices = [] # сброс для следующего раунда
st.rerun() # перезапуск для нового раунда
# ----------------------------
# 6⃣ После всех выборов: отображение результатов
# ----------------------------
else:
st.subheader("Результаты выбора пользователя")
if not st.session_state.chosen_files:
st.warning("Вы не сделали ни одного выбора!")
if st.button("Начать заново"):
st.session_state.round_num = 0
st.session_state.chosen_files = []
st.rerun()
else:
# Отображение выбранных изображений
cols = st.columns(3)
for i, filename in enumerate(st.session_state.chosen_files):
with cols[i % 3]:
img = Image.open(IMAGES_DIR / filename).convert("RGB")
img_overlay = overlay_top_emotions(img, filename2label[filename])
st.image(img_overlay, caption=f"Выбор {i+1}")
# Расчет среднего embedding
chosen_embeddings = [filename2embedding[f] for f in st.session_state.chosen_files]
chosen_embeddings = np.stack(chosen_embeddings)
user_emotion_vector = np.mean(chosen_embeddings, axis=0)
# Отображение информации о векторе эмоций
with st.expander("Подробности о векторной модели эмоций"):
st.write(f"Размерность embedding: {len(user_emotion_vector)}")
st.write("Первые 10 значений:")
st.json({f"Dim {i}": float(val) for i, val in enumerate(user_emotion_vector[:10])})
# Визуализация
col1, col2 = st.columns(2)
with col1:
# Гистограмма первых 30 измерений
plt.figure(figsize=(8,4))
plt.bar(range(min(30, len(user_emotion_vector))), user_emotion_vector[:30])
plt.xlabel("Embedding dimension")
plt.ylabel("Value")
plt.title("Распределение значений embedding (первые 30 измерений)")
st.pyplot(plt)
with col2:
# Круговая диаграмма средних значений по блокам
if len(user_emotion_vector) > 16:
block_size = len(user_emotion_vector) // 4
block_means = [
np.mean(user_emotion_vector[i*block_size:(i+1)*block_size])
for i in range(4)
]
plt.figure(figsize=(8,4))
plt.pie(block_means, labels=[f"Block {i+1}" for i in range(4)], autopct='%1.1f%%')
plt.title("Распределение эмоциональных блоков")
st.pyplot(plt)
# Кнопка для сохранения результатов
if st.button("Сохранить результаты"):
# Сохранение в файл (например, JSON)
results = {
"chosen_files": st.session_state.chosen_files,
"user_emotion_vector": user_emotion_vector.tolist(),
"timestamp": pd.Timestamp.now().isoformat()
}
# Сохраняем в текущую директорию
save_path = Path("user_results.json")
with open(save_path, 'w') as f:
import json
json.dump(results, f)
st.success(f"Результаты сохранены в {save_path}")
if st.button("Начать заново"):
st.session_state.round_num = 0
st.session_state.chosen_files = []
st.rerun()