feat: refactor code and finetune OOM fix
This commit is contained in:
+5
-14
@@ -9,7 +9,7 @@ from PIL import Image
|
||||
from data_loader import load_music_engine, load_image_processor
|
||||
from music_engine.llm_bridge import LLMAcousticBridge
|
||||
|
||||
app = FastAPI(title="EmoM Inference API", version="1.0.0")
|
||||
app = FastAPI(title="EmoM API", version="1.0.0")
|
||||
|
||||
ml_context = {
|
||||
"image_processor": None,
|
||||
@@ -19,23 +19,22 @@ ml_context = {
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
print("Инициализация нейросетевого ядра EmoM...")
|
||||
print("Loading ML models...")
|
||||
ml_context["image_processor"] = load_image_processor()
|
||||
ml_context["music_matcher"] = load_music_engine()
|
||||
ml_context["llm_bridge"] = LLMAcousticBridge()
|
||||
print("Вычислительный конвейер готов к работе.")
|
||||
print("Initialization complete.")
|
||||
|
||||
@app.post("/analyze")
|
||||
async def analyze_event_endpoint(files: List[UploadFile] = File(...)):
|
||||
try:
|
||||
# 1. Читаем все загруженные картинки
|
||||
images = []
|
||||
for file in files:
|
||||
image_bytes = await file.read()
|
||||
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
||||
images.append(img)
|
||||
|
||||
print(f"Начата обработка события из {len(images)} фотографий...")
|
||||
print(f"Processing batch: {len(images)} images.")
|
||||
|
||||
img_processor = ml_context["image_processor"]
|
||||
matcher = ml_context["music_matcher"]
|
||||
@@ -44,7 +43,6 @@ async def analyze_event_endpoint(files: List[UploadFile] = File(...)):
|
||||
all_v, all_a = [], []
|
||||
all_objects = []
|
||||
|
||||
# 2. Прогоняем каждую картинку через нейросети
|
||||
for img in images:
|
||||
embedding = img_processor.extract_embedding(img)
|
||||
v, a = matcher.predict_va(embedding)
|
||||
@@ -54,20 +52,13 @@ async def analyze_event_endpoint(files: List[UploadFile] = File(...)):
|
||||
caption = img_processor.describe_scene(img)
|
||||
all_objects.append(caption)
|
||||
|
||||
# 3. Усредняем эмоции события
|
||||
target_v = float(np.mean(all_v))
|
||||
target_a = float(np.mean(all_a))
|
||||
unique_semantics = list(set(all_objects))
|
||||
|
||||
# 4. Запрашиваем акустический профиль у Ollama
|
||||
print(f"Запрос к Ollama. V={target_v:.2f}, A={target_a:.2f}")
|
||||
llm_profile = llm.get_acoustic_profile(target_v, target_a, unique_semantics)
|
||||
|
||||
# 5. Ищем треки в базе
|
||||
print("Поиск подходящих композиций...")
|
||||
playlist_df = matcher.find_nearest_tracks(target_v, target_a, llm_profile=llm_profile, top_k=15)
|
||||
|
||||
# Переводим таблицу в JSON-формат
|
||||
tracks_list = playlist_df.to_dict(orient="records")
|
||||
|
||||
return JSONResponse(content={
|
||||
@@ -82,4 +73,4 @@ async def analyze_event_endpoint(files: List[UploadFile] = File(...)):
|
||||
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=f"Ошибка инференса: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
+15
-18
@@ -1,49 +1,46 @@
|
||||
from pathlib import Path
|
||||
from typing import Tuple, List, Optional, Any
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
# Импорты твоих движков
|
||||
from music_engine.matcher import MusicMatcher
|
||||
from music_engine.image_processor import ImageProcessor
|
||||
|
||||
# Базовая директория (папка src)
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
|
||||
def load_music_engine():
|
||||
"""Загрузка базы данных и модели регрессора для бэкенда."""
|
||||
# Пути соответствуют тем, что мы примонтировали в Docker
|
||||
def load_music_engine() -> MusicMatcher:
|
||||
#Инициализация модуля подбора музыкальных композиций.
|
||||
db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
|
||||
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
|
||||
|
||||
return MusicMatcher(db_path=db_path, model_path=model_path)
|
||||
|
||||
def load_image_processor():
|
||||
"""Инициализация нейросетевого экстрактора (ResNet-50)."""
|
||||
def load_image_processor() -> ImageProcessor:
|
||||
#Инициализация модуля экстракции визуальных признаков.
|
||||
weights_path = BASE_DIR / "emoset_resnet50_best.pth"
|
||||
|
||||
return ImageProcessor(weights_path)
|
||||
|
||||
def load_emoset_data():
|
||||
"""
|
||||
Загрузка эталонного датасета EmoSet.
|
||||
(Оставлено для обратной совместимости, если понадобится локальная отладка)
|
||||
"""
|
||||
def load_emoset_data() -> Tuple[Optional[List[str]], Optional[np.ndarray], Optional[np.ndarray], Optional[Path]]:
|
||||
# Загрузка тестовой выборки датасета EmoSet.
|
||||
# Модуль сохранен для обеспечения обратной совместимости в отладочном контуре.
|
||||
try:
|
||||
images_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
|
||||
labels_path = BASE_DIR / "emoset_test_labels.npy"
|
||||
embeddings_path = BASE_DIR / "emoset_test_embeddings.npy"
|
||||
|
||||
# Если файлов нет (например, на проде), возвращаем None
|
||||
if not all(p.exists() for p in [labels_path, embeddings_path]):
|
||||
return None, None, None, None
|
||||
|
||||
labels = np.load(labels_path)
|
||||
embeddings = np.load(embeddings_path)
|
||||
|
||||
# Читаем CSV с метками
|
||||
df = pd.read_csv(BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv")
|
||||
image_files = df['filename'].tolist()
|
||||
csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
return df['filename'].tolist(), embeddings, labels, images_path
|
||||
|
||||
return image_files, embeddings, labels, images_path
|
||||
except Exception as e:
|
||||
print(f"Предупреждение: Тестовые артефакты EmoSet не найдены ({e})")
|
||||
print(f"[WARN] Failed to load EmoSet test artifacts: {str(e)}")
|
||||
return None, None, None, None
|
||||
+1
-2
@@ -147,7 +147,7 @@ def main():
|
||||
"zcr": "ZCR"
|
||||
}
|
||||
|
||||
# Развернутые описания для комиссии (передаются в аргумент help)
|
||||
# Развернутые описания
|
||||
feature_helps = {
|
||||
"energy": "Среднеквадратичная амплитуда (громкость). Бывает высокой в плотных, интенсивных композициях, отражает общую акустическую энергию сцены.",
|
||||
"flux": "Спектральный поток. Измеряет резкость изменений в спектре. Высок при четком, агрессивном ритме и частой смене нот.",
|
||||
@@ -169,7 +169,6 @@ def main():
|
||||
k, v = llm_items[i + j]
|
||||
label = feature_titles.get(k, k)
|
||||
tooltip = feature_helps.get(k, "")
|
||||
# Форматируем до 2 знаков после запятой (например, 0.64)
|
||||
cols[j].metric(label, f"{v:.2f}", help=tooltip)
|
||||
else:
|
||||
st.caption("Акустический профиль недоступен. Применен fallback-алгоритм.")
|
||||
|
||||
@@ -6,14 +6,12 @@ import requests
|
||||
class LLMAcousticBridge:
|
||||
def __init__(self, model_name="dolphin-llama3:8b"):
|
||||
self.model_name = model_name
|
||||
# Динамический выбор URL (внутри Docker используется emom_ollama)
|
||||
base_url = os.getenv("OLLAMA_API_URL", "http://emom_ollama:11434")
|
||||
self.api_url = f"{base_url}/api/generate"
|
||||
|
||||
def get_acoustic_profile(self, valence, arousal, semantics):
|
||||
context_str = ", ".join(semantics) if semantics else "abstract scene"
|
||||
|
||||
# Строгий промпт с примером вывода
|
||||
prompt = f"""
|
||||
Analyze the visual context and emotions to determine the ideal background music properties.
|
||||
Emotions: Valence {valence:.1f}/9.0 (Positivity), Arousal {arousal:.1f}/9.0 (Energy).
|
||||
|
||||
@@ -110,14 +110,14 @@ if __name__ == "__main__":
|
||||
EmoSetDirectDataset(all_paths[:split_idx], all_labels[:split_idx]),
|
||||
batch_size=BATCH_SIZE, shuffle=True,
|
||||
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
|
||||
prefetch_factor=2, persistent_workers=True
|
||||
prefetch_factor=2, persistent_workers=False
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
EmoSetDirectDataset(all_paths[split_idx:], all_labels[split_idx:]),
|
||||
batch_size=BATCH_SIZE, shuffle=False,
|
||||
num_workers=NUM_VAL_WORKERS, pin_memory=True,
|
||||
prefetch_factor=2, persistent_workers=True
|
||||
prefetch_factor=2, persistent_workers=False
|
||||
)
|
||||
|
||||
gpu_train_tf, gpu_val_tf = build_gpu_transforms()
|
||||
|
||||
Reference in New Issue
Block a user