Beta V.1.1

This commit is contained in:
zin
2026-05-06 21:11:15 +00:00
parent 5290554d70
commit 9954603043
4 changed files with 156 additions and 14 deletions
+28 -7
View File
@@ -3,24 +3,45 @@ from pathlib import Path
import pandas as pd
import numpy as np
from music_engine.matcher import MusicMatcher
from music_engine.image_processor import ImageProcessor
# Определяем базовую директорию (папка src)
BASE_DIR = Path(__file__).resolve().parent
@st.cache_resource
def load_music_engine():
"""Загрузка базы данных и модели регрессора."""
base_dir = Path(__file__).resolve().parent
db_path = base_dir.parent / "dataset" / "DEAM" / "music_db.csv"
model_path = base_dir / "music_engine" / "va_regressor.pkl"
# music_db.csv лежит в dataset/DEAM/ (на уровень выше от src)
db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
# va_regressor.pkl лежит в src/music_engine/
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
if not db_path.exists():
print(f"⚠️ Файл базы {db_path} не найден!")
return None
return MusicMatcher(db_path=db_path, model_path=model_path)
@st.cache_resource
def load_image_processor():
"""Загрузка ResNet-50 для извлечения признаков на лету."""
# Файл весов лежит в той же папке src, что и этот скрипт
model_path = BASE_DIR / "emoset_resnet50_best.pth"
if not model_path.exists():
print(f"❌ КРИТИЧЕСКАЯ ОШИБКА: Веса не найдены по пути: {model_path}")
# Если не нашли в src, попробуем поискать в корне проекта на всякий случай
model_path = BASE_DIR.parent / "emoset_resnet50_best.pth"
return ImageProcessor(model_path=model_path)
@st.cache_data
def load_emoset_data():
"""Загрузка тестовой выборки EmoSet для первой вкладки."""
csv_path = Path("./dataset/EmoSet-118K/test/labels.csv")
img_dir = Path("./dataset/EmoSet-118K/test/images")
emb_path = Path("./src/emoset_test_embeddings.npy")
lbl_path = Path("./src/emoset_test_labels.npy")
# Пути относительно корня проекта
csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
img_dir = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
emb_path = BASE_DIR / "emoset_test_embeddings.npy"
lbl_path = BASE_DIR / "emoset_test_labels.npy"
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
return None, None, None, None