Beta v.1.0
This commit is contained in:
+71
-80
@@ -8,7 +8,7 @@ import matplotlib.pyplot as plt
|
|||||||
from music_engine.matcher import MusicMatcher
|
from music_engine.matcher import MusicMatcher
|
||||||
|
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
# 1️⃣ Конфигурация и запуск
|
# 1️⃣ Запуск Streamlit
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import os
|
import os
|
||||||
@@ -23,132 +23,123 @@ if __name__ == "__main__":
|
|||||||
subprocess.run(cmd)
|
subprocess.run(cmd)
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
# Конфигурация путей
|
# Словарь для отладки
|
||||||
DATA_ROOT = Path("./dataset/EmoSet-118K/test")
|
EMO_NAMES = {0: "amusement", 1: "anger", 2: "awe", 3: "contentment",
|
||||||
IMAGES_DIR = DATA_ROOT / "images"
|
4: "disgust", 5: "excitement", 6: "fear", 7: "sadness"}
|
||||||
LABELS_CSV = DATA_ROOT / "labels.csv"
|
|
||||||
|
|
||||||
EMBEDDINGS_PATH = Path("./src/emoset_test_embeddings.npy")
|
st.set_page_config(page_title="Thesis Demo: Image-Music", layout="wide")
|
||||||
LABELS_PATH = Path("./src/emoset_test_labels.npy")
|
|
||||||
|
|
||||||
NUM_CHOICES = 6
|
|
||||||
TOTAL_ROUNDS = 10
|
|
||||||
|
|
||||||
# Словарь для расшифровки меток (алфавитный порядок EmoSet)
|
|
||||||
EMO_NAMES = {
|
|
||||||
0: "amusement (веселье)",
|
|
||||||
1: "anger (гнев)",
|
|
||||||
2: "awe (трепет)",
|
|
||||||
3: "contentment (удовлетворение)",
|
|
||||||
4: "disgust (отвращение)",
|
|
||||||
5: "excitement (возбуждение)",
|
|
||||||
6: "fear (страх)",
|
|
||||||
7: "sadness (грусть)"
|
|
||||||
}
|
|
||||||
|
|
||||||
st.set_page_config(page_title="Debug Mode: EmoSet & Music", layout="wide")
|
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
def load_music_engine():
|
def load_music_engine():
|
||||||
base_dir = Path(__file__).resolve().parent
|
base_dir = Path(__file__).resolve().parent
|
||||||
db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv"
|
db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv"
|
||||||
model_path = base_dir / "music_engine" / "va_regressor.pkl"
|
model_path = base_dir / "music_engine" / "va_regressor.pkl"
|
||||||
if not db_path.exists(): return None
|
if not db_path.exists():
|
||||||
|
return None
|
||||||
return MusicMatcher(db_path=db_path, model_path=model_path)
|
return MusicMatcher(db_path=db_path, model_path=model_path)
|
||||||
|
|
||||||
matcher = load_music_engine()
|
matcher = load_music_engine()
|
||||||
|
|
||||||
@st.cache_data
|
@st.cache_data
|
||||||
def load_emoset_data():
|
def load_emoset_data():
|
||||||
if not IMAGES_DIR.exists() or not EMBEDDINGS_PATH.exists() or not LABELS_CSV.exists():
|
csv_path = Path("./dataset/EmoSet-118K/test/labels.csv")
|
||||||
return None, None, None
|
img_dir = Path("./dataset/EmoSet-118K/test/images")
|
||||||
df = pd.read_csv(LABELS_CSV)
|
emb_path = Path("./src/emoset_test_embeddings.npy")
|
||||||
image_files = df['filename'].tolist()
|
lbl_path = Path("./src/emoset_test_labels.npy")
|
||||||
embeddings = np.load(EMBEDDINGS_PATH)
|
|
||||||
labels_array = np.load(LABELS_PATH)
|
|
||||||
return image_files, embeddings, labels_array
|
|
||||||
|
|
||||||
image_files, embeddings, labels_array = load_emoset_data()
|
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
image_list = df['filename'].tolist()
|
||||||
|
embs = np.load(emb_path)
|
||||||
|
lbls = np.load(lbl_path)
|
||||||
|
|
||||||
|
return image_list, embs, lbls, img_dir
|
||||||
|
|
||||||
|
image_files, embeddings, labels_array, images_path = load_emoset_data()
|
||||||
|
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
# 3️⃣ Логика приложения
|
# 2️⃣ Основной интерфейс
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
if image_files is None:
|
if image_files is None:
|
||||||
st.error("Данные не найдены.")
|
st.error("Ошибка загрузки данных EmoSet. Проверьте пути.")
|
||||||
else:
|
else:
|
||||||
if 'round' not in st.session_state:
|
if 'round' not in st.session_state:
|
||||||
st.session_state.round = 1
|
st.session_state.round = 1
|
||||||
st.session_state.chosen_indices = []
|
st.session_state.chosen_indices = []
|
||||||
st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES)
|
st.session_state.current_options = random.sample(range(len(image_files)), 6)
|
||||||
|
|
||||||
st.title("🧪 Отладка эмоционального маппинга")
|
st.title("🖼️ Эмоциональный подбор музыки")
|
||||||
|
|
||||||
if st.session_state.round <= TOTAL_ROUNDS:
|
if st.session_state.round <= 10:
|
||||||
st.write(f"**Раунд {st.session_state.round} из {TOTAL_ROUNDS}**")
|
st.subheader(f"Раунд {st.session_state.round} из 10")
|
||||||
|
st.write("Выберите изображение, соответствующее вашему настроению:")
|
||||||
|
|
||||||
# Сетка выбора с отладочной информацией
|
|
||||||
cols = st.columns(3)
|
cols = st.columns(3)
|
||||||
for i, idx in enumerate(st.session_state.current_options):
|
for i, idx in enumerate(st.session_state.current_options):
|
||||||
with cols[i % 3]:
|
with cols[i % 3]:
|
||||||
img_path = IMAGES_DIR / image_files[idx]
|
img_name = image_files[idx]
|
||||||
img = Image.open(img_path)
|
img = Image.open(images_path / img_name)
|
||||||
st.image(img, use_container_width=True)
|
st.image(img, use_container_width=True)
|
||||||
|
|
||||||
# --- ИНФОРМАЦИЯ ДЛЯ ОТЛАДКИ ---
|
# Информация для отладки
|
||||||
if matcher:
|
if matcher:
|
||||||
v, a = matcher.predict_va(embeddings[idx])
|
v_p, a_p = matcher.predict_va(embeddings[idx])
|
||||||
label_id = labels_array[idx]
|
gt_label = EMO_NAMES.get(labels_array[idx], "unknown")
|
||||||
label_name = EMO_NAMES.get(label_id, "Unknown")
|
st.caption(f"GT: {gt_label} | Pred: V:{v_p:.1f} A:{a_p:.1f}")
|
||||||
|
|
||||||
st.caption(f"**Класс:** {label_name}")
|
|
||||||
st.caption(f"**Прогноз:** V: {v:.2f} | A: {a:.2f}")
|
|
||||||
# ------------------------------
|
|
||||||
|
|
||||||
if st.button(f"Выбрать образ {i+1}", key=f"btn_{idx}", use_container_width=True):
|
if st.button(f"Выбрать образ {i+1}", key=f"btn_{idx}", use_container_width=True):
|
||||||
st.session_state.chosen_indices.append(idx)
|
st.session_state.chosen_indices.append(idx)
|
||||||
st.session_state.round += 1
|
st.session_state.round += 1
|
||||||
if st.session_state.round <= TOTAL_ROUNDS:
|
if st.session_state.round <= 10:
|
||||||
st.session_state.current_options = random.sample(range(len(image_files)), NUM_CHOICES)
|
st.session_state.current_options = random.sample(range(len(image_files)), 6)
|
||||||
st.rerun()
|
st.rerun()
|
||||||
else:
|
else:
|
||||||
# ФИНАЛЬНЫЙ ЭТАП
|
# РЕЗУЛЬТАТЫ
|
||||||
st.success("✅ Профиль сформирован!")
|
st.success("✅ Анализ завершен! Ваш эмоциональный профиль готов.")
|
||||||
|
|
||||||
# Считаем итог
|
|
||||||
chosen_embs = embeddings[st.session_state.chosen_indices]
|
|
||||||
all_v, all_a = [], []
|
all_v, all_a = [], []
|
||||||
for emb in chosen_embs:
|
for idx in st.session_state.chosen_indices:
|
||||||
v, a = matcher.predict_va(emb)
|
v, a = matcher.predict_va(embeddings[idx])
|
||||||
all_v.append(v)
|
all_v.append(v)
|
||||||
all_a.append(a)
|
all_a.append(a)
|
||||||
|
|
||||||
target_v, target_a = np.mean(all_v), np.mean(all_a)
|
target_v, target_a = np.mean(all_v), np.mean(all_a)
|
||||||
|
playlist = matcher.find_nearest_tracks(target_v, target_a, top_k=5)
|
||||||
|
|
||||||
# Вывод результатов
|
col_left, col_right = st.columns([1, 2])
|
||||||
col1, col2 = st.columns([2, 1])
|
|
||||||
|
|
||||||
with col1:
|
with col_left:
|
||||||
st.header("🎵 Ваш плейлист")
|
st.header("📊 Ваш профиль")
|
||||||
distances = np.sqrt((matcher.music_db['valence'] - target_v)**2 + (matcher.music_db['arousal'] - target_a)**2)
|
st.metric("Позитивность (Valence)", f"{target_v:.2f}")
|
||||||
playlist = matcher.music_db.copy()
|
st.metric("Энергия (Arousal)", f"{target_a:.2f}")
|
||||||
playlist['distance'] = distances
|
|
||||||
st.table(playlist.sort_values(by='distance').head(5)[['song_id', 'valence', 'arousal', 'distance']])
|
|
||||||
|
|
||||||
with col2:
|
# График Рассела
|
||||||
st.header("📊 Профиль")
|
fig, ax = plt.subplots(figsize=(4, 4))
|
||||||
st.metric("Valence (Итог)", f"{target_v:.2f}")
|
ax.set_xlim(1, 9); ax.set_ylim(1, 9)
|
||||||
st.metric("Arousal (Итог)", f"{target_a:.2f}")
|
ax.axhline(5, color='gray', lw=1, ls='--')
|
||||||
|
ax.axvline(5, color='gray', lw=1, ls='--')
|
||||||
|
ax.scatter(target_v, target_a, color='red', s=150, edgecolors='white', zorder=5)
|
||||||
|
ax.set_xlabel("Valence"); ax.set_ylabel("Arousal")
|
||||||
|
ax.set_title("Карта эмоций (Модель Рассела)")
|
||||||
|
st.pyplot(fig)
|
||||||
|
|
||||||
# Показываем, что именно выбрал пользователь
|
with col_right:
|
||||||
st.divider()
|
st.header("🎵 Рекомендованная музыка")
|
||||||
st.subheader("Выбранные вами образы и их веса:")
|
for _, row in playlist.iterrows():
|
||||||
sum_cols = st.columns(5)
|
with st.container(border=True):
|
||||||
for i, idx in enumerate(st.session_state.chosen_indices):
|
c1, c2 = st.columns([1, 3])
|
||||||
with sum_cols[i % 5]:
|
with c1:
|
||||||
v_i, a_i = matcher.predict_va(embeddings[idx])
|
st.write(f"**ID:** {int(row['song_id'])}")
|
||||||
st.image(Image.open(IMAGES_DIR / image_files[idx]), use_container_width=True)
|
st.caption(f"L2: {row['distance']:.2f}")
|
||||||
st.write(f"V:{v_i:.1f} A:{a_i:.1f}")
|
with c2:
|
||||||
|
audio_path = matcher.get_audio_path(row['song_id'])
|
||||||
|
if audio_path:
|
||||||
|
st.audio(str(audio_path))
|
||||||
|
else:
|
||||||
|
st.warning(f"Файл {int(row['song_id'])}.mp3 не найден")
|
||||||
|
|
||||||
if st.button("Начать заново"):
|
if st.button("Начать заново", type="primary"):
|
||||||
for key in list(st.session_state.keys()): del st.session_state[key]
|
for key in list(st.session_state.keys()): del st.session_state[key]
|
||||||
st.rerun()
|
st.rerun()
|
||||||
+39
-25
@@ -5,50 +5,64 @@ import joblib
|
|||||||
|
|
||||||
class MusicMatcher:
|
class MusicMatcher:
|
||||||
def __init__(self, db_path: Path | str, model_path: Path | str):
|
def __init__(self, db_path: Path | str, model_path: Path | str):
|
||||||
# 1. Загрузка базы музыки
|
"""
|
||||||
|
Инициализация движка сопоставления музыки.
|
||||||
|
"""
|
||||||
self.music_db = pd.read_csv(db_path)
|
self.music_db = pd.read_csv(db_path)
|
||||||
self.music_db['valence'] = pd.to_numeric(self.music_db['valence'], errors='coerce')
|
self.music_db['valence'] = pd.to_numeric(self.music_db['valence'], errors='coerce')
|
||||||
self.music_db['arousal'] = pd.to_numeric(self.music_db['arousal'], errors='coerce')
|
self.music_db['arousal'] = pd.to_numeric(self.music_db['arousal'], errors='coerce')
|
||||||
self.music_db = self.music_db.dropna()
|
self.music_db = self.music_db.dropna()
|
||||||
|
|
||||||
# 2. Загрузка обученного регрессора
|
self.audio_dir = Path(db_path).parent / "DEAM_audio" / "MEMD_audio"
|
||||||
|
|
||||||
|
if self.audio_dir.exists():
|
||||||
|
print(f"✅ Музыкальный архив найден: {self.audio_dir}")
|
||||||
|
else:
|
||||||
|
print(f"⚠️ ПРЕДУПРЕЖДЕНИЕ: Папка {self.audio_dir} не найдена!")
|
||||||
|
|
||||||
if Path(model_path).exists():
|
if Path(model_path).exists():
|
||||||
self.regressor = joblib.load(model_path)
|
self.regressor = joblib.load(model_path)
|
||||||
print("✅ Регрессионная модель успешно загружена.")
|
print("✅ ML-регрессор загружен.")
|
||||||
else:
|
else:
|
||||||
self.regressor = None
|
self.regressor = None
|
||||||
print("⚠️ ВНИМАНИЕ: Модель va_regressor.pkl не найдена!")
|
print("⚠️ Файл модели .pkl не найден.")
|
||||||
|
|
||||||
def predict_va(self, embedding: np.ndarray):
|
def predict_va(self, embedding: np.ndarray):
|
||||||
"""
|
"""Честный прогноз координат Valence-Arousal."""
|
||||||
Использование обученной ML-модели (Ridge) для маппинга
|
|
||||||
эмбеддингов в пространство Valence-Arousal.
|
|
||||||
"""
|
|
||||||
if self.regressor is not None:
|
if self.regressor is not None:
|
||||||
# Модель ожидает двумерный массив (batch_size, features)
|
|
||||||
emb_2d = embedding.reshape(1, -1)
|
emb_2d = embedding.reshape(1, -1)
|
||||||
prediction = self.regressor.predict(emb_2d)[0] # Получаем [Valence, Arousal]
|
prediction = self.regressor.predict(emb_2d)[0]
|
||||||
|
return np.clip(prediction[0], 1.0, 9.0), np.clip(prediction[1], 1.0, 9.0)
|
||||||
|
return 5.0, 5.0
|
||||||
|
|
||||||
v, a = prediction[0], prediction[1]
|
def get_audio_path(self, song_id):
|
||||||
else:
|
"""Поиск mp3 файла по его номеру."""
|
||||||
# Fallback на случай, если файл модели потеряется
|
if not self.audio_dir.exists():
|
||||||
v, a = 5.0, 5.0
|
return None
|
||||||
|
|
||||||
return np.clip(v, 1.0, 9.0), np.clip(a, 1.0, 9.0)
|
clean_id = str(int(float(song_id)))
|
||||||
|
for ext in ['.mp3', '.wav']:
|
||||||
|
file_path = self.audio_dir / f"{clean_id}{ext}"
|
||||||
|
if file_path.exists():
|
||||||
|
return file_path
|
||||||
|
return None
|
||||||
|
|
||||||
def get_playlist(self, user_vector: np.ndarray, top_k: int = 5):
|
def find_nearest_tracks(self, target_v: float, target_a: float, top_k: int = 5):
|
||||||
# 1. Предсказываем координаты через ML-модель
|
"""
|
||||||
target_v, target_a = self.predict_va(user_vector)
|
Поиск с использованием Взвешенного Евклидова расстояния (Weighted KNN).
|
||||||
|
Энергия (Arousal) получает больший вес, так как она сильнее
|
||||||
# 2. Считаем Евклидово расстояние (L2-норма) до треков в базе
|
определяет жанр и ритм композиции.
|
||||||
|
"""
|
||||||
|
# Вес для Arousal = 2.0, для Valence = 1.0
|
||||||
|
# Это не позволит спокойным трекам (A < 4) попадать в выдачу
|
||||||
|
# для энергичных запросов (A > 6).
|
||||||
distances = np.sqrt(
|
distances = np.sqrt(
|
||||||
(self.music_db['valence'] - target_v)**2 +
|
1.0 * (self.music_db['valence'] - target_v)**2 +
|
||||||
(self.music_db['arousal'] - target_a)**2
|
2.5 * (self.music_db['arousal'] - target_a)**2 # Жесткий штраф за разницу в энергии
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Формируем финальную выдачу
|
|
||||||
df_result = self.music_db.copy()
|
df_result = self.music_db.copy()
|
||||||
df_result['distance'] = distances
|
df_result['distance'] = distances
|
||||||
playlist = df_result.sort_values(by='distance').head(top_k)
|
|
||||||
|
|
||||||
return target_v, target_a, playlist
|
# Сортируем по расстоянию и берем топ-K
|
||||||
|
return df_result.sort_values(by='distance').head(top_k)
|
||||||
+35
-93
@@ -2,8 +2,8 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 5,
|
||||||
"id": "83693ad7",
|
"id": "b92e0213",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -14,119 +14,61 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 7,
|
||||||
"id": "99850a99",
|
"id": "1763c51e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Читаем файл аннотаций: ../dataset/DEAM/DEAM_Annotations/annotations/annotations per each rater/song_level/static_annotations_songs_1_2000.csv\n"
|
"✅ УСПЕХ! База создана: ../../dataset/DEAM/music_db.csv\n",
|
||||||
|
"Всего треков в базе: 1744\n",
|
||||||
|
"Пример данных:\n",
|
||||||
|
" song_id valence arousal\n",
|
||||||
|
"0 2 3.1 3.0\n",
|
||||||
|
"1 3 3.5 3.3\n",
|
||||||
|
"2 4 5.7 5.5\n",
|
||||||
|
"3 5 4.4 5.3\n",
|
||||||
|
"4 7 5.8 6.4\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# 1. Ищем файл (поднимаемся из src на уровень выше)\n",
|
"# Точный путь к оригинальным аннотациям\n",
|
||||||
"deam_root = Path(\"../dataset/DEAM\")\n",
|
"source_path = Path(\"../../dataset/DEAM/DEAM_Annotations/annotations/annotations averaged per song/song_level/static_annotations_averaged_songs_1_2000.csv\")\n",
|
||||||
|
"# Путь, куда сохраним очищенную базу для движка\n",
|
||||||
|
"output_path = Path(\"../../dataset/DEAM/music_db.csv\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Ищем файл статичных аннотаций. Берем первый попавшийся.\n",
|
"if not source_path.exists():\n",
|
||||||
"csv_files = list(deam_root.rglob(\"*static_annotations*.csv\"))\n",
|
" print(f\"❌ Исходный файл не найден по пути: {source_path}\")\n",
|
||||||
"if not csv_files:\n",
|
"else:\n",
|
||||||
" # Если не нашел static, берем вообще любой csv с аннотациями\n",
|
" # skipinitialspace=True уберет лишние пробелы в названиях колонок, если они есть\n",
|
||||||
" csv_files = list(deam_root.rglob(\"*.csv\"))\n",
|
" df = pd.read_csv(source_path, skipinitialspace=True)\n",
|
||||||
" \n",
|
" \n",
|
||||||
"if not csv_files:\n",
|
" # Берем только нужные колонки (по твоему примеру)\n",
|
||||||
" # Если путь неверный или файлов нет, скрипт сразу скажет об этом и покажет полный путь\n",
|
" clean_df = df[['song_id', 'valence_mean', 'arousal_mean']].copy()\n",
|
||||||
" raise FileNotFoundError(f\"В папке {deam_root.resolve()} не найдено ни одного CSV файла! Проверьте пути.\")\n",
|
|
||||||
" \n",
|
" \n",
|
||||||
"anno_path = csv_files[0]\n",
|
" # Переименовываем для простоты кода в движке\n",
|
||||||
"print(f\"Читаем файл аннотаций: {anno_path}\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 8,
|
|
||||||
"id": "5fbc493f",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Оригинальные колонки в файле: ['workerID', ' SongId', ' Valence', ' Arousal']\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# 2. Загружаем и чистим колонки\n",
|
|
||||||
"df = pd.read_csv(anno_path)\n",
|
|
||||||
"print(\"Оригинальные колонки в файле:\", df.columns.tolist())\n",
|
|
||||||
"\n",
|
|
||||||
"# Сносим пробелы по краям и переводим в нижний регистр\n",
|
|
||||||
"df.columns = [str(c).strip().lower() for c in df.columns]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"id": "1e28fece",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Успешно найдены колонки -> ID: 'workerid', Valence: 'valence', Arousal: 'arousal'\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# 3. Умный поиск колонок\n",
|
|
||||||
"# Ищем первую колонку, где есть 'id' или 'song'\n",
|
|
||||||
"song_col = next((c for c in df.columns if 'song' in c or 'id' in c), df.columns[0])\n",
|
|
||||||
"# Ищем valence (желательно mean, но сойдет любой)\n",
|
|
||||||
"v_col = next((c for c in df.columns if 'valence' in c and 'mean' in c), \n",
|
|
||||||
" next((c for c in df.columns if 'valence' in c), None))\n",
|
|
||||||
"# Ищем arousal\n",
|
|
||||||
"a_col = next((c for c in df.columns if 'arousal' in c and 'mean' in c), \n",
|
|
||||||
" next((c for c in df.columns if 'arousal' in c), None))\n",
|
|
||||||
"\n",
|
|
||||||
"if not v_col or not a_col:\n",
|
|
||||||
" raise ValueError(f\"Не смог найти Valence или Arousal! Доступные колонки: {df.columns.tolist()}\")\n",
|
|
||||||
"\n",
|
|
||||||
"print(f\"Успешно найдены колонки -> ID: '{song_col}', Valence: '{v_col}', Arousal: '{a_col}'\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"id": "469f651c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Готово! Музыкальная база сохранена: ../dataset/DEAM/music_db.csv\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# 4. Сохраняем результат\n",
|
|
||||||
"clean_df = df[[song_col, v_col, a_col]].copy()\n",
|
|
||||||
" clean_df.columns = ['song_id', 'valence', 'arousal']\n",
|
" clean_df.columns = ['song_id', 'valence', 'arousal']\n",
|
||||||
" \n",
|
" \n",
|
||||||
"output_path = deam_root / \"music_db.csv\"\n",
|
" # Приводим ID к целому числу (2, 3, 4...), чтобы искать файлы '2.mp3'\n",
|
||||||
|
" clean_df['song_id'] = clean_df['song_id'].astype(int)\n",
|
||||||
|
" \n",
|
||||||
|
" # Сохраняем финальный файл\n",
|
||||||
" clean_df.to_csv(output_path, index=False)\n",
|
" clean_df.to_csv(output_path, index=False)\n",
|
||||||
"print(f\"Готово! Музыкальная база сохранена: {output_path}\")"
|
" \n",
|
||||||
|
" print(f\"✅ УСПЕХ! База создана: {output_path}\")\n",
|
||||||
|
" print(f\"Всего треков в базе: {len(clean_df)}\")\n",
|
||||||
|
" print(\"Пример данных:\")\n",
|
||||||
|
" print(clean_df.head())"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python (my-python-project)",
|
"display_name": "Python (thesis)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "my-python-project"
|
"name": "thesis"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
|
|||||||
Reference in New Issue
Block a user