Compare commits

..

1 Commits

Author SHA1 Message Date
zin c32a2544ff Release 2026-05-06 22:31:52 +00:00
47 changed files with 3112 additions and 3125 deletions
-18
View File
@@ -1,18 +0,0 @@
bin/
lib/
share/
etc/
include/
pyvenv.cfg
.idea/
.vscode/
__pycache__/
*.pyc
.git/
runs/
dataset/
NFS/
*.pth
*.pkl
*.npy
.env
-25
View File
@@ -1,25 +0,0 @@
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
WORKDIR /app
# System dependencies for OpenCV and image processing
RUN apt-get update && apt-get install -y \
libglib2.0-0 \
libsm6 \
libxext6 \
libxrender-dev \
&& rm -rf /var/lib/apt/lists/*
# Install python packages
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy source code
COPY src/ /app/src/
EXPOSE 8080
CMD ["streamlit", "run", "src/main.py", "--server.port", "8080", "--server.address", "0.0.0.0"]
-20
View File
@@ -1,20 +0,0 @@
.PHONY: up down logs restart status
# Сборка и запуск контейнеров в фоновом режиме
up:
docker compose up --build -d
# Остановка и удаление контейнеров
down:
docker compose down
# Просмотр логов в реальном времени
logs:
docker compose logs -f
# Быстрый перезапуск
restart: down up
# Проверка статуса
status:
docker compose ps
-63
View File
@@ -1,63 +0,0 @@
version: '3.8'
networks:
emom_mesh:
driver: bridge
services:
emom_ui:
build:
context: .
dockerfile: docker/Dockerfile.ui
container_name: emom_web_ui
restart: unless-stopped
ports:
- "8080:8080"
networks:
- emom_mesh
env_file:
- .env
volumes:
- ./src:/app/src
- ${DATA_DEAM_DIR}:/app/dataset/DEAM:ro
depends_on:
- emom_inference
emom_inference:
build:
context: .
dockerfile: docker/Dockerfile.api
container_name: emom_pytorch_api
restart: unless-stopped
networks:
- emom_mesh
env_file:
- .env
volumes:
- ${HOST_ARTIFACTS_DIR}/emoset_resnet50_best.pth:/app/src/emoset_resnet50_best.pth:ro
- ${HOST_ARTIFACTS_DIR}/music_engine/va_regressor.pkl:/app/src/music_engine/va_regressor.pkl:ro
- ${DATA_DEAM_DIR}:/app/dataset/DEAM:ro
- ~/.cache/huggingface:/root/.cache/huggingface
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
emom_ollama:
image: ollama/ollama:latest
container_name: emom_ollama_engine
restart: unless-stopped
networks:
- emom_mesh
volumes:
- ~/.ollama:/root/.ollama
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
-19
View File
@@ -1,19 +0,0 @@
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
RUN apt-get update && apt-get install -y \
libglib2.0-0 libsm6 libxext6 libxrender-dev \
&& rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir fastapi uvicorn timm scikit-learn pandas joblib python-multipart transformers==4.38.2 tokenizers==0.15.2 accelerate
WORKDIR /app
COPY src/ /app/src/
WORKDIR /app/src
EXPOSE 8000
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8000"]
-15
View File
@@ -1,15 +0,0 @@
FROM python:3.12-slim
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
WORKDIR /app
RUN pip install --no-cache-dir streamlit==1.32.0 requests pandas pillow
COPY src/ /app/src/
WORKDIR /app/src
EXPOSE 8080
CMD ["streamlit", "run", "main.py", "--server.port", "8080", "--server.address", "0.0.0.0"]
-9
View File
@@ -1,9 +0,0 @@
streamlit==1.32.0
torch==2.2.1
torchvision==0.17.1
timm==0.9.16
pandas==2.2.1
scikit-learn==1.4.1.post1
joblib==1.3.2
transformers==4.38.2
requests==2.31.0
-76
View File
@@ -1,76 +0,0 @@
import io
import traceback
import numpy as np
from typing import List
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
from data_loader import load_music_engine, load_image_processor
from music_engine.llm_bridge import LLMAcousticBridge
app = FastAPI(title="EmoM API", version="1.0.0")
ml_context = {
"image_processor": None,
"music_matcher": None,
"llm_bridge": None
}
@app.on_event("startup")
async def startup_event():
print("Loading ML models...")
ml_context["image_processor"] = load_image_processor()
ml_context["music_matcher"] = load_music_engine()
ml_context["llm_bridge"] = LLMAcousticBridge()
print("Initialization complete.")
@app.post("/analyze")
async def analyze_event_endpoint(files: List[UploadFile] = File(...)):
try:
images = []
for file in files:
image_bytes = await file.read()
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
images.append(img)
print(f"Processing batch: {len(images)} images.")
img_processor = ml_context["image_processor"]
matcher = ml_context["music_matcher"]
llm = ml_context["llm_bridge"]
all_v, all_a = [], []
all_objects = []
for img in images:
embedding = img_processor.extract_embedding(img)
v, a = matcher.predict_va(embedding)
all_v.append(v)
all_a.append(a)
caption = img_processor.describe_scene(img)
all_objects.append(caption)
target_v = float(np.mean(all_v))
target_a = float(np.mean(all_a))
unique_semantics = list(set(all_objects))
llm_profile = llm.get_acoustic_profile(target_v, target_a, unique_semantics)
playlist_df = matcher.find_nearest_tracks(target_v, target_a, llm_profile=llm_profile, top_k=15)
tracks_list = playlist_df.to_dict(orient="records")
return JSONResponse(content={
"status": "success",
"images_processed": len(images),
"target_v": target_v,
"target_a": target_a,
"llm_profile": llm_profile,
"semantics": unique_semantics,
"tracks": tracks_list
})
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
+36 -28
View File
@@ -1,46 +1,54 @@
import streamlit as st
from pathlib import Path from pathlib import Path
from typing import Tuple, List, Optional, Any
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from music_engine.matcher import MusicMatcher from music_engine.matcher import MusicMatcher
from music_engine.image_processor import ImageProcessor from music_engine.image_processor import ImageProcessor
# Определяем базовую директорию (папка src)
BASE_DIR = Path(__file__).resolve().parent BASE_DIR = Path(__file__).resolve().parent
def load_music_engine() -> MusicMatcher: @st.cache_resource
#Инициализация модуля подбора музыкальных композиций. def load_music_engine():
"""Загрузка базы данных и модели регрессора."""
# music_db.csv лежит в dataset/DEAM/ (на уровень выше от src)
db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv" db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
# va_regressor.pkl лежит в src/music_engine/
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl" model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
if not db_path.exists():
print(f"⚠️ Файл базы {db_path} не найден!")
return None
return MusicMatcher(db_path=db_path, model_path=model_path) return MusicMatcher(db_path=db_path, model_path=model_path)
def load_image_processor() -> ImageProcessor: @st.cache_resource
#Инициализация модуля экстракции визуальных признаков. def load_image_processor():
weights_path = BASE_DIR / "emoset_resnet50_best.pth" """Загрузка ResNet-50 для извлечения признаков на лету."""
# Файл весов лежит в той же папке src, что и этот скрипт
model_path = BASE_DIR / "emoset_resnet50_best.pth"
return ImageProcessor(weights_path) if not model_path.exists():
print(f"❌ КРИТИЧЕСКАЯ ОШИБКА: Веса не найдены по пути: {model_path}")
# Если не нашли в src, попробуем поискать в корне проекта на всякий случай
model_path = BASE_DIR.parent / "emoset_resnet50_best.pth"
def load_emoset_data() -> Tuple[Optional[List[str]], Optional[np.ndarray], Optional[np.ndarray], Optional[Path]]: return ImageProcessor(model_path=model_path)
# Загрузка тестовой выборки датасета EmoSet.
# Модуль сохранен для обеспечения обратной совместимости в отладочном контуре.
try:
images_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
labels_path = BASE_DIR / "emoset_test_labels.npy"
embeddings_path = BASE_DIR / "emoset_test_embeddings.npy"
if not all(p.exists() for p in [labels_path, embeddings_path]): @st.cache_data
return None, None, None, None def load_emoset_data():
"""Загрузка тестовой выборки EmoSet для первой вкладки."""
# Пути относительно корня проекта
csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
img_dir = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
emb_path = BASE_DIR / "emoset_test_embeddings.npy"
lbl_path = BASE_DIR / "emoset_test_labels.npy"
labels = np.load(labels_path) if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
embeddings = np.load(embeddings_path)
csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
df = pd.read_csv(csv_path)
return df['filename'].tolist(), embeddings, labels, images_path
except Exception as e:
print(f"[WARN] Failed to load EmoSet test artifacts: {str(e)}")
return None, None, None, None return None, None, None, None
df = pd.read_csv(csv_path)
image_list = df['filename'].tolist()
embs = np.load(emb_path)
lbls = np.load(lbl_path)
return image_list, embs, lbls, img_dir
Binary file not shown.
+39 -185
View File
@@ -1,189 +1,43 @@
import os
import requests
import streamlit as st import streamlit as st
import streamlit.components.v1 as components import sys
from PIL import Image import os
import base64 import subprocess
from io import BytesIO
st.set_page_config(page_title="EmoM Playlist Generator", layout="wide", initial_sidebar_state="collapsed") from data_loader import load_music_engine, load_emoset_data, load_image_processor
from tabs.tab_dataset import render_dataset_tab
API_URL = os.getenv("BACKEND_API_URL", "http://emom_inference:8000") + "/analyze" from tabs.tab_live import render_live_tab
DEAM_AUDIO_DIR = "/app/dataset/DEAM/DEAM_audio/MEMD_audio"
def get_thumbnail_html(images, max_display=12):
html_images = ""
for file in images[:max_display]:
img = Image.open(file)
img.thumbnail((100, 100))
if img.mode != "RGB":
img = img.convert("RGB")
buffered = BytesIO()
img.save(buffered, format="JPEG")
b64_str = base64.b64encode(buffered.getvalue()).decode()
html_images += f'<img src="data:image/jpeg;base64,{b64_str}" style="width: 60px; height: 60px; object-fit: cover; border-radius: 8px; margin-right: 8px; margin-bottom: 8px; border: 1px solid rgba(255, 255, 255, 0.2);">'
if len(images) > max_display:
html_images += f'<span style="display: inline-block; width: 60px; height: 60px; line-height: 60px; text-align: center; background: rgba(150, 150, 150, 0.2); border-radius: 8px; vertical-align: top; font-size: 14px;">+{len(images) - max_display}</span>'
return f'<div style="display: flex; flex-wrap: wrap;">{html_images}</div>'
def main():
if "live_state" not in st.session_state:
st.session_state.live_state = "upload"
if "result_data" not in st.session_state:
st.session_state.result_data = None
viewport = st.query_params.get("viewport", "desktop")
st.markdown("""
<style>
[data-testid="stFileUploadDropzone"] { min-height: 250px !important; display: flex; align-items: center; justify-content: center; border-radius: 16px; background-color: rgba(255, 75, 75, 0.03); }
.spinner-container { display: flex; flex-direction: column; align-items: center; justify-content: center; min-height: 40vh; margin-top: 10vh; }
.big-spinner { width: 120px; height: 120px; border: 10px solid rgba(255, 75, 75, 0.1); border-top: 10px solid #ff4b4b; border-radius: 50%; animation: spin 1s linear infinite; margin-bottom: 2rem; }
@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
#MainMenu {visibility: hidden;} footer {visibility: hidden;}
</style>
""", unsafe_allow_html=True)
if st.session_state.live_state == "upload":
upload_placeholder = st.empty()
with upload_placeholder.container():
st.write("Загрузите изображения для визуально-семантического анализа.")
if viewport == "mobile":
st.markdown("<br>", unsafe_allow_html=True)
uploaded_files = st.file_uploader(
"Загрузка файлов",
type=['png', 'jpg', 'jpeg'],
accept_multiple_files=True,
label_visibility="collapsed" if viewport == "mobile" else "visible"
)
if uploaded_files:
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Выполнить анализ", type="primary", use_container_width=True):
st.session_state.uploaded_images = uploaded_files
st.session_state.live_state = "processing"
upload_placeholder.empty()
st.rerun()
st.markdown("<br>", unsafe_allow_html=True)
st.caption("Выбранные файлы:")
st.markdown(get_thumbnail_html(uploaded_files), unsafe_allow_html=True)
elif st.session_state.live_state == "processing":
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
files = st.session_state.get("uploaded_images", [])
st.markdown('<div class="spinner-container"><div class="big-spinner"></div><h3 style="text-align: center; font-weight: 400;">Обработка данных...</h3></div>', unsafe_allow_html=True)
try:
upload_data = [('files', (f.name, f.getvalue(), f.type)) for f in files]
response = requests.post(API_URL, files=upload_data, timeout=300)
if response.status_code == 200:
st.session_state.result_data = response.json()
st.session_state.live_state = "result"
st.rerun()
else:
st.error(f"Ошибка сервера: {response.status_code}")
if st.button("Назад"):
st.session_state.live_state = "upload"
st.rerun()
except Exception as e:
st.error(f"Ошибка соединения: {str(e)}")
if st.button("Назад"):
st.session_state.live_state = "upload"
st.rerun()
elif st.session_state.live_state == "result":
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
data = st.session_state.result_data
st.header(f"Сгенерированный плейлист (обработано файлов: {data['images_processed']})")
for row in data.get("tracks", []):
with st.container(border=True):
song_id = int(row['song_id'])
score = row['final_score']
audio_path = f"{DEAM_AUDIO_DIR}/{song_id}.mp3"
if not os.path.exists(audio_path):
audio_path = audio_path.replace('.mp3', '.wav')
if viewport == "desktop":
c1, c2 = st.columns([1, 3])
with c1:
st.write(f"**Track ID:** {song_id}")
st.caption(f"Score: {score:.4f}")
with c2:
if os.path.exists(audio_path):
st.audio(audio_path)
else:
st.caption("Аудиофайл не найден")
else:
st.write(f"**Track ID:** {song_id} (Score: {score:.4f})")
if os.path.exists(audio_path):
st.audio(audio_path)
else:
st.caption("Аудиофайл не найден")
st.markdown("<br>", unsafe_allow_html=True)
with st.expander("Отладочная информация (Метрики)"):
st.subheader("Координаты V/A")
c_v, c_a = st.columns(2)
c_v.metric("Valence", f"{data['target_v']:.2f}")
c_a.metric("Arousal", f"{data['target_a']:.2f}")
st.markdown("---")
st.subheader("Акустические признаки (LLM)")
feature_titles = {
"energy": "RMS Energy",
"flux": "Spectral Flux",
"centroid": "Spectral Centroid",
"pitch": "F0 (Pitch)",
"hnr": "HNR",
"zcr": "ZCR"
}
# Развернутые описания
feature_helps = {
"energy": "Среднеквадратичная амплитуда (громкость). Бывает высокой в плотных, интенсивных композициях, отражает общую акустическую энергию сцены.",
"flux": "Спектральный поток. Измеряет резкость изменений в спектре. Высок при четком, агрессивном ритме и частой смене нот.",
"centroid": "Спектральный центроид («яркость» звука). Высокие значения указывают на преобладание высоких частот (звонкие инструменты, открытые пространства).",
"pitch": "Основная частота звука. Высокий pitch характерен для позитивных, легких или, напротив, напряженных мелодий.",
"hnr": "Отношение гармоник к шуму. Высокий HNR — чистая мелодия и вокал. Низкий HNR — присутствие дисторшна, шумов или перкуссии.",
"zcr": "Частота пересечения нуля. Отражает шумовую составляющую сигнала. Высок в треках с выраженными ударными (hi-hats) или атмосферным шумом."
}
llm_profile = data.get("llm_profile")
if llm_profile and isinstance(llm_profile, dict) and len(llm_profile) > 0:
cols_per_row = 2 if viewport == "mobile" else 3
llm_items = list(llm_profile.items())
for i in range(0, len(llm_items), cols_per_row):
cols = st.columns(cols_per_row)
for j in range(cols_per_row):
if i + j < len(llm_items):
k, v = llm_items[i + j]
label = feature_titles.get(k, k)
tooltip = feature_helps.get(k, "")
cols[j].metric(label, f"{v:.2f}", help=tooltip)
else:
st.caption("Акустический профиль недоступен. Применен fallback-алгоритм.")
st.markdown("---")
st.write("**Извлеченные теги (BLIP-2):**")
st.write(", ".join([str(c).capitalize() for c in data.get("semantics", [])]))
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Новый запрос", use_container_width=True):
st.session_state.live_state = "upload"
st.session_state.result_data = None
st.session_state.pop("uploaded_images", None)
st.rerun()
# ----------------------------
# 1️⃣ Запуск приложения
# ----------------------------
if __name__ == "__main__": if __name__ == "__main__":
main() if "STREAMLIT_RUN" not in os.environ:
os.environ["STREAMLIT_RUN"] = "1"
cmd = [sys.executable, "-m", "streamlit", "run", __file__, "--server.port", "8080", "--server.address", "0.0.0.0"]
subprocess.run(cmd)
sys.exit()
st.set_page_config(page_title="Thesis Demo", layout="wide")
# ----------------------------
# 2️⃣ Инициализация движка и данных
# ----------------------------
matcher = load_music_engine()
image_processor = load_image_processor()
image_files, embeddings, labels_array, images_path = load_emoset_data()
# ----------------------------
# 3️⃣ Интерфейс и Вкладки
# ----------------------------
st.title("🖼️ Генератор саундтреков (Research Demo)")
tab1, tab2 = st.tabs(["📊 Отладка (Датасет EmoSet)", "📸 Анализ событий (Свои фото)"])
with tab1:
render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path)
with tab2:
if image_processor:
render_live_tab(matcher, image_processor)
else:
st.error("Система обработки изображений недоступна (не найдены веса ResNet).")
Binary file not shown.

Before

Width:  |  Height:  |  Size: 851 KiB

+35 -44
View File
@@ -1,66 +1,57 @@
import numpy as np
from pathlib import Path
from PIL import Image
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from PIL import Image
import timm import timm
from transformers import Blip2Processor, Blip2ForConditionalGeneration from pathlib import Path
import numpy as np
# НОВЫЙ ИМПОРТ ДЛЯ VLM
from transformers import BlipProcessor, BlipForConditionalGeneration
class ImageProcessor: class ImageProcessor:
def __init__(self, weights_path: str | Path): def __init__(self, model_path: Path | str):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Модель извлечения визуальных признаков # --- ПОТОК 1: ЭМОЦИИ (ResNet-50) ---
self.feature_extractor = timm.create_model('resnet50', pretrained=False, num_classes=8) print("⏳ Загрузка эмоционального модуля (ResNet-50)...")
self.emo_model = timm.create_model('resnet50', pretrained=False, num_classes=8)
if Path(model_path).exists():
self.emo_model.load_state_dict(torch.load(model_path, map_location=self.device))
self.emo_model.fc = torch.nn.Identity()
self.emo_model.to(self.device).eval()
if Path(weights_path).exists(): self.emo_transform = T.Compose([
self.feature_extractor.load_state_dict(torch.load(weights_path, map_location=self.device))
else:
print(f"Не удалось найти веса ResNet по пути: {weights_path}")
# Удаление слоя классификации для вывода сырого вектора эмбеддингов
self.feature_extractor.fc = torch.nn.Identity()
self.feature_extractor.to(self.device).eval()
# Трансформации для предварительной обработки изображений
self.preprocess_image = T.Compose([
T.Resize((224, 224)), T.Resize((224, 224)),
T.ToTensor(), T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]) ])
# Модуль семантического описания сцены # --- ПОТОК 2: СЕМАНТИКА И КОНТЕКСТ (BLIP Large) ---
print("Инициализация BLIP-2...") print("⏳ Загрузка мощной VLM модели (BLIP) для описания сцен...")
# Обход бага конфигурации Hugging Face (ручная сборка процессора) # Используем версию Large, так как позволяет железо V100
from transformers import BlipImageProcessor, AutoTokenizer self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
img_proc = BlipImageProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
tok = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b", use_fast=False)
self.blip_processor = Blip2Processor(image_processor=img_proc, tokenizer=tok) print("✅ Обе нейросети визуального анализа успешно загружены на V100!")
self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
torch_dtype=torch.float16
).to(self.device)
@torch.no_grad() @torch.no_grad()
def extract_embedding(self, image: Image.Image) -> np.ndarray: def extract_embedding(self, image: Image.Image) -> np.ndarray:
# Извлечение эмбеддингов из изображения """Извлекает 2048-мерный вектор эмоций."""
rgb_image = image.convert('RGB') img_rgb = image.convert('RGB')
img_tensor = self.preprocess_image(rgb_image).unsqueeze(0).to(self.device) img_tensor = self.emo_transform(img_rgb).unsqueeze(0).to(self.device)
return self.emo_model(img_tensor).cpu().numpy().flatten()
features = self.feature_extractor(img_tensor)
features_np = features.cpu().numpy()
return features_np.flatten()
@torch.no_grad() @torch.no_grad()
def describe_scene(self, image: Image.Image) -> str: def describe_scene(self, image: Image.Image) -> str:
# Генерация текстового описания сцены """Генерирует текстовое описание картинки (Captioning) для LLM."""
rgb_image = image.convert('RGB') img_rgb = image.convert('RGB')
inputs = self.blip_processor(images=rgb_image, return_tensors="pt").to(self.device, torch.float16) # Готовим картинку для BLIP
generated_ids = self.blip_model.generate(**inputs, max_new_tokens=40) inputs = self.blip_processor(img_rgb, return_tensors="pt").to(self.device)
scene_description = self.blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Генерируем описание (max_new_tokens ограничим, чтобы было лаконично)
out = self.blip_model.generate(**inputs, max_new_tokens=30)
return scene_description.strip() # Декодируем тензор в строку
caption = self.blip_processor.decode(out[0], skip_special_tokens=True)
return caption
+43 -48
View File
@@ -1,65 +1,60 @@
import os import requests
import json import json
import re import re
import requests
class LLMAcousticBridge: class LLMAcousticBridge:
def __init__(self, model_name="dolphin-llama3:8b"): def __init__(self, model_name="phi3", host="http://localhost:11434"):
self.model_name = model_name self.model_name = model_name
base_url = os.getenv("OLLAMA_API_URL", "http://emom_ollama:11434") self.api_url = f"{host}/api/generate"
self.api_url = f"{base_url}/api/generate"
def get_acoustic_profile(self, valence, arousal, semantics): def _clean_json(self, text):
context_str = ", ".join(semantics) if semantics else "abstract scene" """Вытаскивает чистый JSON из ответа нейросети."""
try:
match = re.search(r'\{.*\}', text, re.DOTALL)
if match:
return json.loads(match.group(0))
return json.loads(text)
except:
return None
prompt = f""" def get_acoustic_profile(self, valence, arousal, scene_descriptions):
Analyze the visual context and emotions to determine the ideal background music properties. """Просит LLM сгенерировать идеальный звук под описание."""
Emotions: Valence {valence:.1f}/9.0 (Positivity), Arousal {arousal:.1f}/9.0 (Energy). # Объединяем описания, если загружено несколько фото
Visual Context: {context_str}. context_str = " | ".join(scene_descriptions) if scene_descriptions else "abstract scene"
Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0.
1. "energy": (Loudness/Density) prompt = f"""You are an expert music producer and acoustic engineer.
2. "flux": (Rhythmic sharpness/Beat) Analyze the visual context and emotions to determine the ideal background music properties.
3. "centroid": (Brightness) Emotions: Valence {valence:.1f}/9.0 (Positivity), Arousal {arousal:.1f}/9.0 (Energy).
4. "pitch": (Fundamental frequency) Visual Context: {context_str}.
5. "hnr": (Harmonics-to-Noise)
6. "zcr": (Percussiveness)
Return ONLY a valid JSON object. No explanations, no markdown blocks. Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0.
Example: {{"energy": 0.8, "flux": 0.5, "centroid": 0.6, "pitch": 0.4, "hnr": 0.9, "zcr": 0.3}} 1. "energy": (Loudness/Density. High for massive/busy scenes, Low for calm)
""" 2. "flux": (Rhythmic sharpness/Beat. High for action/people/cars, Low for static nature)
3. "centroid": (Brightness: 0=Dark/Bass/Massive, 1=Bright/Treble/Light)
4. "pitch": (Fundamental frequency: 0=Low pitch/Huge objects, 1=High pitch/Small objects)
5. "hnr": (Harmonics-to-Noise: 0=Noisy/Distorted textures, 1=Clear/Melodic/Smooth textures)
6. "zcr": (Percussiveness. High for detailed noise like leaves/rain, Low for solid blocks)
Return ONLY a valid JSON object. Do not add any text or explanation.
Example: {{"energy": 0.5, "flux": 0.2, "centroid": 0.4, "pitch": 0.3, "hnr": 0.8, "zcr": 0.1}}"""
try: try:
payload = { response = requests.post(self.api_url, json={
"model": self.model_name, "model": self.model_name,
"prompt": prompt, "prompt": prompt,
"stream": False, "stream": False,
"format": "json" # Принудительный JSON-режим Ollama "format": "json"
} }, timeout=30)
response.raise_for_status()
print(f"Запрос акустического профиля к Ollama...") result_text = response.json().get("response", "")
response = requests.post(self.api_url, json=payload, timeout=120) profile = self._clean_json(result_text)
if response.status_code == 200:
data = response.json()
response_text = data.get("response", "")
try:
# 1. Попытка прямой десериализации
profile = json.loads(response_text)
return profile
except json.JSONDecodeError:
# 2. Аварийное извлечение JSON из текста с помощью регулярного выражения
match = re.search(r'\{.*\}', response_text, re.DOTALL)
if match:
return json.loads(match.group(0))
print(f"Ошибка парсинга LLM ответа: {response_text}")
return {}
else:
print(f"Ollama вернула ошибку HTTP: {response.status_code}")
return {}
# Проверяем, что все нужные ключи есть
required_keys = ['energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr']
if profile and all(k in profile for k in required_keys):
return profile
return None
except Exception as e: except Exception as e:
print(f"Ошибка соединения с Ollama: {str(e)}") print(f"⚠️ Ошибка связи с локальной LLM: {e}")
return {} return None
+26 -42
View File
@@ -1,83 +1,67 @@
import joblib
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from pathlib import Path from pathlib import Path
import joblib
class MusicMatcher: class MusicMatcher:
def __init__(self, db_path: Path | str, model_path: Path | str): def __init__(self, db_path: Path | str, model_path: Path | str):
# Загрузка базы данных музыкальных произведений # Загружаем твою новую, обогащенную базу
self.music_db = pd.read_csv(db_path) self.music_db = pd.read_csv(db_path)
self.acoustic_features = ['energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr'] self.acoustic_features = ['energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr']
# Удаление записей с пропущенными целевыми или акустическими признаками # Удаляем строки, где нет акустических фич
target_columns = ['valence', 'arousal'] + self.acoustic_features self.music_db = self.music_db.dropna(subset=['valence', 'arousal'] + self.acoustic_features)
self.music_db = self.music_db.dropna(subset=target_columns)
# Масштабирование акустических параметров к диапазону [0, 1] # Нормализуем акустику от 0 до 1, чтобы сравнивать с ответом LLM
self.norm_db = self.music_db.copy() self.norm_db = self.music_db.copy()
for feat in self.acoustic_features: for feat in self.acoustic_features:
f_min = self.norm_db[feat].min() f_min, f_max = self.norm_db[feat].min(), self.norm_db[feat].max()
f_max = self.norm_db[feat].max()
if f_max > f_min: if f_max > f_min:
self.norm_db[f"norm_{feat}"] = (self.norm_db[feat] - f_min) / (f_max - f_min) self.norm_db[f"norm_{feat}"] = (self.norm_db[feat] - f_min) / (f_max - f_min)
else: else:
self.norm_db[f"norm_{feat}"] = 0.0 self.norm_db[f"norm_{feat}"] = 0.0
# Определение путей к аудиофайлам и загрузка модели регрессии
self.audio_dir = Path(db_path).parent / "DEAM_audio" / "MEMD_audio" self.audio_dir = Path(db_path).parent / "DEAM_audio" / "MEMD_audio"
self.regressor = joblib.load(model_path) if Path(model_path).exists() else None
if Path(model_path).exists(): def predict_va(self, embedding: np.ndarray):
self.regressor = joblib.load(model_path) if self.regressor:
else: prediction = self.regressor.predict(embedding.reshape(1, -1))[0]
self.regressor = None return np.clip(prediction[0], 1.0, 9.0), np.clip(prediction[1], 1.0, 9.0)
return 5.0, 5.0
def predict_va(self, embedding: np.ndarray) -> tuple[float, float]:
# Прогнозирование координат Valence/Arousal по визуальному эмбеддингу
if not self.regressor:
return 5.0, 5.0
raw_prediction = self.regressor.predict(embedding.reshape(1, -1))[0]
valence_pred = np.clip(raw_prediction[0], 1.0, 9.0)
arousal_pred = np.clip(raw_prediction[1], 1.0, 9.0)
return float(valence_pred), float(arousal_pred)
def get_audio_path(self, song_id: int | float | str) -> Path | None:
# Поиск физического пути к аудиофайлу в зависимости от расширения
if not self.audio_dir.exists():
return None
def get_audio_path(self, song_id):
if not self.audio_dir.exists(): return None
clean_id = str(int(float(song_id))) clean_id = str(int(float(song_id)))
for ext in ['.mp3', '.wav']: for ext in ['.mp3', '.wav']:
path = self.audio_dir / f"{clean_id}{ext}" path = self.audio_dir / f"{clean_id}{ext}"
if path.exists(): if path.exists(): return path
return path
return None return None
def find_nearest_tracks(self, target_v: float, target_a: float, llm_profile: dict = None, top_k: int = 5) -> pd.DataFrame: def find_nearest_tracks(self, target_v: float, target_a: float, llm_profile: dict = None, top_k: int = 5):
# Расчет евклидова расстояния в эмоциональном пространстве Рассела # 1. Эмоциональная дистанция (как и раньше)
v_dist = (self.norm_db['valence'] - target_v) ** 2 emo_dist = np.sqrt(
a_dist = (self.norm_db['arousal'] - target_a) ** 2 1.0 * (self.norm_db['valence'] - target_v)**2 +
2.5 * (self.norm_db['arousal'] - target_a)**2
)
self.norm_db['emo_distance'] = emo_dist
# Взвешенное расстояние с приоритетом оси активации (Arousal) # Если LLM не дала ответ, сортируем только по эмоциям
self.norm_db['emo_distance'] = np.sqrt(1.0 * v_dist + 2.5 * a_dist)
# Ранжирование только по эмоциональному критерию при отсутствии профиля LLM
if not llm_profile: if not llm_profile:
self.norm_db['final_score'] = self.norm_db['emo_distance'] self.norm_db['final_score'] = self.norm_db['emo_distance']
return self.norm_db.sort_values(by='final_score').head(top_k) return self.norm_db.sort_values(by='final_score').head(top_k)
# Расчет отклонений по вектору акустических параметров LLM # 2. Акустическая дистанция (сравниваем треки с запросом LLM)
acoustic_penalty = np.zeros(len(self.norm_db)) acoustic_penalty = np.zeros(len(self.norm_db))
for feat in self.acoustic_features: for feat in self.acoustic_features:
if feat in llm_profile: if feat in llm_profile:
target_val = llm_profile[feat] target_val = llm_profile[feat]
acoustic_penalty += np.abs(self.norm_db[f"norm_{feat}"] - target_val) acoustic_penalty += np.abs(self.norm_db[f"norm_{feat}"] - target_val)
# Нормирование акустической дистанции # Усредняем штраф
self.norm_db['acoustic_distance'] = acoustic_penalty / len(self.acoustic_features) self.norm_db['acoustic_distance'] = acoustic_penalty / len(self.acoustic_features)
# Вычисление интегральной метрики соответствия (мультимодальный скоринг) # 3. Финальный Score (Смесь Эмоций и Акустики). Коэф 4.0 делает акустику важной!
self.norm_db['final_score'] = self.norm_db['emo_distance'] + (self.norm_db['acoustic_distance'] * 4.0) self.norm_db['final_score'] = self.norm_db['emo_distance'] + (self.norm_db['acoustic_distance'] * 4.0)
return self.norm_db.sort_values(by='final_score').head(top_k) return self.norm_db.sort_values(by='final_score').head(top_k)
Binary file not shown.
-73
View File
@@ -1,73 +0,0 @@
#!/bin/bash
# Данный скрипт написан ИИ для быстрой подготовки окружения, установка драйверов и докера
# Остановка скрипта при возникновении любой ошибки
set -e
# Цвета для красивого вывода в консоль
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
RED='\033[0;31m'
BLUE='\033[0;34m'
NC='\033[0m'
echo -e "${BLUE}[INFO]${NC} Инициализация проверки окружения для проекта EmoM..."
# 1. ПРОВЕРКА DOCKER
if ! command -v docker &> /dev/null; then
echo -e "${YELLOW}[SETUP]${NC} Docker не найден. Начинаем установку..."
# Использование официального скрипта установки Docker
curl -fsSL https://get.docker.com -o get-docker.sh
sudo sh get-docker.sh
rm get-docker.sh
# Добавляем текущего пользователя в группу docker, чтобы не писать sudo docker
sudo usermod -aG docker $USER
echo -e "${GREEN}[OK]${NC} Docker успешно установлен."
echo -e "${YELLOW}[ВНИМАНИЕ]${NC} Для применения прав группы docker потребуется перезайти в сессию SSH после завершения скрипта."
else
echo -e "${GREEN}[OK]${NC} Docker установлен ($(docker --version))."
fi
# 2. ПРОВЕРКА DOCKER COMPOSE
if ! docker compose version &> /dev/null; then
echo -e "${YELLOW}[SETUP]${NC} Плагин Docker Compose не найден. Устанавливаем..."
sudo apt-get update && sudo apt-get install -y docker-compose-plugin
echo -e "${GREEN}[OK]${NC} Docker Compose успешно установлен."
else
echo -e "${GREEN}[OK]${NC} Плагин Docker Compose доступен."
fi
# 3. ПРОВЕРКА NVIDIA И ПРОБРОСА GPU В DOCKER
if command -v nvidia-smi &> /dev/null; then
echo -e "${GREEN}[OK]${NC} Драйверы NVIDIA обнаружены."
# Проверяем наличие NVIDIA Container Toolkit
if ! dpkg -l | grep -q nvidia-container-toolkit; then
echo -e "${YELLOW}[SETUP]${NC} NVIDIA Container Toolkit не найден. Выполняется установка..."
# Настройка репозиториев NVIDIA
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
sudo apt-get update
sudo apt-get install -y nvidia-container-toolkit
# Конфигурация Docker для работы с NVIDIA
echo -e "${YELLOW}[SETUP]${NC} Конфигурация runtime NVIDIA для Docker..."
sudo nvidia-ctk runtime configure --runtime=docker
sudo systemctl restart docker
echo -e "${GREEN}[OK]${NC} NVIDIA Container Toolkit установлен и настроен."
else
echo -e "${GREEN}[OK]${NC} NVIDIA Container Toolkit уже установлен."
fi
else
echo -e "${RED}[WARN]${NC} Утилита nvidia-smi не найдена! Убедитесь, что драйверы видеокарты установлены, иначе Docker будет использовать только CPU."
fi
echo -e "\n${BLUE}[INFO]${NC} ========================================="
echo -e "${GREEN}[SUCCESS]${NC} Окружение готово к работе!"
echo -e "Теперь вы можете запустить проект командой: ${YELLOW}make up${NC}"
-20
View File
@@ -1,20 +0,0 @@
import shutil
from pathlib import Path
import kagglehub
dataset_dir = Path("../dataset/DEAM")
dataset_dir.mkdir(parents=True, exist_ok=True)
print("Скачивание датасета DEAM...")
# kagglehub по умолчанию тянет данные в системный кэш (~/.cache)
cache_path = kagglehub.dataset_download("imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music")
print(f"Загружено в кэш: {cache_path}")
print(f"Перенос файлов в {dataset_dir} и очистка временной директории...")
# Перемещаем данные
shutil.copytree(cache_path, dataset_dir, dirs_exist_ok=True)
shutil.rmtree(cache_path)
print("Готово. Датасет DEAM загружен, кэш очищен.")
-56
View File
@@ -1,56 +0,0 @@
import csv
from pathlib import Path
from datasets import load_dataset
from tqdm import tqdm
# Конфигурация корневой директории локального датасета
DATASET_DIR = Path("../dataset/EmoSet-118K")
def process_and_save_split(dataset_split, split_name: str, output_dir: Path):
# Подготовка структуры директорий для текущей выборки
split_dir = output_dir / split_name
img_dir = split_dir / "images"
img_dir.mkdir(parents=True, exist_ok=True)
labels_path = split_dir / "labels.csv"
print(f"Обработка выборки: {split_name}...")
# Открытие файла разметки перед циклом для минимизации I/O операций диска
with open(labels_path, mode="w", newline="", encoding="utf-8") as csv_file:
writer = csv.writer(csv_file)
writer.writerow(["filename", "label"])
for example in tqdm(dataset_split, desc=split_name):
img = example["image"]
emotion_label = example["emotion"]
img_id = example["image_id"]
file_name = f"{img_id}.jpg"
# Принудительная конвертация в RGB для безопасного сохранения в JPEG-формате
if img.mode != "RGB":
img = img.convert("RGB")
img.save(img_dir / file_name, format="JPEG")
writer.writerow([file_name, emotion_label])
if __name__ == "__main__":
DATASET_DIR.mkdir(exist_ok=True, parents=True)
# Инициализация подключения к Hugging Face Hub
print("Загрузка метаданных EmoSet-118K...")
raw_dataset = load_dataset("Woleek/EmoSet-118K")
# Итеративная выгрузка размеченных данных
for split_key in ["train", "val", "test"]:
if split_key in raw_dataset:
process_and_save_split(
dataset_split=raw_dataset[split_key],
split_name=split_key,
output_dir=DATASET_DIR
)
print("Экспорт датасета завершен.")
-30
View File
@@ -1,30 +0,0 @@
import pandas as pd
from pathlib import Path
# Конфигурация локальных путей
SOURCE_CSV = Path("../../dataset/DEAM/DEAM_Annotations/annotations/annotations averaged per song/song_level/static_annotations_averaged_songs_1_2000.csv")
OUTPUT_CSV = Path("../../dataset/DEAM/music_db.csv")
def prepare_deam_database():
if not SOURCE_CSV.exists():
print(f"Исходный файл аннотаций не найден: {SOURCE_CSV}")
return
print("Обработка разметки датасета DEAM...")
# Загрузка сырых данных с очисткой артефактов форматирования
raw_df = pd.read_csv(SOURCE_CSV, skipinitialspace=True)
# Экстракция координат пространства Рассела (Valence/Arousal)
processed_df = raw_df[['song_id', 'valence_mean', 'arousal_mean']].copy()
processed_df.columns = ['song_id', 'valence', 'arousal']
# Приведение идентификаторов к формату файловой системы (int)
processed_df['song_id'] = processed_df['song_id'].astype(int)
processed_df.to_csv(OUTPUT_CSV, index=False)
print(f"База успешно сформирована. Всего записей: {len(processed_df)}")
if __name__ == "__main__":
prepare_deam_database()
-60
View File
@@ -1,60 +0,0 @@
import time
import torch
import torch.nn as nn
import torch.optim as optim
# Конфигурация параметров нагрузочного тестирования
NUM_SAMPLES = 300_000
DIM_IN = 4096
DIM_OUT = 10
BATCH_SIZE = 16_384
NUM_STEPS = 1000
def run_gpu_benchmark():
# Проверка доступности аппаратного ускорения
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Инициализация стресс-теста на устройстве: {device}")
# Генерация синтетического датасета для аллокации VRAM
x_data = torch.randn(NUM_SAMPLES, DIM_IN, device=device, dtype=torch.float32)
y_data = torch.randn(NUM_SAMPLES, DIM_OUT, device=device, dtype=torch.float32)
# Архитектура тестовой полносвязной сети
model = nn.Sequential(
nn.Linear(DIM_IN, 2048),
nn.ReLU(),
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, DIM_OUT)
).to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print("Начало прогрева GPU и симуляции цикла обучения...")
start_time = time.time()
for step in range(NUM_STEPS):
# Сэмплирование случайного батча
idx = torch.randint(0, NUM_SAMPLES, (BATCH_SIZE,), device=device)
x_batch = x_data[idx]
y_batch = y_data[idx]
optimizer.zero_grad()
predictions = model(x_batch)
loss = loss_fn(predictions, y_batch)
loss.backward()
optimizer.step()
# Логирование статуса (каждые 100 итераций для снижения I/O overhead)
if step % 100 == 0:
print(f"Итерация {step}/{NUM_STEPS} | Текущий loss: {loss.item():.4f}")
end_time = time.time()
print(f"Стресс-тест завершен. Общее время: {end_time - start_time:.2f} сек.")
if __name__ == "__main__":
run_gpu_benchmark()
-184
View File
@@ -1,184 +0,0 @@
import os
import random
import warnings
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm
# Подавление предупреждений цветовых профилей
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
# Настройки окружения
DATA_ROOT = Path("/home/zin/projects/Thesis/NFS/Thesis/Emoset/EmoSet-118K")
BATCH_SIZE = 64
EPOCHS = 30
LR = 5e-5
NUM_WORKERS = 62
PATIENCE = 7
# Маппинг классов
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {DEVICE}")
# Фиксация генераторов псевдослучайных чисел
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed()
# Инициализация структур данных
class EmoSetDataset(Dataset):
def __init__(self, root: Path | str, split: str, transform=None):
self.root = Path(root) / split
self.df = pd.read_csv(self.root / "labels.csv")
self.transform = transform
# Фильтрация датафрейма
self.df = self.df[self.df["label"].isin(CLASS_MAPPING.keys())].reset_index(drop=True)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_path = self.root / "images" / row["filename"]
try:
img = Image.open(img_path).convert("RGB")
except Exception:
img = Image.new("RGB", (256, 256), (0, 0, 0))
if self.transform:
img_tensor = self.transform(img)
else:
img_tensor = T.ToTensor()(img)
label_idx = CLASS_MAPPING[row["label"]]
return img_tensor, label_idx
# Трансформации
base_tf = [
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
train_transform = T.Compose([
T.Resize(256, antialias=True),
T.RandomCrop(224),
T.RandomHorizontalFlip(),
*base_tf
])
val_transform = T.Compose([
T.Resize(256, antialias=True),
T.CenterCrop(224),
*base_tf
])
train_ds = EmoSetDataset(DATA_ROOT, "train", transform=train_transform)
val_ds = EmoSetDataset(DATA_ROOT, "val", transform=val_transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
# Инициализация модели и оптимизатора
model = timm.create_model("resnet50", pretrained=True, num_classes=8, drop_rate=0.3)
model.to(DEVICE)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
# Логика эпохи обучения
def train_epoch(current_model, loader):
current_model.train()
total_loss, correct_preds, total_samples = 0.0, 0, 0
for imgs, labels in tqdm(loader, desc="Тренировка", leave=False, smoothing=0):
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad(set_to_none=True)
logits = current_model(imgs)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * imgs.size(0)
preds = logits.argmax(dim=1)
correct_preds += (preds == labels).sum().item()
total_samples += labels.size(0)
return total_loss / total_samples, correct_preds / total_samples
# Логика эпохи валидации
@torch.no_grad()
def val_epoch(current_model, loader):
current_model.eval()
total_loss, correct_preds, total_samples = 0.0, 0, 0
for imgs, labels in tqdm(loader, desc="Валидация", leave=False, smoothing=0):
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
logits = current_model(imgs)
loss = criterion(logits, labels)
total_loss += loss.item() * imgs.size(0)
preds = logits.argmax(dim=1)
correct_preds += (preds == labels).sum().item()
total_samples += labels.size(0)
return total_loss / total_samples, correct_preds / total_samples
if __name__ == "__main__":
best_val_acc = 0.0
best_val_loss = float('inf')
epochs_no_improve = 0
checkpoint_path = "./emosetV2_resnet50_best.pth"
print("Старт обучения.")
for epoch in range(1, EPOCHS + 1):
train_loss, train_acc = train_epoch(model, train_loader)
val_loss, val_acc = val_epoch(model, val_loader)
scheduler.step()
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
# Сохранение лучших весов по Accuracy
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), checkpoint_path)
print(f"Сохранен чекпоинт (Acc: {best_val_acc:.4f})")
# Оценка переобучения по Loss (Early Stopping)
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= PATIENCE:
print(f"Ранняя остановка: метрика валидации не улучшается {PATIENCE} эпох.")
break
print("Процесс завершен.")
-283
View File
@@ -1,283 +0,0 @@
import os
import random
import warnings
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
# Подавление предупреждений цветовых профилей
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
# Настройки окружения
DATA_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K")
# ВАЖНО: Добавили путь для медиа файлов
MEDIA_DIR = Path("./src/scripts/media")
MEDIA_DIR.mkdir(parents=True, exist_ok=True)
BATCH_SIZE = 64
EPOCHS = 30
LR = 5e-5
NUM_WORKERS = 32
PATIENCE = 7
# Маппинг классов
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
}
# Инвертированный маппинг для графиков
INV_CLASS_MAPPING = {v: k for k, v in CLASS_MAPPING.items()}
CLASS_NAMES = [INV_CLASS_MAPPING[i] for i in range(len(CLASS_MAPPING))]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {DEVICE}")
# Фиксация генераторов псевдослучайных чисел
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed()
# Инициализация структур данных
class EmoSetDataset(Dataset):
def __init__(self, root: Path | str, split: str, transform=None):
self.root = Path(root) / split
self.df = pd.read_csv(self.root / "labels.csv")
self.transform = transform
# Фильтрация датафрейма
self.df = self.df[self.df["label"].isin(CLASS_MAPPING.keys())].reset_index(drop=True)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_path = self.root / "images" / row["filename"]
try:
img = Image.open(img_path).convert("RGB")
except Exception:
img = Image.new("RGB", (256, 256), (0, 0, 0))
if self.transform:
img_tensor = self.transform(img)
else:
img_tensor = T.ToTensor()(img)
label_idx = CLASS_MAPPING[row["label"]]
return img_tensor, label_idx
# Трансформации
base_tf = [
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
train_transform = T.Compose([
T.Resize(256, antialias=True),
T.RandomCrop(224),
T.RandomHorizontalFlip(),
*base_tf
])
val_transform = T.Compose([
T.Resize(256, antialias=True),
T.CenterCrop(224),
*base_tf
])
train_ds = EmoSetDataset(DATA_ROOT, "train", transform=train_transform)
val_ds = EmoSetDataset(DATA_ROOT, "val", transform=val_transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
# Инициализация модели и оптимизатора
model = timm.create_model("resnet50", pretrained=True, num_classes=8, drop_rate=0.3)
model.to(DEVICE)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
# Функции для отрисовки графиков
def plot_learning_curves(history):
"""Отрисовка графиков функции потерь и точности"""
epochs = range(1, len(history['train_loss']) + 1)
plt.figure(figsize=(14, 5))
# График Loss
plt.subplot(1, 2, 1)
plt.plot(epochs, history['train_loss'], 'b-', label='Train Loss')
plt.plot(epochs, history['val_loss'], 'r--', label='Validation Loss')
plt.title('График функции потерь (Loss)', fontsize=14)
plt.xlabel('Эпохи', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.legend()
plt.grid(True, linestyle=':', alpha=0.7)
# График Accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs, history['train_acc'], 'b-', label='Train Accuracy')
plt.plot(epochs, history['val_acc'], 'r--', label='Validation Accuracy')
plt.title('График точности (Accuracy)', fontsize=14)
plt.xlabel('Эпохи', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.legend()
plt.grid(True, linestyle=':', alpha=0.7)
plt.tight_layout()
plot_path = MEDIA_DIR / "training_history.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"[INFO] График обучения сохранен в: {plot_path}")
def plot_confusion_matrix(y_true, y_pred):
"""Отрисовка тепловой матрицы ошибок"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
cbar_kws={'label': 'Количество сэмплов'})
plt.title('Матрица ошибок (Confusion Matrix) - ResNet50', fontsize=16, pad=20)
plt.ylabel('Истинные классы (Ground Truth)', fontsize=12)
plt.xlabel('Предсказанные классы (Predicted)', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
cm_path = MEDIA_DIR / "confusion_matrix_emoset.png"
plt.savefig(cm_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"[INFO] Матрица ошибок сохранена в: {cm_path}")
# Логика эпохи обучения
def train_epoch(current_model, loader):
current_model.train()
total_loss, correct_preds, total_samples = 0.0, 0, 0
for imgs, labels in tqdm(loader, desc="Тренировка", leave=False, smoothing=0):
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad(set_to_none=True)
logits = current_model(imgs)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * imgs.size(0)
preds = logits.argmax(dim=1)
correct_preds += (preds == labels).sum().item()
total_samples += labels.size(0)
return total_loss / total_samples, correct_preds / total_samples
# Логика эпохи валидации с сохранением предсказаний для матрицы ошибок
@torch.no_grad()
def val_epoch(current_model, loader, return_preds=False):
current_model.eval()
total_loss, correct_preds, total_samples = 0.0, 0, 0
all_preds, all_labels = [], []
for imgs, labels in tqdm(loader, desc="Валидация", leave=False, smoothing=0):
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
logits = current_model(imgs)
loss = criterion(logits, labels)
total_loss += loss.item() * imgs.size(0)
preds = logits.argmax(dim=1)
correct_preds += (preds == labels).sum().item()
total_samples += labels.size(0)
if return_preds:
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
avg_loss = total_loss / total_samples
avg_acc = correct_preds / total_samples
if return_preds:
return avg_loss, avg_acc, all_labels, all_preds
return avg_loss, avg_acc
if __name__ == "__main__":
best_val_acc = 0.0
best_val_loss = float('inf')
epochs_no_improve = 0
checkpoint_path = "./emosetV2_resnet50_best.pth"
# Словарь для хранения истории обучения
history = {
'train_loss': [], 'train_acc': [],
'val_loss': [], 'val_acc': []
}
# Переменные для хранения лучших предсказаний для матрицы
best_labels, best_preds = [], []
print("Старт обучения.")
for epoch in range(1, EPOCHS + 1):
train_loss, train_acc = train_epoch(model, train_loader)
# Получаем предсказания только если это может быть лучшая эпоха
val_loss, val_acc, val_labels, val_preds = val_epoch(model, val_loader, return_preds=True)
scheduler.step()
# Запись в историю
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
# Сохранение лучших весов по Accuracy
if val_acc > best_val_acc:
best_val_acc = val_acc
best_labels = val_labels # Сохраняем предсказания лучшей модели
best_preds = val_preds
torch.save(model.state_dict(), checkpoint_path)
print(f"Сохранен чекпоинт (Acc: {best_val_acc:.4f})")
# Оценка переобучения по Loss (Early Stopping)
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= PATIENCE:
print(f"Ранняя остановка: метрика валидации не улучшается {PATIENCE} эпох.")
break
print("Процесс обучения завершен. Генерирую графики для диссертации...")
plot_learning_curves(history)
plot_confusion_matrix(best_labels, best_preds)
print("Все медиафайлы успешно созданы!")
File diff suppressed because one or more lines are too long
-171
View File
@@ -1,171 +0,0 @@
import os
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
# Настройки путей для медиа
MEDIA_DIR = Path("scripts/media")
MEDIA_DIR.mkdir(parents=True, exist_ok=True)
# Конфигурация путей для инференса и кэширования векторов
DATA_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K")
MODEL_PATH = Path("./src/emoset_resnet50_best.pth")
BATCH_SIZE = 128
NUM_WORKERS = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Вычисления перенесены на: {device}")
class EmoSetFeatureDataset(Dataset):
def __init__(self, root: Path | str, split: str):
self.root = Path(root) / split
self.df = pd.read_csv(self.root / "labels.csv")
self.labels = sorted(self.df["label"].unique())
self.label2idx = {l: i for i, l in enumerate(self.labels)}
self.idx2label = {i: l for l, i in self.label2idx.items()}
# Для экстракции признаков аугментация отключена, используется строгий CenterCrop
self.transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_path = self.root / "images" / row["filename"]
# Перехват битых файлов выборки
try:
img = Image.open(img_path).convert("RGB")
except Exception:
img = Image.new("RGB", (224, 224), (0, 0, 0))
img_tensor = self.transform(img)
label_idx = self.label2idx[row["label"]]
return img_tensor, label_idx
def plot_tsne(embeddings, labels, idx2label, sample_limit=3000):
"""Генерация t-SNE графика для диссертации"""
print(f"Построение t-SNE проекции для {sample_limit} сэмплов...")
tsne_model = TSNE(n_components=2, perplexity=30, random_state=42)
embeddings_2d = tsne_model.fit_transform(embeddings[:sample_limit])
labels_subset = labels[:sample_limit]
plt.figure(figsize=(12, 9))
# Используем более академическую палитру
scatter = plt.scatter(
embeddings_2d[:, 0],
embeddings_2d[:, 1],
c=labels_subset,
cmap="Set2", # Set2 лучше различается при печати
alpha=0.7,
s=20,
edgecolors='w',
linewidths=0.5
)
# Формирование легенды
handles, _ = scatter.legend_elements()
legend_labels = [idx2label[i] for i in range(len(idx2label))]
# Размещение легенды снаружи графика, чтобы не перекрывать данные
plt.legend(handles, legend_labels, title="Эмоциональные классы",
bbox_to_anchor=(1.05, 1), loc='upper left')
plt.title("2D проекция скрытого пространства признаков (t-SNE)", pad=20, fontsize=14)
plt.xlabel("Первая главная компонента (t-SNE 1)", fontsize=12)
plt.ylabel("Вторая главная компонента (t-SNE 2)", fontsize=12)
plt.grid(True, linestyle='--', alpha=0.3)
plt.tight_layout()
plot_path = MEDIA_DIR / "tsne_embeddings.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"[INFO] График t-SNE сохранен в: {plot_path}")
if __name__ == "__main__":
test_ds = EmoSetFeatureDataset(DATA_ROOT, "test")
test_loader = DataLoader(
test_ds,
batch_size=BATCH_SIZE,
shuffle=False, # Отключение шаффла для строгого соответствия индексов
num_workers=NUM_WORKERS,
pin_memory=True
)
print(f"Подготовлено для извлечения: {len(test_ds)} файлов.")
# Инициализация модели и загрузка лучших весов
feature_extractor = timm.create_model(
"resnet50",
pretrained=False,
num_classes=len(test_ds.labels)
)
try:
checkpoint = torch.load(MODEL_PATH, map_location=device)
feature_extractor.load_state_dict(checkpoint)
print("Веса модели успешно загружены.")
except Exception as e:
print(f"Ошибка загрузки весов: {e}. Убедитесь, что модель обучена.")
exit(1)
# Удаление классификационного слоя (fc)
feature_extractor.reset_classifier(0)
feature_extractor.to(device)
feature_extractor.eval()
print("Слой классификации удален. Модель готова к экстракции.")
extracted_embeddings = []
extracted_labels = []
print("Старт пакетной экстракции признаков...")
with torch.no_grad():
for imgs, labels in tqdm(test_loader, desc="Экстракция"):
imgs = imgs.to(device)
# Получение вектора [BATCH_SIZE, 2048]
embeddings_batch = feature_extractor(imgs)
extracted_embeddings.append(embeddings_batch.cpu().numpy())
extracted_labels.append(labels.numpy())
# Агрегация батчей в единые массивы
np_embeddings = np.concatenate(extracted_embeddings, axis=0)
np_labels = np.concatenate(extracted_labels, axis=0)
print(f"Размерность матрицы признаков: {np_embeddings.shape}")
# Сохранение артефактов
np.save("./src/emoset_test_embeddings.npy", np_embeddings)
np.save("./src/emoset_test_labels.npy", np_labels)
print("Матрицы успешно экспортированы в .npy файлы.")
# Генерация медиа для диссертации
plot_tsne(np_embeddings, np_labels, test_ds.idx2label, sample_limit=3000)
print("Процесс полностью завершен.")
-69
View File
@@ -1,69 +0,0 @@
import pandas as pd
from pathlib import Path
from tqdm import tqdm
# Конфигурация путей и целевых признаков
BASE_DIR = Path("../../dataset/DEAM")
MUSIC_DB_PATH = BASE_DIR / "music_db.csv"
FEATURES_DIR = BASE_DIR / "features" / "features"
OUTPUT_PATH = BASE_DIR / "music_db_enriched.csv"
# Маппинг низкоуровневых признаков экстрактора (openSMILE/GeMAPS) в дескрипторы системы
TARGET_FEATURES = {
'pcm_RMSenergy_sma_amean': 'energy',
'pcm_fftMag_spectralFlux_sma_amean': 'flux',
'pcm_fftMag_spectralCentroid_sma_amean': 'centroid',
'F0final_sma_amean': 'pitch',
'logHNR_sma_amean': 'hnr',
'pcm_zcr_sma_amean': 'zcr',
'pcm_fftMag_spectralEntropy_sma_amean': 'entropy',
'pcm_fftMag_psySharpness_sma_amean': 'sharpness'
}
def aggregate_acoustic_features():
if not MUSIC_DB_PATH.exists():
print(f"Базовый файл аннотаций не найден: {MUSIC_DB_PATH}")
return
print("Загрузка эмоциональной разметки DEAM...")
df_main = pd.read_csv(MUSIC_DB_PATH)
print("Агрегация фреймовых акустических признаков...")
aggregated_data = []
# Итерация по трекам для сбора покадровых характеристик
for _, row in tqdm(df_main.iterrows(), total=len(df_main), desc="Обработка аудио-векторов"):
song_id = int(row['song_id'])
feature_file = FEATURES_DIR / f"{song_id}.csv"
if feature_file.exists():
try:
# Чтение сырых векторов (формат csv с разделителем ';')
df_feat = pd.read_csv(feature_file, sep=';')
# Усреднение характеристик по временной оси (time frames)
mean_features = df_feat[list(TARGET_FEATURES.keys())].mean()
# Формирование агрегированной записи
track_data = {'song_id': song_id}
for orig_col, new_col in TARGET_FEATURES.items():
track_data[new_col] = mean_features[orig_col]
aggregated_data.append(track_data)
except Exception as e:
print(f"Ошибка парсинга файла {feature_file.name}: {e}")
# Слияние акустических дескрипторов с эмоциональными координатами (Inner Join)
df_features = pd.DataFrame(aggregated_data)
df_enriched = pd.merge(df_main, df_features, on='song_id', how='inner')
# Очистка возможных артефактов NaN после агрегации
df_enriched = df_enriched.dropna(subset=list(TARGET_FEATURES.values()))
df_enriched.to_csv(OUTPUT_PATH, index=False)
print(f"Экспорт завершен. Сформирована обогащенная база: {OUTPUT_PATH.name}")
print(f"Итоговый размер выборки: {len(df_enriched)} треков.")
if __name__ == "__main__":
aggregate_acoustic_features()
-80
View File
@@ -1,80 +0,0 @@
import joblib
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.linear_model import RidgeCV
from sklearn.multioutput import MultiOutputRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
# Проекция дискретных классов эмоций на непрерывное пространство Рассела (Valence, Arousal)
# Значения откалиброваны в диапазоне [1.0, 9.0]
EMOTION_TO_VA_COORDS = {
0: (7.5, 6.5), # amusement
1: (2.0, 8.0), # anger
2: (6.5, 5.0), # awe
3: (7.0, 3.0), # contentment
4: (3.0, 6.0), # disgust
5: (8.0, 8.0), # excitement
6: (2.5, 7.5), # fear
7: (2.0, 2.0), # sadness
}
def train_va_regressor():
# Настройка путей
base_dir = Path(__file__).resolve().parent.parent
embeddings_path = base_dir / "emoset_test_embeddings.npy"
labels_path = base_dir / "emoset_test_labels.npy"
model_output_path = base_dir / "music_engine" / "va_regressor.pkl"
if not embeddings_path.exists() or not labels_path.exists():
print(f"Артефакты признаков не найдены в директории: {base_dir}")
return
print("Загрузка вектора признаков и меток классов...")
x_features = np.load(embeddings_path)
y_discrete = np.load(labels_path)
# Трансформация целевой переменной: классы -> непрерывные координаты V/A
y_continuous = np.array([EMOTION_TO_VA_COORDS[label] for label in y_discrete])
x_train, x_test, y_train, y_test = train_test_split(
x_features, y_continuous, test_size=0.2, random_state=42
)
# Построение пайплайна: Z-масштабирование и L2-регуляризованная регрессия
# RidgeCV автоматически подбирает оптимальный гиперпараметр alpha (силу регуляризации)
print("Инициализация и обучение пайплайна RidgeCV...")
regression_pipeline = Pipeline([
('scaler', StandardScaler()),
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
])
regression_pipeline.fit(x_train, y_train)
# Оценка обобщающей способности модели
y_pred = regression_pipeline.predict(x_test)
mse_score = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print("Обучение завершено. Метрики качества на тестовой выборке:")
print(f" - MSE: {mse_score:.4f}")
print(f" - R^2: {r2:.4f}")
# Диагностика дисперсии предсказаний
v_min, v_max = y_pred[:, 0].min(), y_pred[:, 0].max()
a_min, a_max = y_pred[:, 1].min(), y_pred[:, 1].max()
print(f"Распределение Valence (прогноз): [{v_min:.2f}, {v_max:.2f}] (Эталон: 1.0 - 9.0)")
print(f"Распределение Arousal (прогноз): [{a_min:.2f}, {a_max:.2f}] (Эталон: 1.0 - 9.0)")
# Экспорт обученного пайплайна
model_output_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(regression_pipeline, model_output_path)
print(f"Пайплайн сохранен: {model_output_path.name}")
if __name__ == "__main__":
train_va_regressor()
-97
View File
@@ -1,97 +0,0 @@
import numpy as np
import pandas as pd
import joblib
from pathlib import Path
from sklearn.metrics import mean_squared_error, r2_score
# 1. Настройка путей
embeddings_path = Path("./src/emoset_test_embeddings.npy")
csv_path = Path("./NFS/Thesis/Emoset/EmoSet-118K/test/labels.csv")
model_path = Path("./src/music_engine/va_regressor.pkl")
output_dir = Path("./src/scripts/media")
output_file = output_dir / "metrics_output.txt"
# 2. Корректный маппинг 8 классов EmoSet в шкалу DEAM [1.0, 9.0]
# Формула перевода из [-1, 1] в [1, 9]: 5.0 + (X * 4.0)
EMO_TO_VA = {
"amusement": [8.2, 6.6], # Веселье (Высокий позитив, средняя энергия)
"awe": [7.0, 7.4], # Восхищение (Позитив, высокая энергия)
"contentment": [7.8, 3.4], # Умиротворение (Позитив, низкая энергия)
"excitement": [8.2, 8.2], # Возбуждение (Макс. позитив, макс. энергия)
"anger": [2.2, 7.8], # Гнев (Глубокий негатив, высокая энергия)
"disgust": [2.6, 6.6], # Отвращение (Негатив, средняя энергия)
"fear": [2.6, 8.2], # Страх (Негатив, максимальная энергия)
"sadness": [2.2, 2.6] # Грусть (Глубокий негатив, низкая энергия)
}
def generate_slide_metrics():
print("[INFO] Загрузка тестовых артефактов...")
if not all(p.exists() for p in [embeddings_path, csv_path, model_path]):
print("[ERROR] Проверьте наличие файлов данных или модели регрессора.")
return
output_dir.mkdir(parents=True, exist_ok=True)
# 3. Загрузка эмбеддингов и меток
X_test = np.load(embeddings_path)
df = pd.read_csv(csv_path)
if len(X_test) != len(df):
print(f"[WARN] Корректировка размеров выборки: Эмбеддинги ({len(X_test)}) != Метки ({len(df)})")
min_len = min(len(X_test), len(df))
X_test = X_test[:min_len]
df = df.iloc[:min_len]
y_test_list = [EMO_TO_VA.get(label.lower().strip(), [5.0, 5.0]) for label in df['label']]
y_test = np.array(y_test_list)
# 4. Выполнение инференса
print("[INFO] Выполнение инференса регрессора на скрытом пространстве признаков...")
regressor = joblib.load(model_path)
y_pred = regressor.predict(X_test)
# === БЛОК ДИАГНОСТИКИ ШКАЛЫ ===
print("\n" + "-"*50)
print(" ДИАГНОСТИКА ДИАПАЗОНОВ ЗНАЧЕНИЙ ".center(50))
print("-"*50)
print(f"Истинные (y_test) -> Мин: {y_test.min():.2f}, Макс: {y_test.max():.2f}, Среднее: {y_test.mean():.2f}")
print(f"Предсказания (y_pred) -> Мин: {y_pred.min():.2f}, Макс: {y_pred.max():.2f}, Среднее: {y_pred.mean():.2f}")
print("-"*50 + "\n")
# ==============================
# 5. Расчет метрик
mse_v = mean_squared_error(y_test[:, 0], y_pred[:, 0])
r2_v = r2_score(y_test[:, 0], y_pred[:, 0])
mse_a = mean_squared_error(y_test[:, 1], y_pred[:, 1])
r2_a = r2_score(y_test[:, 1], y_pred[:, 1])
mse_total = mean_squared_error(y_test, y_pred)
r2_total = r2_score(y_test, y_pred)
# 6. Вывод и сохранение результатов
table_content = f"""
==================================================
ТАБЛИЦА МЕТРИК ДЛЯ СЛАЙДА 10
==================================================
| Метрика | Valence (V) | Arousal (A) | Общая (Total) |
|------------|--------------|--------------|---------------|
| MSE | {mse_v:<12.4f} | {mse_a:<12.4f} | {mse_total:<13.4f} |
| R² | {r2_v:<12.4f} | {r2_a:<12.4f} | {r2_total:<13.4f} |
==================================================
Формула целевой функции для вставки на слайд (LaTeX):
$$Score_{{final}} = D_{{emo}} + 4.0 \cdot Acoustic_{{penalty}}$$
"""
print(table_content)
with open(output_file, 'w', encoding='utf-8') as f:
f.write(table_content)
print(f"[SUCCESS] Метрики успешно сохранены в файл: {output_file.absolute()}")
if __name__ == "__main__":
generate_slide_metrics()
@@ -2,30 +2,30 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"id": "8523d028", "id": "8523d028",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"import pandas as pd\n", "import pandas as pd\n",
"import numpy as np\n",
"from pathlib import Path\n", "from pathlib import Path\n",
"from PIL import Image\n", "from PIL import Image\n",
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"\n", "\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torchvision.transforms as T\n", "import torchvision.transforms as T\n",
"import timm\n", "import timm\n",
"import numpy as np\n",
"\n", "\n",
"from sklearn.metrics import accuracy_score, confusion_matrix, classification_report\n", "from sklearn.metrics import accuracy_score, confusion_matrix, classification_report\n"
"import matplotlib.pyplot as plt\n",
"import seaborn as sns"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 6,
"id": "e0781b02", "id": "e0781b02",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -41,26 +41,25 @@
} }
], ],
"source": [ "source": [
"# Конфигурация путей и параметров инференса\n",
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n", "DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
"MODEL_PATH = Path(\"./emoset_resnet50_best.pth\")\n", "MODEL_PATH = Path(\"./emoset_resnet50_best.pth\")\n",
"\n", "\n",
"BATCH_SIZE = 64\n", "BATCH_SIZE = 64\n",
"NUM_WORKERS = 4\n", "NUM_WORKERS = 4\n",
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(f\"Аппаратное ускорение: {DEVICE}\")" "\n",
"DEVICE\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 3,
"id": "79da9640", "id": "79da9640",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"class EmoSetEvaluationDataset(Dataset):\n", "class EmoSetDataset(Dataset):\n",
" # Датасет для строгой валидации с центрированным кропом\n", " def __init__(self, root, split):\n",
" def __init__(self, root: Path | str, split: str):\n",
" self.root = Path(root) / split\n", " self.root = Path(root) / split\n",
" self.df = pd.read_csv(self.root / \"labels.csv\")\n", " self.df = pd.read_csv(self.root / \"labels.csv\")\n",
"\n", "\n",
@@ -68,12 +67,13 @@
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n", " self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n", " self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
"\n", "\n",
" # Стандартный пайплайн трансформаций для инференса ResNet\n",
" self.transform = T.Compose([\n", " self.transform = T.Compose([\n",
" T.Resize(256),\n", " T.Resize((224, 224)),\n",
" T.CenterCrop(224),\n",
" T.ToTensor(),\n", " T.ToTensor(),\n",
" T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", " T.Normalize(\n",
" mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225]\n",
" )\n",
" ])\n", " ])\n",
"\n", "\n",
" def __len__(self):\n", " def __len__(self):\n",
@@ -81,23 +81,15 @@
"\n", "\n",
" def __getitem__(self, idx):\n", " def __getitem__(self, idx):\n",
" row = self.df.iloc[idx]\n", " row = self.df.iloc[idx]\n",
" img_path = self.root / \"images\" / row[\"filename\"]\n", " img = Image.open(self.root / \"images\" / row[\"filename\"]).convert(\"RGB\")\n",
" \n", " img = self.transform(img)\n",
" # Перехват битых файлов для непрерывности оценки\n", " label = self.label2idx[row[\"label\"]]\n",
" try:\n", " return img, label\n"
" img = Image.open(img_path).convert(\"RGB\")\n",
" except Exception:\n",
" img = Image.new(\"RGB\", (224, 224), (0, 0, 0))\n",
" \n",
" img_tensor = self.transform(img)\n",
" label_idx = self.label2idx[row[\"label\"]]\n",
" \n",
" return img_tensor, label_idx"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 7,
"id": "12201756", "id": "12201756",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -111,8 +103,8 @@
} }
], ],
"source": [ "source": [
"# Инициализация тестовой выборки\n", "test_ds = EmoSetDataset(DATA_ROOT, \"test\")\n",
"test_ds = EmoSetEvaluationDataset(DATA_ROOT, \"test\")\n", "\n",
"test_loader = DataLoader(\n", "test_loader = DataLoader(\n",
" test_ds,\n", " test_ds,\n",
" batch_size=BATCH_SIZE,\n", " batch_size=BATCH_SIZE,\n",
@@ -121,13 +113,13 @@
" pin_memory=True\n", " pin_memory=True\n",
")\n", ")\n",
"\n", "\n",
"print(f\"Индексированные классы: {test_ds.labels}\")\n", "print(\"Classes:\", test_ds.labels)\n",
"print(f\"Размер тестовой выборки: {len(test_ds)}\")" "print(\"Test samples:\", len(test_ds))\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 8,
"id": "7e3dc1d5", "id": "7e3dc1d5",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -382,17 +374,22 @@
} }
], ],
"source": [ "source": [
"# Инициализация модели в режиме классификации\n",
"model = timm.create_model(\n", "model = timm.create_model(\n",
" \"resnet50\",\n", " \"resnet50\",\n",
" pretrained=False,\n", " pretrained=False,\n",
" num_classes=len(test_ds.labels)\n", " num_classes=len(test_ds.labels)\n",
")" ")\n",
"\n",
"state = torch.load(MODEL_PATH, map_location=DEVICE)\n",
"model.load_state_dict(state)\n",
"\n",
"model.to(DEVICE)\n",
"model.eval()\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 9,
"id": "b42a84f1", "id": "b42a84f1",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -405,16 +402,27 @@
} }
], ],
"source": [ "source": [
"# Загрузка весов и перевод в режим инференса\n", "all_preds = []\n",
"checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)\n", "all_targets = []\n",
"model.load_state_dict(checkpoint)\n", "\n",
"model.to(DEVICE)\n", "with torch.no_grad():\n",
"model.eval()" " for imgs, labels in tqdm(test_loader):\n",
" imgs = imgs.to(DEVICE)\n",
" labels = labels.to(DEVICE)\n",
"\n",
" logits = model(imgs)\n",
" preds = logits.argmax(dim=1)\n",
"\n",
" all_preds.append(preds.cpu().numpy())\n",
" all_targets.append(labels.cpu().numpy())\n",
"\n",
"all_preds = np.concatenate(all_preds)\n",
"all_targets = np.concatenate(all_targets)\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 10,
"id": "4c1f1377", "id": "4c1f1377",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -427,25 +435,13 @@
} }
], ],
"source": [ "source": [
"# Сбор предсказаний на тестовой выборке\n", "acc = accuracy_score(all_targets, all_preds)\n",
"all_preds = []\n", "print(f\"Test accuracy: {acc:.4f}\")\n"
"all_targets = []\n",
"\n",
"print(\"Запуск инференса на тестовой выборке...\")\n",
"with torch.no_grad():\n",
" for imgs, labels in tqdm(test_loader, desc=\"Оценка метрик\"):\n",
" imgs = imgs.to(DEVICE)\n",
" \n",
" logits = model(imgs)\n",
" preds = logits.argmax(dim=1)\n",
"\n",
" all_preds.append(preds.cpu().numpy())\n",
" all_targets.append(labels.numpy())"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 11,
"id": "6b022825", "id": "6b022825",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -472,14 +468,19 @@
} }
], ],
"source": [ "source": [
"# Агрегация результатов\n", "print(\n",
"all_preds = np.concatenate(all_preds, axis=0)\n", " classification_report(\n",
"all_targets = np.concatenate(all_targets, axis=0)" " all_targets,\n",
" all_preds,\n",
" target_names=test_ds.labels,\n",
" digits=4\n",
" )\n",
")\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 12,
"id": "2fcb69ac", "id": "2fcb69ac",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -495,70 +496,20 @@
} }
], ],
"source": [ "source": [
"# Расчет интегральных метрик классификации\n", "import matplotlib.pyplot as plt\n",
"acc = accuracy_score(all_targets, all_preds)\n",
"print(f\"\\nОбщая точность (Accuracy): {acc:.4f}\\n\")\n",
"\n", "\n",
"print(\"Детализированный отчет (Classification Report):\")\n",
"print(\n",
" classification_report(\n",
" all_targets,\n",
" all_preds,\n",
" target_names=test_ds.labels,\n",
" digits=4\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2084ab91",
"metadata": {},
"outputs": [],
"source": [
"# Построение матрицы ошибок (Confusion Matrix)\n",
"cm = confusion_matrix(all_targets, all_preds)\n", "cm = confusion_matrix(all_targets, all_preds)\n",
"\n", "\n",
"plt.figure(figsize=(10, 8))" "plt.figure(figsize=(8, 8))\n",
] "plt.imshow(cm)\n",
}, "plt.colorbar()\n",
{ "plt.xticks(range(len(test_ds.labels)), test_ds.labels, rotation=45)\n",
"cell_type": "code", "plt.yticks(range(len(test_ds.labels)), test_ds.labels)\n",
"execution_count": null, "plt.xlabel(\"Predicted\")\n",
"id": "83a84e14", "plt.ylabel(\"True\")\n",
"metadata": {}, "plt.title(\"Confusion Matrix (Test)\")\n",
"outputs": [], "plt.tight_layout()\n",
"source": [ "plt.show()\n"
"# Использование seaborn для академичной визуализации с числами\n",
"sns.heatmap(\n",
" cm, \n",
" annot=True, \n",
" fmt=\"d\", \n",
" cmap=\"Blues\", \n",
" xticklabels=test_ds.labels, \n",
" yticklabels=test_ds.labels,\n",
" cbar=False\n",
")\n",
"\n",
"plt.title(\"Матрица ошибок классификации EmoSet (ResNet-50)\", pad=20)\n",
"plt.xlabel(\"Предсказанный класс\", labelpad=15)\n",
"plt.ylabel(\"Истинный класс\", labelpad=15)\n",
"plt.xticks(rotation=45)\n",
"plt.yticks(rotation=0)\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "280d5637",
"metadata": {},
"outputs": [],
"source": [
"# Экспорт графика\n",
"plt.savefig(\"../confusion_matrix_resnet50.png\", dpi=300, bbox_inches='tight')\n",
"plt.show()"
] ]
} }
], ],
+125
View File
@@ -0,0 +1,125 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "0336fd0c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ База загружена. Треков: 1744\n",
"🔍 Собираем акустические признаки...\n",
"\n",
"🚀 ГОТОВО! Обогащенная база сохранена: ../../dataset/DEAM/music_db_enriched.csv\n",
"Собрано фичей для 1744 из 1744 треков.\n",
" song_id valence arousal energy flux centroid pitch \\\n",
"0 2 3.1 3.0 0.097268 0.846947 483.421751 93.884056 \n",
"1 3 3.5 3.3 0.126809 0.959460 173.219616 62.682589 \n",
"2 4 5.7 5.5 0.156699 1.333944 466.434797 92.850316 \n",
"3 5 4.4 5.3 0.126455 1.009927 546.152506 158.673853 \n",
"4 7 5.8 6.4 0.268180 1.589191 175.369162 83.823484 \n",
"\n",
" hnr zcr entropy sharpness \n",
"0 3.615380 0.034270 3.299075 0.426490 \n",
"1 -2.600122 0.017893 2.294971 0.165583 \n",
"2 -0.579130 0.042936 3.258138 0.395410 \n",
"3 1.751148 0.043781 3.514585 0.494367 \n",
"4 12.006770 0.014783 2.177862 0.170058 \n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from pathlib import Path\n",
"from tqdm import tqdm # для красивого прогресс-бара, если не установлен - убери\n",
"\n",
"# 1. Пути к файлам\n",
"base_dir = Path(\"../../dataset/DEAM\") # Поправь, если запускаешь из другого места\n",
"music_db_path = base_dir / \"music_db.csv\"\n",
"features_dir = base_dir / \"features\" / \"features\"\n",
"output_path = base_dir / \"music_db_enriched.csv\"\n",
"\n",
"# 2. Наш \"Золотой список\" (8 признаков)\n",
"target_columns = {\n",
" 'pcm_RMSenergy_sma_amean': 'energy',\n",
" 'pcm_fftMag_spectralFlux_sma_amean': 'flux',\n",
" 'pcm_fftMag_spectralCentroid_sma_amean': 'centroid',\n",
" 'F0final_sma_amean': 'pitch',\n",
" 'logHNR_sma_amean': 'hnr',\n",
" 'pcm_zcr_sma_amean': 'zcr',\n",
" 'pcm_fftMag_spectralEntropy_sma_amean': 'entropy',\n",
" 'pcm_fftMag_psySharpness_sma_amean': 'sharpness'\n",
"}\n",
"\n",
"# 3. Загружаем текущую базу с V/A\n",
"if not music_db_path.exists():\n",
" print(f\"❌ ОШИБКА: Не найден файл {music_db_path}\")\n",
"else:\n",
" df_main = pd.read_csv(music_db_path)\n",
" print(f\"✅ База загружена. Треков: {len(df_main)}\")\n",
"\n",
" # Подготавливаем новые колонки\n",
" for col_name in target_columns.values():\n",
" df_main[col_name] = np.nan\n",
"\n",
" # 4. Проходимся по всем трекам и ищем их акустические CSV\n",
" print(\"🔍 Собираем акустические признаки...\")\n",
" found_count = 0\n",
" \n",
" for index, row in df_main.iterrows():\n",
" song_id = int(row['song_id'])\n",
" feature_file = features_dir / f\"{song_id}.csv\"\n",
" \n",
" if feature_file.exists():\n",
" try:\n",
" # Читаем CSV с признаками (разделитель там обычно точка с запятой)\n",
" df_feat = pd.read_csv(feature_file, sep=';')\n",
" \n",
" # Усредняем значения по всем фреймам (одна песня разбита на сотни строк-фреймов)\n",
" mean_features = df_feat[list(target_columns.keys())].mean()\n",
" \n",
" # Записываем в главную базу\n",
" for orig_col, new_col in target_columns.items():\n",
" df_main.at[index, new_col] = mean_features[orig_col]\n",
" \n",
" found_count += 1\n",
" except Exception as e:\n",
" print(f\"Ошибка чтения {feature_file}: {e}\")\n",
" \n",
" # 5. Сохраняем результат\n",
" # Удаляем треки, для которых не нашлось фичей (если такие есть)\n",
" df_main = df_main.dropna(subset=list(target_columns.values()))\n",
" \n",
" df_main.to_csv(output_path, index=False)\n",
" print(f\"\\n🚀 ГОТОВО! Обогащенная база сохранена: {output_path}\")\n",
" print(f\"Собрано фичей для {found_count} из {len(df_main)} треков.\")\n",
" print(df_main.head())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (thesis)",
"language": "python",
"name": "thesis"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+614
View File
@@ -0,0 +1,614 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "09f9237a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (4.4.2)\n",
"Requirement already satisfied: tqdm in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (4.67.1)\n",
"Requirement already satisfied: pillow in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (12.1.0)\n",
"Requirement already satisfied: requests in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (2.32.5)\n",
"Requirement already satisfied: filelock in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (3.20.3)\n",
"Requirement already satisfied: numpy>=1.17 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (2.4.1)\n",
"Requirement already satisfied: pyarrow>=21.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (22.0.0)\n",
"Requirement already satisfied: dill<0.4.1,>=0.3.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.4.0)\n",
"Requirement already satisfied: pandas in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (2.3.3)\n",
"Requirement already satisfied: httpx<1.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.28.1)\n",
"Requirement already satisfied: xxhash in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (3.6.0)\n",
"Requirement already satisfied: multiprocess<0.70.19 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.70.18)\n",
"Requirement already satisfied: fsspec<=2025.10.0,>=2023.1.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2025.10.0)\n",
"Requirement already satisfied: huggingface-hub<2.0,>=0.25.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (1.3.1)\n",
"Requirement already satisfied: packaging in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (25.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (6.0.3)\n",
"Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (3.13.3)\n",
"Requirement already satisfied: anyio in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (4.12.1)\n",
"Requirement already satisfied: certifi in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (2026.1.4)\n",
"Requirement already satisfied: httpcore==1.* in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (1.0.9)\n",
"Requirement already satisfied: idna in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (3.11)\n",
"Requirement already satisfied: h11>=0.16 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx<1.0.0->datasets) (0.16.0)\n",
"Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.2.0)\n",
"Requirement already satisfied: shellingham in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.5.4)\n",
"Requirement already satisfied: typer-slim in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (0.21.1)\n",
"Requirement already satisfied: typing-extensions>=4.1.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (4.15.0)\n",
"Requirement already satisfied: charset_normalizer<4,>=2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests) (3.4.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests) (2.6.3)\n",
"Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2.6.1)\n",
"Requirement already satisfied: aiosignal>=1.4.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.4.0)\n",
"Requirement already satisfied: attrs>=17.3.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (25.4.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.8.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (6.7.0)\n",
"Requirement already satisfied: propcache>=0.2.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (0.4.1)\n",
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.22.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2025.2)\n",
"Requirement already satisfied: tzdata>=2022.7 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2025.3)\n",
"Requirement already satisfied: six>=1.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n",
"Requirement already satisfied: click>=8.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from typer-slim->huggingface-hub<2.0,>=0.25.0->datasets) (8.3.1)\n"
]
}
],
"source": [
"!pip install datasets tqdm pillow requests\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6f0b2e2c",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "95f07577d20642b09f2cda6f0b2cca14",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "868d872a109d49f9966f2f19985e7048",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "06741794289540849ad179c5966dcab8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0/18 [00:00<?, ?files/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e47aad5270144913996cb5b226213ab9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00000-of-00018.parquet: 0%| | 0.00/509M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "30d1492a948245e3b6b58e92218cd760",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00001-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "931823b458cb4696b459e9011537cf1e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00002-of-00018.parquet: 0%| | 0.00/489M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "846f4245b16d4cc096a43c940590ad11",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00003-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "71df201ff1a24811af67458c3fe3f2f4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00004-of-00018.parquet: 0%| | 0.00/495M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "404dce6c69fc413dbe4aa84c289a0ab6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00005-of-00018.parquet: 0%| | 0.00/501M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e52b0bbbfdd14c599f44f02a48542317",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00006-of-00018.parquet: 0%| | 0.00/510M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "172981d77fc941cfa32c05f5a34bf742",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00007-of-00018.parquet: 0%| | 0.00/497M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cc9d886ff22f4165bf696c8b4d758931",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00008-of-00018.parquet: 0%| | 0.00/512M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5f118a9923c64ee2aa2001a1414927a3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00009-of-00018.parquet: 0%| | 0.00/502M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "db61d8d556dc4574adbd8f916f790fa7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00010-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "75414190b19c4affbe190f6dd4f7bc4f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00011-of-00018.parquet: 0%| | 0.00/500M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "172aa22ed0c44a289e0ac68b240c13c4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00012-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2baa935ed3524a73883909752cb15907",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00013-of-00018.parquet: 0%| | 0.00/491M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5e716611b29b44788e0bf2e7ad05be5b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00014-of-00018.parquet: 0%| | 0.00/502M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d9c0baac101b449794155392f07b49c3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00015-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b31cdc7f17ac4ac8a04593e8a01a300a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00016-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ed6766f750c54b4194957bfe3db78ed6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00017-of-00018.parquet: 0%| | 0.00/494M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5454d2ecded64b82a12823f02a7ab12d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/val-00000-of-00002.parquet: 0%| | 0.00/282M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "62dd1439e0514c98b0c24cc8f600c57e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/val-00001-of-00002.parquet: 0%| | 0.00/283M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3a5b966f79314e069251462bff82395f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00000-of-00004.parquet: 0%| | 0.00/422M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "422974f938924910a0712b30a9c2bd84",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00001-of-00004.parquet: 0%| | 0.00/430M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f155a08427094de7ad1a5884e623db2b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00002-of-00004.parquet: 0%| | 0.00/420M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a94a4621d19f45f690e0064fee83767b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00003-of-00004.parquet: 0%| | 0.00/422M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "50f55b00a27b4213b573b398e5b0d708",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0%| | 0/94481 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8658b8414f604f0ca2fd248a214ad4aa",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating val split: 0%| | 0/5905 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d59b7dea75f84b64bb8b262b43730e51",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating test split: 0%| | 0/17716 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0c5815040f0a4a31903348a8327811a5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading dataset shards: 0%| | 0/18 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
" num_rows: 94481\n",
" })\n",
" val: Dataset({\n",
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
" num_rows: 5905\n",
" })\n",
" test: Dataset({\n",
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
" num_rows: 17716\n",
" })\n",
"})\n"
]
}
],
"source": [
"from datasets import load_dataset\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"import requests\n",
"\n",
"# куда сохраняем датасет\n",
"DATA_DIR = Path(\"../dataset/EmoSet-118K\")\n",
"DATA_DIR.mkdir(exist_ok=True, parents=True)\n",
"\n",
"# загружаем через Hugging Face\n",
"ds = load_dataset(\"Woleek/EmoSet-118K\")\n",
"\n",
"print(ds)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "052ab073",
"metadata": {},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"from pathlib import Path\n",
"\n",
"def save_split(split):\n",
" split_dir = DATA_DIR / split\n",
" img_dir = split_dir / \"images\"\n",
" img_dir.mkdir(parents=True, exist_ok=True)\n",
"\n",
" labels_path = split_dir / \"labels.csv\"\n",
"\n",
" # перезаписываем labels.csv\n",
" with open(labels_path, \"w\") as f:\n",
" f.write(\"filename,label\\n\")\n",
"\n",
" for example in tqdm(ds[split]):\n",
" img = example[\"image\"] # уже PIL.Image\n",
" label = example[\"emotion\"]\n",
" image_id = example[\"image_id\"]\n",
"\n",
" fname = f\"{image_id}.jpg\"\n",
" img.save(img_dir / fname)\n",
"\n",
" with open(labels_path, \"a\") as f:\n",
" f.write(f\"{fname},{label}\\n\")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a74ceedf",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 94481/94481 [18:43<00:00, 84.10it/s] \n",
"100%|██████████| 5905/5905 [01:08<00:00, 86.57it/s] \n",
"100%|██████████| 17716/17716 [02:57<00:00, 100.01it/s]\n"
]
}
],
"source": [
"save_split(\"train\")\n",
"save_split(\"val\")\n",
"save_split(\"test\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "thesis-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+140
View File
@@ -0,0 +1,140 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Загрузка датасета DEAM\n",
"\n",
"Этот ноутбук предназначен для автоматизации процесса скачивания и подготовки музыкального датасета **DEAM** (Database for Emotional Analysis in Music).\n",
"Данные будут помещены в папку `dataset/DEAM`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting kagglehub\n",
" Downloading kagglehub-1.0.1-py3-none-any.whl.metadata (40 kB)\n",
"Collecting kagglesdk<1.0,>=0.1.22 (from kagglehub)\n",
" Downloading kagglesdk-0.1.23-py3-none-any.whl.metadata (13 kB)\n",
"Requirement already satisfied: packaging in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (25.0)\n",
"Requirement already satisfied: pyyaml in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (6.0.3)\n",
"Requirement already satisfied: requests in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (2.32.5)\n",
"Requirement already satisfied: tqdm in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (4.67.1)\n",
"Requirement already satisfied: protobuf in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglesdk<1.0,>=0.1.22->kagglehub) (6.33.4)\n",
"Requirement already satisfied: charset_normalizer<4,>=2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (3.4.4)\n",
"Requirement already satisfied: idna<4,>=2.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (3.11)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (2.6.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (2026.1.4)\n",
"Downloading kagglehub-1.0.1-py3-none-any.whl (70 kB)\n",
"Downloading kagglesdk-0.1.23-py3-none-any.whl (217 kB)\n",
"Installing collected packages: kagglesdk, kagglehub\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2/2\u001b[0m [kagglehub]\n",
"\u001b[1A\u001b[2KSuccessfully installed kagglehub-1.0.1 kagglesdk-0.1.23\n",
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.1.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"!pip install kagglehub"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Скачиваем датасет DEAM...\n",
"Downloading to /home/zin/.cache/kagglehub/datasets/imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music/1.archive...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1.83G/1.83G [01:09<00:00, 28.2MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting files...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Датасет скачан во временную директорию: /home/zin/.cache/kagglehub/datasets/imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music/versions/1\n",
"Переносим файлы в ../dataset/DEAM...\n",
"\n",
"[УСПЕХ] Датасет DEAM готов к работе!\n"
]
}
],
"source": [
"import os\n",
"import shutil\n",
"import kagglehub\n",
"from pathlib import Path\n",
"\n",
"# 1. Настройка путей\n",
"DATASET_ROOT = Path(\"../dataset\")\n",
"DEAM_ROOT = DATASET_ROOT / \"DEAM\"\n",
"DEAM_ROOT.mkdir(parents=True, exist_ok=True)\n",
"\n",
"# 2. Загрузка через kagglehub\n",
"print(\"Скачиваем датасет DEAM...\")\n",
"kaggle_cache_path = kagglehub.dataset_download(\"imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music\")\n",
"print(f\"Датасет скачан во временную директорию: {kaggle_cache_path}\")\n",
"\n",
"# 3. Перемещение файлов в проект\n",
"print(f\"Переносим файлы в {DEAM_ROOT}...\")\n",
"shutil.copytree(kaggle_cache_path, DEAM_ROOT, dirs_exist_ok=True)\n",
"\n",
"print(\"\\n[УСПЕХ] Датасет DEAM готов к работе!\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (my-python-project)",
"language": "python",
"name": "my-python-project"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
+171
View File
@@ -0,0 +1,171 @@
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
import json
from PIL import Image
class EmoSet(Dataset):
ATTRIBUTES_MULTI_CLASS = [
'scene', 'facial_expression', 'human_action', 'brightness', 'colorfulness',
]
ATTRIBUTES_MULTI_LABEL = [
'object'
]
NUM_CLASSES = {
'brightness': 11,
'colorfulness': 11,
'scene': 254,
'object': 409,
'facial_expression': 6,
'human_action': 264,
}
def __init__(self,
data_root,
num_emotion_classes,
phase,
):
assert num_emotion_classes in (8, 2)
assert phase in ('train', 'val', 'test')
self.transforms_dict = self.get_data_transforms()
self.info = self.get_info(data_root, num_emotion_classes)
if phase == 'train':
self.transform = self.transforms_dict['train']
elif phase == 'val':
self.transform = self.transforms_dict['val']
elif phase == 'test':
self.transform = self.transforms_dict['test']
else:
raise NotImplementedError
data_store = json.load(open(os.path.join(data_root, f'{phase}.json')))
self.data_store = [
[
self.info['emotion']['label2idx'][item[0]],
item[1],
os.path.join(data_root, item[2]),
os.path.join(data_root, item[3])
]
for item in data_store
]
@classmethod
def get_data_transforms(cls):
transforms_dict = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
return transforms_dict
def get_info(self, data_root, num_emotion_classes):
assert num_emotion_classes in (8, 2)
info = json.load(open(os.path.join(data_root, 'info.json')))
if num_emotion_classes == 8:
pass
elif num_emotion_classes == 2:
emotion_info = {
'label2idx': {
'amusement': 0,
'awe': 0,
'contentment': 0,
'excitement': 0,
'anger': 1,
'disgust': 1,
'fear': 1,
'sadness': 1,
},
'idx2label': {
'0': 'positive',
'1': 'negative',
}
}
info['emotion'] = emotion_info
else:
raise NotImplementedError
return info
def load_image_by_path(self, path):
image = Image.open(path).convert('RGB')
image = self.transform(image)
return image
def load_annotation_by_path(self, path):
json_data = json.load(open(path))
return json_data
def __getitem__(self, item):
emotion_label_idx, image_id, image_path, annotation_path = self.data_store[item]
image = self.load_image_by_path(image_path)
annotation_data = self.load_annotation_by_path(annotation_path)
data = {'image_id': image_id, 'image': image, 'emotion_label_idx': emotion_label_idx}
for attribute in self.ATTRIBUTES_MULTI_CLASS:
# if empty, set to -1, else set to label index
attribute_label_idx = -1
if attribute in annotation_data:
attribute_label_idx = self.info[attribute]['label2idx'][str(annotation_data[attribute])]
data.update({f'{attribute}_label_idx': attribute_label_idx})
for attribute in self.ATTRIBUTES_MULTI_LABEL:
# if empty, set to 0, else set to 1
assert attribute == 'object'
num_classes = self.NUM_CLASSES[attribute]
attribute_label_idx = torch.zeros(num_classes)
if attribute in annotation_data:
for label in annotation_data[attribute]:
attribute_label_idx[self.info[attribute]['label2idx'][label]] = 1
data.update({f'{attribute}_label_idx': attribute_label_idx})
return data
def __len__(self):
return len(self.data_store)
if __name__ == '__main__':
data_root = r'F:\common_file_system\EmoSet\EmoSet_v5_划分train-test-val'
num_emotion_classes = 8
phase = 'train'
dataset = EmoSet(
data_root=data_root,
num_emotion_classes=num_emotion_classes,
phase=phase,
)
# print(dataset.info)
dataloader = DataLoader(dataset, batch_size = 16, shuffle = True)
for i, data in enumerate(dataloader):
pass
# print(data['emotion_label_idx'])
# print(data['scene_label_idx'])
# print(data['facial_expression_label_idx'])
# print(data['human_action_label_idx'])
# print(data['brightness_label_idx'])
# print(data['colorfulness_label_idx'])
# print(data['object_label_idx'])
# break
File diff suppressed because one or more lines are too long
Binary file not shown.

Before

Width:  |  Height:  |  Size: 313 KiB

-12
View File
@@ -1,12 +0,0 @@
==================================================
ТАБЛИЦА МЕТРИК ДЛЯ СЛАЙДА 10
==================================================
| Метрика | Valence (V) | Arousal (A) | Общая (Total) |
|------------|--------------|--------------|---------------|
| MSE | 1.5135 | 2.2743 | 1.8939 |
| R² | 0.7927 | 0.4321 | 0.6124 |
==================================================
Формула целевой функции для вставки на слайд (LaTeX):
$$Score_{final} = D_{emo} + 4.0 \cdot Acoustic_{penalty}$$
Binary file not shown.

Before

Width:  |  Height:  |  Size: 243 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.3 MiB

+88
View File
@@ -0,0 +1,88 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"id": "b92e0213",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from pathlib import Path"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1763c51e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ УСПЕХ! База создана: ../../dataset/DEAM/music_db.csv\n",
"Всего треков в базе: 1744\n",
"Пример данных:\n",
" song_id valence arousal\n",
"0 2 3.1 3.0\n",
"1 3 3.5 3.3\n",
"2 4 5.7 5.5\n",
"3 5 4.4 5.3\n",
"4 7 5.8 6.4\n"
]
}
],
"source": [
"# Точный путь к оригинальным аннотациям\n",
"source_path = Path(\"../../dataset/DEAM/DEAM_Annotations/annotations/annotations averaged per song/song_level/static_annotations_averaged_songs_1_2000.csv\")\n",
"# Путь, куда сохраним очищенную базу для движка\n",
"output_path = Path(\"../../dataset/DEAM/music_db.csv\")\n",
"\n",
"if not source_path.exists():\n",
" print(f\"❌ Исходный файл не найден по пути: {source_path}\")\n",
"else:\n",
" # skipinitialspace=True уберет лишние пробелы в названиях колонок, если они есть\n",
" df = pd.read_csv(source_path, skipinitialspace=True)\n",
" \n",
" # Берем только нужные колонки (по твоему примеру)\n",
" clean_df = df[['song_id', 'valence_mean', 'arousal_mean']].copy()\n",
" \n",
" # Переименовываем для простоты кода в движке\n",
" clean_df.columns = ['song_id', 'valence', 'arousal']\n",
" \n",
" # Приводим ID к целому числу (2, 3, 4...), чтобы искать файлы '2.mp3'\n",
" clean_df['song_id'] = clean_df['song_id'].astype(int)\n",
" \n",
" # Сохраняем финальный файл\n",
" clean_df.to_csv(output_path, index=False)\n",
" \n",
" print(f\"✅ УСПЕХ! База создана: {output_path}\")\n",
" print(f\"Всего треков в базе: {len(clean_df)}\")\n",
" print(\"Пример данных:\")\n",
" print(clean_df.head())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (thesis)",
"language": "python",
"name": "thesis"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+114
View File
@@ -0,0 +1,114 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d70d8e32",
"metadata": {},
"outputs": [],
"source": [
"from concurrent.futures import ProcessPoolExecutor\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"import torch\n",
"from torchvision import transforms\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "31b0fa82",
"metadata": {},
"outputs": [],
"source": [
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
"TRANSFORM = transforms.Compose([\n",
" transforms.Resize((224,224)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1a17ecf5",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/94481 [00:00<?, ?it/s]\n"
]
},
{
"ename": "PicklingError",
"evalue": "Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31m_RemoteTraceback\u001b[39m Traceback (most recent call last)",
"\u001b[31m_RemoteTraceback\u001b[39m: \n\"\"\"\nTraceback (most recent call last):\n File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 244, in _feed\n obj = _ForkingPickler.dumps(obj)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/reduction.py\", line 51, in dumps\n cls(buf, protocol).dump(obj)\n_pickle.PicklingError: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed\n\"\"\"",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[31mPicklingError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 20\u001b[39m\n\u001b[32m 18\u001b[39m futures = [executor.submit(process_row, row, split_dir, tensor_dir) \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m df.itertuples()]\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m tqdm(futures):\n\u001b[32m---> \u001b[39m\u001b[32m20\u001b[39m results.append(\u001b[43mf\u001b[49m\u001b[43m.\u001b[49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 22\u001b[39m new_df = pd.DataFrame(results)\n\u001b[32m 23\u001b[39m new_df.to_csv(DATA_ROOT / split / \u001b[33m\"\u001b[39m\u001b[33mlabels_tensor.csv\u001b[39m\u001b[33m\"\u001b[39m, index=\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:449\u001b[39m, in \u001b[36mFuture.result\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 447\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m CancelledError()\n\u001b[32m 448\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state == FINISHED:\n\u001b[32m--> \u001b[39m\u001b[32m449\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m__get_result\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 451\u001b[39m \u001b[38;5;28mself\u001b[39m._condition.wait(timeout)\n\u001b[32m 453\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state \u001b[38;5;129;01min\u001b[39;00m [CANCELLED, CANCELLED_AND_NOTIFIED]:\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:401\u001b[39m, in \u001b[36mFuture.__get_result\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 399\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception:\n\u001b[32m 400\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m401\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception\n\u001b[32m 402\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 403\u001b[39m \u001b[38;5;66;03m# Break a reference cycle with the exception in self._exception\u001b[39;00m\n\u001b[32m 404\u001b[39m \u001b[38;5;28mself\u001b[39m = \u001b[38;5;28;01mNone\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py:244\u001b[39m, in \u001b[36mQueue._feed\u001b[39m\u001b[34m(buffer, notempty, send_bytes, writelock, reader_close, writer_close, ignore_epipe, onerror, queue_sem)\u001b[39m\n\u001b[32m 241\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[32m 243\u001b[39m \u001b[38;5;66;03m# serialize the data before acquiring the lock\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m244\u001b[39m obj = \u001b[43m_ForkingPickler\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdumps\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 245\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m wacquire \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 246\u001b[39m send_bytes(obj)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/reduction.py:51\u001b[39m, in \u001b[36mForkingPickler.dumps\u001b[39m\u001b[34m(cls, obj, protocol)\u001b[39m\n\u001b[32m 48\u001b[39m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[32m 49\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdumps\u001b[39m(\u001b[38;5;28mcls\u001b[39m, obj, protocol=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m 50\u001b[39m buf = io.BytesIO()\n\u001b[32m---> \u001b[39m\u001b[32m51\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mbuf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprotocol\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdump\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 52\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m buf.getbuffer()\n",
"\u001b[31mPicklingError\u001b[39m: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed"
]
}
],
"source": [
"def process_row(row, split_dir, tensor_dir):\n",
" img_path = split_dir / row.filename\n",
" img = Image.open(img_path).convert(\"RGB\")\n",
" tensor = TRANSFORM(img)\n",
" tensor_path = tensor_dir / f\"{row.filename}.pt\"\n",
" torch.save(tensor, tensor_path)\n",
" return {\"tensor_path\": str(tensor_path), \"label\": row.label}\n",
"\n",
"for split in [\"train\",\"val\",\"test\"]:\n",
" split_dir = DATA_ROOT / split / \"images\"\n",
" tensor_dir = DATA_ROOT / split / \"tensors\"\n",
" tensor_dir.mkdir(exist_ok=True, parents=True)\n",
"\n",
" df = pd.read_csv(DATA_ROOT / split / \"labels.csv\")\n",
"\n",
" results = []\n",
" with ProcessPoolExecutor(max_workers=12) as executor:\n",
" futures = [executor.submit(process_row, row, split_dir, tensor_dir) for row in df.itertuples()]\n",
" for f in tqdm(futures):\n",
" results.append(f.result())\n",
"\n",
" new_df = pd.DataFrame(results)\n",
" new_df.to_csv(DATA_ROOT / split / \"labels_tensor.csv\", index=False)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "thesis-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
-319
View File
@@ -1,319 +0,0 @@
import os
import random
import warnings
from collections import defaultdict
from pathlib import Path
from PIL import Image, ImageFile
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torch.amp import autocast, GradScaler
import timm
# Подавление предупреждений и защита от битых "хвостов" JPEG
warnings.filterwarnings("ignore")
ImageFile.LOAD_TRUNCATED_IMAGES = True
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {DEVICE}")
# --- ПУТИ ---
TRAIN_ROOT = Path("./dataset/Original-2.41M")
ANCHOR_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/train") # ЯКОРЬ (Чистые данные для обучения)
VAL_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/val")
SAVE_MODEL_PATH = Path("./src/emosetV2_resnet50_finetuned_2_41M.pth")
RESUME_CHECKPOINT = Path("./src/finetuneV2_resume.pth")
PRETRAINED_PATH = Path("./src/emosetV2_resnet50_best.pth")
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
}
# --- НАСТРОЙКИ ---
TOTAL_BATCH_SIZE = 64
BATCH_NOISY = 48 # 75% батча - новые данные 2.41M
BATCH_ANCHOR = 16 # 25% батча - чистые якорные данные 118K
EPOCHS_PER_FOLDER = 15
PATIENCE = 5
LR = 1e-6
NUM_TRAIN_WORKERS = 32
NUM_VAL_WORKERS = 32
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
# --- 1. ТРАНСФОРМАЦИИ ---
train_transform = T.Compose([
T.Resize(256),
T.RandomResizedCrop(224, scale=(0.8, 1.0)),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# --- 2. ДАТАСЕТЫ ---
class ChunkTrainDataset(Dataset):
def __init__(self, paths, transform):
self.paths = paths
self.transform = transform
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
path = self.paths[idx]
try:
img = Image.open(path).convert('RGB')
tensor = self.transform(img)
label = CLASS_MAPPING.get(path.parts[-3].lower(), 0)
return tensor, label
except Exception:
return torch.zeros((3, 224, 224)), 0
class CsvDataset(Dataset):
def __init__(self, root, transform):
self.root = Path(root)
self.df = pd.read_csv(self.root / "labels.csv")
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
path = self.root / "images" / row["filename"]
try:
img = Image.open(path).convert('RGB')
tensor = self.transform(img)
label = CLASS_MAPPING.get(row["label"].lower(), 0)
return tensor, label
except Exception:
return torch.zeros((3, 224, 224)), 0
# --- 3. СБОР ДАННЫХ ---
def prepare_chunks():
print("\nСканирование датасета 2.41M...")
chunk_dict = defaultdict(list)
for path in TRAIN_ROOT.rglob('*.jpg'):
emotion = path.parts[-3].lower()
if emotion not in CLASS_MAPPING:
continue
folder_str = path.parts[-2]
if folder_str.isdigit():
chunk_dict[int(folder_str)].append(path)
sorted_chunks = sorted(chunk_dict.keys())
print(f"Найдено пронумерованных папок (чанков): {len(sorted_chunks)}")
return chunk_dict, sorted_chunks
# --- 4. ОСНОВНОЙ ЦИКЛ ОБУЧЕНИЯ ---
if __name__ == "__main__":
chunk_dict, sorted_chunks = prepare_chunks()
# Валидационный датасет (только чистые данные)
val_loader = DataLoader(
CsvDataset(VAL_118K_ROOT, val_transform),
batch_size=TOTAL_BATCH_SIZE, shuffle=False,
num_workers=NUM_VAL_WORKERS, pin_memory=True
)
# ЯКОРНЫЙ ЗАГРУЗЧИК (Чистые данные для подмешивания)
# Используем prefetch_factor и persistent_workers для устранения рывков CPU
anchor_dataset = CsvDataset(ANCHOR_118K_ROOT, train_transform)
anchor_loader = DataLoader(
anchor_dataset, batch_size=BATCH_ANCHOR, shuffle=True,
num_workers=16, pin_memory=True, drop_last=True,
prefetch_factor=2, persistent_workers=False
)
# Инициализация модели
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
if PRETRAINED_PATH.exists():
model.load_state_dict(torch.load(PRETRAINED_PATH, map_location=DEVICE))
print(f"Базовые веса загружены из {PRETRAINED_PATH.name}")
# Размораживаем всю модель
for param in model.parameters():
param.requires_grad = True
# Дифференцированный оптимизатор
backbone_params = [p for n, p in model.named_parameters() if "fc" not in n]
fc_params = [p for n, p in model.named_parameters() if "fc" in n]
optimizer = torch.optim.AdamW([
{'params': backbone_params, 'lr': LR}, # 1e-6: микро-шаг для основы
{'params': fc_params, 'lr': LR * 10} # 1e-5: шаг для классификатора
], weight_decay=1e-3)
# Label Smoothing помогает игнорировать мусор в разметке 2.41M
criterion = nn.CrossEntropyLoss(label_smoothing=0.15)
scaler = GradScaler()
# --- ПАРАМЕТРЫ ВОССТАНОВЛЕНИЯ ---
start_stage = 0
start_epoch = 1
best_val_loss = float('inf')
if RESUME_CHECKPOINT.exists():
print(f"\nОбнаружен чекпоинт: {RESUME_CHECKPOINT.name}. Восстановление...")
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
try:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
except Exception as e:
print(f"Оптимизатор сброшен: {e}")
best_val_loss = checkpoint['best_val_loss']
start_stage = checkpoint['stage']
start_epoch = checkpoint['epoch'] + 1
print(f"Успешный запуск с ЭТАПА {start_stage + 1}, Эпохи {start_epoch}. Best Val Loss: {best_val_loss:.4f}\n")
else:
# --- ЗАМЕР EPOCH 0 (БАЗОВАЯ ТОЧНОСТЬ) ---
# Выполняется только если мы начинаем с нуля
print("\n[Проверка базовых весов перед обучением (Epoch 0)]")
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc="Baseline Eval", smoothing=0):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
with autocast(device_type="cuda"):
outputs = model(inputs)
v_loss = criterion(outputs, labels)
val_loss += v_loss.item() * inputs.size(0)
_, pred = outputs.max(1)
val_total += labels.size(0)
val_correct += pred.eq(labels).sum().item()
best_val_loss = val_loss / val_total
baseline_acc = val_correct / val_total
print(f"Стартовая точка -> Val Loss: {best_val_loss:.4f} | Val Acc: {baseline_acc:.4f}\n")
# ВОССТАНОВЛЕНИЕ НАКОПЛЕННЫХ ДАННЫХ
current_train_paths = []
for s in range(start_stage):
current_train_paths.extend(chunk_dict[sorted_chunks[s]])
print("Старт Anchor Curriculum Learning (Смешивание чистых и шумных данных).")
# ГЛАВНЫЙ ЦИКЛ ПО ПАПКАМ
for stage in range(start_stage, len(sorted_chunks)):
chunk_id = sorted_chunks[stage]
print(f"\n{'='*50}")
print(f"ЭТАП {stage+1}/{len(sorted_chunks)}: Добавляем папку '{chunk_id}'")
# Накопление и перемешивание
current_train_paths.extend(chunk_dict[chunk_id])
random.shuffle(current_train_paths)
print(f"Всего файлов (грязных) в текущем пуле: {len(current_train_paths)}")
# ОСНОВНОЙ ЗАГРУЗЧИК (Грязные данные) с PREFETCH
train_loader = DataLoader(
ChunkTrainDataset(current_train_paths, train_transform),
batch_size=BATCH_NOISY, shuffle=True,
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
worker_init_fn=worker_init_fn, drop_last=True,
prefetch_factor=4, persistent_workers=True # Устраняет рывки CPU
)
epochs_no_improve = 0
first_epoch = start_epoch if stage == start_stage else 1
# Инициализация итератора якорей
anchor_iter = iter(anchor_loader)
# ЦИКЛ ЭПОХ ДЛЯ ТЕКУЩЕГО ЭТАПА
for epoch in range(first_epoch, EPOCHS_PER_FOLDER + 1):
model.train()
train_loss, train_correct, train_total = 0.0, 0, 0
for noisy_inputs, noisy_labels in tqdm(train_loader, desc=f"S{stage+1}-Ep{epoch}/{EPOCHS_PER_FOLDER} [Train]", smoothing=0):
# Достаем якорный чистый батч
try:
anc_inputs, anc_labels = next(anchor_iter)
except StopIteration:
anchor_iter = iter(anchor_loader)
anc_inputs, anc_labels = next(anchor_iter)
# СМЕШИВАЕМ БАТЧИ (Грязные + Чистые)
inputs = torch.cat([noisy_inputs, anc_inputs]).to(DEVICE)
labels = torch.cat([noisy_labels, anc_labels]).to(DEVICE)
optimizer.zero_grad(set_to_none=True)
with autocast(device_type="cuda"):
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item() * inputs.size(0)
_, pred = outputs.max(1)
train_total += labels.size(0)
train_correct += pred.eq(labels).sum().item()
# ВАЛИДАЦИЯ
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc="[Val]", leave=False, smoothing=0):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
with autocast(device_type="cuda"):
outputs = model(inputs)
v_loss = criterion(outputs, labels)
val_loss += v_loss.item() * inputs.size(0)
_, pred = outputs.max(1)
val_total += labels.size(0)
val_correct += pred.eq(labels).sum().item()
avg_train_loss = train_loss / train_total
avg_train_acc = train_correct / train_total
avg_val_loss = val_loss / val_total
avg_val_acc = val_correct / val_total
print(f"S{stage+1}-E{epoch} | Train L: {avg_train_loss:.4f}, Acc: {avg_train_acc:.4f} | Val L: {avg_val_loss:.4f}, Acc: {avg_val_acc:.4f}")
# СОХРАНЕНИЕ ЛУЧШИХ ВЕСОВ
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
epochs_no_improve = 0
torch.save(model.state_dict(), SAVE_MODEL_PATH)
print("--> Обновлены лучшие веса")
else:
epochs_no_improve += 1
# АВАРИЙНОЕ СОХРАНЕНИЕ В КОНЦЕ ЭПОХИ
checkpoint_state = {
'stage': stage,
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss
}
torch.save(checkpoint_state, RESUME_CHECKPOINT)
os.sync() # Защита от отключения электричества
print(f"--> Чекпоинт (Этап {stage+1}, Эпоха {epoch}) зафиксирован на диске.")
# РАННЯЯ ОСТАНОВКА ДЛЯ ТЕКУЩЕГО ЭТАПА
if epochs_no_improve >= PATIENCE:
print(f"Ранняя остановка для ЭТАПА {stage+1}. Переход к следующей папке...")
break
# Сброс счетчика стартовой эпохи после прохождения восстановительного этапа
start_epoch = 1
+199
View File
@@ -0,0 +1,199 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "ca08df84",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 0/1000, Loss: 1.0013\n",
"Step 10/1000, Loss: 1.0088\n",
"Step 20/1000, Loss: 0.9956\n",
"Step 30/1000, Loss: 0.9781\n",
"Step 40/1000, Loss: 0.9613\n",
"Step 50/1000, Loss: 0.9313\n",
"Step 60/1000, Loss: 0.8927\n",
"Step 70/1000, Loss: 0.8503\n",
"Step 80/1000, Loss: 0.7537\n",
"Step 90/1000, Loss: 0.6689\n",
"Step 100/1000, Loss: 0.6063\n",
"Step 110/1000, Loss: 0.5172\n",
"Step 120/1000, Loss: 0.4592\n",
"Step 130/1000, Loss: 0.4044\n",
"Step 140/1000, Loss: 0.3610\n",
"Step 150/1000, Loss: 0.3175\n",
"Step 160/1000, Loss: 0.2825\n",
"Step 170/1000, Loss: 0.2560\n",
"Step 180/1000, Loss: 0.2360\n",
"Step 190/1000, Loss: 0.2203\n",
"Step 200/1000, Loss: 0.1930\n",
"Step 210/1000, Loss: 0.1854\n",
"Step 220/1000, Loss: 0.1723\n",
"Step 230/1000, Loss: 0.1546\n",
"Step 240/1000, Loss: 0.1386\n",
"Step 250/1000, Loss: 0.1271\n",
"Step 260/1000, Loss: 0.1109\n",
"Step 270/1000, Loss: 0.1032\n",
"Step 280/1000, Loss: 0.0899\n",
"Step 290/1000, Loss: 0.0807\n",
"Step 300/1000, Loss: 0.0750\n",
"Step 310/1000, Loss: 0.0813\n",
"Step 320/1000, Loss: 0.0612\n",
"Step 330/1000, Loss: 0.0544\n",
"Step 340/1000, Loss: 0.0552\n",
"Step 350/1000, Loss: 0.0446\n",
"Step 360/1000, Loss: 0.0403\n",
"Step 370/1000, Loss: 0.0350\n",
"Step 380/1000, Loss: 0.0612\n",
"Step 390/1000, Loss: 0.0364\n",
"Step 400/1000, Loss: 0.0322\n",
"Step 410/1000, Loss: 0.0302\n",
"Step 420/1000, Loss: 0.0519\n",
"Step 430/1000, Loss: 0.0319\n",
"Step 440/1000, Loss: 0.0260\n",
"Step 450/1000, Loss: 0.0208\n",
"Step 460/1000, Loss: 0.0409\n",
"Step 470/1000, Loss: 0.0291\n",
"Step 480/1000, Loss: 0.0234\n",
"Step 490/1000, Loss: 0.0194\n",
"Step 500/1000, Loss: 0.0274\n",
"Step 510/1000, Loss: 0.0231\n",
"Step 520/1000, Loss: 0.0199\n",
"Step 530/1000, Loss: 0.0154\n",
"Step 540/1000, Loss: 0.0278\n",
"Step 550/1000, Loss: 0.0185\n",
"Step 560/1000, Loss: 0.0180\n",
"Step 570/1000, Loss: 0.0152\n",
"Step 580/1000, Loss: 0.0132\n",
"Step 590/1000, Loss: 0.0111\n",
"Step 600/1000, Loss: 0.0396\n",
"Step 610/1000, Loss: 0.0179\n",
"Step 620/1000, Loss: 0.0148\n",
"Step 630/1000, Loss: 0.0123\n",
"Step 640/1000, Loss: 0.0265\n",
"Step 650/1000, Loss: 0.0133\n",
"Step 660/1000, Loss: 0.0128\n",
"Step 670/1000, Loss: 0.0107\n",
"Step 680/1000, Loss: 0.0142\n",
"Step 690/1000, Loss: 0.0202\n",
"Step 700/1000, Loss: 0.0125\n",
"Step 710/1000, Loss: 0.0107\n",
"Step 720/1000, Loss: 0.0140\n",
"Step 730/1000, Loss: 0.0195\n",
"Step 740/1000, Loss: 0.0148\n",
"Step 750/1000, Loss: 0.0109\n",
"Step 760/1000, Loss: 0.0094\n",
"Step 770/1000, Loss: 0.0121\n",
"Step 780/1000, Loss: 0.0233\n",
"Step 790/1000, Loss: 0.0151\n",
"Step 800/1000, Loss: 0.0134\n",
"Step 810/1000, Loss: 0.0117\n",
"Step 820/1000, Loss: 0.0124\n",
"Step 830/1000, Loss: 0.0221\n",
"Step 840/1000, Loss: 0.0161\n",
"Step 850/1000, Loss: 0.0136\n",
"Step 860/1000, Loss: 0.0161\n",
"Step 870/1000, Loss: 0.0194\n",
"Step 880/1000, Loss: 0.0145\n",
"Step 890/1000, Loss: 0.0149\n",
"Step 900/1000, Loss: 0.0232\n",
"Step 910/1000, Loss: 0.0166\n",
"Step 920/1000, Loss: 0.0156\n",
"Step 930/1000, Loss: 0.0276\n",
"Step 940/1000, Loss: 0.0176\n",
"Step 950/1000, Loss: 0.0152\n",
"Step 960/1000, Loss: 0.0162\n",
"Step 970/1000, Loss: 0.0143\n",
"Step 980/1000, Loss: 0.0136\n",
"Step 990/1000, Loss: 0.0117\n",
"Total time: 67.25 s\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import time\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(\"Using device:\", device)\n",
"\n",
"\n",
"# Огромные параметры\n",
"N, D_in, H1, H2, H3, D_out = 300_000, 4096, 2048, 1024, 512, 10\n",
"batch_size = 16_384 # большой батч\n",
"steps = 1000 # много итераций для длительной нагрузки\n",
"\n",
"# Случайные данные на GPU\n",
"x = torch.randn(N, D_in, device=device, dtype=torch.float32)\n",
"y = torch.randn(N, D_out, device=device, dtype=torch.float32)\n",
"\n",
"model = nn.Sequential(\n",
" nn.Linear(D_in, H1),\n",
" nn.ReLU(),\n",
" nn.Linear(H1, H2),\n",
" nn.ReLU(),\n",
" nn.Linear(H2, H3),\n",
" nn.ReLU(),\n",
" nn.Linear(H3, D_out)\n",
").to(device)\n",
"\n",
"loss_fn = nn.MSELoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
"\n",
"start = time.time()\n",
"for t in range(steps):\n",
" idx = torch.randint(0, N, (batch_size,), device=device)\n",
" x_batch = x[idx]\n",
" y_batch = y[idx]\n",
"\n",
" y_pred = model(x_batch)\n",
" loss = loss_fn(y_pred, y_batch)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if t % 10 == 0:\n",
" # замедляем вывод, чтобы можно было наблюдать\n",
" print(f\"Step {t}/{steps}, Loss: {loss.item():.4f}\")\n",
"\n",
"end = time.time()\n",
"print(f\"Total time: {end-start:.2f} s\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+759
View File
@@ -0,0 +1,759 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "9336560f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0c00b67b",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torchvision.transforms as T\n",
"\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"from tqdm import tqdm\n",
"\n",
"import timm\n",
"import numpy as np\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "84c3657f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'cuda'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# === CONFIG ===\n",
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
"BATCH_SIZE = 64 # V100 спокойно тянет\n",
"EPOCHS = 15\n",
"LR = 3e-4\n",
"NUM_WORKERS = 24\n",
"\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"DEVICE\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9f749add",
"metadata": {},
"outputs": [],
"source": [
"class EmoSetDataset(Dataset):\n",
" def __init__(self, root, split):\n",
" self.root = Path(root) / split\n",
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
"\n",
" self.labels = sorted(self.df[\"label\"].unique())\n",
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
"\n",
" self.transform = T.Compose([\n",
" T.Resize((224, 224)),\n",
" T.ToTensor(),\n",
" T.Normalize(\n",
" mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225]\n",
" )\n",
" ])\n",
"\n",
" def __len__(self):\n",
" return len(self.df)\n",
"\n",
" def __getitem__(self, idx):\n",
" row = self.df.iloc[idx]\n",
" img_path = self.root / \"images\" / row[\"filename\"]\n",
"\n",
" img = Image.open(img_path).convert(\"RGB\")\n",
" img = self.transform(img)\n",
"\n",
" label = self.label2idx[row[\"label\"]]\n",
" return img, label\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c8805341",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classes: ['amusement', 'anger', 'awe', 'contentment', 'disgust', 'excitement', 'fear', 'sadness']\n"
]
}
],
"source": [
"train_ds = EmoSetDataset(DATA_ROOT, \"train\")\n",
"val_ds = EmoSetDataset(DATA_ROOT, \"val\")\n",
"\n",
"train_loader = DataLoader(\n",
" train_ds,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" num_workers=NUM_WORKERS,\n",
" pin_memory=True\n",
")\n",
"\n",
"val_loader = DataLoader(\n",
" val_ds,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=False,\n",
" num_workers=NUM_WORKERS,\n",
" pin_memory=True\n",
")\n",
"\n",
"print(\"Classes:\", train_ds.labels)\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dffce582",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ResNet(\n",
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (layer1): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer2): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer3): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (4): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (5): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer4): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
" (fc): Linear(in_features=2048, out_features=8, bias=True)\n",
")"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = timm.create_model(\n",
" \"resnet50\",\n",
" pretrained=True,\n",
" num_classes=len(train_ds.labels)\n",
")\n",
"\n",
"model.to(DEVICE)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "81a457ef",
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"optimizer = torch.optim.AdamW(\n",
" model.parameters(),\n",
" lr=LR,\n",
" weight_decay=1e-4\n",
")\n",
"\n",
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
" optimizer,\n",
" T_max=EPOCHS\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "951aa9e3",
"metadata": {},
"outputs": [],
"source": [
"def train_epoch(model, loader):\n",
" model.train()\n",
" total_loss = 0\n",
" correct = 0\n",
" total = 0\n",
"\n",
" for imgs, labels in tqdm(loader, leave=False):\n",
" imgs = imgs.to(DEVICE)\n",
" labels = labels.to(DEVICE)\n",
"\n",
" optimizer.zero_grad()\n",
" logits = model(imgs)\n",
" loss = criterion(logits, labels)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" total_loss += loss.item() * imgs.size(0)\n",
" preds = logits.argmax(dim=1)\n",
" correct += (preds == labels).sum().item()\n",
" total += labels.size(0)\n",
"\n",
" return total_loss / total, correct / total\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "fb7e9398",
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def val_epoch(model, loader):\n",
" model.eval()\n",
" total_loss = 0\n",
" correct = 0\n",
" total = 0\n",
"\n",
" for imgs, labels in loader:\n",
" imgs = imgs.to(DEVICE)\n",
" labels = labels.to(DEVICE)\n",
"\n",
" logits = model(imgs)\n",
" loss = criterion(logits, labels)\n",
"\n",
" total_loss += loss.item() * imgs.size(0)\n",
" preds = logits.argmax(dim=1)\n",
" correct += (preds == labels).sum().item()\n",
" total += labels.size(0)\n",
"\n",
" return total_loss / total, correct / total\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9e870e5d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/1477 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 01 | Train loss: 0.8383, acc: 0.6954 | Val loss: 0.6694, acc: 0.7563\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 02 | Train loss: 0.5462, acc: 0.7972 | Val loss: 0.6592, acc: 0.7594\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 03 | Train loss: 0.3654, acc: 0.8632 | Val loss: 0.7263, acc: 0.7600\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 04 | Train loss: 0.2111, acc: 0.9230 | Val loss: 0.8572, acc: 0.7472\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 05 | Train loss: 0.1187, acc: 0.9585 | Val loss: 1.0372, acc: 0.7453\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 06 | Train loss: 0.0690, acc: 0.9768 | Val loss: 1.1982, acc: 0.7529\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 07 | Train loss: 0.0466, acc: 0.9843 | Val loss: 1.3178, acc: 0.7492\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 08 | Train loss: 0.0295, acc: 0.9905 | Val loss: 1.3926, acc: 0.7551\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 09 | Train loss: 0.0204, acc: 0.9938 | Val loss: 1.4682, acc: 0.7497\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10 | Train loss: 0.0146, acc: 0.9955 | Val loss: 1.4784, acc: 0.7604\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 11 | Train loss: 0.0087, acc: 0.9975 | Val loss: 1.5263, acc: 0.7580\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 12 | Train loss: 0.0057, acc: 0.9987 | Val loss: 1.5689, acc: 0.7558\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 13 | Train loss: 0.0044, acc: 0.9990 | Val loss: 1.5952, acc: 0.7566\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 14 | Train loss: 0.0030, acc: 0.9993 | Val loss: 1.6130, acc: 0.7600\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 15 | Train loss: 0.0025, acc: 0.9995 | Val loss: 1.5921, acc: 0.7627\n"
]
}
],
"source": [
"best_val_acc = 0.0\n",
"\n",
"for epoch in range(1, EPOCHS + 1):\n",
" train_loss, train_acc = train_epoch(model, train_loader)\n",
" val_loss, val_acc = val_epoch(model, val_loader)\n",
"\n",
" scheduler.step()\n",
"\n",
" print(\n",
" f\"Epoch {epoch:02d} | \"\n",
" f\"Train loss: {train_loss:.4f}, acc: {train_acc:.4f} | \"\n",
" f\"Val loss: {val_loss:.4f}, acc: {val_acc:.4f}\"\n",
" )\n",
"\n",
" if val_acc > best_val_acc:\n",
" best_val_acc = val_acc\n",
" torch.save(model.state_dict(), \"emoset_resnet50_best.pth\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7796ef11",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "thesis-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+65
View File
@@ -0,0 +1,65 @@
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.linear_model import RidgeCV
from sklearn.multioutput import MultiOutputRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import joblib
# 1. Алфавитный маппинг EmoSet
EMO_VA_MAP = {
0: (7.5, 6.5), # amusement
1: (2.0, 8.0), # anger
2: (6.5, 5.0), # awe
3: (7.0, 3.0), # contentment
4: (3.0, 6.0), # disgust
5: (8.0, 8.0), # excitement
6: (2.5, 7.5), # fear
7: (2.0, 2.0), # sadness
}
BASE_DIR = Path(__file__).resolve().parent.parent
EMBEDDINGS_PATH = BASE_DIR / "emoset_test_embeddings.npy"
LABELS_PATH = BASE_DIR / "emoset_test_labels.npy"
print("Загрузка данных...")
X = np.load(EMBEDDINGS_PATH)
y_labels = np.load(LABELS_PATH)
y_va = np.array([EMO_VA_MAP[label] for label in y_labels])
X_train, X_test, y_train, y_test = train_test_split(X, y_va, test_size=0.2, random_state=42)
# 2. НОВАЯ, ПРАВИЛЬНАЯ АРХИТЕКТУРА (Pipeline)
print("Обучение масштабатора и RidgeCV регрессора...")
# Pipeline гарантирует, что при предсказании в main.py новые векторы тоже будут масштабированы
model = Pipeline([
('scaler', StandardScaler()),
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
])
model.fit(X_train, y_train)
# 3. Диагностика и Оценка
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"\n[УСПЕХ] Обучение завершено!")
print(f"MSE: {mse:.4f}")
print(f"R^2 Score: {r2:.4f}")
# === ТОТ САМЫЙ ТЕСТ НА КОЛЛАПС ===
print("\n--- ДИАГНОСТИКА РАЗБРОСА ПРЕДСКАЗАНИЙ ---")
print(f"Valence: от {y_pred[:, 0].min():.2f} до {y_pred[:, 0].max():.2f} (Эталон: 2.0 - 8.0)")
print(f"Arousal: от {y_pred[:, 1].min():.2f} до {y_pred[:, 1].max():.2f} (Эталон: 2.0 - 8.0)")
# ===============================================
# 4. Сохранение (Pipeline сохраняется целиком со StandardScaler)
output_model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
output_model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(model, output_model_path)
print(f"\nМодель сохранена в: {output_model_path}")
+3 -4
View File
@@ -42,7 +42,7 @@ def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_pa
st.session_state.ds_current_options = random.sample(range(len(image_files)), 6) st.session_state.ds_current_options = random.sample(range(len(image_files)), 6)
st.rerun() st.rerun()
else: else:
st.success("Анализ завершен! Ваш эмоциональный профиль готов.") st.success("Анализ завершен! Ваш эмоциональный профиль готов.")
all_v, all_a = [], [] all_v, all_a = [], []
for idx in st.session_state.ds_chosen_indices: for idx in st.session_state.ds_chosen_indices:
@@ -56,7 +56,7 @@ def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_pa
col_left, col_right = st.columns([1, 2]) col_left, col_right = st.columns([1, 2])
with col_left: with col_left:
st.header("Ваш профиль") st.header("📊 Ваш профиль")
st.metric("Позитивность (Valence)", f"{target_v:.2f}") st.metric("Позитивность (Valence)", f"{target_v:.2f}")
st.metric("Энергия (Arousal)", f"{target_a:.2f}") st.metric("Энергия (Arousal)", f"{target_a:.2f}")
@@ -74,8 +74,7 @@ def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_pa
c1, c2 = st.columns([1, 3]) c1, c2 = st.columns([1, 3])
with c1: with c1:
st.write(f"**ID:** {int(row['song_id'])}") st.write(f"**ID:** {int(row['song_id'])}")
score_val = row.get('final_score', row.get('emo_distance', 0)) st.caption(f"L2 Dist: {row['distance']:.2f}")
st.caption(f"Dist Score: {score_val:.2f}")
with c2: with c2:
audio_path = matcher.get_audio_path(row['song_id']) audio_path = matcher.get_audio_path(row['song_id'])
if audio_path: if audio_path:
+59 -185
View File
@@ -1,208 +1,82 @@
import streamlit as st import streamlit as st
import streamlit.components.v1 as components
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import base64 import matplotlib.pyplot as plt
from io import BytesIO from music_engine.llm_bridge import LLMAcousticBridge # ИМПОРТИРУЕМ МОСТ
from music_engine.llm_bridge import LLMAcousticBridge
# Вспомогательная функция для крохотного предпросмотра
def get_thumbnail_html(images, max_display=12):
html_images = ""
for file in images[:max_display]:
img = Image.open(file)
img.thumbnail((100, 100)) # Сжимаем картинку
if img.mode != "RGB":
img = img.convert("RGB")
buffered = BytesIO()
img.save(buffered, format="JPEG")
b64_str = base64.b64encode(buffered.getvalue()).decode()
# Строгие стили для квадратных миниатюр
html_images += f'<img src="data:image/jpeg;base64,{b64_str}" style="width: 60px; height: 60px; object-fit: cover; border-radius: 8px; margin-right: 8px; margin-bottom: 8px; border: 1px solid rgba(255, 255, 255, 0.2);">'
# Индикатор оставшихся фото, если их много
if len(images) > max_display:
html_images += f'<span style="display: inline-block; width: 60px; height: 60px; line-height: 60px; text-align: center; background: rgba(150, 150, 150, 0.2); border-radius: 8px; vertical-align: top; font-size: 14px;">+{len(images) - max_display}</span>'
return f'<div style="display: flex; flex-wrap: wrap;">{html_images}</div>'
def render_live_tab(matcher, image_processor): def render_live_tab(matcher, image_processor):
if "live_state" not in st.session_state: st.write("Загрузите фотографии с вашего устройства. Система проанализирует эмоции и семантику кадра.")
st.session_state.live_state = "upload"
if "result_data" not in st.session_state:
st.session_state.result_data = None
viewport = st.query_params.get("viewport", "desktop") uploaded_files = st.file_uploader(
"Перетащите изображения сюда",
type=['png', 'jpg', 'jpeg'],
accept_multiple_files=True
)
# ========================================== if uploaded_files:
# CSS ИНЪЕКЦИИ st.subheader("Анализ визуальных признаков:")
# ==========================================
st.markdown("""
<style>
[data-testid="stFileUploadDropzone"] {
min-height: 250px !important;
display: flex;
align-items: center;
justify-content: center;
border-radius: 16px;
background-color: rgba(255, 75, 75, 0.03);
}
.spinner-container {
display: flex; flex-direction: column; align-items: center;
justify-content: center; min-height: 40vh; margin-top: 10vh;
}
.big-spinner {
width: 120px; height: 120px; border: 10px solid rgba(255, 75, 75, 0.1);
border-top: 10px solid #ff4b4b; border-radius: 50%;
animation: spin 1s linear infinite; margin-bottom: 2rem;
}
@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
</style>
""", unsafe_allow_html=True)
# ==========================================
# ЭКРАН 1: ЗАГРУЗКА
# ==========================================
if st.session_state.live_state == "upload":
upload_placeholder = st.empty()
with upload_placeholder.container():
st.write("Загрузите фотографии с вашего устройства. Система проанализирует эмоции и семантику кадра.")
if viewport == "mobile":
st.markdown("<br>", unsafe_allow_html=True)
uploaded_files = st.file_uploader(
"Перетащите изображения сюда",
type=['png', 'jpg', 'jpeg'],
accept_multiple_files=True,
label_visibility="collapsed" if viewport == "mobile" else "visible"
)
if uploaded_files:
# 1. КНОПКА СРАЗУ ПОСЛЕ ЗАГРУЗКИ (Не нужно скроллить вниз)
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Сгенерировать саундтрек", type="primary", use_container_width=True):
st.session_state.uploaded_images = uploaded_files
st.session_state.live_state = "processing"
upload_placeholder.empty()
st.rerun()
# 2. МИНИАТЮРЫ ПОД КНОПКОЙ
st.markdown("<br>", unsafe_allow_html=True)
st.caption("Выбранные кадры:")
# Генерируем компактный блок миниатюр
st.markdown(get_thumbnail_html(uploaded_files), unsafe_allow_html=True)
# ==========================================
# ЭКРАН 2: АНАЛИЗ (СПИННЕР)
# ==========================================
elif st.session_state.live_state == "processing":
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
files = st.session_state.get("uploaded_images", [])
st.markdown('<div class="spinner-container"><div class="big-spinner"></div></div>', unsafe_allow_html=True)
status_text = st.empty()
cols = st.columns(min(len(uploaded_files), 5))
images = [] images = []
all_objects = [] all_objects = []
all_v, all_a = [], []
for i, file in enumerate(files):
status_text.markdown(f"<h3 style='text-align: center; font-weight: 400;'>Анализ кадра {i + 1} из {len(files)}...</h3>", unsafe_allow_html=True)
for i, file in enumerate(uploaded_files):
img = Image.open(file) img = Image.open(file)
images.append(img) images.append(img)
with cols[i % 5]:
st.image(img, use_container_width=True)
with st.spinner("VLM Анализ..."):
caption = image_processor.describe_scene(img)
st.caption(f"👁️ *{caption.capitalize()}*")
all_objects.append(caption)
embedding = image_processor.extract_embedding(img) if st.button("🎵 Сгенерировать саундтрек", type="primary", use_container_width=True):
v, a = matcher.predict_va(embedding)
all_v.append(v)
all_a.append(a)
caption = image_processor.describe_scene(img) # 1. Извлекаем эмоции
all_objects.append(caption) all_v, all_a = [], []
for img in images:
embedding = image_processor.extract_embedding(img)
v, a = matcher.predict_va(embedding)
all_v.append(v)
all_a.append(a)
target_v, target_a = np.mean(all_v), np.mean(all_a) target_v, target_a = np.mean(all_v), np.mean(all_a)
status_text.markdown("<h3 style='text-align: center; font-weight: 400;'>Трансляция семантики в аудиопрофиль...</h3>", unsafe_allow_html=True) # 2. Переводим Объекты -> Акустику через LLM
llm = LLMAcousticBridge() with st.spinner("Phi-3 генерирует акустический профиль..."):
llm_profile = llm.get_acoustic_profile(target_v, target_a, list(set(all_objects))) llm = LLMAcousticBridge()
llm_profile = llm.get_acoustic_profile(target_v, target_a, list(set(all_objects)))
status_text.markdown("<h3 style='text-align: center; font-weight: 400;'>Поиск идеальных композиций...</h3>", unsafe_allow_html=True) # 3. Ищем треки
playlist = matcher.find_nearest_tracks(target_v, target_a, llm_profile=llm_profile, top_k=15) with st.spinner("Поиск треков в базе DEAM..."):
playlist = matcher.find_nearest_tracks(target_v, target_a, llm_profile=llm_profile, top_k=5)
st.session_state.result_data = { st.success("✅ Кросс-модальный анализ завершен!")
"target_v": target_v,
"target_a": target_a,
"llm_profile": llm_profile,
"playlist": playlist,
"semantics": list(set(all_objects))
}
st.session_state.live_state = "result"
st.rerun()
# ========================================== # ВЫВОД РЕЗУЛЬТАТОВ
# ЭКРАН 3: РЕЗУЛЬТАТЫ col_left, col_right = st.columns([1, 2])
# ==========================================
elif st.session_state.live_state == "result":
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0) with col_left:
st.header("📊 Профиль")
st.metric("Valence (Настроение)", f"{target_v:.2f}")
st.metric("Arousal (Энергия)", f"{target_a:.2f}")
data = st.session_state.result_data if llm_profile:
st.header("Рекомендованный плейлист") st.write("**Требования LLM к звуку:**")
for k, v in llm_profile.items():
st.caption(f"- {k}: {v:.2f}")
for _, row in data["playlist"].iterrows(): with col_right:
with st.container(border=True): st.header("🎵 Плейлист")
if viewport == "desktop": for _, row in playlist.iterrows():
c1, c2 = st.columns([1, 3]) with st.container(border=True):
with c1: c1, c2 = st.columns([1, 3])
st.write(f"**Track:** {int(row['song_id'])}") with c1:
st.caption(f"Score: {row['final_score']:.2f}") st.write(f"**Track:** {int(row['song_id'])}")
with c2: st.caption(f"Score: {row['final_score']:.2f}")
audio_path = matcher.get_audio_path(row['song_id']) with c2:
if audio_path: audio_path = matcher.get_audio_path(row['song_id'])
st.audio(str(audio_path)) if audio_path:
else: st.audio(str(audio_path))
st.warning("Файл не найден") else:
else: st.warning("Файл не найден")
st.write(f"**Track:** {int(row['song_id'])} (Score: {row['final_score']:.2f})")
audio_path = matcher.get_audio_path(row['song_id'])
if audio_path:
st.audio(str(audio_path))
else:
st.warning("Файл не найден")
st.markdown("<br>", unsafe_allow_html=True)
with st.expander("Технические параметры анализа"):
c_v, c_a = st.columns(2)
c_v.metric("Valence (Настроение)", f"{data['target_v']:.2f}")
c_a.metric("Arousal (Энергия)", f"{data['target_a']:.2f}")
st.markdown("---")
st.write("**Акустические таргеты (LLM):**")
if data["llm_profile"]:
cols_per_row = 2 if viewport == "mobile" else 3
llm_items = list(data["llm_profile"].items())
for i in range(0, len(llm_items), cols_per_row):
cols = st.columns(cols_per_row)
for j in range(cols_per_row):
if i + j < len(llm_items):
k, v = llm_items[i + j]
cols[j].metric(k, f"{v:.2f}")
st.markdown("---")
st.write("**Обнаруженная семантика:**")
st.write(", ".join([str(c).capitalize() for c in data["semantics"]]))
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Новый анализ", use_container_width=True):
st.session_state.live_state = "upload"
st.session_state.result_data = None
st.session_state.pop("uploaded_images", None)
st.rerun()
-461
View File
@@ -1,461 +0,0 @@
.
├── bin
│   ├── activate
│   ├── activate.csh
│   ├── activate.fish
│   ├── activate.nu
│   ├── activate.ps1
│   ├── activate_this.py
│   ├── debugpy
│   ├── debugpy-adapter
│   ├── f2py
│   ├── fonttools
│   ├── httpx
│   ├── ipython
│   ├── ipython3
│   ├── isympy
│   ├── jlpm
│   ├── jsonpointer
│   ├── jsonschema
│   ├── jupyter
│   ├── jupyter-dejavu
│   ├── jupyter-events
│   ├── jupyter-execute
│   ├── jupyter-kernel
│   ├── jupyter-kernelspec
│   ├── jupyter-lab
│   ├── jupyter-labextension
│   ├── jupyter-labhub
│   ├── jupyter-migrate
│   ├── jupyter-nbconvert
│   ├── jupyter-run
│   ├── jupyter-server
│   ├── jupyter-troubleshoot
│   ├── jupyter-trust
│   ├── normalizer
│   ├── numpy-config
│   ├── pip
│   ├── pip3
│   ├── pip3.12
│   ├── proton
│   ├── proton-viewer
│   ├── pybabel
│   ├── pyftmerge
│   ├── pyftsubset
│   ├── pygmentize
│   ├── pyjson5
│   ├── python -> /usr/bin/python3
│   ├── python3 -> python
│   ├── python3.12 -> python
│   ├── send2trash
│   ├── streamlit
│   ├── streamlit.cmd
│   ├── torchfrtrace
│   ├── torchrun
│   ├── tqdm
│   ├── ttx
│   ├── watchmedo
│   └── wsdump
├── CACHEDIR.TAG
├── docker
│   ├── Dockerfile.api
│   └── Dockerfile.ui
├── docker-compose.yml
├── Dockerfile
├── .dockerignore
├── .env
├── etc
│   └── jupyter
│   ├── jupyter_notebook_config.d
│   │   └── jupyterlab.json
│   ├── jupyter_server_config.d
│   │   ├── jupyterlab.json
│   │   ├── jupyter-lsp-jupyter-server.json
│   │   ├── jupyter_server_terminals.json
│   │   └── notebook_shim.json
│   └── nbconfig
│   └── notebook.d
├── .gitignore
├── .idea
│   ├── .gitignore
│   ├── inspectionProfiles
│   │   └── profiles_settings.xml
│   ├── misc.xml
│   ├── modules.xml
│   ├── Thesis.iml
│   ├── vcs.xml
│   └── workspace.xml
├── lib
│   └── python3.12
│   └── site-packages
│   ├── altair
│   ├── altair-6.0.0.dist-info
│   ├── anyio
│   ├── anyio-4.12.1.dist-info
│   ├── argon2
│   ├── argon2_cffi-25.1.0.dist-info
│   ├── _argon2_cffi_bindings
│   ├── argon2_cffi_bindings-25.1.0.dist-info
│   ├── arrow
│   ├── arrow-1.4.0.dist-info
│   ├── asttokens
│   ├── asttokens-3.0.1.dist-info
│   ├── async_lru
│   ├── async_lru-2.0.5.dist-info
│   ├── attr
│   ├── attrs
│   ├── attrs-25.4.0.dist-info
│   ├── babel
│   ├── babel-2.17.0.dist-info
│   ├── beautifulsoup4-4.14.3.dist-info
│   ├── bleach
│   ├── bleach-6.3.0.dist-info
│   ├── blinker
│   ├── blinker-1.9.0.dist-info
│   ├── bs4
│   ├── cachetools
│   ├── cachetools-6.2.4.dist-info
│   ├── certifi
│   ├── certifi-2026.1.4.dist-info
│   ├── cffi
│   ├── cffi-2.0.0.dist-info
│   ├── _cffi_backend.cpython-312-x86_64-linux-gnu.so
│   ├── charset_normalizer
│   ├── charset_normalizer-3.4.4.dist-info
│   ├── click
│   ├── click-8.3.1.dist-info
│   ├── comm
│   ├── comm-0.2.3.dist-info
│   ├── contourpy
│   ├── contourpy-1.3.3.dist-info
│   ├── cycler
│   ├── cycler-0.12.1.dist-info
│   ├── dateutil
│   ├── debugpy
│   ├── debugpy-1.8.19.dist-info
│   ├── decorator-5.2.1.dist-info
│   ├── decorator.py
│   ├── defusedxml
│   ├── defusedxml-0.7.1.dist-info
│   ├── _distutils_hack
│   ├── distutils-precedence.pth
│   ├── .DS_Store
│   ├── executing
│   ├── executing-2.2.1.dist-info
│   ├── fastjsonschema
│   ├── fastjsonschema-2.21.2.dist-info
│   ├── filelock
│   ├── filelock-3.20.3.dist-info
│   ├── fontTools
│   ├── fonttools-4.61.1.dist-info
│   ├── fqdn
│   ├── fqdn-1.5.1.dist-info
│   ├── fsspec
│   ├── fsspec-2026.1.0.dist-info
│   ├── functorch
│   ├── git
│   ├── gitdb
│   ├── gitdb-4.0.12.dist-info
│   ├── gitpython-3.1.46.dist-info
│   ├── google
│   ├── h11
│   ├── h11-0.16.0.dist-info
│   ├── httpcore
│   ├── httpcore-1.0.9.dist-info
│   ├── httpx
│   ├── httpx-0.28.1.dist-info
│   ├── idna
│   ├── idna-3.11.dist-info
│   ├── ipykernel
│   ├── ipykernel-7.1.0.dist-info
│   ├── ipykernel_launcher.py
│   ├── IPython
│   ├── ipython-9.9.0.dist-info
│   ├── ipython_pygments_lexers-1.1.1.dist-info
│   ├── ipython_pygments_lexers.py
│   ├── isoduration
│   ├── isoduration-20.11.0.dist-info
│   ├── isympy.py
│   ├── jedi
│   ├── jedi-0.19.2.dist-info
│   ├── jinja2
│   ├── jinja2-3.1.6.dist-info
│   ├── joblib
│   ├── joblib-1.5.3.dist-info
│   ├── json5
│   ├── json5-0.13.0.dist-info
│   ├── jsonpointer-3.0.0.dist-info
│   ├── jsonpointer.py
│   ├── jsonschema
│   ├── jsonschema-4.26.0.dist-info
│   ├── jsonschema_specifications
│   ├── jsonschema_specifications-2025.9.1.dist-info
│   ├── jupyter_client
│   ├── jupyter_client-8.8.0.dist-info
│   ├── jupyter_core
│   ├── jupyter_core-5.9.1.dist-info
│   ├── jupyter_events
│   ├── jupyter_events-0.12.0.dist-info
│   ├── jupyterlab
│   ├── jupyterlab-4.5.1.dist-info
│   ├── jupyterlab_pygments
│   ├── jupyterlab_pygments-0.3.0.dist-info
│   ├── jupyterlab_server
│   ├── jupyterlab_server-2.28.0.dist-info
│   ├── jupyter_lsp
│   ├── jupyter_lsp-2.3.0.dist-info
│   ├── jupyter.py
│   ├── jupyter_server
│   ├── jupyter_server-2.17.0.dist-info
│   ├── jupyter_server_terminals
│   ├── jupyter_server_terminals-0.5.3.dist-info
│   ├── kiwisolver
│   ├── kiwisolver-1.4.9.dist-info
│   ├── lark
│   ├── lark-1.3.1.dist-info
│   ├── markupsafe
│   ├── markupsafe-3.0.3.dist-info
│   ├── matplotlib
│   ├── matplotlib-3.10.8.dist-info
│   ├── matplotlib_inline
│   ├── matplotlib_inline-0.2.1.dist-info
│   ├── mistune
│   ├── mistune-3.2.0.dist-info
│   ├── mpl_toolkits
│   ├── mpmath
│   ├── mpmath-1.3.0.dist-info
│   ├── narwhals
│   ├── narwhals-2.15.0.dist-info
│   ├── nbclient
│   ├── nbclient-0.10.4.dist-info
│   ├── nbconvert
│   ├── nbconvert-7.16.6.dist-info
│   ├── nbformat
│   ├── nbformat-5.10.4.dist-info
│   ├── nest_asyncio-1.6.0.dist-info
│   ├── nest_asyncio.py
│   ├── networkx
│   ├── networkx-3.6.1.dist-info
│   ├── notebook_shim
│   ├── notebook_shim-0.2.4.dist-info
│   ├── numpy
│   ├── numpy-2.4.1.dist-info
│   ├── numpy.libs
│   ├── nvidia
│   ├── nvidia_cublas_cu12-12.8.4.1.dist-info
│   ├── nvidia_cuda_cupti_cu12-12.8.90.dist-info
│   ├── nvidia_cuda_nvrtc_cu12-12.8.93.dist-info
│   ├── nvidia_cuda_runtime_cu12-12.8.90.dist-info
│   ├── nvidia_cudnn_cu12-9.10.2.21.dist-info
│   ├── nvidia_cufft_cu12-11.3.3.83.dist-info
│   ├── nvidia_cufile_cu12-1.13.1.3.dist-info
│   ├── nvidia_curand_cu12-10.3.9.90.dist-info
│   ├── nvidia_cusolver_cu12-11.7.3.90.dist-info
│   ├── nvidia_cusparse_cu12-12.5.8.93.dist-info
│   ├── nvidia_cusparselt_cu12-0.7.1.dist-info
│   ├── nvidia_nccl_cu12-2.27.5.dist-info
│   ├── nvidia_nvjitlink_cu12-12.8.93.dist-info
│   ├── nvidia_nvshmem_cu12-3.3.20.dist-info
│   ├── nvidia_nvtx_cu12-12.8.90.dist-info
│   ├── packaging
│   ├── packaging-25.0.dist-info
│   ├── pandas
│   ├── pandas-2.3.3.dist-info
│   ├── pandocfilters-1.5.1.dist-info
│   ├── pandocfilters.py
│   ├── parso
│   ├── parso-0.8.5.dist-info
│   ├── pexpect
│   ├── pexpect-4.9.0.dist-info
│   ├── PIL
│   ├── pillow-12.1.0.dist-info
│   ├── pillow.libs
│   ├── pip
│   ├── pip-25.3.dist-info
│   ├── pkg_resources
│   ├── platformdirs
│   ├── platformdirs-4.5.1.dist-info
│   ├── prometheus_client
│   ├── prometheus_client-0.23.1.dist-info
│   ├── prompt_toolkit
│   ├── prompt_toolkit-3.0.52.dist-info
│   ├── protobuf-6.33.4.dist-info
│   ├── psutil
│   ├── psutil-7.2.1.dist-info
│   ├── ptyprocess
│   ├── ptyprocess-0.7.0.dist-info
│   ├── pure_eval
│   ├── pure_eval-0.2.3.dist-info
│   ├── pyarrow
│   ├── pyarrow-22.0.0.dist-info
│   ├── pycparser
│   ├── pycparser-2.23.dist-info
│   ├── pydeck
│   ├── pydeck-0.9.1.dist-info
│   ├── pygments
│   ├── pygments-2.19.2.dist-info
│   ├── pylab.py
│   ├── pyparsing
│   ├── pyparsing-3.3.1.dist-info
│   ├── python_dateutil-2.9.0.post0.dist-info
│   ├── pythonjsonlogger
│   ├── python_json_logger-4.0.0.dist-info
│   ├── pytz
│   ├── pytz-2025.2.dist-info
│   ├── pyyaml-6.0.3.dist-info
│   ├── pyzmq-27.1.0.dist-info
│   ├── pyzmq.libs
│   ├── referencing
│   ├── referencing-0.37.0.dist-info
│   ├── requests
│   ├── requests-2.32.5.dist-info
│   ├── rfc3339_validator-0.1.4.dist-info
│   ├── rfc3339_validator.py
│   ├── rfc3986_validator-0.1.1.dist-info
│   ├── rfc3986_validator.py
│   ├── rfc3987_syntax
│   ├── rfc3987_syntax-1.1.0.dist-info
│   ├── rpds
│   ├── rpds_py-0.30.0.dist-info
│   ├── scikit_learn-1.8.0.dist-info
│   ├── scikit_learn.libs
│   ├── scipy
│   ├── scipy-1.17.0.dist-info
│   ├── scipy.libs
│   ├── send2trash
│   ├── send2trash-2.0.0.dist-info
│   ├── setuptools
│   ├── setuptools-80.9.0.dist-info
│   ├── six-1.17.0.dist-info
│   ├── six.py
│   ├── sklearn
│   ├── smmap
│   ├── smmap-5.0.2.dist-info
│   ├── soupsieve
│   ├── soupsieve-2.8.1.dist-info
│   ├── stack_data
│   ├── stack_data-0.6.3.dist-info
│   ├── streamlit
│   ├── streamlit-1.53.0.dist-info
│   ├── sympy
│   ├── sympy-1.14.0.dist-info
│   ├── tenacity
│   ├── tenacity-9.1.2.dist-info
│   ├── terminado
│   ├── terminado-0.18.1.dist-info
│   ├── threadpoolctl-3.6.0.dist-info
│   ├── threadpoolctl.py
│   ├── tinycss2
│   ├── tinycss2-1.4.0.dist-info
│   ├── toml
│   ├── toml-0.10.2.dist-info
│   ├── torch
│   ├── torch-2.9.1.dist-info
│   ├── torchaudio
│   ├── torchaudio-2.9.1.dist-info
│   ├── torchgen
│   ├── torchvision
│   ├── torchvision-0.24.1.dist-info
│   ├── torchvision.libs
│   ├── tornado
│   ├── tornado-6.5.4.dist-info
│   ├── tqdm
│   ├── tqdm-4.67.1.dist-info
│   ├── traitlets
│   ├── traitlets-5.14.3.dist-info
│   ├── triton
│   ├── triton-3.5.1.dist-info
│   ├── typing_extensions-4.15.0.dist-info
│   ├── typing_extensions.py
│   ├── tzdata
│   ├── tzdata-2025.3.dist-info
│   ├── uri_template
│   ├── uri_template-1.3.0.dist-info
│   ├── urllib3
│   ├── urllib3-2.6.3.dist-info
│   ├── _virtualenv.pth
│   ├── _virtualenv.py
│   ├── watchdog
│   ├── watchdog-6.0.0.dist-info
│   ├── wcwidth
│   ├── wcwidth-0.2.14.dist-info
│   ├── webcolors
│   ├── webcolors-25.10.0.dist-info
│   ├── webencodings
│   ├── webencodings-0.5.1.dist-info
│   ├── websocket
│   ├── websocket_client-1.9.0.dist-info
│   ├── _yaml
│   ├── yaml
│   └── zmq
├── Makefile
├── NFS
├── poetry.lock
├── pyproject.toml
├── pyvenv.cfg
├── README.md
├── requirements.txt
├── runs
├── share
│   ├── applications
│   │   └── jupyterlab.desktop
│   ├── icons
│   │   └── hicolor
│   │   └── scalable
│   ├── jupyter
│   │   ├── kernels
│   │   │   └── python3
│   │   ├── lab
│   │   │   ├── schemas
│   │   │   ├── static
│   │   │   └── themes
│   │   ├── labextensions
│   │   │   └── jupyterlab_pygments
│   │   ├── nbconvert
│   │   │   └── templates
│   │   └── nbextensions
│   │   └── pydeck
│   └── man
│   └── man1
│   ├── ipython.1
│   ├── isympy.1
│   └── ttx.1
├── src
│   ├── 5_epoch_emoset_resnet50_finetuned_2.41M.pth
│   ├── api.py
│   ├── data_loader.py
│   ├── dataset_paths_cache.pkl
│   ├── emoset_resnet50_best.pth
│   ├── emoset_resnet50_finetuned_2_41M_best.pth
│   ├── emoset_resnet50_resume.pth
│   ├── emoset_test_embeddings.npy
│   ├── emoset_test_labels.npy
│   ├── main.py
│   ├── music_engine
│   │   ├── image_processor.py
│   │   ├── __init__.py
│   │   ├── llm_bridge.py
│   │   ├── matcher.py
│   │   └── va_regressor.pkl
│   ├── scripts
│   │   ├── 00_setup_env.sh
│   │   ├── 01_download_DEAM.py
│   │   ├── 02_download_EmoSet.py
│   │   ├── 11_prerp_DEAM.py
│   │   ├── 20_bench_GPU.py
│   │   ├── 21_train_images.ipynb
│   │   ├── 22_extract_embeddings.ipynb
│   │   ├── 23_aggregate_DEAM_timeline.py
│   │   ├── 24_train_regressor.py
│   │   ├── 31_finetune_2.41M.py
│   │   ├── 90_acc_images_model.ipynb
│   │   └── 91_generate_metrics.py
│   └── tabs
│   ├── tab_dataset.py
│   └── tab_live.py
├── tree.txt
└── .vscode
├── launch.json
└── tasks.json
322 directories, 137 files