Beta V.1.1
This commit is contained in:
+28
-7
@@ -3,24 +3,45 @@ from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from music_engine.matcher import MusicMatcher
|
||||
from music_engine.image_processor import ImageProcessor
|
||||
|
||||
# Определяем базовую директорию (папка src)
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
|
||||
@st.cache_resource
|
||||
def load_music_engine():
|
||||
"""Загрузка базы данных и модели регрессора."""
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv"
|
||||
model_path = base_dir / "music_engine" / "va_regressor.pkl"
|
||||
# music_db.csv лежит в dataset/DEAM/ (на уровень выше от src)
|
||||
db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
|
||||
# va_regressor.pkl лежит в src/music_engine/
|
||||
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
|
||||
|
||||
if not db_path.exists():
|
||||
print(f"⚠️ Файл базы {db_path} не найден!")
|
||||
return None
|
||||
return MusicMatcher(db_path=db_path, model_path=model_path)
|
||||
|
||||
@st.cache_resource
|
||||
def load_image_processor():
|
||||
"""Загрузка ResNet-50 для извлечения признаков на лету."""
|
||||
# Файл весов лежит в той же папке src, что и этот скрипт
|
||||
model_path = BASE_DIR / "emoset_resnet50_best.pth"
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"❌ КРИТИЧЕСКАЯ ОШИБКА: Веса не найдены по пути: {model_path}")
|
||||
# Если не нашли в src, попробуем поискать в корне проекта на всякий случай
|
||||
model_path = BASE_DIR.parent / "emoset_resnet50_best.pth"
|
||||
|
||||
return ImageProcessor(model_path=model_path)
|
||||
|
||||
@st.cache_data
|
||||
def load_emoset_data():
|
||||
"""Загрузка тестовой выборки EmoSet для первой вкладки."""
|
||||
csv_path = Path("./dataset/EmoSet-118K/test/labels.csv")
|
||||
img_dir = Path("./dataset/EmoSet-118K/test/images")
|
||||
emb_path = Path("./src/emoset_test_embeddings.npy")
|
||||
lbl_path = Path("./src/emoset_test_labels.npy")
|
||||
# Пути относительно корня проекта
|
||||
csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
|
||||
img_dir = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
|
||||
emb_path = BASE_DIR / "emoset_test_embeddings.npy"
|
||||
lbl_path = BASE_DIR / "emoset_test_labels.npy"
|
||||
|
||||
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
|
||||
return None, None, None, None
|
||||
|
||||
+10
-7
@@ -2,8 +2,10 @@ import streamlit as st
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
from data_loader import load_music_engine, load_emoset_data
|
||||
|
||||
from data_loader import load_music_engine, load_emoset_data, load_image_processor
|
||||
from tabs.tab_dataset import render_dataset_tab
|
||||
from tabs.tab_live import render_live_tab
|
||||
|
||||
# ----------------------------
|
||||
# 1️⃣ Запуск приложения
|
||||
@@ -21,20 +23,21 @@ st.set_page_config(page_title="Thesis Demo", layout="wide")
|
||||
# 2️⃣ Инициализация движка и данных
|
||||
# ----------------------------
|
||||
matcher = load_music_engine()
|
||||
image_processor = load_image_processor()
|
||||
image_files, embeddings, labels_array, images_path = load_emoset_data()
|
||||
|
||||
# ----------------------------
|
||||
# 3️⃣ Интерфейс и Вкладки
|
||||
# ----------------------------
|
||||
st.title("🖼️ Эмоциональный генератор плейлистов")
|
||||
st.title("🖼️ Генератор саундтреков (Research Demo)")
|
||||
|
||||
# Создаем две вкладки
|
||||
tab1, tab2 = st.tabs(["📊 Анализ EmoSet (Отладка)", "📸 Анализ своих фото (Live)"])
|
||||
tab1, tab2 = st.tabs(["📊 Отладка (Датасет EmoSet)", "📸 Анализ событий (Свои фото)"])
|
||||
|
||||
with tab1:
|
||||
render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path)
|
||||
|
||||
with tab2:
|
||||
st.info("🚀 Модуль загрузки пользовательских фотографий и извлечения признаков 'на лету'.")
|
||||
st.write("Скоро здесь появится drag-and-drop интерфейс для тестирования ваших собственных изображений.")
|
||||
# TODO: render_live_tab(matcher)
|
||||
if image_processor:
|
||||
render_live_tab(matcher, image_processor)
|
||||
else:
|
||||
st.error("Система обработки изображений недоступна (не найдены веса ResNet).")
|
||||
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
import timm
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
class ImageProcessor:
|
||||
def __init__(self, model_path: Path | str):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Загружаем базовую архитектуру, как при обучении EmoSet
|
||||
self.model = timm.create_model('resnet50', pretrained=False, num_classes=8)
|
||||
|
||||
# Подгружаем обученные веса
|
||||
if Path(model_path).exists():
|
||||
# map_location позволяет загрузить модель на CPU, если нет видеокарты
|
||||
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
||||
print(f"✅ Веса ResNet-50 успешно загружены из {model_path}")
|
||||
else:
|
||||
print(f"⚠️ ОШИБКА: Файл весов {model_path} не найден! Модель будет выдавать случайный шум.")
|
||||
|
||||
# Удаляем последний слой (классификатор на 8 эмоций),
|
||||
# чтобы на выходе получать сырой вектор (embedding) на 2048 чисел
|
||||
self.model.fc = torch.nn.Identity()
|
||||
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# Стандартные трансформации ImageNet (строго как при обучении)
|
||||
self.transform = T.Compose([
|
||||
T.Resize((224, 224)),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_embedding(self, image: Image.Image) -> np.ndarray:
|
||||
"""Принимает PIL Image, возвращает numpy-вектор."""
|
||||
# Переводим в RGB (на случай если загрузят PNG с прозрачностью или ЧБ)
|
||||
img_rgb = image.convert('RGB')
|
||||
img_tensor = self.transform(img_rgb).unsqueeze(0).to(self.device)
|
||||
|
||||
embedding = self.model(img_tensor)
|
||||
return embedding.cpu().numpy().flatten()
|
||||
@@ -0,0 +1,73 @@
|
||||
import streamlit as st
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def render_live_tab(matcher, image_processor):
|
||||
st.write("Загрузите фотографии с вашего устройства (например, снимки с недавней поездки или прогулки). Система проанализирует их эмоциональный фон и подберет подходящий саундтрек.")
|
||||
|
||||
# Drag-and-drop интерфейс для загрузки нескольких файлов
|
||||
uploaded_files = st.file_uploader(
|
||||
"Перетащите изображения сюда",
|
||||
type=['png', 'jpg', 'jpeg'],
|
||||
accept_multiple_files=True
|
||||
)
|
||||
|
||||
if uploaded_files:
|
||||
st.subheader("Загруженные образы:")
|
||||
|
||||
# Показываем миниатюры загруженных фото
|
||||
cols = st.columns(min(len(uploaded_files), 5))
|
||||
images = []
|
||||
for i, file in enumerate(uploaded_files):
|
||||
img = Image.open(file)
|
||||
images.append(img)
|
||||
with cols[i % 5]:
|
||||
st.image(img, use_container_width=True)
|
||||
|
||||
if st.button("🎵 Сгенерировать саундтрек события", type="primary", use_container_width=True):
|
||||
with st.spinner("Анализируем визуальные признаки нейросетью..."):
|
||||
all_v, all_a = [], []
|
||||
|
||||
# Прогоняем каждое фото через пайплайн: ResNet -> Ridge Regressor
|
||||
for img in images:
|
||||
embedding = image_processor.extract_embedding(img)
|
||||
v, a = matcher.predict_va(embedding)
|
||||
all_v.append(v)
|
||||
all_a.append(a)
|
||||
|
||||
# Late Fusion: усредняем результаты
|
||||
target_v, target_a = np.mean(all_v), np.mean(all_a)
|
||||
playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5)
|
||||
|
||||
st.success("✅ Саундтрек сформирован!")
|
||||
|
||||
# ВЫВОД РЕЗУЛЬТАТОВ (аналогично первой вкладке)
|
||||
col_left, col_right = st.columns([1, 2])
|
||||
|
||||
with col_left:
|
||||
st.header("📊 Профиль события")
|
||||
st.metric("Позитивность (Valence)", f"{target_v:.2f}")
|
||||
st.metric("Энергия (Arousal)", f"{target_a:.2f}")
|
||||
|
||||
fig, ax = plt.subplots(figsize=(4, 4))
|
||||
ax.set_xlim(1, 9); ax.set_ylim(1, 9)
|
||||
ax.axhline(5, color='gray', lw=1, ls='--'); ax.axvline(5, color='gray', lw=1, ls='--')
|
||||
ax.scatter(target_v, target_a, color='green', s=150, edgecolors='white', zorder=5)
|
||||
ax.set_xlabel("Valence"); ax.set_ylabel("Arousal")
|
||||
st.pyplot(fig)
|
||||
|
||||
with col_right:
|
||||
st.header("🎵 Плейлист")
|
||||
for _, row in playlist.iterrows():
|
||||
with st.container(border=True):
|
||||
c1, c2 = st.columns([1, 3])
|
||||
with c1:
|
||||
st.write(f"**ID:** {int(row['song_id'])}")
|
||||
st.caption(f"L2 Dist: {row['distance']:.2f}")
|
||||
with c2:
|
||||
audio_path = matcher.get_audio_path(row['song_id'])
|
||||
if audio_path:
|
||||
st.audio(str(audio_path))
|
||||
else:
|
||||
st.warning(f"Файл {int(row['song_id'])}.mp3 не найден")
|
||||
Reference in New Issue
Block a user