ref: refactor before chekout
This commit is contained in:
+20
-19
@@ -1,54 +1,55 @@
|
||||
import streamlit as st
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
|
||||
from music_engine.matcher import MusicMatcher
|
||||
from music_engine.image_processor import ImageProcessor
|
||||
|
||||
# Определяем базовую директорию (папка src)
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
|
||||
@st.cache_resource
|
||||
def load_music_engine():
|
||||
"""Загрузка базы данных и модели регрессора."""
|
||||
# music_db.csv лежит в dataset/DEAM/ (на уровень выше от src)
|
||||
# Инициализация базы данных и регрессора для музыкального мэтчинга
|
||||
db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
|
||||
# va_regressor.pkl лежит в src/music_engine/
|
||||
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
|
||||
|
||||
if not db_path.exists():
|
||||
print(f"⚠️ Файл базы {db_path} не найден!")
|
||||
print(f"Музыкальная БД не найдена: {db_path}")
|
||||
return None
|
||||
|
||||
return MusicMatcher(db_path=db_path, model_path=model_path)
|
||||
|
||||
@st.cache_resource
|
||||
def load_image_processor():
|
||||
"""Загрузка ResNet-50 для извлечения признаков на лету."""
|
||||
# Файл весов лежит в той же папке src, что и этот скрипт
|
||||
# Модуль обработки визуальных признаков
|
||||
model_path = BASE_DIR / "emoset_resnet50_best.pth"
|
||||
|
||||
# Обработка пути при вызове из корневой директории
|
||||
if not model_path.exists():
|
||||
print(f"Ошибка: Веса не найдены по пути: {model_path}")
|
||||
# Если не нашли в src, попробуем поискать в корне проекта на всякий случай
|
||||
model_path = BASE_DIR.parent / "emoset_resnet50_best.pth"
|
||||
|
||||
return ImageProcessor(model_path=model_path)
|
||||
|
||||
@st.cache_data
|
||||
def load_emoset_data():
|
||||
"""Загрузка тестовой выборки EmoSet для первой вкладки."""
|
||||
# Пути относительно корня проекта
|
||||
csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
|
||||
img_dir = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
|
||||
# Выборка данных датасета для вкладки отладки
|
||||
dataset_root = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test"
|
||||
|
||||
csv_path = dataset_root / "labels.csv"
|
||||
img_dir = dataset_root / "images"
|
||||
emb_path = BASE_DIR / "emoset_test_embeddings.npy"
|
||||
lbl_path = BASE_DIR / "emoset_test_labels.npy"
|
||||
|
||||
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
|
||||
print("Тестовые файлы датасета не найдены, вкладка отладки может работать некорректно")
|
||||
return None, None, None, None
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
image_list = df['filename'].tolist()
|
||||
embs = np.load(emb_path)
|
||||
lbls = np.load(lbl_path)
|
||||
labels_df = pd.read_csv(csv_path)
|
||||
|
||||
return image_list, embs, lbls, img_dir
|
||||
test_filenames = labels_df['filename'].tolist()
|
||||
test_embeddings = np.load(emb_path)
|
||||
test_labels = np.load(lbl_path)
|
||||
|
||||
return test_filenames, test_embeddings, test_labels, img_dir
|
||||
Reference in New Issue
Block a user