Compare commits

...

14 Commits

Author SHA1 Message Date
zin 934a4cbff4 feat: add metrics 2026-06-16 04:59:51 +00:00
zin 14968dd4d4 feat: commiting model 2026-06-08 14:49:45 +00:00
zin daba573b2c feat: improving finetuning 2026-06-06 21:06:21 +00:00
zin 8648e52106 feat: refactor code and finetune OOM fix 2026-06-03 09:34:34 +00:00
zin e3a2eb3289 Merge pull request 'Docker integration into project' (#1) from Dockered into main
Reviewed-on: #1
2026-06-03 12:18:10 +03:00
zin a57addcbb1 feat: finale 2026-06-03 09:16:12 +00:00
zin 3850b15053 feat: Init 2026-06-02 22:39:11 +00:00
zin 875616730b feat: docker integration 2026-06-02 19:18:37 +00:00
zin 4c86d30657 ref: refactor before checkout 2026-06-02 19:02:49 +00:00
zin f04cd7359b ref: refactor before chekout 2026-06-02 17:27:05 +00:00
zin 9ce92b70a9 feat: finetune 2.7M 2026-06-02 14:42:39 +00:00
zin c631c5649a feach: add mobile UI 2026-05-28 21:57:41 +00:00
zin fde8dbf2e7 chore: fix tab_live 2026-05-28 20:24:49 +00:00
zin af3c5a953e chore: change text output 2026-05-28 17:15:33 +00:00
47 changed files with 3113 additions and 3100 deletions
+18
View File
@@ -0,0 +1,18 @@
bin/
lib/
share/
etc/
include/
pyvenv.cfg
.idea/
.vscode/
__pycache__/
*.pyc
.git/
runs/
dataset/
NFS/
*.pth
*.pkl
*.npy
.env
+25
View File
@@ -0,0 +1,25 @@
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
WORKDIR /app
# System dependencies for OpenCV and image processing
RUN apt-get update && apt-get install -y \
libglib2.0-0 \
libsm6 \
libxext6 \
libxrender-dev \
&& rm -rf /var/lib/apt/lists/*
# Install python packages
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy source code
COPY src/ /app/src/
EXPOSE 8080
CMD ["streamlit", "run", "src/main.py", "--server.port", "8080", "--server.address", "0.0.0.0"]
+20
View File
@@ -0,0 +1,20 @@
.PHONY: up down logs restart status
# Сборка и запуск контейнеров в фоновом режиме
up:
docker compose up --build -d
# Остановка и удаление контейнеров
down:
docker compose down
# Просмотр логов в реальном времени
logs:
docker compose logs -f
# Быстрый перезапуск
restart: down up
# Проверка статуса
status:
docker compose ps
+63
View File
@@ -0,0 +1,63 @@
version: '3.8'
networks:
emom_mesh:
driver: bridge
services:
emom_ui:
build:
context: .
dockerfile: docker/Dockerfile.ui
container_name: emom_web_ui
restart: unless-stopped
ports:
- "8080:8080"
networks:
- emom_mesh
env_file:
- .env
volumes:
- ./src:/app/src
- ${DATA_DEAM_DIR}:/app/dataset/DEAM:ro
depends_on:
- emom_inference
emom_inference:
build:
context: .
dockerfile: docker/Dockerfile.api
container_name: emom_pytorch_api
restart: unless-stopped
networks:
- emom_mesh
env_file:
- .env
volumes:
- ${HOST_ARTIFACTS_DIR}/emoset_resnet50_best.pth:/app/src/emoset_resnet50_best.pth:ro
- ${HOST_ARTIFACTS_DIR}/music_engine/va_regressor.pkl:/app/src/music_engine/va_regressor.pkl:ro
- ${DATA_DEAM_DIR}:/app/dataset/DEAM:ro
- ~/.cache/huggingface:/root/.cache/huggingface
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
emom_ollama:
image: ollama/ollama:latest
container_name: emom_ollama_engine
restart: unless-stopped
networks:
- emom_mesh
volumes:
- ~/.ollama:/root/.ollama
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
+19
View File
@@ -0,0 +1,19 @@
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
RUN apt-get update && apt-get install -y \
libglib2.0-0 libsm6 libxext6 libxrender-dev \
&& rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir fastapi uvicorn timm scikit-learn pandas joblib python-multipart transformers==4.38.2 tokenizers==0.15.2 accelerate
WORKDIR /app
COPY src/ /app/src/
WORKDIR /app/src
EXPOSE 8000
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8000"]
+15
View File
@@ -0,0 +1,15 @@
FROM python:3.12-slim
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
WORKDIR /app
RUN pip install --no-cache-dir streamlit==1.32.0 requests pandas pillow
COPY src/ /app/src/
WORKDIR /app/src
EXPOSE 8080
CMD ["streamlit", "run", "main.py", "--server.port", "8080", "--server.address", "0.0.0.0"]
+9
View File
@@ -0,0 +1,9 @@
streamlit==1.32.0
torch==2.2.1
torchvision==0.17.1
timm==0.9.16
pandas==2.2.1
scikit-learn==1.4.1.post1
joblib==1.3.2
transformers==4.38.2
requests==2.31.0
+76
View File
@@ -0,0 +1,76 @@
import io
import traceback
import numpy as np
from typing import List
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
from data_loader import load_music_engine, load_image_processor
from music_engine.llm_bridge import LLMAcousticBridge
app = FastAPI(title="EmoM API", version="1.0.0")
ml_context = {
"image_processor": None,
"music_matcher": None,
"llm_bridge": None
}
@app.on_event("startup")
async def startup_event():
print("Loading ML models...")
ml_context["image_processor"] = load_image_processor()
ml_context["music_matcher"] = load_music_engine()
ml_context["llm_bridge"] = LLMAcousticBridge()
print("Initialization complete.")
@app.post("/analyze")
async def analyze_event_endpoint(files: List[UploadFile] = File(...)):
try:
images = []
for file in files:
image_bytes = await file.read()
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
images.append(img)
print(f"Processing batch: {len(images)} images.")
img_processor = ml_context["image_processor"]
matcher = ml_context["music_matcher"]
llm = ml_context["llm_bridge"]
all_v, all_a = [], []
all_objects = []
for img in images:
embedding = img_processor.extract_embedding(img)
v, a = matcher.predict_va(embedding)
all_v.append(v)
all_a.append(a)
caption = img_processor.describe_scene(img)
all_objects.append(caption)
target_v = float(np.mean(all_v))
target_a = float(np.mean(all_a))
unique_semantics = list(set(all_objects))
llm_profile = llm.get_acoustic_profile(target_v, target_a, unique_semantics)
playlist_df = matcher.find_nearest_tracks(target_v, target_a, llm_profile=llm_profile, top_k=15)
tracks_list = playlist_df.to_dict(orient="records")
return JSONResponse(content={
"status": "success",
"images_processed": len(images),
"target_v": target_v,
"target_a": target_a,
"llm_profile": llm_profile,
"semantics": unique_semantics,
"tracks": tracks_list
})
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
+27 -35
View File
@@ -1,54 +1,46 @@
import streamlit as st
from pathlib import Path from pathlib import Path
from typing import Tuple, List, Optional, Any
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from music_engine.matcher import MusicMatcher from music_engine.matcher import MusicMatcher
from music_engine.image_processor import ImageProcessor from music_engine.image_processor import ImageProcessor
# Определяем базовую директорию (папка src)
BASE_DIR = Path(__file__).resolve().parent BASE_DIR = Path(__file__).resolve().parent
@st.cache_resource def load_music_engine() -> MusicMatcher:
def load_music_engine(): #Инициализация модуля подбора музыкальных композиций.
"""Загрузка базы данных и модели регрессора."""
# music_db.csv лежит в dataset/DEAM/ (на уровень выше от src)
db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv" db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
# va_regressor.pkl лежит в src/music_engine/
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl" model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
if not db_path.exists():
print(f"⚠️ Файл базы {db_path} не найден!")
return None
return MusicMatcher(db_path=db_path, model_path=model_path) return MusicMatcher(db_path=db_path, model_path=model_path)
@st.cache_resource def load_image_processor() -> ImageProcessor:
def load_image_processor(): #Инициализация модуля экстракции визуальных признаков.
"""Загрузка ResNet-50 для извлечения признаков на лету.""" weights_path = BASE_DIR / "emoset_resnet50_best.pth"
# Файл весов лежит в той же папке src, что и этот скрипт
model_path = BASE_DIR / "emoset_resnet50_best.pth"
if not model_path.exists(): return ImageProcessor(weights_path)
print(f"❌ КРИТИЧЕСКАЯ ОШИБКА: Веса не найдены по пути: {model_path}")
# Если не нашли в src, попробуем поискать в корне проекта на всякий случай
model_path = BASE_DIR.parent / "emoset_resnet50_best.pth"
return ImageProcessor(model_path=model_path) def load_emoset_data() -> Tuple[Optional[List[str]], Optional[np.ndarray], Optional[np.ndarray], Optional[Path]]:
# Загрузка тестовой выборки датасета EmoSet.
# Модуль сохранен для обеспечения обратной совместимости в отладочном контуре.
try:
images_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
labels_path = BASE_DIR / "emoset_test_labels.npy"
embeddings_path = BASE_DIR / "emoset_test_embeddings.npy"
@st.cache_data if not all(p.exists() for p in [labels_path, embeddings_path]):
def load_emoset_data():
"""Загрузка тестовой выборки EmoSet для первой вкладки."""
# Пути относительно корня проекта
csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
img_dir = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
emb_path = BASE_DIR / "emoset_test_embeddings.npy"
lbl_path = BASE_DIR / "emoset_test_labels.npy"
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
return None, None, None, None return None, None, None, None
df = pd.read_csv(csv_path) labels = np.load(labels_path)
image_list = df['filename'].tolist() embeddings = np.load(embeddings_path)
embs = np.load(emb_path)
lbls = np.load(lbl_path)
return image_list, embs, lbls, img_dir csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
df = pd.read_csv(csv_path)
return df['filename'].tolist(), embeddings, labels, images_path
except Exception as e:
print(f"[WARN] Failed to load EmoSet test artifacts: {str(e)}")
return None, None, None, None
Binary file not shown.
+179 -33
View File
@@ -1,43 +1,189 @@
import streamlit as st
import sys
import os import os
import subprocess import requests
import streamlit as st
import streamlit.components.v1 as components
from PIL import Image
import base64
from io import BytesIO
from data_loader import load_music_engine, load_emoset_data, load_image_processor st.set_page_config(page_title="EmoM Playlist Generator", layout="wide", initial_sidebar_state="collapsed")
from tabs.tab_dataset import render_dataset_tab
from tabs.tab_live import render_live_tab
# ---------------------------- API_URL = os.getenv("BACKEND_API_URL", "http://emom_inference:8000") + "/analyze"
# 1️⃣ Запуск приложения DEAM_AUDIO_DIR = "/app/dataset/DEAM/DEAM_audio/MEMD_audio"
# ----------------------------
if __name__ == "__main__":
if "STREAMLIT_RUN" not in os.environ:
os.environ["STREAMLIT_RUN"] = "1"
cmd = [sys.executable, "-m", "streamlit", "run", __file__, "--server.port", "8080", "--server.address", "0.0.0.0"]
subprocess.run(cmd)
sys.exit()
st.set_page_config(page_title="Thesis Demo", layout="wide") def get_thumbnail_html(images, max_display=12):
html_images = ""
for file in images[:max_display]:
img = Image.open(file)
img.thumbnail((100, 100))
if img.mode != "RGB":
img = img.convert("RGB")
buffered = BytesIO()
img.save(buffered, format="JPEG")
b64_str = base64.b64encode(buffered.getvalue()).decode()
html_images += f'<img src="data:image/jpeg;base64,{b64_str}" style="width: 60px; height: 60px; object-fit: cover; border-radius: 8px; margin-right: 8px; margin-bottom: 8px; border: 1px solid rgba(255, 255, 255, 0.2);">'
# ---------------------------- if len(images) > max_display:
# 2️⃣ Инициализация движка и данных html_images += f'<span style="display: inline-block; width: 60px; height: 60px; line-height: 60px; text-align: center; background: rgba(150, 150, 150, 0.2); border-radius: 8px; vertical-align: top; font-size: 14px;">+{len(images) - max_display}</span>'
# ---------------------------- return f'<div style="display: flex; flex-wrap: wrap;">{html_images}</div>'
matcher = load_music_engine()
image_processor = load_image_processor()
image_files, embeddings, labels_array, images_path = load_emoset_data()
# ---------------------------- def main():
# 3️⃣ Интерфейс и Вкладки if "live_state" not in st.session_state:
# ---------------------------- st.session_state.live_state = "upload"
st.title("🖼️ Генератор саундтреков (Research Demo)") if "result_data" not in st.session_state:
st.session_state.result_data = None
tab1, tab2 = st.tabs(["📊 Отладка (Датасет EmoSet)", "📸 Анализ событий (Свои фото)"]) viewport = st.query_params.get("viewport", "desktop")
with tab1: st.markdown("""
render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path) <style>
[data-testid="stFileUploadDropzone"] { min-height: 250px !important; display: flex; align-items: center; justify-content: center; border-radius: 16px; background-color: rgba(255, 75, 75, 0.03); }
.spinner-container { display: flex; flex-direction: column; align-items: center; justify-content: center; min-height: 40vh; margin-top: 10vh; }
.big-spinner { width: 120px; height: 120px; border: 10px solid rgba(255, 75, 75, 0.1); border-top: 10px solid #ff4b4b; border-radius: 50%; animation: spin 1s linear infinite; margin-bottom: 2rem; }
@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
#MainMenu {visibility: hidden;} footer {visibility: hidden;}
</style>
""", unsafe_allow_html=True)
with tab2: if st.session_state.live_state == "upload":
if image_processor: upload_placeholder = st.empty()
render_live_tab(matcher, image_processor) with upload_placeholder.container():
st.write("Загрузите изображения для визуально-семантического анализа.")
if viewport == "mobile":
st.markdown("<br>", unsafe_allow_html=True)
uploaded_files = st.file_uploader(
"Загрузка файлов",
type=['png', 'jpg', 'jpeg'],
accept_multiple_files=True,
label_visibility="collapsed" if viewport == "mobile" else "visible"
)
if uploaded_files:
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Выполнить анализ", type="primary", use_container_width=True):
st.session_state.uploaded_images = uploaded_files
st.session_state.live_state = "processing"
upload_placeholder.empty()
st.rerun()
st.markdown("<br>", unsafe_allow_html=True)
st.caption("Выбранные файлы:")
st.markdown(get_thumbnail_html(uploaded_files), unsafe_allow_html=True)
elif st.session_state.live_state == "processing":
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
files = st.session_state.get("uploaded_images", [])
st.markdown('<div class="spinner-container"><div class="big-spinner"></div><h3 style="text-align: center; font-weight: 400;">Обработка данных...</h3></div>', unsafe_allow_html=True)
try:
upload_data = [('files', (f.name, f.getvalue(), f.type)) for f in files]
response = requests.post(API_URL, files=upload_data, timeout=300)
if response.status_code == 200:
st.session_state.result_data = response.json()
st.session_state.live_state = "result"
st.rerun()
else: else:
st.error("Система обработки изображений недоступна (не найдены веса ResNet).") st.error(f"Ошибка сервера: {response.status_code}")
if st.button("Назад"):
st.session_state.live_state = "upload"
st.rerun()
except Exception as e:
st.error(f"Ошибка соединения: {str(e)}")
if st.button("Назад"):
st.session_state.live_state = "upload"
st.rerun()
elif st.session_state.live_state == "result":
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
data = st.session_state.result_data
st.header(f"Сгенерированный плейлист (обработано файлов: {data['images_processed']})")
for row in data.get("tracks", []):
with st.container(border=True):
song_id = int(row['song_id'])
score = row['final_score']
audio_path = f"{DEAM_AUDIO_DIR}/{song_id}.mp3"
if not os.path.exists(audio_path):
audio_path = audio_path.replace('.mp3', '.wav')
if viewport == "desktop":
c1, c2 = st.columns([1, 3])
with c1:
st.write(f"**Track ID:** {song_id}")
st.caption(f"Score: {score:.4f}")
with c2:
if os.path.exists(audio_path):
st.audio(audio_path)
else:
st.caption("Аудиофайл не найден")
else:
st.write(f"**Track ID:** {song_id} (Score: {score:.4f})")
if os.path.exists(audio_path):
st.audio(audio_path)
else:
st.caption("Аудиофайл не найден")
st.markdown("<br>", unsafe_allow_html=True)
with st.expander("Отладочная информация (Метрики)"):
st.subheader("Координаты V/A")
c_v, c_a = st.columns(2)
c_v.metric("Valence", f"{data['target_v']:.2f}")
c_a.metric("Arousal", f"{data['target_a']:.2f}")
st.markdown("---")
st.subheader("Акустические признаки (LLM)")
feature_titles = {
"energy": "RMS Energy",
"flux": "Spectral Flux",
"centroid": "Spectral Centroid",
"pitch": "F0 (Pitch)",
"hnr": "HNR",
"zcr": "ZCR"
}
# Развернутые описания
feature_helps = {
"energy": "Среднеквадратичная амплитуда (громкость). Бывает высокой в плотных, интенсивных композициях, отражает общую акустическую энергию сцены.",
"flux": "Спектральный поток. Измеряет резкость изменений в спектре. Высок при четком, агрессивном ритме и частой смене нот.",
"centroid": "Спектральный центроид («яркость» звука). Высокие значения указывают на преобладание высоких частот (звонкие инструменты, открытые пространства).",
"pitch": "Основная частота звука. Высокий pitch характерен для позитивных, легких или, напротив, напряженных мелодий.",
"hnr": "Отношение гармоник к шуму. Высокий HNR — чистая мелодия и вокал. Низкий HNR — присутствие дисторшна, шумов или перкуссии.",
"zcr": "Частота пересечения нуля. Отражает шумовую составляющую сигнала. Высок в треках с выраженными ударными (hi-hats) или атмосферным шумом."
}
llm_profile = data.get("llm_profile")
if llm_profile and isinstance(llm_profile, dict) and len(llm_profile) > 0:
cols_per_row = 2 if viewport == "mobile" else 3
llm_items = list(llm_profile.items())
for i in range(0, len(llm_items), cols_per_row):
cols = st.columns(cols_per_row)
for j in range(cols_per_row):
if i + j < len(llm_items):
k, v = llm_items[i + j]
label = feature_titles.get(k, k)
tooltip = feature_helps.get(k, "")
cols[j].metric(label, f"{v:.2f}", help=tooltip)
else:
st.caption("Акустический профиль недоступен. Применен fallback-алгоритм.")
st.markdown("---")
st.write("**Извлеченные теги (BLIP-2):**")
st.write(", ".join([str(c).capitalize() for c in data.get("semantics", [])]))
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Новый запрос", use_container_width=True):
st.session_state.live_state = "upload"
st.session_state.result_data = None
st.session_state.pop("uploaded_images", None)
st.rerun()
if __name__ == "__main__":
main()
Binary file not shown.

After

Width:  |  Height:  |  Size: 851 KiB

+40 -30
View File
@@ -1,56 +1,66 @@
import numpy as np
from pathlib import Path
from PIL import Image
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from PIL import Image
import timm import timm
from pathlib import Path
import numpy as np
# ТЕПЕРЬ BLIP-2
from transformers import Blip2Processor, Blip2ForConditionalGeneration from transformers import Blip2Processor, Blip2ForConditionalGeneration
class ImageProcessor: class ImageProcessor:
def __init__(self, model_path: Path | str): def __init__(self, weights_path: str | Path):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# --- ПОТОК 1: ТВОЙ АВТОРСКИЙ ЭМБЕДДИНГ (Core) --- # Модель извлечения визуальных признаков
self.emo_model = timm.create_model('resnet50', pretrained=False, num_classes=8) self.feature_extractor = timm.create_model('resnet50', pretrained=False, num_classes=8)
if Path(model_path).exists():
self.emo_model.load_state_dict(torch.load(model_path, map_location=self.device))
self.emo_model.fc = torch.nn.Identity() if Path(weights_path).exists():
self.emo_model.to(self.device).eval() self.feature_extractor.load_state_dict(torch.load(weights_path, map_location=self.device))
else:
print(f"Не удалось найти веса ResNet по пути: {weights_path}")
self.emo_transform = T.Compose([ # Удаление слоя классификации для вывода сырого вектора эмбеддингов
self.feature_extractor.fc = torch.nn.Identity()
self.feature_extractor.to(self.device).eval()
# Трансформации для предварительной обработки изображений
self.preprocess_image = T.Compose([
T.Resize((224, 224)), T.Resize((224, 224)),
T.ToTensor(), T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]) ])
# --- ПОТОК 2: СЕМАНТИЧЕСКИЙ ЭКСПЕРТ BLIP-2 --- # Модуль семантического описания сцены
print("⏳ Загрузка тяжелой артиллерии: BLIP-2...") print("Инициализация BLIP-2...")
# Используем версию opt-2.7b — она идеально сбалансирована для V100 # Обход бага конфигурации Hugging Face (ручная сборка процессора)
self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") from transformers import BlipImageProcessor, AutoTokenizer
img_proc = BlipImageProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
tok = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b", use_fast=False)
self.blip_processor = Blip2Processor(image_processor=img_proc, tokenizer=tok)
self.blip_model = Blip2ForConditionalGeneration.from_pretrained( self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", "Salesforce/blip2-opt-2.7b",
torch_dtype=torch.float16 # Обязательно для скорости на V100 torch_dtype=torch.float16
).to(self.device) ).to(self.device)
print("✅ BLIP-2 и ResNet-50 готовы.")
@torch.no_grad() @torch.no_grad()
def extract_embedding(self, image: Image.Image) -> np.ndarray: def extract_embedding(self, image: Image.Image) -> np.ndarray:
img_rgb = image.convert('RGB') # Извлечение эмбеддингов из изображения
img_tensor = self.emo_transform(img_rgb).unsqueeze(0).to(self.device) rgb_image = image.convert('RGB')
return self.emo_model(img_tensor).cpu().numpy().flatten() img_tensor = self.preprocess_image(rgb_image).unsqueeze(0).to(self.device)
features = self.feature_extractor(img_tensor)
features_np = features.cpu().numpy()
return features_np.flatten()
@torch.no_grad() @torch.no_grad()
def describe_scene(self, image: Image.Image) -> str: def describe_scene(self, image: Image.Image) -> str:
"""Генерирует описание через BLIP-2.""" # Генерация текстового описания сцены
img_rgb = image.convert('RGB') rgb_image = image.convert('RGB')
# Инференс BLIP-2 требует float16 для V100 inputs = self.blip_processor(images=rgb_image, return_tensors="pt").to(self.device, torch.float16)
inputs = self.blip_processor(images=img_rgb, return_tensors="pt").to(self.device, torch.float16)
# Генерируем описание
generated_ids = self.blip_model.generate(**inputs, max_new_tokens=40) generated_ids = self.blip_model.generate(**inputs, max_new_tokens=40)
caption = self.blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return caption scene_description = self.blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return scene_description.strip()
+43 -38
View File
@@ -1,60 +1,65 @@
import requests import os
import json import json
import re import re
import requests
class LLMAcousticBridge: class LLMAcousticBridge:
def __init__(self, model_name="dolphin-llama3:8b"): def __init__(self, model_name="dolphin-llama3:8b"):
self.model_name = model_name self.model_name = model_name
self.api_url = "http://localhost:11434/api/generate" base_url = os.getenv("OLLAMA_API_URL", "http://emom_ollama:11434")
self.api_url = f"{base_url}/api/generate"
def _clean_json(self, text): def get_acoustic_profile(self, valence, arousal, semantics):
"""Вытаскивает чистый JSON из ответа нейросети.""" context_str = ", ".join(semantics) if semantics else "abstract scene"
try:
match = re.search(r'\{.*\}', text, re.DOTALL)
if match:
return json.loads(match.group(0))
return json.loads(text)
except:
return None
def get_acoustic_profile(self, valence, arousal, scene_descriptions): prompt = f"""
"""Просит LLM сгенерировать идеальный звук под описание."""
# Объединяем описания, если загружено несколько фото
context_str = " | ".join(scene_descriptions) if scene_descriptions else "abstract scene"
prompt = f"""You are an expert music producer and acoustic engineer.
Analyze the visual context and emotions to determine the ideal background music properties. Analyze the visual context and emotions to determine the ideal background music properties.
Emotions: Valence {valence:.1f}/9.0 (Positivity), Arousal {arousal:.1f}/9.0 (Energy). Emotions: Valence {valence:.1f}/9.0 (Positivity), Arousal {arousal:.1f}/9.0 (Energy).
Visual Context: {context_str}. Visual Context: {context_str}.
Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0. Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0.
1. "energy": (Loudness/Density. High for massive/busy scenes, Low for calm)
2. "flux": (Rhythmic sharpness/Beat. High for action/people/cars, Low for static nature)
3. "centroid": (Brightness: 0=Dark/Bass/Massive, 1=Bright/Treble/Light)
4. "pitch": (Fundamental frequency: 0=Low pitch/Huge objects, 1=High pitch/Small objects)
5. "hnr": (Harmonics-to-Noise: 0=Noisy/Distorted textures, 1=Clear/Melodic/Smooth textures)
6. "zcr": (Percussiveness. High for detailed noise like leaves/rain, Low for solid blocks)
Return ONLY a valid JSON object. Do not add any text or explanation. 1. "energy": (Loudness/Density)
Example: {{"energy": 0.5, "flux": 0.2, "centroid": 0.4, "pitch": 0.3, "hnr": 0.8, "zcr": 0.1}}""" 2. "flux": (Rhythmic sharpness/Beat)
3. "centroid": (Brightness)
4. "pitch": (Fundamental frequency)
5. "hnr": (Harmonics-to-Noise)
6. "zcr": (Percussiveness)
Return ONLY a valid JSON object. No explanations, no markdown blocks.
Example: {{"energy": 0.8, "flux": 0.5, "centroid": 0.6, "pitch": 0.4, "hnr": 0.9, "zcr": 0.3}}
"""
try: try:
response = requests.post(self.api_url, json={ payload = {
"model": self.model_name, "model": self.model_name,
"prompt": prompt, "prompt": prompt,
"stream": False, "stream": False,
"format": "json" "format": "json" # Принудительный JSON-режим Ollama
}, timeout=30) }
response.raise_for_status()
result_text = response.json().get("response", "") print(f"Запрос акустического профиля к Ollama...")
profile = self._clean_json(result_text) response = requests.post(self.api_url, json=payload, timeout=120)
# Проверяем, что все нужные ключи есть if response.status_code == 200:
required_keys = ['energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr'] data = response.json()
if profile and all(k in profile for k in required_keys): response_text = data.get("response", "")
try:
# 1. Попытка прямой десериализации
profile = json.loads(response_text)
return profile return profile
return None except json.JSONDecodeError:
# 2. Аварийное извлечение JSON из текста с помощью регулярного выражения
match = re.search(r'\{.*\}', response_text, re.DOTALL)
if match:
return json.loads(match.group(0))
print(f"Ошибка парсинга LLM ответа: {response_text}")
return {}
else:
print(f"Ollama вернула ошибку HTTP: {response.status_code}")
return {}
except Exception as e: except Exception as e:
print(f"⚠️ Ошибка связи с локальной LLM: {e}") print(f"Ошибка соединения с Ollama: {str(e)}")
return None return {}
+41 -25
View File
@@ -1,67 +1,83 @@
import joblib
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from pathlib import Path from pathlib import Path
import joblib
class MusicMatcher: class MusicMatcher:
def __init__(self, db_path: Path | str, model_path: Path | str): def __init__(self, db_path: Path | str, model_path: Path | str):
# Загружаем твою новую, обогащенную базу # Загрузка базы данных музыкальных произведений
self.music_db = pd.read_csv(db_path) self.music_db = pd.read_csv(db_path)
self.acoustic_features = ['energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr'] self.acoustic_features = ['energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr']
# Удаляем строки, где нет акустических фич # Удаление записей с пропущенными целевыми или акустическими признаками
self.music_db = self.music_db.dropna(subset=['valence', 'arousal'] + self.acoustic_features) target_columns = ['valence', 'arousal'] + self.acoustic_features
self.music_db = self.music_db.dropna(subset=target_columns)
# Нормализуем акустику от 0 до 1, чтобы сравнивать с ответом LLM # Масштабирование акустических параметров к диапазону [0, 1]
self.norm_db = self.music_db.copy() self.norm_db = self.music_db.copy()
for feat in self.acoustic_features: for feat in self.acoustic_features:
f_min, f_max = self.norm_db[feat].min(), self.norm_db[feat].max() f_min = self.norm_db[feat].min()
f_max = self.norm_db[feat].max()
if f_max > f_min: if f_max > f_min:
self.norm_db[f"norm_{feat}"] = (self.norm_db[feat] - f_min) / (f_max - f_min) self.norm_db[f"norm_{feat}"] = (self.norm_db[feat] - f_min) / (f_max - f_min)
else: else:
self.norm_db[f"norm_{feat}"] = 0.0 self.norm_db[f"norm_{feat}"] = 0.0
# Определение путей к аудиофайлам и загрузка модели регрессии
self.audio_dir = Path(db_path).parent / "DEAM_audio" / "MEMD_audio" self.audio_dir = Path(db_path).parent / "DEAM_audio" / "MEMD_audio"
self.regressor = joblib.load(model_path) if Path(model_path).exists() else None
def predict_va(self, embedding: np.ndarray): if Path(model_path).exists():
if self.regressor: self.regressor = joblib.load(model_path)
prediction = self.regressor.predict(embedding.reshape(1, -1))[0] else:
return np.clip(prediction[0], 1.0, 9.0), np.clip(prediction[1], 1.0, 9.0) self.regressor = None
def predict_va(self, embedding: np.ndarray) -> tuple[float, float]:
# Прогнозирование координат Valence/Arousal по визуальному эмбеддингу
if not self.regressor:
return 5.0, 5.0 return 5.0, 5.0
def get_audio_path(self, song_id): raw_prediction = self.regressor.predict(embedding.reshape(1, -1))[0]
if not self.audio_dir.exists(): return None valence_pred = np.clip(raw_prediction[0], 1.0, 9.0)
arousal_pred = np.clip(raw_prediction[1], 1.0, 9.0)
return float(valence_pred), float(arousal_pred)
def get_audio_path(self, song_id: int | float | str) -> Path | None:
# Поиск физического пути к аудиофайлу в зависимости от расширения
if not self.audio_dir.exists():
return None
clean_id = str(int(float(song_id))) clean_id = str(int(float(song_id)))
for ext in ['.mp3', '.wav']: for ext in ['.mp3', '.wav']:
path = self.audio_dir / f"{clean_id}{ext}" path = self.audio_dir / f"{clean_id}{ext}"
if path.exists(): return path if path.exists():
return path
return None return None
def find_nearest_tracks(self, target_v: float, target_a: float, llm_profile: dict = None, top_k: int = 5): def find_nearest_tracks(self, target_v: float, target_a: float, llm_profile: dict = None, top_k: int = 5) -> pd.DataFrame:
# 1. Эмоциональная дистанция (как и раньше) # Расчет евклидова расстояния в эмоциональном пространстве Рассела
emo_dist = np.sqrt( v_dist = (self.norm_db['valence'] - target_v) ** 2
1.0 * (self.norm_db['valence'] - target_v)**2 + a_dist = (self.norm_db['arousal'] - target_a) ** 2
2.5 * (self.norm_db['arousal'] - target_a)**2
)
self.norm_db['emo_distance'] = emo_dist
# Если LLM не дала ответ, сортируем только по эмоциям # Взвешенное расстояние с приоритетом оси активации (Arousal)
self.norm_db['emo_distance'] = np.sqrt(1.0 * v_dist + 2.5 * a_dist)
# Ранжирование только по эмоциональному критерию при отсутствии профиля LLM
if not llm_profile: if not llm_profile:
self.norm_db['final_score'] = self.norm_db['emo_distance'] self.norm_db['final_score'] = self.norm_db['emo_distance']
return self.norm_db.sort_values(by='final_score').head(top_k) return self.norm_db.sort_values(by='final_score').head(top_k)
# 2. Акустическая дистанция (сравниваем треки с запросом LLM) # Расчет отклонений по вектору акустических параметров LLM
acoustic_penalty = np.zeros(len(self.norm_db)) acoustic_penalty = np.zeros(len(self.norm_db))
for feat in self.acoustic_features: for feat in self.acoustic_features:
if feat in llm_profile: if feat in llm_profile:
target_val = llm_profile[feat] target_val = llm_profile[feat]
acoustic_penalty += np.abs(self.norm_db[f"norm_{feat}"] - target_val) acoustic_penalty += np.abs(self.norm_db[f"norm_{feat}"] - target_val)
# Усредняем штраф # Нормирование акустической дистанции
self.norm_db['acoustic_distance'] = acoustic_penalty / len(self.acoustic_features) self.norm_db['acoustic_distance'] = acoustic_penalty / len(self.acoustic_features)
# 3. Финальный Score (Смесь Эмоций и Акустики). Коэф 4.0 делает акустику важной! # Вычисление интегральной метрики соответствия (мультимодальный скоринг)
self.norm_db['final_score'] = self.norm_db['emo_distance'] + (self.norm_db['acoustic_distance'] * 4.0) self.norm_db['final_score'] = self.norm_db['emo_distance'] + (self.norm_db['acoustic_distance'] * 4.0)
return self.norm_db.sort_values(by='final_score').head(top_k) return self.norm_db.sort_values(by='final_score').head(top_k)
Binary file not shown.
+73
View File
@@ -0,0 +1,73 @@
#!/bin/bash
# Данный скрипт написан ИИ для быстрой подготовки окружения, установка драйверов и докера
# Остановка скрипта при возникновении любой ошибки
set -e
# Цвета для красивого вывода в консоль
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
RED='\033[0;31m'
BLUE='\033[0;34m'
NC='\033[0m'
echo -e "${BLUE}[INFO]${NC} Инициализация проверки окружения для проекта EmoM..."
# 1. ПРОВЕРКА DOCKER
if ! command -v docker &> /dev/null; then
echo -e "${YELLOW}[SETUP]${NC} Docker не найден. Начинаем установку..."
# Использование официального скрипта установки Docker
curl -fsSL https://get.docker.com -o get-docker.sh
sudo sh get-docker.sh
rm get-docker.sh
# Добавляем текущего пользователя в группу docker, чтобы не писать sudo docker
sudo usermod -aG docker $USER
echo -e "${GREEN}[OK]${NC} Docker успешно установлен."
echo -e "${YELLOW}[ВНИМАНИЕ]${NC} Для применения прав группы docker потребуется перезайти в сессию SSH после завершения скрипта."
else
echo -e "${GREEN}[OK]${NC} Docker установлен ($(docker --version))."
fi
# 2. ПРОВЕРКА DOCKER COMPOSE
if ! docker compose version &> /dev/null; then
echo -e "${YELLOW}[SETUP]${NC} Плагин Docker Compose не найден. Устанавливаем..."
sudo apt-get update && sudo apt-get install -y docker-compose-plugin
echo -e "${GREEN}[OK]${NC} Docker Compose успешно установлен."
else
echo -e "${GREEN}[OK]${NC} Плагин Docker Compose доступен."
fi
# 3. ПРОВЕРКА NVIDIA И ПРОБРОСА GPU В DOCKER
if command -v nvidia-smi &> /dev/null; then
echo -e "${GREEN}[OK]${NC} Драйверы NVIDIA обнаружены."
# Проверяем наличие NVIDIA Container Toolkit
if ! dpkg -l | grep -q nvidia-container-toolkit; then
echo -e "${YELLOW}[SETUP]${NC} NVIDIA Container Toolkit не найден. Выполняется установка..."
# Настройка репозиториев NVIDIA
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
sudo apt-get update
sudo apt-get install -y nvidia-container-toolkit
# Конфигурация Docker для работы с NVIDIA
echo -e "${YELLOW}[SETUP]${NC} Конфигурация runtime NVIDIA для Docker..."
sudo nvidia-ctk runtime configure --runtime=docker
sudo systemctl restart docker
echo -e "${GREEN}[OK]${NC} NVIDIA Container Toolkit установлен и настроен."
else
echo -e "${GREEN}[OK]${NC} NVIDIA Container Toolkit уже установлен."
fi
else
echo -e "${RED}[WARN]${NC} Утилита nvidia-smi не найдена! Убедитесь, что драйверы видеокарты установлены, иначе Docker будет использовать только CPU."
fi
echo -e "\n${BLUE}[INFO]${NC} ========================================="
echo -e "${GREEN}[SUCCESS]${NC} Окружение готово к работе!"
echo -e "Теперь вы можете запустить проект командой: ${YELLOW}make up${NC}"
+20
View File
@@ -0,0 +1,20 @@
import shutil
from pathlib import Path
import kagglehub
dataset_dir = Path("../dataset/DEAM")
dataset_dir.mkdir(parents=True, exist_ok=True)
print("Скачивание датасета DEAM...")
# kagglehub по умолчанию тянет данные в системный кэш (~/.cache)
cache_path = kagglehub.dataset_download("imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music")
print(f"Загружено в кэш: {cache_path}")
print(f"Перенос файлов в {dataset_dir} и очистка временной директории...")
# Перемещаем данные
shutil.copytree(cache_path, dataset_dir, dirs_exist_ok=True)
shutil.rmtree(cache_path)
print("Готово. Датасет DEAM загружен, кэш очищен.")
+56
View File
@@ -0,0 +1,56 @@
import csv
from pathlib import Path
from datasets import load_dataset
from tqdm import tqdm
# Конфигурация корневой директории локального датасета
DATASET_DIR = Path("../dataset/EmoSet-118K")
def process_and_save_split(dataset_split, split_name: str, output_dir: Path):
# Подготовка структуры директорий для текущей выборки
split_dir = output_dir / split_name
img_dir = split_dir / "images"
img_dir.mkdir(parents=True, exist_ok=True)
labels_path = split_dir / "labels.csv"
print(f"Обработка выборки: {split_name}...")
# Открытие файла разметки перед циклом для минимизации I/O операций диска
with open(labels_path, mode="w", newline="", encoding="utf-8") as csv_file:
writer = csv.writer(csv_file)
writer.writerow(["filename", "label"])
for example in tqdm(dataset_split, desc=split_name):
img = example["image"]
emotion_label = example["emotion"]
img_id = example["image_id"]
file_name = f"{img_id}.jpg"
# Принудительная конвертация в RGB для безопасного сохранения в JPEG-формате
if img.mode != "RGB":
img = img.convert("RGB")
img.save(img_dir / file_name, format="JPEG")
writer.writerow([file_name, emotion_label])
if __name__ == "__main__":
DATASET_DIR.mkdir(exist_ok=True, parents=True)
# Инициализация подключения к Hugging Face Hub
print("Загрузка метаданных EmoSet-118K...")
raw_dataset = load_dataset("Woleek/EmoSet-118K")
# Итеративная выгрузка размеченных данных
for split_key in ["train", "val", "test"]:
if split_key in raw_dataset:
process_and_save_split(
dataset_split=raw_dataset[split_key],
split_name=split_key,
output_dir=DATASET_DIR
)
print("Экспорт датасета завершен.")
+30
View File
@@ -0,0 +1,30 @@
import pandas as pd
from pathlib import Path
# Конфигурация локальных путей
SOURCE_CSV = Path("../../dataset/DEAM/DEAM_Annotations/annotations/annotations averaged per song/song_level/static_annotations_averaged_songs_1_2000.csv")
OUTPUT_CSV = Path("../../dataset/DEAM/music_db.csv")
def prepare_deam_database():
if not SOURCE_CSV.exists():
print(f"Исходный файл аннотаций не найден: {SOURCE_CSV}")
return
print("Обработка разметки датасета DEAM...")
# Загрузка сырых данных с очисткой артефактов форматирования
raw_df = pd.read_csv(SOURCE_CSV, skipinitialspace=True)
# Экстракция координат пространства Рассела (Valence/Arousal)
processed_df = raw_df[['song_id', 'valence_mean', 'arousal_mean']].copy()
processed_df.columns = ['song_id', 'valence', 'arousal']
# Приведение идентификаторов к формату файловой системы (int)
processed_df['song_id'] = processed_df['song_id'].astype(int)
processed_df.to_csv(OUTPUT_CSV, index=False)
print(f"База успешно сформирована. Всего записей: {len(processed_df)}")
if __name__ == "__main__":
prepare_deam_database()
+60
View File
@@ -0,0 +1,60 @@
import time
import torch
import torch.nn as nn
import torch.optim as optim
# Конфигурация параметров нагрузочного тестирования
NUM_SAMPLES = 300_000
DIM_IN = 4096
DIM_OUT = 10
BATCH_SIZE = 16_384
NUM_STEPS = 1000
def run_gpu_benchmark():
# Проверка доступности аппаратного ускорения
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Инициализация стресс-теста на устройстве: {device}")
# Генерация синтетического датасета для аллокации VRAM
x_data = torch.randn(NUM_SAMPLES, DIM_IN, device=device, dtype=torch.float32)
y_data = torch.randn(NUM_SAMPLES, DIM_OUT, device=device, dtype=torch.float32)
# Архитектура тестовой полносвязной сети
model = nn.Sequential(
nn.Linear(DIM_IN, 2048),
nn.ReLU(),
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, DIM_OUT)
).to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print("Начало прогрева GPU и симуляции цикла обучения...")
start_time = time.time()
for step in range(NUM_STEPS):
# Сэмплирование случайного батча
idx = torch.randint(0, NUM_SAMPLES, (BATCH_SIZE,), device=device)
x_batch = x_data[idx]
y_batch = y_data[idx]
optimizer.zero_grad()
predictions = model(x_batch)
loss = loss_fn(predictions, y_batch)
loss.backward()
optimizer.step()
# Логирование статуса (каждые 100 итераций для снижения I/O overhead)
if step % 100 == 0:
print(f"Итерация {step}/{NUM_STEPS} | Текущий loss: {loss.item():.4f}")
end_time = time.time()
print(f"Стресс-тест завершен. Общее время: {end_time - start_time:.2f} сек.")
if __name__ == "__main__":
run_gpu_benchmark()
+184
View File
@@ -0,0 +1,184 @@
import os
import random
import warnings
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm
# Подавление предупреждений цветовых профилей
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
# Настройки окружения
DATA_ROOT = Path("/home/zin/projects/Thesis/NFS/Thesis/Emoset/EmoSet-118K")
BATCH_SIZE = 64
EPOCHS = 30
LR = 5e-5
NUM_WORKERS = 62
PATIENCE = 7
# Маппинг классов
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {DEVICE}")
# Фиксация генераторов псевдослучайных чисел
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed()
# Инициализация структур данных
class EmoSetDataset(Dataset):
def __init__(self, root: Path | str, split: str, transform=None):
self.root = Path(root) / split
self.df = pd.read_csv(self.root / "labels.csv")
self.transform = transform
# Фильтрация датафрейма
self.df = self.df[self.df["label"].isin(CLASS_MAPPING.keys())].reset_index(drop=True)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_path = self.root / "images" / row["filename"]
try:
img = Image.open(img_path).convert("RGB")
except Exception:
img = Image.new("RGB", (256, 256), (0, 0, 0))
if self.transform:
img_tensor = self.transform(img)
else:
img_tensor = T.ToTensor()(img)
label_idx = CLASS_MAPPING[row["label"]]
return img_tensor, label_idx
# Трансформации
base_tf = [
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
train_transform = T.Compose([
T.Resize(256, antialias=True),
T.RandomCrop(224),
T.RandomHorizontalFlip(),
*base_tf
])
val_transform = T.Compose([
T.Resize(256, antialias=True),
T.CenterCrop(224),
*base_tf
])
train_ds = EmoSetDataset(DATA_ROOT, "train", transform=train_transform)
val_ds = EmoSetDataset(DATA_ROOT, "val", transform=val_transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
# Инициализация модели и оптимизатора
model = timm.create_model("resnet50", pretrained=True, num_classes=8, drop_rate=0.3)
model.to(DEVICE)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
# Логика эпохи обучения
def train_epoch(current_model, loader):
current_model.train()
total_loss, correct_preds, total_samples = 0.0, 0, 0
for imgs, labels in tqdm(loader, desc="Тренировка", leave=False, smoothing=0):
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad(set_to_none=True)
logits = current_model(imgs)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * imgs.size(0)
preds = logits.argmax(dim=1)
correct_preds += (preds == labels).sum().item()
total_samples += labels.size(0)
return total_loss / total_samples, correct_preds / total_samples
# Логика эпохи валидации
@torch.no_grad()
def val_epoch(current_model, loader):
current_model.eval()
total_loss, correct_preds, total_samples = 0.0, 0, 0
for imgs, labels in tqdm(loader, desc="Валидация", leave=False, smoothing=0):
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
logits = current_model(imgs)
loss = criterion(logits, labels)
total_loss += loss.item() * imgs.size(0)
preds = logits.argmax(dim=1)
correct_preds += (preds == labels).sum().item()
total_samples += labels.size(0)
return total_loss / total_samples, correct_preds / total_samples
if __name__ == "__main__":
best_val_acc = 0.0
best_val_loss = float('inf')
epochs_no_improve = 0
checkpoint_path = "./emosetV2_resnet50_best.pth"
print("Старт обучения.")
for epoch in range(1, EPOCHS + 1):
train_loss, train_acc = train_epoch(model, train_loader)
val_loss, val_acc = val_epoch(model, val_loader)
scheduler.step()
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
# Сохранение лучших весов по Accuracy
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), checkpoint_path)
print(f"Сохранен чекпоинт (Acc: {best_val_acc:.4f})")
# Оценка переобучения по Loss (Early Stopping)
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= PATIENCE:
print(f"Ранняя остановка: метрика валидации не улучшается {PATIENCE} эпох.")
break
print("Процесс завершен.")
+283
View File
@@ -0,0 +1,283 @@
import os
import random
import warnings
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
# Подавление предупреждений цветовых профилей
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
# Настройки окружения
DATA_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K")
# ВАЖНО: Добавили путь для медиа файлов
MEDIA_DIR = Path("./src/scripts/media")
MEDIA_DIR.mkdir(parents=True, exist_ok=True)
BATCH_SIZE = 64
EPOCHS = 30
LR = 5e-5
NUM_WORKERS = 32
PATIENCE = 7
# Маппинг классов
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
}
# Инвертированный маппинг для графиков
INV_CLASS_MAPPING = {v: k for k, v in CLASS_MAPPING.items()}
CLASS_NAMES = [INV_CLASS_MAPPING[i] for i in range(len(CLASS_MAPPING))]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {DEVICE}")
# Фиксация генераторов псевдослучайных чисел
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed()
# Инициализация структур данных
class EmoSetDataset(Dataset):
def __init__(self, root: Path | str, split: str, transform=None):
self.root = Path(root) / split
self.df = pd.read_csv(self.root / "labels.csv")
self.transform = transform
# Фильтрация датафрейма
self.df = self.df[self.df["label"].isin(CLASS_MAPPING.keys())].reset_index(drop=True)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_path = self.root / "images" / row["filename"]
try:
img = Image.open(img_path).convert("RGB")
except Exception:
img = Image.new("RGB", (256, 256), (0, 0, 0))
if self.transform:
img_tensor = self.transform(img)
else:
img_tensor = T.ToTensor()(img)
label_idx = CLASS_MAPPING[row["label"]]
return img_tensor, label_idx
# Трансформации
base_tf = [
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
train_transform = T.Compose([
T.Resize(256, antialias=True),
T.RandomCrop(224),
T.RandomHorizontalFlip(),
*base_tf
])
val_transform = T.Compose([
T.Resize(256, antialias=True),
T.CenterCrop(224),
*base_tf
])
train_ds = EmoSetDataset(DATA_ROOT, "train", transform=train_transform)
val_ds = EmoSetDataset(DATA_ROOT, "val", transform=val_transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
# Инициализация модели и оптимизатора
model = timm.create_model("resnet50", pretrained=True, num_classes=8, drop_rate=0.3)
model.to(DEVICE)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
# Функции для отрисовки графиков
def plot_learning_curves(history):
"""Отрисовка графиков функции потерь и точности"""
epochs = range(1, len(history['train_loss']) + 1)
plt.figure(figsize=(14, 5))
# График Loss
plt.subplot(1, 2, 1)
plt.plot(epochs, history['train_loss'], 'b-', label='Train Loss')
plt.plot(epochs, history['val_loss'], 'r--', label='Validation Loss')
plt.title('График функции потерь (Loss)', fontsize=14)
plt.xlabel('Эпохи', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.legend()
plt.grid(True, linestyle=':', alpha=0.7)
# График Accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs, history['train_acc'], 'b-', label='Train Accuracy')
plt.plot(epochs, history['val_acc'], 'r--', label='Validation Accuracy')
plt.title('График точности (Accuracy)', fontsize=14)
plt.xlabel('Эпохи', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.legend()
plt.grid(True, linestyle=':', alpha=0.7)
plt.tight_layout()
plot_path = MEDIA_DIR / "training_history.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"[INFO] График обучения сохранен в: {plot_path}")
def plot_confusion_matrix(y_true, y_pred):
"""Отрисовка тепловой матрицы ошибок"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
cbar_kws={'label': 'Количество сэмплов'})
plt.title('Матрица ошибок (Confusion Matrix) - ResNet50', fontsize=16, pad=20)
plt.ylabel('Истинные классы (Ground Truth)', fontsize=12)
plt.xlabel('Предсказанные классы (Predicted)', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
cm_path = MEDIA_DIR / "confusion_matrix_emoset.png"
plt.savefig(cm_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"[INFO] Матрица ошибок сохранена в: {cm_path}")
# Логика эпохи обучения
def train_epoch(current_model, loader):
current_model.train()
total_loss, correct_preds, total_samples = 0.0, 0, 0
for imgs, labels in tqdm(loader, desc="Тренировка", leave=False, smoothing=0):
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad(set_to_none=True)
logits = current_model(imgs)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * imgs.size(0)
preds = logits.argmax(dim=1)
correct_preds += (preds == labels).sum().item()
total_samples += labels.size(0)
return total_loss / total_samples, correct_preds / total_samples
# Логика эпохи валидации с сохранением предсказаний для матрицы ошибок
@torch.no_grad()
def val_epoch(current_model, loader, return_preds=False):
current_model.eval()
total_loss, correct_preds, total_samples = 0.0, 0, 0
all_preds, all_labels = [], []
for imgs, labels in tqdm(loader, desc="Валидация", leave=False, smoothing=0):
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
logits = current_model(imgs)
loss = criterion(logits, labels)
total_loss += loss.item() * imgs.size(0)
preds = logits.argmax(dim=1)
correct_preds += (preds == labels).sum().item()
total_samples += labels.size(0)
if return_preds:
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
avg_loss = total_loss / total_samples
avg_acc = correct_preds / total_samples
if return_preds:
return avg_loss, avg_acc, all_labels, all_preds
return avg_loss, avg_acc
if __name__ == "__main__":
best_val_acc = 0.0
best_val_loss = float('inf')
epochs_no_improve = 0
checkpoint_path = "./emosetV2_resnet50_best.pth"
# Словарь для хранения истории обучения
history = {
'train_loss': [], 'train_acc': [],
'val_loss': [], 'val_acc': []
}
# Переменные для хранения лучших предсказаний для матрицы
best_labels, best_preds = [], []
print("Старт обучения.")
for epoch in range(1, EPOCHS + 1):
train_loss, train_acc = train_epoch(model, train_loader)
# Получаем предсказания только если это может быть лучшая эпоха
val_loss, val_acc, val_labels, val_preds = val_epoch(model, val_loader, return_preds=True)
scheduler.step()
# Запись в историю
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
# Сохранение лучших весов по Accuracy
if val_acc > best_val_acc:
best_val_acc = val_acc
best_labels = val_labels # Сохраняем предсказания лучшей модели
best_preds = val_preds
torch.save(model.state_dict(), checkpoint_path)
print(f"Сохранен чекпоинт (Acc: {best_val_acc:.4f})")
# Оценка переобучения по Loss (Early Stopping)
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= PATIENCE:
print(f"Ранняя остановка: метрика валидации не улучшается {PATIENCE} эпох.")
break
print("Процесс обучения завершен. Генерирую графики для диссертации...")
plot_learning_curves(history)
plot_confusion_matrix(best_labels, best_preds)
print("Все медиафайлы успешно созданы!")
File diff suppressed because one or more lines are too long
+171
View File
@@ -0,0 +1,171 @@
import os
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
# Настройки путей для медиа
MEDIA_DIR = Path("scripts/media")
MEDIA_DIR.mkdir(parents=True, exist_ok=True)
# Конфигурация путей для инференса и кэширования векторов
DATA_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K")
MODEL_PATH = Path("./src/emoset_resnet50_best.pth")
BATCH_SIZE = 128
NUM_WORKERS = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Вычисления перенесены на: {device}")
class EmoSetFeatureDataset(Dataset):
def __init__(self, root: Path | str, split: str):
self.root = Path(root) / split
self.df = pd.read_csv(self.root / "labels.csv")
self.labels = sorted(self.df["label"].unique())
self.label2idx = {l: i for i, l in enumerate(self.labels)}
self.idx2label = {i: l for l, i in self.label2idx.items()}
# Для экстракции признаков аугментация отключена, используется строгий CenterCrop
self.transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_path = self.root / "images" / row["filename"]
# Перехват битых файлов выборки
try:
img = Image.open(img_path).convert("RGB")
except Exception:
img = Image.new("RGB", (224, 224), (0, 0, 0))
img_tensor = self.transform(img)
label_idx = self.label2idx[row["label"]]
return img_tensor, label_idx
def plot_tsne(embeddings, labels, idx2label, sample_limit=3000):
"""Генерация t-SNE графика для диссертации"""
print(f"Построение t-SNE проекции для {sample_limit} сэмплов...")
tsne_model = TSNE(n_components=2, perplexity=30, random_state=42)
embeddings_2d = tsne_model.fit_transform(embeddings[:sample_limit])
labels_subset = labels[:sample_limit]
plt.figure(figsize=(12, 9))
# Используем более академическую палитру
scatter = plt.scatter(
embeddings_2d[:, 0],
embeddings_2d[:, 1],
c=labels_subset,
cmap="Set2", # Set2 лучше различается при печати
alpha=0.7,
s=20,
edgecolors='w',
linewidths=0.5
)
# Формирование легенды
handles, _ = scatter.legend_elements()
legend_labels = [idx2label[i] for i in range(len(idx2label))]
# Размещение легенды снаружи графика, чтобы не перекрывать данные
plt.legend(handles, legend_labels, title="Эмоциональные классы",
bbox_to_anchor=(1.05, 1), loc='upper left')
plt.title("2D проекция скрытого пространства признаков (t-SNE)", pad=20, fontsize=14)
plt.xlabel("Первая главная компонента (t-SNE 1)", fontsize=12)
plt.ylabel("Вторая главная компонента (t-SNE 2)", fontsize=12)
plt.grid(True, linestyle='--', alpha=0.3)
plt.tight_layout()
plot_path = MEDIA_DIR / "tsne_embeddings.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"[INFO] График t-SNE сохранен в: {plot_path}")
if __name__ == "__main__":
test_ds = EmoSetFeatureDataset(DATA_ROOT, "test")
test_loader = DataLoader(
test_ds,
batch_size=BATCH_SIZE,
shuffle=False, # Отключение шаффла для строгого соответствия индексов
num_workers=NUM_WORKERS,
pin_memory=True
)
print(f"Подготовлено для извлечения: {len(test_ds)} файлов.")
# Инициализация модели и загрузка лучших весов
feature_extractor = timm.create_model(
"resnet50",
pretrained=False,
num_classes=len(test_ds.labels)
)
try:
checkpoint = torch.load(MODEL_PATH, map_location=device)
feature_extractor.load_state_dict(checkpoint)
print("Веса модели успешно загружены.")
except Exception as e:
print(f"Ошибка загрузки весов: {e}. Убедитесь, что модель обучена.")
exit(1)
# Удаление классификационного слоя (fc)
feature_extractor.reset_classifier(0)
feature_extractor.to(device)
feature_extractor.eval()
print("Слой классификации удален. Модель готова к экстракции.")
extracted_embeddings = []
extracted_labels = []
print("Старт пакетной экстракции признаков...")
with torch.no_grad():
for imgs, labels in tqdm(test_loader, desc="Экстракция"):
imgs = imgs.to(device)
# Получение вектора [BATCH_SIZE, 2048]
embeddings_batch = feature_extractor(imgs)
extracted_embeddings.append(embeddings_batch.cpu().numpy())
extracted_labels.append(labels.numpy())
# Агрегация батчей в единые массивы
np_embeddings = np.concatenate(extracted_embeddings, axis=0)
np_labels = np.concatenate(extracted_labels, axis=0)
print(f"Размерность матрицы признаков: {np_embeddings.shape}")
# Сохранение артефактов
np.save("./src/emoset_test_embeddings.npy", np_embeddings)
np.save("./src/emoset_test_labels.npy", np_labels)
print("Матрицы успешно экспортированы в .npy файлы.")
# Генерация медиа для диссертации
plot_tsne(np_embeddings, np_labels, test_ds.idx2label, sample_limit=3000)
print("Процесс полностью завершен.")
+69
View File
@@ -0,0 +1,69 @@
import pandas as pd
from pathlib import Path
from tqdm import tqdm
# Конфигурация путей и целевых признаков
BASE_DIR = Path("../../dataset/DEAM")
MUSIC_DB_PATH = BASE_DIR / "music_db.csv"
FEATURES_DIR = BASE_DIR / "features" / "features"
OUTPUT_PATH = BASE_DIR / "music_db_enriched.csv"
# Маппинг низкоуровневых признаков экстрактора (openSMILE/GeMAPS) в дескрипторы системы
TARGET_FEATURES = {
'pcm_RMSenergy_sma_amean': 'energy',
'pcm_fftMag_spectralFlux_sma_amean': 'flux',
'pcm_fftMag_spectralCentroid_sma_amean': 'centroid',
'F0final_sma_amean': 'pitch',
'logHNR_sma_amean': 'hnr',
'pcm_zcr_sma_amean': 'zcr',
'pcm_fftMag_spectralEntropy_sma_amean': 'entropy',
'pcm_fftMag_psySharpness_sma_amean': 'sharpness'
}
def aggregate_acoustic_features():
if not MUSIC_DB_PATH.exists():
print(f"Базовый файл аннотаций не найден: {MUSIC_DB_PATH}")
return
print("Загрузка эмоциональной разметки DEAM...")
df_main = pd.read_csv(MUSIC_DB_PATH)
print("Агрегация фреймовых акустических признаков...")
aggregated_data = []
# Итерация по трекам для сбора покадровых характеристик
for _, row in tqdm(df_main.iterrows(), total=len(df_main), desc="Обработка аудио-векторов"):
song_id = int(row['song_id'])
feature_file = FEATURES_DIR / f"{song_id}.csv"
if feature_file.exists():
try:
# Чтение сырых векторов (формат csv с разделителем ';')
df_feat = pd.read_csv(feature_file, sep=';')
# Усреднение характеристик по временной оси (time frames)
mean_features = df_feat[list(TARGET_FEATURES.keys())].mean()
# Формирование агрегированной записи
track_data = {'song_id': song_id}
for orig_col, new_col in TARGET_FEATURES.items():
track_data[new_col] = mean_features[orig_col]
aggregated_data.append(track_data)
except Exception as e:
print(f"Ошибка парсинга файла {feature_file.name}: {e}")
# Слияние акустических дескрипторов с эмоциональными координатами (Inner Join)
df_features = pd.DataFrame(aggregated_data)
df_enriched = pd.merge(df_main, df_features, on='song_id', how='inner')
# Очистка возможных артефактов NaN после агрегации
df_enriched = df_enriched.dropna(subset=list(TARGET_FEATURES.values()))
df_enriched.to_csv(OUTPUT_PATH, index=False)
print(f"Экспорт завершен. Сформирована обогащенная база: {OUTPUT_PATH.name}")
print(f"Итоговый размер выборки: {len(df_enriched)} треков.")
if __name__ == "__main__":
aggregate_acoustic_features()
+80
View File
@@ -0,0 +1,80 @@
import joblib
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.linear_model import RidgeCV
from sklearn.multioutput import MultiOutputRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
# Проекция дискретных классов эмоций на непрерывное пространство Рассела (Valence, Arousal)
# Значения откалиброваны в диапазоне [1.0, 9.0]
EMOTION_TO_VA_COORDS = {
0: (7.5, 6.5), # amusement
1: (2.0, 8.0), # anger
2: (6.5, 5.0), # awe
3: (7.0, 3.0), # contentment
4: (3.0, 6.0), # disgust
5: (8.0, 8.0), # excitement
6: (2.5, 7.5), # fear
7: (2.0, 2.0), # sadness
}
def train_va_regressor():
# Настройка путей
base_dir = Path(__file__).resolve().parent.parent
embeddings_path = base_dir / "emoset_test_embeddings.npy"
labels_path = base_dir / "emoset_test_labels.npy"
model_output_path = base_dir / "music_engine" / "va_regressor.pkl"
if not embeddings_path.exists() or not labels_path.exists():
print(f"Артефакты признаков не найдены в директории: {base_dir}")
return
print("Загрузка вектора признаков и меток классов...")
x_features = np.load(embeddings_path)
y_discrete = np.load(labels_path)
# Трансформация целевой переменной: классы -> непрерывные координаты V/A
y_continuous = np.array([EMOTION_TO_VA_COORDS[label] for label in y_discrete])
x_train, x_test, y_train, y_test = train_test_split(
x_features, y_continuous, test_size=0.2, random_state=42
)
# Построение пайплайна: Z-масштабирование и L2-регуляризованная регрессия
# RidgeCV автоматически подбирает оптимальный гиперпараметр alpha (силу регуляризации)
print("Инициализация и обучение пайплайна RidgeCV...")
regression_pipeline = Pipeline([
('scaler', StandardScaler()),
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
])
regression_pipeline.fit(x_train, y_train)
# Оценка обобщающей способности модели
y_pred = regression_pipeline.predict(x_test)
mse_score = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print("Обучение завершено. Метрики качества на тестовой выборке:")
print(f" - MSE: {mse_score:.4f}")
print(f" - R^2: {r2:.4f}")
# Диагностика дисперсии предсказаний
v_min, v_max = y_pred[:, 0].min(), y_pred[:, 0].max()
a_min, a_max = y_pred[:, 1].min(), y_pred[:, 1].max()
print(f"Распределение Valence (прогноз): [{v_min:.2f}, {v_max:.2f}] (Эталон: 1.0 - 9.0)")
print(f"Распределение Arousal (прогноз): [{a_min:.2f}, {a_max:.2f}] (Эталон: 1.0 - 9.0)")
# Экспорт обученного пайплайна
model_output_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(regression_pipeline, model_output_path)
print(f"Пайплайн сохранен: {model_output_path.name}")
if __name__ == "__main__":
train_va_regressor()
@@ -2,30 +2,30 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"id": "8523d028", "id": "8523d028",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"import pandas as pd\n", "import pandas as pd\n",
"import numpy as np\n",
"from pathlib import Path\n", "from pathlib import Path\n",
"from PIL import Image\n", "from PIL import Image\n",
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"\n", "\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torchvision.transforms as T\n", "import torchvision.transforms as T\n",
"import timm\n", "import timm\n",
"import numpy as np\n",
"\n", "\n",
"from sklearn.metrics import accuracy_score, confusion_matrix, classification_report\n" "from sklearn.metrics import accuracy_score, confusion_matrix, classification_report\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": null,
"id": "e0781b02", "id": "e0781b02",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -41,25 +41,26 @@
} }
], ],
"source": [ "source": [
"# Конфигурация путей и параметров инференса\n",
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n", "DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
"MODEL_PATH = Path(\"./emoset_resnet50_best.pth\")\n", "MODEL_PATH = Path(\"./emoset_resnet50_best.pth\")\n",
"\n", "\n",
"BATCH_SIZE = 64\n", "BATCH_SIZE = 64\n",
"NUM_WORKERS = 4\n", "NUM_WORKERS = 4\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n", "print(f\"Аппаратное ускорение: {DEVICE}\")"
"DEVICE\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"id": "79da9640", "id": "79da9640",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"class EmoSetDataset(Dataset):\n", "class EmoSetEvaluationDataset(Dataset):\n",
" def __init__(self, root, split):\n", " # Датасет для строгой валидации с центрированным кропом\n",
" def __init__(self, root: Path | str, split: str):\n",
" self.root = Path(root) / split\n", " self.root = Path(root) / split\n",
" self.df = pd.read_csv(self.root / \"labels.csv\")\n", " self.df = pd.read_csv(self.root / \"labels.csv\")\n",
"\n", "\n",
@@ -67,13 +68,12 @@
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n", " self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n", " self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
"\n", "\n",
" # Стандартный пайплайн трансформаций для инференса ResNet\n",
" self.transform = T.Compose([\n", " self.transform = T.Compose([\n",
" T.Resize((224, 224)),\n", " T.Resize(256),\n",
" T.CenterCrop(224),\n",
" T.ToTensor(),\n", " T.ToTensor(),\n",
" T.Normalize(\n", " T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
" mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225]\n",
" )\n",
" ])\n", " ])\n",
"\n", "\n",
" def __len__(self):\n", " def __len__(self):\n",
@@ -81,15 +81,23 @@
"\n", "\n",
" def __getitem__(self, idx):\n", " def __getitem__(self, idx):\n",
" row = self.df.iloc[idx]\n", " row = self.df.iloc[idx]\n",
" img = Image.open(self.root / \"images\" / row[\"filename\"]).convert(\"RGB\")\n", " img_path = self.root / \"images\" / row[\"filename\"]\n",
" img = self.transform(img)\n", " \n",
" label = self.label2idx[row[\"label\"]]\n", " # Перехват битых файлов для непрерывности оценки\n",
" return img, label\n" " try:\n",
" img = Image.open(img_path).convert(\"RGB\")\n",
" except Exception:\n",
" img = Image.new(\"RGB\", (224, 224), (0, 0, 0))\n",
" \n",
" img_tensor = self.transform(img)\n",
" label_idx = self.label2idx[row[\"label\"]]\n",
" \n",
" return img_tensor, label_idx"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": null,
"id": "12201756", "id": "12201756",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -103,8 +111,8 @@
} }
], ],
"source": [ "source": [
"test_ds = EmoSetDataset(DATA_ROOT, \"test\")\n", "# Инициализация тестовой выборки\n",
"\n", "test_ds = EmoSetEvaluationDataset(DATA_ROOT, \"test\")\n",
"test_loader = DataLoader(\n", "test_loader = DataLoader(\n",
" test_ds,\n", " test_ds,\n",
" batch_size=BATCH_SIZE,\n", " batch_size=BATCH_SIZE,\n",
@@ -113,13 +121,13 @@
" pin_memory=True\n", " pin_memory=True\n",
")\n", ")\n",
"\n", "\n",
"print(\"Classes:\", test_ds.labels)\n", "print(f\"Индексированные классы: {test_ds.labels}\")\n",
"print(\"Test samples:\", len(test_ds))\n" "print(f\"Размер тестовой выборки: {len(test_ds)}\")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": null,
"id": "7e3dc1d5", "id": "7e3dc1d5",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -374,22 +382,17 @@
} }
], ],
"source": [ "source": [
"# Инициализация модели в режиме классификации\n",
"model = timm.create_model(\n", "model = timm.create_model(\n",
" \"resnet50\",\n", " \"resnet50\",\n",
" pretrained=False,\n", " pretrained=False,\n",
" num_classes=len(test_ds.labels)\n", " num_classes=len(test_ds.labels)\n",
")\n", ")"
"\n",
"state = torch.load(MODEL_PATH, map_location=DEVICE)\n",
"model.load_state_dict(state)\n",
"\n",
"model.to(DEVICE)\n",
"model.eval()\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": null,
"id": "b42a84f1", "id": "b42a84f1",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -402,27 +405,16 @@
} }
], ],
"source": [ "source": [
"all_preds = []\n", "# Загрузка весов и перевод в режим инференса\n",
"all_targets = []\n", "checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)\n",
"\n", "model.load_state_dict(checkpoint)\n",
"with torch.no_grad():\n", "model.to(DEVICE)\n",
" for imgs, labels in tqdm(test_loader):\n", "model.eval()"
" imgs = imgs.to(DEVICE)\n",
" labels = labels.to(DEVICE)\n",
"\n",
" logits = model(imgs)\n",
" preds = logits.argmax(dim=1)\n",
"\n",
" all_preds.append(preds.cpu().numpy())\n",
" all_targets.append(labels.cpu().numpy())\n",
"\n",
"all_preds = np.concatenate(all_preds)\n",
"all_targets = np.concatenate(all_targets)\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": null,
"id": "4c1f1377", "id": "4c1f1377",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -435,13 +427,25 @@
} }
], ],
"source": [ "source": [
"acc = accuracy_score(all_targets, all_preds)\n", "# Сбор предсказаний на тестовой выборке\n",
"print(f\"Test accuracy: {acc:.4f}\")\n" "all_preds = []\n",
"all_targets = []\n",
"\n",
"print(\"Запуск инференса на тестовой выборке...\")\n",
"with torch.no_grad():\n",
" for imgs, labels in tqdm(test_loader, desc=\"Оценка метрик\"):\n",
" imgs = imgs.to(DEVICE)\n",
" \n",
" logits = model(imgs)\n",
" preds = logits.argmax(dim=1)\n",
"\n",
" all_preds.append(preds.cpu().numpy())\n",
" all_targets.append(labels.numpy())"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": null,
"id": "6b022825", "id": "6b022825",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -468,19 +472,14 @@
} }
], ],
"source": [ "source": [
"print(\n", "# Агрегация результатов\n",
" classification_report(\n", "all_preds = np.concatenate(all_preds, axis=0)\n",
" all_targets,\n", "all_targets = np.concatenate(all_targets, axis=0)"
" all_preds,\n",
" target_names=test_ds.labels,\n",
" digits=4\n",
" )\n",
")\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": null,
"id": "2fcb69ac", "id": "2fcb69ac",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -496,20 +495,70 @@
} }
], ],
"source": [ "source": [
"import matplotlib.pyplot as plt\n", "# Расчет интегральных метрик классификации\n",
"acc = accuracy_score(all_targets, all_preds)\n",
"print(f\"\\nОбщая точность (Accuracy): {acc:.4f}\\n\")\n",
"\n", "\n",
"print(\"Детализированный отчет (Classification Report):\")\n",
"print(\n",
" classification_report(\n",
" all_targets,\n",
" all_preds,\n",
" target_names=test_ds.labels,\n",
" digits=4\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2084ab91",
"metadata": {},
"outputs": [],
"source": [
"# Построение матрицы ошибок (Confusion Matrix)\n",
"cm = confusion_matrix(all_targets, all_preds)\n", "cm = confusion_matrix(all_targets, all_preds)\n",
"\n", "\n",
"plt.figure(figsize=(8, 8))\n", "plt.figure(figsize=(10, 8))"
"plt.imshow(cm)\n", ]
"plt.colorbar()\n", },
"plt.xticks(range(len(test_ds.labels)), test_ds.labels, rotation=45)\n", {
"plt.yticks(range(len(test_ds.labels)), test_ds.labels)\n", "cell_type": "code",
"plt.xlabel(\"Predicted\")\n", "execution_count": null,
"plt.ylabel(\"True\")\n", "id": "83a84e14",
"plt.title(\"Confusion Matrix (Test)\")\n", "metadata": {},
"plt.tight_layout()\n", "outputs": [],
"plt.show()\n" "source": [
"# Использование seaborn для академичной визуализации с числами\n",
"sns.heatmap(\n",
" cm, \n",
" annot=True, \n",
" fmt=\"d\", \n",
" cmap=\"Blues\", \n",
" xticklabels=test_ds.labels, \n",
" yticklabels=test_ds.labels,\n",
" cbar=False\n",
")\n",
"\n",
"plt.title(\"Матрица ошибок классификации EmoSet (ResNet-50)\", pad=20)\n",
"plt.xlabel(\"Предсказанный класс\", labelpad=15)\n",
"plt.ylabel(\"Истинный класс\", labelpad=15)\n",
"plt.xticks(rotation=45)\n",
"plt.yticks(rotation=0)\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "280d5637",
"metadata": {},
"outputs": [],
"source": [
"# Экспорт графика\n",
"plt.savefig(\"../confusion_matrix_resnet50.png\", dpi=300, bbox_inches='tight')\n",
"plt.show()"
] ]
} }
], ],
+97
View File
@@ -0,0 +1,97 @@
import numpy as np
import pandas as pd
import joblib
from pathlib import Path
from sklearn.metrics import mean_squared_error, r2_score
# 1. Настройка путей
embeddings_path = Path("./src/emoset_test_embeddings.npy")
csv_path = Path("./NFS/Thesis/Emoset/EmoSet-118K/test/labels.csv")
model_path = Path("./src/music_engine/va_regressor.pkl")
output_dir = Path("./src/scripts/media")
output_file = output_dir / "metrics_output.txt"
# 2. Корректный маппинг 8 классов EmoSet в шкалу DEAM [1.0, 9.0]
# Формула перевода из [-1, 1] в [1, 9]: 5.0 + (X * 4.0)
EMO_TO_VA = {
"amusement": [8.2, 6.6], # Веселье (Высокий позитив, средняя энергия)
"awe": [7.0, 7.4], # Восхищение (Позитив, высокая энергия)
"contentment": [7.8, 3.4], # Умиротворение (Позитив, низкая энергия)
"excitement": [8.2, 8.2], # Возбуждение (Макс. позитив, макс. энергия)
"anger": [2.2, 7.8], # Гнев (Глубокий негатив, высокая энергия)
"disgust": [2.6, 6.6], # Отвращение (Негатив, средняя энергия)
"fear": [2.6, 8.2], # Страх (Негатив, максимальная энергия)
"sadness": [2.2, 2.6] # Грусть (Глубокий негатив, низкая энергия)
}
def generate_slide_metrics():
print("[INFO] Загрузка тестовых артефактов...")
if not all(p.exists() for p in [embeddings_path, csv_path, model_path]):
print("[ERROR] Проверьте наличие файлов данных или модели регрессора.")
return
output_dir.mkdir(parents=True, exist_ok=True)
# 3. Загрузка эмбеддингов и меток
X_test = np.load(embeddings_path)
df = pd.read_csv(csv_path)
if len(X_test) != len(df):
print(f"[WARN] Корректировка размеров выборки: Эмбеддинги ({len(X_test)}) != Метки ({len(df)})")
min_len = min(len(X_test), len(df))
X_test = X_test[:min_len]
df = df.iloc[:min_len]
y_test_list = [EMO_TO_VA.get(label.lower().strip(), [5.0, 5.0]) for label in df['label']]
y_test = np.array(y_test_list)
# 4. Выполнение инференса
print("[INFO] Выполнение инференса регрессора на скрытом пространстве признаков...")
regressor = joblib.load(model_path)
y_pred = regressor.predict(X_test)
# === БЛОК ДИАГНОСТИКИ ШКАЛЫ ===
print("\n" + "-"*50)
print(" ДИАГНОСТИКА ДИАПАЗОНОВ ЗНАЧЕНИЙ ".center(50))
print("-"*50)
print(f"Истинные (y_test) -> Мин: {y_test.min():.2f}, Макс: {y_test.max():.2f}, Среднее: {y_test.mean():.2f}")
print(f"Предсказания (y_pred) -> Мин: {y_pred.min():.2f}, Макс: {y_pred.max():.2f}, Среднее: {y_pred.mean():.2f}")
print("-"*50 + "\n")
# ==============================
# 5. Расчет метрик
mse_v = mean_squared_error(y_test[:, 0], y_pred[:, 0])
r2_v = r2_score(y_test[:, 0], y_pred[:, 0])
mse_a = mean_squared_error(y_test[:, 1], y_pred[:, 1])
r2_a = r2_score(y_test[:, 1], y_pred[:, 1])
mse_total = mean_squared_error(y_test, y_pred)
r2_total = r2_score(y_test, y_pred)
# 6. Вывод и сохранение результатов
table_content = f"""
==================================================
ТАБЛИЦА МЕТРИК ДЛЯ СЛАЙДА 10
==================================================
| Метрика | Valence (V) | Arousal (A) | Общая (Total) |
|------------|--------------|--------------|---------------|
| MSE | {mse_v:<12.4f} | {mse_a:<12.4f} | {mse_total:<13.4f} |
| R² | {r2_v:<12.4f} | {r2_a:<12.4f} | {r2_total:<13.4f} |
==================================================
Формула целевой функции для вставки на слайд (LaTeX):
$$Score_{{final}} = D_{{emo}} + 4.0 \cdot Acoustic_{{penalty}}$$
"""
print(table_content)
with open(output_file, 'w', encoding='utf-8') as f:
f.write(table_content)
print(f"[SUCCESS] Метрики успешно сохранены в файл: {output_file.absolute()}")
if __name__ == "__main__":
generate_slide_metrics()
-125
View File
@@ -1,125 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "0336fd0c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ База загружена. Треков: 1744\n",
"🔍 Собираем акустические признаки...\n",
"\n",
"🚀 ГОТОВО! Обогащенная база сохранена: ../../dataset/DEAM/music_db_enriched.csv\n",
"Собрано фичей для 1744 из 1744 треков.\n",
" song_id valence arousal energy flux centroid pitch \\\n",
"0 2 3.1 3.0 0.097268 0.846947 483.421751 93.884056 \n",
"1 3 3.5 3.3 0.126809 0.959460 173.219616 62.682589 \n",
"2 4 5.7 5.5 0.156699 1.333944 466.434797 92.850316 \n",
"3 5 4.4 5.3 0.126455 1.009927 546.152506 158.673853 \n",
"4 7 5.8 6.4 0.268180 1.589191 175.369162 83.823484 \n",
"\n",
" hnr zcr entropy sharpness \n",
"0 3.615380 0.034270 3.299075 0.426490 \n",
"1 -2.600122 0.017893 2.294971 0.165583 \n",
"2 -0.579130 0.042936 3.258138 0.395410 \n",
"3 1.751148 0.043781 3.514585 0.494367 \n",
"4 12.006770 0.014783 2.177862 0.170058 \n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from pathlib import Path\n",
"from tqdm import tqdm # для красивого прогресс-бара, если не установлен - убери\n",
"\n",
"# 1. Пути к файлам\n",
"base_dir = Path(\"../../dataset/DEAM\") # Поправь, если запускаешь из другого места\n",
"music_db_path = base_dir / \"music_db.csv\"\n",
"features_dir = base_dir / \"features\" / \"features\"\n",
"output_path = base_dir / \"music_db_enriched.csv\"\n",
"\n",
"# 2. Наш \"Золотой список\" (8 признаков)\n",
"target_columns = {\n",
" 'pcm_RMSenergy_sma_amean': 'energy',\n",
" 'pcm_fftMag_spectralFlux_sma_amean': 'flux',\n",
" 'pcm_fftMag_spectralCentroid_sma_amean': 'centroid',\n",
" 'F0final_sma_amean': 'pitch',\n",
" 'logHNR_sma_amean': 'hnr',\n",
" 'pcm_zcr_sma_amean': 'zcr',\n",
" 'pcm_fftMag_spectralEntropy_sma_amean': 'entropy',\n",
" 'pcm_fftMag_psySharpness_sma_amean': 'sharpness'\n",
"}\n",
"\n",
"# 3. Загружаем текущую базу с V/A\n",
"if not music_db_path.exists():\n",
" print(f\"❌ ОШИБКА: Не найден файл {music_db_path}\")\n",
"else:\n",
" df_main = pd.read_csv(music_db_path)\n",
" print(f\"✅ База загружена. Треков: {len(df_main)}\")\n",
"\n",
" # Подготавливаем новые колонки\n",
" for col_name in target_columns.values():\n",
" df_main[col_name] = np.nan\n",
"\n",
" # 4. Проходимся по всем трекам и ищем их акустические CSV\n",
" print(\"🔍 Собираем акустические признаки...\")\n",
" found_count = 0\n",
" \n",
" for index, row in df_main.iterrows():\n",
" song_id = int(row['song_id'])\n",
" feature_file = features_dir / f\"{song_id}.csv\"\n",
" \n",
" if feature_file.exists():\n",
" try:\n",
" # Читаем CSV с признаками (разделитель там обычно точка с запятой)\n",
" df_feat = pd.read_csv(feature_file, sep=';')\n",
" \n",
" # Усредняем значения по всем фреймам (одна песня разбита на сотни строк-фреймов)\n",
" mean_features = df_feat[list(target_columns.keys())].mean()\n",
" \n",
" # Записываем в главную базу\n",
" for orig_col, new_col in target_columns.items():\n",
" df_main.at[index, new_col] = mean_features[orig_col]\n",
" \n",
" found_count += 1\n",
" except Exception as e:\n",
" print(f\"Ошибка чтения {feature_file}: {e}\")\n",
" \n",
" # 5. Сохраняем результат\n",
" # Удаляем треки, для которых не нашлось фичей (если такие есть)\n",
" df_main = df_main.dropna(subset=list(target_columns.values()))\n",
" \n",
" df_main.to_csv(output_path, index=False)\n",
" print(f\"\\n🚀 ГОТОВО! Обогащенная база сохранена: {output_path}\")\n",
" print(f\"Собрано фичей для {found_count} из {len(df_main)} треков.\")\n",
" print(df_main.head())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (thesis)",
"language": "python",
"name": "thesis"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
-614
View File
@@ -1,614 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "09f9237a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (4.4.2)\n",
"Requirement already satisfied: tqdm in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (4.67.1)\n",
"Requirement already satisfied: pillow in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (12.1.0)\n",
"Requirement already satisfied: requests in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (2.32.5)\n",
"Requirement already satisfied: filelock in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (3.20.3)\n",
"Requirement already satisfied: numpy>=1.17 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (2.4.1)\n",
"Requirement already satisfied: pyarrow>=21.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (22.0.0)\n",
"Requirement already satisfied: dill<0.4.1,>=0.3.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.4.0)\n",
"Requirement already satisfied: pandas in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (2.3.3)\n",
"Requirement already satisfied: httpx<1.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.28.1)\n",
"Requirement already satisfied: xxhash in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (3.6.0)\n",
"Requirement already satisfied: multiprocess<0.70.19 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.70.18)\n",
"Requirement already satisfied: fsspec<=2025.10.0,>=2023.1.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2025.10.0)\n",
"Requirement already satisfied: huggingface-hub<2.0,>=0.25.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (1.3.1)\n",
"Requirement already satisfied: packaging in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (25.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (6.0.3)\n",
"Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (3.13.3)\n",
"Requirement already satisfied: anyio in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (4.12.1)\n",
"Requirement already satisfied: certifi in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (2026.1.4)\n",
"Requirement already satisfied: httpcore==1.* in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (1.0.9)\n",
"Requirement already satisfied: idna in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (3.11)\n",
"Requirement already satisfied: h11>=0.16 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx<1.0.0->datasets) (0.16.0)\n",
"Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.2.0)\n",
"Requirement already satisfied: shellingham in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.5.4)\n",
"Requirement already satisfied: typer-slim in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (0.21.1)\n",
"Requirement already satisfied: typing-extensions>=4.1.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (4.15.0)\n",
"Requirement already satisfied: charset_normalizer<4,>=2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests) (3.4.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests) (2.6.3)\n",
"Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2.6.1)\n",
"Requirement already satisfied: aiosignal>=1.4.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.4.0)\n",
"Requirement already satisfied: attrs>=17.3.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (25.4.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.8.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (6.7.0)\n",
"Requirement already satisfied: propcache>=0.2.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (0.4.1)\n",
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.22.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2025.2)\n",
"Requirement already satisfied: tzdata>=2022.7 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2025.3)\n",
"Requirement already satisfied: six>=1.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n",
"Requirement already satisfied: click>=8.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from typer-slim->huggingface-hub<2.0,>=0.25.0->datasets) (8.3.1)\n"
]
}
],
"source": [
"!pip install datasets tqdm pillow requests\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6f0b2e2c",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "95f07577d20642b09f2cda6f0b2cca14",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "868d872a109d49f9966f2f19985e7048",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "06741794289540849ad179c5966dcab8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0/18 [00:00<?, ?files/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e47aad5270144913996cb5b226213ab9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00000-of-00018.parquet: 0%| | 0.00/509M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "30d1492a948245e3b6b58e92218cd760",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00001-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "931823b458cb4696b459e9011537cf1e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00002-of-00018.parquet: 0%| | 0.00/489M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "846f4245b16d4cc096a43c940590ad11",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00003-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "71df201ff1a24811af67458c3fe3f2f4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00004-of-00018.parquet: 0%| | 0.00/495M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "404dce6c69fc413dbe4aa84c289a0ab6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00005-of-00018.parquet: 0%| | 0.00/501M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e52b0bbbfdd14c599f44f02a48542317",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00006-of-00018.parquet: 0%| | 0.00/510M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "172981d77fc941cfa32c05f5a34bf742",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00007-of-00018.parquet: 0%| | 0.00/497M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cc9d886ff22f4165bf696c8b4d758931",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00008-of-00018.parquet: 0%| | 0.00/512M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5f118a9923c64ee2aa2001a1414927a3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00009-of-00018.parquet: 0%| | 0.00/502M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "db61d8d556dc4574adbd8f916f790fa7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00010-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "75414190b19c4affbe190f6dd4f7bc4f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00011-of-00018.parquet: 0%| | 0.00/500M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "172aa22ed0c44a289e0ac68b240c13c4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00012-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2baa935ed3524a73883909752cb15907",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00013-of-00018.parquet: 0%| | 0.00/491M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5e716611b29b44788e0bf2e7ad05be5b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00014-of-00018.parquet: 0%| | 0.00/502M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d9c0baac101b449794155392f07b49c3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00015-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b31cdc7f17ac4ac8a04593e8a01a300a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00016-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ed6766f750c54b4194957bfe3db78ed6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00017-of-00018.parquet: 0%| | 0.00/494M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5454d2ecded64b82a12823f02a7ab12d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/val-00000-of-00002.parquet: 0%| | 0.00/282M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "62dd1439e0514c98b0c24cc8f600c57e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/val-00001-of-00002.parquet: 0%| | 0.00/283M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3a5b966f79314e069251462bff82395f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00000-of-00004.parquet: 0%| | 0.00/422M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "422974f938924910a0712b30a9c2bd84",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00001-of-00004.parquet: 0%| | 0.00/430M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f155a08427094de7ad1a5884e623db2b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00002-of-00004.parquet: 0%| | 0.00/420M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a94a4621d19f45f690e0064fee83767b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00003-of-00004.parquet: 0%| | 0.00/422M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "50f55b00a27b4213b573b398e5b0d708",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0%| | 0/94481 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8658b8414f604f0ca2fd248a214ad4aa",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating val split: 0%| | 0/5905 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d59b7dea75f84b64bb8b262b43730e51",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating test split: 0%| | 0/17716 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0c5815040f0a4a31903348a8327811a5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading dataset shards: 0%| | 0/18 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
" num_rows: 94481\n",
" })\n",
" val: Dataset({\n",
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
" num_rows: 5905\n",
" })\n",
" test: Dataset({\n",
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
" num_rows: 17716\n",
" })\n",
"})\n"
]
}
],
"source": [
"from datasets import load_dataset\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"import requests\n",
"\n",
"# куда сохраняем датасет\n",
"DATA_DIR = Path(\"../dataset/EmoSet-118K\")\n",
"DATA_DIR.mkdir(exist_ok=True, parents=True)\n",
"\n",
"# загружаем через Hugging Face\n",
"ds = load_dataset(\"Woleek/EmoSet-118K\")\n",
"\n",
"print(ds)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "052ab073",
"metadata": {},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"from pathlib import Path\n",
"\n",
"def save_split(split):\n",
" split_dir = DATA_DIR / split\n",
" img_dir = split_dir / \"images\"\n",
" img_dir.mkdir(parents=True, exist_ok=True)\n",
"\n",
" labels_path = split_dir / \"labels.csv\"\n",
"\n",
" # перезаписываем labels.csv\n",
" with open(labels_path, \"w\") as f:\n",
" f.write(\"filename,label\\n\")\n",
"\n",
" for example in tqdm(ds[split]):\n",
" img = example[\"image\"] # уже PIL.Image\n",
" label = example[\"emotion\"]\n",
" image_id = example[\"image_id\"]\n",
"\n",
" fname = f\"{image_id}.jpg\"\n",
" img.save(img_dir / fname)\n",
"\n",
" with open(labels_path, \"a\") as f:\n",
" f.write(f\"{fname},{label}\\n\")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a74ceedf",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 94481/94481 [18:43<00:00, 84.10it/s] \n",
"100%|██████████| 5905/5905 [01:08<00:00, 86.57it/s] \n",
"100%|██████████| 17716/17716 [02:57<00:00, 100.01it/s]\n"
]
}
],
"source": [
"save_split(\"train\")\n",
"save_split(\"val\")\n",
"save_split(\"test\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "thesis-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
-140
View File
@@ -1,140 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Загрузка датасета DEAM\n",
"\n",
"Этот ноутбук предназначен для автоматизации процесса скачивания и подготовки музыкального датасета **DEAM** (Database for Emotional Analysis in Music).\n",
"Данные будут помещены в папку `dataset/DEAM`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting kagglehub\n",
" Downloading kagglehub-1.0.1-py3-none-any.whl.metadata (40 kB)\n",
"Collecting kagglesdk<1.0,>=0.1.22 (from kagglehub)\n",
" Downloading kagglesdk-0.1.23-py3-none-any.whl.metadata (13 kB)\n",
"Requirement already satisfied: packaging in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (25.0)\n",
"Requirement already satisfied: pyyaml in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (6.0.3)\n",
"Requirement already satisfied: requests in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (2.32.5)\n",
"Requirement already satisfied: tqdm in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (4.67.1)\n",
"Requirement already satisfied: protobuf in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglesdk<1.0,>=0.1.22->kagglehub) (6.33.4)\n",
"Requirement already satisfied: charset_normalizer<4,>=2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (3.4.4)\n",
"Requirement already satisfied: idna<4,>=2.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (3.11)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (2.6.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (2026.1.4)\n",
"Downloading kagglehub-1.0.1-py3-none-any.whl (70 kB)\n",
"Downloading kagglesdk-0.1.23-py3-none-any.whl (217 kB)\n",
"Installing collected packages: kagglesdk, kagglehub\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2/2\u001b[0m [kagglehub]\n",
"\u001b[1A\u001b[2KSuccessfully installed kagglehub-1.0.1 kagglesdk-0.1.23\n",
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.1.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"!pip install kagglehub"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Скачиваем датасет DEAM...\n",
"Downloading to /home/zin/.cache/kagglehub/datasets/imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music/1.archive...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1.83G/1.83G [01:09<00:00, 28.2MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting files...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Датасет скачан во временную директорию: /home/zin/.cache/kagglehub/datasets/imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music/versions/1\n",
"Переносим файлы в ../dataset/DEAM...\n",
"\n",
"[УСПЕХ] Датасет DEAM готов к работе!\n"
]
}
],
"source": [
"import os\n",
"import shutil\n",
"import kagglehub\n",
"from pathlib import Path\n",
"\n",
"# 1. Настройка путей\n",
"DATASET_ROOT = Path(\"../dataset\")\n",
"DEAM_ROOT = DATASET_ROOT / \"DEAM\"\n",
"DEAM_ROOT.mkdir(parents=True, exist_ok=True)\n",
"\n",
"# 2. Загрузка через kagglehub\n",
"print(\"Скачиваем датасет DEAM...\")\n",
"kaggle_cache_path = kagglehub.dataset_download(\"imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music\")\n",
"print(f\"Датасет скачан во временную директорию: {kaggle_cache_path}\")\n",
"\n",
"# 3. Перемещение файлов в проект\n",
"print(f\"Переносим файлы в {DEAM_ROOT}...\")\n",
"shutil.copytree(kaggle_cache_path, DEAM_ROOT, dirs_exist_ok=True)\n",
"\n",
"print(\"\\n[УСПЕХ] Датасет DEAM готов к работе!\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (my-python-project)",
"language": "python",
"name": "my-python-project"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
-171
View File
@@ -1,171 +0,0 @@
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
import json
from PIL import Image
class EmoSet(Dataset):
ATTRIBUTES_MULTI_CLASS = [
'scene', 'facial_expression', 'human_action', 'brightness', 'colorfulness',
]
ATTRIBUTES_MULTI_LABEL = [
'object'
]
NUM_CLASSES = {
'brightness': 11,
'colorfulness': 11,
'scene': 254,
'object': 409,
'facial_expression': 6,
'human_action': 264,
}
def __init__(self,
data_root,
num_emotion_classes,
phase,
):
assert num_emotion_classes in (8, 2)
assert phase in ('train', 'val', 'test')
self.transforms_dict = self.get_data_transforms()
self.info = self.get_info(data_root, num_emotion_classes)
if phase == 'train':
self.transform = self.transforms_dict['train']
elif phase == 'val':
self.transform = self.transforms_dict['val']
elif phase == 'test':
self.transform = self.transforms_dict['test']
else:
raise NotImplementedError
data_store = json.load(open(os.path.join(data_root, f'{phase}.json')))
self.data_store = [
[
self.info['emotion']['label2idx'][item[0]],
item[1],
os.path.join(data_root, item[2]),
os.path.join(data_root, item[3])
]
for item in data_store
]
@classmethod
def get_data_transforms(cls):
transforms_dict = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
return transforms_dict
def get_info(self, data_root, num_emotion_classes):
assert num_emotion_classes in (8, 2)
info = json.load(open(os.path.join(data_root, 'info.json')))
if num_emotion_classes == 8:
pass
elif num_emotion_classes == 2:
emotion_info = {
'label2idx': {
'amusement': 0,
'awe': 0,
'contentment': 0,
'excitement': 0,
'anger': 1,
'disgust': 1,
'fear': 1,
'sadness': 1,
},
'idx2label': {
'0': 'positive',
'1': 'negative',
}
}
info['emotion'] = emotion_info
else:
raise NotImplementedError
return info
def load_image_by_path(self, path):
image = Image.open(path).convert('RGB')
image = self.transform(image)
return image
def load_annotation_by_path(self, path):
json_data = json.load(open(path))
return json_data
def __getitem__(self, item):
emotion_label_idx, image_id, image_path, annotation_path = self.data_store[item]
image = self.load_image_by_path(image_path)
annotation_data = self.load_annotation_by_path(annotation_path)
data = {'image_id': image_id, 'image': image, 'emotion_label_idx': emotion_label_idx}
for attribute in self.ATTRIBUTES_MULTI_CLASS:
# if empty, set to -1, else set to label index
attribute_label_idx = -1
if attribute in annotation_data:
attribute_label_idx = self.info[attribute]['label2idx'][str(annotation_data[attribute])]
data.update({f'{attribute}_label_idx': attribute_label_idx})
for attribute in self.ATTRIBUTES_MULTI_LABEL:
# if empty, set to 0, else set to 1
assert attribute == 'object'
num_classes = self.NUM_CLASSES[attribute]
attribute_label_idx = torch.zeros(num_classes)
if attribute in annotation_data:
for label in annotation_data[attribute]:
attribute_label_idx[self.info[attribute]['label2idx'][label]] = 1
data.update({f'{attribute}_label_idx': attribute_label_idx})
return data
def __len__(self):
return len(self.data_store)
if __name__ == '__main__':
data_root = r'F:\common_file_system\EmoSet\EmoSet_v5_划分train-test-val'
num_emotion_classes = 8
phase = 'train'
dataset = EmoSet(
data_root=data_root,
num_emotion_classes=num_emotion_classes,
phase=phase,
)
# print(dataset.info)
dataloader = DataLoader(dataset, batch_size = 16, shuffle = True)
for i, data in enumerate(dataloader):
pass
# print(data['emotion_label_idx'])
# print(data['scene_label_idx'])
# print(data['facial_expression_label_idx'])
# print(data['human_action_label_idx'])
# print(data['brightness_label_idx'])
# print(data['colorfulness_label_idx'])
# print(data['object_label_idx'])
# break
File diff suppressed because one or more lines are too long
Binary file not shown.

After

Width:  |  Height:  |  Size: 313 KiB

+12
View File
@@ -0,0 +1,12 @@
==================================================
ТАБЛИЦА МЕТРИК ДЛЯ СЛАЙДА 10
==================================================
| Метрика | Valence (V) | Arousal (A) | Общая (Total) |
|------------|--------------|--------------|---------------|
| MSE | 1.5135 | 2.2743 | 1.8939 |
| R² | 0.7927 | 0.4321 | 0.6124 |
==================================================
Формула целевой функции для вставки на слайд (LaTeX):
$$Score_{final} = D_{emo} + 4.0 \cdot Acoustic_{penalty}$$
Binary file not shown.

After

Width:  |  Height:  |  Size: 243 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

-88
View File
@@ -1,88 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"id": "b92e0213",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from pathlib import Path"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1763c51e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ УСПЕХ! База создана: ../../dataset/DEAM/music_db.csv\n",
"Всего треков в базе: 1744\n",
"Пример данных:\n",
" song_id valence arousal\n",
"0 2 3.1 3.0\n",
"1 3 3.5 3.3\n",
"2 4 5.7 5.5\n",
"3 5 4.4 5.3\n",
"4 7 5.8 6.4\n"
]
}
],
"source": [
"# Точный путь к оригинальным аннотациям\n",
"source_path = Path(\"../../dataset/DEAM/DEAM_Annotations/annotations/annotations averaged per song/song_level/static_annotations_averaged_songs_1_2000.csv\")\n",
"# Путь, куда сохраним очищенную базу для движка\n",
"output_path = Path(\"../../dataset/DEAM/music_db.csv\")\n",
"\n",
"if not source_path.exists():\n",
" print(f\"❌ Исходный файл не найден по пути: {source_path}\")\n",
"else:\n",
" # skipinitialspace=True уберет лишние пробелы в названиях колонок, если они есть\n",
" df = pd.read_csv(source_path, skipinitialspace=True)\n",
" \n",
" # Берем только нужные колонки (по твоему примеру)\n",
" clean_df = df[['song_id', 'valence_mean', 'arousal_mean']].copy()\n",
" \n",
" # Переименовываем для простоты кода в движке\n",
" clean_df.columns = ['song_id', 'valence', 'arousal']\n",
" \n",
" # Приводим ID к целому числу (2, 3, 4...), чтобы искать файлы '2.mp3'\n",
" clean_df['song_id'] = clean_df['song_id'].astype(int)\n",
" \n",
" # Сохраняем финальный файл\n",
" clean_df.to_csv(output_path, index=False)\n",
" \n",
" print(f\"✅ УСПЕХ! База создана: {output_path}\")\n",
" print(f\"Всего треков в базе: {len(clean_df)}\")\n",
" print(\"Пример данных:\")\n",
" print(clean_df.head())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (thesis)",
"language": "python",
"name": "thesis"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
-114
View File
@@ -1,114 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d70d8e32",
"metadata": {},
"outputs": [],
"source": [
"from concurrent.futures import ProcessPoolExecutor\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"import torch\n",
"from torchvision import transforms\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "31b0fa82",
"metadata": {},
"outputs": [],
"source": [
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
"TRANSFORM = transforms.Compose([\n",
" transforms.Resize((224,224)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1a17ecf5",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/94481 [00:00<?, ?it/s]\n"
]
},
{
"ename": "PicklingError",
"evalue": "Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31m_RemoteTraceback\u001b[39m Traceback (most recent call last)",
"\u001b[31m_RemoteTraceback\u001b[39m: \n\"\"\"\nTraceback (most recent call last):\n File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 244, in _feed\n obj = _ForkingPickler.dumps(obj)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/reduction.py\", line 51, in dumps\n cls(buf, protocol).dump(obj)\n_pickle.PicklingError: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed\n\"\"\"",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[31mPicklingError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 20\u001b[39m\n\u001b[32m 18\u001b[39m futures = [executor.submit(process_row, row, split_dir, tensor_dir) \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m df.itertuples()]\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m tqdm(futures):\n\u001b[32m---> \u001b[39m\u001b[32m20\u001b[39m results.append(\u001b[43mf\u001b[49m\u001b[43m.\u001b[49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 22\u001b[39m new_df = pd.DataFrame(results)\n\u001b[32m 23\u001b[39m new_df.to_csv(DATA_ROOT / split / \u001b[33m\"\u001b[39m\u001b[33mlabels_tensor.csv\u001b[39m\u001b[33m\"\u001b[39m, index=\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:449\u001b[39m, in \u001b[36mFuture.result\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 447\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m CancelledError()\n\u001b[32m 448\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state == FINISHED:\n\u001b[32m--> \u001b[39m\u001b[32m449\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m__get_result\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 451\u001b[39m \u001b[38;5;28mself\u001b[39m._condition.wait(timeout)\n\u001b[32m 453\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state \u001b[38;5;129;01min\u001b[39;00m [CANCELLED, CANCELLED_AND_NOTIFIED]:\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:401\u001b[39m, in \u001b[36mFuture.__get_result\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 399\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception:\n\u001b[32m 400\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m401\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception\n\u001b[32m 402\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 403\u001b[39m \u001b[38;5;66;03m# Break a reference cycle with the exception in self._exception\u001b[39;00m\n\u001b[32m 404\u001b[39m \u001b[38;5;28mself\u001b[39m = \u001b[38;5;28;01mNone\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py:244\u001b[39m, in \u001b[36mQueue._feed\u001b[39m\u001b[34m(buffer, notempty, send_bytes, writelock, reader_close, writer_close, ignore_epipe, onerror, queue_sem)\u001b[39m\n\u001b[32m 241\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[32m 243\u001b[39m \u001b[38;5;66;03m# serialize the data before acquiring the lock\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m244\u001b[39m obj = \u001b[43m_ForkingPickler\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdumps\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 245\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m wacquire \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 246\u001b[39m send_bytes(obj)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/reduction.py:51\u001b[39m, in \u001b[36mForkingPickler.dumps\u001b[39m\u001b[34m(cls, obj, protocol)\u001b[39m\n\u001b[32m 48\u001b[39m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[32m 49\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdumps\u001b[39m(\u001b[38;5;28mcls\u001b[39m, obj, protocol=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m 50\u001b[39m buf = io.BytesIO()\n\u001b[32m---> \u001b[39m\u001b[32m51\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mbuf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprotocol\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdump\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 52\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m buf.getbuffer()\n",
"\u001b[31mPicklingError\u001b[39m: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed"
]
}
],
"source": [
"def process_row(row, split_dir, tensor_dir):\n",
" img_path = split_dir / row.filename\n",
" img = Image.open(img_path).convert(\"RGB\")\n",
" tensor = TRANSFORM(img)\n",
" tensor_path = tensor_dir / f\"{row.filename}.pt\"\n",
" torch.save(tensor, tensor_path)\n",
" return {\"tensor_path\": str(tensor_path), \"label\": row.label}\n",
"\n",
"for split in [\"train\",\"val\",\"test\"]:\n",
" split_dir = DATA_ROOT / split / \"images\"\n",
" tensor_dir = DATA_ROOT / split / \"tensors\"\n",
" tensor_dir.mkdir(exist_ok=True, parents=True)\n",
"\n",
" df = pd.read_csv(DATA_ROOT / split / \"labels.csv\")\n",
"\n",
" results = []\n",
" with ProcessPoolExecutor(max_workers=12) as executor:\n",
" futures = [executor.submit(process_row, row, split_dir, tensor_dir) for row in df.itertuples()]\n",
" for f in tqdm(futures):\n",
" results.append(f.result())\n",
"\n",
" new_df = pd.DataFrame(results)\n",
" new_df.to_csv(DATA_ROOT / split / \"labels_tensor.csv\", index=False)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "thesis-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+319
View File
@@ -0,0 +1,319 @@
import os
import random
import warnings
from collections import defaultdict
from pathlib import Path
from PIL import Image, ImageFile
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torch.amp import autocast, GradScaler
import timm
# Подавление предупреждений и защита от битых "хвостов" JPEG
warnings.filterwarnings("ignore")
ImageFile.LOAD_TRUNCATED_IMAGES = True
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {DEVICE}")
# --- ПУТИ ---
TRAIN_ROOT = Path("./dataset/Original-2.41M")
ANCHOR_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/train") # ЯКОРЬ (Чистые данные для обучения)
VAL_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/val")
SAVE_MODEL_PATH = Path("./src/emosetV2_resnet50_finetuned_2_41M.pth")
RESUME_CHECKPOINT = Path("./src/finetuneV2_resume.pth")
PRETRAINED_PATH = Path("./src/emosetV2_resnet50_best.pth")
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
}
# --- НАСТРОЙКИ ---
TOTAL_BATCH_SIZE = 64
BATCH_NOISY = 48 # 75% батча - новые данные 2.41M
BATCH_ANCHOR = 16 # 25% батча - чистые якорные данные 118K
EPOCHS_PER_FOLDER = 15
PATIENCE = 5
LR = 1e-6
NUM_TRAIN_WORKERS = 32
NUM_VAL_WORKERS = 32
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
# --- 1. ТРАНСФОРМАЦИИ ---
train_transform = T.Compose([
T.Resize(256),
T.RandomResizedCrop(224, scale=(0.8, 1.0)),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# --- 2. ДАТАСЕТЫ ---
class ChunkTrainDataset(Dataset):
def __init__(self, paths, transform):
self.paths = paths
self.transform = transform
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
path = self.paths[idx]
try:
img = Image.open(path).convert('RGB')
tensor = self.transform(img)
label = CLASS_MAPPING.get(path.parts[-3].lower(), 0)
return tensor, label
except Exception:
return torch.zeros((3, 224, 224)), 0
class CsvDataset(Dataset):
def __init__(self, root, transform):
self.root = Path(root)
self.df = pd.read_csv(self.root / "labels.csv")
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
path = self.root / "images" / row["filename"]
try:
img = Image.open(path).convert('RGB')
tensor = self.transform(img)
label = CLASS_MAPPING.get(row["label"].lower(), 0)
return tensor, label
except Exception:
return torch.zeros((3, 224, 224)), 0
# --- 3. СБОР ДАННЫХ ---
def prepare_chunks():
print("\nСканирование датасета 2.41M...")
chunk_dict = defaultdict(list)
for path in TRAIN_ROOT.rglob('*.jpg'):
emotion = path.parts[-3].lower()
if emotion not in CLASS_MAPPING:
continue
folder_str = path.parts[-2]
if folder_str.isdigit():
chunk_dict[int(folder_str)].append(path)
sorted_chunks = sorted(chunk_dict.keys())
print(f"Найдено пронумерованных папок (чанков): {len(sorted_chunks)}")
return chunk_dict, sorted_chunks
# --- 4. ОСНОВНОЙ ЦИКЛ ОБУЧЕНИЯ ---
if __name__ == "__main__":
chunk_dict, sorted_chunks = prepare_chunks()
# Валидационный датасет (только чистые данные)
val_loader = DataLoader(
CsvDataset(VAL_118K_ROOT, val_transform),
batch_size=TOTAL_BATCH_SIZE, shuffle=False,
num_workers=NUM_VAL_WORKERS, pin_memory=True
)
# ЯКОРНЫЙ ЗАГРУЗЧИК (Чистые данные для подмешивания)
# Используем prefetch_factor и persistent_workers для устранения рывков CPU
anchor_dataset = CsvDataset(ANCHOR_118K_ROOT, train_transform)
anchor_loader = DataLoader(
anchor_dataset, batch_size=BATCH_ANCHOR, shuffle=True,
num_workers=16, pin_memory=True, drop_last=True,
prefetch_factor=2, persistent_workers=False
)
# Инициализация модели
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
if PRETRAINED_PATH.exists():
model.load_state_dict(torch.load(PRETRAINED_PATH, map_location=DEVICE))
print(f"Базовые веса загружены из {PRETRAINED_PATH.name}")
# Размораживаем всю модель
for param in model.parameters():
param.requires_grad = True
# Дифференцированный оптимизатор
backbone_params = [p for n, p in model.named_parameters() if "fc" not in n]
fc_params = [p for n, p in model.named_parameters() if "fc" in n]
optimizer = torch.optim.AdamW([
{'params': backbone_params, 'lr': LR}, # 1e-6: микро-шаг для основы
{'params': fc_params, 'lr': LR * 10} # 1e-5: шаг для классификатора
], weight_decay=1e-3)
# Label Smoothing помогает игнорировать мусор в разметке 2.41M
criterion = nn.CrossEntropyLoss(label_smoothing=0.15)
scaler = GradScaler()
# --- ПАРАМЕТРЫ ВОССТАНОВЛЕНИЯ ---
start_stage = 0
start_epoch = 1
best_val_loss = float('inf')
if RESUME_CHECKPOINT.exists():
print(f"\nОбнаружен чекпоинт: {RESUME_CHECKPOINT.name}. Восстановление...")
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
try:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
except Exception as e:
print(f"Оптимизатор сброшен: {e}")
best_val_loss = checkpoint['best_val_loss']
start_stage = checkpoint['stage']
start_epoch = checkpoint['epoch'] + 1
print(f"Успешный запуск с ЭТАПА {start_stage + 1}, Эпохи {start_epoch}. Best Val Loss: {best_val_loss:.4f}\n")
else:
# --- ЗАМЕР EPOCH 0 (БАЗОВАЯ ТОЧНОСТЬ) ---
# Выполняется только если мы начинаем с нуля
print("\n[Проверка базовых весов перед обучением (Epoch 0)]")
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc="Baseline Eval", smoothing=0):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
with autocast(device_type="cuda"):
outputs = model(inputs)
v_loss = criterion(outputs, labels)
val_loss += v_loss.item() * inputs.size(0)
_, pred = outputs.max(1)
val_total += labels.size(0)
val_correct += pred.eq(labels).sum().item()
best_val_loss = val_loss / val_total
baseline_acc = val_correct / val_total
print(f"Стартовая точка -> Val Loss: {best_val_loss:.4f} | Val Acc: {baseline_acc:.4f}\n")
# ВОССТАНОВЛЕНИЕ НАКОПЛЕННЫХ ДАННЫХ
current_train_paths = []
for s in range(start_stage):
current_train_paths.extend(chunk_dict[sorted_chunks[s]])
print("Старт Anchor Curriculum Learning (Смешивание чистых и шумных данных).")
# ГЛАВНЫЙ ЦИКЛ ПО ПАПКАМ
for stage in range(start_stage, len(sorted_chunks)):
chunk_id = sorted_chunks[stage]
print(f"\n{'='*50}")
print(f"ЭТАП {stage+1}/{len(sorted_chunks)}: Добавляем папку '{chunk_id}'")
# Накопление и перемешивание
current_train_paths.extend(chunk_dict[chunk_id])
random.shuffle(current_train_paths)
print(f"Всего файлов (грязных) в текущем пуле: {len(current_train_paths)}")
# ОСНОВНОЙ ЗАГРУЗЧИК (Грязные данные) с PREFETCH
train_loader = DataLoader(
ChunkTrainDataset(current_train_paths, train_transform),
batch_size=BATCH_NOISY, shuffle=True,
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
worker_init_fn=worker_init_fn, drop_last=True,
prefetch_factor=4, persistent_workers=True # Устраняет рывки CPU
)
epochs_no_improve = 0
first_epoch = start_epoch if stage == start_stage else 1
# Инициализация итератора якорей
anchor_iter = iter(anchor_loader)
# ЦИКЛ ЭПОХ ДЛЯ ТЕКУЩЕГО ЭТАПА
for epoch in range(first_epoch, EPOCHS_PER_FOLDER + 1):
model.train()
train_loss, train_correct, train_total = 0.0, 0, 0
for noisy_inputs, noisy_labels in tqdm(train_loader, desc=f"S{stage+1}-Ep{epoch}/{EPOCHS_PER_FOLDER} [Train]", smoothing=0):
# Достаем якорный чистый батч
try:
anc_inputs, anc_labels = next(anchor_iter)
except StopIteration:
anchor_iter = iter(anchor_loader)
anc_inputs, anc_labels = next(anchor_iter)
# СМЕШИВАЕМ БАТЧИ (Грязные + Чистые)
inputs = torch.cat([noisy_inputs, anc_inputs]).to(DEVICE)
labels = torch.cat([noisy_labels, anc_labels]).to(DEVICE)
optimizer.zero_grad(set_to_none=True)
with autocast(device_type="cuda"):
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item() * inputs.size(0)
_, pred = outputs.max(1)
train_total += labels.size(0)
train_correct += pred.eq(labels).sum().item()
# ВАЛИДАЦИЯ
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc="[Val]", leave=False, smoothing=0):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
with autocast(device_type="cuda"):
outputs = model(inputs)
v_loss = criterion(outputs, labels)
val_loss += v_loss.item() * inputs.size(0)
_, pred = outputs.max(1)
val_total += labels.size(0)
val_correct += pred.eq(labels).sum().item()
avg_train_loss = train_loss / train_total
avg_train_acc = train_correct / train_total
avg_val_loss = val_loss / val_total
avg_val_acc = val_correct / val_total
print(f"S{stage+1}-E{epoch} | Train L: {avg_train_loss:.4f}, Acc: {avg_train_acc:.4f} | Val L: {avg_val_loss:.4f}, Acc: {avg_val_acc:.4f}")
# СОХРАНЕНИЕ ЛУЧШИХ ВЕСОВ
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
epochs_no_improve = 0
torch.save(model.state_dict(), SAVE_MODEL_PATH)
print("--> Обновлены лучшие веса")
else:
epochs_no_improve += 1
# АВАРИЙНОЕ СОХРАНЕНИЕ В КОНЦЕ ЭПОХИ
checkpoint_state = {
'stage': stage,
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss
}
torch.save(checkpoint_state, RESUME_CHECKPOINT)
os.sync() # Защита от отключения электричества
print(f"--> Чекпоинт (Этап {stage+1}, Эпоха {epoch}) зафиксирован на диске.")
# РАННЯЯ ОСТАНОВКА ДЛЯ ТЕКУЩЕГО ЭТАПА
if epochs_no_improve >= PATIENCE:
print(f"Ранняя остановка для ЭТАПА {stage+1}. Переход к следующей папке...")
break
# Сброс счетчика стартовой эпохи после прохождения восстановительного этапа
start_epoch = 1
-199
View File
@@ -1,199 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "ca08df84",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 0/1000, Loss: 1.0013\n",
"Step 10/1000, Loss: 1.0088\n",
"Step 20/1000, Loss: 0.9956\n",
"Step 30/1000, Loss: 0.9781\n",
"Step 40/1000, Loss: 0.9613\n",
"Step 50/1000, Loss: 0.9313\n",
"Step 60/1000, Loss: 0.8927\n",
"Step 70/1000, Loss: 0.8503\n",
"Step 80/1000, Loss: 0.7537\n",
"Step 90/1000, Loss: 0.6689\n",
"Step 100/1000, Loss: 0.6063\n",
"Step 110/1000, Loss: 0.5172\n",
"Step 120/1000, Loss: 0.4592\n",
"Step 130/1000, Loss: 0.4044\n",
"Step 140/1000, Loss: 0.3610\n",
"Step 150/1000, Loss: 0.3175\n",
"Step 160/1000, Loss: 0.2825\n",
"Step 170/1000, Loss: 0.2560\n",
"Step 180/1000, Loss: 0.2360\n",
"Step 190/1000, Loss: 0.2203\n",
"Step 200/1000, Loss: 0.1930\n",
"Step 210/1000, Loss: 0.1854\n",
"Step 220/1000, Loss: 0.1723\n",
"Step 230/1000, Loss: 0.1546\n",
"Step 240/1000, Loss: 0.1386\n",
"Step 250/1000, Loss: 0.1271\n",
"Step 260/1000, Loss: 0.1109\n",
"Step 270/1000, Loss: 0.1032\n",
"Step 280/1000, Loss: 0.0899\n",
"Step 290/1000, Loss: 0.0807\n",
"Step 300/1000, Loss: 0.0750\n",
"Step 310/1000, Loss: 0.0813\n",
"Step 320/1000, Loss: 0.0612\n",
"Step 330/1000, Loss: 0.0544\n",
"Step 340/1000, Loss: 0.0552\n",
"Step 350/1000, Loss: 0.0446\n",
"Step 360/1000, Loss: 0.0403\n",
"Step 370/1000, Loss: 0.0350\n",
"Step 380/1000, Loss: 0.0612\n",
"Step 390/1000, Loss: 0.0364\n",
"Step 400/1000, Loss: 0.0322\n",
"Step 410/1000, Loss: 0.0302\n",
"Step 420/1000, Loss: 0.0519\n",
"Step 430/1000, Loss: 0.0319\n",
"Step 440/1000, Loss: 0.0260\n",
"Step 450/1000, Loss: 0.0208\n",
"Step 460/1000, Loss: 0.0409\n",
"Step 470/1000, Loss: 0.0291\n",
"Step 480/1000, Loss: 0.0234\n",
"Step 490/1000, Loss: 0.0194\n",
"Step 500/1000, Loss: 0.0274\n",
"Step 510/1000, Loss: 0.0231\n",
"Step 520/1000, Loss: 0.0199\n",
"Step 530/1000, Loss: 0.0154\n",
"Step 540/1000, Loss: 0.0278\n",
"Step 550/1000, Loss: 0.0185\n",
"Step 560/1000, Loss: 0.0180\n",
"Step 570/1000, Loss: 0.0152\n",
"Step 580/1000, Loss: 0.0132\n",
"Step 590/1000, Loss: 0.0111\n",
"Step 600/1000, Loss: 0.0396\n",
"Step 610/1000, Loss: 0.0179\n",
"Step 620/1000, Loss: 0.0148\n",
"Step 630/1000, Loss: 0.0123\n",
"Step 640/1000, Loss: 0.0265\n",
"Step 650/1000, Loss: 0.0133\n",
"Step 660/1000, Loss: 0.0128\n",
"Step 670/1000, Loss: 0.0107\n",
"Step 680/1000, Loss: 0.0142\n",
"Step 690/1000, Loss: 0.0202\n",
"Step 700/1000, Loss: 0.0125\n",
"Step 710/1000, Loss: 0.0107\n",
"Step 720/1000, Loss: 0.0140\n",
"Step 730/1000, Loss: 0.0195\n",
"Step 740/1000, Loss: 0.0148\n",
"Step 750/1000, Loss: 0.0109\n",
"Step 760/1000, Loss: 0.0094\n",
"Step 770/1000, Loss: 0.0121\n",
"Step 780/1000, Loss: 0.0233\n",
"Step 790/1000, Loss: 0.0151\n",
"Step 800/1000, Loss: 0.0134\n",
"Step 810/1000, Loss: 0.0117\n",
"Step 820/1000, Loss: 0.0124\n",
"Step 830/1000, Loss: 0.0221\n",
"Step 840/1000, Loss: 0.0161\n",
"Step 850/1000, Loss: 0.0136\n",
"Step 860/1000, Loss: 0.0161\n",
"Step 870/1000, Loss: 0.0194\n",
"Step 880/1000, Loss: 0.0145\n",
"Step 890/1000, Loss: 0.0149\n",
"Step 900/1000, Loss: 0.0232\n",
"Step 910/1000, Loss: 0.0166\n",
"Step 920/1000, Loss: 0.0156\n",
"Step 930/1000, Loss: 0.0276\n",
"Step 940/1000, Loss: 0.0176\n",
"Step 950/1000, Loss: 0.0152\n",
"Step 960/1000, Loss: 0.0162\n",
"Step 970/1000, Loss: 0.0143\n",
"Step 980/1000, Loss: 0.0136\n",
"Step 990/1000, Loss: 0.0117\n",
"Total time: 67.25 s\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import time\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(\"Using device:\", device)\n",
"\n",
"\n",
"# Огромные параметры\n",
"N, D_in, H1, H2, H3, D_out = 300_000, 4096, 2048, 1024, 512, 10\n",
"batch_size = 16_384 # большой батч\n",
"steps = 1000 # много итераций для длительной нагрузки\n",
"\n",
"# Случайные данные на GPU\n",
"x = torch.randn(N, D_in, device=device, dtype=torch.float32)\n",
"y = torch.randn(N, D_out, device=device, dtype=torch.float32)\n",
"\n",
"model = nn.Sequential(\n",
" nn.Linear(D_in, H1),\n",
" nn.ReLU(),\n",
" nn.Linear(H1, H2),\n",
" nn.ReLU(),\n",
" nn.Linear(H2, H3),\n",
" nn.ReLU(),\n",
" nn.Linear(H3, D_out)\n",
").to(device)\n",
"\n",
"loss_fn = nn.MSELoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
"\n",
"start = time.time()\n",
"for t in range(steps):\n",
" idx = torch.randint(0, N, (batch_size,), device=device)\n",
" x_batch = x[idx]\n",
" y_batch = y[idx]\n",
"\n",
" y_pred = model(x_batch)\n",
" loss = loss_fn(y_pred, y_batch)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if t % 10 == 0:\n",
" # замедляем вывод, чтобы можно было наблюдать\n",
" print(f\"Step {t}/{steps}, Loss: {loss.item():.4f}\")\n",
"\n",
"end = time.time()\n",
"print(f\"Total time: {end-start:.2f} s\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
-759
View File
@@ -1,759 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "9336560f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0c00b67b",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torchvision.transforms as T\n",
"\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"from tqdm import tqdm\n",
"\n",
"import timm\n",
"import numpy as np\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "84c3657f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'cuda'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# === CONFIG ===\n",
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
"BATCH_SIZE = 64 # V100 спокойно тянет\n",
"EPOCHS = 15\n",
"LR = 3e-4\n",
"NUM_WORKERS = 24\n",
"\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"DEVICE\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9f749add",
"metadata": {},
"outputs": [],
"source": [
"class EmoSetDataset(Dataset):\n",
" def __init__(self, root, split):\n",
" self.root = Path(root) / split\n",
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
"\n",
" self.labels = sorted(self.df[\"label\"].unique())\n",
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
"\n",
" self.transform = T.Compose([\n",
" T.Resize((224, 224)),\n",
" T.ToTensor(),\n",
" T.Normalize(\n",
" mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225]\n",
" )\n",
" ])\n",
"\n",
" def __len__(self):\n",
" return len(self.df)\n",
"\n",
" def __getitem__(self, idx):\n",
" row = self.df.iloc[idx]\n",
" img_path = self.root / \"images\" / row[\"filename\"]\n",
"\n",
" img = Image.open(img_path).convert(\"RGB\")\n",
" img = self.transform(img)\n",
"\n",
" label = self.label2idx[row[\"label\"]]\n",
" return img, label\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c8805341",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classes: ['amusement', 'anger', 'awe', 'contentment', 'disgust', 'excitement', 'fear', 'sadness']\n"
]
}
],
"source": [
"train_ds = EmoSetDataset(DATA_ROOT, \"train\")\n",
"val_ds = EmoSetDataset(DATA_ROOT, \"val\")\n",
"\n",
"train_loader = DataLoader(\n",
" train_ds,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" num_workers=NUM_WORKERS,\n",
" pin_memory=True\n",
")\n",
"\n",
"val_loader = DataLoader(\n",
" val_ds,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=False,\n",
" num_workers=NUM_WORKERS,\n",
" pin_memory=True\n",
")\n",
"\n",
"print(\"Classes:\", train_ds.labels)\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dffce582",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ResNet(\n",
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (layer1): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer2): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer3): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (4): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (5): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer4): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
" (fc): Linear(in_features=2048, out_features=8, bias=True)\n",
")"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = timm.create_model(\n",
" \"resnet50\",\n",
" pretrained=True,\n",
" num_classes=len(train_ds.labels)\n",
")\n",
"\n",
"model.to(DEVICE)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "81a457ef",
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"optimizer = torch.optim.AdamW(\n",
" model.parameters(),\n",
" lr=LR,\n",
" weight_decay=1e-4\n",
")\n",
"\n",
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
" optimizer,\n",
" T_max=EPOCHS\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "951aa9e3",
"metadata": {},
"outputs": [],
"source": [
"def train_epoch(model, loader):\n",
" model.train()\n",
" total_loss = 0\n",
" correct = 0\n",
" total = 0\n",
"\n",
" for imgs, labels in tqdm(loader, leave=False):\n",
" imgs = imgs.to(DEVICE)\n",
" labels = labels.to(DEVICE)\n",
"\n",
" optimizer.zero_grad()\n",
" logits = model(imgs)\n",
" loss = criterion(logits, labels)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" total_loss += loss.item() * imgs.size(0)\n",
" preds = logits.argmax(dim=1)\n",
" correct += (preds == labels).sum().item()\n",
" total += labels.size(0)\n",
"\n",
" return total_loss / total, correct / total\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "fb7e9398",
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def val_epoch(model, loader):\n",
" model.eval()\n",
" total_loss = 0\n",
" correct = 0\n",
" total = 0\n",
"\n",
" for imgs, labels in loader:\n",
" imgs = imgs.to(DEVICE)\n",
" labels = labels.to(DEVICE)\n",
"\n",
" logits = model(imgs)\n",
" loss = criterion(logits, labels)\n",
"\n",
" total_loss += loss.item() * imgs.size(0)\n",
" preds = logits.argmax(dim=1)\n",
" correct += (preds == labels).sum().item()\n",
" total += labels.size(0)\n",
"\n",
" return total_loss / total, correct / total\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9e870e5d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/1477 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 01 | Train loss: 0.8383, acc: 0.6954 | Val loss: 0.6694, acc: 0.7563\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 02 | Train loss: 0.5462, acc: 0.7972 | Val loss: 0.6592, acc: 0.7594\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 03 | Train loss: 0.3654, acc: 0.8632 | Val loss: 0.7263, acc: 0.7600\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 04 | Train loss: 0.2111, acc: 0.9230 | Val loss: 0.8572, acc: 0.7472\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 05 | Train loss: 0.1187, acc: 0.9585 | Val loss: 1.0372, acc: 0.7453\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 06 | Train loss: 0.0690, acc: 0.9768 | Val loss: 1.1982, acc: 0.7529\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 07 | Train loss: 0.0466, acc: 0.9843 | Val loss: 1.3178, acc: 0.7492\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 08 | Train loss: 0.0295, acc: 0.9905 | Val loss: 1.3926, acc: 0.7551\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 09 | Train loss: 0.0204, acc: 0.9938 | Val loss: 1.4682, acc: 0.7497\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10 | Train loss: 0.0146, acc: 0.9955 | Val loss: 1.4784, acc: 0.7604\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 11 | Train loss: 0.0087, acc: 0.9975 | Val loss: 1.5263, acc: 0.7580\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 12 | Train loss: 0.0057, acc: 0.9987 | Val loss: 1.5689, acc: 0.7558\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 13 | Train loss: 0.0044, acc: 0.9990 | Val loss: 1.5952, acc: 0.7566\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 14 | Train loss: 0.0030, acc: 0.9993 | Val loss: 1.6130, acc: 0.7600\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 15 | Train loss: 0.0025, acc: 0.9995 | Val loss: 1.5921, acc: 0.7627\n"
]
}
],
"source": [
"best_val_acc = 0.0\n",
"\n",
"for epoch in range(1, EPOCHS + 1):\n",
" train_loss, train_acc = train_epoch(model, train_loader)\n",
" val_loss, val_acc = val_epoch(model, val_loader)\n",
"\n",
" scheduler.step()\n",
"\n",
" print(\n",
" f\"Epoch {epoch:02d} | \"\n",
" f\"Train loss: {train_loss:.4f}, acc: {train_acc:.4f} | \"\n",
" f\"Val loss: {val_loss:.4f}, acc: {val_acc:.4f}\"\n",
" )\n",
"\n",
" if val_acc > best_val_acc:\n",
" best_val_acc = val_acc\n",
" torch.save(model.state_dict(), \"emoset_resnet50_best.pth\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7796ef11",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "thesis-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
-65
View File
@@ -1,65 +0,0 @@
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.linear_model import RidgeCV
from sklearn.multioutput import MultiOutputRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import joblib
# 1. Алфавитный маппинг EmoSet
EMO_VA_MAP = {
0: (7.5, 6.5), # amusement
1: (2.0, 8.0), # anger
2: (6.5, 5.0), # awe
3: (7.0, 3.0), # contentment
4: (3.0, 6.0), # disgust
5: (8.0, 8.0), # excitement
6: (2.5, 7.5), # fear
7: (2.0, 2.0), # sadness
}
BASE_DIR = Path(__file__).resolve().parent.parent
EMBEDDINGS_PATH = BASE_DIR / "emoset_test_embeddings.npy"
LABELS_PATH = BASE_DIR / "emoset_test_labels.npy"
print("Загрузка данных...")
X = np.load(EMBEDDINGS_PATH)
y_labels = np.load(LABELS_PATH)
y_va = np.array([EMO_VA_MAP[label] for label in y_labels])
X_train, X_test, y_train, y_test = train_test_split(X, y_va, test_size=0.2, random_state=42)
# 2. НОВАЯ, ПРАВИЛЬНАЯ АРХИТЕКТУРА (Pipeline)
print("Обучение масштабатора и RidgeCV регрессора...")
# Pipeline гарантирует, что при предсказании в main.py новые векторы тоже будут масштабированы
model = Pipeline([
('scaler', StandardScaler()),
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
])
model.fit(X_train, y_train)
# 3. Диагностика и Оценка
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"\n[УСПЕХ] Обучение завершено!")
print(f"MSE: {mse:.4f}")
print(f"R^2 Score: {r2:.4f}")
# === ТОТ САМЫЙ ТЕСТ НА КОЛЛАПС ===
print("\n--- ДИАГНОСТИКА РАЗБРОСА ПРЕДСКАЗАНИЙ ---")
print(f"Valence: от {y_pred[:, 0].min():.2f} до {y_pred[:, 0].max():.2f} (Эталон: 2.0 - 8.0)")
print(f"Arousal: от {y_pred[:, 1].min():.2f} до {y_pred[:, 1].max():.2f} (Эталон: 2.0 - 8.0)")
# ===============================================
# 4. Сохранение (Pipeline сохраняется целиком со StandardScaler)
output_model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
output_model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(model, output_model_path)
print(f"\nМодель сохранена в: {output_model_path}")
+2 -2
View File
@@ -42,7 +42,7 @@ def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_pa
st.session_state.ds_current_options = random.sample(range(len(image_files)), 6) st.session_state.ds_current_options = random.sample(range(len(image_files)), 6)
st.rerun() st.rerun()
else: else:
st.success("Анализ завершен! Ваш эмоциональный профиль готов.") st.success("Анализ завершен! Ваш эмоциональный профиль готов.")
all_v, all_a = [], [] all_v, all_a = [], []
for idx in st.session_state.ds_chosen_indices: for idx in st.session_state.ds_chosen_indices:
@@ -56,7 +56,7 @@ def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_pa
col_left, col_right = st.columns([1, 2]) col_left, col_right = st.columns([1, 2])
with col_left: with col_left:
st.header("📊 Ваш профиль") st.header("Ваш профиль")
st.metric("Позитивность (Valence)", f"{target_v:.2f}") st.metric("Позитивность (Valence)", f"{target_v:.2f}")
st.metric("Энергия (Arousal)", f"{target_a:.2f}") st.metric("Энергия (Arousal)", f"{target_a:.2f}")
+162 -36
View File
@@ -1,75 +1,162 @@
import streamlit as st import streamlit as st
import streamlit.components.v1 as components
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import matplotlib.pyplot as plt import base64
from music_engine.llm_bridge import LLMAcousticBridge # ИМПОРТИРУЕМ МОСТ from io import BytesIO
from music_engine.llm_bridge import LLMAcousticBridge
# Вспомогательная функция для крохотного предпросмотра
def get_thumbnail_html(images, max_display=12):
html_images = ""
for file in images[:max_display]:
img = Image.open(file)
img.thumbnail((100, 100)) # Сжимаем картинку
if img.mode != "RGB":
img = img.convert("RGB")
buffered = BytesIO()
img.save(buffered, format="JPEG")
b64_str = base64.b64encode(buffered.getvalue()).decode()
# Строгие стили для квадратных миниатюр
html_images += f'<img src="data:image/jpeg;base64,{b64_str}" style="width: 60px; height: 60px; object-fit: cover; border-radius: 8px; margin-right: 8px; margin-bottom: 8px; border: 1px solid rgba(255, 255, 255, 0.2);">'
# Индикатор оставшихся фото, если их много
if len(images) > max_display:
html_images += f'<span style="display: inline-block; width: 60px; height: 60px; line-height: 60px; text-align: center; background: rgba(150, 150, 150, 0.2); border-radius: 8px; vertical-align: top; font-size: 14px;">+{len(images) - max_display}</span>'
return f'<div style="display: flex; flex-wrap: wrap;">{html_images}</div>'
def render_live_tab(matcher, image_processor): def render_live_tab(matcher, image_processor):
if "live_state" not in st.session_state:
st.session_state.live_state = "upload"
if "result_data" not in st.session_state:
st.session_state.result_data = None
viewport = st.query_params.get("viewport", "desktop")
# ==========================================
# CSS ИНЪЕКЦИИ
# ==========================================
st.markdown("""
<style>
[data-testid="stFileUploadDropzone"] {
min-height: 250px !important;
display: flex;
align-items: center;
justify-content: center;
border-radius: 16px;
background-color: rgba(255, 75, 75, 0.03);
}
.spinner-container {
display: flex; flex-direction: column; align-items: center;
justify-content: center; min-height: 40vh; margin-top: 10vh;
}
.big-spinner {
width: 120px; height: 120px; border: 10px solid rgba(255, 75, 75, 0.1);
border-top: 10px solid #ff4b4b; border-radius: 50%;
animation: spin 1s linear infinite; margin-bottom: 2rem;
}
@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
</style>
""", unsafe_allow_html=True)
# ==========================================
# ЭКРАН 1: ЗАГРУЗКА
# ==========================================
if st.session_state.live_state == "upload":
upload_placeholder = st.empty()
with upload_placeholder.container():
st.write("Загрузите фотографии с вашего устройства. Система проанализирует эмоции и семантику кадра.") st.write("Загрузите фотографии с вашего устройства. Система проанализирует эмоции и семантику кадра.")
if viewport == "mobile":
st.markdown("<br>", unsafe_allow_html=True)
uploaded_files = st.file_uploader( uploaded_files = st.file_uploader(
"Перетащите изображения сюда", "Перетащите изображения сюда",
type=['png', 'jpg', 'jpeg'], type=['png', 'jpg', 'jpeg'],
accept_multiple_files=True accept_multiple_files=True,
label_visibility="collapsed" if viewport == "mobile" else "visible"
) )
if uploaded_files: if uploaded_files:
st.subheader("Анализ визуальных признаков:") # 1. КНОПКА СРАЗУ ПОСЛЕ ЗАГРУЗКИ (Не нужно скроллить вниз)
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Сгенерировать саундтрек", type="primary", use_container_width=True):
st.session_state.uploaded_images = uploaded_files
st.session_state.live_state = "processing"
upload_placeholder.empty()
st.rerun()
# 2. МИНИАТЮРЫ ПОД КНОПКОЙ
st.markdown("<br>", unsafe_allow_html=True)
st.caption("Выбранные кадры:")
# Генерируем компактный блок миниатюр
st.markdown(get_thumbnail_html(uploaded_files), unsafe_allow_html=True)
# ==========================================
# ЭКРАН 2: АНАЛИЗ (СПИННЕР)
# ==========================================
elif st.session_state.live_state == "processing":
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
files = st.session_state.get("uploaded_images", [])
st.markdown('<div class="spinner-container"><div class="big-spinner"></div></div>', unsafe_allow_html=True)
status_text = st.empty()
cols = st.columns(min(len(uploaded_files), 5))
images = [] images = []
all_objects = [] all_objects = []
all_v, all_a = [], []
for i, file in enumerate(files):
status_text.markdown(f"<h3 style='text-align: center; font-weight: 400;'>Анализ кадра {i + 1} из {len(files)}...</h3>", unsafe_allow_html=True)
for i, file in enumerate(uploaded_files):
img = Image.open(file) img = Image.open(file)
images.append(img) images.append(img)
with cols[i % 5]:
st.image(img, use_container_width=True)
with st.spinner("VLM Анализ..."):
caption = image_processor.describe_scene(img)
st.caption(f"👁️ *{caption.capitalize()}*")
all_objects.append(caption)
if st.button("🎵 Сгенерировать саундтрек", type="primary", use_container_width=True):
# 1. Извлекаем эмоции
all_v, all_a = [], []
for img in images:
embedding = image_processor.extract_embedding(img) embedding = image_processor.extract_embedding(img)
v, a = matcher.predict_va(embedding) v, a = matcher.predict_va(embedding)
all_v.append(v) all_v.append(v)
all_a.append(a) all_a.append(a)
caption = image_processor.describe_scene(img)
all_objects.append(caption)
target_v, target_a = np.mean(all_v), np.mean(all_a) target_v, target_a = np.mean(all_v), np.mean(all_a)
# 2. Переводим Объекты -> Акустику через LLM status_text.markdown("<h3 style='text-align: center; font-weight: 400;'>Трансляция семантики в аудиопрофиль...</h3>", unsafe_allow_html=True)
with st.spinner("Phi-3 генерирует акустический профиль..."):
llm = LLMAcousticBridge() llm = LLMAcousticBridge()
llm_profile = llm.get_acoustic_profile(target_v, target_a, list(set(all_objects))) llm_profile = llm.get_acoustic_profile(target_v, target_a, list(set(all_objects)))
# 3. Ищем треки status_text.markdown("<h3 style='text-align: center; font-weight: 400;'>Поиск идеальных композиций...</h3>", unsafe_allow_html=True)
with st.spinner("Поиск треков в базе DEAM..."): playlist = matcher.find_nearest_tracks(target_v, target_a, llm_profile=llm_profile, top_k=15)
playlist = matcher.find_nearest_tracks(target_v, target_a, llm_profile=llm_profile, top_k=5)
st.success("✅ Кросс-модальный анализ завершен!") st.session_state.result_data = {
"target_v": target_v,
"target_a": target_a,
"llm_profile": llm_profile,
"playlist": playlist,
"semantics": list(set(all_objects))
}
st.session_state.live_state = "result"
st.rerun()
# ВЫВОД РЕЗУЛЬТАТОВ # ==========================================
col_left, col_right = st.columns([1, 2]) # ЭКРАН 3: РЕЗУЛЬТАТЫ
# ==========================================
elif st.session_state.live_state == "result":
with col_left: components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
st.header("📊 Профиль")
st.metric("Valence (Настроение)", f"{target_v:.2f}")
st.metric("Arousal (Энергия)", f"{target_a:.2f}")
if llm_profile: data = st.session_state.result_data
st.write("**Требования LLM к звуку:**") st.header("Рекомендованный плейлист")
for k, v in llm_profile.items():
st.caption(f"- {k}: {v:.2f}")
with col_right: for _, row in data["playlist"].iterrows():
st.header("🎵 Плейлист")
for _, row in playlist.iterrows():
with st.container(border=True): with st.container(border=True):
if viewport == "desktop":
c1, c2 = st.columns([1, 3]) c1, c2 = st.columns([1, 3])
with c1: with c1:
st.write(f"**Track:** {int(row['song_id'])}") st.write(f"**Track:** {int(row['song_id'])}")
@@ -80,3 +167,42 @@ def render_live_tab(matcher, image_processor):
st.audio(str(audio_path)) st.audio(str(audio_path))
else: else:
st.warning("Файл не найден") st.warning("Файл не найден")
else:
st.write(f"**Track:** {int(row['song_id'])} (Score: {row['final_score']:.2f})")
audio_path = matcher.get_audio_path(row['song_id'])
if audio_path:
st.audio(str(audio_path))
else:
st.warning("Файл не найден")
st.markdown("<br>", unsafe_allow_html=True)
with st.expander("Технические параметры анализа"):
c_v, c_a = st.columns(2)
c_v.metric("Valence (Настроение)", f"{data['target_v']:.2f}")
c_a.metric("Arousal (Энергия)", f"{data['target_a']:.2f}")
st.markdown("---")
st.write("**Акустические таргеты (LLM):**")
if data["llm_profile"]:
cols_per_row = 2 if viewport == "mobile" else 3
llm_items = list(data["llm_profile"].items())
for i in range(0, len(llm_items), cols_per_row):
cols = st.columns(cols_per_row)
for j in range(cols_per_row):
if i + j < len(llm_items):
k, v = llm_items[i + j]
cols[j].metric(k, f"{v:.2f}")
st.markdown("---")
st.write("**Обнаруженная семантика:**")
st.write(", ".join([str(c).capitalize() for c in data["semantics"]]))
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Новый анализ", use_container_width=True):
st.session_state.live_state = "upload"
st.session_state.result_data = None
st.session_state.pop("uploaded_images", None)
st.rerun()
+461
View File
@@ -0,0 +1,461 @@
.
├── bin
│   ├── activate
│   ├── activate.csh
│   ├── activate.fish
│   ├── activate.nu
│   ├── activate.ps1
│   ├── activate_this.py
│   ├── debugpy
│   ├── debugpy-adapter
│   ├── f2py
│   ├── fonttools
│   ├── httpx
│   ├── ipython
│   ├── ipython3
│   ├── isympy
│   ├── jlpm
│   ├── jsonpointer
│   ├── jsonschema
│   ├── jupyter
│   ├── jupyter-dejavu
│   ├── jupyter-events
│   ├── jupyter-execute
│   ├── jupyter-kernel
│   ├── jupyter-kernelspec
│   ├── jupyter-lab
│   ├── jupyter-labextension
│   ├── jupyter-labhub
│   ├── jupyter-migrate
│   ├── jupyter-nbconvert
│   ├── jupyter-run
│   ├── jupyter-server
│   ├── jupyter-troubleshoot
│   ├── jupyter-trust
│   ├── normalizer
│   ├── numpy-config
│   ├── pip
│   ├── pip3
│   ├── pip3.12
│   ├── proton
│   ├── proton-viewer
│   ├── pybabel
│   ├── pyftmerge
│   ├── pyftsubset
│   ├── pygmentize
│   ├── pyjson5
│   ├── python -> /usr/bin/python3
│   ├── python3 -> python
│   ├── python3.12 -> python
│   ├── send2trash
│   ├── streamlit
│   ├── streamlit.cmd
│   ├── torchfrtrace
│   ├── torchrun
│   ├── tqdm
│   ├── ttx
│   ├── watchmedo
│   └── wsdump
├── CACHEDIR.TAG
├── docker
│   ├── Dockerfile.api
│   └── Dockerfile.ui
├── docker-compose.yml
├── Dockerfile
├── .dockerignore
├── .env
├── etc
│   └── jupyter
│   ├── jupyter_notebook_config.d
│   │   └── jupyterlab.json
│   ├── jupyter_server_config.d
│   │   ├── jupyterlab.json
│   │   ├── jupyter-lsp-jupyter-server.json
│   │   ├── jupyter_server_terminals.json
│   │   └── notebook_shim.json
│   └── nbconfig
│   └── notebook.d
├── .gitignore
├── .idea
│   ├── .gitignore
│   ├── inspectionProfiles
│   │   └── profiles_settings.xml
│   ├── misc.xml
│   ├── modules.xml
│   ├── Thesis.iml
│   ├── vcs.xml
│   └── workspace.xml
├── lib
│   └── python3.12
│   └── site-packages
│   ├── altair
│   ├── altair-6.0.0.dist-info
│   ├── anyio
│   ├── anyio-4.12.1.dist-info
│   ├── argon2
│   ├── argon2_cffi-25.1.0.dist-info
│   ├── _argon2_cffi_bindings
│   ├── argon2_cffi_bindings-25.1.0.dist-info
│   ├── arrow
│   ├── arrow-1.4.0.dist-info
│   ├── asttokens
│   ├── asttokens-3.0.1.dist-info
│   ├── async_lru
│   ├── async_lru-2.0.5.dist-info
│   ├── attr
│   ├── attrs
│   ├── attrs-25.4.0.dist-info
│   ├── babel
│   ├── babel-2.17.0.dist-info
│   ├── beautifulsoup4-4.14.3.dist-info
│   ├── bleach
│   ├── bleach-6.3.0.dist-info
│   ├── blinker
│   ├── blinker-1.9.0.dist-info
│   ├── bs4
│   ├── cachetools
│   ├── cachetools-6.2.4.dist-info
│   ├── certifi
│   ├── certifi-2026.1.4.dist-info
│   ├── cffi
│   ├── cffi-2.0.0.dist-info
│   ├── _cffi_backend.cpython-312-x86_64-linux-gnu.so
│   ├── charset_normalizer
│   ├── charset_normalizer-3.4.4.dist-info
│   ├── click
│   ├── click-8.3.1.dist-info
│   ├── comm
│   ├── comm-0.2.3.dist-info
│   ├── contourpy
│   ├── contourpy-1.3.3.dist-info
│   ├── cycler
│   ├── cycler-0.12.1.dist-info
│   ├── dateutil
│   ├── debugpy
│   ├── debugpy-1.8.19.dist-info
│   ├── decorator-5.2.1.dist-info
│   ├── decorator.py
│   ├── defusedxml
│   ├── defusedxml-0.7.1.dist-info
│   ├── _distutils_hack
│   ├── distutils-precedence.pth
│   ├── .DS_Store
│   ├── executing
│   ├── executing-2.2.1.dist-info
│   ├── fastjsonschema
│   ├── fastjsonschema-2.21.2.dist-info
│   ├── filelock
│   ├── filelock-3.20.3.dist-info
│   ├── fontTools
│   ├── fonttools-4.61.1.dist-info
│   ├── fqdn
│   ├── fqdn-1.5.1.dist-info
│   ├── fsspec
│   ├── fsspec-2026.1.0.dist-info
│   ├── functorch
│   ├── git
│   ├── gitdb
│   ├── gitdb-4.0.12.dist-info
│   ├── gitpython-3.1.46.dist-info
│   ├── google
│   ├── h11
│   ├── h11-0.16.0.dist-info
│   ├── httpcore
│   ├── httpcore-1.0.9.dist-info
│   ├── httpx
│   ├── httpx-0.28.1.dist-info
│   ├── idna
│   ├── idna-3.11.dist-info
│   ├── ipykernel
│   ├── ipykernel-7.1.0.dist-info
│   ├── ipykernel_launcher.py
│   ├── IPython
│   ├── ipython-9.9.0.dist-info
│   ├── ipython_pygments_lexers-1.1.1.dist-info
│   ├── ipython_pygments_lexers.py
│   ├── isoduration
│   ├── isoduration-20.11.0.dist-info
│   ├── isympy.py
│   ├── jedi
│   ├── jedi-0.19.2.dist-info
│   ├── jinja2
│   ├── jinja2-3.1.6.dist-info
│   ├── joblib
│   ├── joblib-1.5.3.dist-info
│   ├── json5
│   ├── json5-0.13.0.dist-info
│   ├── jsonpointer-3.0.0.dist-info
│   ├── jsonpointer.py
│   ├── jsonschema
│   ├── jsonschema-4.26.0.dist-info
│   ├── jsonschema_specifications
│   ├── jsonschema_specifications-2025.9.1.dist-info
│   ├── jupyter_client
│   ├── jupyter_client-8.8.0.dist-info
│   ├── jupyter_core
│   ├── jupyter_core-5.9.1.dist-info
│   ├── jupyter_events
│   ├── jupyter_events-0.12.0.dist-info
│   ├── jupyterlab
│   ├── jupyterlab-4.5.1.dist-info
│   ├── jupyterlab_pygments
│   ├── jupyterlab_pygments-0.3.0.dist-info
│   ├── jupyterlab_server
│   ├── jupyterlab_server-2.28.0.dist-info
│   ├── jupyter_lsp
│   ├── jupyter_lsp-2.3.0.dist-info
│   ├── jupyter.py
│   ├── jupyter_server
│   ├── jupyter_server-2.17.0.dist-info
│   ├── jupyter_server_terminals
│   ├── jupyter_server_terminals-0.5.3.dist-info
│   ├── kiwisolver
│   ├── kiwisolver-1.4.9.dist-info
│   ├── lark
│   ├── lark-1.3.1.dist-info
│   ├── markupsafe
│   ├── markupsafe-3.0.3.dist-info
│   ├── matplotlib
│   ├── matplotlib-3.10.8.dist-info
│   ├── matplotlib_inline
│   ├── matplotlib_inline-0.2.1.dist-info
│   ├── mistune
│   ├── mistune-3.2.0.dist-info
│   ├── mpl_toolkits
│   ├── mpmath
│   ├── mpmath-1.3.0.dist-info
│   ├── narwhals
│   ├── narwhals-2.15.0.dist-info
│   ├── nbclient
│   ├── nbclient-0.10.4.dist-info
│   ├── nbconvert
│   ├── nbconvert-7.16.6.dist-info
│   ├── nbformat
│   ├── nbformat-5.10.4.dist-info
│   ├── nest_asyncio-1.6.0.dist-info
│   ├── nest_asyncio.py
│   ├── networkx
│   ├── networkx-3.6.1.dist-info
│   ├── notebook_shim
│   ├── notebook_shim-0.2.4.dist-info
│   ├── numpy
│   ├── numpy-2.4.1.dist-info
│   ├── numpy.libs
│   ├── nvidia
│   ├── nvidia_cublas_cu12-12.8.4.1.dist-info
│   ├── nvidia_cuda_cupti_cu12-12.8.90.dist-info
│   ├── nvidia_cuda_nvrtc_cu12-12.8.93.dist-info
│   ├── nvidia_cuda_runtime_cu12-12.8.90.dist-info
│   ├── nvidia_cudnn_cu12-9.10.2.21.dist-info
│   ├── nvidia_cufft_cu12-11.3.3.83.dist-info
│   ├── nvidia_cufile_cu12-1.13.1.3.dist-info
│   ├── nvidia_curand_cu12-10.3.9.90.dist-info
│   ├── nvidia_cusolver_cu12-11.7.3.90.dist-info
│   ├── nvidia_cusparse_cu12-12.5.8.93.dist-info
│   ├── nvidia_cusparselt_cu12-0.7.1.dist-info
│   ├── nvidia_nccl_cu12-2.27.5.dist-info
│   ├── nvidia_nvjitlink_cu12-12.8.93.dist-info
│   ├── nvidia_nvshmem_cu12-3.3.20.dist-info
│   ├── nvidia_nvtx_cu12-12.8.90.dist-info
│   ├── packaging
│   ├── packaging-25.0.dist-info
│   ├── pandas
│   ├── pandas-2.3.3.dist-info
│   ├── pandocfilters-1.5.1.dist-info
│   ├── pandocfilters.py
│   ├── parso
│   ├── parso-0.8.5.dist-info
│   ├── pexpect
│   ├── pexpect-4.9.0.dist-info
│   ├── PIL
│   ├── pillow-12.1.0.dist-info
│   ├── pillow.libs
│   ├── pip
│   ├── pip-25.3.dist-info
│   ├── pkg_resources
│   ├── platformdirs
│   ├── platformdirs-4.5.1.dist-info
│   ├── prometheus_client
│   ├── prometheus_client-0.23.1.dist-info
│   ├── prompt_toolkit
│   ├── prompt_toolkit-3.0.52.dist-info
│   ├── protobuf-6.33.4.dist-info
│   ├── psutil
│   ├── psutil-7.2.1.dist-info
│   ├── ptyprocess
│   ├── ptyprocess-0.7.0.dist-info
│   ├── pure_eval
│   ├── pure_eval-0.2.3.dist-info
│   ├── pyarrow
│   ├── pyarrow-22.0.0.dist-info
│   ├── pycparser
│   ├── pycparser-2.23.dist-info
│   ├── pydeck
│   ├── pydeck-0.9.1.dist-info
│   ├── pygments
│   ├── pygments-2.19.2.dist-info
│   ├── pylab.py
│   ├── pyparsing
│   ├── pyparsing-3.3.1.dist-info
│   ├── python_dateutil-2.9.0.post0.dist-info
│   ├── pythonjsonlogger
│   ├── python_json_logger-4.0.0.dist-info
│   ├── pytz
│   ├── pytz-2025.2.dist-info
│   ├── pyyaml-6.0.3.dist-info
│   ├── pyzmq-27.1.0.dist-info
│   ├── pyzmq.libs
│   ├── referencing
│   ├── referencing-0.37.0.dist-info
│   ├── requests
│   ├── requests-2.32.5.dist-info
│   ├── rfc3339_validator-0.1.4.dist-info
│   ├── rfc3339_validator.py
│   ├── rfc3986_validator-0.1.1.dist-info
│   ├── rfc3986_validator.py
│   ├── rfc3987_syntax
│   ├── rfc3987_syntax-1.1.0.dist-info
│   ├── rpds
│   ├── rpds_py-0.30.0.dist-info
│   ├── scikit_learn-1.8.0.dist-info
│   ├── scikit_learn.libs
│   ├── scipy
│   ├── scipy-1.17.0.dist-info
│   ├── scipy.libs
│   ├── send2trash
│   ├── send2trash-2.0.0.dist-info
│   ├── setuptools
│   ├── setuptools-80.9.0.dist-info
│   ├── six-1.17.0.dist-info
│   ├── six.py
│   ├── sklearn
│   ├── smmap
│   ├── smmap-5.0.2.dist-info
│   ├── soupsieve
│   ├── soupsieve-2.8.1.dist-info
│   ├── stack_data
│   ├── stack_data-0.6.3.dist-info
│   ├── streamlit
│   ├── streamlit-1.53.0.dist-info
│   ├── sympy
│   ├── sympy-1.14.0.dist-info
│   ├── tenacity
│   ├── tenacity-9.1.2.dist-info
│   ├── terminado
│   ├── terminado-0.18.1.dist-info
│   ├── threadpoolctl-3.6.0.dist-info
│   ├── threadpoolctl.py
│   ├── tinycss2
│   ├── tinycss2-1.4.0.dist-info
│   ├── toml
│   ├── toml-0.10.2.dist-info
│   ├── torch
│   ├── torch-2.9.1.dist-info
│   ├── torchaudio
│   ├── torchaudio-2.9.1.dist-info
│   ├── torchgen
│   ├── torchvision
│   ├── torchvision-0.24.1.dist-info
│   ├── torchvision.libs
│   ├── tornado
│   ├── tornado-6.5.4.dist-info
│   ├── tqdm
│   ├── tqdm-4.67.1.dist-info
│   ├── traitlets
│   ├── traitlets-5.14.3.dist-info
│   ├── triton
│   ├── triton-3.5.1.dist-info
│   ├── typing_extensions-4.15.0.dist-info
│   ├── typing_extensions.py
│   ├── tzdata
│   ├── tzdata-2025.3.dist-info
│   ├── uri_template
│   ├── uri_template-1.3.0.dist-info
│   ├── urllib3
│   ├── urllib3-2.6.3.dist-info
│   ├── _virtualenv.pth
│   ├── _virtualenv.py
│   ├── watchdog
│   ├── watchdog-6.0.0.dist-info
│   ├── wcwidth
│   ├── wcwidth-0.2.14.dist-info
│   ├── webcolors
│   ├── webcolors-25.10.0.dist-info
│   ├── webencodings
│   ├── webencodings-0.5.1.dist-info
│   ├── websocket
│   ├── websocket_client-1.9.0.dist-info
│   ├── _yaml
│   ├── yaml
│   └── zmq
├── Makefile
├── NFS
├── poetry.lock
├── pyproject.toml
├── pyvenv.cfg
├── README.md
├── requirements.txt
├── runs
├── share
│   ├── applications
│   │   └── jupyterlab.desktop
│   ├── icons
│   │   └── hicolor
│   │   └── scalable
│   ├── jupyter
│   │   ├── kernels
│   │   │   └── python3
│   │   ├── lab
│   │   │   ├── schemas
│   │   │   ├── static
│   │   │   └── themes
│   │   ├── labextensions
│   │   │   └── jupyterlab_pygments
│   │   ├── nbconvert
│   │   │   └── templates
│   │   └── nbextensions
│   │   └── pydeck
│   └── man
│   └── man1
│   ├── ipython.1
│   ├── isympy.1
│   └── ttx.1
├── src
│   ├── 5_epoch_emoset_resnet50_finetuned_2.41M.pth
│   ├── api.py
│   ├── data_loader.py
│   ├── dataset_paths_cache.pkl
│   ├── emoset_resnet50_best.pth
│   ├── emoset_resnet50_finetuned_2_41M_best.pth
│   ├── emoset_resnet50_resume.pth
│   ├── emoset_test_embeddings.npy
│   ├── emoset_test_labels.npy
│   ├── main.py
│   ├── music_engine
│   │   ├── image_processor.py
│   │   ├── __init__.py
│   │   ├── llm_bridge.py
│   │   ├── matcher.py
│   │   └── va_regressor.pkl
│   ├── scripts
│   │   ├── 00_setup_env.sh
│   │   ├── 01_download_DEAM.py
│   │   ├── 02_download_EmoSet.py
│   │   ├── 11_prerp_DEAM.py
│   │   ├── 20_bench_GPU.py
│   │   ├── 21_train_images.ipynb
│   │   ├── 22_extract_embeddings.ipynb
│   │   ├── 23_aggregate_DEAM_timeline.py
│   │   ├── 24_train_regressor.py
│   │   ├── 31_finetune_2.41M.py
│   │   ├── 90_acc_images_model.ipynb
│   │   └── 91_generate_metrics.py
│   └── tabs
│   ├── tab_dataset.py
│   └── tab_live.py
├── tree.txt
└── .vscode
├── launch.json
└── tasks.json
322 directories, 137 files