Compare commits

..

7 Commits

Author SHA1 Message Date
zin 934a4cbff4 feat: add metrics 2026-06-16 04:59:51 +00:00
zin 14968dd4d4 feat: commiting model 2026-06-08 14:49:45 +00:00
zin daba573b2c feat: improving finetuning 2026-06-06 21:06:21 +00:00
zin 8648e52106 feat: refactor code and finetune OOM fix 2026-06-03 09:34:34 +00:00
zin e3a2eb3289 Merge pull request 'Docker integration into project' (#1) from Dockered into main
Reviewed-on: #1
2026-06-03 12:18:10 +03:00
zin a57addcbb1 feat: finale 2026-06-03 09:16:12 +00:00
zin 3850b15053 feat: Init 2026-06-02 22:39:11 +00:00
28 changed files with 1972 additions and 1111 deletions
+18
View File
@@ -0,0 +1,18 @@
bin/
lib/
share/
etc/
include/
pyvenv.cfg
.idea/
.vscode/
__pycache__/
*.pyc
.git/
runs/
dataset/
NFS/
*.pth
*.pkl
*.npy
.env
+3 -7
View File
@@ -1,13 +1,11 @@
# Базовый образ среды выполнения PyTorch
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime
# Конфигурация интерпретатора Python (отключение генерации байткода и буферизации вывода)
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
WORKDIR /app
# Системные библиотеки для низкоуровневой обработки изображений
# System dependencies for OpenCV and image processing
RUN apt-get update && apt-get install -y \
libglib2.0-0 \
libsm6 \
@@ -15,15 +13,13 @@ RUN apt-get update && apt-get install -y \
libxrender-dev \
&& rm -rf /var/lib/apt/lists/*
# Интеграция Python-зависимостей
# Install python packages
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Модули программного комплекса
# Copy source code
COPY src/ /app/src/
# Сетевой интерфейс UI
EXPOSE 8080
# Точка входа контейнера
CMD ["streamlit", "run", "src/main.py", "--server.port", "8080", "--server.address", "0.0.0.0"]
+63
View File
@@ -0,0 +1,63 @@
version: '3.8'
networks:
emom_mesh:
driver: bridge
services:
emom_ui:
build:
context: .
dockerfile: docker/Dockerfile.ui
container_name: emom_web_ui
restart: unless-stopped
ports:
- "8080:8080"
networks:
- emom_mesh
env_file:
- .env
volumes:
- ./src:/app/src
- ${DATA_DEAM_DIR}:/app/dataset/DEAM:ro
depends_on:
- emom_inference
emom_inference:
build:
context: .
dockerfile: docker/Dockerfile.api
container_name: emom_pytorch_api
restart: unless-stopped
networks:
- emom_mesh
env_file:
- .env
volumes:
- ${HOST_ARTIFACTS_DIR}/emoset_resnet50_best.pth:/app/src/emoset_resnet50_best.pth:ro
- ${HOST_ARTIFACTS_DIR}/music_engine/va_regressor.pkl:/app/src/music_engine/va_regressor.pkl:ro
- ${DATA_DEAM_DIR}:/app/dataset/DEAM:ro
- ~/.cache/huggingface:/root/.cache/huggingface
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
emom_ollama:
image: ollama/ollama:latest
container_name: emom_ollama_engine
restart: unless-stopped
networks:
- emom_mesh
volumes:
- ~/.ollama:/root/.ollama
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
+19
View File
@@ -0,0 +1,19 @@
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
RUN apt-get update && apt-get install -y \
libglib2.0-0 libsm6 libxext6 libxrender-dev \
&& rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir fastapi uvicorn timm scikit-learn pandas joblib python-multipart transformers==4.38.2 tokenizers==0.15.2 accelerate
WORKDIR /app
COPY src/ /app/src/
WORKDIR /app/src
EXPOSE 8000
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8000"]
+15
View File
@@ -0,0 +1,15 @@
FROM python:3.12-slim
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
WORKDIR /app
RUN pip install --no-cache-dir streamlit==1.32.0 requests pandas pillow
COPY src/ /app/src/
WORKDIR /app/src
EXPOSE 8080
CMD ["streamlit", "run", "main.py", "--server.port", "8080", "--server.address", "0.0.0.0"]
-64
View File
@@ -1,64 +0,0 @@
version: '3.8'
# Определение общих сетей для изоляции трафика
networks:
ai_mesh:
driver: bridge
services:
# ----------------------------------------------------
# SERVICE 1: Frontend (Пользовательский интерфейс)
# Не требует GPU, может быть вынесен на отдельный сервер
# ----------------------------------------------------
web_ui:
build:
context: .
dockerfile: Dockerfile
container_name: emom_frontend
restart: always
ports:
- "8080:8080"
networks:
- ai_mesh
environment:
- STREAMLIT_RUN=1
# Указываем UI, где искать LLM-бэкенд (внутри Docker-сети)
- OLLAMA_HOST=http://llm_backend:11434
volumes:
- ./src:/app/src
# Модели пока остаются здесь, так как код монолитный,
# но архитектурно сервис уже изолирован
- /home/zin/projects/Thesis/src/emoset_resnet50_best.pth:/app/emoset_resnet50_best.pth:ro
- /home/zin/projects/Thesis/src/music_engine/va_regressor.pkl:/app/src/music_engine/va_regressor.pkl:ro
- /home/zin/projects/Thesis/dataset/DEAM:/app/dataset/DEAM:ro
# Временно оставляем GPU для PyTorch (пока он не вынесен в API)
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
# ----------------------------------------------------
# SERVICE 2: LLM Inference Backend (Ollama)
# Изолированный сервис для языковой модели на GPU
# ----------------------------------------------------
llm_backend:
image: ollama/ollama:latest
container_name: ollama_gpu_inference
restart: always
networks:
- ai_mesh
ports:
- "11434:11434"
volumes:
# Проброс локальных моделей Ollama, чтобы не качать их заново внутри докера
- ~/.ollama:/root/.ollama
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
+76
View File
@@ -0,0 +1,76 @@
import io
import traceback
import numpy as np
from typing import List
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
from data_loader import load_music_engine, load_image_processor
from music_engine.llm_bridge import LLMAcousticBridge
app = FastAPI(title="EmoM API", version="1.0.0")
ml_context = {
"image_processor": None,
"music_matcher": None,
"llm_bridge": None
}
@app.on_event("startup")
async def startup_event():
print("Loading ML models...")
ml_context["image_processor"] = load_image_processor()
ml_context["music_matcher"] = load_music_engine()
ml_context["llm_bridge"] = LLMAcousticBridge()
print("Initialization complete.")
@app.post("/analyze")
async def analyze_event_endpoint(files: List[UploadFile] = File(...)):
try:
images = []
for file in files:
image_bytes = await file.read()
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
images.append(img)
print(f"Processing batch: {len(images)} images.")
img_processor = ml_context["image_processor"]
matcher = ml_context["music_matcher"]
llm = ml_context["llm_bridge"]
all_v, all_a = [], []
all_objects = []
for img in images:
embedding = img_processor.extract_embedding(img)
v, a = matcher.predict_va(embedding)
all_v.append(v)
all_a.append(a)
caption = img_processor.describe_scene(img)
all_objects.append(caption)
target_v = float(np.mean(all_v))
target_a = float(np.mean(all_a))
unique_semantics = list(set(all_objects))
llm_profile = llm.get_acoustic_profile(target_v, target_a, unique_semantics)
playlist_df = matcher.find_nearest_tracks(target_v, target_a, llm_profile=llm_profile, top_k=15)
tracks_list = playlist_df.to_dict(orient="records")
return JSONResponse(content={
"status": "success",
"images_processed": len(images),
"target_v": target_v,
"target_a": target_a,
"llm_profile": llm_profile,
"semantics": unique_semantics,
"tracks": tracks_list
})
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
+30 -39
View File
@@ -1,55 +1,46 @@
import os
from pathlib import Path
from typing import Tuple, List, Optional, Any
import pandas as pd
import numpy as np
import streamlit as st
from music_engine.matcher import MusicMatcher
from music_engine.image_processor import ImageProcessor
BASE_DIR = Path(__file__).resolve().parent
@st.cache_resource
def load_music_engine():
# Инициализация базы данных и регрессора для музыкального мэтчинга
def load_music_engine() -> MusicMatcher:
#Инициализация модуля подбора музыкальных композиций.
db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
if not db_path.exists():
print(f"Музыкальная БД не найдена: {db_path}")
return None
return MusicMatcher(db_path=db_path, model_path=model_path)
@st.cache_resource
def load_image_processor():
# Модуль обработки визуальных признаков
model_path = BASE_DIR / "emoset_resnet50_best.pth"
def load_image_processor() -> ImageProcessor:
#Инициализация модуля экстракции визуальных признаков.
weights_path = BASE_DIR / "emoset_resnet50_best.pth"
# Обработка пути при вызове из корневой директории
if not model_path.exists():
model_path = BASE_DIR.parent / "emoset_resnet50_best.pth"
return ImageProcessor(weights_path)
def load_emoset_data() -> Tuple[Optional[List[str]], Optional[np.ndarray], Optional[np.ndarray], Optional[Path]]:
# Загрузка тестовой выборки датасета EmoSet.
# Модуль сохранен для обеспечения обратной совместимости в отладочном контуре.
try:
images_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
labels_path = BASE_DIR / "emoset_test_labels.npy"
embeddings_path = BASE_DIR / "emoset_test_embeddings.npy"
return ImageProcessor(model_path=model_path)
@st.cache_data
def load_emoset_data():
# Выборка данных датасета для вкладки отладки
dataset_root = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test"
csv_path = dataset_root / "labels.csv"
img_dir = dataset_root / "images"
emb_path = BASE_DIR / "emoset_test_embeddings.npy"
lbl_path = BASE_DIR / "emoset_test_labels.npy"
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
print("Тестовые файлы датасета не найдены, вкладка отладки может работать некорректно")
return None, None, None, None
labels_df = pd.read_csv(csv_path)
test_filenames = labels_df['filename'].tolist()
test_embeddings = np.load(emb_path)
test_labels = np.load(lbl_path)
return test_filenames, test_embeddings, test_labels, img_dir
if not all(p.exists() for p in [labels_path, embeddings_path]):
return None, None, None, None
labels = np.load(labels_path)
embeddings = np.load(embeddings_path)
csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
df = pd.read_csv(csv_path)
return df['filename'].tolist(), embeddings, labels, images_path
except Exception as e:
print(f"[WARN] Failed to load EmoSet test artifacts: {str(e)}")
return None, None, None, None
Binary file not shown.
Binary file not shown.
+174 -58
View File
@@ -1,73 +1,189 @@
import sys
import os
import subprocess
import requests
import streamlit as st
import streamlit.components.v1 as components
from PIL import Image
import base64
from io import BytesIO
from data_loader import load_music_engine, load_emoset_data, load_image_processor
from tabs.tab_dataset import render_dataset_tab
from tabs.tab_live import render_live_tab
st.set_page_config(page_title="EmoM Playlist Generator", layout="wide", initial_sidebar_state="collapsed")
# Костыль для прямого запуска
if __name__ == "__main__":
if "STREAMLIT_RUN" not in os.environ:
os.environ["STREAMLIT_RUN"] = "1"
cmd = [sys.executable, "-m", "streamlit", "run", __file__, "--server.port", "8080", "--server.address", "0.0.0.0"]
subprocess.run(cmd)
sys.exit()
API_URL = os.getenv("BACKEND_API_URL", "http://emom_inference:8000") + "/analyze"
DEAM_AUDIO_DIR = "/app/dataset/DEAM/DEAM_audio/MEMD_audio"
viewport_mode = st.query_params.get("viewport", "desktop")
page_layout = "centered" if viewport_mode == "mobile" else "wide"
st.set_page_config(page_title="Thesis Demo", layout=page_layout)
# Определения ширины экрана и смены верстки
components.html(
"""
<script>
const w = window.parent.innerWidth;
const h = window.parent.innerHeight;
const url = new URL(window.parent.location.href);
def get_thumbnail_html(images, max_display=12):
html_images = ""
for file in images[:max_display]:
img = Image.open(file)
img.thumbnail((100, 100))
if img.mode != "RGB":
img = img.convert("RGB")
buffered = BytesIO()
img.save(buffered, format="JPEG")
b64_str = base64.b64encode(buffered.getvalue()).decode()
html_images += f'<img src="data:image/jpeg;base64,{b64_str}" style="width: 60px; height: 60px; object-fit: cover; border-radius: 8px; margin-right: 8px; margin-bottom: 8px; border: 1px solid rgba(255, 255, 255, 0.2);">'
// Считаем мобилкой, если ушли в портретный режим или экран уже 768px
const isMobile = (h > w) || (w < 768);
const target = isMobile ? "mobile" : "desktop";
if (url.searchParams.get("viewport") !== target) {
url.searchParams.set("viewport", target);
window.parent.location.href = url.href;
}
</script>
""",
height=0,
width=0,
)
if len(images) > max_display:
html_images += f'<span style="display: inline-block; width: 60px; height: 60px; line-height: 60px; text-align: center; background: rgba(150, 150, 150, 0.2); border-radius: 8px; vertical-align: top; font-size: 14px;">+{len(images) - max_display}</span>'
return f'<div style="display: flex; flex-wrap: wrap;">{html_images}</div>'
st.markdown(
"""
def main():
if "live_state" not in st.session_state:
st.session_state.live_state = "upload"
if "result_data" not in st.session_state:
st.session_state.result_data = None
viewport = st.query_params.get("viewport", "desktop")
st.markdown("""
<style>
img { max-width: 100%; height: auto; object-fit: contain; }
[data-testid="stMetricValue"] { font-size: 1.8rem; }
[data-testid="stFileUploadDropzone"] { min-height: 250px !important; display: flex; align-items: center; justify-content: center; border-radius: 16px; background-color: rgba(255, 75, 75, 0.03); }
.spinner-container { display: flex; flex-direction: column; align-items: center; justify-content: center; min-height: 40vh; margin-top: 10vh; }
.big-spinner { width: 120px; height: 120px; border: 10px solid rgba(255, 75, 75, 0.1); border-top: 10px solid #ff4b4b; border-radius: 50%; animation: spin 1s linear infinite; margin-bottom: 2rem; }
@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
#MainMenu {visibility: hidden;} footer {visibility: hidden;}
</style>
""",
unsafe_allow_html=True
)
""", unsafe_allow_html=True)
# Подгрузка ML-моделей и датасета
music_matcher = load_music_engine()
img_processor = load_image_processor()
emoset_files, emoset_embeddings, emoset_labels, emoset_path = load_emoset_data()
if st.session_state.live_state == "upload":
upload_placeholder = st.empty()
with upload_placeholder.container():
st.write("Загрузите изображения для визуально-семантического анализа.")
if viewport == "mobile":
st.markdown("<br>", unsafe_allow_html=True)
uploaded_files = st.file_uploader(
"Загрузка файлов",
type=['png', 'jpg', 'jpeg'],
accept_multiple_files=True,
label_visibility="collapsed" if viewport == "mobile" else "visible"
)
if uploaded_files:
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Выполнить анализ", type="primary", use_container_width=True):
st.session_state.uploaded_images = uploaded_files
st.session_state.live_state = "processing"
upload_placeholder.empty()
st.rerun()
st.markdown("<br>", unsafe_allow_html=True)
st.caption("Выбранные файлы:")
st.markdown(get_thumbnail_html(uploaded_files), unsafe_allow_html=True)
st.title("Генератор саундтреков (Research Demo)")
elif st.session_state.live_state == "processing":
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
files = st.session_state.get("uploaded_images", [])
st.markdown('<div class="spinner-container"><div class="big-spinner"></div><h3 style="text-align: center; font-weight: 400;">Обработка данных...</h3></div>', unsafe_allow_html=True)
try:
upload_data = [('files', (f.name, f.getvalue(), f.type)) for f in files]
response = requests.post(API_URL, files=upload_data, timeout=300)
if response.status_code == 200:
st.session_state.result_data = response.json()
st.session_state.live_state = "result"
st.rerun()
else:
st.error(f"Ошибка сервера: {response.status_code}")
if st.button("Назад"):
st.session_state.live_state = "upload"
st.rerun()
except Exception as e:
st.error(f"Ошибка соединения: {str(e)}")
if st.button("Назад"):
st.session_state.live_state = "upload"
st.rerun()
elif st.session_state.live_state == "result":
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
data = st.session_state.result_data
st.header(f"Сгенерированный плейлист (обработано файлов: {data['images_processed']})")
for row in data.get("tracks", []):
with st.container(border=True):
song_id = int(row['song_id'])
score = row['final_score']
audio_path = f"{DEAM_AUDIO_DIR}/{song_id}.mp3"
if not os.path.exists(audio_path):
audio_path = audio_path.replace('.mp3', '.wav')
if viewport == "desktop":
c1, c2 = st.columns([1, 3])
with c1:
st.write(f"**Track ID:** {song_id}")
st.caption(f"Score: {score:.4f}")
with c2:
if os.path.exists(audio_path):
st.audio(audio_path)
else:
st.caption("Аудиофайл не найден")
else:
st.write(f"**Track ID:** {song_id} (Score: {score:.4f})")
if os.path.exists(audio_path):
st.audio(audio_path)
else:
st.caption("Аудиофайл не найден")
tab_live, tab_debug = st.tabs(["Анализ событий (Свои фото)", "Отладка (Датасет EmoSet)"])
st.markdown("<br>", unsafe_allow_html=True)
with st.expander("Отладочная информация (Метрики)"):
st.subheader("Координаты V/A")
c_v, c_a = st.columns(2)
c_v.metric("Valence", f"{data['target_v']:.2f}")
c_a.metric("Arousal", f"{data['target_a']:.2f}")
st.markdown("---")
st.subheader("Акустические признаки (LLM)")
feature_titles = {
"energy": "RMS Energy",
"flux": "Spectral Flux",
"centroid": "Spectral Centroid",
"pitch": "F0 (Pitch)",
"hnr": "HNR",
"zcr": "ZCR"
}
# Развернутые описания
feature_helps = {
"energy": "Среднеквадратичная амплитуда (громкость). Бывает высокой в плотных, интенсивных композициях, отражает общую акустическую энергию сцены.",
"flux": "Спектральный поток. Измеряет резкость изменений в спектре. Высок при четком, агрессивном ритме и частой смене нот.",
"centroid": "Спектральный центроид («яркость» звука). Высокие значения указывают на преобладание высоких частот (звонкие инструменты, открытые пространства).",
"pitch": "Основная частота звука. Высокий pitch характерен для позитивных, легких или, напротив, напряженных мелодий.",
"hnr": "Отношение гармоник к шуму. Высокий HNR — чистая мелодия и вокал. Низкий HNR — присутствие дисторшна, шумов или перкуссии.",
"zcr": "Частота пересечения нуля. Отражает шумовую составляющую сигнала. Высок в треках с выраженными ударными (hi-hats) или атмосферным шумом."
}
llm_profile = data.get("llm_profile")
if llm_profile and isinstance(llm_profile, dict) and len(llm_profile) > 0:
cols_per_row = 2 if viewport == "mobile" else 3
llm_items = list(llm_profile.items())
for i in range(0, len(llm_items), cols_per_row):
cols = st.columns(cols_per_row)
for j in range(cols_per_row):
if i + j < len(llm_items):
k, v = llm_items[i + j]
label = feature_titles.get(k, k)
tooltip = feature_helps.get(k, "")
cols[j].metric(label, f"{v:.2f}", help=tooltip)
else:
st.caption("Акустический профиль недоступен. Применен fallback-алгоритм.")
st.markdown("---")
st.write("**Извлеченные теги (BLIP-2):**")
st.write(", ".join([str(c).capitalize() for c in data.get("semantics", [])]))
with tab_live:
if img_processor:
render_live_tab(music_matcher, img_processor)
else:
st.error("Ошибка загрузки: не найдены веса ResNet для image_processor.")
st.markdown("<br>", unsafe_allow_html=True)
if st.button("Новый запрос", use_container_width=True):
st.session_state.live_state = "upload"
st.session_state.result_data = None
st.session_state.pop("uploaded_images", None)
st.rerun()
with tab_debug:
render_dataset_tab(music_matcher, emoset_files, emoset_embeddings, emoset_labels, emoset_path)
if __name__ == "__main__":
main()
+5 -1
View File
@@ -32,7 +32,11 @@ class ImageProcessor:
# Модуль семантического описания сцены
print("Инициализация BLIP-2...")
self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
# Обход бага конфигурации Hugging Face (ручная сборка процессора)
from transformers import BlipImageProcessor, AutoTokenizer
img_proc = BlipImageProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
tok = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b", use_fast=False)
self.blip_processor = Blip2Processor(image_processor=img_proc, tokenizer=tok)
self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
torch_dtype=torch.float16
+54 -54
View File
@@ -1,65 +1,65 @@
import re
import os
import json
import re
import requests
class LLMAcousticBridge:
def __init__(self, target_model="dolphin-llama3:8b"):
self.api_url = "http://localhost:11434/api/generate"
self.model = target_model
def __init__(self, model_name="dolphin-llama3:8b"):
self.model_name = model_name
base_url = os.getenv("OLLAMA_API_URL", "http://emom_ollama:11434")
self.api_url = f"{base_url}/api/generate"
def _extract_json(self, raw_text: str):
# Проверка на ИИдиота, LLM иногда игнорирует format="json" и оборачивает ответ в маркдаун
try:
match = re.search(r'\{.*\}', raw_text, re.DOTALL)
if match:
return json.loads(match.group(0))
return json.loads(raw_text)
except json.JSONDecodeError:
# Если ИИдиот
return None
def get_acoustic_profile(self, v_score: float, a_score: float, scene_context: list) -> dict | None:
# Агрегация контекста для обработки серии снимков (события)
context_merged = " | ".join(scene_context) if scene_context else "abstract scene"
def get_acoustic_profile(self, valence, arousal, semantics):
context_str = ", ".join(semantics) if semantics else "abstract scene"
prompt = f"""
Analyze the visual context and emotions to determine the ideal background music properties.
Emotions: Valence {valence:.1f}/9.0 (Positivity), Arousal {arousal:.1f}/9.0 (Energy).
Visual Context: {context_str}.
Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0.
1. "energy": (Loudness/Density)
2. "flux": (Rhythmic sharpness/Beat)
3. "centroid": (Brightness)
4. "pitch": (Fundamental frequency)
5. "hnr": (Harmonics-to-Noise)
6. "zcr": (Percussiveness)
Return ONLY a valid JSON object. No explanations, no markdown blocks.
Example: {{"energy": 0.8, "flux": 0.5, "centroid": 0.6, "pitch": 0.4, "hnr": 0.9, "zcr": 0.3}}
"""
system_prompt = f"""You are an expert music producer and acoustic engineer.
Analyze the visual context and emotions to determine the ideal background music properties.
Emotions: Valence {v_score:.1f}/9.0 (Positivity), Arousal {a_score:.1f}/9.0 (Energy).
Visual Context: {context_merged}.
Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0.
1. "energy": (Loudness/Density. High for massive/busy scenes, Low for calm)
2. "flux": (Rhythmic sharpness/Beat. High for action/people/cars, Low for static nature)
3. "centroid": (Brightness: 0=Dark/Bass/Massive, 1=Bright/Treble/Light)
4. "pitch": (Fundamental frequency: 0=Low pitch/Huge objects, 1=High pitch/Small objects)
5. "hnr": (Harmonics-to-Noise: 0=Noisy/Distorted textures, 1=Clear/Melodic/Smooth textures)
6. "zcr": (Percussiveness. High for detailed noise like leaves/rain, Low for solid blocks)
Return ONLY a valid JSON object. Do not add any text or explanation.
Example: {{"energy": 0.5, "flux": 0.2, "centroid": 0.4, "pitch": 0.3, "hnr": 0.8, "zcr": 0.1}}"""
try:
# Отправка промпта локальной Ollama
response = requests.post(self.api_url, json={
"model": self.model,
"prompt": system_prompt,
payload = {
"model": self.model_name,
"prompt": prompt,
"stream": False,
"format": "json"
}, timeout=45)
response.raise_for_status()
"format": "json" # Принудительный JSON-режим Ollama
}
raw_response = response.json().get("response", "")
profile_data = self._extract_json(raw_response)
print(f"Запрос акустического профиля к Ollama...")
response = requests.post(self.api_url, json=payload, timeout=120)
# Валидация структуры ответа
expected_features = {'energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr'}
if profile_data and expected_features.issubset(profile_data.keys()):
return profile_data
if response.status_code == 200:
data = response.json()
response_text = data.get("response", "")
print("LLM вернула неполный или некорректный набор акустических признаков")
return None
except requests.exceptions.RequestException as req_err:
print(f"Не удалось подключиться к Ollama: {req_err}")
return None
try:
# 1. Попытка прямой десериализации
profile = json.loads(response_text)
return profile
except json.JSONDecodeError:
# 2. Аварийное извлечение JSON из текста с помощью регулярного выражения
match = re.search(r'\{.*\}', response_text, re.DOTALL)
if match:
return json.loads(match.group(0))
print(f"Ошибка парсинга LLM ответа: {response_text}")
return {}
else:
print(f"Ollama вернула ошибку HTTP: {response.status_code}")
return {}
except Exception as e:
print(f"Ошибка соединения с Ollama: {str(e)}")
return {}
Binary file not shown.
+1
View File
@@ -1,5 +1,6 @@
#!/bin/bash
# Данный скрипт написан ИИ для быстрой подготовки окружения, установка драйверов и докера
# Остановка скрипта при возникновении любой ошибки
set -e
-541
View File
@@ -1,541 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0c00b67b",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"import pandas as pd\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torchvision.transforms as T\n",
"import timm"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "84c3657f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'cuda'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Конфигурация параметров обучения и путей файловой системы\n",
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
"BATCH_SIZE = 64\n",
"EPOCHS = 15\n",
"LR = 3e-4\n",
"NUM_WORKERS = 40\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Аппаратное ускорение: {device}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f749add",
"metadata": {},
"outputs": [],
"source": [
"class EmoSetDataset(Dataset):\n",
" def __init__(self, root: Path | str, split: str):\n",
" self.root = Path(root) / split\n",
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
"\n",
" # Формирование словарей маппинга классов\n",
" self.labels = sorted(self.df[\"label\"].unique())\n",
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
"\n",
" # Базовые трансформации для валидации и теста\n",
" base_tf = [\n",
" T.ToTensor(),\n",
" T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
" ]\n",
"\n",
" # Внедрение аугментации исключительно для обучающей выборки (предотвращение переобучения)\n",
" if split == \"train\":\n",
" self.transform = T.Compose([\n",
" T.RandomResizedCrop(224),\n",
" T.RandomHorizontalFlip(),\n",
" *base_tf\n",
" ])\n",
" else:\n",
" self.transform = T.Compose([\n",
" T.Resize(256),\n",
" T.CenterCrop(224),\n",
" *base_tf\n",
" ])\n",
"\n",
" def __len__(self):\n",
" return len(self.df)\n",
"\n",
" def __getitem__(self, idx):\n",
" row = self.df.iloc[idx]\n",
" img_path = self.root / \"images\" / row[\"filename\"]\n",
"\n",
" # Обработка возможных исключений ввода-вывода (поврежденные JPEG-файлы в датасете)\n",
" try:\n",
" img = Image.open(img_path).convert(\"RGB\")\n",
" except Exception:\n",
" img = Image.new(\"RGB\", (224, 224), (0, 0, 0))\n",
"\n",
" img_tensor = self.transform(img)\n",
" label_idx = self.label2idx[row[\"label\"]]\n",
" \n",
" return img_tensor, label_idx"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c8805341",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classes: ['amusement', 'anger', 'awe', 'contentment', 'disgust', 'excitement', 'fear', 'sadness']\n"
]
}
],
"source": [
"# Подготовка объектов выборки\n",
"train_ds = EmoSetDataset(DATA_ROOT, \"train\")\n",
"val_ds = EmoSetDataset(DATA_ROOT, \"val\")\n",
"\n",
"# Инициализация итераторов с закреплением памяти (pin_memory) для ускорения передачи на GPU\n",
"train_loader = DataLoader(\n",
" train_ds,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" num_workers=NUM_WORKERS,\n",
" pin_memory=True\n",
")\n",
"\n",
"val_loader = DataLoader(\n",
" val_ds,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=False,\n",
" num_workers=NUM_WORKERS,\n",
" pin_memory=True\n",
")\n",
"\n",
"print(f\"Индексированные классы: {train_ds.labels}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dffce582",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ResNet(\n",
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (layer1): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer2): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer3): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (4): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (5): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer4): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
" (fc): Linear(in_features=2048, out_features=8, bias=True)\n",
")"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# TODO перед защитой, повторить оптимизаторы\n",
"# Загрузка предобученной архитектуры ResNet-50 с заменой классификационного слоя\n",
"model = timm.create_model(\n",
" \"resnet50\",\n",
" pretrained=True,\n",
" num_classes=len(train_ds.labels)\n",
")\n",
"model.to(device)\n",
"\n",
"# Функция потерь для многоклассовой классификации\n",
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"# Оптимизатор AdamW с L2-регуляризацией (weight_decay) для повышения обобщающей способности\n",
"optimizer = torch.optim.AdamW(\n",
" model.parameters(),\n",
" lr=LR,\n",
" weight_decay=1e-4\n",
")\n",
"\n",
"# Планировщик скорости обучения: косинусный отжиг\n",
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
" optimizer,\n",
" T_max=EPOCHS\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81a457ef",
"metadata": {},
"outputs": [],
"source": [
"def train_epoch(current_model, loader):\n",
" current_model.train()\n",
" total_loss = 0.0\n",
" correct_preds = 0\n",
" total_samples = 0\n",
"\n",
" for imgs, labels in tqdm(loader, desc=\"Тренировка\", leave=False):\n",
" imgs = imgs.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" logits = current_model(imgs)\n",
" loss = criterion(logits, labels)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" total_loss += loss.item() * imgs.size(0)\n",
" preds = logits.argmax(dim=1)\n",
" correct_preds += (preds == labels).sum().item()\n",
" total_samples += labels.size(0)\n",
"\n",
" return total_loss / total_samples, correct_preds / total_samples\n",
"\n",
"@torch.no_grad()\n",
"def val_epoch(current_model, loader):\n",
" # Перевод модели в режим инференса (отключение Dropout и фиксация BatchNorm)\n",
" current_model.eval()\n",
" total_loss = 0.0\n",
" correct_preds = 0\n",
" total_samples = 0\n",
"\n",
" for imgs, labels in tqdm(loader, desc=\"Валидация\", leave=False):\n",
" imgs = imgs.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" logits = current_model(imgs)\n",
" loss = criterion(logits, labels)\n",
"\n",
" total_loss += loss.item() * imgs.size(0)\n",
" preds = logits.argmax(dim=1)\n",
" correct_preds += (preds == labels).sum().item()\n",
" total_samples += labels.size(0)\n",
"\n",
" return total_loss / total_samples, correct_preds / total_samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "951aa9e3",
"metadata": {},
"outputs": [],
"source": [
"best_val_acc = 0.0\n",
"checkpoint_path = \"../emoset_resnet50_best.pth\"\n",
"\n",
"print(\"Старт процесса обучения...\")\n",
"\n",
"for epoch in range(1, EPOCHS + 1):\n",
" train_loss, train_acc = train_epoch(model, train_loader)\n",
" val_loss, val_acc = val_epoch(model, val_loader)\n",
"\n",
" # Обновление шага планировщика\n",
" scheduler.step()\n",
"\n",
" print(\n",
" f\"Эпоха {epoch:02d}/{EPOCHS} | \"\n",
" f\"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | \"\n",
" f\"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}\"\n",
" )\n",
"\n",
" # Экспорт весов при улучшении целевой метрики\n",
" if val_acc > best_val_acc:\n",
" best_val_acc = val_acc\n",
" torch.save(model.state_dict(), checkpoint_path)\n",
" print(f\" -> Сохранен новый лучший чекпоинт (Acc: {best_val_acc:.4f})\")\n",
"\n",
"print(\"Обучение завершено.\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "thesis-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+184
View File
@@ -0,0 +1,184 @@
import os
import random
import warnings
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm
# Подавление предупреждений цветовых профилей
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
# Настройки окружения
DATA_ROOT = Path("/home/zin/projects/Thesis/NFS/Thesis/Emoset/EmoSet-118K")
BATCH_SIZE = 64
EPOCHS = 30
LR = 5e-5
NUM_WORKERS = 62
PATIENCE = 7
# Маппинг классов
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {DEVICE}")
# Фиксация генераторов псевдослучайных чисел
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed()
# Инициализация структур данных
class EmoSetDataset(Dataset):
def __init__(self, root: Path | str, split: str, transform=None):
self.root = Path(root) / split
self.df = pd.read_csv(self.root / "labels.csv")
self.transform = transform
# Фильтрация датафрейма
self.df = self.df[self.df["label"].isin(CLASS_MAPPING.keys())].reset_index(drop=True)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_path = self.root / "images" / row["filename"]
try:
img = Image.open(img_path).convert("RGB")
except Exception:
img = Image.new("RGB", (256, 256), (0, 0, 0))
if self.transform:
img_tensor = self.transform(img)
else:
img_tensor = T.ToTensor()(img)
label_idx = CLASS_MAPPING[row["label"]]
return img_tensor, label_idx
# Трансформации
base_tf = [
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
train_transform = T.Compose([
T.Resize(256, antialias=True),
T.RandomCrop(224),
T.RandomHorizontalFlip(),
*base_tf
])
val_transform = T.Compose([
T.Resize(256, antialias=True),
T.CenterCrop(224),
*base_tf
])
train_ds = EmoSetDataset(DATA_ROOT, "train", transform=train_transform)
val_ds = EmoSetDataset(DATA_ROOT, "val", transform=val_transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
# Инициализация модели и оптимизатора
model = timm.create_model("resnet50", pretrained=True, num_classes=8, drop_rate=0.3)
model.to(DEVICE)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
# Логика эпохи обучения
def train_epoch(current_model, loader):
current_model.train()
total_loss, correct_preds, total_samples = 0.0, 0, 0
for imgs, labels in tqdm(loader, desc="Тренировка", leave=False, smoothing=0):
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad(set_to_none=True)
logits = current_model(imgs)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * imgs.size(0)
preds = logits.argmax(dim=1)
correct_preds += (preds == labels).sum().item()
total_samples += labels.size(0)
return total_loss / total_samples, correct_preds / total_samples
# Логика эпохи валидации
@torch.no_grad()
def val_epoch(current_model, loader):
current_model.eval()
total_loss, correct_preds, total_samples = 0.0, 0, 0
for imgs, labels in tqdm(loader, desc="Валидация", leave=False, smoothing=0):
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
logits = current_model(imgs)
loss = criterion(logits, labels)
total_loss += loss.item() * imgs.size(0)
preds = logits.argmax(dim=1)
correct_preds += (preds == labels).sum().item()
total_samples += labels.size(0)
return total_loss / total_samples, correct_preds / total_samples
if __name__ == "__main__":
best_val_acc = 0.0
best_val_loss = float('inf')
epochs_no_improve = 0
checkpoint_path = "./emosetV2_resnet50_best.pth"
print("Старт обучения.")
for epoch in range(1, EPOCHS + 1):
train_loss, train_acc = train_epoch(model, train_loader)
val_loss, val_acc = val_epoch(model, val_loader)
scheduler.step()
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
# Сохранение лучших весов по Accuracy
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), checkpoint_path)
print(f"Сохранен чекпоинт (Acc: {best_val_acc:.4f})")
# Оценка переобучения по Loss (Early Stopping)
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= PATIENCE:
print(f"Ранняя остановка: метрика валидации не улучшается {PATIENCE} эпох.")
break
print("Процесс завершен.")
+283
View File
@@ -0,0 +1,283 @@
import os
import random
import warnings
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
# Подавление предупреждений цветовых профилей
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
# Настройки окружения
DATA_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K")
# ВАЖНО: Добавили путь для медиа файлов
MEDIA_DIR = Path("./src/scripts/media")
MEDIA_DIR.mkdir(parents=True, exist_ok=True)
BATCH_SIZE = 64
EPOCHS = 30
LR = 5e-5
NUM_WORKERS = 32
PATIENCE = 7
# Маппинг классов
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
}
# Инвертированный маппинг для графиков
INV_CLASS_MAPPING = {v: k for k, v in CLASS_MAPPING.items()}
CLASS_NAMES = [INV_CLASS_MAPPING[i] for i in range(len(CLASS_MAPPING))]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {DEVICE}")
# Фиксация генераторов псевдослучайных чисел
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed()
# Инициализация структур данных
class EmoSetDataset(Dataset):
def __init__(self, root: Path | str, split: str, transform=None):
self.root = Path(root) / split
self.df = pd.read_csv(self.root / "labels.csv")
self.transform = transform
# Фильтрация датафрейма
self.df = self.df[self.df["label"].isin(CLASS_MAPPING.keys())].reset_index(drop=True)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_path = self.root / "images" / row["filename"]
try:
img = Image.open(img_path).convert("RGB")
except Exception:
img = Image.new("RGB", (256, 256), (0, 0, 0))
if self.transform:
img_tensor = self.transform(img)
else:
img_tensor = T.ToTensor()(img)
label_idx = CLASS_MAPPING[row["label"]]
return img_tensor, label_idx
# Трансформации
base_tf = [
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
train_transform = T.Compose([
T.Resize(256, antialias=True),
T.RandomCrop(224),
T.RandomHorizontalFlip(),
*base_tf
])
val_transform = T.Compose([
T.Resize(256, antialias=True),
T.CenterCrop(224),
*base_tf
])
train_ds = EmoSetDataset(DATA_ROOT, "train", transform=train_transform)
val_ds = EmoSetDataset(DATA_ROOT, "val", transform=val_transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
# Инициализация модели и оптимизатора
model = timm.create_model("resnet50", pretrained=True, num_classes=8, drop_rate=0.3)
model.to(DEVICE)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
# Функции для отрисовки графиков
def plot_learning_curves(history):
"""Отрисовка графиков функции потерь и точности"""
epochs = range(1, len(history['train_loss']) + 1)
plt.figure(figsize=(14, 5))
# График Loss
plt.subplot(1, 2, 1)
plt.plot(epochs, history['train_loss'], 'b-', label='Train Loss')
plt.plot(epochs, history['val_loss'], 'r--', label='Validation Loss')
plt.title('График функции потерь (Loss)', fontsize=14)
plt.xlabel('Эпохи', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.legend()
plt.grid(True, linestyle=':', alpha=0.7)
# График Accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs, history['train_acc'], 'b-', label='Train Accuracy')
plt.plot(epochs, history['val_acc'], 'r--', label='Validation Accuracy')
plt.title('График точности (Accuracy)', fontsize=14)
plt.xlabel('Эпохи', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.legend()
plt.grid(True, linestyle=':', alpha=0.7)
plt.tight_layout()
plot_path = MEDIA_DIR / "training_history.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"[INFO] График обучения сохранен в: {plot_path}")
def plot_confusion_matrix(y_true, y_pred):
"""Отрисовка тепловой матрицы ошибок"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
cbar_kws={'label': 'Количество сэмплов'})
plt.title('Матрица ошибок (Confusion Matrix) - ResNet50', fontsize=16, pad=20)
plt.ylabel('Истинные классы (Ground Truth)', fontsize=12)
plt.xlabel('Предсказанные классы (Predicted)', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
cm_path = MEDIA_DIR / "confusion_matrix_emoset.png"
plt.savefig(cm_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"[INFO] Матрица ошибок сохранена в: {cm_path}")
# Логика эпохи обучения
def train_epoch(current_model, loader):
current_model.train()
total_loss, correct_preds, total_samples = 0.0, 0, 0
for imgs, labels in tqdm(loader, desc="Тренировка", leave=False, smoothing=0):
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad(set_to_none=True)
logits = current_model(imgs)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * imgs.size(0)
preds = logits.argmax(dim=1)
correct_preds += (preds == labels).sum().item()
total_samples += labels.size(0)
return total_loss / total_samples, correct_preds / total_samples
# Логика эпохи валидации с сохранением предсказаний для матрицы ошибок
@torch.no_grad()
def val_epoch(current_model, loader, return_preds=False):
current_model.eval()
total_loss, correct_preds, total_samples = 0.0, 0, 0
all_preds, all_labels = [], []
for imgs, labels in tqdm(loader, desc="Валидация", leave=False, smoothing=0):
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
logits = current_model(imgs)
loss = criterion(logits, labels)
total_loss += loss.item() * imgs.size(0)
preds = logits.argmax(dim=1)
correct_preds += (preds == labels).sum().item()
total_samples += labels.size(0)
if return_preds:
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
avg_loss = total_loss / total_samples
avg_acc = correct_preds / total_samples
if return_preds:
return avg_loss, avg_acc, all_labels, all_preds
return avg_loss, avg_acc
if __name__ == "__main__":
best_val_acc = 0.0
best_val_loss = float('inf')
epochs_no_improve = 0
checkpoint_path = "./emosetV2_resnet50_best.pth"
# Словарь для хранения истории обучения
history = {
'train_loss': [], 'train_acc': [],
'val_loss': [], 'val_acc': []
}
# Переменные для хранения лучших предсказаний для матрицы
best_labels, best_preds = [], []
print("Старт обучения.")
for epoch in range(1, EPOCHS + 1):
train_loss, train_acc = train_epoch(model, train_loader)
# Получаем предсказания только если это может быть лучшая эпоха
val_loss, val_acc, val_labels, val_preds = val_epoch(model, val_loader, return_preds=True)
scheduler.step()
# Запись в историю
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
# Сохранение лучших весов по Accuracy
if val_acc > best_val_acc:
best_val_acc = val_acc
best_labels = val_labels # Сохраняем предсказания лучшей модели
best_preds = val_preds
torch.save(model.state_dict(), checkpoint_path)
print(f"Сохранен чекпоинт (Acc: {best_val_acc:.4f})")
# Оценка переобучения по Loss (Early Stopping)
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= PATIENCE:
print(f"Ранняя остановка: метрика валидации не улучшается {PATIENCE} эпох.")
break
print("Процесс обучения завершен. Генерирую графики для диссертации...")
plot_learning_curves(history)
plot_confusion_matrix(best_labels, best_preds)
print("Все медиафайлы успешно созданы!")
+171
View File
@@ -0,0 +1,171 @@
import os
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
# Настройки путей для медиа
MEDIA_DIR = Path("scripts/media")
MEDIA_DIR.mkdir(parents=True, exist_ok=True)
# Конфигурация путей для инференса и кэширования векторов
DATA_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K")
MODEL_PATH = Path("./src/emoset_resnet50_best.pth")
BATCH_SIZE = 128
NUM_WORKERS = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Вычисления перенесены на: {device}")
class EmoSetFeatureDataset(Dataset):
def __init__(self, root: Path | str, split: str):
self.root = Path(root) / split
self.df = pd.read_csv(self.root / "labels.csv")
self.labels = sorted(self.df["label"].unique())
self.label2idx = {l: i for i, l in enumerate(self.labels)}
self.idx2label = {i: l for l, i in self.label2idx.items()}
# Для экстракции признаков аугментация отключена, используется строгий CenterCrop
self.transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_path = self.root / "images" / row["filename"]
# Перехват битых файлов выборки
try:
img = Image.open(img_path).convert("RGB")
except Exception:
img = Image.new("RGB", (224, 224), (0, 0, 0))
img_tensor = self.transform(img)
label_idx = self.label2idx[row["label"]]
return img_tensor, label_idx
def plot_tsne(embeddings, labels, idx2label, sample_limit=3000):
"""Генерация t-SNE графика для диссертации"""
print(f"Построение t-SNE проекции для {sample_limit} сэмплов...")
tsne_model = TSNE(n_components=2, perplexity=30, random_state=42)
embeddings_2d = tsne_model.fit_transform(embeddings[:sample_limit])
labels_subset = labels[:sample_limit]
plt.figure(figsize=(12, 9))
# Используем более академическую палитру
scatter = plt.scatter(
embeddings_2d[:, 0],
embeddings_2d[:, 1],
c=labels_subset,
cmap="Set2", # Set2 лучше различается при печати
alpha=0.7,
s=20,
edgecolors='w',
linewidths=0.5
)
# Формирование легенды
handles, _ = scatter.legend_elements()
legend_labels = [idx2label[i] for i in range(len(idx2label))]
# Размещение легенды снаружи графика, чтобы не перекрывать данные
plt.legend(handles, legend_labels, title="Эмоциональные классы",
bbox_to_anchor=(1.05, 1), loc='upper left')
plt.title("2D проекция скрытого пространства признаков (t-SNE)", pad=20, fontsize=14)
plt.xlabel("Первая главная компонента (t-SNE 1)", fontsize=12)
plt.ylabel("Вторая главная компонента (t-SNE 2)", fontsize=12)
plt.grid(True, linestyle='--', alpha=0.3)
plt.tight_layout()
plot_path = MEDIA_DIR / "tsne_embeddings.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"[INFO] График t-SNE сохранен в: {plot_path}")
if __name__ == "__main__":
test_ds = EmoSetFeatureDataset(DATA_ROOT, "test")
test_loader = DataLoader(
test_ds,
batch_size=BATCH_SIZE,
shuffle=False, # Отключение шаффла для строгого соответствия индексов
num_workers=NUM_WORKERS,
pin_memory=True
)
print(f"Подготовлено для извлечения: {len(test_ds)} файлов.")
# Инициализация модели и загрузка лучших весов
feature_extractor = timm.create_model(
"resnet50",
pretrained=False,
num_classes=len(test_ds.labels)
)
try:
checkpoint = torch.load(MODEL_PATH, map_location=device)
feature_extractor.load_state_dict(checkpoint)
print("Веса модели успешно загружены.")
except Exception as e:
print(f"Ошибка загрузки весов: {e}. Убедитесь, что модель обучена.")
exit(1)
# Удаление классификационного слоя (fc)
feature_extractor.reset_classifier(0)
feature_extractor.to(device)
feature_extractor.eval()
print("Слой классификации удален. Модель готова к экстракции.")
extracted_embeddings = []
extracted_labels = []
print("Старт пакетной экстракции признаков...")
with torch.no_grad():
for imgs, labels in tqdm(test_loader, desc="Экстракция"):
imgs = imgs.to(device)
# Получение вектора [BATCH_SIZE, 2048]
embeddings_batch = feature_extractor(imgs)
extracted_embeddings.append(embeddings_batch.cpu().numpy())
extracted_labels.append(labels.numpy())
# Агрегация батчей в единые массивы
np_embeddings = np.concatenate(extracted_embeddings, axis=0)
np_labels = np.concatenate(extracted_labels, axis=0)
print(f"Размерность матрицы признаков: {np_embeddings.shape}")
# Сохранение артефактов
np.save("./src/emoset_test_embeddings.npy", np_embeddings)
np.save("./src/emoset_test_labels.npy", np_labels)
print("Матрицы успешно экспортированы в .npy файлы.")
# Генерация медиа для диссертации
plot_tsne(np_embeddings, np_labels, test_ds.idx2label, sample_limit=3000)
print("Процесс полностью завершен.")
-264
View File
@@ -1,264 +0,0 @@
import os
import gc
import pickle
import random
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.io as tv_io
from torch.amp import autocast, GradScaler
from tqdm import tqdm
import timm
# Конфигурация стенда и путей файловой системы
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M")
CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl")
PREVIOUS_WEIGHTS = Path("/home/zin/projects/Thesis/src/emoset_resnet50_best.pth")
RESUME_CHECKPOINT = Path("/home/zin/projects/Thesis/src/emoset_resnet50_resume.pth")
SAVE_MODEL_PATH = Path("/home/zin/projects/Thesis/src/emoset_resnet50_finetuned_2_41M.pth")
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7
}
# Гиперпараметры конвейера обучения
BATCH_SIZE = 82
EPOCHS = 15
LR = 5e-5
NUM_TRAIN_WORKERS = 48
NUM_VAL_WORKERS = 18
PATIENCE = 4
def prepare_dataset_index():
# Построение или загрузка индекса файлов для минимизации I/O операций по сети (NFS)
if CACHE_PATH.exists():
print(f"Загрузка карты файловой системы из кэша: {CACHE_PATH.name}")
with open(CACHE_PATH, 'rb') as f:
cache_data = pickle.load(f)
return cache_data['image_paths'], cache_data['labels']
print(f"Сканирование сетевой директории {DATA_ROOT} (первичная индексация)...")
paths, labels = [], []
for img_path in DATA_ROOT.rglob('*.jpg'):
emotion_folder = img_path.parts[-3].lower()
if emotion_folder in CLASS_MAPPING:
paths.append(str(img_path))
labels.append(CLASS_MAPPING[emotion_folder])
with open(CACHE_PATH, 'wb') as f:
pickle.dump({'image_paths': paths, 'labels': labels}, f)
return paths, labels
class EmoSetDirectDataset(Dataset):
# Датасет с отложенной аугментацией: декодирование на CPU, трансформации на GPU
def __init__(self, image_paths, labels):
self.image_paths = image_paths
self.labels = labels
self.base_transform = T.Resize((256, 256), antialias=True)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
try:
image = tv_io.read_image(self.image_paths[idx], mode=tv_io.ImageReadMode.RGB)
image = image.to(torch.float32) / 255.0
image = self.base_transform(image)
except Exception:
# Изолирование сбоев ввода-вывода (поврежденные файлы на сетевом диске)
image = torch.zeros((3, 256, 256), dtype=torch.float32)
return image, self.labels[idx]
def build_gpu_transforms():
# Перенос матричных операций аугментации на тензорные ядра видеокарты
train_tf = torch.nn.Sequential(
T.RandomCrop((224, 224)),
T.RandomHorizontalFlip(p=0.5),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(DEVICE)
val_tf = torch.nn.Sequential(
T.CenterCrop((224, 224)),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(DEVICE)
return train_tf, val_tf
if __name__ == "__main__":
print(f"Инициализация конвейера обучения. Устройство: {DEVICE}")
all_paths, all_labels = prepare_dataset_index()
# Фиксация сида для детерминированного разделения выборок при перезапусках скрипта
random.seed(42)
combined = list(zip(all_paths, all_labels))
random.shuffle(combined)
all_paths, all_labels = zip(*combined)
split_idx = int(len(all_paths) * 0.95)
train_loader = DataLoader(
EmoSetDirectDataset(all_paths[:split_idx], all_labels[:split_idx]),
batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
prefetch_factor=2, persistent_workers=True
)
val_loader = DataLoader(
EmoSetDirectDataset(all_paths[split_idx:], all_labels[split_idx:]),
batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_VAL_WORKERS, pin_memory=True,
prefetch_factor=2, persistent_workers=True
)
gpu_train_tf, gpu_val_tf = build_gpu_transforms()
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = GradScaler()
best_val_loss = float('inf')
epochs_no_improve = 0
start_epoch = 1
# Инициализация механизма отказоустойчивости и интеграция весов
if RESUME_CHECKPOINT.exists():
print(f"Восстановление контекста выполнения из: {RESUME_CHECKPOINT.name}")
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if 'scaler_state_dict' in checkpoint: scaler.load_state_dict(checkpoint['scaler_state_dict'])
if 'best_val_loss' in checkpoint: best_val_loss = checkpoint['best_val_loss']
start_epoch = checkpoint['epoch'] + 1
elif PREVIOUS_WEIGHTS.exists():
print(f"Интеграция претренированных весов: {PREVIOUS_WEIGHTS.name}")
model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE))
else:
print("Веса не найдены. Инициализация с ImageNet.")
model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE)
try:
for epoch in range(start_epoch, EPOCHS + 1):
# Проход по обучающей выборке
model.train()
running_loss, correct, total = 0.0, 0, 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [Train]")
for inputs, labels in pbar:
try:
inputs = inputs.to(DEVICE, non_blocking=True)
labels = labels.to(DEVICE, non_blocking=True)
inputs = gpu_train_tf(inputs)
optimizer.zero_grad()
# Смешанная точность для экономии VRAM
with autocast(device_type="cuda"):
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
except RuntimeError as memory_err:
# Подавление пиковых скачков потребления VRAM
if "out of memory" in str(memory_err).lower():
if 'outputs' in locals(): del outputs
if 'loss' in locals(): del loss
torch.cuda.empty_cache()
optimizer.zero_grad()
continue
raise memory_err
train_loss = running_loss / total if total > 0 else 0
train_acc = correct / total if total > 0 else 0
gc.collect()
torch.cuda.empty_cache()
# Проход по валидационной выборке
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for val_inputs, val_labels in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} [Val]", leave=False):
val_inputs = val_inputs.to(DEVICE, non_blocking=True)
val_labels = val_labels.to(DEVICE, non_blocking=True)
val_inputs = gpu_val_tf(val_inputs)
with autocast(device_type="cuda"):
val_outputs = model(val_inputs)
v_loss = criterion(val_outputs, val_labels)
val_loss += v_loss.item() * val_inputs.size(0)
_, val_predicted = val_outputs.max(1)
val_total += val_labels.size(0)
val_correct += val_predicted.eq(val_labels).sum().item()
epoch_val_loss = val_loss / val_total if val_total > 0 else 0
epoch_val_acc = val_correct / val_total if val_total > 0 else 0
scheduler.step()
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}")
# Оценка критериев ранней остановки и сохранение состояния сессии
if epoch_val_loss < best_val_loss:
best_val_loss = epoch_val_loss
epochs_no_improve = 0
torch.save(model.state_dict(), str(SAVE_MODEL_PATH).replace(".pth", "_best.pth"))
else:
epochs_no_improve += 1
if epochs_no_improve >= PATIENCE and epoch >= 15:
print(f"Сработал механизм Early Stopping. Валидация не улучшается {PATIENCE} эпох.")
break
# Атомарное сохранение контекста
checkpoint_state = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'best_val_loss': best_val_loss
}
torch.save(checkpoint_state, RESUME_CHECKPOINT)
gc.collect()
except KeyboardInterrupt:
print("\nВыполнение прервано пользователем (SIGINT).")
print(f"Дамп памяти конвейера зафиксирован на эпохе {epoch}.")
checkpoint_state = {
'epoch': epoch, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': scaler.state_dict(),
'best_val_loss': best_val_loss
}
torch.save(checkpoint_state, RESUME_CHECKPOINT)
else:
if SAVE_MODEL_PATH.parent.exists():
torch.save(model.state_dict(), SAVE_MODEL_PATH)
print(f"Процесс Fine-Tuning завершен. Артефакт сохранен: {SAVE_MODEL_PATH.name}")
if RESUME_CHECKPOINT.exists():
RESUME_CHECKPOINT.unlink()
+84 -83
View File
@@ -1,96 +1,97 @@
import joblib
import numpy as np
import pandas as pd
import joblib
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
# Калибровочные координаты центров эмоциональных классов в пространстве Рассела [1.0 - 9.0]
EMOTION_TO_VA_COORDS = {
0: (7.5, 6.5), # amusement
1: (2.0, 8.0), # anger
2: (6.5, 5.0), # awe
3: (7.0, 3.0), # contentment
4: (3.0, 6.0), # disgust
5: (8.0, 8.0), # excitement
6: (2.5, 7.5), # fear
7: (2.0, 2.0), # sadness
# 1. Настройка путей
embeddings_path = Path("./src/emoset_test_embeddings.npy")
csv_path = Path("./NFS/Thesis/Emoset/EmoSet-118K/test/labels.csv")
model_path = Path("./src/music_engine/va_regressor.pkl")
output_dir = Path("./src/scripts/media")
output_file = output_dir / "metrics_output.txt"
# 2. Корректный маппинг 8 классов EmoSet в шкалу DEAM [1.0, 9.0]
# Формула перевода из [-1, 1] в [1, 9]: 5.0 + (X * 4.0)
EMO_TO_VA = {
"amusement": [8.2, 6.6], # Веселье (Высокий позитив, средняя энергия)
"awe": [7.0, 7.4], # Восхищение (Позитив, высокая энергия)
"contentment": [7.8, 3.4], # Умиротворение (Позитив, низкая энергия)
"excitement": [8.2, 8.2], # Возбуждение (Макс. позитив, макс. энергия)
"anger": [2.2, 7.8], # Гнев (Глубокий негатив, высокая энергия)
"disgust": [2.6, 6.6], # Отвращение (Негатив, средняя энергия)
"fear": [2.6, 8.2], # Страх (Негатив, максимальная энергия)
"sadness": [2.2, 2.6] # Грусть (Глубокий негатив, низкая энергия)
}
def evaluate_regression_model():
# Инициализация путей к артефактам пайплайна
base_dir = Path(__file__).resolve().parent.parent.parent
embeddings_path = base_dir / "src" / "emoset_test_embeddings.npy"
labels_path = base_dir / "src" / "emoset_test_labels.npy"
model_path = base_dir / "src" / "music_engine" / "va_regressor.pkl"
if not all(p.exists() for p in [embeddings_path, labels_path, model_path]):
print("Отсутствуют необходимые артефакты для расчета метрик.")
def generate_slide_metrics():
print("[INFO] Загрузка тестовых артефактов...")
if not all(p.exists() for p in [embeddings_path, csv_path, model_path]):
print("[ERROR] Проверьте наличие файлов данных или модели регрессора.")
return
# Загрузка скрытых представлений и инициализация регрессора
x_features = np.load(embeddings_path)
y_discrete = np.load(labels_path)
regression_pipeline = joblib.load(model_path)
# Маппинг дискретных меток в непрерывные координаты
y_continuous = np.array([EMOTION_TO_VA_COORDS[label] for label in y_discrete])
# Изоляция тестовой выборки (сохранение детерминированности через random_state)
_, x_test, _, y_test = train_test_split(x_features, y_continuous, test_size=0.2, random_state=42)
# Генерация предсказаний на отложенной выборке
y_pred = regression_pipeline.predict(x_test)
# Расчет метрик качества регрессии (Mean Squared Error, R-squared)
mse_valence = mean_squared_error(y_test[:, 0], y_pred[:, 0])
r2_valence = r2_score(y_test[:, 0], y_pred[:, 0])
mse_arousal = mean_squared_error(y_test[:, 1], y_pred[:, 1])
r2_arousal = r2_score(y_test[:, 1], y_pred[:, 1])
print("Метрики качества регрессионной модели на тестовой выборке:")
print(f"Valence -> MSE: {mse_valence:.4f} | R^2: {r2_valence:.4f}")
print(f"Arousal -> MSE: {mse_arousal:.4f} | R^2: {r2_arousal:.4f}")
# Построение диагностических диаграмм рассеяния (Scatter Plots)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
# Конфигурация подграфика: Ось Валентности
ax1.scatter(y_test[:, 0], y_pred[:, 0], alpha=0.3, color='#1f77b4', edgecolors='none', label='Прогноз регрессора')
ax1.plot([1, 9], [1, 9], 'r--', lw=2, label='Идеальное совпадение (x=y)')
ax1.set_title('Диаграмма рассеяния: Valence (Позитивность)', fontsize=14, fontweight='bold')
ax1.set_xlabel('Эталонные значения (центры классов)', fontsize=12)
ax1.set_ylabel('Непрерывные предсказания модели', fontsize=12)
ax1.set_xlim(1, 9)
ax1.set_ylim(1, 9)
ax1.grid(True, linestyle='--', alpha=0.6)
ax1.legend(loc='upper left', fontsize=10)
# Научное обоснование распределения данных для комиссии
ax1.text(1.2, 8.2,
'Формирование вертикальных кластеров\n'
'обусловлено проекцией 8 дискретных\n'
'базовых эмоций на непрерывную\n'
'координатную плоскость.',
fontsize=10, bbox=dict(facecolor='white', alpha=0.9, edgecolor='gray'))
output_dir.mkdir(parents=True, exist_ok=True)
# Конфигурация подграфика: Ось Активности
ax2.scatter(y_test[:, 1], y_pred[:, 1], alpha=0.3, color='#ff7f0e', edgecolors='none', label='Прогноз регрессора')
ax2.plot([1, 9], [1, 9], 'r--', lw=2, label='Идеальное совпадение (x=y)')
ax2.set_title('Диаграмма рассеяния: Arousal (Активность)', fontsize=14, fontweight='bold')
ax2.set_xlabel('Эталонные значения (центры классов)', fontsize=12)
ax2.set_ylabel('Непрерывные предсказания модели', fontsize=12)
ax2.set_xlim(1, 9)
ax2.set_ylim(1, 9)
ax2.grid(True, linestyle='--', alpha=0.6)
ax2.legend(loc='upper left', fontsize=10)
# 3. Загрузка эмбеддингов и меток
X_test = np.load(embeddings_path)
df = pd.read_csv(csv_path)
plt.tight_layout()
plt.savefig('regression_metrics_plot.png', dpi=300, bbox_inches='tight')
print("Диагностические графики экспортированы в regression_metrics_plot.png")
if len(X_test) != len(df):
print(f"[WARN] Корректировка размеров выборки: Эмбеддинги ({len(X_test)}) != Метки ({len(df)})")
min_len = min(len(X_test), len(df))
X_test = X_test[:min_len]
df = df.iloc[:min_len]
y_test_list = [EMO_TO_VA.get(label.lower().strip(), [5.0, 5.0]) for label in df['label']]
y_test = np.array(y_test_list)
# 4. Выполнение инференса
print("[INFO] Выполнение инференса регрессора на скрытом пространстве признаков...")
regressor = joblib.load(model_path)
y_pred = regressor.predict(X_test)
# === БЛОК ДИАГНОСТИКИ ШКАЛЫ ===
print("\n" + "-"*50)
print(" ДИАГНОСТИКА ДИАПАЗОНОВ ЗНАЧЕНИЙ ".center(50))
print("-"*50)
print(f"Истинные (y_test) -> Мин: {y_test.min():.2f}, Макс: {y_test.max():.2f}, Среднее: {y_test.mean():.2f}")
print(f"Предсказания (y_pred) -> Мин: {y_pred.min():.2f}, Макс: {y_pred.max():.2f}, Среднее: {y_pred.mean():.2f}")
print("-"*50 + "\n")
# ==============================
# 5. Расчет метрик
mse_v = mean_squared_error(y_test[:, 0], y_pred[:, 0])
r2_v = r2_score(y_test[:, 0], y_pred[:, 0])
mse_a = mean_squared_error(y_test[:, 1], y_pred[:, 1])
r2_a = r2_score(y_test[:, 1], y_pred[:, 1])
mse_total = mean_squared_error(y_test, y_pred)
r2_total = r2_score(y_test, y_pred)
# 6. Вывод и сохранение результатов
table_content = f"""
==================================================
ТАБЛИЦА МЕТРИК ДЛЯ СЛАЙДА 10
==================================================
| Метрика | Valence (V) | Arousal (A) | Общая (Total) |
|------------|--------------|--------------|---------------|
| MSE | {mse_v:<12.4f} | {mse_a:<12.4f} | {mse_total:<13.4f} |
| R² | {r2_v:<12.4f} | {r2_a:<12.4f} | {r2_total:<13.4f} |
==================================================
Формула целевой функции для вставки на слайд (LaTeX):
$$Score_{{final}} = D_{{emo}} + 4.0 \cdot Acoustic_{{penalty}}$$
"""
print(table_content)
with open(output_file, 'w', encoding='utf-8') as f:
f.write(table_content)
print(f"[SUCCESS] Метрики успешно сохранены в файл: {output_file.absolute()}")
if __name__ == "__main__":
evaluate_regression_model()
generate_slide_metrics()
Binary file not shown.

After

Width:  |  Height:  |  Size: 313 KiB

+12
View File
@@ -0,0 +1,12 @@
==================================================
ТАБЛИЦА МЕТРИК ДЛЯ СЛАЙДА 10
==================================================
| Метрика | Valence (V) | Arousal (A) | Общая (Total) |
|------------|--------------|--------------|---------------|
| MSE | 1.5135 | 2.2743 | 1.8939 |
| R² | 0.7927 | 0.4321 | 0.6124 |
==================================================
Формула целевой функции для вставки на слайд (LaTeX):
$$Score_{final} = D_{emo} + 4.0 \cdot Acoustic_{penalty}$$
Binary file not shown.

After

Width:  |  Height:  |  Size: 243 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

+319
View File
@@ -0,0 +1,319 @@
import os
import random
import warnings
from collections import defaultdict
from pathlib import Path
from PIL import Image, ImageFile
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torch.amp import autocast, GradScaler
import timm
# Подавление предупреждений и защита от битых "хвостов" JPEG
warnings.filterwarnings("ignore")
ImageFile.LOAD_TRUNCATED_IMAGES = True
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {DEVICE}")
# --- ПУТИ ---
TRAIN_ROOT = Path("./dataset/Original-2.41M")
ANCHOR_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/train") # ЯКОРЬ (Чистые данные для обучения)
VAL_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/val")
SAVE_MODEL_PATH = Path("./src/emosetV2_resnet50_finetuned_2_41M.pth")
RESUME_CHECKPOINT = Path("./src/finetuneV2_resume.pth")
PRETRAINED_PATH = Path("./src/emosetV2_resnet50_best.pth")
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
}
# --- НАСТРОЙКИ ---
TOTAL_BATCH_SIZE = 64
BATCH_NOISY = 48 # 75% батча - новые данные 2.41M
BATCH_ANCHOR = 16 # 25% батча - чистые якорные данные 118K
EPOCHS_PER_FOLDER = 15
PATIENCE = 5
LR = 1e-6
NUM_TRAIN_WORKERS = 32
NUM_VAL_WORKERS = 32
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
# --- 1. ТРАНСФОРМАЦИИ ---
train_transform = T.Compose([
T.Resize(256),
T.RandomResizedCrop(224, scale=(0.8, 1.0)),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# --- 2. ДАТАСЕТЫ ---
class ChunkTrainDataset(Dataset):
def __init__(self, paths, transform):
self.paths = paths
self.transform = transform
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
path = self.paths[idx]
try:
img = Image.open(path).convert('RGB')
tensor = self.transform(img)
label = CLASS_MAPPING.get(path.parts[-3].lower(), 0)
return tensor, label
except Exception:
return torch.zeros((3, 224, 224)), 0
class CsvDataset(Dataset):
def __init__(self, root, transform):
self.root = Path(root)
self.df = pd.read_csv(self.root / "labels.csv")
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
path = self.root / "images" / row["filename"]
try:
img = Image.open(path).convert('RGB')
tensor = self.transform(img)
label = CLASS_MAPPING.get(row["label"].lower(), 0)
return tensor, label
except Exception:
return torch.zeros((3, 224, 224)), 0
# --- 3. СБОР ДАННЫХ ---
def prepare_chunks():
print("\nСканирование датасета 2.41M...")
chunk_dict = defaultdict(list)
for path in TRAIN_ROOT.rglob('*.jpg'):
emotion = path.parts[-3].lower()
if emotion not in CLASS_MAPPING:
continue
folder_str = path.parts[-2]
if folder_str.isdigit():
chunk_dict[int(folder_str)].append(path)
sorted_chunks = sorted(chunk_dict.keys())
print(f"Найдено пронумерованных папок (чанков): {len(sorted_chunks)}")
return chunk_dict, sorted_chunks
# --- 4. ОСНОВНОЙ ЦИКЛ ОБУЧЕНИЯ ---
if __name__ == "__main__":
chunk_dict, sorted_chunks = prepare_chunks()
# Валидационный датасет (только чистые данные)
val_loader = DataLoader(
CsvDataset(VAL_118K_ROOT, val_transform),
batch_size=TOTAL_BATCH_SIZE, shuffle=False,
num_workers=NUM_VAL_WORKERS, pin_memory=True
)
# ЯКОРНЫЙ ЗАГРУЗЧИК (Чистые данные для подмешивания)
# Используем prefetch_factor и persistent_workers для устранения рывков CPU
anchor_dataset = CsvDataset(ANCHOR_118K_ROOT, train_transform)
anchor_loader = DataLoader(
anchor_dataset, batch_size=BATCH_ANCHOR, shuffle=True,
num_workers=16, pin_memory=True, drop_last=True,
prefetch_factor=2, persistent_workers=False
)
# Инициализация модели
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
if PRETRAINED_PATH.exists():
model.load_state_dict(torch.load(PRETRAINED_PATH, map_location=DEVICE))
print(f"Базовые веса загружены из {PRETRAINED_PATH.name}")
# Размораживаем всю модель
for param in model.parameters():
param.requires_grad = True
# Дифференцированный оптимизатор
backbone_params = [p for n, p in model.named_parameters() if "fc" not in n]
fc_params = [p for n, p in model.named_parameters() if "fc" in n]
optimizer = torch.optim.AdamW([
{'params': backbone_params, 'lr': LR}, # 1e-6: микро-шаг для основы
{'params': fc_params, 'lr': LR * 10} # 1e-5: шаг для классификатора
], weight_decay=1e-3)
# Label Smoothing помогает игнорировать мусор в разметке 2.41M
criterion = nn.CrossEntropyLoss(label_smoothing=0.15)
scaler = GradScaler()
# --- ПАРАМЕТРЫ ВОССТАНОВЛЕНИЯ ---
start_stage = 0
start_epoch = 1
best_val_loss = float('inf')
if RESUME_CHECKPOINT.exists():
print(f"\nОбнаружен чекпоинт: {RESUME_CHECKPOINT.name}. Восстановление...")
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
try:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
except Exception as e:
print(f"Оптимизатор сброшен: {e}")
best_val_loss = checkpoint['best_val_loss']
start_stage = checkpoint['stage']
start_epoch = checkpoint['epoch'] + 1
print(f"Успешный запуск с ЭТАПА {start_stage + 1}, Эпохи {start_epoch}. Best Val Loss: {best_val_loss:.4f}\n")
else:
# --- ЗАМЕР EPOCH 0 (БАЗОВАЯ ТОЧНОСТЬ) ---
# Выполняется только если мы начинаем с нуля
print("\n[Проверка базовых весов перед обучением (Epoch 0)]")
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc="Baseline Eval", smoothing=0):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
with autocast(device_type="cuda"):
outputs = model(inputs)
v_loss = criterion(outputs, labels)
val_loss += v_loss.item() * inputs.size(0)
_, pred = outputs.max(1)
val_total += labels.size(0)
val_correct += pred.eq(labels).sum().item()
best_val_loss = val_loss / val_total
baseline_acc = val_correct / val_total
print(f"Стартовая точка -> Val Loss: {best_val_loss:.4f} | Val Acc: {baseline_acc:.4f}\n")
# ВОССТАНОВЛЕНИЕ НАКОПЛЕННЫХ ДАННЫХ
current_train_paths = []
for s in range(start_stage):
current_train_paths.extend(chunk_dict[sorted_chunks[s]])
print("Старт Anchor Curriculum Learning (Смешивание чистых и шумных данных).")
# ГЛАВНЫЙ ЦИКЛ ПО ПАПКАМ
for stage in range(start_stage, len(sorted_chunks)):
chunk_id = sorted_chunks[stage]
print(f"\n{'='*50}")
print(f"ЭТАП {stage+1}/{len(sorted_chunks)}: Добавляем папку '{chunk_id}'")
# Накопление и перемешивание
current_train_paths.extend(chunk_dict[chunk_id])
random.shuffle(current_train_paths)
print(f"Всего файлов (грязных) в текущем пуле: {len(current_train_paths)}")
# ОСНОВНОЙ ЗАГРУЗЧИК (Грязные данные) с PREFETCH
train_loader = DataLoader(
ChunkTrainDataset(current_train_paths, train_transform),
batch_size=BATCH_NOISY, shuffle=True,
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
worker_init_fn=worker_init_fn, drop_last=True,
prefetch_factor=4, persistent_workers=True # Устраняет рывки CPU
)
epochs_no_improve = 0
first_epoch = start_epoch if stage == start_stage else 1
# Инициализация итератора якорей
anchor_iter = iter(anchor_loader)
# ЦИКЛ ЭПОХ ДЛЯ ТЕКУЩЕГО ЭТАПА
for epoch in range(first_epoch, EPOCHS_PER_FOLDER + 1):
model.train()
train_loss, train_correct, train_total = 0.0, 0, 0
for noisy_inputs, noisy_labels in tqdm(train_loader, desc=f"S{stage+1}-Ep{epoch}/{EPOCHS_PER_FOLDER} [Train]", smoothing=0):
# Достаем якорный чистый батч
try:
anc_inputs, anc_labels = next(anchor_iter)
except StopIteration:
anchor_iter = iter(anchor_loader)
anc_inputs, anc_labels = next(anchor_iter)
# СМЕШИВАЕМ БАТЧИ (Грязные + Чистые)
inputs = torch.cat([noisy_inputs, anc_inputs]).to(DEVICE)
labels = torch.cat([noisy_labels, anc_labels]).to(DEVICE)
optimizer.zero_grad(set_to_none=True)
with autocast(device_type="cuda"):
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item() * inputs.size(0)
_, pred = outputs.max(1)
train_total += labels.size(0)
train_correct += pred.eq(labels).sum().item()
# ВАЛИДАЦИЯ
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc="[Val]", leave=False, smoothing=0):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
with autocast(device_type="cuda"):
outputs = model(inputs)
v_loss = criterion(outputs, labels)
val_loss += v_loss.item() * inputs.size(0)
_, pred = outputs.max(1)
val_total += labels.size(0)
val_correct += pred.eq(labels).sum().item()
avg_train_loss = train_loss / train_total
avg_train_acc = train_correct / train_total
avg_val_loss = val_loss / val_total
avg_val_acc = val_correct / val_total
print(f"S{stage+1}-E{epoch} | Train L: {avg_train_loss:.4f}, Acc: {avg_train_acc:.4f} | Val L: {avg_val_loss:.4f}, Acc: {avg_val_acc:.4f}")
# СОХРАНЕНИЕ ЛУЧШИХ ВЕСОВ
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
epochs_no_improve = 0
torch.save(model.state_dict(), SAVE_MODEL_PATH)
print("--> Обновлены лучшие веса")
else:
epochs_no_improve += 1
# АВАРИЙНОЕ СОХРАНЕНИЕ В КОНЦЕ ЭПОХИ
checkpoint_state = {
'stage': stage,
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss
}
torch.save(checkpoint_state, RESUME_CHECKPOINT)
os.sync() # Защита от отключения электричества
print(f"--> Чекпоинт (Этап {stage+1}, Эпоха {epoch}) зафиксирован на диске.")
# РАННЯЯ ОСТАНОВКА ДЛЯ ТЕКУЩЕГО ЭТАПА
if epochs_no_improve >= PATIENCE:
print(f"Ранняя остановка для ЭТАПА {stage+1}. Переход к следующей папке...")
break
# Сброс счетчика стартовой эпохи после прохождения восстановительного этапа
start_epoch = 1
Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.8 MiB

+461
View File
@@ -0,0 +1,461 @@
.
├── bin
│   ├── activate
│   ├── activate.csh
│   ├── activate.fish
│   ├── activate.nu
│   ├── activate.ps1
│   ├── activate_this.py
│   ├── debugpy
│   ├── debugpy-adapter
│   ├── f2py
│   ├── fonttools
│   ├── httpx
│   ├── ipython
│   ├── ipython3
│   ├── isympy
│   ├── jlpm
│   ├── jsonpointer
│   ├── jsonschema
│   ├── jupyter
│   ├── jupyter-dejavu
│   ├── jupyter-events
│   ├── jupyter-execute
│   ├── jupyter-kernel
│   ├── jupyter-kernelspec
│   ├── jupyter-lab
│   ├── jupyter-labextension
│   ├── jupyter-labhub
│   ├── jupyter-migrate
│   ├── jupyter-nbconvert
│   ├── jupyter-run
│   ├── jupyter-server
│   ├── jupyter-troubleshoot
│   ├── jupyter-trust
│   ├── normalizer
│   ├── numpy-config
│   ├── pip
│   ├── pip3
│   ├── pip3.12
│   ├── proton
│   ├── proton-viewer
│   ├── pybabel
│   ├── pyftmerge
│   ├── pyftsubset
│   ├── pygmentize
│   ├── pyjson5
│   ├── python -> /usr/bin/python3
│   ├── python3 -> python
│   ├── python3.12 -> python
│   ├── send2trash
│   ├── streamlit
│   ├── streamlit.cmd
│   ├── torchfrtrace
│   ├── torchrun
│   ├── tqdm
│   ├── ttx
│   ├── watchmedo
│   └── wsdump
├── CACHEDIR.TAG
├── docker
│   ├── Dockerfile.api
│   └── Dockerfile.ui
├── docker-compose.yml
├── Dockerfile
├── .dockerignore
├── .env
├── etc
│   └── jupyter
│   ├── jupyter_notebook_config.d
│   │   └── jupyterlab.json
│   ├── jupyter_server_config.d
│   │   ├── jupyterlab.json
│   │   ├── jupyter-lsp-jupyter-server.json
│   │   ├── jupyter_server_terminals.json
│   │   └── notebook_shim.json
│   └── nbconfig
│   └── notebook.d
├── .gitignore
├── .idea
│   ├── .gitignore
│   ├── inspectionProfiles
│   │   └── profiles_settings.xml
│   ├── misc.xml
│   ├── modules.xml
│   ├── Thesis.iml
│   ├── vcs.xml
│   └── workspace.xml
├── lib
│   └── python3.12
│   └── site-packages
│   ├── altair
│   ├── altair-6.0.0.dist-info
│   ├── anyio
│   ├── anyio-4.12.1.dist-info
│   ├── argon2
│   ├── argon2_cffi-25.1.0.dist-info
│   ├── _argon2_cffi_bindings
│   ├── argon2_cffi_bindings-25.1.0.dist-info
│   ├── arrow
│   ├── arrow-1.4.0.dist-info
│   ├── asttokens
│   ├── asttokens-3.0.1.dist-info
│   ├── async_lru
│   ├── async_lru-2.0.5.dist-info
│   ├── attr
│   ├── attrs
│   ├── attrs-25.4.0.dist-info
│   ├── babel
│   ├── babel-2.17.0.dist-info
│   ├── beautifulsoup4-4.14.3.dist-info
│   ├── bleach
│   ├── bleach-6.3.0.dist-info
│   ├── blinker
│   ├── blinker-1.9.0.dist-info
│   ├── bs4
│   ├── cachetools
│   ├── cachetools-6.2.4.dist-info
│   ├── certifi
│   ├── certifi-2026.1.4.dist-info
│   ├── cffi
│   ├── cffi-2.0.0.dist-info
│   ├── _cffi_backend.cpython-312-x86_64-linux-gnu.so
│   ├── charset_normalizer
│   ├── charset_normalizer-3.4.4.dist-info
│   ├── click
│   ├── click-8.3.1.dist-info
│   ├── comm
│   ├── comm-0.2.3.dist-info
│   ├── contourpy
│   ├── contourpy-1.3.3.dist-info
│   ├── cycler
│   ├── cycler-0.12.1.dist-info
│   ├── dateutil
│   ├── debugpy
│   ├── debugpy-1.8.19.dist-info
│   ├── decorator-5.2.1.dist-info
│   ├── decorator.py
│   ├── defusedxml
│   ├── defusedxml-0.7.1.dist-info
│   ├── _distutils_hack
│   ├── distutils-precedence.pth
│   ├── .DS_Store
│   ├── executing
│   ├── executing-2.2.1.dist-info
│   ├── fastjsonschema
│   ├── fastjsonschema-2.21.2.dist-info
│   ├── filelock
│   ├── filelock-3.20.3.dist-info
│   ├── fontTools
│   ├── fonttools-4.61.1.dist-info
│   ├── fqdn
│   ├── fqdn-1.5.1.dist-info
│   ├── fsspec
│   ├── fsspec-2026.1.0.dist-info
│   ├── functorch
│   ├── git
│   ├── gitdb
│   ├── gitdb-4.0.12.dist-info
│   ├── gitpython-3.1.46.dist-info
│   ├── google
│   ├── h11
│   ├── h11-0.16.0.dist-info
│   ├── httpcore
│   ├── httpcore-1.0.9.dist-info
│   ├── httpx
│   ├── httpx-0.28.1.dist-info
│   ├── idna
│   ├── idna-3.11.dist-info
│   ├── ipykernel
│   ├── ipykernel-7.1.0.dist-info
│   ├── ipykernel_launcher.py
│   ├── IPython
│   ├── ipython-9.9.0.dist-info
│   ├── ipython_pygments_lexers-1.1.1.dist-info
│   ├── ipython_pygments_lexers.py
│   ├── isoduration
│   ├── isoduration-20.11.0.dist-info
│   ├── isympy.py
│   ├── jedi
│   ├── jedi-0.19.2.dist-info
│   ├── jinja2
│   ├── jinja2-3.1.6.dist-info
│   ├── joblib
│   ├── joblib-1.5.3.dist-info
│   ├── json5
│   ├── json5-0.13.0.dist-info
│   ├── jsonpointer-3.0.0.dist-info
│   ├── jsonpointer.py
│   ├── jsonschema
│   ├── jsonschema-4.26.0.dist-info
│   ├── jsonschema_specifications
│   ├── jsonschema_specifications-2025.9.1.dist-info
│   ├── jupyter_client
│   ├── jupyter_client-8.8.0.dist-info
│   ├── jupyter_core
│   ├── jupyter_core-5.9.1.dist-info
│   ├── jupyter_events
│   ├── jupyter_events-0.12.0.dist-info
│   ├── jupyterlab
│   ├── jupyterlab-4.5.1.dist-info
│   ├── jupyterlab_pygments
│   ├── jupyterlab_pygments-0.3.0.dist-info
│   ├── jupyterlab_server
│   ├── jupyterlab_server-2.28.0.dist-info
│   ├── jupyter_lsp
│   ├── jupyter_lsp-2.3.0.dist-info
│   ├── jupyter.py
│   ├── jupyter_server
│   ├── jupyter_server-2.17.0.dist-info
│   ├── jupyter_server_terminals
│   ├── jupyter_server_terminals-0.5.3.dist-info
│   ├── kiwisolver
│   ├── kiwisolver-1.4.9.dist-info
│   ├── lark
│   ├── lark-1.3.1.dist-info
│   ├── markupsafe
│   ├── markupsafe-3.0.3.dist-info
│   ├── matplotlib
│   ├── matplotlib-3.10.8.dist-info
│   ├── matplotlib_inline
│   ├── matplotlib_inline-0.2.1.dist-info
│   ├── mistune
│   ├── mistune-3.2.0.dist-info
│   ├── mpl_toolkits
│   ├── mpmath
│   ├── mpmath-1.3.0.dist-info
│   ├── narwhals
│   ├── narwhals-2.15.0.dist-info
│   ├── nbclient
│   ├── nbclient-0.10.4.dist-info
│   ├── nbconvert
│   ├── nbconvert-7.16.6.dist-info
│   ├── nbformat
│   ├── nbformat-5.10.4.dist-info
│   ├── nest_asyncio-1.6.0.dist-info
│   ├── nest_asyncio.py
│   ├── networkx
│   ├── networkx-3.6.1.dist-info
│   ├── notebook_shim
│   ├── notebook_shim-0.2.4.dist-info
│   ├── numpy
│   ├── numpy-2.4.1.dist-info
│   ├── numpy.libs
│   ├── nvidia
│   ├── nvidia_cublas_cu12-12.8.4.1.dist-info
│   ├── nvidia_cuda_cupti_cu12-12.8.90.dist-info
│   ├── nvidia_cuda_nvrtc_cu12-12.8.93.dist-info
│   ├── nvidia_cuda_runtime_cu12-12.8.90.dist-info
│   ├── nvidia_cudnn_cu12-9.10.2.21.dist-info
│   ├── nvidia_cufft_cu12-11.3.3.83.dist-info
│   ├── nvidia_cufile_cu12-1.13.1.3.dist-info
│   ├── nvidia_curand_cu12-10.3.9.90.dist-info
│   ├── nvidia_cusolver_cu12-11.7.3.90.dist-info
│   ├── nvidia_cusparse_cu12-12.5.8.93.dist-info
│   ├── nvidia_cusparselt_cu12-0.7.1.dist-info
│   ├── nvidia_nccl_cu12-2.27.5.dist-info
│   ├── nvidia_nvjitlink_cu12-12.8.93.dist-info
│   ├── nvidia_nvshmem_cu12-3.3.20.dist-info
│   ├── nvidia_nvtx_cu12-12.8.90.dist-info
│   ├── packaging
│   ├── packaging-25.0.dist-info
│   ├── pandas
│   ├── pandas-2.3.3.dist-info
│   ├── pandocfilters-1.5.1.dist-info
│   ├── pandocfilters.py
│   ├── parso
│   ├── parso-0.8.5.dist-info
│   ├── pexpect
│   ├── pexpect-4.9.0.dist-info
│   ├── PIL
│   ├── pillow-12.1.0.dist-info
│   ├── pillow.libs
│   ├── pip
│   ├── pip-25.3.dist-info
│   ├── pkg_resources
│   ├── platformdirs
│   ├── platformdirs-4.5.1.dist-info
│   ├── prometheus_client
│   ├── prometheus_client-0.23.1.dist-info
│   ├── prompt_toolkit
│   ├── prompt_toolkit-3.0.52.dist-info
│   ├── protobuf-6.33.4.dist-info
│   ├── psutil
│   ├── psutil-7.2.1.dist-info
│   ├── ptyprocess
│   ├── ptyprocess-0.7.0.dist-info
│   ├── pure_eval
│   ├── pure_eval-0.2.3.dist-info
│   ├── pyarrow
│   ├── pyarrow-22.0.0.dist-info
│   ├── pycparser
│   ├── pycparser-2.23.dist-info
│   ├── pydeck
│   ├── pydeck-0.9.1.dist-info
│   ├── pygments
│   ├── pygments-2.19.2.dist-info
│   ├── pylab.py
│   ├── pyparsing
│   ├── pyparsing-3.3.1.dist-info
│   ├── python_dateutil-2.9.0.post0.dist-info
│   ├── pythonjsonlogger
│   ├── python_json_logger-4.0.0.dist-info
│   ├── pytz
│   ├── pytz-2025.2.dist-info
│   ├── pyyaml-6.0.3.dist-info
│   ├── pyzmq-27.1.0.dist-info
│   ├── pyzmq.libs
│   ├── referencing
│   ├── referencing-0.37.0.dist-info
│   ├── requests
│   ├── requests-2.32.5.dist-info
│   ├── rfc3339_validator-0.1.4.dist-info
│   ├── rfc3339_validator.py
│   ├── rfc3986_validator-0.1.1.dist-info
│   ├── rfc3986_validator.py
│   ├── rfc3987_syntax
│   ├── rfc3987_syntax-1.1.0.dist-info
│   ├── rpds
│   ├── rpds_py-0.30.0.dist-info
│   ├── scikit_learn-1.8.0.dist-info
│   ├── scikit_learn.libs
│   ├── scipy
│   ├── scipy-1.17.0.dist-info
│   ├── scipy.libs
│   ├── send2trash
│   ├── send2trash-2.0.0.dist-info
│   ├── setuptools
│   ├── setuptools-80.9.0.dist-info
│   ├── six-1.17.0.dist-info
│   ├── six.py
│   ├── sklearn
│   ├── smmap
│   ├── smmap-5.0.2.dist-info
│   ├── soupsieve
│   ├── soupsieve-2.8.1.dist-info
│   ├── stack_data
│   ├── stack_data-0.6.3.dist-info
│   ├── streamlit
│   ├── streamlit-1.53.0.dist-info
│   ├── sympy
│   ├── sympy-1.14.0.dist-info
│   ├── tenacity
│   ├── tenacity-9.1.2.dist-info
│   ├── terminado
│   ├── terminado-0.18.1.dist-info
│   ├── threadpoolctl-3.6.0.dist-info
│   ├── threadpoolctl.py
│   ├── tinycss2
│   ├── tinycss2-1.4.0.dist-info
│   ├── toml
│   ├── toml-0.10.2.dist-info
│   ├── torch
│   ├── torch-2.9.1.dist-info
│   ├── torchaudio
│   ├── torchaudio-2.9.1.dist-info
│   ├── torchgen
│   ├── torchvision
│   ├── torchvision-0.24.1.dist-info
│   ├── torchvision.libs
│   ├── tornado
│   ├── tornado-6.5.4.dist-info
│   ├── tqdm
│   ├── tqdm-4.67.1.dist-info
│   ├── traitlets
│   ├── traitlets-5.14.3.dist-info
│   ├── triton
│   ├── triton-3.5.1.dist-info
│   ├── typing_extensions-4.15.0.dist-info
│   ├── typing_extensions.py
│   ├── tzdata
│   ├── tzdata-2025.3.dist-info
│   ├── uri_template
│   ├── uri_template-1.3.0.dist-info
│   ├── urllib3
│   ├── urllib3-2.6.3.dist-info
│   ├── _virtualenv.pth
│   ├── _virtualenv.py
│   ├── watchdog
│   ├── watchdog-6.0.0.dist-info
│   ├── wcwidth
│   ├── wcwidth-0.2.14.dist-info
│   ├── webcolors
│   ├── webcolors-25.10.0.dist-info
│   ├── webencodings
│   ├── webencodings-0.5.1.dist-info
│   ├── websocket
│   ├── websocket_client-1.9.0.dist-info
│   ├── _yaml
│   ├── yaml
│   └── zmq
├── Makefile
├── NFS
├── poetry.lock
├── pyproject.toml
├── pyvenv.cfg
├── README.md
├── requirements.txt
├── runs
├── share
│   ├── applications
│   │   └── jupyterlab.desktop
│   ├── icons
│   │   └── hicolor
│   │   └── scalable
│   ├── jupyter
│   │   ├── kernels
│   │   │   └── python3
│   │   ├── lab
│   │   │   ├── schemas
│   │   │   ├── static
│   │   │   └── themes
│   │   ├── labextensions
│   │   │   └── jupyterlab_pygments
│   │   ├── nbconvert
│   │   │   └── templates
│   │   └── nbextensions
│   │   └── pydeck
│   └── man
│   └── man1
│   ├── ipython.1
│   ├── isympy.1
│   └── ttx.1
├── src
│   ├── 5_epoch_emoset_resnet50_finetuned_2.41M.pth
│   ├── api.py
│   ├── data_loader.py
│   ├── dataset_paths_cache.pkl
│   ├── emoset_resnet50_best.pth
│   ├── emoset_resnet50_finetuned_2_41M_best.pth
│   ├── emoset_resnet50_resume.pth
│   ├── emoset_test_embeddings.npy
│   ├── emoset_test_labels.npy
│   ├── main.py
│   ├── music_engine
│   │   ├── image_processor.py
│   │   ├── __init__.py
│   │   ├── llm_bridge.py
│   │   ├── matcher.py
│   │   └── va_regressor.pkl
│   ├── scripts
│   │   ├── 00_setup_env.sh
│   │   ├── 01_download_DEAM.py
│   │   ├── 02_download_EmoSet.py
│   │   ├── 11_prerp_DEAM.py
│   │   ├── 20_bench_GPU.py
│   │   ├── 21_train_images.ipynb
│   │   ├── 22_extract_embeddings.ipynb
│   │   ├── 23_aggregate_DEAM_timeline.py
│   │   ├── 24_train_regressor.py
│   │   ├── 31_finetune_2.41M.py
│   │   ├── 90_acc_images_model.ipynb
│   │   └── 91_generate_metrics.py
│   └── tabs
│   ├── tab_dataset.py
│   └── tab_live.py
├── tree.txt
└── .vscode
├── launch.json
└── tasks.json
322 directories, 137 files