feat: finale

This commit is contained in:
zin
2026-06-03 09:16:12 +00:00
parent 3850b15053
commit a57addcbb1
9 changed files with 807 additions and 176 deletions
+55 -23
View File
@@ -1,18 +1,20 @@
import io
import os
import traceback
import numpy as np
from typing import List
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
# Импортируем твои существующие загрузчики (они теперь работают только на бэкенде)
from data_loader import load_music_engine, load_image_processor
from music_engine.llm_bridge import LLMAcousticBridge
app = FastAPI(title="EmoM Inference API", version="1.0.0")
# Глобальный кэш для удержания моделей в памяти
ml_context = {
"image_processor": None,
"music_matcher": None
"music_matcher": None,
"llm_bridge": None
}
@app.on_event("startup")
@@ -20,34 +22,64 @@ async def startup_event():
print("Инициализация нейросетевого ядра EmoM...")
ml_context["image_processor"] = load_image_processor()
ml_context["music_matcher"] = load_music_engine()
if not ml_context["image_processor"] or not ml_context["music_matcher"]:
raise RuntimeError("Отказ системы: Артефакты моделей не найдены.")
ml_context["llm_bridge"] = LLMAcousticBridge()
print("Вычислительный конвейер готов к работе.")
@app.post("/analyze")
async def analyze_image_endpoint(file: UploadFile = File(...)):
"""
Принимает изображение, прогоняет через ResNet и возвращает треки из DEAM.
"""
async def analyze_event_endpoint(files: List[UploadFile] = File(...)):
try:
# 1. Чтение бинарных данных из запроса
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# 1. Читаем все загруженные картинки
images = []
for file in files:
image_bytes = await file.read()
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
images.append(img)
print(f"Начата обработка события из {len(images)} фотографий...")
# 2. Инференс (ВНИМАНИЕ: здесь используй реальные названия методов из своих классов!)
# Предположим, твой процессор выдает координаты V/A
v_a_coords = ml_context["image_processor"].extract_va(image)
img_processor = ml_context["image_processor"]
matcher = ml_context["music_matcher"]
llm = ml_context["llm_bridge"]
all_v, all_a = [], []
all_objects = []
# 2. Прогоняем каждую картинку через нейросети
for img in images:
embedding = img_processor.extract_embedding(img)
v, a = matcher.predict_va(embedding)
all_v.append(v)
all_a.append(a)
caption = img_processor.describe_scene(img)
all_objects.append(caption)
# 3. Усредняем эмоции события
target_v = float(np.mean(all_v))
target_a = float(np.mean(all_a))
unique_semantics = list(set(all_objects))
# 4. Запрашиваем акустический профиль у Ollama
print(f"Запрос к Ollama. V={target_v:.2f}, A={target_a:.2f}")
llm_profile = llm.get_acoustic_profile(target_v, target_a, unique_semantics)
# 5. Ищем треки в базе
print("Поиск подходящих композиций...")
playlist_df = matcher.find_nearest_tracks(target_v, target_a, llm_profile=llm_profile, top_k=15)
# 3. Поиск треков в базе
matched_tracks = ml_context["music_matcher"].find_tracks(v_a_coords)
# 4. Формирование ответа
# Переводим таблицу в JSON-формат
tracks_list = playlist_df.to_dict(orient="records")
return JSONResponse(content={
"status": "success",
"valence_arousal": v_a_coords,
"tracks": matched_tracks
"images_processed": len(images),
"target_v": target_v,
"target_a": target_a,
"llm_profile": llm_profile,
"semantics": unique_semantics,
"tracks": tracks_list
})
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Ошибка инференса: {str(e)}")
+31 -37
View File
@@ -1,55 +1,49 @@
import os
from pathlib import Path
import pandas as pd
import numpy as np
import streamlit as st
# Импорты твоих движков
from music_engine.matcher import MusicMatcher
from music_engine.image_processor import ImageProcessor
# Базовая директория (папка src)
BASE_DIR = Path(__file__).resolve().parent
@st.cache_resource
def load_music_engine():
# Инициализация базы данных и регрессора для музыкального мэтчинга
"""Загрузка базы данных и модели регрессора для бэкенда."""
# Пути соответствуют тем, что мы примонтировали в Docker
db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
if not db_path.exists():
print(f"Музыкальная БД не найдена: {db_path}")
return None
return MusicMatcher(db_path=db_path, model_path=model_path)
@st.cache_resource
def load_image_processor():
# Модуль обработки визуальных признаков
model_path = BASE_DIR / "emoset_resnet50_best.pth"
# Обработка пути при вызове из корневой директории
if not model_path.exists():
model_path = BASE_DIR.parent / "emoset_resnet50_best.pth"
return ImageProcessor(model_path=model_path)
"""Инициализация нейросетевого экстрактора (ResNet-50)."""
weights_path = BASE_DIR / "emoset_resnet50_best.pth"
return ImageProcessor(weights_path)
@st.cache_data
def load_emoset_data():
# Выборка данных датасета для вкладки отладки
dataset_root = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test"
csv_path = dataset_root / "labels.csv"
img_dir = dataset_root / "images"
emb_path = BASE_DIR / "emoset_test_embeddings.npy"
lbl_path = BASE_DIR / "emoset_test_labels.npy"
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
print("Тестовые файлы датасета не найдены, вкладка отладки может работать некорректно")
return None, None, None, None
labels_df = pd.read_csv(csv_path)
test_filenames = labels_df['filename'].tolist()
test_embeddings = np.load(emb_path)
test_labels = np.load(lbl_path)
return test_filenames, test_embeddings, test_labels, img_dir
"""
Загрузка эталонного датасета EmoSet.
(Оставлено для обратной совместимости, если понадобится локальная отладка)
"""
try:
images_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
labels_path = BASE_DIR / "emoset_test_labels.npy"
embeddings_path = BASE_DIR / "emoset_test_embeddings.npy"
# Если файлов нет (например, на проде), возвращаем None
if not all(p.exists() for p in [labels_path, embeddings_path]):
return None, None, None, None
labels = np.load(labels_path)
embeddings = np.load(embeddings_path)
# Читаем CSV с метками
df = pd.read_csv(BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv")
image_files = df['filename'].tolist()
return image_files, embeddings, labels, images_path
except Exception as e:
print(f"Предупреждение: Тестовые артефакты EmoSet не найдены ({e})")
return None, None, None, None
+178 -50
View File
@@ -1,62 +1,190 @@
import os
import requests
import streamlit as st
import streamlit.components.v1 as components
from PIL import Image
import base64
from io import BytesIO
# Конфигурация UI
st.set_page_config(
page_title="EmoM | EmotionMusic",
layout="wide",
initial_sidebar_state="collapsed"
)
st.set_page_config(page_title="EmoM Playlist Generator", layout="wide", initial_sidebar_state="collapsed")
st.markdown(
"""
<style>
img { max-width: 100%; height: auto; object-fit: contain; border-radius: 4px; }
[data-testid="stMetricValue"] { font-size: 1.8rem; font-weight: 600; }
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
</style>
""",
unsafe_allow_html=True
)
API_URL = os.getenv("BACKEND_API_URL", "http://emom_inference:8000") + "/analyze"
DEAM_AUDIO_DIR = "/app/dataset/DEAM/DEAM_audio/MEMD_audio"
# Маршрутизация к нашему новому микросервису (берется из .env, либо локалхост)
API_URL = os.getenv("BACKEND_API_URL", "http://localhost:8000") + "/analyze"
def get_thumbnail_html(images, max_display=12):
html_images = ""
for file in images[:max_display]:
img = Image.open(file)
img.thumbnail((100, 100))
if img.mode != "RGB":
img = img.convert("RGB")
buffered = BytesIO()
img.save(buffered, format="JPEG")
b64_str = base64.b64encode(buffered.getvalue()).decode()
html_images += f'<img src="data:image/jpeg;base64,{b64_str}" style="width: 60px; height: 60px; object-fit: cover; border-radius: 8px; margin-right: 8px; margin-bottom: 8px; border: 1px solid rgba(255, 255, 255, 0.2);">'
if len(images) > max_display:
html_images += f'<span style="display: inline-block; width: 60px; height: 60px; line-height: 60px; text-align: center; background: rgba(150, 150, 150, 0.2); border-radius: 8px; vertical-align: top; font-size: 14px;">+{len(images) - max_display}</span>'
return f'<div style="display: flex; flex-wrap: wrap;">{html_images}</div>'
def main():
st.title("Система генерации саундтреков (EmoM)")
st.caption("Микросервисная архитектура: Frontend (Streamlit) -> REST API -> PyTorch/DEAM")
uploaded_file = st.file_uploader("Загрузите изображение для анализа", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
st.image(uploaded_file, caption="Входной визуальный контент")
if "live_state" not in st.session_state:
st.session_state.live_state = "upload"
if "result_data" not in st.session_state:
st.session_state.result_data = None
if st.button("Анализировать"):
with st.spinner("Отправка данных в вычислительный кластер..."):
try:
# Отправляем POST-запрос в наш FastAPI микросервис
files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
response = requests.post(API_URL, files=files, timeout=30)
if response.status_code == 200:
data = response.json()
st.success("Анализ успешно завершен!")
# Вывод результатов
st.subheader("Результаты анализа")
st.write(f"Координаты Valence/Arousal: {data.get('valence_arousal')}")
st.write("Подобранные треки:")
st.json(data.get('tracks'))
# Здесь в будущем можно добавить обращение к Ollama для генерации красивого описания
else:
st.error(f"Ошибка сервера: {response.text}")
except requests.exceptions.ConnectionError:
st.error("Ошибка сети: Микросервис инференса недоступен. Проверьте статус Docker-контейнера emom_inference.")
viewport = st.query_params.get("viewport", "desktop")
st.markdown("""
<style>
[data-testid="stFileUploadDropzone"] { min-height: 250px !important; display: flex; align-items: center; justify-content: center; border-radius: 16px; background-color: rgba(255, 75, 75, 0.03); }
.spinner-container { display: flex; flex-direction: column; align-items: center; justify-content: center; min-height: 40vh; margin-top: 10vh; }
.big-spinner { width: 120px; height: 120px; border: 10px solid rgba(255, 75, 75, 0.1); border-top: 10px solid #ff4b4b; border-radius: 50%; animation: spin 1s linear infinite; margin-bottom: 2rem; }
@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
#MainMenu {visibility: hidden;} footer {visibility: hidden;}
</style>
""", unsafe_allow_html=True)
if st.session_state.live_state == "upload":
upload_placeholder = st.empty()
with upload_placeholder.container():
st.write("Загрузите изображения для визуально-семантического анализа.")
if viewport == "mobile":
st.markdown("<br>", unsafe_allow_html=True)
uploaded_files = st.file_uploader(
"Загрузка файлов",
type=['png', 'jpg', 'jpeg'],
accept_multiple_files=True,
label_visibility="collapsed" if viewport == "mobile" else "visible"
)
if uploaded_files:
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Выполнить анализ", type="primary", use_container_width=True):
st.session_state.uploaded_images = uploaded_files
st.session_state.live_state = "processing"
upload_placeholder.empty()
st.rerun()
st.markdown("<br>", unsafe_allow_html=True)
st.caption("Выбранные файлы:")
st.markdown(get_thumbnail_html(uploaded_files), unsafe_allow_html=True)
elif st.session_state.live_state == "processing":
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
files = st.session_state.get("uploaded_images", [])
st.markdown('<div class="spinner-container"><div class="big-spinner"></div><h3 style="text-align: center; font-weight: 400;">Обработка данных...</h3></div>', unsafe_allow_html=True)
try:
upload_data = [('files', (f.name, f.getvalue(), f.type)) for f in files]
response = requests.post(API_URL, files=upload_data, timeout=300)
if response.status_code == 200:
st.session_state.result_data = response.json()
st.session_state.live_state = "result"
st.rerun()
else:
st.error(f"Ошибка сервера: {response.status_code}")
if st.button("Назад"):
st.session_state.live_state = "upload"
st.rerun()
except Exception as e:
st.error(f"Ошибка соединения: {str(e)}")
if st.button("Назад"):
st.session_state.live_state = "upload"
st.rerun()
elif st.session_state.live_state == "result":
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
data = st.session_state.result_data
st.header(f"Сгенерированный плейлист (обработано файлов: {data['images_processed']})")
for row in data.get("tracks", []):
with st.container(border=True):
song_id = int(row['song_id'])
score = row['final_score']
audio_path = f"{DEAM_AUDIO_DIR}/{song_id}.mp3"
if not os.path.exists(audio_path):
audio_path = audio_path.replace('.mp3', '.wav')
if viewport == "desktop":
c1, c2 = st.columns([1, 3])
with c1:
st.write(f"**Track ID:** {song_id}")
st.caption(f"Score: {score:.4f}")
with c2:
if os.path.exists(audio_path):
st.audio(audio_path)
else:
st.caption("Аудиофайл не найден")
else:
st.write(f"**Track ID:** {song_id} (Score: {score:.4f})")
if os.path.exists(audio_path):
st.audio(audio_path)
else:
st.caption("Аудиофайл не найден")
st.markdown("<br>", unsafe_allow_html=True)
with st.expander("Отладочная информация (Метрики)"):
st.subheader("Координаты V/A")
c_v, c_a = st.columns(2)
c_v.metric("Valence", f"{data['target_v']:.2f}")
c_a.metric("Arousal", f"{data['target_a']:.2f}")
st.markdown("---")
st.subheader("Акустические признаки (LLM)")
feature_titles = {
"energy": "RMS Energy",
"flux": "Spectral Flux",
"centroid": "Spectral Centroid",
"pitch": "F0 (Pitch)",
"hnr": "HNR",
"zcr": "ZCR"
}
# Развернутые описания для комиссии (передаются в аргумент help)
feature_helps = {
"energy": "Среднеквадратичная амплитуда (громкость). Бывает высокой в плотных, интенсивных композициях, отражает общую акустическую энергию сцены.",
"flux": "Спектральный поток. Измеряет резкость изменений в спектре. Высок при четком, агрессивном ритме и частой смене нот.",
"centroid": "Спектральный центроид («яркость» звука). Высокие значения указывают на преобладание высоких частот (звонкие инструменты, открытые пространства).",
"pitch": "Основная частота звука. Высокий pitch характерен для позитивных, легких или, напротив, напряженных мелодий.",
"hnr": "Отношение гармоник к шуму. Высокий HNR — чистая мелодия и вокал. Низкий HNR — присутствие дисторшна, шумов или перкуссии.",
"zcr": "Частота пересечения нуля. Отражает шумовую составляющую сигнала. Высок в треках с выраженными ударными (hi-hats) или атмосферным шумом."
}
llm_profile = data.get("llm_profile")
if llm_profile and isinstance(llm_profile, dict) and len(llm_profile) > 0:
cols_per_row = 2 if viewport == "mobile" else 3
llm_items = list(llm_profile.items())
for i in range(0, len(llm_items), cols_per_row):
cols = st.columns(cols_per_row)
for j in range(cols_per_row):
if i + j < len(llm_items):
k, v = llm_items[i + j]
label = feature_titles.get(k, k)
tooltip = feature_helps.get(k, "")
# Форматируем до 2 знаков после запятой (например, 0.64)
cols[j].metric(label, f"{v:.2f}", help=tooltip)
else:
st.caption("Акустический профиль недоступен. Применен fallback-алгоритм.")
st.markdown("---")
st.write("**Извлеченные теги (BLIP-2):**")
st.write(", ".join([str(c).capitalize() for c in data.get("semantics", [])]))
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Новый запрос", use_container_width=True):
st.session_state.live_state = "upload"
st.session_state.result_data = None
st.session_state.pop("uploaded_images", None)
st.rerun()
if __name__ == "__main__":
main()
+5 -1
View File
@@ -32,7 +32,11 @@ class ImageProcessor:
# Модуль семантического описания сцены
print("Инициализация BLIP-2...")
self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
# Обход бага конфигурации Hugging Face (ручная сборка процессора)
from transformers import BlipImageProcessor, AutoTokenizer
img_proc = BlipImageProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
tok = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b", use_fast=False)
self.blip_processor = Blip2Processor(image_processor=img_proc, tokenizer=tok)
self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
torch_dtype=torch.float16
+56 -54
View File
@@ -1,65 +1,67 @@
import re
import os
import json
import re
import requests
class LLMAcousticBridge:
def __init__(self, target_model="dolphin-llama3:8b"):
self.api_url = "http://localhost:11434/api/generate"
self.model = target_model
def __init__(self, model_name="dolphin-llama3:8b"):
self.model_name = model_name
# Динамический выбор URL (внутри Docker используется emom_ollama)
base_url = os.getenv("OLLAMA_API_URL", "http://emom_ollama:11434")
self.api_url = f"{base_url}/api/generate"
def _extract_json(self, raw_text: str):
# Проверка на ИИдиота, LLM иногда игнорирует format="json" и оборачивает ответ в маркдаун
try:
match = re.search(r'\{.*\}', raw_text, re.DOTALL)
if match:
return json.loads(match.group(0))
return json.loads(raw_text)
except json.JSONDecodeError:
# Если ИИдиот
return None
def get_acoustic_profile(self, v_score: float, a_score: float, scene_context: list) -> dict | None:
# Агрегация контекста для обработки серии снимков (события)
context_merged = " | ".join(scene_context) if scene_context else "abstract scene"
def get_acoustic_profile(self, valence, arousal, semantics):
context_str = ", ".join(semantics) if semantics else "abstract scene"
# Строгий промпт с примером вывода
prompt = f"""
Analyze the visual context and emotions to determine the ideal background music properties.
Emotions: Valence {valence:.1f}/9.0 (Positivity), Arousal {arousal:.1f}/9.0 (Energy).
Visual Context: {context_str}.
Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0.
1. "energy": (Loudness/Density)
2. "flux": (Rhythmic sharpness/Beat)
3. "centroid": (Brightness)
4. "pitch": (Fundamental frequency)
5. "hnr": (Harmonics-to-Noise)
6. "zcr": (Percussiveness)
Return ONLY a valid JSON object. No explanations, no markdown blocks.
Example: {{"energy": 0.8, "flux": 0.5, "centroid": 0.6, "pitch": 0.4, "hnr": 0.9, "zcr": 0.3}}
"""
system_prompt = f"""You are an expert music producer and acoustic engineer.
Analyze the visual context and emotions to determine the ideal background music properties.
Emotions: Valence {v_score:.1f}/9.0 (Positivity), Arousal {a_score:.1f}/9.0 (Energy).
Visual Context: {context_merged}.
Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0.
1. "energy": (Loudness/Density. High for massive/busy scenes, Low for calm)
2. "flux": (Rhythmic sharpness/Beat. High for action/people/cars, Low for static nature)
3. "centroid": (Brightness: 0=Dark/Bass/Massive, 1=Bright/Treble/Light)
4. "pitch": (Fundamental frequency: 0=Low pitch/Huge objects, 1=High pitch/Small objects)
5. "hnr": (Harmonics-to-Noise: 0=Noisy/Distorted textures, 1=Clear/Melodic/Smooth textures)
6. "zcr": (Percussiveness. High for detailed noise like leaves/rain, Low for solid blocks)
Return ONLY a valid JSON object. Do not add any text or explanation.
Example: {{"energy": 0.5, "flux": 0.2, "centroid": 0.4, "pitch": 0.3, "hnr": 0.8, "zcr": 0.1}}"""
try:
# Отправка промпта локальной Ollama
response = requests.post(self.api_url, json={
"model": self.model,
"prompt": system_prompt,
payload = {
"model": self.model_name,
"prompt": prompt,
"stream": False,
"format": "json"
}, timeout=45)
response.raise_for_status()
"format": "json" # Принудительный JSON-режим Ollama
}
raw_response = response.json().get("response", "")
profile_data = self._extract_json(raw_response)
print(f"Запрос акустического профиля к Ollama...")
response = requests.post(self.api_url, json=payload, timeout=120)
# Валидация структуры ответа
expected_features = {'energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr'}
if profile_data and expected_features.issubset(profile_data.keys()):
return profile_data
if response.status_code == 200:
data = response.json()
response_text = data.get("response", "")
print("LLM вернула неполный или некорректный набор акустических признаков")
return None
except requests.exceptions.RequestException as req_err:
print(f"Не удалось подключиться к Ollama: {req_err}")
return None
try:
# 1. Попытка прямой десериализации
profile = json.loads(response_text)
return profile
except json.JSONDecodeError:
# 2. Аварийное извлечение JSON из текста с помощью регулярного выражения
match = re.search(r'\{.*\}', response_text, re.DOTALL)
if match:
return json.loads(match.group(0))
print(f"Ошибка парсинга LLM ответа: {response_text}")
return {}
else:
print(f"Ollama вернула ошибку HTTP: {response.status_code}")
return {}
except Exception as e:
print(f"Ошибка соединения с Ollama: {str(e)}")
return {}