diff --git a/src/data_loader.py b/src/data_loader.py index 917b162..5f6fab1 100644 --- a/src/data_loader.py +++ b/src/data_loader.py @@ -3,24 +3,45 @@ from pathlib import Path import pandas as pd import numpy as np from music_engine.matcher import MusicMatcher +from music_engine.image_processor import ImageProcessor + +# Определяем базовую директорию (папка src) +BASE_DIR = Path(__file__).resolve().parent @st.cache_resource def load_music_engine(): """Загрузка базы данных и модели регрессора.""" - base_dir = Path(__file__).resolve().parent - db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv" - model_path = base_dir / "music_engine" / "va_regressor.pkl" + # music_db.csv лежит в dataset/DEAM/ (на уровень выше от src) + db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv" + # va_regressor.pkl лежит в src/music_engine/ + model_path = BASE_DIR / "music_engine" / "va_regressor.pkl" + if not db_path.exists(): + print(f"⚠️ Файл базы {db_path} не найден!") return None return MusicMatcher(db_path=db_path, model_path=model_path) +@st.cache_resource +def load_image_processor(): + """Загрузка ResNet-50 для извлечения признаков на лету.""" + # Файл весов лежит в той же папке src, что и этот скрипт + model_path = BASE_DIR / "emoset_resnet50_best.pth" + + if not model_path.exists(): + print(f"❌ КРИТИЧЕСКАЯ ОШИБКА: Веса не найдены по пути: {model_path}") + # Если не нашли в src, попробуем поискать в корне проекта на всякий случай + model_path = BASE_DIR.parent / "emoset_resnet50_best.pth" + + return ImageProcessor(model_path=model_path) + @st.cache_data def load_emoset_data(): """Загрузка тестовой выборки EmoSet для первой вкладки.""" - csv_path = Path("./dataset/EmoSet-118K/test/labels.csv") - img_dir = Path("./dataset/EmoSet-118K/test/images") - emb_path = Path("./src/emoset_test_embeddings.npy") - lbl_path = Path("./src/emoset_test_labels.npy") + # Пути относительно корня проекта + csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv" + img_dir = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images" + emb_path = BASE_DIR / "emoset_test_embeddings.npy" + lbl_path = BASE_DIR / "emoset_test_labels.npy" if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]): return None, None, None, None diff --git a/src/main.py b/src/main.py index 5c38129..2b9d04b 100644 --- a/src/main.py +++ b/src/main.py @@ -2,8 +2,10 @@ import streamlit as st import sys import os import subprocess -from data_loader import load_music_engine, load_emoset_data + +from data_loader import load_music_engine, load_emoset_data, load_image_processor from tabs.tab_dataset import render_dataset_tab +from tabs.tab_live import render_live_tab # ---------------------------- # 1️⃣ Запуск приложения @@ -21,20 +23,21 @@ st.set_page_config(page_title="Thesis Demo", layout="wide") # 2️⃣ Инициализация движка и данных # ---------------------------- matcher = load_music_engine() +image_processor = load_image_processor() image_files, embeddings, labels_array, images_path = load_emoset_data() # ---------------------------- # 3️⃣ Интерфейс и Вкладки # ---------------------------- -st.title("🖼️ Эмоциональный генератор плейлистов") +st.title("🖼️ Генератор саундтреков (Research Demo)") -# Создаем две вкладки -tab1, tab2 = st.tabs(["📊 Анализ EmoSet (Отладка)", "📸 Анализ своих фото (Live)"]) +tab1, tab2 = st.tabs(["📊 Отладка (Датасет EmoSet)", "📸 Анализ событий (Свои фото)"]) with tab1: render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path) with tab2: - st.info("🚀 Модуль загрузки пользовательских фотографий и извлечения признаков 'на лету'.") - st.write("Скоро здесь появится drag-and-drop интерфейс для тестирования ваших собственных изображений.") - # TODO: render_live_tab(matcher) \ No newline at end of file + if image_processor: + render_live_tab(matcher, image_processor) + else: + st.error("Система обработки изображений недоступна (не найдены веса ResNet).") \ No newline at end of file diff --git a/src/music_engine/image_processor.py b/src/music_engine/image_processor.py new file mode 100644 index 0000000..bc47b2c --- /dev/null +++ b/src/music_engine/image_processor.py @@ -0,0 +1,45 @@ +import torch +import torchvision.transforms as T +from PIL import Image +import timm +from pathlib import Path +import numpy as np + +class ImageProcessor: + def __init__(self, model_path: Path | str): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Загружаем базовую архитектуру, как при обучении EmoSet + self.model = timm.create_model('resnet50', pretrained=False, num_classes=8) + + # Подгружаем обученные веса + if Path(model_path).exists(): + # map_location позволяет загрузить модель на CPU, если нет видеокарты + self.model.load_state_dict(torch.load(model_path, map_location=self.device)) + print(f"✅ Веса ResNet-50 успешно загружены из {model_path}") + else: + print(f"⚠️ ОШИБКА: Файл весов {model_path} не найден! Модель будет выдавать случайный шум.") + + # Удаляем последний слой (классификатор на 8 эмоций), + # чтобы на выходе получать сырой вектор (embedding) на 2048 чисел + self.model.fc = torch.nn.Identity() + + self.model.to(self.device) + self.model.eval() + + # Стандартные трансформации ImageNet (строго как при обучении) + self.transform = T.Compose([ + T.Resize((224, 224)), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + @torch.no_grad() + def extract_embedding(self, image: Image.Image) -> np.ndarray: + """Принимает PIL Image, возвращает numpy-вектор.""" + # Переводим в RGB (на случай если загрузят PNG с прозрачностью или ЧБ) + img_rgb = image.convert('RGB') + img_tensor = self.transform(img_rgb).unsqueeze(0).to(self.device) + + embedding = self.model(img_tensor) + return embedding.cpu().numpy().flatten() \ No newline at end of file diff --git a/src/tabs/tab_live.py b/src/tabs/tab_live.py new file mode 100644 index 0000000..4e30c34 --- /dev/null +++ b/src/tabs/tab_live.py @@ -0,0 +1,73 @@ +import streamlit as st +import numpy as np +from PIL import Image +import matplotlib.pyplot as plt + +def render_live_tab(matcher, image_processor): + st.write("Загрузите фотографии с вашего устройства (например, снимки с недавней поездки или прогулки). Система проанализирует их эмоциональный фон и подберет подходящий саундтрек.") + + # Drag-and-drop интерфейс для загрузки нескольких файлов + uploaded_files = st.file_uploader( + "Перетащите изображения сюда", + type=['png', 'jpg', 'jpeg'], + accept_multiple_files=True + ) + + if uploaded_files: + st.subheader("Загруженные образы:") + + # Показываем миниатюры загруженных фото + cols = st.columns(min(len(uploaded_files), 5)) + images = [] + for i, file in enumerate(uploaded_files): + img = Image.open(file) + images.append(img) + with cols[i % 5]: + st.image(img, use_container_width=True) + + if st.button("🎵 Сгенерировать саундтрек события", type="primary", use_container_width=True): + with st.spinner("Анализируем визуальные признаки нейросетью..."): + all_v, all_a = [], [] + + # Прогоняем каждое фото через пайплайн: ResNet -> Ridge Regressor + for img in images: + embedding = image_processor.extract_embedding(img) + v, a = matcher.predict_va(embedding) + all_v.append(v) + all_a.append(a) + + # Late Fusion: усредняем результаты + target_v, target_a = np.mean(all_v), np.mean(all_a) + playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5) + + st.success("✅ Саундтрек сформирован!") + + # ВЫВОД РЕЗУЛЬТАТОВ (аналогично первой вкладке) + col_left, col_right = st.columns([1, 2]) + + with col_left: + st.header("📊 Профиль события") + st.metric("Позитивность (Valence)", f"{target_v:.2f}") + st.metric("Энергия (Arousal)", f"{target_a:.2f}") + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.set_xlim(1, 9); ax.set_ylim(1, 9) + ax.axhline(5, color='gray', lw=1, ls='--'); ax.axvline(5, color='gray', lw=1, ls='--') + ax.scatter(target_v, target_a, color='green', s=150, edgecolors='white', zorder=5) + ax.set_xlabel("Valence"); ax.set_ylabel("Arousal") + st.pyplot(fig) + + with col_right: + st.header("🎵 Плейлист") + for _, row in playlist.iterrows(): + with st.container(border=True): + c1, c2 = st.columns([1, 3]) + with c1: + st.write(f"**ID:** {int(row['song_id'])}") + st.caption(f"L2 Dist: {row['distance']:.2f}") + with c2: + audio_path = matcher.get_audio_path(row['song_id']) + if audio_path: + st.audio(str(audio_path)) + else: + st.warning(f"Файл {int(row['song_id'])}.mp3 не найден") \ No newline at end of file