Beta V.1.1

This commit is contained in:
zin
2026-05-06 21:11:15 +00:00
parent 5290554d70
commit 9954603043
4 changed files with 156 additions and 14 deletions
+28 -7
View File
@@ -3,24 +3,45 @@ from pathlib import Path
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from music_engine.matcher import MusicMatcher from music_engine.matcher import MusicMatcher
from music_engine.image_processor import ImageProcessor
# Определяем базовую директорию (папка src)
BASE_DIR = Path(__file__).resolve().parent
@st.cache_resource @st.cache_resource
def load_music_engine(): def load_music_engine():
"""Загрузка базы данных и модели регрессора.""" """Загрузка базы данных и модели регрессора."""
base_dir = Path(__file__).resolve().parent # music_db.csv лежит в dataset/DEAM/ (на уровень выше от src)
db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv" db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
model_path = base_dir / "music_engine" / "va_regressor.pkl" # va_regressor.pkl лежит в src/music_engine/
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
if not db_path.exists(): if not db_path.exists():
print(f"⚠️ Файл базы {db_path} не найден!")
return None return None
return MusicMatcher(db_path=db_path, model_path=model_path) return MusicMatcher(db_path=db_path, model_path=model_path)
@st.cache_resource
def load_image_processor():
"""Загрузка ResNet-50 для извлечения признаков на лету."""
# Файл весов лежит в той же папке src, что и этот скрипт
model_path = BASE_DIR / "emoset_resnet50_best.pth"
if not model_path.exists():
print(f"❌ КРИТИЧЕСКАЯ ОШИБКА: Веса не найдены по пути: {model_path}")
# Если не нашли в src, попробуем поискать в корне проекта на всякий случай
model_path = BASE_DIR.parent / "emoset_resnet50_best.pth"
return ImageProcessor(model_path=model_path)
@st.cache_data @st.cache_data
def load_emoset_data(): def load_emoset_data():
"""Загрузка тестовой выборки EmoSet для первой вкладки.""" """Загрузка тестовой выборки EmoSet для первой вкладки."""
csv_path = Path("./dataset/EmoSet-118K/test/labels.csv") # Пути относительно корня проекта
img_dir = Path("./dataset/EmoSet-118K/test/images") csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
emb_path = Path("./src/emoset_test_embeddings.npy") img_dir = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
lbl_path = Path("./src/emoset_test_labels.npy") emb_path = BASE_DIR / "emoset_test_embeddings.npy"
lbl_path = BASE_DIR / "emoset_test_labels.npy"
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]): if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
return None, None, None, None return None, None, None, None
+10 -7
View File
@@ -2,8 +2,10 @@ import streamlit as st
import sys import sys
import os import os
import subprocess import subprocess
from data_loader import load_music_engine, load_emoset_data
from data_loader import load_music_engine, load_emoset_data, load_image_processor
from tabs.tab_dataset import render_dataset_tab from tabs.tab_dataset import render_dataset_tab
from tabs.tab_live import render_live_tab
# ---------------------------- # ----------------------------
# 1️⃣ Запуск приложения # 1️⃣ Запуск приложения
@@ -21,20 +23,21 @@ st.set_page_config(page_title="Thesis Demo", layout="wide")
# 2️⃣ Инициализация движка и данных # 2️⃣ Инициализация движка и данных
# ---------------------------- # ----------------------------
matcher = load_music_engine() matcher = load_music_engine()
image_processor = load_image_processor()
image_files, embeddings, labels_array, images_path = load_emoset_data() image_files, embeddings, labels_array, images_path = load_emoset_data()
# ---------------------------- # ----------------------------
# 3️⃣ Интерфейс и Вкладки # 3️⃣ Интерфейс и Вкладки
# ---------------------------- # ----------------------------
st.title("🖼️ Эмоциональный генератор плейлистов") st.title("🖼️ Генератор саундтреков (Research Demo)")
# Создаем две вкладки tab1, tab2 = st.tabs(["📊 Отладка (Датасет EmoSet)", "📸 Анализ событий (Свои фото)"])
tab1, tab2 = st.tabs(["📊 Анализ EmoSet (Отладка)", "📸 Анализ своих фото (Live)"])
with tab1: with tab1:
render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path) render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path)
with tab2: with tab2:
st.info("🚀 Модуль загрузки пользовательских фотографий и извлечения признаков 'на лету'.") if image_processor:
st.write("Скоро здесь появится drag-and-drop интерфейс для тестирования ваших собственных изображений.") render_live_tab(matcher, image_processor)
# TODO: render_live_tab(matcher) else:
st.error("Система обработки изображений недоступна (не найдены веса ResNet).")
+45
View File
@@ -0,0 +1,45 @@
import torch
import torchvision.transforms as T
from PIL import Image
import timm
from pathlib import Path
import numpy as np
class ImageProcessor:
def __init__(self, model_path: Path | str):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Загружаем базовую архитектуру, как при обучении EmoSet
self.model = timm.create_model('resnet50', pretrained=False, num_classes=8)
# Подгружаем обученные веса
if Path(model_path).exists():
# map_location позволяет загрузить модель на CPU, если нет видеокарты
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
print(f"✅ Веса ResNet-50 успешно загружены из {model_path}")
else:
print(f"⚠️ ОШИБКА: Файл весов {model_path} не найден! Модель будет выдавать случайный шум.")
# Удаляем последний слой (классификатор на 8 эмоций),
# чтобы на выходе получать сырой вектор (embedding) на 2048 чисел
self.model.fc = torch.nn.Identity()
self.model.to(self.device)
self.model.eval()
# Стандартные трансформации ImageNet (строго как при обучении)
self.transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
@torch.no_grad()
def extract_embedding(self, image: Image.Image) -> np.ndarray:
"""Принимает PIL Image, возвращает numpy-вектор."""
# Переводим в RGB (на случай если загрузят PNG с прозрачностью или ЧБ)
img_rgb = image.convert('RGB')
img_tensor = self.transform(img_rgb).unsqueeze(0).to(self.device)
embedding = self.model(img_tensor)
return embedding.cpu().numpy().flatten()
+73
View File
@@ -0,0 +1,73 @@
import streamlit as st
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
def render_live_tab(matcher, image_processor):
st.write("Загрузите фотографии с вашего устройства (например, снимки с недавней поездки или прогулки). Система проанализирует их эмоциональный фон и подберет подходящий саундтрек.")
# Drag-and-drop интерфейс для загрузки нескольких файлов
uploaded_files = st.file_uploader(
"Перетащите изображения сюда",
type=['png', 'jpg', 'jpeg'],
accept_multiple_files=True
)
if uploaded_files:
st.subheader("Загруженные образы:")
# Показываем миниатюры загруженных фото
cols = st.columns(min(len(uploaded_files), 5))
images = []
for i, file in enumerate(uploaded_files):
img = Image.open(file)
images.append(img)
with cols[i % 5]:
st.image(img, use_container_width=True)
if st.button("🎵 Сгенерировать саундтрек события", type="primary", use_container_width=True):
with st.spinner("Анализируем визуальные признаки нейросетью..."):
all_v, all_a = [], []
# Прогоняем каждое фото через пайплайн: ResNet -> Ridge Regressor
for img in images:
embedding = image_processor.extract_embedding(img)
v, a = matcher.predict_va(embedding)
all_v.append(v)
all_a.append(a)
# Late Fusion: усредняем результаты
target_v, target_a = np.mean(all_v), np.mean(all_a)
playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5)
st.success("✅ Саундтрек сформирован!")
# ВЫВОД РЕЗУЛЬТАТОВ (аналогично первой вкладке)
col_left, col_right = st.columns([1, 2])
with col_left:
st.header("📊 Профиль события")
st.metric("Позитивность (Valence)", f"{target_v:.2f}")
st.metric("Энергия (Arousal)", f"{target_a:.2f}")
fig, ax = plt.subplots(figsize=(4, 4))
ax.set_xlim(1, 9); ax.set_ylim(1, 9)
ax.axhline(5, color='gray', lw=1, ls='--'); ax.axvline(5, color='gray', lw=1, ls='--')
ax.scatter(target_v, target_a, color='green', s=150, edgecolors='white', zorder=5)
ax.set_xlabel("Valence"); ax.set_ylabel("Arousal")
st.pyplot(fig)
with col_right:
st.header("🎵 Плейлист")
for _, row in playlist.iterrows():
with st.container(border=True):
c1, c2 = st.columns([1, 3])
with c1:
st.write(f"**ID:** {int(row['song_id'])}")
st.caption(f"L2 Dist: {row['distance']:.2f}")
with c2:
audio_path = matcher.get_audio_path(row['song_id'])
if audio_path:
st.audio(str(audio_path))
else:
st.warning(f"Файл {int(row['song_id'])}.mp3 не найден")