Initial Demo
This commit is contained in:
220
src/main.py
Normal file
220
src/main.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import streamlit as st
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import random
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Проверяем, запущен ли скрипт через Streamlit
|
||||
import os
|
||||
if "STREAMLIT_RUN" not in os.environ:
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
os.environ["STREAMLIT_RUN"] = "1"
|
||||
|
||||
# Формируем команду запуска
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"streamlit",
|
||||
"run",
|
||||
__file__,
|
||||
"--server.port", "8080",
|
||||
"--server.address", "0.0.0.0"
|
||||
]
|
||||
subprocess.run(cmd)
|
||||
sys.exit()
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 1️⃣ Конфигурация
|
||||
# ----------------------------
|
||||
DATA_ROOT = Path("./dataset/EmoSet-118K/test")
|
||||
IMAGES_DIR = DATA_ROOT / "images"
|
||||
LABELS_CSV = DATA_ROOT / "labels.csv"
|
||||
|
||||
EMBEDDINGS_PATH = Path("./src/emoset_test_embeddings.npy")
|
||||
LABELS_PATH = Path("./src/emoset_test_labels.npy")
|
||||
|
||||
NUM_CHOICES = 6 # количество изображений за один раунд
|
||||
TOTAL_ROUNDS = 10 # количество раундов выбора
|
||||
|
||||
st.set_page_config(page_title="EmoSet Demo", layout="wide")
|
||||
|
||||
# ----------------------------
|
||||
# 2️⃣ Загрузка данных
|
||||
# ----------------------------
|
||||
if not IMAGES_DIR.exists():
|
||||
st.error(f"Папка с изображениями не найдена: {IMAGES_DIR.resolve()}")
|
||||
st.stop()
|
||||
|
||||
labels_df = pd.read_csv(LABELS_CSV)
|
||||
embeddings = np.load(EMBEDDINGS_PATH)
|
||||
labels_array = np.load(LABELS_PATH)
|
||||
|
||||
image_files = list(IMAGES_DIR.glob("*.jpg"))
|
||||
|
||||
# Проверка совпадения размеров
|
||||
if len(image_files) != len(embeddings) or len(image_files) != len(labels_array):
|
||||
st.error("Размеры массивов не совпадают!")
|
||||
st.stop()
|
||||
|
||||
# Создадим mapping: filename -> embedding, label
|
||||
filename2embedding = {f.name: emb for f, emb in zip(image_files, embeddings)}
|
||||
filename2label = {f.name: lbl for f, lbl in zip(image_files, labels_array)}
|
||||
|
||||
# ----------------------------
|
||||
# 3️⃣ Сессия Streamlit
|
||||
# ----------------------------
|
||||
if 'round_num' not in st.session_state:
|
||||
st.session_state.round_num = 0
|
||||
if 'chosen_files' not in st.session_state:
|
||||
st.session_state.chosen_files = []
|
||||
|
||||
st.title("EmoSet: Выбор изображений")
|
||||
|
||||
# ----------------------------
|
||||
# 4️⃣ Функция для overlay топ-3 эмоций
|
||||
# ----------------------------
|
||||
def get_font():
|
||||
try:
|
||||
return ImageFont.truetype("arial.ttf", 14)
|
||||
except:
|
||||
try:
|
||||
return ImageFont.truetype("/usr/share/fonts/truetype/arial.ttf", 14)
|
||||
except:
|
||||
return ImageFont.load_default()
|
||||
|
||||
def overlay_top_emotions(img: Image.Image, label: int, top_n=3):
|
||||
draw = ImageDraw.Draw(img)
|
||||
font = get_font()
|
||||
text = f"Label: {label}"
|
||||
draw.rectangle([(0,0),(img.width,20)], fill=(0,0,0,150))
|
||||
draw.text((2,2), text, fill=(255,255,255), font=font)
|
||||
return img
|
||||
|
||||
# ----------------------------
|
||||
# 5️⃣ Рандомный выбор 6 изображений для текущего раунда
|
||||
# ----------------------------
|
||||
# Инициализация состояния Streamlit
|
||||
if "round_num" not in st.session_state:
|
||||
st.session_state.round_num = 0
|
||||
if "chosen_files" not in st.session_state:
|
||||
st.session_state.chosen_files = []
|
||||
if "current_choices" not in st.session_state:
|
||||
st.session_state.current_choices = []
|
||||
|
||||
# Если все раунды уже завершены, блок пропускается
|
||||
if st.session_state.round_num < TOTAL_ROUNDS:
|
||||
st.subheader(f"Раунд {st.session_state.round_num + 1} из {TOTAL_ROUNDS}")
|
||||
|
||||
# Генерируем выбор для текущего раунда только если он ещё не создан
|
||||
if len(st.session_state.current_choices) == 0:
|
||||
already_chosen = set(st.session_state.chosen_files)
|
||||
available_images = [f for f in image_files if f.name not in already_chosen]
|
||||
|
||||
if len(available_images) < NUM_CHOICES:
|
||||
st.warning("Недостаточно изображений для выбора!")
|
||||
st.stop()
|
||||
|
||||
st.session_state.current_choices = random.sample(available_images, NUM_CHOICES)
|
||||
|
||||
# Отображаем изображения и кнопки
|
||||
cols = st.columns(NUM_CHOICES)
|
||||
for col, img_path in zip(cols, st.session_state.current_choices):
|
||||
# Загружаем изображение и накладываем overlay
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
img_overlay = overlay_top_emotions(img, filename2label[img_path.name])
|
||||
|
||||
# Показываем изображение
|
||||
col.image(img_overlay, width=250)
|
||||
|
||||
# Кнопка выбора
|
||||
if col.button("Выбрать", key=img_path.name):
|
||||
st.session_state.chosen_files.append(img_path.name)
|
||||
st.session_state.round_num += 1
|
||||
st.session_state.current_choices = [] # сброс для следующего раунда
|
||||
st.rerun() # перезапуск для нового раунда
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 6️⃣ После всех выборов: отображение результатов
|
||||
# ----------------------------
|
||||
else:
|
||||
st.subheader("Результаты выбора пользователя")
|
||||
|
||||
if not st.session_state.chosen_files:
|
||||
st.warning("Вы не сделали ни одного выбора!")
|
||||
if st.button("Начать заново"):
|
||||
st.session_state.round_num = 0
|
||||
st.session_state.chosen_files = []
|
||||
st.rerun()
|
||||
else:
|
||||
# Отображение выбранных изображений
|
||||
cols = st.columns(3)
|
||||
for i, filename in enumerate(st.session_state.chosen_files):
|
||||
with cols[i % 3]:
|
||||
img = Image.open(IMAGES_DIR / filename).convert("RGB")
|
||||
img_overlay = overlay_top_emotions(img, filename2label[filename])
|
||||
st.image(img_overlay, caption=f"Выбор {i+1}")
|
||||
|
||||
# Расчет среднего embedding
|
||||
chosen_embeddings = [filename2embedding[f] for f in st.session_state.chosen_files]
|
||||
chosen_embeddings = np.stack(chosen_embeddings)
|
||||
user_emotion_vector = np.mean(chosen_embeddings, axis=0)
|
||||
|
||||
# Отображение информации о векторе эмоций
|
||||
with st.expander("Подробности о векторной модели эмоций"):
|
||||
st.write(f"Размерность embedding: {len(user_emotion_vector)}")
|
||||
st.write("Первые 10 значений:")
|
||||
st.json({f"Dim {i}": float(val) for i, val in enumerate(user_emotion_vector[:10])})
|
||||
|
||||
# Визуализация
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
# Гистограмма первых 30 измерений
|
||||
plt.figure(figsize=(8,4))
|
||||
plt.bar(range(min(30, len(user_emotion_vector))), user_emotion_vector[:30])
|
||||
plt.xlabel("Embedding dimension")
|
||||
plt.ylabel("Value")
|
||||
plt.title("Распределение значений embedding (первые 30 измерений)")
|
||||
st.pyplot(plt)
|
||||
|
||||
with col2:
|
||||
# Круговая диаграмма средних значений по блокам
|
||||
if len(user_emotion_vector) > 16:
|
||||
block_size = len(user_emotion_vector) // 4
|
||||
block_means = [
|
||||
np.mean(user_emotion_vector[i*block_size:(i+1)*block_size])
|
||||
for i in range(4)
|
||||
]
|
||||
plt.figure(figsize=(8,4))
|
||||
plt.pie(block_means, labels=[f"Block {i+1}" for i in range(4)], autopct='%1.1f%%')
|
||||
plt.title("Распределение эмоциональных блоков")
|
||||
st.pyplot(plt)
|
||||
|
||||
# Кнопка для сохранения результатов
|
||||
if st.button("Сохранить результаты"):
|
||||
# Сохранение в файл (например, JSON)
|
||||
results = {
|
||||
"chosen_files": st.session_state.chosen_files,
|
||||
"user_emotion_vector": user_emotion_vector.tolist(),
|
||||
"timestamp": pd.Timestamp.now().isoformat()
|
||||
}
|
||||
|
||||
# Сохраняем в текущую директорию
|
||||
save_path = Path("user_results.json")
|
||||
with open(save_path, 'w') as f:
|
||||
import json
|
||||
json.dump(results, f)
|
||||
|
||||
st.success(f"Результаты сохранены в {save_path}")
|
||||
|
||||
if st.button("Начать заново"):
|
||||
st.session_state.round_num = 0
|
||||
st.session_state.chosen_files = []
|
||||
st.rerun()
|
||||
Reference in New Issue
Block a user