Compare commits
7 Commits
875616730b
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 934a4cbff4 | |||
| 14968dd4d4 | |||
| daba573b2c | |||
| 8648e52106 | |||
| e3a2eb3289 | |||
| a57addcbb1 | |||
| 3850b15053 |
@@ -0,0 +1,18 @@
|
||||
bin/
|
||||
lib/
|
||||
share/
|
||||
etc/
|
||||
include/
|
||||
pyvenv.cfg
|
||||
.idea/
|
||||
.vscode/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.git/
|
||||
runs/
|
||||
dataset/
|
||||
NFS/
|
||||
*.pth
|
||||
*.pkl
|
||||
*.npy
|
||||
.env
|
||||
+3
-7
@@ -1,13 +1,11 @@
|
||||
# Базовый образ среды выполнения PyTorch
|
||||
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime
|
||||
|
||||
# Конфигурация интерпретатора Python (отключение генерации байткода и буферизации вывода)
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Системные библиотеки для низкоуровневой обработки изображений
|
||||
# System dependencies for OpenCV and image processing
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
@@ -15,15 +13,13 @@ RUN apt-get update && apt-get install -y \
|
||||
libxrender-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Интеграция Python-зависимостей
|
||||
# Install python packages
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Модули программного комплекса
|
||||
# Copy source code
|
||||
COPY src/ /app/src/
|
||||
|
||||
# Сетевой интерфейс UI
|
||||
EXPOSE 8080
|
||||
|
||||
# Точка входа контейнера
|
||||
CMD ["streamlit", "run", "src/main.py", "--server.port", "8080", "--server.address", "0.0.0.0"]
|
||||
@@ -0,0 +1,63 @@
|
||||
version: '3.8'
|
||||
|
||||
networks:
|
||||
emom_mesh:
|
||||
driver: bridge
|
||||
|
||||
services:
|
||||
emom_ui:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/Dockerfile.ui
|
||||
container_name: emom_web_ui
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "8080:8080"
|
||||
networks:
|
||||
- emom_mesh
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./src:/app/src
|
||||
- ${DATA_DEAM_DIR}:/app/dataset/DEAM:ro
|
||||
depends_on:
|
||||
- emom_inference
|
||||
|
||||
emom_inference:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/Dockerfile.api
|
||||
container_name: emom_pytorch_api
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- emom_mesh
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ${HOST_ARTIFACTS_DIR}/emoset_resnet50_best.pth:/app/src/emoset_resnet50_best.pth:ro
|
||||
- ${HOST_ARTIFACTS_DIR}/music_engine/va_regressor.pkl:/app/src/music_engine/va_regressor.pkl:ro
|
||||
- ${DATA_DEAM_DIR}:/app/dataset/DEAM:ro
|
||||
- ~/.cache/huggingface:/root/.cache/huggingface
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
emom_ollama:
|
||||
image: ollama/ollama:latest
|
||||
container_name: emom_ollama_engine
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- emom_mesh
|
||||
volumes:
|
||||
- ~/.ollama:/root/.ollama
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
@@ -0,0 +1,19 @@
|
||||
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libglib2.0-0 libsm6 libxext6 libxrender-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install --no-cache-dir fastapi uvicorn timm scikit-learn pandas joblib python-multipart transformers==4.38.2 tokenizers==0.15.2 accelerate
|
||||
|
||||
WORKDIR /app
|
||||
COPY src/ /app/src/
|
||||
|
||||
WORKDIR /app/src
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -0,0 +1,15 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
RUN pip install --no-cache-dir streamlit==1.32.0 requests pandas pillow
|
||||
|
||||
COPY src/ /app/src/
|
||||
|
||||
WORKDIR /app/src
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
CMD ["streamlit", "run", "main.py", "--server.port", "8080", "--server.address", "0.0.0.0"]
|
||||
@@ -1,64 +0,0 @@
|
||||
version: '3.8'
|
||||
|
||||
# Определение общих сетей для изоляции трафика
|
||||
networks:
|
||||
ai_mesh:
|
||||
driver: bridge
|
||||
|
||||
services:
|
||||
# ----------------------------------------------------
|
||||
# SERVICE 1: Frontend (Пользовательский интерфейс)
|
||||
# Не требует GPU, может быть вынесен на отдельный сервер
|
||||
# ----------------------------------------------------
|
||||
web_ui:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: emom_frontend
|
||||
restart: always
|
||||
ports:
|
||||
- "8080:8080"
|
||||
networks:
|
||||
- ai_mesh
|
||||
environment:
|
||||
- STREAMLIT_RUN=1
|
||||
# Указываем UI, где искать LLM-бэкенд (внутри Docker-сети)
|
||||
- OLLAMA_HOST=http://llm_backend:11434
|
||||
volumes:
|
||||
- ./src:/app/src
|
||||
# Модели пока остаются здесь, так как код монолитный,
|
||||
# но архитектурно сервис уже изолирован
|
||||
- /home/zin/projects/Thesis/src/emoset_resnet50_best.pth:/app/emoset_resnet50_best.pth:ro
|
||||
- /home/zin/projects/Thesis/src/music_engine/va_regressor.pkl:/app/src/music_engine/va_regressor.pkl:ro
|
||||
- /home/zin/projects/Thesis/dataset/DEAM:/app/dataset/DEAM:ro
|
||||
# Временно оставляем GPU для PyTorch (пока он не вынесен в API)
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
# ----------------------------------------------------
|
||||
# SERVICE 2: LLM Inference Backend (Ollama)
|
||||
# Изолированный сервис для языковой модели на GPU
|
||||
# ----------------------------------------------------
|
||||
llm_backend:
|
||||
image: ollama/ollama:latest
|
||||
container_name: ollama_gpu_inference
|
||||
restart: always
|
||||
networks:
|
||||
- ai_mesh
|
||||
ports:
|
||||
- "11434:11434"
|
||||
volumes:
|
||||
# Проброс локальных моделей Ollama, чтобы не качать их заново внутри докера
|
||||
- ~/.ollama:/root/.ollama
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
+76
@@ -0,0 +1,76 @@
|
||||
import io
|
||||
import traceback
|
||||
import numpy as np
|
||||
from typing import List
|
||||
from fastapi import FastAPI, UploadFile, File, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from PIL import Image
|
||||
|
||||
from data_loader import load_music_engine, load_image_processor
|
||||
from music_engine.llm_bridge import LLMAcousticBridge
|
||||
|
||||
app = FastAPI(title="EmoM API", version="1.0.0")
|
||||
|
||||
ml_context = {
|
||||
"image_processor": None,
|
||||
"music_matcher": None,
|
||||
"llm_bridge": None
|
||||
}
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
print("Loading ML models...")
|
||||
ml_context["image_processor"] = load_image_processor()
|
||||
ml_context["music_matcher"] = load_music_engine()
|
||||
ml_context["llm_bridge"] = LLMAcousticBridge()
|
||||
print("Initialization complete.")
|
||||
|
||||
@app.post("/analyze")
|
||||
async def analyze_event_endpoint(files: List[UploadFile] = File(...)):
|
||||
try:
|
||||
images = []
|
||||
for file in files:
|
||||
image_bytes = await file.read()
|
||||
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
||||
images.append(img)
|
||||
|
||||
print(f"Processing batch: {len(images)} images.")
|
||||
|
||||
img_processor = ml_context["image_processor"]
|
||||
matcher = ml_context["music_matcher"]
|
||||
llm = ml_context["llm_bridge"]
|
||||
|
||||
all_v, all_a = [], []
|
||||
all_objects = []
|
||||
|
||||
for img in images:
|
||||
embedding = img_processor.extract_embedding(img)
|
||||
v, a = matcher.predict_va(embedding)
|
||||
all_v.append(v)
|
||||
all_a.append(a)
|
||||
|
||||
caption = img_processor.describe_scene(img)
|
||||
all_objects.append(caption)
|
||||
|
||||
target_v = float(np.mean(all_v))
|
||||
target_a = float(np.mean(all_a))
|
||||
unique_semantics = list(set(all_objects))
|
||||
|
||||
llm_profile = llm.get_acoustic_profile(target_v, target_a, unique_semantics)
|
||||
|
||||
playlist_df = matcher.find_nearest_tracks(target_v, target_a, llm_profile=llm_profile, top_k=15)
|
||||
tracks_list = playlist_df.to_dict(orient="records")
|
||||
|
||||
return JSONResponse(content={
|
||||
"status": "success",
|
||||
"images_processed": len(images),
|
||||
"target_v": target_v,
|
||||
"target_a": target_a,
|
||||
"llm_profile": llm_profile,
|
||||
"semantics": unique_semantics,
|
||||
"tracks": tracks_list
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
+30
-39
@@ -1,55 +1,46 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Tuple, List, Optional, Any
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
|
||||
from music_engine.matcher import MusicMatcher
|
||||
from music_engine.image_processor import ImageProcessor
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
|
||||
@st.cache_resource
|
||||
def load_music_engine():
|
||||
# Инициализация базы данных и регрессора для музыкального мэтчинга
|
||||
def load_music_engine() -> MusicMatcher:
|
||||
#Инициализация модуля подбора музыкальных композиций.
|
||||
db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
|
||||
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
|
||||
|
||||
if not db_path.exists():
|
||||
print(f"Музыкальная БД не найдена: {db_path}")
|
||||
return None
|
||||
|
||||
return MusicMatcher(db_path=db_path, model_path=model_path)
|
||||
|
||||
@st.cache_resource
|
||||
def load_image_processor():
|
||||
# Модуль обработки визуальных признаков
|
||||
model_path = BASE_DIR / "emoset_resnet50_best.pth"
|
||||
def load_image_processor() -> ImageProcessor:
|
||||
#Инициализация модуля экстракции визуальных признаков.
|
||||
weights_path = BASE_DIR / "emoset_resnet50_best.pth"
|
||||
|
||||
# Обработка пути при вызове из корневой директории
|
||||
if not model_path.exists():
|
||||
model_path = BASE_DIR.parent / "emoset_resnet50_best.pth"
|
||||
return ImageProcessor(weights_path)
|
||||
|
||||
def load_emoset_data() -> Tuple[Optional[List[str]], Optional[np.ndarray], Optional[np.ndarray], Optional[Path]]:
|
||||
# Загрузка тестовой выборки датасета EmoSet.
|
||||
# Модуль сохранен для обеспечения обратной совместимости в отладочном контуре.
|
||||
try:
|
||||
images_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
|
||||
labels_path = BASE_DIR / "emoset_test_labels.npy"
|
||||
embeddings_path = BASE_DIR / "emoset_test_embeddings.npy"
|
||||
|
||||
return ImageProcessor(model_path=model_path)
|
||||
|
||||
@st.cache_data
|
||||
def load_emoset_data():
|
||||
# Выборка данных датасета для вкладки отладки
|
||||
dataset_root = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test"
|
||||
|
||||
csv_path = dataset_root / "labels.csv"
|
||||
img_dir = dataset_root / "images"
|
||||
emb_path = BASE_DIR / "emoset_test_embeddings.npy"
|
||||
lbl_path = BASE_DIR / "emoset_test_labels.npy"
|
||||
|
||||
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
|
||||
print("Тестовые файлы датасета не найдены, вкладка отладки может работать некорректно")
|
||||
return None, None, None, None
|
||||
|
||||
labels_df = pd.read_csv(csv_path)
|
||||
|
||||
test_filenames = labels_df['filename'].tolist()
|
||||
test_embeddings = np.load(emb_path)
|
||||
test_labels = np.load(lbl_path)
|
||||
|
||||
return test_filenames, test_embeddings, test_labels, img_dir
|
||||
if not all(p.exists() for p in [labels_path, embeddings_path]):
|
||||
return None, None, None, None
|
||||
|
||||
labels = np.load(labels_path)
|
||||
embeddings = np.load(embeddings_path)
|
||||
|
||||
csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
return df['filename'].tolist(), embeddings, labels, images_path
|
||||
|
||||
except Exception as e:
|
||||
print(f"[WARN] Failed to load EmoSet test artifacts: {str(e)}")
|
||||
return None, None, None, None
|
||||
Binary file not shown.
Binary file not shown.
+174
-58
@@ -1,73 +1,189 @@
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import requests
|
||||
import streamlit as st
|
||||
import streamlit.components.v1 as components
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from data_loader import load_music_engine, load_emoset_data, load_image_processor
|
||||
from tabs.tab_dataset import render_dataset_tab
|
||||
from tabs.tab_live import render_live_tab
|
||||
st.set_page_config(page_title="EmoM Playlist Generator", layout="wide", initial_sidebar_state="collapsed")
|
||||
|
||||
# Костыль для прямого запуска
|
||||
if __name__ == "__main__":
|
||||
if "STREAMLIT_RUN" not in os.environ:
|
||||
os.environ["STREAMLIT_RUN"] = "1"
|
||||
cmd = [sys.executable, "-m", "streamlit", "run", __file__, "--server.port", "8080", "--server.address", "0.0.0.0"]
|
||||
subprocess.run(cmd)
|
||||
sys.exit()
|
||||
API_URL = os.getenv("BACKEND_API_URL", "http://emom_inference:8000") + "/analyze"
|
||||
DEAM_AUDIO_DIR = "/app/dataset/DEAM/DEAM_audio/MEMD_audio"
|
||||
|
||||
viewport_mode = st.query_params.get("viewport", "desktop")
|
||||
page_layout = "centered" if viewport_mode == "mobile" else "wide"
|
||||
|
||||
st.set_page_config(page_title="Thesis Demo", layout=page_layout)
|
||||
|
||||
# Определения ширины экрана и смены верстки
|
||||
components.html(
|
||||
"""
|
||||
<script>
|
||||
const w = window.parent.innerWidth;
|
||||
const h = window.parent.innerHeight;
|
||||
const url = new URL(window.parent.location.href);
|
||||
def get_thumbnail_html(images, max_display=12):
|
||||
html_images = ""
|
||||
for file in images[:max_display]:
|
||||
img = Image.open(file)
|
||||
img.thumbnail((100, 100))
|
||||
if img.mode != "RGB":
|
||||
img = img.convert("RGB")
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
b64_str = base64.b64encode(buffered.getvalue()).decode()
|
||||
html_images += f'<img src="data:image/jpeg;base64,{b64_str}" style="width: 60px; height: 60px; object-fit: cover; border-radius: 8px; margin-right: 8px; margin-bottom: 8px; border: 1px solid rgba(255, 255, 255, 0.2);">'
|
||||
|
||||
// Считаем мобилкой, если ушли в портретный режим или экран уже 768px
|
||||
const isMobile = (h > w) || (w < 768);
|
||||
const target = isMobile ? "mobile" : "desktop";
|
||||
|
||||
if (url.searchParams.get("viewport") !== target) {
|
||||
url.searchParams.set("viewport", target);
|
||||
window.parent.location.href = url.href;
|
||||
}
|
||||
</script>
|
||||
""",
|
||||
height=0,
|
||||
width=0,
|
||||
)
|
||||
if len(images) > max_display:
|
||||
html_images += f'<span style="display: inline-block; width: 60px; height: 60px; line-height: 60px; text-align: center; background: rgba(150, 150, 150, 0.2); border-radius: 8px; vertical-align: top; font-size: 14px;">+{len(images) - max_display}</span>'
|
||||
return f'<div style="display: flex; flex-wrap: wrap;">{html_images}</div>'
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
def main():
|
||||
if "live_state" not in st.session_state:
|
||||
st.session_state.live_state = "upload"
|
||||
if "result_data" not in st.session_state:
|
||||
st.session_state.result_data = None
|
||||
|
||||
viewport = st.query_params.get("viewport", "desktop")
|
||||
|
||||
st.markdown("""
|
||||
<style>
|
||||
img { max-width: 100%; height: auto; object-fit: contain; }
|
||||
[data-testid="stMetricValue"] { font-size: 1.8rem; }
|
||||
[data-testid="stFileUploadDropzone"] { min-height: 250px !important; display: flex; align-items: center; justify-content: center; border-radius: 16px; background-color: rgba(255, 75, 75, 0.03); }
|
||||
.spinner-container { display: flex; flex-direction: column; align-items: center; justify-content: center; min-height: 40vh; margin-top: 10vh; }
|
||||
.big-spinner { width: 120px; height: 120px; border: 10px solid rgba(255, 75, 75, 0.1); border-top: 10px solid #ff4b4b; border-radius: 50%; animation: spin 1s linear infinite; margin-bottom: 2rem; }
|
||||
@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
|
||||
#MainMenu {visibility: hidden;} footer {visibility: hidden;}
|
||||
</style>
|
||||
""",
|
||||
unsafe_allow_html=True
|
||||
)
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
# Подгрузка ML-моделей и датасета
|
||||
music_matcher = load_music_engine()
|
||||
img_processor = load_image_processor()
|
||||
emoset_files, emoset_embeddings, emoset_labels, emoset_path = load_emoset_data()
|
||||
if st.session_state.live_state == "upload":
|
||||
upload_placeholder = st.empty()
|
||||
with upload_placeholder.container():
|
||||
st.write("Загрузите изображения для визуально-семантического анализа.")
|
||||
if viewport == "mobile":
|
||||
st.markdown("<br>", unsafe_allow_html=True)
|
||||
|
||||
uploaded_files = st.file_uploader(
|
||||
"Загрузка файлов",
|
||||
type=['png', 'jpg', 'jpeg'],
|
||||
accept_multiple_files=True,
|
||||
label_visibility="collapsed" if viewport == "mobile" else "visible"
|
||||
)
|
||||
|
||||
if uploaded_files:
|
||||
st.markdown("<br>", unsafe_allow_html=True)
|
||||
if st.button("Выполнить анализ", type="primary", use_container_width=True):
|
||||
st.session_state.uploaded_images = uploaded_files
|
||||
st.session_state.live_state = "processing"
|
||||
upload_placeholder.empty()
|
||||
st.rerun()
|
||||
|
||||
st.markdown("<br>", unsafe_allow_html=True)
|
||||
st.caption("Выбранные файлы:")
|
||||
st.markdown(get_thumbnail_html(uploaded_files), unsafe_allow_html=True)
|
||||
|
||||
st.title("Генератор саундтреков (Research Demo)")
|
||||
elif st.session_state.live_state == "processing":
|
||||
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
|
||||
files = st.session_state.get("uploaded_images", [])
|
||||
st.markdown('<div class="spinner-container"><div class="big-spinner"></div><h3 style="text-align: center; font-weight: 400;">Обработка данных...</h3></div>', unsafe_allow_html=True)
|
||||
|
||||
try:
|
||||
upload_data = [('files', (f.name, f.getvalue(), f.type)) for f in files]
|
||||
response = requests.post(API_URL, files=upload_data, timeout=300)
|
||||
|
||||
if response.status_code == 200:
|
||||
st.session_state.result_data = response.json()
|
||||
st.session_state.live_state = "result"
|
||||
st.rerun()
|
||||
else:
|
||||
st.error(f"Ошибка сервера: {response.status_code}")
|
||||
if st.button("Назад"):
|
||||
st.session_state.live_state = "upload"
|
||||
st.rerun()
|
||||
except Exception as e:
|
||||
st.error(f"Ошибка соединения: {str(e)}")
|
||||
if st.button("Назад"):
|
||||
st.session_state.live_state = "upload"
|
||||
st.rerun()
|
||||
|
||||
elif st.session_state.live_state == "result":
|
||||
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
|
||||
data = st.session_state.result_data
|
||||
|
||||
st.header(f"Сгенерированный плейлист (обработано файлов: {data['images_processed']})")
|
||||
|
||||
for row in data.get("tracks", []):
|
||||
with st.container(border=True):
|
||||
song_id = int(row['song_id'])
|
||||
score = row['final_score']
|
||||
|
||||
audio_path = f"{DEAM_AUDIO_DIR}/{song_id}.mp3"
|
||||
if not os.path.exists(audio_path):
|
||||
audio_path = audio_path.replace('.mp3', '.wav')
|
||||
|
||||
if viewport == "desktop":
|
||||
c1, c2 = st.columns([1, 3])
|
||||
with c1:
|
||||
st.write(f"**Track ID:** {song_id}")
|
||||
st.caption(f"Score: {score:.4f}")
|
||||
with c2:
|
||||
if os.path.exists(audio_path):
|
||||
st.audio(audio_path)
|
||||
else:
|
||||
st.caption("Аудиофайл не найден")
|
||||
else:
|
||||
st.write(f"**Track ID:** {song_id} (Score: {score:.4f})")
|
||||
if os.path.exists(audio_path):
|
||||
st.audio(audio_path)
|
||||
else:
|
||||
st.caption("Аудиофайл не найден")
|
||||
|
||||
tab_live, tab_debug = st.tabs(["Анализ событий (Свои фото)", "Отладка (Датасет EmoSet)"])
|
||||
st.markdown("<br>", unsafe_allow_html=True)
|
||||
|
||||
with st.expander("Отладочная информация (Метрики)"):
|
||||
st.subheader("Координаты V/A")
|
||||
c_v, c_a = st.columns(2)
|
||||
c_v.metric("Valence", f"{data['target_v']:.2f}")
|
||||
c_a.metric("Arousal", f"{data['target_a']:.2f}")
|
||||
|
||||
st.markdown("---")
|
||||
st.subheader("Акустические признаки (LLM)")
|
||||
|
||||
feature_titles = {
|
||||
"energy": "RMS Energy",
|
||||
"flux": "Spectral Flux",
|
||||
"centroid": "Spectral Centroid",
|
||||
"pitch": "F0 (Pitch)",
|
||||
"hnr": "HNR",
|
||||
"zcr": "ZCR"
|
||||
}
|
||||
|
||||
# Развернутые описания
|
||||
feature_helps = {
|
||||
"energy": "Среднеквадратичная амплитуда (громкость). Бывает высокой в плотных, интенсивных композициях, отражает общую акустическую энергию сцены.",
|
||||
"flux": "Спектральный поток. Измеряет резкость изменений в спектре. Высок при четком, агрессивном ритме и частой смене нот.",
|
||||
"centroid": "Спектральный центроид («яркость» звука). Высокие значения указывают на преобладание высоких частот (звонкие инструменты, открытые пространства).",
|
||||
"pitch": "Основная частота звука. Высокий pitch характерен для позитивных, легких или, напротив, напряженных мелодий.",
|
||||
"hnr": "Отношение гармоник к шуму. Высокий HNR — чистая мелодия и вокал. Низкий HNR — присутствие дисторшна, шумов или перкуссии.",
|
||||
"zcr": "Частота пересечения нуля. Отражает шумовую составляющую сигнала. Высок в треках с выраженными ударными (hi-hats) или атмосферным шумом."
|
||||
}
|
||||
|
||||
llm_profile = data.get("llm_profile")
|
||||
if llm_profile and isinstance(llm_profile, dict) and len(llm_profile) > 0:
|
||||
cols_per_row = 2 if viewport == "mobile" else 3
|
||||
llm_items = list(llm_profile.items())
|
||||
|
||||
for i in range(0, len(llm_items), cols_per_row):
|
||||
cols = st.columns(cols_per_row)
|
||||
for j in range(cols_per_row):
|
||||
if i + j < len(llm_items):
|
||||
k, v = llm_items[i + j]
|
||||
label = feature_titles.get(k, k)
|
||||
tooltip = feature_helps.get(k, "")
|
||||
cols[j].metric(label, f"{v:.2f}", help=tooltip)
|
||||
else:
|
||||
st.caption("Акустический профиль недоступен. Применен fallback-алгоритм.")
|
||||
|
||||
st.markdown("---")
|
||||
st.write("**Извлеченные теги (BLIP-2):**")
|
||||
st.write(", ".join([str(c).capitalize() for c in data.get("semantics", [])]))
|
||||
|
||||
with tab_live:
|
||||
if img_processor:
|
||||
render_live_tab(music_matcher, img_processor)
|
||||
else:
|
||||
st.error("Ошибка загрузки: не найдены веса ResNet для image_processor.")
|
||||
st.markdown("<br>", unsafe_allow_html=True)
|
||||
|
||||
if st.button("Новый запрос", use_container_width=True):
|
||||
st.session_state.live_state = "upload"
|
||||
st.session_state.result_data = None
|
||||
st.session_state.pop("uploaded_images", None)
|
||||
st.rerun()
|
||||
|
||||
with tab_debug:
|
||||
render_dataset_tab(music_matcher, emoset_files, emoset_embeddings, emoset_labels, emoset_path)
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -32,7 +32,11 @@ class ImageProcessor:
|
||||
|
||||
# Модуль семантического описания сцены
|
||||
print("Инициализация BLIP-2...")
|
||||
self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
# Обход бага конфигурации Hugging Face (ручная сборка процессора)
|
||||
from transformers import BlipImageProcessor, AutoTokenizer
|
||||
img_proc = BlipImageProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
tok = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b", use_fast=False)
|
||||
self.blip_processor = Blip2Processor(image_processor=img_proc, tokenizer=tok)
|
||||
self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/blip2-opt-2.7b",
|
||||
torch_dtype=torch.float16
|
||||
|
||||
@@ -1,65 +1,65 @@
|
||||
import re
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import requests
|
||||
|
||||
class LLMAcousticBridge:
|
||||
def __init__(self, target_model="dolphin-llama3:8b"):
|
||||
self.api_url = "http://localhost:11434/api/generate"
|
||||
self.model = target_model
|
||||
def __init__(self, model_name="dolphin-llama3:8b"):
|
||||
self.model_name = model_name
|
||||
base_url = os.getenv("OLLAMA_API_URL", "http://emom_ollama:11434")
|
||||
self.api_url = f"{base_url}/api/generate"
|
||||
|
||||
def _extract_json(self, raw_text: str):
|
||||
# Проверка на ИИдиота, LLM иногда игнорирует format="json" и оборачивает ответ в маркдаун
|
||||
try:
|
||||
match = re.search(r'\{.*\}', raw_text, re.DOTALL)
|
||||
if match:
|
||||
return json.loads(match.group(0))
|
||||
return json.loads(raw_text)
|
||||
except json.JSONDecodeError:
|
||||
# Если ИИдиот
|
||||
return None
|
||||
|
||||
def get_acoustic_profile(self, v_score: float, a_score: float, scene_context: list) -> dict | None:
|
||||
# Агрегация контекста для обработки серии снимков (события)
|
||||
context_merged = " | ".join(scene_context) if scene_context else "abstract scene"
|
||||
def get_acoustic_profile(self, valence, arousal, semantics):
|
||||
context_str = ", ".join(semantics) if semantics else "abstract scene"
|
||||
|
||||
prompt = f"""
|
||||
Analyze the visual context and emotions to determine the ideal background music properties.
|
||||
Emotions: Valence {valence:.1f}/9.0 (Positivity), Arousal {arousal:.1f}/9.0 (Energy).
|
||||
Visual Context: {context_str}.
|
||||
Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0.
|
||||
|
||||
1. "energy": (Loudness/Density)
|
||||
2. "flux": (Rhythmic sharpness/Beat)
|
||||
3. "centroid": (Brightness)
|
||||
4. "pitch": (Fundamental frequency)
|
||||
5. "hnr": (Harmonics-to-Noise)
|
||||
6. "zcr": (Percussiveness)
|
||||
|
||||
Return ONLY a valid JSON object. No explanations, no markdown blocks.
|
||||
Example: {{"energy": 0.8, "flux": 0.5, "centroid": 0.6, "pitch": 0.4, "hnr": 0.9, "zcr": 0.3}}
|
||||
"""
|
||||
|
||||
system_prompt = f"""You are an expert music producer and acoustic engineer.
|
||||
Analyze the visual context and emotions to determine the ideal background music properties.
|
||||
Emotions: Valence {v_score:.1f}/9.0 (Positivity), Arousal {a_score:.1f}/9.0 (Energy).
|
||||
Visual Context: {context_merged}.
|
||||
|
||||
Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0.
|
||||
1. "energy": (Loudness/Density. High for massive/busy scenes, Low for calm)
|
||||
2. "flux": (Rhythmic sharpness/Beat. High for action/people/cars, Low for static nature)
|
||||
3. "centroid": (Brightness: 0=Dark/Bass/Massive, 1=Bright/Treble/Light)
|
||||
4. "pitch": (Fundamental frequency: 0=Low pitch/Huge objects, 1=High pitch/Small objects)
|
||||
5. "hnr": (Harmonics-to-Noise: 0=Noisy/Distorted textures, 1=Clear/Melodic/Smooth textures)
|
||||
6. "zcr": (Percussiveness. High for detailed noise like leaves/rain, Low for solid blocks)
|
||||
|
||||
Return ONLY a valid JSON object. Do not add any text or explanation.
|
||||
Example: {{"energy": 0.5, "flux": 0.2, "centroid": 0.4, "pitch": 0.3, "hnr": 0.8, "zcr": 0.1}}"""
|
||||
|
||||
try:
|
||||
# Отправка промпта локальной Ollama
|
||||
response = requests.post(self.api_url, json={
|
||||
"model": self.model,
|
||||
"prompt": system_prompt,
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"format": "json"
|
||||
}, timeout=45)
|
||||
response.raise_for_status()
|
||||
"format": "json" # Принудительный JSON-режим Ollama
|
||||
}
|
||||
|
||||
raw_response = response.json().get("response", "")
|
||||
profile_data = self._extract_json(raw_response)
|
||||
print(f"Запрос акустического профиля к Ollama...")
|
||||
response = requests.post(self.api_url, json=payload, timeout=120)
|
||||
|
||||
# Валидация структуры ответа
|
||||
expected_features = {'energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr'}
|
||||
|
||||
if profile_data and expected_features.issubset(profile_data.keys()):
|
||||
return profile_data
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
response_text = data.get("response", "")
|
||||
|
||||
print("LLM вернула неполный или некорректный набор акустических признаков")
|
||||
return None
|
||||
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
print(f"Не удалось подключиться к Ollama: {req_err}")
|
||||
return None
|
||||
try:
|
||||
# 1. Попытка прямой десериализации
|
||||
profile = json.loads(response_text)
|
||||
return profile
|
||||
except json.JSONDecodeError:
|
||||
# 2. Аварийное извлечение JSON из текста с помощью регулярного выражения
|
||||
match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if match:
|
||||
return json.loads(match.group(0))
|
||||
|
||||
print(f"Ошибка парсинга LLM ответа: {response_text}")
|
||||
return {}
|
||||
else:
|
||||
print(f"Ollama вернула ошибку HTTP: {response.status_code}")
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Ошибка соединения с Ollama: {str(e)}")
|
||||
return {}
|
||||
Binary file not shown.
@@ -1,5 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Данный скрипт написан ИИ для быстрой подготовки окружения, установка драйверов и докера
|
||||
# Остановка скрипта при возникновении любой ошибки
|
||||
set -e
|
||||
|
||||
|
||||
@@ -1,541 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0c00b67b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from pathlib import Path\n",
|
||||
"from PIL import Image\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"import torchvision.transforms as T\n",
|
||||
"import timm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "84c3657f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'cuda'"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Конфигурация параметров обучения и путей файловой системы\n",
|
||||
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
|
||||
"BATCH_SIZE = 64\n",
|
||||
"EPOCHS = 15\n",
|
||||
"LR = 3e-4\n",
|
||||
"NUM_WORKERS = 40\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"print(f\"Аппаратное ускорение: {device}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9f749add",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class EmoSetDataset(Dataset):\n",
|
||||
" def __init__(self, root: Path | str, split: str):\n",
|
||||
" self.root = Path(root) / split\n",
|
||||
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
|
||||
"\n",
|
||||
" # Формирование словарей маппинга классов\n",
|
||||
" self.labels = sorted(self.df[\"label\"].unique())\n",
|
||||
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
|
||||
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
|
||||
"\n",
|
||||
" # Базовые трансформации для валидации и теста\n",
|
||||
" base_tf = [\n",
|
||||
" T.ToTensor(),\n",
|
||||
" T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" # Внедрение аугментации исключительно для обучающей выборки (предотвращение переобучения)\n",
|
||||
" if split == \"train\":\n",
|
||||
" self.transform = T.Compose([\n",
|
||||
" T.RandomResizedCrop(224),\n",
|
||||
" T.RandomHorizontalFlip(),\n",
|
||||
" *base_tf\n",
|
||||
" ])\n",
|
||||
" else:\n",
|
||||
" self.transform = T.Compose([\n",
|
||||
" T.Resize(256),\n",
|
||||
" T.CenterCrop(224),\n",
|
||||
" *base_tf\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.df)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" row = self.df.iloc[idx]\n",
|
||||
" img_path = self.root / \"images\" / row[\"filename\"]\n",
|
||||
"\n",
|
||||
" # Обработка возможных исключений ввода-вывода (поврежденные JPEG-файлы в датасете)\n",
|
||||
" try:\n",
|
||||
" img = Image.open(img_path).convert(\"RGB\")\n",
|
||||
" except Exception:\n",
|
||||
" img = Image.new(\"RGB\", (224, 224), (0, 0, 0))\n",
|
||||
"\n",
|
||||
" img_tensor = self.transform(img)\n",
|
||||
" label_idx = self.label2idx[row[\"label\"]]\n",
|
||||
" \n",
|
||||
" return img_tensor, label_idx"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c8805341",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Classes: ['amusement', 'anger', 'awe', 'contentment', 'disgust', 'excitement', 'fear', 'sadness']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Подготовка объектов выборки\n",
|
||||
"train_ds = EmoSetDataset(DATA_ROOT, \"train\")\n",
|
||||
"val_ds = EmoSetDataset(DATA_ROOT, \"val\")\n",
|
||||
"\n",
|
||||
"# Инициализация итераторов с закреплением памяти (pin_memory) для ускорения передачи на GPU\n",
|
||||
"train_loader = DataLoader(\n",
|
||||
" train_ds,\n",
|
||||
" batch_size=BATCH_SIZE,\n",
|
||||
" shuffle=True,\n",
|
||||
" num_workers=NUM_WORKERS,\n",
|
||||
" pin_memory=True\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"val_loader = DataLoader(\n",
|
||||
" val_ds,\n",
|
||||
" batch_size=BATCH_SIZE,\n",
|
||||
" shuffle=False,\n",
|
||||
" num_workers=NUM_WORKERS,\n",
|
||||
" pin_memory=True\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"Индексированные классы: {train_ds.labels}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dffce582",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ResNet(\n",
|
||||
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
|
||||
" (layer1): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (layer2): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (3): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (layer3): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (3): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (4): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (5): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (layer4): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act1): ReLU(inplace=True)\n",
|
||||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (drop_block): Identity()\n",
|
||||
" (act2): ReLU(inplace=True)\n",
|
||||
" (aa): Identity()\n",
|
||||
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (act3): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
|
||||
" (fc): Linear(in_features=2048, out_features=8, bias=True)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# TODO перед защитой, повторить оптимизаторы\n",
|
||||
"# Загрузка предобученной архитектуры ResNet-50 с заменой классификационного слоя\n",
|
||||
"model = timm.create_model(\n",
|
||||
" \"resnet50\",\n",
|
||||
" pretrained=True,\n",
|
||||
" num_classes=len(train_ds.labels)\n",
|
||||
")\n",
|
||||
"model.to(device)\n",
|
||||
"\n",
|
||||
"# Функция потерь для многоклассовой классификации\n",
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"\n",
|
||||
"# Оптимизатор AdamW с L2-регуляризацией (weight_decay) для повышения обобщающей способности\n",
|
||||
"optimizer = torch.optim.AdamW(\n",
|
||||
" model.parameters(),\n",
|
||||
" lr=LR,\n",
|
||||
" weight_decay=1e-4\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Планировщик скорости обучения: косинусный отжиг\n",
|
||||
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
|
||||
" optimizer,\n",
|
||||
" T_max=EPOCHS\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "81a457ef",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train_epoch(current_model, loader):\n",
|
||||
" current_model.train()\n",
|
||||
" total_loss = 0.0\n",
|
||||
" correct_preds = 0\n",
|
||||
" total_samples = 0\n",
|
||||
"\n",
|
||||
" for imgs, labels in tqdm(loader, desc=\"Тренировка\", leave=False):\n",
|
||||
" imgs = imgs.to(device)\n",
|
||||
" labels = labels.to(device)\n",
|
||||
"\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" logits = current_model(imgs)\n",
|
||||
" loss = criterion(logits, labels)\n",
|
||||
"\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" total_loss += loss.item() * imgs.size(0)\n",
|
||||
" preds = logits.argmax(dim=1)\n",
|
||||
" correct_preds += (preds == labels).sum().item()\n",
|
||||
" total_samples += labels.size(0)\n",
|
||||
"\n",
|
||||
" return total_loss / total_samples, correct_preds / total_samples\n",
|
||||
"\n",
|
||||
"@torch.no_grad()\n",
|
||||
"def val_epoch(current_model, loader):\n",
|
||||
" # Перевод модели в режим инференса (отключение Dropout и фиксация BatchNorm)\n",
|
||||
" current_model.eval()\n",
|
||||
" total_loss = 0.0\n",
|
||||
" correct_preds = 0\n",
|
||||
" total_samples = 0\n",
|
||||
"\n",
|
||||
" for imgs, labels in tqdm(loader, desc=\"Валидация\", leave=False):\n",
|
||||
" imgs = imgs.to(device)\n",
|
||||
" labels = labels.to(device)\n",
|
||||
"\n",
|
||||
" logits = current_model(imgs)\n",
|
||||
" loss = criterion(logits, labels)\n",
|
||||
"\n",
|
||||
" total_loss += loss.item() * imgs.size(0)\n",
|
||||
" preds = logits.argmax(dim=1)\n",
|
||||
" correct_preds += (preds == labels).sum().item()\n",
|
||||
" total_samples += labels.size(0)\n",
|
||||
"\n",
|
||||
" return total_loss / total_samples, correct_preds / total_samples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "951aa9e3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"best_val_acc = 0.0\n",
|
||||
"checkpoint_path = \"../emoset_resnet50_best.pth\"\n",
|
||||
"\n",
|
||||
"print(\"Старт процесса обучения...\")\n",
|
||||
"\n",
|
||||
"for epoch in range(1, EPOCHS + 1):\n",
|
||||
" train_loss, train_acc = train_epoch(model, train_loader)\n",
|
||||
" val_loss, val_acc = val_epoch(model, val_loader)\n",
|
||||
"\n",
|
||||
" # Обновление шага планировщика\n",
|
||||
" scheduler.step()\n",
|
||||
"\n",
|
||||
" print(\n",
|
||||
" f\"Эпоха {epoch:02d}/{EPOCHS} | \"\n",
|
||||
" f\"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | \"\n",
|
||||
" f\"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Экспорт весов при улучшении целевой метрики\n",
|
||||
" if val_acc > best_val_acc:\n",
|
||||
" best_val_acc = val_acc\n",
|
||||
" torch.save(model.state_dict(), checkpoint_path)\n",
|
||||
" print(f\" -> Сохранен новый лучший чекпоинт (Acc: {best_val_acc:.4f})\")\n",
|
||||
"\n",
|
||||
"print(\"Обучение завершено.\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "thesis-py3.11",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,184 @@
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torchvision.transforms as T
|
||||
import timm
|
||||
|
||||
# Подавление предупреждений цветовых профилей
|
||||
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
|
||||
|
||||
# Настройки окружения
|
||||
DATA_ROOT = Path("/home/zin/projects/Thesis/NFS/Thesis/Emoset/EmoSet-118K")
|
||||
BATCH_SIZE = 64
|
||||
EPOCHS = 30
|
||||
LR = 5e-5
|
||||
NUM_WORKERS = 62
|
||||
PATIENCE = 7
|
||||
|
||||
# Маппинг классов
|
||||
CLASS_MAPPING = {
|
||||
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
|
||||
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
|
||||
}
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Устройство: {DEVICE}")
|
||||
|
||||
# Фиксация генераторов псевдослучайных чисел
|
||||
def set_seed(seed=42):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
set_seed()
|
||||
|
||||
# Инициализация структур данных
|
||||
class EmoSetDataset(Dataset):
|
||||
def __init__(self, root: Path | str, split: str, transform=None):
|
||||
self.root = Path(root) / split
|
||||
self.df = pd.read_csv(self.root / "labels.csv")
|
||||
self.transform = transform
|
||||
|
||||
# Фильтрация датафрейма
|
||||
self.df = self.df[self.df["label"].isin(CLASS_MAPPING.keys())].reset_index(drop=True)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
row = self.df.iloc[idx]
|
||||
img_path = self.root / "images" / row["filename"]
|
||||
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
except Exception:
|
||||
img = Image.new("RGB", (256, 256), (0, 0, 0))
|
||||
|
||||
if self.transform:
|
||||
img_tensor = self.transform(img)
|
||||
else:
|
||||
img_tensor = T.ToTensor()(img)
|
||||
|
||||
label_idx = CLASS_MAPPING[row["label"]]
|
||||
return img_tensor, label_idx
|
||||
|
||||
# Трансформации
|
||||
base_tf = [
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
]
|
||||
|
||||
train_transform = T.Compose([
|
||||
T.Resize(256, antialias=True),
|
||||
T.RandomCrop(224),
|
||||
T.RandomHorizontalFlip(),
|
||||
*base_tf
|
||||
])
|
||||
|
||||
val_transform = T.Compose([
|
||||
T.Resize(256, antialias=True),
|
||||
T.CenterCrop(224),
|
||||
*base_tf
|
||||
])
|
||||
|
||||
train_ds = EmoSetDataset(DATA_ROOT, "train", transform=train_transform)
|
||||
val_ds = EmoSetDataset(DATA_ROOT, "val", transform=val_transform)
|
||||
|
||||
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
|
||||
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
|
||||
|
||||
# Инициализация модели и оптимизатора
|
||||
model = timm.create_model("resnet50", pretrained=True, num_classes=8, drop_rate=0.3)
|
||||
model.to(DEVICE)
|
||||
|
||||
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-3)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
||||
|
||||
# Логика эпохи обучения
|
||||
def train_epoch(current_model, loader):
|
||||
current_model.train()
|
||||
total_loss, correct_preds, total_samples = 0.0, 0, 0
|
||||
|
||||
for imgs, labels in tqdm(loader, desc="Тренировка", leave=False, smoothing=0):
|
||||
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
logits = current_model(imgs)
|
||||
loss = criterion(logits, labels)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item() * imgs.size(0)
|
||||
preds = logits.argmax(dim=1)
|
||||
correct_preds += (preds == labels).sum().item()
|
||||
total_samples += labels.size(0)
|
||||
|
||||
return total_loss / total_samples, correct_preds / total_samples
|
||||
|
||||
# Логика эпохи валидации
|
||||
@torch.no_grad()
|
||||
def val_epoch(current_model, loader):
|
||||
current_model.eval()
|
||||
total_loss, correct_preds, total_samples = 0.0, 0, 0
|
||||
|
||||
for imgs, labels in tqdm(loader, desc="Валидация", leave=False, smoothing=0):
|
||||
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
|
||||
|
||||
logits = current_model(imgs)
|
||||
loss = criterion(logits, labels)
|
||||
|
||||
total_loss += loss.item() * imgs.size(0)
|
||||
preds = logits.argmax(dim=1)
|
||||
correct_preds += (preds == labels).sum().item()
|
||||
total_samples += labels.size(0)
|
||||
|
||||
return total_loss / total_samples, correct_preds / total_samples
|
||||
|
||||
if __name__ == "__main__":
|
||||
best_val_acc = 0.0
|
||||
best_val_loss = float('inf')
|
||||
epochs_no_improve = 0
|
||||
checkpoint_path = "./emosetV2_resnet50_best.pth"
|
||||
|
||||
print("Старт обучения.")
|
||||
|
||||
for epoch in range(1, EPOCHS + 1):
|
||||
train_loss, train_acc = train_epoch(model, train_loader)
|
||||
val_loss, val_acc = val_epoch(model, val_loader)
|
||||
|
||||
scheduler.step()
|
||||
|
||||
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
|
||||
|
||||
# Сохранение лучших весов по Accuracy
|
||||
if val_acc > best_val_acc:
|
||||
best_val_acc = val_acc
|
||||
torch.save(model.state_dict(), checkpoint_path)
|
||||
print(f"Сохранен чекпоинт (Acc: {best_val_acc:.4f})")
|
||||
|
||||
# Оценка переобучения по Loss (Early Stopping)
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
epochs_no_improve = 0
|
||||
else:
|
||||
epochs_no_improve += 1
|
||||
if epochs_no_improve >= PATIENCE:
|
||||
print(f"Ранняя остановка: метрика валидации не улучшается {PATIENCE} эпох.")
|
||||
break
|
||||
|
||||
print("Процесс завершен.")
|
||||
@@ -0,0 +1,283 @@
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torchvision.transforms as T
|
||||
import timm
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
# Подавление предупреждений цветовых профилей
|
||||
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
|
||||
|
||||
# Настройки окружения
|
||||
DATA_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K")
|
||||
# ВАЖНО: Добавили путь для медиа файлов
|
||||
MEDIA_DIR = Path("./src/scripts/media")
|
||||
MEDIA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
BATCH_SIZE = 64
|
||||
EPOCHS = 30
|
||||
LR = 5e-5
|
||||
NUM_WORKERS = 32
|
||||
PATIENCE = 7
|
||||
|
||||
# Маппинг классов
|
||||
CLASS_MAPPING = {
|
||||
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
|
||||
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
|
||||
}
|
||||
# Инвертированный маппинг для графиков
|
||||
INV_CLASS_MAPPING = {v: k for k, v in CLASS_MAPPING.items()}
|
||||
CLASS_NAMES = [INV_CLASS_MAPPING[i] for i in range(len(CLASS_MAPPING))]
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Устройство: {DEVICE}")
|
||||
|
||||
# Фиксация генераторов псевдослучайных чисел
|
||||
def set_seed(seed=42):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
set_seed()
|
||||
|
||||
# Инициализация структур данных
|
||||
class EmoSetDataset(Dataset):
|
||||
def __init__(self, root: Path | str, split: str, transform=None):
|
||||
self.root = Path(root) / split
|
||||
self.df = pd.read_csv(self.root / "labels.csv")
|
||||
self.transform = transform
|
||||
|
||||
# Фильтрация датафрейма
|
||||
self.df = self.df[self.df["label"].isin(CLASS_MAPPING.keys())].reset_index(drop=True)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
row = self.df.iloc[idx]
|
||||
img_path = self.root / "images" / row["filename"]
|
||||
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
except Exception:
|
||||
img = Image.new("RGB", (256, 256), (0, 0, 0))
|
||||
|
||||
if self.transform:
|
||||
img_tensor = self.transform(img)
|
||||
else:
|
||||
img_tensor = T.ToTensor()(img)
|
||||
|
||||
label_idx = CLASS_MAPPING[row["label"]]
|
||||
return img_tensor, label_idx
|
||||
|
||||
# Трансформации
|
||||
base_tf = [
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
]
|
||||
|
||||
train_transform = T.Compose([
|
||||
T.Resize(256, antialias=True),
|
||||
T.RandomCrop(224),
|
||||
T.RandomHorizontalFlip(),
|
||||
*base_tf
|
||||
])
|
||||
|
||||
val_transform = T.Compose([
|
||||
T.Resize(256, antialias=True),
|
||||
T.CenterCrop(224),
|
||||
*base_tf
|
||||
])
|
||||
|
||||
train_ds = EmoSetDataset(DATA_ROOT, "train", transform=train_transform)
|
||||
val_ds = EmoSetDataset(DATA_ROOT, "val", transform=val_transform)
|
||||
|
||||
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
|
||||
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
|
||||
|
||||
# Инициализация модели и оптимизатора
|
||||
model = timm.create_model("resnet50", pretrained=True, num_classes=8, drop_rate=0.3)
|
||||
model.to(DEVICE)
|
||||
|
||||
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-3)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
||||
|
||||
# Функции для отрисовки графиков
|
||||
def plot_learning_curves(history):
|
||||
"""Отрисовка графиков функции потерь и точности"""
|
||||
epochs = range(1, len(history['train_loss']) + 1)
|
||||
|
||||
plt.figure(figsize=(14, 5))
|
||||
|
||||
# График Loss
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.plot(epochs, history['train_loss'], 'b-', label='Train Loss')
|
||||
plt.plot(epochs, history['val_loss'], 'r--', label='Validation Loss')
|
||||
plt.title('График функции потерь (Loss)', fontsize=14)
|
||||
plt.xlabel('Эпохи', fontsize=12)
|
||||
plt.ylabel('Loss', fontsize=12)
|
||||
plt.legend()
|
||||
plt.grid(True, linestyle=':', alpha=0.7)
|
||||
|
||||
# График Accuracy
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.plot(epochs, history['train_acc'], 'b-', label='Train Accuracy')
|
||||
plt.plot(epochs, history['val_acc'], 'r--', label='Validation Accuracy')
|
||||
plt.title('График точности (Accuracy)', fontsize=14)
|
||||
plt.xlabel('Эпохи', fontsize=12)
|
||||
plt.ylabel('Accuracy', fontsize=12)
|
||||
plt.legend()
|
||||
plt.grid(True, linestyle=':', alpha=0.7)
|
||||
|
||||
plt.tight_layout()
|
||||
plot_path = MEDIA_DIR / "training_history.png"
|
||||
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"[INFO] График обучения сохранен в: {plot_path}")
|
||||
|
||||
def plot_confusion_matrix(y_true, y_pred):
|
||||
"""Отрисовка тепловой матрицы ошибок"""
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
|
||||
plt.figure(figsize=(10, 8))
|
||||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||||
xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
|
||||
cbar_kws={'label': 'Количество сэмплов'})
|
||||
|
||||
plt.title('Матрица ошибок (Confusion Matrix) - ResNet50', fontsize=16, pad=20)
|
||||
plt.ylabel('Истинные классы (Ground Truth)', fontsize=12)
|
||||
plt.xlabel('Предсказанные классы (Predicted)', fontsize=12)
|
||||
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
plt.yticks(rotation=0)
|
||||
|
||||
plt.tight_layout()
|
||||
cm_path = MEDIA_DIR / "confusion_matrix_emoset.png"
|
||||
plt.savefig(cm_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"[INFO] Матрица ошибок сохранена в: {cm_path}")
|
||||
|
||||
# Логика эпохи обучения
|
||||
def train_epoch(current_model, loader):
|
||||
current_model.train()
|
||||
total_loss, correct_preds, total_samples = 0.0, 0, 0
|
||||
|
||||
for imgs, labels in tqdm(loader, desc="Тренировка", leave=False, smoothing=0):
|
||||
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
logits = current_model(imgs)
|
||||
loss = criterion(logits, labels)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item() * imgs.size(0)
|
||||
preds = logits.argmax(dim=1)
|
||||
correct_preds += (preds == labels).sum().item()
|
||||
total_samples += labels.size(0)
|
||||
|
||||
return total_loss / total_samples, correct_preds / total_samples
|
||||
|
||||
# Логика эпохи валидации с сохранением предсказаний для матрицы ошибок
|
||||
@torch.no_grad()
|
||||
def val_epoch(current_model, loader, return_preds=False):
|
||||
current_model.eval()
|
||||
total_loss, correct_preds, total_samples = 0.0, 0, 0
|
||||
all_preds, all_labels = [], []
|
||||
|
||||
for imgs, labels in tqdm(loader, desc="Валидация", leave=False, smoothing=0):
|
||||
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
|
||||
|
||||
logits = current_model(imgs)
|
||||
loss = criterion(logits, labels)
|
||||
|
||||
total_loss += loss.item() * imgs.size(0)
|
||||
preds = logits.argmax(dim=1)
|
||||
|
||||
correct_preds += (preds == labels).sum().item()
|
||||
total_samples += labels.size(0)
|
||||
|
||||
if return_preds:
|
||||
all_preds.extend(preds.cpu().numpy())
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
|
||||
avg_loss = total_loss / total_samples
|
||||
avg_acc = correct_preds / total_samples
|
||||
|
||||
if return_preds:
|
||||
return avg_loss, avg_acc, all_labels, all_preds
|
||||
return avg_loss, avg_acc
|
||||
|
||||
if __name__ == "__main__":
|
||||
best_val_acc = 0.0
|
||||
best_val_loss = float('inf')
|
||||
epochs_no_improve = 0
|
||||
checkpoint_path = "./emosetV2_resnet50_best.pth"
|
||||
|
||||
# Словарь для хранения истории обучения
|
||||
history = {
|
||||
'train_loss': [], 'train_acc': [],
|
||||
'val_loss': [], 'val_acc': []
|
||||
}
|
||||
|
||||
# Переменные для хранения лучших предсказаний для матрицы
|
||||
best_labels, best_preds = [], []
|
||||
|
||||
print("Старт обучения.")
|
||||
|
||||
for epoch in range(1, EPOCHS + 1):
|
||||
train_loss, train_acc = train_epoch(model, train_loader)
|
||||
|
||||
# Получаем предсказания только если это может быть лучшая эпоха
|
||||
val_loss, val_acc, val_labels, val_preds = val_epoch(model, val_loader, return_preds=True)
|
||||
|
||||
scheduler.step()
|
||||
|
||||
# Запись в историю
|
||||
history['train_loss'].append(train_loss)
|
||||
history['train_acc'].append(train_acc)
|
||||
history['val_loss'].append(val_loss)
|
||||
history['val_acc'].append(val_acc)
|
||||
|
||||
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
|
||||
|
||||
# Сохранение лучших весов по Accuracy
|
||||
if val_acc > best_val_acc:
|
||||
best_val_acc = val_acc
|
||||
best_labels = val_labels # Сохраняем предсказания лучшей модели
|
||||
best_preds = val_preds
|
||||
torch.save(model.state_dict(), checkpoint_path)
|
||||
print(f"Сохранен чекпоинт (Acc: {best_val_acc:.4f})")
|
||||
|
||||
# Оценка переобучения по Loss (Early Stopping)
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
epochs_no_improve = 0
|
||||
else:
|
||||
epochs_no_improve += 1
|
||||
if epochs_no_improve >= PATIENCE:
|
||||
print(f"Ранняя остановка: метрика валидации не улучшается {PATIENCE} эпох.")
|
||||
break
|
||||
|
||||
print("Процесс обучения завершен. Генерирую графики для диссертации...")
|
||||
plot_learning_curves(history)
|
||||
plot_confusion_matrix(best_labels, best_preds)
|
||||
print("Все медиафайлы успешно созданы!")
|
||||
@@ -0,0 +1,171 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torchvision.transforms as T
|
||||
import timm
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
# Настройки путей для медиа
|
||||
MEDIA_DIR = Path("scripts/media")
|
||||
MEDIA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Конфигурация путей для инференса и кэширования векторов
|
||||
DATA_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K")
|
||||
MODEL_PATH = Path("./src/emoset_resnet50_best.pth")
|
||||
|
||||
BATCH_SIZE = 128
|
||||
NUM_WORKERS = 32
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Вычисления перенесены на: {device}")
|
||||
|
||||
class EmoSetFeatureDataset(Dataset):
|
||||
def __init__(self, root: Path | str, split: str):
|
||||
self.root = Path(root) / split
|
||||
self.df = pd.read_csv(self.root / "labels.csv")
|
||||
|
||||
self.labels = sorted(self.df["label"].unique())
|
||||
self.label2idx = {l: i for i, l in enumerate(self.labels)}
|
||||
self.idx2label = {i: l for l, i in self.label2idx.items()}
|
||||
|
||||
# Для экстракции признаков аугментация отключена, используется строгий CenterCrop
|
||||
self.transform = T.Compose([
|
||||
T.Resize(256),
|
||||
T.CenterCrop(224),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
row = self.df.iloc[idx]
|
||||
img_path = self.root / "images" / row["filename"]
|
||||
|
||||
# Перехват битых файлов выборки
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
except Exception:
|
||||
img = Image.new("RGB", (224, 224), (0, 0, 0))
|
||||
|
||||
img_tensor = self.transform(img)
|
||||
label_idx = self.label2idx[row["label"]]
|
||||
|
||||
return img_tensor, label_idx
|
||||
|
||||
def plot_tsne(embeddings, labels, idx2label, sample_limit=3000):
|
||||
"""Генерация t-SNE графика для диссертации"""
|
||||
print(f"Построение t-SNE проекции для {sample_limit} сэмплов...")
|
||||
|
||||
tsne_model = TSNE(n_components=2, perplexity=30, random_state=42)
|
||||
embeddings_2d = tsne_model.fit_transform(embeddings[:sample_limit])
|
||||
labels_subset = labels[:sample_limit]
|
||||
|
||||
plt.figure(figsize=(12, 9))
|
||||
|
||||
# Используем более академическую палитру
|
||||
scatter = plt.scatter(
|
||||
embeddings_2d[:, 0],
|
||||
embeddings_2d[:, 1],
|
||||
c=labels_subset,
|
||||
cmap="Set2", # Set2 лучше различается при печати
|
||||
alpha=0.7,
|
||||
s=20,
|
||||
edgecolors='w',
|
||||
linewidths=0.5
|
||||
)
|
||||
|
||||
# Формирование легенды
|
||||
handles, _ = scatter.legend_elements()
|
||||
legend_labels = [idx2label[i] for i in range(len(idx2label))]
|
||||
|
||||
# Размещение легенды снаружи графика, чтобы не перекрывать данные
|
||||
plt.legend(handles, legend_labels, title="Эмоциональные классы",
|
||||
bbox_to_anchor=(1.05, 1), loc='upper left')
|
||||
|
||||
plt.title("2D проекция скрытого пространства признаков (t-SNE)", pad=20, fontsize=14)
|
||||
plt.xlabel("Первая главная компонента (t-SNE 1)", fontsize=12)
|
||||
plt.ylabel("Вторая главная компонента (t-SNE 2)", fontsize=12)
|
||||
plt.grid(True, linestyle='--', alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plot_path = MEDIA_DIR / "tsne_embeddings.png"
|
||||
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"[INFO] График t-SNE сохранен в: {plot_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ds = EmoSetFeatureDataset(DATA_ROOT, "test")
|
||||
test_loader = DataLoader(
|
||||
test_ds,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=False, # Отключение шаффла для строгого соответствия индексов
|
||||
num_workers=NUM_WORKERS,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
print(f"Подготовлено для извлечения: {len(test_ds)} файлов.")
|
||||
|
||||
# Инициализация модели и загрузка лучших весов
|
||||
feature_extractor = timm.create_model(
|
||||
"resnet50",
|
||||
pretrained=False,
|
||||
num_classes=len(test_ds.labels)
|
||||
)
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(MODEL_PATH, map_location=device)
|
||||
feature_extractor.load_state_dict(checkpoint)
|
||||
print("Веса модели успешно загружены.")
|
||||
except Exception as e:
|
||||
print(f"Ошибка загрузки весов: {e}. Убедитесь, что модель обучена.")
|
||||
exit(1)
|
||||
|
||||
# Удаление классификационного слоя (fc)
|
||||
feature_extractor.reset_classifier(0)
|
||||
feature_extractor.to(device)
|
||||
feature_extractor.eval()
|
||||
|
||||
print("Слой классификации удален. Модель готова к экстракции.")
|
||||
|
||||
extracted_embeddings = []
|
||||
extracted_labels = []
|
||||
|
||||
print("Старт пакетной экстракции признаков...")
|
||||
|
||||
with torch.no_grad():
|
||||
for imgs, labels in tqdm(test_loader, desc="Экстракция"):
|
||||
imgs = imgs.to(device)
|
||||
|
||||
# Получение вектора [BATCH_SIZE, 2048]
|
||||
embeddings_batch = feature_extractor(imgs)
|
||||
|
||||
extracted_embeddings.append(embeddings_batch.cpu().numpy())
|
||||
extracted_labels.append(labels.numpy())
|
||||
|
||||
# Агрегация батчей в единые массивы
|
||||
np_embeddings = np.concatenate(extracted_embeddings, axis=0)
|
||||
np_labels = np.concatenate(extracted_labels, axis=0)
|
||||
|
||||
print(f"Размерность матрицы признаков: {np_embeddings.shape}")
|
||||
|
||||
# Сохранение артефактов
|
||||
np.save("./src/emoset_test_embeddings.npy", np_embeddings)
|
||||
np.save("./src/emoset_test_labels.npy", np_labels)
|
||||
print("Матрицы успешно экспортированы в .npy файлы.")
|
||||
|
||||
# Генерация медиа для диссертации
|
||||
plot_tsne(np_embeddings, np_labels, test_ds.idx2label, sample_limit=3000)
|
||||
|
||||
print("Процесс полностью завершен.")
|
||||
@@ -1,264 +0,0 @@
|
||||
import os
|
||||
import gc
|
||||
import pickle
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torchvision.transforms as T
|
||||
import torchvision.io as tv_io
|
||||
from torch.amp import autocast, GradScaler
|
||||
from tqdm import tqdm
|
||||
import timm
|
||||
|
||||
# Конфигурация стенда и путей файловой системы
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M")
|
||||
CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl")
|
||||
|
||||
PREVIOUS_WEIGHTS = Path("/home/zin/projects/Thesis/src/emoset_resnet50_best.pth")
|
||||
RESUME_CHECKPOINT = Path("/home/zin/projects/Thesis/src/emoset_resnet50_resume.pth")
|
||||
SAVE_MODEL_PATH = Path("/home/zin/projects/Thesis/src/emoset_resnet50_finetuned_2_41M.pth")
|
||||
|
||||
CLASS_MAPPING = {
|
||||
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
|
||||
"disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7
|
||||
}
|
||||
|
||||
# Гиперпараметры конвейера обучения
|
||||
BATCH_SIZE = 82
|
||||
EPOCHS = 15
|
||||
LR = 5e-5
|
||||
NUM_TRAIN_WORKERS = 48
|
||||
NUM_VAL_WORKERS = 18
|
||||
PATIENCE = 4
|
||||
|
||||
def prepare_dataset_index():
|
||||
# Построение или загрузка индекса файлов для минимизации I/O операций по сети (NFS)
|
||||
if CACHE_PATH.exists():
|
||||
print(f"Загрузка карты файловой системы из кэша: {CACHE_PATH.name}")
|
||||
with open(CACHE_PATH, 'rb') as f:
|
||||
cache_data = pickle.load(f)
|
||||
return cache_data['image_paths'], cache_data['labels']
|
||||
|
||||
print(f"Сканирование сетевой директории {DATA_ROOT} (первичная индексация)...")
|
||||
paths, labels = [], []
|
||||
for img_path in DATA_ROOT.rglob('*.jpg'):
|
||||
emotion_folder = img_path.parts[-3].lower()
|
||||
if emotion_folder in CLASS_MAPPING:
|
||||
paths.append(str(img_path))
|
||||
labels.append(CLASS_MAPPING[emotion_folder])
|
||||
|
||||
with open(CACHE_PATH, 'wb') as f:
|
||||
pickle.dump({'image_paths': paths, 'labels': labels}, f)
|
||||
|
||||
return paths, labels
|
||||
|
||||
class EmoSetDirectDataset(Dataset):
|
||||
# Датасет с отложенной аугментацией: декодирование на CPU, трансформации на GPU
|
||||
def __init__(self, image_paths, labels):
|
||||
self.image_paths = image_paths
|
||||
self.labels = labels
|
||||
self.base_transform = T.Resize((256, 256), antialias=True)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
image = tv_io.read_image(self.image_paths[idx], mode=tv_io.ImageReadMode.RGB)
|
||||
image = image.to(torch.float32) / 255.0
|
||||
image = self.base_transform(image)
|
||||
except Exception:
|
||||
# Изолирование сбоев ввода-вывода (поврежденные файлы на сетевом диске)
|
||||
image = torch.zeros((3, 256, 256), dtype=torch.float32)
|
||||
return image, self.labels[idx]
|
||||
|
||||
def build_gpu_transforms():
|
||||
# Перенос матричных операций аугментации на тензорные ядра видеокарты
|
||||
train_tf = torch.nn.Sequential(
|
||||
T.RandomCrop((224, 224)),
|
||||
T.RandomHorizontalFlip(p=0.5),
|
||||
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
|
||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
).to(DEVICE)
|
||||
|
||||
val_tf = torch.nn.Sequential(
|
||||
T.CenterCrop((224, 224)),
|
||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
).to(DEVICE)
|
||||
|
||||
return train_tf, val_tf
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"Инициализация конвейера обучения. Устройство: {DEVICE}")
|
||||
|
||||
all_paths, all_labels = prepare_dataset_index()
|
||||
|
||||
# Фиксация сида для детерминированного разделения выборок при перезапусках скрипта
|
||||
random.seed(42)
|
||||
combined = list(zip(all_paths, all_labels))
|
||||
random.shuffle(combined)
|
||||
all_paths, all_labels = zip(*combined)
|
||||
|
||||
split_idx = int(len(all_paths) * 0.95)
|
||||
|
||||
train_loader = DataLoader(
|
||||
EmoSetDirectDataset(all_paths[:split_idx], all_labels[:split_idx]),
|
||||
batch_size=BATCH_SIZE, shuffle=True,
|
||||
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
|
||||
prefetch_factor=2, persistent_workers=True
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
EmoSetDirectDataset(all_paths[split_idx:], all_labels[split_idx:]),
|
||||
batch_size=BATCH_SIZE, shuffle=False,
|
||||
num_workers=NUM_VAL_WORKERS, pin_memory=True,
|
||||
prefetch_factor=2, persistent_workers=True
|
||||
)
|
||||
|
||||
gpu_train_tf, gpu_val_tf = build_gpu_transforms()
|
||||
|
||||
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
||||
scaler = GradScaler()
|
||||
|
||||
best_val_loss = float('inf')
|
||||
epochs_no_improve = 0
|
||||
start_epoch = 1
|
||||
|
||||
# Инициализация механизма отказоустойчивости и интеграция весов
|
||||
if RESUME_CHECKPOINT.exists():
|
||||
print(f"Восстановление контекста выполнения из: {RESUME_CHECKPOINT.name}")
|
||||
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
if 'scaler_state_dict' in checkpoint: scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
||||
if 'best_val_loss' in checkpoint: best_val_loss = checkpoint['best_val_loss']
|
||||
start_epoch = checkpoint['epoch'] + 1
|
||||
elif PREVIOUS_WEIGHTS.exists():
|
||||
print(f"Интеграция претренированных весов: {PREVIOUS_WEIGHTS.name}")
|
||||
model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE))
|
||||
else:
|
||||
print("Веса не найдены. Инициализация с ImageNet.")
|
||||
model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE)
|
||||
|
||||
try:
|
||||
for epoch in range(start_epoch, EPOCHS + 1):
|
||||
|
||||
# Проход по обучающей выборке
|
||||
model.train()
|
||||
running_loss, correct, total = 0.0, 0, 0
|
||||
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [Train]")
|
||||
for inputs, labels in pbar:
|
||||
try:
|
||||
inputs = inputs.to(DEVICE, non_blocking=True)
|
||||
labels = labels.to(DEVICE, non_blocking=True)
|
||||
inputs = gpu_train_tf(inputs)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Смешанная точность для экономии VRAM
|
||||
with autocast(device_type="cuda"):
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
running_loss += loss.item() * inputs.size(0)
|
||||
_, predicted = outputs.max(1)
|
||||
total += labels.size(0)
|
||||
correct += predicted.eq(labels).sum().item()
|
||||
|
||||
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
||||
|
||||
except RuntimeError as memory_err:
|
||||
# Подавление пиковых скачков потребления VRAM
|
||||
if "out of memory" in str(memory_err).lower():
|
||||
if 'outputs' in locals(): del outputs
|
||||
if 'loss' in locals(): del loss
|
||||
torch.cuda.empty_cache()
|
||||
optimizer.zero_grad()
|
||||
continue
|
||||
raise memory_err
|
||||
|
||||
train_loss = running_loss / total if total > 0 else 0
|
||||
train_acc = correct / total if total > 0 else 0
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Проход по валидационной выборке
|
||||
model.eval()
|
||||
val_loss, val_correct, val_total = 0.0, 0, 0
|
||||
|
||||
with torch.no_grad():
|
||||
for val_inputs, val_labels in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} [Val]", leave=False):
|
||||
val_inputs = val_inputs.to(DEVICE, non_blocking=True)
|
||||
val_labels = val_labels.to(DEVICE, non_blocking=True)
|
||||
val_inputs = gpu_val_tf(val_inputs)
|
||||
|
||||
with autocast(device_type="cuda"):
|
||||
val_outputs = model(val_inputs)
|
||||
v_loss = criterion(val_outputs, val_labels)
|
||||
|
||||
val_loss += v_loss.item() * val_inputs.size(0)
|
||||
_, val_predicted = val_outputs.max(1)
|
||||
val_total += val_labels.size(0)
|
||||
val_correct += val_predicted.eq(val_labels).sum().item()
|
||||
|
||||
epoch_val_loss = val_loss / val_total if val_total > 0 else 0
|
||||
epoch_val_acc = val_correct / val_total if val_total > 0 else 0
|
||||
|
||||
scheduler.step()
|
||||
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}")
|
||||
|
||||
# Оценка критериев ранней остановки и сохранение состояния сессии
|
||||
if epoch_val_loss < best_val_loss:
|
||||
best_val_loss = epoch_val_loss
|
||||
epochs_no_improve = 0
|
||||
torch.save(model.state_dict(), str(SAVE_MODEL_PATH).replace(".pth", "_best.pth"))
|
||||
else:
|
||||
epochs_no_improve += 1
|
||||
if epochs_no_improve >= PATIENCE and epoch >= 15:
|
||||
print(f"Сработал механизм Early Stopping. Валидация не улучшается {PATIENCE} эпох.")
|
||||
break
|
||||
|
||||
# Атомарное сохранение контекста
|
||||
checkpoint_state = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'scaler_state_dict': scaler.state_dict(),
|
||||
'best_val_loss': best_val_loss
|
||||
}
|
||||
torch.save(checkpoint_state, RESUME_CHECKPOINT)
|
||||
gc.collect()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nВыполнение прервано пользователем (SIGINT).")
|
||||
print(f"Дамп памяти конвейера зафиксирован на эпохе {epoch}.")
|
||||
checkpoint_state = {
|
||||
'epoch': epoch, 'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': scaler.state_dict(),
|
||||
'best_val_loss': best_val_loss
|
||||
}
|
||||
torch.save(checkpoint_state, RESUME_CHECKPOINT)
|
||||
|
||||
else:
|
||||
if SAVE_MODEL_PATH.parent.exists():
|
||||
torch.save(model.state_dict(), SAVE_MODEL_PATH)
|
||||
print(f"Процесс Fine-Tuning завершен. Артефакт сохранен: {SAVE_MODEL_PATH.name}")
|
||||
if RESUME_CHECKPOINT.exists():
|
||||
RESUME_CHECKPOINT.unlink()
|
||||
@@ -1,96 +1,97 @@
|
||||
import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import joblib
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import mean_squared_error, r2_score
|
||||
|
||||
# Калибровочные координаты центров эмоциональных классов в пространстве Рассела [1.0 - 9.0]
|
||||
EMOTION_TO_VA_COORDS = {
|
||||
0: (7.5, 6.5), # amusement
|
||||
1: (2.0, 8.0), # anger
|
||||
2: (6.5, 5.0), # awe
|
||||
3: (7.0, 3.0), # contentment
|
||||
4: (3.0, 6.0), # disgust
|
||||
5: (8.0, 8.0), # excitement
|
||||
6: (2.5, 7.5), # fear
|
||||
7: (2.0, 2.0), # sadness
|
||||
# 1. Настройка путей
|
||||
embeddings_path = Path("./src/emoset_test_embeddings.npy")
|
||||
csv_path = Path("./NFS/Thesis/Emoset/EmoSet-118K/test/labels.csv")
|
||||
model_path = Path("./src/music_engine/va_regressor.pkl")
|
||||
|
||||
output_dir = Path("./src/scripts/media")
|
||||
output_file = output_dir / "metrics_output.txt"
|
||||
|
||||
# 2. Корректный маппинг 8 классов EmoSet в шкалу DEAM [1.0, 9.0]
|
||||
# Формула перевода из [-1, 1] в [1, 9]: 5.0 + (X * 4.0)
|
||||
EMO_TO_VA = {
|
||||
"amusement": [8.2, 6.6], # Веселье (Высокий позитив, средняя энергия)
|
||||
"awe": [7.0, 7.4], # Восхищение (Позитив, высокая энергия)
|
||||
"contentment": [7.8, 3.4], # Умиротворение (Позитив, низкая энергия)
|
||||
"excitement": [8.2, 8.2], # Возбуждение (Макс. позитив, макс. энергия)
|
||||
"anger": [2.2, 7.8], # Гнев (Глубокий негатив, высокая энергия)
|
||||
"disgust": [2.6, 6.6], # Отвращение (Негатив, средняя энергия)
|
||||
"fear": [2.6, 8.2], # Страх (Негатив, максимальная энергия)
|
||||
"sadness": [2.2, 2.6] # Грусть (Глубокий негатив, низкая энергия)
|
||||
}
|
||||
|
||||
def evaluate_regression_model():
|
||||
# Инициализация путей к артефактам пайплайна
|
||||
base_dir = Path(__file__).resolve().parent.parent.parent
|
||||
embeddings_path = base_dir / "src" / "emoset_test_embeddings.npy"
|
||||
labels_path = base_dir / "src" / "emoset_test_labels.npy"
|
||||
model_path = base_dir / "src" / "music_engine" / "va_regressor.pkl"
|
||||
|
||||
if not all(p.exists() for p in [embeddings_path, labels_path, model_path]):
|
||||
print("Отсутствуют необходимые артефакты для расчета метрик.")
|
||||
def generate_slide_metrics():
|
||||
print("[INFO] Загрузка тестовых артефактов...")
|
||||
|
||||
if not all(p.exists() for p in [embeddings_path, csv_path, model_path]):
|
||||
print("[ERROR] Проверьте наличие файлов данных или модели регрессора.")
|
||||
return
|
||||
|
||||
# Загрузка скрытых представлений и инициализация регрессора
|
||||
x_features = np.load(embeddings_path)
|
||||
y_discrete = np.load(labels_path)
|
||||
regression_pipeline = joblib.load(model_path)
|
||||
|
||||
# Маппинг дискретных меток в непрерывные координаты
|
||||
y_continuous = np.array([EMOTION_TO_VA_COORDS[label] for label in y_discrete])
|
||||
|
||||
# Изоляция тестовой выборки (сохранение детерминированности через random_state)
|
||||
_, x_test, _, y_test = train_test_split(x_features, y_continuous, test_size=0.2, random_state=42)
|
||||
|
||||
# Генерация предсказаний на отложенной выборке
|
||||
y_pred = regression_pipeline.predict(x_test)
|
||||
|
||||
# Расчет метрик качества регрессии (Mean Squared Error, R-squared)
|
||||
mse_valence = mean_squared_error(y_test[:, 0], y_pred[:, 0])
|
||||
r2_valence = r2_score(y_test[:, 0], y_pred[:, 0])
|
||||
|
||||
mse_arousal = mean_squared_error(y_test[:, 1], y_pred[:, 1])
|
||||
r2_arousal = r2_score(y_test[:, 1], y_pred[:, 1])
|
||||
|
||||
print("Метрики качества регрессионной модели на тестовой выборке:")
|
||||
print(f"Valence -> MSE: {mse_valence:.4f} | R^2: {r2_valence:.4f}")
|
||||
print(f"Arousal -> MSE: {mse_arousal:.4f} | R^2: {r2_arousal:.4f}")
|
||||
|
||||
# Построение диагностических диаграмм рассеяния (Scatter Plots)
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
|
||||
|
||||
# Конфигурация подграфика: Ось Валентности
|
||||
ax1.scatter(y_test[:, 0], y_pred[:, 0], alpha=0.3, color='#1f77b4', edgecolors='none', label='Прогноз регрессора')
|
||||
ax1.plot([1, 9], [1, 9], 'r--', lw=2, label='Идеальное совпадение (x=y)')
|
||||
ax1.set_title('Диаграмма рассеяния: Valence (Позитивность)', fontsize=14, fontweight='bold')
|
||||
ax1.set_xlabel('Эталонные значения (центры классов)', fontsize=12)
|
||||
ax1.set_ylabel('Непрерывные предсказания модели', fontsize=12)
|
||||
ax1.set_xlim(1, 9)
|
||||
ax1.set_ylim(1, 9)
|
||||
ax1.grid(True, linestyle='--', alpha=0.6)
|
||||
ax1.legend(loc='upper left', fontsize=10)
|
||||
|
||||
# Научное обоснование распределения данных для комиссии
|
||||
ax1.text(1.2, 8.2,
|
||||
'Формирование вертикальных кластеров\n'
|
||||
'обусловлено проекцией 8 дискретных\n'
|
||||
'базовых эмоций на непрерывную\n'
|
||||
'координатную плоскость.',
|
||||
fontsize=10, bbox=dict(facecolor='white', alpha=0.9, edgecolor='gray'))
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Конфигурация подграфика: Ось Активности
|
||||
ax2.scatter(y_test[:, 1], y_pred[:, 1], alpha=0.3, color='#ff7f0e', edgecolors='none', label='Прогноз регрессора')
|
||||
ax2.plot([1, 9], [1, 9], 'r--', lw=2, label='Идеальное совпадение (x=y)')
|
||||
ax2.set_title('Диаграмма рассеяния: Arousal (Активность)', fontsize=14, fontweight='bold')
|
||||
ax2.set_xlabel('Эталонные значения (центры классов)', fontsize=12)
|
||||
ax2.set_ylabel('Непрерывные предсказания модели', fontsize=12)
|
||||
ax2.set_xlim(1, 9)
|
||||
ax2.set_ylim(1, 9)
|
||||
ax2.grid(True, linestyle='--', alpha=0.6)
|
||||
ax2.legend(loc='upper left', fontsize=10)
|
||||
# 3. Загрузка эмбеддингов и меток
|
||||
X_test = np.load(embeddings_path)
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('regression_metrics_plot.png', dpi=300, bbox_inches='tight')
|
||||
print("Диагностические графики экспортированы в regression_metrics_plot.png")
|
||||
if len(X_test) != len(df):
|
||||
print(f"[WARN] Корректировка размеров выборки: Эмбеддинги ({len(X_test)}) != Метки ({len(df)})")
|
||||
min_len = min(len(X_test), len(df))
|
||||
X_test = X_test[:min_len]
|
||||
df = df.iloc[:min_len]
|
||||
|
||||
y_test_list = [EMO_TO_VA.get(label.lower().strip(), [5.0, 5.0]) for label in df['label']]
|
||||
y_test = np.array(y_test_list)
|
||||
|
||||
# 4. Выполнение инференса
|
||||
print("[INFO] Выполнение инференса регрессора на скрытом пространстве признаков...")
|
||||
regressor = joblib.load(model_path)
|
||||
y_pred = regressor.predict(X_test)
|
||||
|
||||
# === БЛОК ДИАГНОСТИКИ ШКАЛЫ ===
|
||||
print("\n" + "-"*50)
|
||||
print(" ДИАГНОСТИКА ДИАПАЗОНОВ ЗНАЧЕНИЙ ".center(50))
|
||||
print("-"*50)
|
||||
print(f"Истинные (y_test) -> Мин: {y_test.min():.2f}, Макс: {y_test.max():.2f}, Среднее: {y_test.mean():.2f}")
|
||||
print(f"Предсказания (y_pred) -> Мин: {y_pred.min():.2f}, Макс: {y_pred.max():.2f}, Среднее: {y_pred.mean():.2f}")
|
||||
print("-"*50 + "\n")
|
||||
# ==============================
|
||||
|
||||
# 5. Расчет метрик
|
||||
mse_v = mean_squared_error(y_test[:, 0], y_pred[:, 0])
|
||||
r2_v = r2_score(y_test[:, 0], y_pred[:, 0])
|
||||
|
||||
mse_a = mean_squared_error(y_test[:, 1], y_pred[:, 1])
|
||||
r2_a = r2_score(y_test[:, 1], y_pred[:, 1])
|
||||
|
||||
mse_total = mean_squared_error(y_test, y_pred)
|
||||
r2_total = r2_score(y_test, y_pred)
|
||||
|
||||
# 6. Вывод и сохранение результатов
|
||||
table_content = f"""
|
||||
==================================================
|
||||
ТАБЛИЦА МЕТРИК ДЛЯ СЛАЙДА 10
|
||||
==================================================
|
||||
| Метрика | Valence (V) | Arousal (A) | Общая (Total) |
|
||||
|------------|--------------|--------------|---------------|
|
||||
| MSE | {mse_v:<12.4f} | {mse_a:<12.4f} | {mse_total:<13.4f} |
|
||||
| R² | {r2_v:<12.4f} | {r2_a:<12.4f} | {r2_total:<13.4f} |
|
||||
==================================================
|
||||
|
||||
Формула целевой функции для вставки на слайд (LaTeX):
|
||||
$$Score_{{final}} = D_{{emo}} + 4.0 \cdot Acoustic_{{penalty}}$$
|
||||
"""
|
||||
|
||||
print(table_content)
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(table_content)
|
||||
|
||||
print(f"[SUCCESS] Метрики успешно сохранены в файл: {output_file.absolute()}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluate_regression_model()
|
||||
generate_slide_metrics()
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 313 KiB |
@@ -0,0 +1,12 @@
|
||||
|
||||
==================================================
|
||||
ТАБЛИЦА МЕТРИК ДЛЯ СЛАЙДА 10
|
||||
==================================================
|
||||
| Метрика | Valence (V) | Arousal (A) | Общая (Total) |
|
||||
|------------|--------------|--------------|---------------|
|
||||
| MSE | 1.5135 | 2.2743 | 1.8939 |
|
||||
| R² | 0.7927 | 0.4321 | 0.6124 |
|
||||
==================================================
|
||||
|
||||
Формула целевой функции для вставки на слайд (LaTeX):
|
||||
$$Score_{final} = D_{emo} + 4.0 \cdot Acoustic_{penalty}$$
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 243 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 1.3 MiB |
@@ -0,0 +1,319 @@
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from PIL import Image, ImageFile
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torchvision.transforms as T
|
||||
from torch.amp import autocast, GradScaler
|
||||
import timm
|
||||
|
||||
# Подавление предупреждений и защита от битых "хвостов" JPEG
|
||||
warnings.filterwarnings("ignore")
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Устройство: {DEVICE}")
|
||||
|
||||
# --- ПУТИ ---
|
||||
TRAIN_ROOT = Path("./dataset/Original-2.41M")
|
||||
ANCHOR_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/train") # ЯКОРЬ (Чистые данные для обучения)
|
||||
VAL_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/val")
|
||||
|
||||
SAVE_MODEL_PATH = Path("./src/emosetV2_resnet50_finetuned_2_41M.pth")
|
||||
RESUME_CHECKPOINT = Path("./src/finetuneV2_resume.pth")
|
||||
PRETRAINED_PATH = Path("./src/emosetV2_resnet50_best.pth")
|
||||
|
||||
CLASS_MAPPING = {
|
||||
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
|
||||
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
|
||||
}
|
||||
|
||||
# --- НАСТРОЙКИ ---
|
||||
TOTAL_BATCH_SIZE = 64
|
||||
BATCH_NOISY = 48 # 75% батча - новые данные 2.41M
|
||||
BATCH_ANCHOR = 16 # 25% батча - чистые якорные данные 118K
|
||||
|
||||
EPOCHS_PER_FOLDER = 15
|
||||
PATIENCE = 5
|
||||
LR = 1e-6
|
||||
NUM_TRAIN_WORKERS = 32
|
||||
NUM_VAL_WORKERS = 32
|
||||
|
||||
def worker_init_fn(worker_id):
|
||||
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||||
|
||||
# --- 1. ТРАНСФОРМАЦИИ ---
|
||||
train_transform = T.Compose([
|
||||
T.Resize(256),
|
||||
T.RandomResizedCrop(224, scale=(0.8, 1.0)),
|
||||
T.RandomHorizontalFlip(),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
val_transform = T.Compose([
|
||||
T.Resize(256),
|
||||
T.CenterCrop(224),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# --- 2. ДАТАСЕТЫ ---
|
||||
class ChunkTrainDataset(Dataset):
|
||||
def __init__(self, paths, transform):
|
||||
self.paths = paths
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
path = self.paths[idx]
|
||||
try:
|
||||
img = Image.open(path).convert('RGB')
|
||||
tensor = self.transform(img)
|
||||
label = CLASS_MAPPING.get(path.parts[-3].lower(), 0)
|
||||
return tensor, label
|
||||
except Exception:
|
||||
return torch.zeros((3, 224, 224)), 0
|
||||
|
||||
class CsvDataset(Dataset):
|
||||
def __init__(self, root, transform):
|
||||
self.root = Path(root)
|
||||
self.df = pd.read_csv(self.root / "labels.csv")
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
row = self.df.iloc[idx]
|
||||
path = self.root / "images" / row["filename"]
|
||||
try:
|
||||
img = Image.open(path).convert('RGB')
|
||||
tensor = self.transform(img)
|
||||
label = CLASS_MAPPING.get(row["label"].lower(), 0)
|
||||
return tensor, label
|
||||
except Exception:
|
||||
return torch.zeros((3, 224, 224)), 0
|
||||
|
||||
# --- 3. СБОР ДАННЫХ ---
|
||||
def prepare_chunks():
|
||||
print("\nСканирование датасета 2.41M...")
|
||||
chunk_dict = defaultdict(list)
|
||||
for path in TRAIN_ROOT.rglob('*.jpg'):
|
||||
emotion = path.parts[-3].lower()
|
||||
if emotion not in CLASS_MAPPING:
|
||||
continue
|
||||
folder_str = path.parts[-2]
|
||||
if folder_str.isdigit():
|
||||
chunk_dict[int(folder_str)].append(path)
|
||||
|
||||
sorted_chunks = sorted(chunk_dict.keys())
|
||||
print(f"Найдено пронумерованных папок (чанков): {len(sorted_chunks)}")
|
||||
return chunk_dict, sorted_chunks
|
||||
# --- 4. ОСНОВНОЙ ЦИКЛ ОБУЧЕНИЯ ---
|
||||
if __name__ == "__main__":
|
||||
chunk_dict, sorted_chunks = prepare_chunks()
|
||||
|
||||
# Валидационный датасет (только чистые данные)
|
||||
val_loader = DataLoader(
|
||||
CsvDataset(VAL_118K_ROOT, val_transform),
|
||||
batch_size=TOTAL_BATCH_SIZE, shuffle=False,
|
||||
num_workers=NUM_VAL_WORKERS, pin_memory=True
|
||||
)
|
||||
|
||||
# ЯКОРНЫЙ ЗАГРУЗЧИК (Чистые данные для подмешивания)
|
||||
# Используем prefetch_factor и persistent_workers для устранения рывков CPU
|
||||
anchor_dataset = CsvDataset(ANCHOR_118K_ROOT, train_transform)
|
||||
anchor_loader = DataLoader(
|
||||
anchor_dataset, batch_size=BATCH_ANCHOR, shuffle=True,
|
||||
num_workers=16, pin_memory=True, drop_last=True,
|
||||
prefetch_factor=2, persistent_workers=False
|
||||
)
|
||||
|
||||
# Инициализация модели
|
||||
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
|
||||
if PRETRAINED_PATH.exists():
|
||||
model.load_state_dict(torch.load(PRETRAINED_PATH, map_location=DEVICE))
|
||||
print(f"Базовые веса загружены из {PRETRAINED_PATH.name}")
|
||||
|
||||
# Размораживаем всю модель
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
# Дифференцированный оптимизатор
|
||||
backbone_params = [p for n, p in model.named_parameters() if "fc" not in n]
|
||||
fc_params = [p for n, p in model.named_parameters() if "fc" in n]
|
||||
|
||||
optimizer = torch.optim.AdamW([
|
||||
{'params': backbone_params, 'lr': LR}, # 1e-6: микро-шаг для основы
|
||||
{'params': fc_params, 'lr': LR * 10} # 1e-5: шаг для классификатора
|
||||
], weight_decay=1e-3)
|
||||
|
||||
# Label Smoothing помогает игнорировать мусор в разметке 2.41M
|
||||
criterion = nn.CrossEntropyLoss(label_smoothing=0.15)
|
||||
scaler = GradScaler()
|
||||
|
||||
# --- ПАРАМЕТРЫ ВОССТАНОВЛЕНИЯ ---
|
||||
start_stage = 0
|
||||
start_epoch = 1
|
||||
best_val_loss = float('inf')
|
||||
|
||||
if RESUME_CHECKPOINT.exists():
|
||||
print(f"\nОбнаружен чекпоинт: {RESUME_CHECKPOINT.name}. Восстановление...")
|
||||
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
try:
|
||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
except Exception as e:
|
||||
print(f"Оптимизатор сброшен: {e}")
|
||||
|
||||
best_val_loss = checkpoint['best_val_loss']
|
||||
start_stage = checkpoint['stage']
|
||||
start_epoch = checkpoint['epoch'] + 1
|
||||
print(f"Успешный запуск с ЭТАПА {start_stage + 1}, Эпохи {start_epoch}. Best Val Loss: {best_val_loss:.4f}\n")
|
||||
else:
|
||||
# --- ЗАМЕР EPOCH 0 (БАЗОВАЯ ТОЧНОСТЬ) ---
|
||||
# Выполняется только если мы начинаем с нуля
|
||||
print("\n[Проверка базовых весов перед обучением (Epoch 0)]")
|
||||
model.eval()
|
||||
val_loss, val_correct, val_total = 0.0, 0, 0
|
||||
with torch.no_grad():
|
||||
for inputs, labels in tqdm(val_loader, desc="Baseline Eval", smoothing=0):
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||
with autocast(device_type="cuda"):
|
||||
outputs = model(inputs)
|
||||
v_loss = criterion(outputs, labels)
|
||||
val_loss += v_loss.item() * inputs.size(0)
|
||||
_, pred = outputs.max(1)
|
||||
val_total += labels.size(0)
|
||||
val_correct += pred.eq(labels).sum().item()
|
||||
|
||||
best_val_loss = val_loss / val_total
|
||||
baseline_acc = val_correct / val_total
|
||||
print(f"Стартовая точка -> Val Loss: {best_val_loss:.4f} | Val Acc: {baseline_acc:.4f}\n")
|
||||
|
||||
# ВОССТАНОВЛЕНИЕ НАКОПЛЕННЫХ ДАННЫХ
|
||||
current_train_paths = []
|
||||
for s in range(start_stage):
|
||||
current_train_paths.extend(chunk_dict[sorted_chunks[s]])
|
||||
|
||||
print("Старт Anchor Curriculum Learning (Смешивание чистых и шумных данных).")
|
||||
|
||||
# ГЛАВНЫЙ ЦИКЛ ПО ПАПКАМ
|
||||
for stage in range(start_stage, len(sorted_chunks)):
|
||||
chunk_id = sorted_chunks[stage]
|
||||
print(f"\n{'='*50}")
|
||||
print(f"ЭТАП {stage+1}/{len(sorted_chunks)}: Добавляем папку '{chunk_id}'")
|
||||
|
||||
# Накопление и перемешивание
|
||||
current_train_paths.extend(chunk_dict[chunk_id])
|
||||
random.shuffle(current_train_paths)
|
||||
print(f"Всего файлов (грязных) в текущем пуле: {len(current_train_paths)}")
|
||||
|
||||
# ОСНОВНОЙ ЗАГРУЗЧИК (Грязные данные) с PREFETCH
|
||||
train_loader = DataLoader(
|
||||
ChunkTrainDataset(current_train_paths, train_transform),
|
||||
batch_size=BATCH_NOISY, shuffle=True,
|
||||
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
|
||||
worker_init_fn=worker_init_fn, drop_last=True,
|
||||
prefetch_factor=4, persistent_workers=True # Устраняет рывки CPU
|
||||
)
|
||||
|
||||
epochs_no_improve = 0
|
||||
first_epoch = start_epoch if stage == start_stage else 1
|
||||
|
||||
# Инициализация итератора якорей
|
||||
anchor_iter = iter(anchor_loader)
|
||||
|
||||
# ЦИКЛ ЭПОХ ДЛЯ ТЕКУЩЕГО ЭТАПА
|
||||
for epoch in range(first_epoch, EPOCHS_PER_FOLDER + 1):
|
||||
model.train()
|
||||
train_loss, train_correct, train_total = 0.0, 0, 0
|
||||
|
||||
for noisy_inputs, noisy_labels in tqdm(train_loader, desc=f"S{stage+1}-Ep{epoch}/{EPOCHS_PER_FOLDER} [Train]", smoothing=0):
|
||||
|
||||
# Достаем якорный чистый батч
|
||||
try:
|
||||
anc_inputs, anc_labels = next(anchor_iter)
|
||||
except StopIteration:
|
||||
anchor_iter = iter(anchor_loader)
|
||||
anc_inputs, anc_labels = next(anchor_iter)
|
||||
|
||||
# СМЕШИВАЕМ БАТЧИ (Грязные + Чистые)
|
||||
inputs = torch.cat([noisy_inputs, anc_inputs]).to(DEVICE)
|
||||
labels = torch.cat([noisy_labels, anc_labels]).to(DEVICE)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
with autocast(device_type="cuda"):
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
train_loss += loss.item() * inputs.size(0)
|
||||
_, pred = outputs.max(1)
|
||||
train_total += labels.size(0)
|
||||
train_correct += pred.eq(labels).sum().item()
|
||||
|
||||
# ВАЛИДАЦИЯ
|
||||
model.eval()
|
||||
val_loss, val_correct, val_total = 0.0, 0, 0
|
||||
with torch.no_grad():
|
||||
for inputs, labels in tqdm(val_loader, desc="[Val]", leave=False, smoothing=0):
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||
with autocast(device_type="cuda"):
|
||||
outputs = model(inputs)
|
||||
v_loss = criterion(outputs, labels)
|
||||
val_loss += v_loss.item() * inputs.size(0)
|
||||
_, pred = outputs.max(1)
|
||||
val_total += labels.size(0)
|
||||
val_correct += pred.eq(labels).sum().item()
|
||||
|
||||
avg_train_loss = train_loss / train_total
|
||||
avg_train_acc = train_correct / train_total
|
||||
avg_val_loss = val_loss / val_total
|
||||
avg_val_acc = val_correct / val_total
|
||||
|
||||
print(f"S{stage+1}-E{epoch} | Train L: {avg_train_loss:.4f}, Acc: {avg_train_acc:.4f} | Val L: {avg_val_loss:.4f}, Acc: {avg_val_acc:.4f}")
|
||||
|
||||
# СОХРАНЕНИЕ ЛУЧШИХ ВЕСОВ
|
||||
if avg_val_loss < best_val_loss:
|
||||
best_val_loss = avg_val_loss
|
||||
epochs_no_improve = 0
|
||||
torch.save(model.state_dict(), SAVE_MODEL_PATH)
|
||||
print("--> Обновлены лучшие веса")
|
||||
else:
|
||||
epochs_no_improve += 1
|
||||
|
||||
# АВАРИЙНОЕ СОХРАНЕНИЕ В КОНЦЕ ЭПОХИ
|
||||
checkpoint_state = {
|
||||
'stage': stage,
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'best_val_loss': best_val_loss
|
||||
}
|
||||
torch.save(checkpoint_state, RESUME_CHECKPOINT)
|
||||
os.sync() # Защита от отключения электричества
|
||||
print(f"--> Чекпоинт (Этап {stage+1}, Эпоха {epoch}) зафиксирован на диске.")
|
||||
|
||||
# РАННЯЯ ОСТАНОВКА ДЛЯ ТЕКУЩЕГО ЭТАПА
|
||||
if epochs_no_improve >= PATIENCE:
|
||||
print(f"Ранняя остановка для ЭТАПА {stage+1}. Переход к следующей папке...")
|
||||
break
|
||||
|
||||
# Сброс счетчика стартовой эпохи после прохождения восстановительного этапа
|
||||
start_epoch = 1
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.8 MiB |
@@ -0,0 +1,461 @@
|
||||
.
|
||||
├── bin
|
||||
│ ├── activate
|
||||
│ ├── activate.csh
|
||||
│ ├── activate.fish
|
||||
│ ├── activate.nu
|
||||
│ ├── activate.ps1
|
||||
│ ├── activate_this.py
|
||||
│ ├── debugpy
|
||||
│ ├── debugpy-adapter
|
||||
│ ├── f2py
|
||||
│ ├── fonttools
|
||||
│ ├── httpx
|
||||
│ ├── ipython
|
||||
│ ├── ipython3
|
||||
│ ├── isympy
|
||||
│ ├── jlpm
|
||||
│ ├── jsonpointer
|
||||
│ ├── jsonschema
|
||||
│ ├── jupyter
|
||||
│ ├── jupyter-dejavu
|
||||
│ ├── jupyter-events
|
||||
│ ├── jupyter-execute
|
||||
│ ├── jupyter-kernel
|
||||
│ ├── jupyter-kernelspec
|
||||
│ ├── jupyter-lab
|
||||
│ ├── jupyter-labextension
|
||||
│ ├── jupyter-labhub
|
||||
│ ├── jupyter-migrate
|
||||
│ ├── jupyter-nbconvert
|
||||
│ ├── jupyter-run
|
||||
│ ├── jupyter-server
|
||||
│ ├── jupyter-troubleshoot
|
||||
│ ├── jupyter-trust
|
||||
│ ├── normalizer
|
||||
│ ├── numpy-config
|
||||
│ ├── pip
|
||||
│ ├── pip3
|
||||
│ ├── pip3.12
|
||||
│ ├── proton
|
||||
│ ├── proton-viewer
|
||||
│ ├── pybabel
|
||||
│ ├── pyftmerge
|
||||
│ ├── pyftsubset
|
||||
│ ├── pygmentize
|
||||
│ ├── pyjson5
|
||||
│ ├── python -> /usr/bin/python3
|
||||
│ ├── python3 -> python
|
||||
│ ├── python3.12 -> python
|
||||
│ ├── send2trash
|
||||
│ ├── streamlit
|
||||
│ ├── streamlit.cmd
|
||||
│ ├── torchfrtrace
|
||||
│ ├── torchrun
|
||||
│ ├── tqdm
|
||||
│ ├── ttx
|
||||
│ ├── watchmedo
|
||||
│ └── wsdump
|
||||
├── CACHEDIR.TAG
|
||||
├── docker
|
||||
│ ├── Dockerfile.api
|
||||
│ └── Dockerfile.ui
|
||||
├── docker-compose.yml
|
||||
├── Dockerfile
|
||||
├── .dockerignore
|
||||
├── .env
|
||||
├── etc
|
||||
│ └── jupyter
|
||||
│ ├── jupyter_notebook_config.d
|
||||
│ │ └── jupyterlab.json
|
||||
│ ├── jupyter_server_config.d
|
||||
│ │ ├── jupyterlab.json
|
||||
│ │ ├── jupyter-lsp-jupyter-server.json
|
||||
│ │ ├── jupyter_server_terminals.json
|
||||
│ │ └── notebook_shim.json
|
||||
│ └── nbconfig
|
||||
│ └── notebook.d
|
||||
├── .gitignore
|
||||
├── .idea
|
||||
│ ├── .gitignore
|
||||
│ ├── inspectionProfiles
|
||||
│ │ └── profiles_settings.xml
|
||||
│ ├── misc.xml
|
||||
│ ├── modules.xml
|
||||
│ ├── Thesis.iml
|
||||
│ ├── vcs.xml
|
||||
│ └── workspace.xml
|
||||
├── lib
|
||||
│ └── python3.12
|
||||
│ └── site-packages
|
||||
│ ├── altair
|
||||
│ ├── altair-6.0.0.dist-info
|
||||
│ ├── anyio
|
||||
│ ├── anyio-4.12.1.dist-info
|
||||
│ ├── argon2
|
||||
│ ├── argon2_cffi-25.1.0.dist-info
|
||||
│ ├── _argon2_cffi_bindings
|
||||
│ ├── argon2_cffi_bindings-25.1.0.dist-info
|
||||
│ ├── arrow
|
||||
│ ├── arrow-1.4.0.dist-info
|
||||
│ ├── asttokens
|
||||
│ ├── asttokens-3.0.1.dist-info
|
||||
│ ├── async_lru
|
||||
│ ├── async_lru-2.0.5.dist-info
|
||||
│ ├── attr
|
||||
│ ├── attrs
|
||||
│ ├── attrs-25.4.0.dist-info
|
||||
│ ├── babel
|
||||
│ ├── babel-2.17.0.dist-info
|
||||
│ ├── beautifulsoup4-4.14.3.dist-info
|
||||
│ ├── bleach
|
||||
│ ├── bleach-6.3.0.dist-info
|
||||
│ ├── blinker
|
||||
│ ├── blinker-1.9.0.dist-info
|
||||
│ ├── bs4
|
||||
│ ├── cachetools
|
||||
│ ├── cachetools-6.2.4.dist-info
|
||||
│ ├── certifi
|
||||
│ ├── certifi-2026.1.4.dist-info
|
||||
│ ├── cffi
|
||||
│ ├── cffi-2.0.0.dist-info
|
||||
│ ├── _cffi_backend.cpython-312-x86_64-linux-gnu.so
|
||||
│ ├── charset_normalizer
|
||||
│ ├── charset_normalizer-3.4.4.dist-info
|
||||
│ ├── click
|
||||
│ ├── click-8.3.1.dist-info
|
||||
│ ├── comm
|
||||
│ ├── comm-0.2.3.dist-info
|
||||
│ ├── contourpy
|
||||
│ ├── contourpy-1.3.3.dist-info
|
||||
│ ├── cycler
|
||||
│ ├── cycler-0.12.1.dist-info
|
||||
│ ├── dateutil
|
||||
│ ├── debugpy
|
||||
│ ├── debugpy-1.8.19.dist-info
|
||||
│ ├── decorator-5.2.1.dist-info
|
||||
│ ├── decorator.py
|
||||
│ ├── defusedxml
|
||||
│ ├── defusedxml-0.7.1.dist-info
|
||||
│ ├── _distutils_hack
|
||||
│ ├── distutils-precedence.pth
|
||||
│ ├── .DS_Store
|
||||
│ ├── executing
|
||||
│ ├── executing-2.2.1.dist-info
|
||||
│ ├── fastjsonschema
|
||||
│ ├── fastjsonschema-2.21.2.dist-info
|
||||
│ ├── filelock
|
||||
│ ├── filelock-3.20.3.dist-info
|
||||
│ ├── fontTools
|
||||
│ ├── fonttools-4.61.1.dist-info
|
||||
│ ├── fqdn
|
||||
│ ├── fqdn-1.5.1.dist-info
|
||||
│ ├── fsspec
|
||||
│ ├── fsspec-2026.1.0.dist-info
|
||||
│ ├── functorch
|
||||
│ ├── git
|
||||
│ ├── gitdb
|
||||
│ ├── gitdb-4.0.12.dist-info
|
||||
│ ├── gitpython-3.1.46.dist-info
|
||||
│ ├── google
|
||||
│ ├── h11
|
||||
│ ├── h11-0.16.0.dist-info
|
||||
│ ├── httpcore
|
||||
│ ├── httpcore-1.0.9.dist-info
|
||||
│ ├── httpx
|
||||
│ ├── httpx-0.28.1.dist-info
|
||||
│ ├── idna
|
||||
│ ├── idna-3.11.dist-info
|
||||
│ ├── ipykernel
|
||||
│ ├── ipykernel-7.1.0.dist-info
|
||||
│ ├── ipykernel_launcher.py
|
||||
│ ├── IPython
|
||||
│ ├── ipython-9.9.0.dist-info
|
||||
│ ├── ipython_pygments_lexers-1.1.1.dist-info
|
||||
│ ├── ipython_pygments_lexers.py
|
||||
│ ├── isoduration
|
||||
│ ├── isoduration-20.11.0.dist-info
|
||||
│ ├── isympy.py
|
||||
│ ├── jedi
|
||||
│ ├── jedi-0.19.2.dist-info
|
||||
│ ├── jinja2
|
||||
│ ├── jinja2-3.1.6.dist-info
|
||||
│ ├── joblib
|
||||
│ ├── joblib-1.5.3.dist-info
|
||||
│ ├── json5
|
||||
│ ├── json5-0.13.0.dist-info
|
||||
│ ├── jsonpointer-3.0.0.dist-info
|
||||
│ ├── jsonpointer.py
|
||||
│ ├── jsonschema
|
||||
│ ├── jsonschema-4.26.0.dist-info
|
||||
│ ├── jsonschema_specifications
|
||||
│ ├── jsonschema_specifications-2025.9.1.dist-info
|
||||
│ ├── jupyter_client
|
||||
│ ├── jupyter_client-8.8.0.dist-info
|
||||
│ ├── jupyter_core
|
||||
│ ├── jupyter_core-5.9.1.dist-info
|
||||
│ ├── jupyter_events
|
||||
│ ├── jupyter_events-0.12.0.dist-info
|
||||
│ ├── jupyterlab
|
||||
│ ├── jupyterlab-4.5.1.dist-info
|
||||
│ ├── jupyterlab_pygments
|
||||
│ ├── jupyterlab_pygments-0.3.0.dist-info
|
||||
│ ├── jupyterlab_server
|
||||
│ ├── jupyterlab_server-2.28.0.dist-info
|
||||
│ ├── jupyter_lsp
|
||||
│ ├── jupyter_lsp-2.3.0.dist-info
|
||||
│ ├── jupyter.py
|
||||
│ ├── jupyter_server
|
||||
│ ├── jupyter_server-2.17.0.dist-info
|
||||
│ ├── jupyter_server_terminals
|
||||
│ ├── jupyter_server_terminals-0.5.3.dist-info
|
||||
│ ├── kiwisolver
|
||||
│ ├── kiwisolver-1.4.9.dist-info
|
||||
│ ├── lark
|
||||
│ ├── lark-1.3.1.dist-info
|
||||
│ ├── markupsafe
|
||||
│ ├── markupsafe-3.0.3.dist-info
|
||||
│ ├── matplotlib
|
||||
│ ├── matplotlib-3.10.8.dist-info
|
||||
│ ├── matplotlib_inline
|
||||
│ ├── matplotlib_inline-0.2.1.dist-info
|
||||
│ ├── mistune
|
||||
│ ├── mistune-3.2.0.dist-info
|
||||
│ ├── mpl_toolkits
|
||||
│ ├── mpmath
|
||||
│ ├── mpmath-1.3.0.dist-info
|
||||
│ ├── narwhals
|
||||
│ ├── narwhals-2.15.0.dist-info
|
||||
│ ├── nbclient
|
||||
│ ├── nbclient-0.10.4.dist-info
|
||||
│ ├── nbconvert
|
||||
│ ├── nbconvert-7.16.6.dist-info
|
||||
│ ├── nbformat
|
||||
│ ├── nbformat-5.10.4.dist-info
|
||||
│ ├── nest_asyncio-1.6.0.dist-info
|
||||
│ ├── nest_asyncio.py
|
||||
│ ├── networkx
|
||||
│ ├── networkx-3.6.1.dist-info
|
||||
│ ├── notebook_shim
|
||||
│ ├── notebook_shim-0.2.4.dist-info
|
||||
│ ├── numpy
|
||||
│ ├── numpy-2.4.1.dist-info
|
||||
│ ├── numpy.libs
|
||||
│ ├── nvidia
|
||||
│ ├── nvidia_cublas_cu12-12.8.4.1.dist-info
|
||||
│ ├── nvidia_cuda_cupti_cu12-12.8.90.dist-info
|
||||
│ ├── nvidia_cuda_nvrtc_cu12-12.8.93.dist-info
|
||||
│ ├── nvidia_cuda_runtime_cu12-12.8.90.dist-info
|
||||
│ ├── nvidia_cudnn_cu12-9.10.2.21.dist-info
|
||||
│ ├── nvidia_cufft_cu12-11.3.3.83.dist-info
|
||||
│ ├── nvidia_cufile_cu12-1.13.1.3.dist-info
|
||||
│ ├── nvidia_curand_cu12-10.3.9.90.dist-info
|
||||
│ ├── nvidia_cusolver_cu12-11.7.3.90.dist-info
|
||||
│ ├── nvidia_cusparse_cu12-12.5.8.93.dist-info
|
||||
│ ├── nvidia_cusparselt_cu12-0.7.1.dist-info
|
||||
│ ├── nvidia_nccl_cu12-2.27.5.dist-info
|
||||
│ ├── nvidia_nvjitlink_cu12-12.8.93.dist-info
|
||||
│ ├── nvidia_nvshmem_cu12-3.3.20.dist-info
|
||||
│ ├── nvidia_nvtx_cu12-12.8.90.dist-info
|
||||
│ ├── packaging
|
||||
│ ├── packaging-25.0.dist-info
|
||||
│ ├── pandas
|
||||
│ ├── pandas-2.3.3.dist-info
|
||||
│ ├── pandocfilters-1.5.1.dist-info
|
||||
│ ├── pandocfilters.py
|
||||
│ ├── parso
|
||||
│ ├── parso-0.8.5.dist-info
|
||||
│ ├── pexpect
|
||||
│ ├── pexpect-4.9.0.dist-info
|
||||
│ ├── PIL
|
||||
│ ├── pillow-12.1.0.dist-info
|
||||
│ ├── pillow.libs
|
||||
│ ├── pip
|
||||
│ ├── pip-25.3.dist-info
|
||||
│ ├── pkg_resources
|
||||
│ ├── platformdirs
|
||||
│ ├── platformdirs-4.5.1.dist-info
|
||||
│ ├── prometheus_client
|
||||
│ ├── prometheus_client-0.23.1.dist-info
|
||||
│ ├── prompt_toolkit
|
||||
│ ├── prompt_toolkit-3.0.52.dist-info
|
||||
│ ├── protobuf-6.33.4.dist-info
|
||||
│ ├── psutil
|
||||
│ ├── psutil-7.2.1.dist-info
|
||||
│ ├── ptyprocess
|
||||
│ ├── ptyprocess-0.7.0.dist-info
|
||||
│ ├── pure_eval
|
||||
│ ├── pure_eval-0.2.3.dist-info
|
||||
│ ├── pyarrow
|
||||
│ ├── pyarrow-22.0.0.dist-info
|
||||
│ ├── pycparser
|
||||
│ ├── pycparser-2.23.dist-info
|
||||
│ ├── pydeck
|
||||
│ ├── pydeck-0.9.1.dist-info
|
||||
│ ├── pygments
|
||||
│ ├── pygments-2.19.2.dist-info
|
||||
│ ├── pylab.py
|
||||
│ ├── pyparsing
|
||||
│ ├── pyparsing-3.3.1.dist-info
|
||||
│ ├── python_dateutil-2.9.0.post0.dist-info
|
||||
│ ├── pythonjsonlogger
|
||||
│ ├── python_json_logger-4.0.0.dist-info
|
||||
│ ├── pytz
|
||||
│ ├── pytz-2025.2.dist-info
|
||||
│ ├── pyyaml-6.0.3.dist-info
|
||||
│ ├── pyzmq-27.1.0.dist-info
|
||||
│ ├── pyzmq.libs
|
||||
│ ├── referencing
|
||||
│ ├── referencing-0.37.0.dist-info
|
||||
│ ├── requests
|
||||
│ ├── requests-2.32.5.dist-info
|
||||
│ ├── rfc3339_validator-0.1.4.dist-info
|
||||
│ ├── rfc3339_validator.py
|
||||
│ ├── rfc3986_validator-0.1.1.dist-info
|
||||
│ ├── rfc3986_validator.py
|
||||
│ ├── rfc3987_syntax
|
||||
│ ├── rfc3987_syntax-1.1.0.dist-info
|
||||
│ ├── rpds
|
||||
│ ├── rpds_py-0.30.0.dist-info
|
||||
│ ├── scikit_learn-1.8.0.dist-info
|
||||
│ ├── scikit_learn.libs
|
||||
│ ├── scipy
|
||||
│ ├── scipy-1.17.0.dist-info
|
||||
│ ├── scipy.libs
|
||||
│ ├── send2trash
|
||||
│ ├── send2trash-2.0.0.dist-info
|
||||
│ ├── setuptools
|
||||
│ ├── setuptools-80.9.0.dist-info
|
||||
│ ├── six-1.17.0.dist-info
|
||||
│ ├── six.py
|
||||
│ ├── sklearn
|
||||
│ ├── smmap
|
||||
│ ├── smmap-5.0.2.dist-info
|
||||
│ ├── soupsieve
|
||||
│ ├── soupsieve-2.8.1.dist-info
|
||||
│ ├── stack_data
|
||||
│ ├── stack_data-0.6.3.dist-info
|
||||
│ ├── streamlit
|
||||
│ ├── streamlit-1.53.0.dist-info
|
||||
│ ├── sympy
|
||||
│ ├── sympy-1.14.0.dist-info
|
||||
│ ├── tenacity
|
||||
│ ├── tenacity-9.1.2.dist-info
|
||||
│ ├── terminado
|
||||
│ ├── terminado-0.18.1.dist-info
|
||||
│ ├── threadpoolctl-3.6.0.dist-info
|
||||
│ ├── threadpoolctl.py
|
||||
│ ├── tinycss2
|
||||
│ ├── tinycss2-1.4.0.dist-info
|
||||
│ ├── toml
|
||||
│ ├── toml-0.10.2.dist-info
|
||||
│ ├── torch
|
||||
│ ├── torch-2.9.1.dist-info
|
||||
│ ├── torchaudio
|
||||
│ ├── torchaudio-2.9.1.dist-info
|
||||
│ ├── torchgen
|
||||
│ ├── torchvision
|
||||
│ ├── torchvision-0.24.1.dist-info
|
||||
│ ├── torchvision.libs
|
||||
│ ├── tornado
|
||||
│ ├── tornado-6.5.4.dist-info
|
||||
│ ├── tqdm
|
||||
│ ├── tqdm-4.67.1.dist-info
|
||||
│ ├── traitlets
|
||||
│ ├── traitlets-5.14.3.dist-info
|
||||
│ ├── triton
|
||||
│ ├── triton-3.5.1.dist-info
|
||||
│ ├── typing_extensions-4.15.0.dist-info
|
||||
│ ├── typing_extensions.py
|
||||
│ ├── tzdata
|
||||
│ ├── tzdata-2025.3.dist-info
|
||||
│ ├── uri_template
|
||||
│ ├── uri_template-1.3.0.dist-info
|
||||
│ ├── urllib3
|
||||
│ ├── urllib3-2.6.3.dist-info
|
||||
│ ├── _virtualenv.pth
|
||||
│ ├── _virtualenv.py
|
||||
│ ├── watchdog
|
||||
│ ├── watchdog-6.0.0.dist-info
|
||||
│ ├── wcwidth
|
||||
│ ├── wcwidth-0.2.14.dist-info
|
||||
│ ├── webcolors
|
||||
│ ├── webcolors-25.10.0.dist-info
|
||||
│ ├── webencodings
|
||||
│ ├── webencodings-0.5.1.dist-info
|
||||
│ ├── websocket
|
||||
│ ├── websocket_client-1.9.0.dist-info
|
||||
│ ├── _yaml
|
||||
│ ├── yaml
|
||||
│ └── zmq
|
||||
├── Makefile
|
||||
├── NFS
|
||||
├── poetry.lock
|
||||
├── pyproject.toml
|
||||
├── pyvenv.cfg
|
||||
├── README.md
|
||||
├── requirements.txt
|
||||
├── runs
|
||||
├── share
|
||||
│ ├── applications
|
||||
│ │ └── jupyterlab.desktop
|
||||
│ ├── icons
|
||||
│ │ └── hicolor
|
||||
│ │ └── scalable
|
||||
│ ├── jupyter
|
||||
│ │ ├── kernels
|
||||
│ │ │ └── python3
|
||||
│ │ ├── lab
|
||||
│ │ │ ├── schemas
|
||||
│ │ │ ├── static
|
||||
│ │ │ └── themes
|
||||
│ │ ├── labextensions
|
||||
│ │ │ └── jupyterlab_pygments
|
||||
│ │ ├── nbconvert
|
||||
│ │ │ └── templates
|
||||
│ │ └── nbextensions
|
||||
│ │ └── pydeck
|
||||
│ └── man
|
||||
│ └── man1
|
||||
│ ├── ipython.1
|
||||
│ ├── isympy.1
|
||||
│ └── ttx.1
|
||||
├── src
|
||||
│ ├── 5_epoch_emoset_resnet50_finetuned_2.41M.pth
|
||||
│ ├── api.py
|
||||
│ ├── data_loader.py
|
||||
│ ├── dataset_paths_cache.pkl
|
||||
│ ├── emoset_resnet50_best.pth
|
||||
│ ├── emoset_resnet50_finetuned_2_41M_best.pth
|
||||
│ ├── emoset_resnet50_resume.pth
|
||||
│ ├── emoset_test_embeddings.npy
|
||||
│ ├── emoset_test_labels.npy
|
||||
│ ├── main.py
|
||||
│ ├── music_engine
|
||||
│ │ ├── image_processor.py
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── llm_bridge.py
|
||||
│ │ ├── matcher.py
|
||||
│ │ └── va_regressor.pkl
|
||||
│ ├── scripts
|
||||
│ │ ├── 00_setup_env.sh
|
||||
│ │ ├── 01_download_DEAM.py
|
||||
│ │ ├── 02_download_EmoSet.py
|
||||
│ │ ├── 11_prerp_DEAM.py
|
||||
│ │ ├── 20_bench_GPU.py
|
||||
│ │ ├── 21_train_images.ipynb
|
||||
│ │ ├── 22_extract_embeddings.ipynb
|
||||
│ │ ├── 23_aggregate_DEAM_timeline.py
|
||||
│ │ ├── 24_train_regressor.py
|
||||
│ │ ├── 31_finetune_2.41M.py
|
||||
│ │ ├── 90_acc_images_model.ipynb
|
||||
│ │ └── 91_generate_metrics.py
|
||||
│ └── tabs
|
||||
│ ├── tab_dataset.py
|
||||
│ └── tab_live.py
|
||||
├── tree.txt
|
||||
└── .vscode
|
||||
├── launch.json
|
||||
└── tasks.json
|
||||
|
||||
322 directories, 137 files
|
||||
Reference in New Issue
Block a user