From 95595a5a5e944f806574f427fb6b1048b0b09c39 Mon Sep 17 00:00:00 2001 From: zin Date: Wed, 6 May 2026 19:48:18 +0000 Subject: [PATCH] Beta v.1.0 --- src/main.py | 157 +++++++++---------- src/music_engine/matcher.py | 68 ++++---- src/{ => scripts}/download emo_dataset.ipynb | 0 src/{ => scripts}/download_dataset.py | 0 src/scripts/prep_deam.ipynb | 138 +++++----------- 5 files changed, 155 insertions(+), 208 deletions(-) rename src/{ => scripts}/download emo_dataset.ipynb (100%) rename src/{ => scripts}/download_dataset.py (100%) diff --git a/src/main.py b/src/main.py index b3e606a..2940e5a 100644 --- a/src/main.py +++ b/src/main.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt from music_engine.matcher import MusicMatcher # ---------------------------- -# 1️⃣ Конфигурация и запуск +# 1️⃣ Запуск Streamlit # ---------------------------- if __name__ == "__main__": import os @@ -23,132 +23,123 @@ if __name__ == "__main__": subprocess.run(cmd) sys.exit() -# Конфигурация путей -DATA_ROOT = Path("./dataset/EmoSet-118K/test") -IMAGES_DIR = DATA_ROOT / "images" -LABELS_CSV = DATA_ROOT / "labels.csv" +# Словарь для отладки +EMO_NAMES = {0: "amusement", 1: "anger", 2: "awe", 3: "contentment", + 4: "disgust", 5: "excitement", 6: "fear", 7: "sadness"} -EMBEDDINGS_PATH = Path("./src/emoset_test_embeddings.npy") -LABELS_PATH = Path("./src/emoset_test_labels.npy") - -NUM_CHOICES = 6 -TOTAL_ROUNDS = 10 - -# Словарь для расшифровки меток (алфавитный порядок EmoSet) -EMO_NAMES = { - 0: "amusement (веселье)", - 1: "anger (гнев)", - 2: "awe (трепет)", - 3: "contentment (удовлетворение)", - 4: "disgust (отвращение)", - 5: "excitement (возбуждение)", - 6: "fear (страх)", - 7: "sadness (грусть)" -} - -st.set_page_config(page_title="Debug Mode: EmoSet & Music", layout="wide") +st.set_page_config(page_title="Thesis Demo: Image-Music", layout="wide") @st.cache_resource def load_music_engine(): base_dir = Path(__file__).resolve().parent db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv" model_path = base_dir / "music_engine" / "va_regressor.pkl" - if not db_path.exists(): return None + if not db_path.exists(): + return None return MusicMatcher(db_path=db_path, model_path=model_path) matcher = load_music_engine() @st.cache_data def load_emoset_data(): - if not IMAGES_DIR.exists() or not EMBEDDINGS_PATH.exists() or not LABELS_CSV.exists(): - return None, None, None - df = pd.read_csv(LABELS_CSV) - image_files = df['filename'].tolist() - embeddings = np.load(EMBEDDINGS_PATH) - labels_array = np.load(LABELS_PATH) - return image_files, embeddings, labels_array + csv_path = Path("./dataset/EmoSet-118K/test/labels.csv") + img_dir = Path("./dataset/EmoSet-118K/test/images") + emb_path = Path("./src/emoset_test_embeddings.npy") + lbl_path = Path("./src/emoset_test_labels.npy") -image_files, embeddings, labels_array = load_emoset_data() + if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]): + return None, None, None, None + + df = pd.read_csv(csv_path) + image_list = df['filename'].tolist() + embs = np.load(emb_path) + lbls = np.load(lbl_path) + + return image_list, embs, lbls, img_dir + +image_files, embeddings, labels_array, images_path = load_emoset_data() # ---------------------------- -# 3️⃣ Логика приложения +# 2️⃣ Основной интерфейс # ---------------------------- if image_files is None: - st.error("Данные не найдены.") + st.error("Ошибка загрузки данных EmoSet. Проверьте пути.") else: if 'round' not in st.session_state: st.session_state.round = 1 st.session_state.chosen_indices = [] - st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES) + st.session_state.current_options = random.sample(range(len(image_files)), 6) - st.title("🧪 Отладка эмоционального маппинга") - - if st.session_state.round <= TOTAL_ROUNDS: - st.write(f"**Раунд {st.session_state.round} из {TOTAL_ROUNDS}**") + st.title("🖼️ Эмоциональный подбор музыки") + + if st.session_state.round <= 10: + st.subheader(f"Раунд {st.session_state.round} из 10") + st.write("Выберите изображение, соответствующее вашему настроению:") - # Сетка выбора с отладочной информацией cols = st.columns(3) for i, idx in enumerate(st.session_state.current_options): with cols[i % 3]: - img_path = IMAGES_DIR / image_files[idx] - img = Image.open(img_path) + img_name = image_files[idx] + img = Image.open(images_path / img_name) st.image(img, use_container_width=True) - # --- ИНФОРМАЦИЯ ДЛЯ ОТЛАДКИ --- + # Информация для отладки if matcher: - v, a = matcher.predict_va(embeddings[idx]) - label_id = labels_array[idx] - label_name = EMO_NAMES.get(label_id, "Unknown") - - st.caption(f"**Класс:** {label_name}") - st.caption(f"**Прогноз:** V: {v:.2f} | A: {a:.2f}") - # ------------------------------ + v_p, a_p = matcher.predict_va(embeddings[idx]) + gt_label = EMO_NAMES.get(labels_array[idx], "unknown") + st.caption(f"GT: {gt_label} | Pred: V:{v_p:.1f} A:{a_p:.1f}") if st.button(f"Выбрать образ {i+1}", key=f"btn_{idx}", use_container_width=True): st.session_state.chosen_indices.append(idx) st.session_state.round += 1 - if st.session_state.round <= TOTAL_ROUNDS: - st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES) + if st.session_state.round <= 10: + st.session_state.current_options = random.sample(range(len(image_files)), 6) st.rerun() else: - # ФИНАЛЬНЫЙ ЭТАП - st.success("✅ Профиль сформирован!") + # РЕЗУЛЬТАТЫ + st.success("✅ Анализ завершен! Ваш эмоциональный профиль готов.") - # Считаем итог - chosen_embs = embeddings[st.session_state.chosen_indices] all_v, all_a = [], [] - for emb in chosen_embs: - v, a = matcher.predict_va(emb) + for idx in st.session_state.chosen_indices: + v, a = matcher.predict_va(embeddings[idx]) all_v.append(v) all_a.append(a) target_v, target_a = np.mean(all_v), np.mean(all_a) - - # Вывод результатов - col1, col2 = st.columns([2, 1]) - - with col1: - st.header("🎵 Ваш плейлист") - distances = np.sqrt((matcher.music_db['valence'] - target_v)**2 + (matcher.music_db['arousal'] - target_a)**2) - playlist = matcher.music_db.copy() - playlist['distance'] = distances - st.table(playlist.sort_values(by='distance').head(5)[['song_id', 'valence', 'arousal', 'distance']]) + playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5) - with col2: - st.header("📊 Профиль") - st.metric("Valence (Итог)", f"{target_v:.2f}") - st.metric("Arousal (Итог)", f"{target_a:.2f}") + col_left, col_right = st.columns([1, 2]) + + with col_left: + st.header("📊 Ваш профиль") + st.metric("Позитивность (Valence)", f"{target_v:.2f}") + st.metric("Энергия (Arousal)", f"{target_a:.2f}") - # Показываем, что именно выбрал пользователь - st.divider() - st.subheader("Выбранные вами образы и их веса:") - sum_cols = st.columns(5) - for i, idx in enumerate(st.session_state.chosen_indices): - with sum_cols[i % 5]: - v_i, a_i = matcher.predict_va(embeddings[idx]) - st.image(Image.open(IMAGES_DIR / image_files[idx]), use_container_width=True) - st.write(f"V:{v_i:.1f} A:{a_i:.1f}") + # График Рассела + fig, ax = plt.subplots(figsize=(4, 4)) + ax.set_xlim(1, 9); ax.set_ylim(1, 9) + ax.axhline(5, color='gray', lw=1, ls='--') + ax.axvline(5, color='gray', lw=1, ls='--') + ax.scatter(target_v, target_a, color='red', s=150, edgecolors='white', zorder=5) + ax.set_xlabel("Valence"); ax.set_ylabel("Arousal") + ax.set_title("Карта эмоций (Модель Рассела)") + st.pyplot(fig) - if st.button("Начать заново"): + with col_right: + st.header("🎵 Рекомендованная музыка") + for _, row in playlist.iterrows(): + with st.container(border=True): + c1, c2 = st.columns([1, 3]) + with c1: + st.write(f"**ID:** {int(row['song_id'])}") + st.caption(f"L2: {row['distance']:.2f}") + with c2: + audio_path = matcher.get_audio_path(row['song_id']) + if audio_path: + st.audio(str(audio_path)) + else: + st.warning(f"Файл {int(row['song_id'])}.mp3 не найден") + + if st.button("Начать заново", type="primary"): for key in list(st.session_state.keys()): del st.session_state[key] st.rerun() \ No newline at end of file diff --git a/src/music_engine/matcher.py b/src/music_engine/matcher.py index cefa65c..9a6e6ea 100644 --- a/src/music_engine/matcher.py +++ b/src/music_engine/matcher.py @@ -5,50 +5,64 @@ import joblib class MusicMatcher: def __init__(self, db_path: Path | str, model_path: Path | str): - # 1. Загрузка базы музыки + """ + Инициализация движка сопоставления музыки. + """ self.music_db = pd.read_csv(db_path) self.music_db['valence'] = pd.to_numeric(self.music_db['valence'], errors='coerce') self.music_db['arousal'] = pd.to_numeric(self.music_db['arousal'], errors='coerce') self.music_db = self.music_db.dropna() - # 2. Загрузка обученного регрессора + self.audio_dir = Path(db_path).parent / "DEAM_audio" / "MEMD_audio" + + if self.audio_dir.exists(): + print(f"✅ Музыкальный архив найден: {self.audio_dir}") + else: + print(f"⚠️ ПРЕДУПРЕЖДЕНИЕ: Папка {self.audio_dir} не найдена!") + if Path(model_path).exists(): self.regressor = joblib.load(model_path) - print("✅ Регрессионная модель успешно загружена.") + print("✅ ML-регрессор загружен.") else: self.regressor = None - print("⚠️ ВНИМАНИЕ: Модель va_regressor.pkl не найдена!") + print("⚠️ Файл модели .pkl не найден.") def predict_va(self, embedding: np.ndarray): - """ - Использование обученной ML-модели (Ridge) для маппинга - эмбеддингов в пространство Valence-Arousal. - """ + """Честный прогноз координат Valence-Arousal.""" if self.regressor is not None: - # Модель ожидает двумерный массив (batch_size, features) emb_2d = embedding.reshape(1, -1) - prediction = self.regressor.predict(emb_2d)[0] # Получаем [Valence, Arousal] - - v, a = prediction[0], prediction[1] - else: - # Fallback на случай, если файл модели потеряется - v, a = 5.0, 5.0 - - return np.clip(v, 1.0, 9.0), np.clip(a, 1.0, 9.0) + prediction = self.regressor.predict(emb_2d)[0] + return np.clip(prediction[0], 1.0, 9.0), np.clip(prediction[1], 1.0, 9.0) + return 5.0, 5.0 - def get_playlist(self, user_vector: np.ndarray, top_k: int = 5): - # 1. Предсказываем координаты через ML-модель - target_v, target_a = self.predict_va(user_vector) - - # 2. Считаем Евклидово расстояние (L2-норма) до треков в базе + def get_audio_path(self, song_id): + """Поиск mp3 файла по его номеру.""" + if not self.audio_dir.exists(): + return None + + clean_id = str(int(float(song_id))) + for ext in ['.mp3', '.wav']: + file_path = self.audio_dir / f"{clean_id}{ext}" + if file_path.exists(): + return file_path + return None + + def find_nearest_tracks(self, target_v: float, target_a: float, top_k: int = 5): + """ + Поиск с использованием Взвешенного Евклидова расстояния (Weighted KNN). + Энергия (Arousal) получает больший вес, так как она сильнее + определяет жанр и ритм композиции. + """ + # Вес для Arousal = 2.0, для Valence = 1.0 + # Это не позволит спокойным трекам (A < 4) попадать в выдачу + # для энергичных запросов (A > 6). distances = np.sqrt( - (self.music_db['valence'] - target_v)**2 + - (self.music_db['arousal'] - target_a)**2 + 1.0 * (self.music_db['valence'] - target_v)**2 + + 2.5 * (self.music_db['arousal'] - target_a)**2 # Жесткий штраф за разницу в энергии ) - # 3. Формируем финальную выдачу df_result = self.music_db.copy() df_result['distance'] = distances - playlist = df_result.sort_values(by='distance').head(top_k) - return target_v, target_a, playlist \ No newline at end of file + # Сортируем по расстоянию и берем топ-K + return df_result.sort_values(by='distance').head(top_k) \ No newline at end of file diff --git a/src/download emo_dataset.ipynb b/src/scripts/download emo_dataset.ipynb similarity index 100% rename from src/download emo_dataset.ipynb rename to src/scripts/download emo_dataset.ipynb diff --git a/src/download_dataset.py b/src/scripts/download_dataset.py similarity index 100% rename from src/download_dataset.py rename to src/scripts/download_dataset.py diff --git a/src/scripts/prep_deam.ipynb b/src/scripts/prep_deam.ipynb index 0beafba..7398a75 100644 --- a/src/scripts/prep_deam.ipynb +++ b/src/scripts/prep_deam.ipynb @@ -2,8 +2,8 @@ "cells": [ { "cell_type": "code", - "execution_count": 6, - "id": "83693ad7", + "execution_count": 5, + "id": "b92e0213", "metadata": {}, "outputs": [], "source": [ @@ -14,119 +14,61 @@ { "cell_type": "code", "execution_count": 7, - "id": "99850a99", + "id": "1763c51e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Читаем файл аннотаций: ../dataset/DEAM/DEAM_Annotations/annotations/annotations per each rater/song_level/static_annotations_songs_1_2000.csv\n" + "✅ УСПЕХ! База создана: ../../dataset/DEAM/music_db.csv\n", + "Всего треков в базе: 1744\n", + "Пример данных:\n", + " song_id valence arousal\n", + "0 2 3.1 3.0\n", + "1 3 3.5 3.3\n", + "2 4 5.7 5.5\n", + "3 5 4.4 5.3\n", + "4 7 5.8 6.4\n" ] } ], "source": [ - "# 1. Ищем файл (поднимаемся из src на уровень выше)\n", - "deam_root = Path(\"../dataset/DEAM\")\n", + "# Точный путь к оригинальным аннотациям\n", + "source_path = Path(\"../../dataset/DEAM/DEAM_Annotations/annotations/annotations averaged per song/song_level/static_annotations_averaged_songs_1_2000.csv\")\n", + "# Путь, куда сохраним очищенную базу для движка\n", + "output_path = Path(\"../../dataset/DEAM/music_db.csv\")\n", "\n", - "# Ищем файл статичных аннотаций. Берем первый попавшийся.\n", - "csv_files = list(deam_root.rglob(\"*static_annotations*.csv\"))\n", - "if not csv_files:\n", - " # Если не нашел static, берем вообще любой csv с аннотациями\n", - " csv_files = list(deam_root.rglob(\"*.csv\"))\n", - "\n", - "if not csv_files:\n", - " # Если путь неверный или файлов нет, скрипт сразу скажет об этом и покажет полный путь\n", - " raise FileNotFoundError(f\"В папке {deam_root.resolve()} не найдено ни одного CSV файла! Проверьте пути.\")\n", - "\n", - "anno_path = csv_files[0]\n", - "print(f\"Читаем файл аннотаций: {anno_path}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "5fbc493f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Оригинальные колонки в файле: ['workerID', ' SongId', ' Valence', ' Arousal']\n" - ] - } - ], - "source": [ - "# 2. Загружаем и чистим колонки\n", - "df = pd.read_csv(anno_path)\n", - "print(\"Оригинальные колонки в файле:\", df.columns.tolist())\n", - "\n", - "# Сносим пробелы по краям и переводим в нижний регистр\n", - "df.columns = [str(c).strip().lower() for c in df.columns]" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "1e28fece", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Успешно найдены колонки -> ID: 'workerid', Valence: 'valence', Arousal: 'arousal'\n" - ] - } - ], - "source": [ - "# 3. Умный поиск колонок\n", - "# Ищем первую колонку, где есть 'id' или 'song'\n", - "song_col = next((c for c in df.columns if 'song' in c or 'id' in c), df.columns[0])\n", - "# Ищем valence (желательно mean, но сойдет любой)\n", - "v_col = next((c for c in df.columns if 'valence' in c and 'mean' in c), \n", - " next((c for c in df.columns if 'valence' in c), None))\n", - "# Ищем arousal\n", - "a_col = next((c for c in df.columns if 'arousal' in c and 'mean' in c), \n", - " next((c for c in df.columns if 'arousal' in c), None))\n", - "\n", - "if not v_col or not a_col:\n", - " raise ValueError(f\"Не смог найти Valence или Arousal! Доступные колонки: {df.columns.tolist()}\")\n", - "\n", - "print(f\"Успешно найдены колонки -> ID: '{song_col}', Valence: '{v_col}', Arousal: '{a_col}'\")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "469f651c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Готово! Музыкальная база сохранена: ../dataset/DEAM/music_db.csv\n" - ] - } - ], - "source": [ - "# 4. Сохраняем результат\n", - "clean_df = df[[song_col, v_col, a_col]].copy()\n", - "clean_df.columns = ['song_id', 'valence', 'arousal']\n", - "\n", - "output_path = deam_root / \"music_db.csv\"\n", - "clean_df.to_csv(output_path, index=False)\n", - "print(f\"Готово! Музыкальная база сохранена: {output_path}\")" + "if not source_path.exists():\n", + " print(f\"❌ Исходный файл не найден по пути: {source_path}\")\n", + "else:\n", + " # skipinitialspace=True уберет лишние пробелы в названиях колонок, если они есть\n", + " df = pd.read_csv(source_path, skipinitialspace=True)\n", + " \n", + " # Берем только нужные колонки (по твоему примеру)\n", + " clean_df = df[['song_id', 'valence_mean', 'arousal_mean']].copy()\n", + " \n", + " # Переименовываем для простоты кода в движке\n", + " clean_df.columns = ['song_id', 'valence', 'arousal']\n", + " \n", + " # Приводим ID к целому числу (2, 3, 4...), чтобы искать файлы '2.mp3'\n", + " clean_df['song_id'] = clean_df['song_id'].astype(int)\n", + " \n", + " # Сохраняем финальный файл\n", + " clean_df.to_csv(output_path, index=False)\n", + " \n", + " print(f\"✅ УСПЕХ! База создана: {output_path}\")\n", + " print(f\"Всего треков в базе: {len(clean_df)}\")\n", + " print(\"Пример данных:\")\n", + " print(clean_df.head())" ] } ], "metadata": { "kernelspec": { - "display_name": "Python (my-python-project)", + "display_name": "Python (thesis)", "language": "python", - "name": "my-python-project" + "name": "thesis" }, "language_info": { "codemirror_mode": {