Compare commits
1 Commits
main
..
c32a2544ff
| Author | SHA1 | Date | |
|---|---|---|---|
| c32a2544ff |
@@ -1,18 +0,0 @@
|
|||||||
bin/
|
|
||||||
lib/
|
|
||||||
share/
|
|
||||||
etc/
|
|
||||||
include/
|
|
||||||
pyvenv.cfg
|
|
||||||
.idea/
|
|
||||||
.vscode/
|
|
||||||
__pycache__/
|
|
||||||
*.pyc
|
|
||||||
.git/
|
|
||||||
runs/
|
|
||||||
dataset/
|
|
||||||
NFS/
|
|
||||||
*.pth
|
|
||||||
*.pkl
|
|
||||||
*.npy
|
|
||||||
.env
|
|
||||||
-25
@@ -1,25 +0,0 @@
|
|||||||
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime
|
|
||||||
|
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# System dependencies for OpenCV and image processing
|
|
||||||
RUN apt-get update && apt-get install -y \
|
|
||||||
libglib2.0-0 \
|
|
||||||
libsm6 \
|
|
||||||
libxext6 \
|
|
||||||
libxrender-dev \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Install python packages
|
|
||||||
COPY requirements.txt .
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
# Copy source code
|
|
||||||
COPY src/ /app/src/
|
|
||||||
|
|
||||||
EXPOSE 8080
|
|
||||||
|
|
||||||
CMD ["streamlit", "run", "src/main.py", "--server.port", "8080", "--server.address", "0.0.0.0"]
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
.PHONY: up down logs restart status
|
|
||||||
|
|
||||||
# Сборка и запуск контейнеров в фоновом режиме
|
|
||||||
up:
|
|
||||||
docker compose up --build -d
|
|
||||||
|
|
||||||
# Остановка и удаление контейнеров
|
|
||||||
down:
|
|
||||||
docker compose down
|
|
||||||
|
|
||||||
# Просмотр логов в реальном времени
|
|
||||||
logs:
|
|
||||||
docker compose logs -f
|
|
||||||
|
|
||||||
# Быстрый перезапуск
|
|
||||||
restart: down up
|
|
||||||
|
|
||||||
# Проверка статуса
|
|
||||||
status:
|
|
||||||
docker compose ps
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
version: '3.8'
|
|
||||||
|
|
||||||
networks:
|
|
||||||
emom_mesh:
|
|
||||||
driver: bridge
|
|
||||||
|
|
||||||
services:
|
|
||||||
emom_ui:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: docker/Dockerfile.ui
|
|
||||||
container_name: emom_web_ui
|
|
||||||
restart: unless-stopped
|
|
||||||
ports:
|
|
||||||
- "8080:8080"
|
|
||||||
networks:
|
|
||||||
- emom_mesh
|
|
||||||
env_file:
|
|
||||||
- .env
|
|
||||||
volumes:
|
|
||||||
- ./src:/app/src
|
|
||||||
- ${DATA_DEAM_DIR}:/app/dataset/DEAM:ro
|
|
||||||
depends_on:
|
|
||||||
- emom_inference
|
|
||||||
|
|
||||||
emom_inference:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: docker/Dockerfile.api
|
|
||||||
container_name: emom_pytorch_api
|
|
||||||
restart: unless-stopped
|
|
||||||
networks:
|
|
||||||
- emom_mesh
|
|
||||||
env_file:
|
|
||||||
- .env
|
|
||||||
volumes:
|
|
||||||
- ${HOST_ARTIFACTS_DIR}/emoset_resnet50_best.pth:/app/src/emoset_resnet50_best.pth:ro
|
|
||||||
- ${HOST_ARTIFACTS_DIR}/music_engine/va_regressor.pkl:/app/src/music_engine/va_regressor.pkl:ro
|
|
||||||
- ${DATA_DEAM_DIR}:/app/dataset/DEAM:ro
|
|
||||||
- ~/.cache/huggingface:/root/.cache/huggingface
|
|
||||||
deploy:
|
|
||||||
resources:
|
|
||||||
reservations:
|
|
||||||
devices:
|
|
||||||
- driver: nvidia
|
|
||||||
count: 1
|
|
||||||
capabilities: [gpu]
|
|
||||||
|
|
||||||
emom_ollama:
|
|
||||||
image: ollama/ollama:latest
|
|
||||||
container_name: emom_ollama_engine
|
|
||||||
restart: unless-stopped
|
|
||||||
networks:
|
|
||||||
- emom_mesh
|
|
||||||
volumes:
|
|
||||||
- ~/.ollama:/root/.ollama
|
|
||||||
deploy:
|
|
||||||
resources:
|
|
||||||
reservations:
|
|
||||||
devices:
|
|
||||||
- driver: nvidia
|
|
||||||
count: 1
|
|
||||||
capabilities: [gpu]
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime
|
|
||||||
|
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y \
|
|
||||||
libglib2.0-0 libsm6 libxext6 libxrender-dev \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
RUN pip install --no-cache-dir fastapi uvicorn timm scikit-learn pandas joblib python-multipart transformers==4.38.2 tokenizers==0.15.2 accelerate
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
COPY src/ /app/src/
|
|
||||||
|
|
||||||
WORKDIR /app/src
|
|
||||||
|
|
||||||
EXPOSE 8000
|
|
||||||
|
|
||||||
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
FROM python:3.12-slim
|
|
||||||
|
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
RUN pip install --no-cache-dir streamlit==1.32.0 requests pandas pillow
|
|
||||||
|
|
||||||
COPY src/ /app/src/
|
|
||||||
|
|
||||||
WORKDIR /app/src
|
|
||||||
|
|
||||||
EXPOSE 8080
|
|
||||||
|
|
||||||
CMD ["streamlit", "run", "main.py", "--server.port", "8080", "--server.address", "0.0.0.0"]
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
streamlit==1.32.0
|
|
||||||
torch==2.2.1
|
|
||||||
torchvision==0.17.1
|
|
||||||
timm==0.9.16
|
|
||||||
pandas==2.2.1
|
|
||||||
scikit-learn==1.4.1.post1
|
|
||||||
joblib==1.3.2
|
|
||||||
transformers==4.38.2
|
|
||||||
requests==2.31.0
|
|
||||||
-76
@@ -1,76 +0,0 @@
|
|||||||
import io
|
|
||||||
import traceback
|
|
||||||
import numpy as np
|
|
||||||
from typing import List
|
|
||||||
from fastapi import FastAPI, UploadFile, File, HTTPException
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from data_loader import load_music_engine, load_image_processor
|
|
||||||
from music_engine.llm_bridge import LLMAcousticBridge
|
|
||||||
|
|
||||||
app = FastAPI(title="EmoM API", version="1.0.0")
|
|
||||||
|
|
||||||
ml_context = {
|
|
||||||
"image_processor": None,
|
|
||||||
"music_matcher": None,
|
|
||||||
"llm_bridge": None
|
|
||||||
}
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def startup_event():
|
|
||||||
print("Loading ML models...")
|
|
||||||
ml_context["image_processor"] = load_image_processor()
|
|
||||||
ml_context["music_matcher"] = load_music_engine()
|
|
||||||
ml_context["llm_bridge"] = LLMAcousticBridge()
|
|
||||||
print("Initialization complete.")
|
|
||||||
|
|
||||||
@app.post("/analyze")
|
|
||||||
async def analyze_event_endpoint(files: List[UploadFile] = File(...)):
|
|
||||||
try:
|
|
||||||
images = []
|
|
||||||
for file in files:
|
|
||||||
image_bytes = await file.read()
|
|
||||||
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
||||||
images.append(img)
|
|
||||||
|
|
||||||
print(f"Processing batch: {len(images)} images.")
|
|
||||||
|
|
||||||
img_processor = ml_context["image_processor"]
|
|
||||||
matcher = ml_context["music_matcher"]
|
|
||||||
llm = ml_context["llm_bridge"]
|
|
||||||
|
|
||||||
all_v, all_a = [], []
|
|
||||||
all_objects = []
|
|
||||||
|
|
||||||
for img in images:
|
|
||||||
embedding = img_processor.extract_embedding(img)
|
|
||||||
v, a = matcher.predict_va(embedding)
|
|
||||||
all_v.append(v)
|
|
||||||
all_a.append(a)
|
|
||||||
|
|
||||||
caption = img_processor.describe_scene(img)
|
|
||||||
all_objects.append(caption)
|
|
||||||
|
|
||||||
target_v = float(np.mean(all_v))
|
|
||||||
target_a = float(np.mean(all_a))
|
|
||||||
unique_semantics = list(set(all_objects))
|
|
||||||
|
|
||||||
llm_profile = llm.get_acoustic_profile(target_v, target_a, unique_semantics)
|
|
||||||
|
|
||||||
playlist_df = matcher.find_nearest_tracks(target_v, target_a, llm_profile=llm_profile, top_k=15)
|
|
||||||
tracks_list = playlist_df.to_dict(orient="records")
|
|
||||||
|
|
||||||
return JSONResponse(content={
|
|
||||||
"status": "success",
|
|
||||||
"images_processed": len(images),
|
|
||||||
"target_v": target_v,
|
|
||||||
"target_a": target_a,
|
|
||||||
"llm_profile": llm_profile,
|
|
||||||
"semantics": unique_semantics,
|
|
||||||
"tracks": tracks_list
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(traceback.format_exc())
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
+35
-27
@@ -1,46 +1,54 @@
|
|||||||
|
import streamlit as st
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple, List, Optional, Any
|
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from music_engine.matcher import MusicMatcher
|
from music_engine.matcher import MusicMatcher
|
||||||
from music_engine.image_processor import ImageProcessor
|
from music_engine.image_processor import ImageProcessor
|
||||||
|
|
||||||
|
# Определяем базовую директорию (папка src)
|
||||||
BASE_DIR = Path(__file__).resolve().parent
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
|
|
||||||
def load_music_engine() -> MusicMatcher:
|
@st.cache_resource
|
||||||
#Инициализация модуля подбора музыкальных композиций.
|
def load_music_engine():
|
||||||
|
"""Загрузка базы данных и модели регрессора."""
|
||||||
|
# music_db.csv лежит в dataset/DEAM/ (на уровень выше от src)
|
||||||
db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
|
db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
|
||||||
|
# va_regressor.pkl лежит в src/music_engine/
|
||||||
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
|
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
|
||||||
|
|
||||||
|
if not db_path.exists():
|
||||||
|
print(f"⚠️ Файл базы {db_path} не найден!")
|
||||||
|
return None
|
||||||
return MusicMatcher(db_path=db_path, model_path=model_path)
|
return MusicMatcher(db_path=db_path, model_path=model_path)
|
||||||
|
|
||||||
def load_image_processor() -> ImageProcessor:
|
@st.cache_resource
|
||||||
#Инициализация модуля экстракции визуальных признаков.
|
def load_image_processor():
|
||||||
weights_path = BASE_DIR / "emoset_resnet50_best.pth"
|
"""Загрузка ResNet-50 для извлечения признаков на лету."""
|
||||||
|
# Файл весов лежит в той же папке src, что и этот скрипт
|
||||||
|
model_path = BASE_DIR / "emoset_resnet50_best.pth"
|
||||||
|
|
||||||
return ImageProcessor(weights_path)
|
if not model_path.exists():
|
||||||
|
print(f"❌ КРИТИЧЕСКАЯ ОШИБКА: Веса не найдены по пути: {model_path}")
|
||||||
|
# Если не нашли в src, попробуем поискать в корне проекта на всякий случай
|
||||||
|
model_path = BASE_DIR.parent / "emoset_resnet50_best.pth"
|
||||||
|
|
||||||
def load_emoset_data() -> Tuple[Optional[List[str]], Optional[np.ndarray], Optional[np.ndarray], Optional[Path]]:
|
return ImageProcessor(model_path=model_path)
|
||||||
# Загрузка тестовой выборки датасета EmoSet.
|
|
||||||
# Модуль сохранен для обеспечения обратной совместимости в отладочном контуре.
|
|
||||||
try:
|
|
||||||
images_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
|
|
||||||
labels_path = BASE_DIR / "emoset_test_labels.npy"
|
|
||||||
embeddings_path = BASE_DIR / "emoset_test_embeddings.npy"
|
|
||||||
|
|
||||||
if not all(p.exists() for p in [labels_path, embeddings_path]):
|
|
||||||
return None, None, None, None
|
|
||||||
|
|
||||||
labels = np.load(labels_path)
|
|
||||||
embeddings = np.load(embeddings_path)
|
|
||||||
|
|
||||||
|
@st.cache_data
|
||||||
|
def load_emoset_data():
|
||||||
|
"""Загрузка тестовой выборки EmoSet для первой вкладки."""
|
||||||
|
# Пути относительно корня проекта
|
||||||
csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
|
csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
|
||||||
df = pd.read_csv(csv_path)
|
img_dir = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
|
||||||
|
emb_path = BASE_DIR / "emoset_test_embeddings.npy"
|
||||||
|
lbl_path = BASE_DIR / "emoset_test_labels.npy"
|
||||||
|
|
||||||
return df['filename'].tolist(), embeddings, labels, images_path
|
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[WARN] Failed to load EmoSet test artifacts: {str(e)}")
|
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
image_list = df['filename'].tolist()
|
||||||
|
embs = np.load(emb_path)
|
||||||
|
lbls = np.load(lbl_path)
|
||||||
|
|
||||||
|
return image_list, embs, lbls, img_dir
|
||||||
Binary file not shown.
+39
-185
@@ -1,189 +1,43 @@
|
|||||||
import os
|
|
||||||
import requests
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import streamlit.components.v1 as components
|
import sys
|
||||||
from PIL import Image
|
import os
|
||||||
import base64
|
import subprocess
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
st.set_page_config(page_title="EmoM Playlist Generator", layout="wide", initial_sidebar_state="collapsed")
|
from data_loader import load_music_engine, load_emoset_data, load_image_processor
|
||||||
|
from tabs.tab_dataset import render_dataset_tab
|
||||||
API_URL = os.getenv("BACKEND_API_URL", "http://emom_inference:8000") + "/analyze"
|
from tabs.tab_live import render_live_tab
|
||||||
DEAM_AUDIO_DIR = "/app/dataset/DEAM/DEAM_audio/MEMD_audio"
|
|
||||||
|
|
||||||
def get_thumbnail_html(images, max_display=12):
|
|
||||||
html_images = ""
|
|
||||||
for file in images[:max_display]:
|
|
||||||
img = Image.open(file)
|
|
||||||
img.thumbnail((100, 100))
|
|
||||||
if img.mode != "RGB":
|
|
||||||
img = img.convert("RGB")
|
|
||||||
buffered = BytesIO()
|
|
||||||
img.save(buffered, format="JPEG")
|
|
||||||
b64_str = base64.b64encode(buffered.getvalue()).decode()
|
|
||||||
html_images += f'<img src="data:image/jpeg;base64,{b64_str}" style="width: 60px; height: 60px; object-fit: cover; border-radius: 8px; margin-right: 8px; margin-bottom: 8px; border: 1px solid rgba(255, 255, 255, 0.2);">'
|
|
||||||
|
|
||||||
if len(images) > max_display:
|
|
||||||
html_images += f'<span style="display: inline-block; width: 60px; height: 60px; line-height: 60px; text-align: center; background: rgba(150, 150, 150, 0.2); border-radius: 8px; vertical-align: top; font-size: 14px;">+{len(images) - max_display}</span>'
|
|
||||||
return f'<div style="display: flex; flex-wrap: wrap;">{html_images}</div>'
|
|
||||||
|
|
||||||
def main():
|
|
||||||
if "live_state" not in st.session_state:
|
|
||||||
st.session_state.live_state = "upload"
|
|
||||||
if "result_data" not in st.session_state:
|
|
||||||
st.session_state.result_data = None
|
|
||||||
|
|
||||||
viewport = st.query_params.get("viewport", "desktop")
|
|
||||||
|
|
||||||
st.markdown("""
|
|
||||||
<style>
|
|
||||||
[data-testid="stFileUploadDropzone"] { min-height: 250px !important; display: flex; align-items: center; justify-content: center; border-radius: 16px; background-color: rgba(255, 75, 75, 0.03); }
|
|
||||||
.spinner-container { display: flex; flex-direction: column; align-items: center; justify-content: center; min-height: 40vh; margin-top: 10vh; }
|
|
||||||
.big-spinner { width: 120px; height: 120px; border: 10px solid rgba(255, 75, 75, 0.1); border-top: 10px solid #ff4b4b; border-radius: 50%; animation: spin 1s linear infinite; margin-bottom: 2rem; }
|
|
||||||
@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
|
|
||||||
#MainMenu {visibility: hidden;} footer {visibility: hidden;}
|
|
||||||
</style>
|
|
||||||
""", unsafe_allow_html=True)
|
|
||||||
|
|
||||||
if st.session_state.live_state == "upload":
|
|
||||||
upload_placeholder = st.empty()
|
|
||||||
with upload_placeholder.container():
|
|
||||||
st.write("Загрузите изображения для визуально-семантического анализа.")
|
|
||||||
if viewport == "mobile":
|
|
||||||
st.markdown("<br>", unsafe_allow_html=True)
|
|
||||||
|
|
||||||
uploaded_files = st.file_uploader(
|
|
||||||
"Загрузка файлов",
|
|
||||||
type=['png', 'jpg', 'jpeg'],
|
|
||||||
accept_multiple_files=True,
|
|
||||||
label_visibility="collapsed" if viewport == "mobile" else "visible"
|
|
||||||
)
|
|
||||||
|
|
||||||
if uploaded_files:
|
|
||||||
st.markdown("<br>", unsafe_allow_html=True)
|
|
||||||
if st.button("Выполнить анализ", type="primary", use_container_width=True):
|
|
||||||
st.session_state.uploaded_images = uploaded_files
|
|
||||||
st.session_state.live_state = "processing"
|
|
||||||
upload_placeholder.empty()
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
st.markdown("<br>", unsafe_allow_html=True)
|
|
||||||
st.caption("Выбранные файлы:")
|
|
||||||
st.markdown(get_thumbnail_html(uploaded_files), unsafe_allow_html=True)
|
|
||||||
|
|
||||||
elif st.session_state.live_state == "processing":
|
|
||||||
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
|
|
||||||
files = st.session_state.get("uploaded_images", [])
|
|
||||||
st.markdown('<div class="spinner-container"><div class="big-spinner"></div><h3 style="text-align: center; font-weight: 400;">Обработка данных...</h3></div>', unsafe_allow_html=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
upload_data = [('files', (f.name, f.getvalue(), f.type)) for f in files]
|
|
||||||
response = requests.post(API_URL, files=upload_data, timeout=300)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
st.session_state.result_data = response.json()
|
|
||||||
st.session_state.live_state = "result"
|
|
||||||
st.rerun()
|
|
||||||
else:
|
|
||||||
st.error(f"Ошибка сервера: {response.status_code}")
|
|
||||||
if st.button("Назад"):
|
|
||||||
st.session_state.live_state = "upload"
|
|
||||||
st.rerun()
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"Ошибка соединения: {str(e)}")
|
|
||||||
if st.button("Назад"):
|
|
||||||
st.session_state.live_state = "upload"
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
elif st.session_state.live_state == "result":
|
|
||||||
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
|
|
||||||
data = st.session_state.result_data
|
|
||||||
|
|
||||||
st.header(f"Сгенерированный плейлист (обработано файлов: {data['images_processed']})")
|
|
||||||
|
|
||||||
for row in data.get("tracks", []):
|
|
||||||
with st.container(border=True):
|
|
||||||
song_id = int(row['song_id'])
|
|
||||||
score = row['final_score']
|
|
||||||
|
|
||||||
audio_path = f"{DEAM_AUDIO_DIR}/{song_id}.mp3"
|
|
||||||
if not os.path.exists(audio_path):
|
|
||||||
audio_path = audio_path.replace('.mp3', '.wav')
|
|
||||||
|
|
||||||
if viewport == "desktop":
|
|
||||||
c1, c2 = st.columns([1, 3])
|
|
||||||
with c1:
|
|
||||||
st.write(f"**Track ID:** {song_id}")
|
|
||||||
st.caption(f"Score: {score:.4f}")
|
|
||||||
with c2:
|
|
||||||
if os.path.exists(audio_path):
|
|
||||||
st.audio(audio_path)
|
|
||||||
else:
|
|
||||||
st.caption("Аудиофайл не найден")
|
|
||||||
else:
|
|
||||||
st.write(f"**Track ID:** {song_id} (Score: {score:.4f})")
|
|
||||||
if os.path.exists(audio_path):
|
|
||||||
st.audio(audio_path)
|
|
||||||
else:
|
|
||||||
st.caption("Аудиофайл не найден")
|
|
||||||
|
|
||||||
st.markdown("<br>", unsafe_allow_html=True)
|
|
||||||
|
|
||||||
with st.expander("Отладочная информация (Метрики)"):
|
|
||||||
st.subheader("Координаты V/A")
|
|
||||||
c_v, c_a = st.columns(2)
|
|
||||||
c_v.metric("Valence", f"{data['target_v']:.2f}")
|
|
||||||
c_a.metric("Arousal", f"{data['target_a']:.2f}")
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
st.subheader("Акустические признаки (LLM)")
|
|
||||||
|
|
||||||
feature_titles = {
|
|
||||||
"energy": "RMS Energy",
|
|
||||||
"flux": "Spectral Flux",
|
|
||||||
"centroid": "Spectral Centroid",
|
|
||||||
"pitch": "F0 (Pitch)",
|
|
||||||
"hnr": "HNR",
|
|
||||||
"zcr": "ZCR"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Развернутые описания
|
|
||||||
feature_helps = {
|
|
||||||
"energy": "Среднеквадратичная амплитуда (громкость). Бывает высокой в плотных, интенсивных композициях, отражает общую акустическую энергию сцены.",
|
|
||||||
"flux": "Спектральный поток. Измеряет резкость изменений в спектре. Высок при четком, агрессивном ритме и частой смене нот.",
|
|
||||||
"centroid": "Спектральный центроид («яркость» звука). Высокие значения указывают на преобладание высоких частот (звонкие инструменты, открытые пространства).",
|
|
||||||
"pitch": "Основная частота звука. Высокий pitch характерен для позитивных, легких или, напротив, напряженных мелодий.",
|
|
||||||
"hnr": "Отношение гармоник к шуму. Высокий HNR — чистая мелодия и вокал. Низкий HNR — присутствие дисторшна, шумов или перкуссии.",
|
|
||||||
"zcr": "Частота пересечения нуля. Отражает шумовую составляющую сигнала. Высок в треках с выраженными ударными (hi-hats) или атмосферным шумом."
|
|
||||||
}
|
|
||||||
|
|
||||||
llm_profile = data.get("llm_profile")
|
|
||||||
if llm_profile and isinstance(llm_profile, dict) and len(llm_profile) > 0:
|
|
||||||
cols_per_row = 2 if viewport == "mobile" else 3
|
|
||||||
llm_items = list(llm_profile.items())
|
|
||||||
|
|
||||||
for i in range(0, len(llm_items), cols_per_row):
|
|
||||||
cols = st.columns(cols_per_row)
|
|
||||||
for j in range(cols_per_row):
|
|
||||||
if i + j < len(llm_items):
|
|
||||||
k, v = llm_items[i + j]
|
|
||||||
label = feature_titles.get(k, k)
|
|
||||||
tooltip = feature_helps.get(k, "")
|
|
||||||
cols[j].metric(label, f"{v:.2f}", help=tooltip)
|
|
||||||
else:
|
|
||||||
st.caption("Акустический профиль недоступен. Применен fallback-алгоритм.")
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
st.write("**Извлеченные теги (BLIP-2):**")
|
|
||||||
st.write(", ".join([str(c).capitalize() for c in data.get("semantics", [])]))
|
|
||||||
|
|
||||||
st.markdown("<br>", unsafe_allow_html=True)
|
|
||||||
|
|
||||||
if st.button("Новый запрос", use_container_width=True):
|
|
||||||
st.session_state.live_state = "upload"
|
|
||||||
st.session_state.result_data = None
|
|
||||||
st.session_state.pop("uploaded_images", None)
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# 1️⃣ Запуск приложения
|
||||||
|
# ----------------------------
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
if "STREAMLIT_RUN" not in os.environ:
|
||||||
|
os.environ["STREAMLIT_RUN"] = "1"
|
||||||
|
cmd = [sys.executable, "-m", "streamlit", "run", __file__, "--server.port", "8080", "--server.address", "0.0.0.0"]
|
||||||
|
subprocess.run(cmd)
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
st.set_page_config(page_title="Thesis Demo", layout="wide")
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# 2️⃣ Инициализация движка и данных
|
||||||
|
# ----------------------------
|
||||||
|
matcher = load_music_engine()
|
||||||
|
image_processor = load_image_processor()
|
||||||
|
image_files, embeddings, labels_array, images_path = load_emoset_data()
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# 3️⃣ Интерфейс и Вкладки
|
||||||
|
# ----------------------------
|
||||||
|
st.title("🖼️ Генератор саундтреков (Research Demo)")
|
||||||
|
|
||||||
|
tab1, tab2 = st.tabs(["📊 Отладка (Датасет EmoSet)", "📸 Анализ событий (Свои фото)"])
|
||||||
|
|
||||||
|
with tab1:
|
||||||
|
render_dataset_tab(matcher, image_files, embeddings, labels_array, images_path)
|
||||||
|
|
||||||
|
with tab2:
|
||||||
|
if image_processor:
|
||||||
|
render_live_tab(matcher, image_processor)
|
||||||
|
else:
|
||||||
|
st.error("Система обработки изображений недоступна (не найдены веса ResNet).")
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 851 KiB |
@@ -1,66 +1,57 @@
|
|||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
|
from PIL import Image
|
||||||
import timm
|
import timm
|
||||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# НОВЫЙ ИМПОРТ ДЛЯ VLM
|
||||||
|
from transformers import BlipProcessor, BlipForConditionalGeneration
|
||||||
|
|
||||||
class ImageProcessor:
|
class ImageProcessor:
|
||||||
def __init__(self, weights_path: str | Path):
|
def __init__(self, model_path: Path | str):
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
# Модель извлечения визуальных признаков
|
# --- ПОТОК 1: ЭМОЦИИ (ResNet-50) ---
|
||||||
self.feature_extractor = timm.create_model('resnet50', pretrained=False, num_classes=8)
|
print("⏳ Загрузка эмоционального модуля (ResNet-50)...")
|
||||||
|
self.emo_model = timm.create_model('resnet50', pretrained=False, num_classes=8)
|
||||||
|
if Path(model_path).exists():
|
||||||
|
self.emo_model.load_state_dict(torch.load(model_path, map_location=self.device))
|
||||||
|
self.emo_model.fc = torch.nn.Identity()
|
||||||
|
self.emo_model.to(self.device).eval()
|
||||||
|
|
||||||
if Path(weights_path).exists():
|
self.emo_transform = T.Compose([
|
||||||
self.feature_extractor.load_state_dict(torch.load(weights_path, map_location=self.device))
|
|
||||||
else:
|
|
||||||
print(f"Не удалось найти веса ResNet по пути: {weights_path}")
|
|
||||||
|
|
||||||
# Удаление слоя классификации для вывода сырого вектора эмбеддингов
|
|
||||||
self.feature_extractor.fc = torch.nn.Identity()
|
|
||||||
self.feature_extractor.to(self.device).eval()
|
|
||||||
|
|
||||||
# Трансформации для предварительной обработки изображений
|
|
||||||
self.preprocess_image = T.Compose([
|
|
||||||
T.Resize((224, 224)),
|
T.Resize((224, 224)),
|
||||||
T.ToTensor(),
|
T.ToTensor(),
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
])
|
])
|
||||||
|
|
||||||
# Модуль семантического описания сцены
|
# --- ПОТОК 2: СЕМАНТИКА И КОНТЕКСТ (BLIP Large) ---
|
||||||
print("Инициализация BLIP-2...")
|
print("⏳ Загрузка мощной VLM модели (BLIP) для описания сцен...")
|
||||||
# Обход бага конфигурации Hugging Face (ручная сборка процессора)
|
# Используем версию Large, так как позволяет железо V100
|
||||||
from transformers import BlipImageProcessor, AutoTokenizer
|
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
||||||
img_proc = BlipImageProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
|
||||||
tok = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b", use_fast=False)
|
|
||||||
self.blip_processor = Blip2Processor(image_processor=img_proc, tokenizer=tok)
|
print("✅ Обе нейросети визуального анализа успешно загружены на V100!")
|
||||||
self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
|
||||||
"Salesforce/blip2-opt-2.7b",
|
|
||||||
torch_dtype=torch.float16
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def extract_embedding(self, image: Image.Image) -> np.ndarray:
|
def extract_embedding(self, image: Image.Image) -> np.ndarray:
|
||||||
# Извлечение эмбеддингов из изображения
|
"""Извлекает 2048-мерный вектор эмоций."""
|
||||||
rgb_image = image.convert('RGB')
|
img_rgb = image.convert('RGB')
|
||||||
img_tensor = self.preprocess_image(rgb_image).unsqueeze(0).to(self.device)
|
img_tensor = self.emo_transform(img_rgb).unsqueeze(0).to(self.device)
|
||||||
|
return self.emo_model(img_tensor).cpu().numpy().flatten()
|
||||||
features = self.feature_extractor(img_tensor)
|
|
||||||
features_np = features.cpu().numpy()
|
|
||||||
|
|
||||||
return features_np.flatten()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def describe_scene(self, image: Image.Image) -> str:
|
def describe_scene(self, image: Image.Image) -> str:
|
||||||
# Генерация текстового описания сцены
|
"""Генерирует текстовое описание картинки (Captioning) для LLM."""
|
||||||
rgb_image = image.convert('RGB')
|
img_rgb = image.convert('RGB')
|
||||||
|
|
||||||
inputs = self.blip_processor(images=rgb_image, return_tensors="pt").to(self.device, torch.float16)
|
# Готовим картинку для BLIP
|
||||||
generated_ids = self.blip_model.generate(**inputs, max_new_tokens=40)
|
inputs = self.blip_processor(img_rgb, return_tensors="pt").to(self.device)
|
||||||
|
|
||||||
scene_description = self.blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
# Генерируем описание (max_new_tokens ограничим, чтобы было лаконично)
|
||||||
|
out = self.blip_model.generate(**inputs, max_new_tokens=30)
|
||||||
|
|
||||||
return scene_description.strip()
|
# Декодируем тензор в строку
|
||||||
|
caption = self.blip_processor.decode(out[0], skip_special_tokens=True)
|
||||||
|
return caption
|
||||||
@@ -1,65 +1,60 @@
|
|||||||
import os
|
import requests
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import requests
|
|
||||||
|
|
||||||
class LLMAcousticBridge:
|
class LLMAcousticBridge:
|
||||||
def __init__(self, model_name="dolphin-llama3:8b"):
|
def __init__(self, model_name="phi3", host="http://localhost:11434"):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
base_url = os.getenv("OLLAMA_API_URL", "http://emom_ollama:11434")
|
self.api_url = f"{host}/api/generate"
|
||||||
self.api_url = f"{base_url}/api/generate"
|
|
||||||
|
|
||||||
def get_acoustic_profile(self, valence, arousal, semantics):
|
def _clean_json(self, text):
|
||||||
context_str = ", ".join(semantics) if semantics else "abstract scene"
|
"""Вытаскивает чистый JSON из ответа нейросети."""
|
||||||
|
try:
|
||||||
|
match = re.search(r'\{.*\}', text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
return json.loads(match.group(0))
|
||||||
|
return json.loads(text)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
prompt = f"""
|
def get_acoustic_profile(self, valence, arousal, scene_descriptions):
|
||||||
|
"""Просит LLM сгенерировать идеальный звук под описание."""
|
||||||
|
# Объединяем описания, если загружено несколько фото
|
||||||
|
context_str = " | ".join(scene_descriptions) if scene_descriptions else "abstract scene"
|
||||||
|
|
||||||
|
prompt = f"""You are an expert music producer and acoustic engineer.
|
||||||
Analyze the visual context and emotions to determine the ideal background music properties.
|
Analyze the visual context and emotions to determine the ideal background music properties.
|
||||||
Emotions: Valence {valence:.1f}/9.0 (Positivity), Arousal {arousal:.1f}/9.0 (Energy).
|
Emotions: Valence {valence:.1f}/9.0 (Positivity), Arousal {arousal:.1f}/9.0 (Energy).
|
||||||
Visual Context: {context_str}.
|
Visual Context: {context_str}.
|
||||||
|
|
||||||
Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0.
|
Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0.
|
||||||
|
1. "energy": (Loudness/Density. High for massive/busy scenes, Low for calm)
|
||||||
|
2. "flux": (Rhythmic sharpness/Beat. High for action/people/cars, Low for static nature)
|
||||||
|
3. "centroid": (Brightness: 0=Dark/Bass/Massive, 1=Bright/Treble/Light)
|
||||||
|
4. "pitch": (Fundamental frequency: 0=Low pitch/Huge objects, 1=High pitch/Small objects)
|
||||||
|
5. "hnr": (Harmonics-to-Noise: 0=Noisy/Distorted textures, 1=Clear/Melodic/Smooth textures)
|
||||||
|
6. "zcr": (Percussiveness. High for detailed noise like leaves/rain, Low for solid blocks)
|
||||||
|
|
||||||
1. "energy": (Loudness/Density)
|
Return ONLY a valid JSON object. Do not add any text or explanation.
|
||||||
2. "flux": (Rhythmic sharpness/Beat)
|
Example: {{"energy": 0.5, "flux": 0.2, "centroid": 0.4, "pitch": 0.3, "hnr": 0.8, "zcr": 0.1}}"""
|
||||||
3. "centroid": (Brightness)
|
|
||||||
4. "pitch": (Fundamental frequency)
|
|
||||||
5. "hnr": (Harmonics-to-Noise)
|
|
||||||
6. "zcr": (Percussiveness)
|
|
||||||
|
|
||||||
Return ONLY a valid JSON object. No explanations, no markdown blocks.
|
|
||||||
Example: {{"energy": 0.8, "flux": 0.5, "centroid": 0.6, "pitch": 0.4, "hnr": 0.9, "zcr": 0.3}}
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = {
|
response = requests.post(self.api_url, json={
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"format": "json" # Принудительный JSON-режим Ollama
|
"format": "json"
|
||||||
}
|
}, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
print(f"Запрос акустического профиля к Ollama...")
|
result_text = response.json().get("response", "")
|
||||||
response = requests.post(self.api_url, json=payload, timeout=120)
|
profile = self._clean_json(result_text)
|
||||||
|
|
||||||
if response.status_code == 200:
|
# Проверяем, что все нужные ключи есть
|
||||||
data = response.json()
|
required_keys = ['energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr']
|
||||||
response_text = data.get("response", "")
|
if profile and all(k in profile for k in required_keys):
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. Попытка прямой десериализации
|
|
||||||
profile = json.loads(response_text)
|
|
||||||
return profile
|
return profile
|
||||||
except json.JSONDecodeError:
|
return None
|
||||||
# 2. Аварийное извлечение JSON из текста с помощью регулярного выражения
|
|
||||||
match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
|
||||||
if match:
|
|
||||||
return json.loads(match.group(0))
|
|
||||||
|
|
||||||
print(f"Ошибка парсинга LLM ответа: {response_text}")
|
|
||||||
return {}
|
|
||||||
else:
|
|
||||||
print(f"Ollama вернула ошибку HTTP: {response.status_code}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Ошибка соединения с Ollama: {str(e)}")
|
print(f"⚠️ Ошибка связи с локальной LLM: {e}")
|
||||||
return {}
|
return None
|
||||||
+25
-41
@@ -1,83 +1,67 @@
|
|||||||
import joblib
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import joblib
|
||||||
|
|
||||||
class MusicMatcher:
|
class MusicMatcher:
|
||||||
def __init__(self, db_path: Path | str, model_path: Path | str):
|
def __init__(self, db_path: Path | str, model_path: Path | str):
|
||||||
# Загрузка базы данных музыкальных произведений
|
# Загружаем твою новую, обогащенную базу
|
||||||
self.music_db = pd.read_csv(db_path)
|
self.music_db = pd.read_csv(db_path)
|
||||||
self.acoustic_features = ['energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr']
|
self.acoustic_features = ['energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr']
|
||||||
|
|
||||||
# Удаление записей с пропущенными целевыми или акустическими признаками
|
# Удаляем строки, где нет акустических фич
|
||||||
target_columns = ['valence', 'arousal'] + self.acoustic_features
|
self.music_db = self.music_db.dropna(subset=['valence', 'arousal'] + self.acoustic_features)
|
||||||
self.music_db = self.music_db.dropna(subset=target_columns)
|
|
||||||
|
|
||||||
# Масштабирование акустических параметров к диапазону [0, 1]
|
# Нормализуем акустику от 0 до 1, чтобы сравнивать с ответом LLM
|
||||||
self.norm_db = self.music_db.copy()
|
self.norm_db = self.music_db.copy()
|
||||||
for feat in self.acoustic_features:
|
for feat in self.acoustic_features:
|
||||||
f_min = self.norm_db[feat].min()
|
f_min, f_max = self.norm_db[feat].min(), self.norm_db[feat].max()
|
||||||
f_max = self.norm_db[feat].max()
|
|
||||||
if f_max > f_min:
|
if f_max > f_min:
|
||||||
self.norm_db[f"norm_{feat}"] = (self.norm_db[feat] - f_min) / (f_max - f_min)
|
self.norm_db[f"norm_{feat}"] = (self.norm_db[feat] - f_min) / (f_max - f_min)
|
||||||
else:
|
else:
|
||||||
self.norm_db[f"norm_{feat}"] = 0.0
|
self.norm_db[f"norm_{feat}"] = 0.0
|
||||||
|
|
||||||
# Определение путей к аудиофайлам и загрузка модели регрессии
|
|
||||||
self.audio_dir = Path(db_path).parent / "DEAM_audio" / "MEMD_audio"
|
self.audio_dir = Path(db_path).parent / "DEAM_audio" / "MEMD_audio"
|
||||||
|
self.regressor = joblib.load(model_path) if Path(model_path).exists() else None
|
||||||
|
|
||||||
if Path(model_path).exists():
|
def predict_va(self, embedding: np.ndarray):
|
||||||
self.regressor = joblib.load(model_path)
|
if self.regressor:
|
||||||
else:
|
prediction = self.regressor.predict(embedding.reshape(1, -1))[0]
|
||||||
self.regressor = None
|
return np.clip(prediction[0], 1.0, 9.0), np.clip(prediction[1], 1.0, 9.0)
|
||||||
|
|
||||||
def predict_va(self, embedding: np.ndarray) -> tuple[float, float]:
|
|
||||||
# Прогнозирование координат Valence/Arousal по визуальному эмбеддингу
|
|
||||||
if not self.regressor:
|
|
||||||
return 5.0, 5.0
|
return 5.0, 5.0
|
||||||
|
|
||||||
raw_prediction = self.regressor.predict(embedding.reshape(1, -1))[0]
|
def get_audio_path(self, song_id):
|
||||||
valence_pred = np.clip(raw_prediction[0], 1.0, 9.0)
|
if not self.audio_dir.exists(): return None
|
||||||
arousal_pred = np.clip(raw_prediction[1], 1.0, 9.0)
|
|
||||||
|
|
||||||
return float(valence_pred), float(arousal_pred)
|
|
||||||
|
|
||||||
def get_audio_path(self, song_id: int | float | str) -> Path | None:
|
|
||||||
# Поиск физического пути к аудиофайлу в зависимости от расширения
|
|
||||||
if not self.audio_dir.exists():
|
|
||||||
return None
|
|
||||||
|
|
||||||
clean_id = str(int(float(song_id)))
|
clean_id = str(int(float(song_id)))
|
||||||
for ext in ['.mp3', '.wav']:
|
for ext in ['.mp3', '.wav']:
|
||||||
path = self.audio_dir / f"{clean_id}{ext}"
|
path = self.audio_dir / f"{clean_id}{ext}"
|
||||||
if path.exists():
|
if path.exists(): return path
|
||||||
return path
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def find_nearest_tracks(self, target_v: float, target_a: float, llm_profile: dict = None, top_k: int = 5) -> pd.DataFrame:
|
def find_nearest_tracks(self, target_v: float, target_a: float, llm_profile: dict = None, top_k: int = 5):
|
||||||
# Расчет евклидова расстояния в эмоциональном пространстве Рассела
|
# 1. Эмоциональная дистанция (как и раньше)
|
||||||
v_dist = (self.norm_db['valence'] - target_v) ** 2
|
emo_dist = np.sqrt(
|
||||||
a_dist = (self.norm_db['arousal'] - target_a) ** 2
|
1.0 * (self.norm_db['valence'] - target_v)**2 +
|
||||||
|
2.5 * (self.norm_db['arousal'] - target_a)**2
|
||||||
|
)
|
||||||
|
self.norm_db['emo_distance'] = emo_dist
|
||||||
|
|
||||||
# Взвешенное расстояние с приоритетом оси активации (Arousal)
|
# Если LLM не дала ответ, сортируем только по эмоциям
|
||||||
self.norm_db['emo_distance'] = np.sqrt(1.0 * v_dist + 2.5 * a_dist)
|
|
||||||
|
|
||||||
# Ранжирование только по эмоциональному критерию при отсутствии профиля LLM
|
|
||||||
if not llm_profile:
|
if not llm_profile:
|
||||||
self.norm_db['final_score'] = self.norm_db['emo_distance']
|
self.norm_db['final_score'] = self.norm_db['emo_distance']
|
||||||
return self.norm_db.sort_values(by='final_score').head(top_k)
|
return self.norm_db.sort_values(by='final_score').head(top_k)
|
||||||
|
|
||||||
# Расчет отклонений по вектору акустических параметров LLM
|
# 2. Акустическая дистанция (сравниваем треки с запросом LLM)
|
||||||
acoustic_penalty = np.zeros(len(self.norm_db))
|
acoustic_penalty = np.zeros(len(self.norm_db))
|
||||||
for feat in self.acoustic_features:
|
for feat in self.acoustic_features:
|
||||||
if feat in llm_profile:
|
if feat in llm_profile:
|
||||||
target_val = llm_profile[feat]
|
target_val = llm_profile[feat]
|
||||||
acoustic_penalty += np.abs(self.norm_db[f"norm_{feat}"] - target_val)
|
acoustic_penalty += np.abs(self.norm_db[f"norm_{feat}"] - target_val)
|
||||||
|
|
||||||
# Нормирование акустической дистанции
|
# Усредняем штраф
|
||||||
self.norm_db['acoustic_distance'] = acoustic_penalty / len(self.acoustic_features)
|
self.norm_db['acoustic_distance'] = acoustic_penalty / len(self.acoustic_features)
|
||||||
|
|
||||||
# Вычисление интегральной метрики соответствия (мультимодальный скоринг)
|
# 3. Финальный Score (Смесь Эмоций и Акустики). Коэф 4.0 делает акустику важной!
|
||||||
self.norm_db['final_score'] = self.norm_db['emo_distance'] + (self.norm_db['acoustic_distance'] * 4.0)
|
self.norm_db['final_score'] = self.norm_db['emo_distance'] + (self.norm_db['acoustic_distance'] * 4.0)
|
||||||
|
|
||||||
return self.norm_db.sort_values(by='final_score').head(top_k)
|
return self.norm_db.sort_values(by='final_score').head(top_k)
|
||||||
Binary file not shown.
@@ -1,73 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Данный скрипт написан ИИ для быстрой подготовки окружения, установка драйверов и докера
|
|
||||||
# Остановка скрипта при возникновении любой ошибки
|
|
||||||
set -e
|
|
||||||
|
|
||||||
# Цвета для красивого вывода в консоль
|
|
||||||
GREEN='\033[0;32m'
|
|
||||||
YELLOW='\033[1;33m'
|
|
||||||
RED='\033[0;31m'
|
|
||||||
BLUE='\033[0;34m'
|
|
||||||
NC='\033[0m'
|
|
||||||
|
|
||||||
echo -e "${BLUE}[INFO]${NC} Инициализация проверки окружения для проекта EmoM..."
|
|
||||||
|
|
||||||
# 1. ПРОВЕРКА DOCKER
|
|
||||||
if ! command -v docker &> /dev/null; then
|
|
||||||
echo -e "${YELLOW}[SETUP]${NC} Docker не найден. Начинаем установку..."
|
|
||||||
# Использование официального скрипта установки Docker
|
|
||||||
curl -fsSL https://get.docker.com -o get-docker.sh
|
|
||||||
sudo sh get-docker.sh
|
|
||||||
rm get-docker.sh
|
|
||||||
|
|
||||||
# Добавляем текущего пользователя в группу docker, чтобы не писать sudo docker
|
|
||||||
sudo usermod -aG docker $USER
|
|
||||||
echo -e "${GREEN}[OK]${NC} Docker успешно установлен."
|
|
||||||
echo -e "${YELLOW}[ВНИМАНИЕ]${NC} Для применения прав группы docker потребуется перезайти в сессию SSH после завершения скрипта."
|
|
||||||
else
|
|
||||||
echo -e "${GREEN}[OK]${NC} Docker установлен ($(docker --version))."
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 2. ПРОВЕРКА DOCKER COMPOSE
|
|
||||||
if ! docker compose version &> /dev/null; then
|
|
||||||
echo -e "${YELLOW}[SETUP]${NC} Плагин Docker Compose не найден. Устанавливаем..."
|
|
||||||
sudo apt-get update && sudo apt-get install -y docker-compose-plugin
|
|
||||||
echo -e "${GREEN}[OK]${NC} Docker Compose успешно установлен."
|
|
||||||
else
|
|
||||||
echo -e "${GREEN}[OK]${NC} Плагин Docker Compose доступен."
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 3. ПРОВЕРКА NVIDIA И ПРОБРОСА GPU В DOCKER
|
|
||||||
if command -v nvidia-smi &> /dev/null; then
|
|
||||||
echo -e "${GREEN}[OK]${NC} Драйверы NVIDIA обнаружены."
|
|
||||||
|
|
||||||
# Проверяем наличие NVIDIA Container Toolkit
|
|
||||||
if ! dpkg -l | grep -q nvidia-container-toolkit; then
|
|
||||||
echo -e "${YELLOW}[SETUP]${NC} NVIDIA Container Toolkit не найден. Выполняется установка..."
|
|
||||||
|
|
||||||
# Настройка репозиториев NVIDIA
|
|
||||||
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
|
|
||||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
|
|
||||||
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
|
|
||||||
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
|
||||||
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y nvidia-container-toolkit
|
|
||||||
|
|
||||||
# Конфигурация Docker для работы с NVIDIA
|
|
||||||
echo -e "${YELLOW}[SETUP]${NC} Конфигурация runtime NVIDIA для Docker..."
|
|
||||||
sudo nvidia-ctk runtime configure --runtime=docker
|
|
||||||
sudo systemctl restart docker
|
|
||||||
|
|
||||||
echo -e "${GREEN}[OK]${NC} NVIDIA Container Toolkit установлен и настроен."
|
|
||||||
else
|
|
||||||
echo -e "${GREEN}[OK]${NC} NVIDIA Container Toolkit уже установлен."
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
echo -e "${RED}[WARN]${NC} Утилита nvidia-smi не найдена! Убедитесь, что драйверы видеокарты установлены, иначе Docker будет использовать только CPU."
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo -e "\n${BLUE}[INFO]${NC} ========================================="
|
|
||||||
echo -e "${GREEN}[SUCCESS]${NC} Окружение готово к работе!"
|
|
||||||
echo -e "Теперь вы можете запустить проект командой: ${YELLOW}make up${NC}"
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
import kagglehub
|
|
||||||
|
|
||||||
dataset_dir = Path("../dataset/DEAM")
|
|
||||||
dataset_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
print("Скачивание датасета DEAM...")
|
|
||||||
|
|
||||||
# kagglehub по умолчанию тянет данные в системный кэш (~/.cache)
|
|
||||||
cache_path = kagglehub.dataset_download("imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music")
|
|
||||||
|
|
||||||
print(f"Загружено в кэш: {cache_path}")
|
|
||||||
print(f"Перенос файлов в {dataset_dir} и очистка временной директории...")
|
|
||||||
|
|
||||||
# Перемещаем данные
|
|
||||||
shutil.copytree(cache_path, dataset_dir, dirs_exist_ok=True)
|
|
||||||
shutil.rmtree(cache_path)
|
|
||||||
|
|
||||||
print("Готово. Датасет DEAM загружен, кэш очищен.")
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
import csv
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from datasets import load_dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# Конфигурация корневой директории локального датасета
|
|
||||||
DATASET_DIR = Path("../dataset/EmoSet-118K")
|
|
||||||
|
|
||||||
def process_and_save_split(dataset_split, split_name: str, output_dir: Path):
|
|
||||||
# Подготовка структуры директорий для текущей выборки
|
|
||||||
split_dir = output_dir / split_name
|
|
||||||
img_dir = split_dir / "images"
|
|
||||||
img_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
labels_path = split_dir / "labels.csv"
|
|
||||||
|
|
||||||
print(f"Обработка выборки: {split_name}...")
|
|
||||||
|
|
||||||
# Открытие файла разметки перед циклом для минимизации I/O операций диска
|
|
||||||
with open(labels_path, mode="w", newline="", encoding="utf-8") as csv_file:
|
|
||||||
writer = csv.writer(csv_file)
|
|
||||||
writer.writerow(["filename", "label"])
|
|
||||||
|
|
||||||
for example in tqdm(dataset_split, desc=split_name):
|
|
||||||
img = example["image"]
|
|
||||||
emotion_label = example["emotion"]
|
|
||||||
img_id = example["image_id"]
|
|
||||||
|
|
||||||
file_name = f"{img_id}.jpg"
|
|
||||||
|
|
||||||
# Принудительная конвертация в RGB для безопасного сохранения в JPEG-формате
|
|
||||||
if img.mode != "RGB":
|
|
||||||
img = img.convert("RGB")
|
|
||||||
|
|
||||||
img.save(img_dir / file_name, format="JPEG")
|
|
||||||
|
|
||||||
writer.writerow([file_name, emotion_label])
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
DATASET_DIR.mkdir(exist_ok=True, parents=True)
|
|
||||||
|
|
||||||
# Инициализация подключения к Hugging Face Hub
|
|
||||||
print("Загрузка метаданных EmoSet-118K...")
|
|
||||||
raw_dataset = load_dataset("Woleek/EmoSet-118K")
|
|
||||||
|
|
||||||
# Итеративная выгрузка размеченных данных
|
|
||||||
for split_key in ["train", "val", "test"]:
|
|
||||||
if split_key in raw_dataset:
|
|
||||||
process_and_save_split(
|
|
||||||
dataset_split=raw_dataset[split_key],
|
|
||||||
split_name=split_key,
|
|
||||||
output_dir=DATASET_DIR
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Экспорт датасета завершен.")
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Конфигурация локальных путей
|
|
||||||
SOURCE_CSV = Path("../../dataset/DEAM/DEAM_Annotations/annotations/annotations averaged per song/song_level/static_annotations_averaged_songs_1_2000.csv")
|
|
||||||
OUTPUT_CSV = Path("../../dataset/DEAM/music_db.csv")
|
|
||||||
|
|
||||||
def prepare_deam_database():
|
|
||||||
if not SOURCE_CSV.exists():
|
|
||||||
print(f"Исходный файл аннотаций не найден: {SOURCE_CSV}")
|
|
||||||
return
|
|
||||||
|
|
||||||
print("Обработка разметки датасета DEAM...")
|
|
||||||
|
|
||||||
# Загрузка сырых данных с очисткой артефактов форматирования
|
|
||||||
raw_df = pd.read_csv(SOURCE_CSV, skipinitialspace=True)
|
|
||||||
|
|
||||||
# Экстракция координат пространства Рассела (Valence/Arousal)
|
|
||||||
processed_df = raw_df[['song_id', 'valence_mean', 'arousal_mean']].copy()
|
|
||||||
processed_df.columns = ['song_id', 'valence', 'arousal']
|
|
||||||
|
|
||||||
# Приведение идентификаторов к формату файловой системы (int)
|
|
||||||
processed_df['song_id'] = processed_df['song_id'].astype(int)
|
|
||||||
|
|
||||||
processed_df.to_csv(OUTPUT_CSV, index=False)
|
|
||||||
|
|
||||||
print(f"База успешно сформирована. Всего записей: {len(processed_df)}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
prepare_deam_database()
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
import time
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
|
|
||||||
# Конфигурация параметров нагрузочного тестирования
|
|
||||||
NUM_SAMPLES = 300_000
|
|
||||||
DIM_IN = 4096
|
|
||||||
DIM_OUT = 10
|
|
||||||
BATCH_SIZE = 16_384
|
|
||||||
NUM_STEPS = 1000
|
|
||||||
|
|
||||||
def run_gpu_benchmark():
|
|
||||||
# Проверка доступности аппаратного ускорения
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
print(f"Инициализация стресс-теста на устройстве: {device}")
|
|
||||||
|
|
||||||
# Генерация синтетического датасета для аллокации VRAM
|
|
||||||
x_data = torch.randn(NUM_SAMPLES, DIM_IN, device=device, dtype=torch.float32)
|
|
||||||
y_data = torch.randn(NUM_SAMPLES, DIM_OUT, device=device, dtype=torch.float32)
|
|
||||||
|
|
||||||
# Архитектура тестовой полносвязной сети
|
|
||||||
model = nn.Sequential(
|
|
||||||
nn.Linear(DIM_IN, 2048),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(2048, 1024),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(1024, 512),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(512, DIM_OUT)
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
loss_fn = nn.MSELoss()
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=1e-3)
|
|
||||||
|
|
||||||
print("Начало прогрева GPU и симуляции цикла обучения...")
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
for step in range(NUM_STEPS):
|
|
||||||
# Сэмплирование случайного батча
|
|
||||||
idx = torch.randint(0, NUM_SAMPLES, (BATCH_SIZE,), device=device)
|
|
||||||
x_batch = x_data[idx]
|
|
||||||
y_batch = y_data[idx]
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
predictions = model(x_batch)
|
|
||||||
loss = loss_fn(predictions, y_batch)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# Логирование статуса (каждые 100 итераций для снижения I/O overhead)
|
|
||||||
if step % 100 == 0:
|
|
||||||
print(f"Итерация {step}/{NUM_STEPS} | Текущий loss: {loss.item():.4f}")
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
print(f"Стресс-тест завершен. Общее время: {end_time - start_time:.2f} сек.")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
run_gpu_benchmark()
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
import os
|
|
||||||
import random
|
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
|
||||||
from PIL import Image
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
import torchvision.transforms as T
|
|
||||||
import timm
|
|
||||||
|
|
||||||
# Подавление предупреждений цветовых профилей
|
|
||||||
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
|
|
||||||
|
|
||||||
# Настройки окружения
|
|
||||||
DATA_ROOT = Path("/home/zin/projects/Thesis/NFS/Thesis/Emoset/EmoSet-118K")
|
|
||||||
BATCH_SIZE = 64
|
|
||||||
EPOCHS = 30
|
|
||||||
LR = 5e-5
|
|
||||||
NUM_WORKERS = 62
|
|
||||||
PATIENCE = 7
|
|
||||||
|
|
||||||
# Маппинг классов
|
|
||||||
CLASS_MAPPING = {
|
|
||||||
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
|
|
||||||
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
|
|
||||||
}
|
|
||||||
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
print(f"Устройство: {DEVICE}")
|
|
||||||
|
|
||||||
# Фиксация генераторов псевдослучайных чисел
|
|
||||||
def set_seed(seed=42):
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
|
|
||||||
set_seed()
|
|
||||||
|
|
||||||
# Инициализация структур данных
|
|
||||||
class EmoSetDataset(Dataset):
|
|
||||||
def __init__(self, root: Path | str, split: str, transform=None):
|
|
||||||
self.root = Path(root) / split
|
|
||||||
self.df = pd.read_csv(self.root / "labels.csv")
|
|
||||||
self.transform = transform
|
|
||||||
|
|
||||||
# Фильтрация датафрейма
|
|
||||||
self.df = self.df[self.df["label"].isin(CLASS_MAPPING.keys())].reset_index(drop=True)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.df)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
row = self.df.iloc[idx]
|
|
||||||
img_path = self.root / "images" / row["filename"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
img = Image.open(img_path).convert("RGB")
|
|
||||||
except Exception:
|
|
||||||
img = Image.new("RGB", (256, 256), (0, 0, 0))
|
|
||||||
|
|
||||||
if self.transform:
|
|
||||||
img_tensor = self.transform(img)
|
|
||||||
else:
|
|
||||||
img_tensor = T.ToTensor()(img)
|
|
||||||
|
|
||||||
label_idx = CLASS_MAPPING[row["label"]]
|
|
||||||
return img_tensor, label_idx
|
|
||||||
|
|
||||||
# Трансформации
|
|
||||||
base_tf = [
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
||||||
]
|
|
||||||
|
|
||||||
train_transform = T.Compose([
|
|
||||||
T.Resize(256, antialias=True),
|
|
||||||
T.RandomCrop(224),
|
|
||||||
T.RandomHorizontalFlip(),
|
|
||||||
*base_tf
|
|
||||||
])
|
|
||||||
|
|
||||||
val_transform = T.Compose([
|
|
||||||
T.Resize(256, antialias=True),
|
|
||||||
T.CenterCrop(224),
|
|
||||||
*base_tf
|
|
||||||
])
|
|
||||||
|
|
||||||
train_ds = EmoSetDataset(DATA_ROOT, "train", transform=train_transform)
|
|
||||||
val_ds = EmoSetDataset(DATA_ROOT, "val", transform=val_transform)
|
|
||||||
|
|
||||||
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
|
|
||||||
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
|
|
||||||
|
|
||||||
# Инициализация модели и оптимизатора
|
|
||||||
model = timm.create_model("resnet50", pretrained=True, num_classes=8, drop_rate=0.3)
|
|
||||||
model.to(DEVICE)
|
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-3)
|
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
|
||||||
|
|
||||||
# Логика эпохи обучения
|
|
||||||
def train_epoch(current_model, loader):
|
|
||||||
current_model.train()
|
|
||||||
total_loss, correct_preds, total_samples = 0.0, 0, 0
|
|
||||||
|
|
||||||
for imgs, labels in tqdm(loader, desc="Тренировка", leave=False, smoothing=0):
|
|
||||||
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
|
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
|
||||||
logits = current_model(imgs)
|
|
||||||
loss = criterion(logits, labels)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
total_loss += loss.item() * imgs.size(0)
|
|
||||||
preds = logits.argmax(dim=1)
|
|
||||||
correct_preds += (preds == labels).sum().item()
|
|
||||||
total_samples += labels.size(0)
|
|
||||||
|
|
||||||
return total_loss / total_samples, correct_preds / total_samples
|
|
||||||
|
|
||||||
# Логика эпохи валидации
|
|
||||||
@torch.no_grad()
|
|
||||||
def val_epoch(current_model, loader):
|
|
||||||
current_model.eval()
|
|
||||||
total_loss, correct_preds, total_samples = 0.0, 0, 0
|
|
||||||
|
|
||||||
for imgs, labels in tqdm(loader, desc="Валидация", leave=False, smoothing=0):
|
|
||||||
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
|
|
||||||
|
|
||||||
logits = current_model(imgs)
|
|
||||||
loss = criterion(logits, labels)
|
|
||||||
|
|
||||||
total_loss += loss.item() * imgs.size(0)
|
|
||||||
preds = logits.argmax(dim=1)
|
|
||||||
correct_preds += (preds == labels).sum().item()
|
|
||||||
total_samples += labels.size(0)
|
|
||||||
|
|
||||||
return total_loss / total_samples, correct_preds / total_samples
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
best_val_acc = 0.0
|
|
||||||
best_val_loss = float('inf')
|
|
||||||
epochs_no_improve = 0
|
|
||||||
checkpoint_path = "./emosetV2_resnet50_best.pth"
|
|
||||||
|
|
||||||
print("Старт обучения.")
|
|
||||||
|
|
||||||
for epoch in range(1, EPOCHS + 1):
|
|
||||||
train_loss, train_acc = train_epoch(model, train_loader)
|
|
||||||
val_loss, val_acc = val_epoch(model, val_loader)
|
|
||||||
|
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
|
|
||||||
|
|
||||||
# Сохранение лучших весов по Accuracy
|
|
||||||
if val_acc > best_val_acc:
|
|
||||||
best_val_acc = val_acc
|
|
||||||
torch.save(model.state_dict(), checkpoint_path)
|
|
||||||
print(f"Сохранен чекпоинт (Acc: {best_val_acc:.4f})")
|
|
||||||
|
|
||||||
# Оценка переобучения по Loss (Early Stopping)
|
|
||||||
if val_loss < best_val_loss:
|
|
||||||
best_val_loss = val_loss
|
|
||||||
epochs_no_improve = 0
|
|
||||||
else:
|
|
||||||
epochs_no_improve += 1
|
|
||||||
if epochs_no_improve >= PATIENCE:
|
|
||||||
print(f"Ранняя остановка: метрика валидации не улучшается {PATIENCE} эпох.")
|
|
||||||
break
|
|
||||||
|
|
||||||
print("Процесс завершен.")
|
|
||||||
@@ -1,283 +0,0 @@
|
|||||||
import os
|
|
||||||
import random
|
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
|
||||||
from PIL import Image
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
import torchvision.transforms as T
|
|
||||||
import timm
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import seaborn as sns
|
|
||||||
from sklearn.metrics import confusion_matrix
|
|
||||||
|
|
||||||
# Подавление предупреждений цветовых профилей
|
|
||||||
warnings.filterwarnings("ignore", message=".*Unknown Adobe color transform code.*")
|
|
||||||
|
|
||||||
# Настройки окружения
|
|
||||||
DATA_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K")
|
|
||||||
# ВАЖНО: Добавили путь для медиа файлов
|
|
||||||
MEDIA_DIR = Path("./src/scripts/media")
|
|
||||||
MEDIA_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
BATCH_SIZE = 64
|
|
||||||
EPOCHS = 30
|
|
||||||
LR = 5e-5
|
|
||||||
NUM_WORKERS = 32
|
|
||||||
PATIENCE = 7
|
|
||||||
|
|
||||||
# Маппинг классов
|
|
||||||
CLASS_MAPPING = {
|
|
||||||
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
|
|
||||||
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
|
|
||||||
}
|
|
||||||
# Инвертированный маппинг для графиков
|
|
||||||
INV_CLASS_MAPPING = {v: k for k, v in CLASS_MAPPING.items()}
|
|
||||||
CLASS_NAMES = [INV_CLASS_MAPPING[i] for i in range(len(CLASS_MAPPING))]
|
|
||||||
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
print(f"Устройство: {DEVICE}")
|
|
||||||
|
|
||||||
# Фиксация генераторов псевдослучайных чисел
|
|
||||||
def set_seed(seed=42):
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
|
|
||||||
set_seed()
|
|
||||||
|
|
||||||
# Инициализация структур данных
|
|
||||||
class EmoSetDataset(Dataset):
|
|
||||||
def __init__(self, root: Path | str, split: str, transform=None):
|
|
||||||
self.root = Path(root) / split
|
|
||||||
self.df = pd.read_csv(self.root / "labels.csv")
|
|
||||||
self.transform = transform
|
|
||||||
|
|
||||||
# Фильтрация датафрейма
|
|
||||||
self.df = self.df[self.df["label"].isin(CLASS_MAPPING.keys())].reset_index(drop=True)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.df)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
row = self.df.iloc[idx]
|
|
||||||
img_path = self.root / "images" / row["filename"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
img = Image.open(img_path).convert("RGB")
|
|
||||||
except Exception:
|
|
||||||
img = Image.new("RGB", (256, 256), (0, 0, 0))
|
|
||||||
|
|
||||||
if self.transform:
|
|
||||||
img_tensor = self.transform(img)
|
|
||||||
else:
|
|
||||||
img_tensor = T.ToTensor()(img)
|
|
||||||
|
|
||||||
label_idx = CLASS_MAPPING[row["label"]]
|
|
||||||
return img_tensor, label_idx
|
|
||||||
|
|
||||||
# Трансформации
|
|
||||||
base_tf = [
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
||||||
]
|
|
||||||
|
|
||||||
train_transform = T.Compose([
|
|
||||||
T.Resize(256, antialias=True),
|
|
||||||
T.RandomCrop(224),
|
|
||||||
T.RandomHorizontalFlip(),
|
|
||||||
*base_tf
|
|
||||||
])
|
|
||||||
|
|
||||||
val_transform = T.Compose([
|
|
||||||
T.Resize(256, antialias=True),
|
|
||||||
T.CenterCrop(224),
|
|
||||||
*base_tf
|
|
||||||
])
|
|
||||||
|
|
||||||
train_ds = EmoSetDataset(DATA_ROOT, "train", transform=train_transform)
|
|
||||||
val_ds = EmoSetDataset(DATA_ROOT, "val", transform=val_transform)
|
|
||||||
|
|
||||||
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
|
|
||||||
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
|
|
||||||
|
|
||||||
# Инициализация модели и оптимизатора
|
|
||||||
model = timm.create_model("resnet50", pretrained=True, num_classes=8, drop_rate=0.3)
|
|
||||||
model.to(DEVICE)
|
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-3)
|
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
|
||||||
|
|
||||||
# Функции для отрисовки графиков
|
|
||||||
def plot_learning_curves(history):
|
|
||||||
"""Отрисовка графиков функции потерь и точности"""
|
|
||||||
epochs = range(1, len(history['train_loss']) + 1)
|
|
||||||
|
|
||||||
plt.figure(figsize=(14, 5))
|
|
||||||
|
|
||||||
# График Loss
|
|
||||||
plt.subplot(1, 2, 1)
|
|
||||||
plt.plot(epochs, history['train_loss'], 'b-', label='Train Loss')
|
|
||||||
plt.plot(epochs, history['val_loss'], 'r--', label='Validation Loss')
|
|
||||||
plt.title('График функции потерь (Loss)', fontsize=14)
|
|
||||||
plt.xlabel('Эпохи', fontsize=12)
|
|
||||||
plt.ylabel('Loss', fontsize=12)
|
|
||||||
plt.legend()
|
|
||||||
plt.grid(True, linestyle=':', alpha=0.7)
|
|
||||||
|
|
||||||
# График Accuracy
|
|
||||||
plt.subplot(1, 2, 2)
|
|
||||||
plt.plot(epochs, history['train_acc'], 'b-', label='Train Accuracy')
|
|
||||||
plt.plot(epochs, history['val_acc'], 'r--', label='Validation Accuracy')
|
|
||||||
plt.title('График точности (Accuracy)', fontsize=14)
|
|
||||||
plt.xlabel('Эпохи', fontsize=12)
|
|
||||||
plt.ylabel('Accuracy', fontsize=12)
|
|
||||||
plt.legend()
|
|
||||||
plt.grid(True, linestyle=':', alpha=0.7)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plot_path = MEDIA_DIR / "training_history.png"
|
|
||||||
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
|
||||||
plt.close()
|
|
||||||
print(f"[INFO] График обучения сохранен в: {plot_path}")
|
|
||||||
|
|
||||||
def plot_confusion_matrix(y_true, y_pred):
|
|
||||||
"""Отрисовка тепловой матрицы ошибок"""
|
|
||||||
cm = confusion_matrix(y_true, y_pred)
|
|
||||||
|
|
||||||
plt.figure(figsize=(10, 8))
|
|
||||||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
|
||||||
xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
|
|
||||||
cbar_kws={'label': 'Количество сэмплов'})
|
|
||||||
|
|
||||||
plt.title('Матрица ошибок (Confusion Matrix) - ResNet50', fontsize=16, pad=20)
|
|
||||||
plt.ylabel('Истинные классы (Ground Truth)', fontsize=12)
|
|
||||||
plt.xlabel('Предсказанные классы (Predicted)', fontsize=12)
|
|
||||||
|
|
||||||
plt.xticks(rotation=45, ha='right')
|
|
||||||
plt.yticks(rotation=0)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
cm_path = MEDIA_DIR / "confusion_matrix_emoset.png"
|
|
||||||
plt.savefig(cm_path, dpi=300, bbox_inches='tight')
|
|
||||||
plt.close()
|
|
||||||
print(f"[INFO] Матрица ошибок сохранена в: {cm_path}")
|
|
||||||
|
|
||||||
# Логика эпохи обучения
|
|
||||||
def train_epoch(current_model, loader):
|
|
||||||
current_model.train()
|
|
||||||
total_loss, correct_preds, total_samples = 0.0, 0, 0
|
|
||||||
|
|
||||||
for imgs, labels in tqdm(loader, desc="Тренировка", leave=False, smoothing=0):
|
|
||||||
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
|
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
|
||||||
logits = current_model(imgs)
|
|
||||||
loss = criterion(logits, labels)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
total_loss += loss.item() * imgs.size(0)
|
|
||||||
preds = logits.argmax(dim=1)
|
|
||||||
correct_preds += (preds == labels).sum().item()
|
|
||||||
total_samples += labels.size(0)
|
|
||||||
|
|
||||||
return total_loss / total_samples, correct_preds / total_samples
|
|
||||||
|
|
||||||
# Логика эпохи валидации с сохранением предсказаний для матрицы ошибок
|
|
||||||
@torch.no_grad()
|
|
||||||
def val_epoch(current_model, loader, return_preds=False):
|
|
||||||
current_model.eval()
|
|
||||||
total_loss, correct_preds, total_samples = 0.0, 0, 0
|
|
||||||
all_preds, all_labels = [], []
|
|
||||||
|
|
||||||
for imgs, labels in tqdm(loader, desc="Валидация", leave=False, smoothing=0):
|
|
||||||
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
|
|
||||||
|
|
||||||
logits = current_model(imgs)
|
|
||||||
loss = criterion(logits, labels)
|
|
||||||
|
|
||||||
total_loss += loss.item() * imgs.size(0)
|
|
||||||
preds = logits.argmax(dim=1)
|
|
||||||
|
|
||||||
correct_preds += (preds == labels).sum().item()
|
|
||||||
total_samples += labels.size(0)
|
|
||||||
|
|
||||||
if return_preds:
|
|
||||||
all_preds.extend(preds.cpu().numpy())
|
|
||||||
all_labels.extend(labels.cpu().numpy())
|
|
||||||
|
|
||||||
avg_loss = total_loss / total_samples
|
|
||||||
avg_acc = correct_preds / total_samples
|
|
||||||
|
|
||||||
if return_preds:
|
|
||||||
return avg_loss, avg_acc, all_labels, all_preds
|
|
||||||
return avg_loss, avg_acc
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
best_val_acc = 0.0
|
|
||||||
best_val_loss = float('inf')
|
|
||||||
epochs_no_improve = 0
|
|
||||||
checkpoint_path = "./emosetV2_resnet50_best.pth"
|
|
||||||
|
|
||||||
# Словарь для хранения истории обучения
|
|
||||||
history = {
|
|
||||||
'train_loss': [], 'train_acc': [],
|
|
||||||
'val_loss': [], 'val_acc': []
|
|
||||||
}
|
|
||||||
|
|
||||||
# Переменные для хранения лучших предсказаний для матрицы
|
|
||||||
best_labels, best_preds = [], []
|
|
||||||
|
|
||||||
print("Старт обучения.")
|
|
||||||
|
|
||||||
for epoch in range(1, EPOCHS + 1):
|
|
||||||
train_loss, train_acc = train_epoch(model, train_loader)
|
|
||||||
|
|
||||||
# Получаем предсказания только если это может быть лучшая эпоха
|
|
||||||
val_loss, val_acc, val_labels, val_preds = val_epoch(model, val_loader, return_preds=True)
|
|
||||||
|
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
# Запись в историю
|
|
||||||
history['train_loss'].append(train_loss)
|
|
||||||
history['train_acc'].append(train_acc)
|
|
||||||
history['val_loss'].append(val_loss)
|
|
||||||
history['val_acc'].append(val_acc)
|
|
||||||
|
|
||||||
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
|
|
||||||
|
|
||||||
# Сохранение лучших весов по Accuracy
|
|
||||||
if val_acc > best_val_acc:
|
|
||||||
best_val_acc = val_acc
|
|
||||||
best_labels = val_labels # Сохраняем предсказания лучшей модели
|
|
||||||
best_preds = val_preds
|
|
||||||
torch.save(model.state_dict(), checkpoint_path)
|
|
||||||
print(f"Сохранен чекпоинт (Acc: {best_val_acc:.4f})")
|
|
||||||
|
|
||||||
# Оценка переобучения по Loss (Early Stopping)
|
|
||||||
if val_loss < best_val_loss:
|
|
||||||
best_val_loss = val_loss
|
|
||||||
epochs_no_improve = 0
|
|
||||||
else:
|
|
||||||
epochs_no_improve += 1
|
|
||||||
if epochs_no_improve >= PATIENCE:
|
|
||||||
print(f"Ранняя остановка: метрика валидации не улучшается {PATIENCE} эпох.")
|
|
||||||
break
|
|
||||||
|
|
||||||
print("Процесс обучения завершен. Генерирую графики для диссертации...")
|
|
||||||
plot_learning_curves(history)
|
|
||||||
plot_confusion_matrix(best_labels, best_preds)
|
|
||||||
print("Все медиафайлы успешно созданы!")
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,171 +0,0 @@
|
|||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from PIL import Image
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
import torchvision.transforms as T
|
|
||||||
import timm
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import seaborn as sns
|
|
||||||
from sklearn.manifold import TSNE
|
|
||||||
|
|
||||||
# Настройки путей для медиа
|
|
||||||
MEDIA_DIR = Path("scripts/media")
|
|
||||||
MEDIA_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Конфигурация путей для инференса и кэширования векторов
|
|
||||||
DATA_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K")
|
|
||||||
MODEL_PATH = Path("./src/emoset_resnet50_best.pth")
|
|
||||||
|
|
||||||
BATCH_SIZE = 128
|
|
||||||
NUM_WORKERS = 32
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
print(f"Вычисления перенесены на: {device}")
|
|
||||||
|
|
||||||
class EmoSetFeatureDataset(Dataset):
|
|
||||||
def __init__(self, root: Path | str, split: str):
|
|
||||||
self.root = Path(root) / split
|
|
||||||
self.df = pd.read_csv(self.root / "labels.csv")
|
|
||||||
|
|
||||||
self.labels = sorted(self.df["label"].unique())
|
|
||||||
self.label2idx = {l: i for i, l in enumerate(self.labels)}
|
|
||||||
self.idx2label = {i: l for l, i in self.label2idx.items()}
|
|
||||||
|
|
||||||
# Для экстракции признаков аугментация отключена, используется строгий CenterCrop
|
|
||||||
self.transform = T.Compose([
|
|
||||||
T.Resize(256),
|
|
||||||
T.CenterCrop(224),
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
||||||
])
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.df)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
row = self.df.iloc[idx]
|
|
||||||
img_path = self.root / "images" / row["filename"]
|
|
||||||
|
|
||||||
# Перехват битых файлов выборки
|
|
||||||
try:
|
|
||||||
img = Image.open(img_path).convert("RGB")
|
|
||||||
except Exception:
|
|
||||||
img = Image.new("RGB", (224, 224), (0, 0, 0))
|
|
||||||
|
|
||||||
img_tensor = self.transform(img)
|
|
||||||
label_idx = self.label2idx[row["label"]]
|
|
||||||
|
|
||||||
return img_tensor, label_idx
|
|
||||||
|
|
||||||
def plot_tsne(embeddings, labels, idx2label, sample_limit=3000):
|
|
||||||
"""Генерация t-SNE графика для диссертации"""
|
|
||||||
print(f"Построение t-SNE проекции для {sample_limit} сэмплов...")
|
|
||||||
|
|
||||||
tsne_model = TSNE(n_components=2, perplexity=30, random_state=42)
|
|
||||||
embeddings_2d = tsne_model.fit_transform(embeddings[:sample_limit])
|
|
||||||
labels_subset = labels[:sample_limit]
|
|
||||||
|
|
||||||
plt.figure(figsize=(12, 9))
|
|
||||||
|
|
||||||
# Используем более академическую палитру
|
|
||||||
scatter = plt.scatter(
|
|
||||||
embeddings_2d[:, 0],
|
|
||||||
embeddings_2d[:, 1],
|
|
||||||
c=labels_subset,
|
|
||||||
cmap="Set2", # Set2 лучше различается при печати
|
|
||||||
alpha=0.7,
|
|
||||||
s=20,
|
|
||||||
edgecolors='w',
|
|
||||||
linewidths=0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
# Формирование легенды
|
|
||||||
handles, _ = scatter.legend_elements()
|
|
||||||
legend_labels = [idx2label[i] for i in range(len(idx2label))]
|
|
||||||
|
|
||||||
# Размещение легенды снаружи графика, чтобы не перекрывать данные
|
|
||||||
plt.legend(handles, legend_labels, title="Эмоциональные классы",
|
|
||||||
bbox_to_anchor=(1.05, 1), loc='upper left')
|
|
||||||
|
|
||||||
plt.title("2D проекция скрытого пространства признаков (t-SNE)", pad=20, fontsize=14)
|
|
||||||
plt.xlabel("Первая главная компонента (t-SNE 1)", fontsize=12)
|
|
||||||
plt.ylabel("Вторая главная компонента (t-SNE 2)", fontsize=12)
|
|
||||||
plt.grid(True, linestyle='--', alpha=0.3)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plot_path = MEDIA_DIR / "tsne_embeddings.png"
|
|
||||||
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
|
||||||
plt.close()
|
|
||||||
print(f"[INFO] График t-SNE сохранен в: {plot_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_ds = EmoSetFeatureDataset(DATA_ROOT, "test")
|
|
||||||
test_loader = DataLoader(
|
|
||||||
test_ds,
|
|
||||||
batch_size=BATCH_SIZE,
|
|
||||||
shuffle=False, # Отключение шаффла для строгого соответствия индексов
|
|
||||||
num_workers=NUM_WORKERS,
|
|
||||||
pin_memory=True
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Подготовлено для извлечения: {len(test_ds)} файлов.")
|
|
||||||
|
|
||||||
# Инициализация модели и загрузка лучших весов
|
|
||||||
feature_extractor = timm.create_model(
|
|
||||||
"resnet50",
|
|
||||||
pretrained=False,
|
|
||||||
num_classes=len(test_ds.labels)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
checkpoint = torch.load(MODEL_PATH, map_location=device)
|
|
||||||
feature_extractor.load_state_dict(checkpoint)
|
|
||||||
print("Веса модели успешно загружены.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Ошибка загрузки весов: {e}. Убедитесь, что модель обучена.")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
# Удаление классификационного слоя (fc)
|
|
||||||
feature_extractor.reset_classifier(0)
|
|
||||||
feature_extractor.to(device)
|
|
||||||
feature_extractor.eval()
|
|
||||||
|
|
||||||
print("Слой классификации удален. Модель готова к экстракции.")
|
|
||||||
|
|
||||||
extracted_embeddings = []
|
|
||||||
extracted_labels = []
|
|
||||||
|
|
||||||
print("Старт пакетной экстракции признаков...")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for imgs, labels in tqdm(test_loader, desc="Экстракция"):
|
|
||||||
imgs = imgs.to(device)
|
|
||||||
|
|
||||||
# Получение вектора [BATCH_SIZE, 2048]
|
|
||||||
embeddings_batch = feature_extractor(imgs)
|
|
||||||
|
|
||||||
extracted_embeddings.append(embeddings_batch.cpu().numpy())
|
|
||||||
extracted_labels.append(labels.numpy())
|
|
||||||
|
|
||||||
# Агрегация батчей в единые массивы
|
|
||||||
np_embeddings = np.concatenate(extracted_embeddings, axis=0)
|
|
||||||
np_labels = np.concatenate(extracted_labels, axis=0)
|
|
||||||
|
|
||||||
print(f"Размерность матрицы признаков: {np_embeddings.shape}")
|
|
||||||
|
|
||||||
# Сохранение артефактов
|
|
||||||
np.save("./src/emoset_test_embeddings.npy", np_embeddings)
|
|
||||||
np.save("./src/emoset_test_labels.npy", np_labels)
|
|
||||||
print("Матрицы успешно экспортированы в .npy файлы.")
|
|
||||||
|
|
||||||
# Генерация медиа для диссертации
|
|
||||||
plot_tsne(np_embeddings, np_labels, test_ds.idx2label, sample_limit=3000)
|
|
||||||
|
|
||||||
print("Процесс полностью завершен.")
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
from pathlib import Path
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# Конфигурация путей и целевых признаков
|
|
||||||
BASE_DIR = Path("../../dataset/DEAM")
|
|
||||||
MUSIC_DB_PATH = BASE_DIR / "music_db.csv"
|
|
||||||
FEATURES_DIR = BASE_DIR / "features" / "features"
|
|
||||||
OUTPUT_PATH = BASE_DIR / "music_db_enriched.csv"
|
|
||||||
|
|
||||||
# Маппинг низкоуровневых признаков экстрактора (openSMILE/GeMAPS) в дескрипторы системы
|
|
||||||
TARGET_FEATURES = {
|
|
||||||
'pcm_RMSenergy_sma_amean': 'energy',
|
|
||||||
'pcm_fftMag_spectralFlux_sma_amean': 'flux',
|
|
||||||
'pcm_fftMag_spectralCentroid_sma_amean': 'centroid',
|
|
||||||
'F0final_sma_amean': 'pitch',
|
|
||||||
'logHNR_sma_amean': 'hnr',
|
|
||||||
'pcm_zcr_sma_amean': 'zcr',
|
|
||||||
'pcm_fftMag_spectralEntropy_sma_amean': 'entropy',
|
|
||||||
'pcm_fftMag_psySharpness_sma_amean': 'sharpness'
|
|
||||||
}
|
|
||||||
|
|
||||||
def aggregate_acoustic_features():
|
|
||||||
if not MUSIC_DB_PATH.exists():
|
|
||||||
print(f"Базовый файл аннотаций не найден: {MUSIC_DB_PATH}")
|
|
||||||
return
|
|
||||||
|
|
||||||
print("Загрузка эмоциональной разметки DEAM...")
|
|
||||||
df_main = pd.read_csv(MUSIC_DB_PATH)
|
|
||||||
|
|
||||||
print("Агрегация фреймовых акустических признаков...")
|
|
||||||
aggregated_data = []
|
|
||||||
|
|
||||||
# Итерация по трекам для сбора покадровых характеристик
|
|
||||||
for _, row in tqdm(df_main.iterrows(), total=len(df_main), desc="Обработка аудио-векторов"):
|
|
||||||
song_id = int(row['song_id'])
|
|
||||||
feature_file = FEATURES_DIR / f"{song_id}.csv"
|
|
||||||
|
|
||||||
if feature_file.exists():
|
|
||||||
try:
|
|
||||||
# Чтение сырых векторов (формат csv с разделителем ';')
|
|
||||||
df_feat = pd.read_csv(feature_file, sep=';')
|
|
||||||
|
|
||||||
# Усреднение характеристик по временной оси (time frames)
|
|
||||||
mean_features = df_feat[list(TARGET_FEATURES.keys())].mean()
|
|
||||||
|
|
||||||
# Формирование агрегированной записи
|
|
||||||
track_data = {'song_id': song_id}
|
|
||||||
for orig_col, new_col in TARGET_FEATURES.items():
|
|
||||||
track_data[new_col] = mean_features[orig_col]
|
|
||||||
|
|
||||||
aggregated_data.append(track_data)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Ошибка парсинга файла {feature_file.name}: {e}")
|
|
||||||
|
|
||||||
# Слияние акустических дескрипторов с эмоциональными координатами (Inner Join)
|
|
||||||
df_features = pd.DataFrame(aggregated_data)
|
|
||||||
df_enriched = pd.merge(df_main, df_features, on='song_id', how='inner')
|
|
||||||
|
|
||||||
# Очистка возможных артефактов NaN после агрегации
|
|
||||||
df_enriched = df_enriched.dropna(subset=list(TARGET_FEATURES.values()))
|
|
||||||
|
|
||||||
df_enriched.to_csv(OUTPUT_PATH, index=False)
|
|
||||||
print(f"Экспорт завершен. Сформирована обогащенная база: {OUTPUT_PATH.name}")
|
|
||||||
print(f"Итоговый размер выборки: {len(df_enriched)} треков.")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
aggregate_acoustic_features()
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
import joblib
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from sklearn.linear_model import RidgeCV
|
|
||||||
from sklearn.multioutput import MultiOutputRegressor
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
|
||||||
from sklearn.pipeline import Pipeline
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from sklearn.metrics import mean_squared_error, r2_score
|
|
||||||
|
|
||||||
# Проекция дискретных классов эмоций на непрерывное пространство Рассела (Valence, Arousal)
|
|
||||||
# Значения откалиброваны в диапазоне [1.0, 9.0]
|
|
||||||
EMOTION_TO_VA_COORDS = {
|
|
||||||
0: (7.5, 6.5), # amusement
|
|
||||||
1: (2.0, 8.0), # anger
|
|
||||||
2: (6.5, 5.0), # awe
|
|
||||||
3: (7.0, 3.0), # contentment
|
|
||||||
4: (3.0, 6.0), # disgust
|
|
||||||
5: (8.0, 8.0), # excitement
|
|
||||||
6: (2.5, 7.5), # fear
|
|
||||||
7: (2.0, 2.0), # sadness
|
|
||||||
}
|
|
||||||
|
|
||||||
def train_va_regressor():
|
|
||||||
# Настройка путей
|
|
||||||
base_dir = Path(__file__).resolve().parent.parent
|
|
||||||
embeddings_path = base_dir / "emoset_test_embeddings.npy"
|
|
||||||
labels_path = base_dir / "emoset_test_labels.npy"
|
|
||||||
model_output_path = base_dir / "music_engine" / "va_regressor.pkl"
|
|
||||||
|
|
||||||
if not embeddings_path.exists() or not labels_path.exists():
|
|
||||||
print(f"Артефакты признаков не найдены в директории: {base_dir}")
|
|
||||||
return
|
|
||||||
|
|
||||||
print("Загрузка вектора признаков и меток классов...")
|
|
||||||
x_features = np.load(embeddings_path)
|
|
||||||
y_discrete = np.load(labels_path)
|
|
||||||
|
|
||||||
# Трансформация целевой переменной: классы -> непрерывные координаты V/A
|
|
||||||
y_continuous = np.array([EMOTION_TO_VA_COORDS[label] for label in y_discrete])
|
|
||||||
|
|
||||||
x_train, x_test, y_train, y_test = train_test_split(
|
|
||||||
x_features, y_continuous, test_size=0.2, random_state=42
|
|
||||||
)
|
|
||||||
|
|
||||||
# Построение пайплайна: Z-масштабирование и L2-регуляризованная регрессия
|
|
||||||
# RidgeCV автоматически подбирает оптимальный гиперпараметр alpha (силу регуляризации)
|
|
||||||
print("Инициализация и обучение пайплайна RidgeCV...")
|
|
||||||
regression_pipeline = Pipeline([
|
|
||||||
('scaler', StandardScaler()),
|
|
||||||
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
|
|
||||||
])
|
|
||||||
|
|
||||||
regression_pipeline.fit(x_train, y_train)
|
|
||||||
|
|
||||||
# Оценка обобщающей способности модели
|
|
||||||
y_pred = regression_pipeline.predict(x_test)
|
|
||||||
|
|
||||||
mse_score = mean_squared_error(y_test, y_pred)
|
|
||||||
r2 = r2_score(y_test, y_pred)
|
|
||||||
|
|
||||||
print("Обучение завершено. Метрики качества на тестовой выборке:")
|
|
||||||
print(f" - MSE: {mse_score:.4f}")
|
|
||||||
print(f" - R^2: {r2:.4f}")
|
|
||||||
|
|
||||||
# Диагностика дисперсии предсказаний
|
|
||||||
v_min, v_max = y_pred[:, 0].min(), y_pred[:, 0].max()
|
|
||||||
a_min, a_max = y_pred[:, 1].min(), y_pred[:, 1].max()
|
|
||||||
print(f"Распределение Valence (прогноз): [{v_min:.2f}, {v_max:.2f}] (Эталон: 1.0 - 9.0)")
|
|
||||||
print(f"Распределение Arousal (прогноз): [{a_min:.2f}, {a_max:.2f}] (Эталон: 1.0 - 9.0)")
|
|
||||||
|
|
||||||
# Экспорт обученного пайплайна
|
|
||||||
model_output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
joblib.dump(regression_pipeline, model_output_path)
|
|
||||||
print(f"Пайплайн сохранен: {model_output_path.name}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
train_va_regressor()
|
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import joblib
|
|
||||||
from pathlib import Path
|
|
||||||
from sklearn.metrics import mean_squared_error, r2_score
|
|
||||||
|
|
||||||
# 1. Настройка путей
|
|
||||||
embeddings_path = Path("./src/emoset_test_embeddings.npy")
|
|
||||||
csv_path = Path("./NFS/Thesis/Emoset/EmoSet-118K/test/labels.csv")
|
|
||||||
model_path = Path("./src/music_engine/va_regressor.pkl")
|
|
||||||
|
|
||||||
output_dir = Path("./src/scripts/media")
|
|
||||||
output_file = output_dir / "metrics_output.txt"
|
|
||||||
|
|
||||||
# 2. Корректный маппинг 8 классов EmoSet в шкалу DEAM [1.0, 9.0]
|
|
||||||
# Формула перевода из [-1, 1] в [1, 9]: 5.0 + (X * 4.0)
|
|
||||||
EMO_TO_VA = {
|
|
||||||
"amusement": [8.2, 6.6], # Веселье (Высокий позитив, средняя энергия)
|
|
||||||
"awe": [7.0, 7.4], # Восхищение (Позитив, высокая энергия)
|
|
||||||
"contentment": [7.8, 3.4], # Умиротворение (Позитив, низкая энергия)
|
|
||||||
"excitement": [8.2, 8.2], # Возбуждение (Макс. позитив, макс. энергия)
|
|
||||||
"anger": [2.2, 7.8], # Гнев (Глубокий негатив, высокая энергия)
|
|
||||||
"disgust": [2.6, 6.6], # Отвращение (Негатив, средняя энергия)
|
|
||||||
"fear": [2.6, 8.2], # Страх (Негатив, максимальная энергия)
|
|
||||||
"sadness": [2.2, 2.6] # Грусть (Глубокий негатив, низкая энергия)
|
|
||||||
}
|
|
||||||
|
|
||||||
def generate_slide_metrics():
|
|
||||||
print("[INFO] Загрузка тестовых артефактов...")
|
|
||||||
|
|
||||||
if not all(p.exists() for p in [embeddings_path, csv_path, model_path]):
|
|
||||||
print("[ERROR] Проверьте наличие файлов данных или модели регрессора.")
|
|
||||||
return
|
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# 3. Загрузка эмбеддингов и меток
|
|
||||||
X_test = np.load(embeddings_path)
|
|
||||||
df = pd.read_csv(csv_path)
|
|
||||||
|
|
||||||
if len(X_test) != len(df):
|
|
||||||
print(f"[WARN] Корректировка размеров выборки: Эмбеддинги ({len(X_test)}) != Метки ({len(df)})")
|
|
||||||
min_len = min(len(X_test), len(df))
|
|
||||||
X_test = X_test[:min_len]
|
|
||||||
df = df.iloc[:min_len]
|
|
||||||
|
|
||||||
y_test_list = [EMO_TO_VA.get(label.lower().strip(), [5.0, 5.0]) for label in df['label']]
|
|
||||||
y_test = np.array(y_test_list)
|
|
||||||
|
|
||||||
# 4. Выполнение инференса
|
|
||||||
print("[INFO] Выполнение инференса регрессора на скрытом пространстве признаков...")
|
|
||||||
regressor = joblib.load(model_path)
|
|
||||||
y_pred = regressor.predict(X_test)
|
|
||||||
|
|
||||||
# === БЛОК ДИАГНОСТИКИ ШКАЛЫ ===
|
|
||||||
print("\n" + "-"*50)
|
|
||||||
print(" ДИАГНОСТИКА ДИАПАЗОНОВ ЗНАЧЕНИЙ ".center(50))
|
|
||||||
print("-"*50)
|
|
||||||
print(f"Истинные (y_test) -> Мин: {y_test.min():.2f}, Макс: {y_test.max():.2f}, Среднее: {y_test.mean():.2f}")
|
|
||||||
print(f"Предсказания (y_pred) -> Мин: {y_pred.min():.2f}, Макс: {y_pred.max():.2f}, Среднее: {y_pred.mean():.2f}")
|
|
||||||
print("-"*50 + "\n")
|
|
||||||
# ==============================
|
|
||||||
|
|
||||||
# 5. Расчет метрик
|
|
||||||
mse_v = mean_squared_error(y_test[:, 0], y_pred[:, 0])
|
|
||||||
r2_v = r2_score(y_test[:, 0], y_pred[:, 0])
|
|
||||||
|
|
||||||
mse_a = mean_squared_error(y_test[:, 1], y_pred[:, 1])
|
|
||||||
r2_a = r2_score(y_test[:, 1], y_pred[:, 1])
|
|
||||||
|
|
||||||
mse_total = mean_squared_error(y_test, y_pred)
|
|
||||||
r2_total = r2_score(y_test, y_pred)
|
|
||||||
|
|
||||||
# 6. Вывод и сохранение результатов
|
|
||||||
table_content = f"""
|
|
||||||
==================================================
|
|
||||||
ТАБЛИЦА МЕТРИК ДЛЯ СЛАЙДА 10
|
|
||||||
==================================================
|
|
||||||
| Метрика | Valence (V) | Arousal (A) | Общая (Total) |
|
|
||||||
|------------|--------------|--------------|---------------|
|
|
||||||
| MSE | {mse_v:<12.4f} | {mse_a:<12.4f} | {mse_total:<13.4f} |
|
|
||||||
| R² | {r2_v:<12.4f} | {r2_a:<12.4f} | {r2_total:<13.4f} |
|
|
||||||
==================================================
|
|
||||||
|
|
||||||
Формула целевой функции для вставки на слайд (LaTeX):
|
|
||||||
$$Score_{{final}} = D_{{emo}} + 4.0 \cdot Acoustic_{{penalty}}$$
|
|
||||||
"""
|
|
||||||
|
|
||||||
print(table_content)
|
|
||||||
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
|
||||||
f.write(table_content)
|
|
||||||
|
|
||||||
print(f"[SUCCESS] Метрики успешно сохранены в файл: {output_file.absolute()}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
generate_slide_metrics()
|
|
||||||
@@ -2,30 +2,30 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 1,
|
||||||
"id": "8523d028",
|
"id": "8523d028",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"from torch.utils.data import Dataset, DataLoader\n",
|
||||||
|
"\n",
|
||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"import numpy as np\n",
|
|
||||||
"from pathlib import Path\n",
|
"from pathlib import Path\n",
|
||||||
"from PIL import Image\n",
|
"from PIL import Image\n",
|
||||||
"from tqdm import tqdm\n",
|
"from tqdm import tqdm\n",
|
||||||
"\n",
|
"\n",
|
||||||
"import torch\n",
|
|
||||||
"from torch.utils.data import Dataset, DataLoader\n",
|
|
||||||
"import torchvision.transforms as T\n",
|
"import torchvision.transforms as T\n",
|
||||||
"import timm\n",
|
"import timm\n",
|
||||||
|
"import numpy as np\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from sklearn.metrics import accuracy_score, confusion_matrix, classification_report\n",
|
"from sklearn.metrics import accuracy_score, confusion_matrix, classification_report\n"
|
||||||
"import matplotlib.pyplot as plt\n",
|
|
||||||
"import seaborn as sns"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 6,
|
||||||
"id": "e0781b02",
|
"id": "e0781b02",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -41,26 +41,25 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# Конфигурация путей и параметров инференса\n",
|
|
||||||
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
|
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
|
||||||
"MODEL_PATH = Path(\"./emoset_resnet50_best.pth\")\n",
|
"MODEL_PATH = Path(\"./emoset_resnet50_best.pth\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"BATCH_SIZE = 64\n",
|
"BATCH_SIZE = 64\n",
|
||||||
"NUM_WORKERS = 4\n",
|
"NUM_WORKERS = 4\n",
|
||||||
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||||
"print(f\"Аппаратное ускорение: {DEVICE}\")"
|
"\n",
|
||||||
|
"DEVICE\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 3,
|
||||||
"id": "79da9640",
|
"id": "79da9640",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"class EmoSetEvaluationDataset(Dataset):\n",
|
"class EmoSetDataset(Dataset):\n",
|
||||||
" # Датасет для строгой валидации с центрированным кропом\n",
|
" def __init__(self, root, split):\n",
|
||||||
" def __init__(self, root: Path | str, split: str):\n",
|
|
||||||
" self.root = Path(root) / split\n",
|
" self.root = Path(root) / split\n",
|
||||||
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
|
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -68,12 +67,13 @@
|
|||||||
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
|
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
|
||||||
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
|
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Стандартный пайплайн трансформаций для инференса ResNet\n",
|
|
||||||
" self.transform = T.Compose([\n",
|
" self.transform = T.Compose([\n",
|
||||||
" T.Resize(256),\n",
|
" T.Resize((224, 224)),\n",
|
||||||
" T.CenterCrop(224),\n",
|
|
||||||
" T.ToTensor(),\n",
|
" T.ToTensor(),\n",
|
||||||
" T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
|
" T.Normalize(\n",
|
||||||
|
" mean=[0.485, 0.456, 0.406],\n",
|
||||||
|
" std=[0.229, 0.224, 0.225]\n",
|
||||||
|
" )\n",
|
||||||
" ])\n",
|
" ])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def __len__(self):\n",
|
" def __len__(self):\n",
|
||||||
@@ -81,23 +81,15 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" def __getitem__(self, idx):\n",
|
" def __getitem__(self, idx):\n",
|
||||||
" row = self.df.iloc[idx]\n",
|
" row = self.df.iloc[idx]\n",
|
||||||
" img_path = self.root / \"images\" / row[\"filename\"]\n",
|
" img = Image.open(self.root / \"images\" / row[\"filename\"]).convert(\"RGB\")\n",
|
||||||
" \n",
|
" img = self.transform(img)\n",
|
||||||
" # Перехват битых файлов для непрерывности оценки\n",
|
" label = self.label2idx[row[\"label\"]]\n",
|
||||||
" try:\n",
|
" return img, label\n"
|
||||||
" img = Image.open(img_path).convert(\"RGB\")\n",
|
|
||||||
" except Exception:\n",
|
|
||||||
" img = Image.new(\"RGB\", (224, 224), (0, 0, 0))\n",
|
|
||||||
" \n",
|
|
||||||
" img_tensor = self.transform(img)\n",
|
|
||||||
" label_idx = self.label2idx[row[\"label\"]]\n",
|
|
||||||
" \n",
|
|
||||||
" return img_tensor, label_idx"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 7,
|
||||||
"id": "12201756",
|
"id": "12201756",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -111,8 +103,8 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# Инициализация тестовой выборки\n",
|
"test_ds = EmoSetDataset(DATA_ROOT, \"test\")\n",
|
||||||
"test_ds = EmoSetEvaluationDataset(DATA_ROOT, \"test\")\n",
|
"\n",
|
||||||
"test_loader = DataLoader(\n",
|
"test_loader = DataLoader(\n",
|
||||||
" test_ds,\n",
|
" test_ds,\n",
|
||||||
" batch_size=BATCH_SIZE,\n",
|
" batch_size=BATCH_SIZE,\n",
|
||||||
@@ -121,13 +113,13 @@
|
|||||||
" pin_memory=True\n",
|
" pin_memory=True\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(f\"Индексированные классы: {test_ds.labels}\")\n",
|
"print(\"Classes:\", test_ds.labels)\n",
|
||||||
"print(f\"Размер тестовой выборки: {len(test_ds)}\")"
|
"print(\"Test samples:\", len(test_ds))\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 8,
|
||||||
"id": "7e3dc1d5",
|
"id": "7e3dc1d5",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -382,17 +374,22 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# Инициализация модели в режиме классификации\n",
|
|
||||||
"model = timm.create_model(\n",
|
"model = timm.create_model(\n",
|
||||||
" \"resnet50\",\n",
|
" \"resnet50\",\n",
|
||||||
" pretrained=False,\n",
|
" pretrained=False,\n",
|
||||||
" num_classes=len(test_ds.labels)\n",
|
" num_classes=len(test_ds.labels)\n",
|
||||||
")"
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"state = torch.load(MODEL_PATH, map_location=DEVICE)\n",
|
||||||
|
"model.load_state_dict(state)\n",
|
||||||
|
"\n",
|
||||||
|
"model.to(DEVICE)\n",
|
||||||
|
"model.eval()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 9,
|
||||||
"id": "b42a84f1",
|
"id": "b42a84f1",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -405,16 +402,27 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# Загрузка весов и перевод в режим инференса\n",
|
"all_preds = []\n",
|
||||||
"checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)\n",
|
"all_targets = []\n",
|
||||||
"model.load_state_dict(checkpoint)\n",
|
"\n",
|
||||||
"model.to(DEVICE)\n",
|
"with torch.no_grad():\n",
|
||||||
"model.eval()"
|
" for imgs, labels in tqdm(test_loader):\n",
|
||||||
|
" imgs = imgs.to(DEVICE)\n",
|
||||||
|
" labels = labels.to(DEVICE)\n",
|
||||||
|
"\n",
|
||||||
|
" logits = model(imgs)\n",
|
||||||
|
" preds = logits.argmax(dim=1)\n",
|
||||||
|
"\n",
|
||||||
|
" all_preds.append(preds.cpu().numpy())\n",
|
||||||
|
" all_targets.append(labels.cpu().numpy())\n",
|
||||||
|
"\n",
|
||||||
|
"all_preds = np.concatenate(all_preds)\n",
|
||||||
|
"all_targets = np.concatenate(all_targets)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 10,
|
||||||
"id": "4c1f1377",
|
"id": "4c1f1377",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -427,25 +435,13 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# Сбор предсказаний на тестовой выборке\n",
|
"acc = accuracy_score(all_targets, all_preds)\n",
|
||||||
"all_preds = []\n",
|
"print(f\"Test accuracy: {acc:.4f}\")\n"
|
||||||
"all_targets = []\n",
|
|
||||||
"\n",
|
|
||||||
"print(\"Запуск инференса на тестовой выборке...\")\n",
|
|
||||||
"with torch.no_grad():\n",
|
|
||||||
" for imgs, labels in tqdm(test_loader, desc=\"Оценка метрик\"):\n",
|
|
||||||
" imgs = imgs.to(DEVICE)\n",
|
|
||||||
" \n",
|
|
||||||
" logits = model(imgs)\n",
|
|
||||||
" preds = logits.argmax(dim=1)\n",
|
|
||||||
"\n",
|
|
||||||
" all_preds.append(preds.cpu().numpy())\n",
|
|
||||||
" all_targets.append(labels.numpy())"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 11,
|
||||||
"id": "6b022825",
|
"id": "6b022825",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -472,14 +468,19 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# Агрегация результатов\n",
|
"print(\n",
|
||||||
"all_preds = np.concatenate(all_preds, axis=0)\n",
|
" classification_report(\n",
|
||||||
"all_targets = np.concatenate(all_targets, axis=0)"
|
" all_targets,\n",
|
||||||
|
" all_preds,\n",
|
||||||
|
" target_names=test_ds.labels,\n",
|
||||||
|
" digits=4\n",
|
||||||
|
" )\n",
|
||||||
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 12,
|
||||||
"id": "2fcb69ac",
|
"id": "2fcb69ac",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -495,70 +496,20 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# Расчет интегральных метрик классификации\n",
|
"import matplotlib.pyplot as plt\n",
|
||||||
"acc = accuracy_score(all_targets, all_preds)\n",
|
|
||||||
"print(f\"\\nОбщая точность (Accuracy): {acc:.4f}\\n\")\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"Детализированный отчет (Classification Report):\")\n",
|
|
||||||
"print(\n",
|
|
||||||
" classification_report(\n",
|
|
||||||
" all_targets,\n",
|
|
||||||
" all_preds,\n",
|
|
||||||
" target_names=test_ds.labels,\n",
|
|
||||||
" digits=4\n",
|
|
||||||
" )\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "2084ab91",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Построение матрицы ошибок (Confusion Matrix)\n",
|
|
||||||
"cm = confusion_matrix(all_targets, all_preds)\n",
|
"cm = confusion_matrix(all_targets, all_preds)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"plt.figure(figsize=(10, 8))"
|
"plt.figure(figsize=(8, 8))\n",
|
||||||
]
|
"plt.imshow(cm)\n",
|
||||||
},
|
"plt.colorbar()\n",
|
||||||
{
|
"plt.xticks(range(len(test_ds.labels)), test_ds.labels, rotation=45)\n",
|
||||||
"cell_type": "code",
|
"plt.yticks(range(len(test_ds.labels)), test_ds.labels)\n",
|
||||||
"execution_count": null,
|
"plt.xlabel(\"Predicted\")\n",
|
||||||
"id": "83a84e14",
|
"plt.ylabel(\"True\")\n",
|
||||||
"metadata": {},
|
"plt.title(\"Confusion Matrix (Test)\")\n",
|
||||||
"outputs": [],
|
"plt.tight_layout()\n",
|
||||||
"source": [
|
"plt.show()\n"
|
||||||
"# Использование seaborn для академичной визуализации с числами\n",
|
|
||||||
"sns.heatmap(\n",
|
|
||||||
" cm, \n",
|
|
||||||
" annot=True, \n",
|
|
||||||
" fmt=\"d\", \n",
|
|
||||||
" cmap=\"Blues\", \n",
|
|
||||||
" xticklabels=test_ds.labels, \n",
|
|
||||||
" yticklabels=test_ds.labels,\n",
|
|
||||||
" cbar=False\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"plt.title(\"Матрица ошибок классификации EmoSet (ResNet-50)\", pad=20)\n",
|
|
||||||
"plt.xlabel(\"Предсказанный класс\", labelpad=15)\n",
|
|
||||||
"plt.ylabel(\"Истинный класс\", labelpad=15)\n",
|
|
||||||
"plt.xticks(rotation=45)\n",
|
|
||||||
"plt.yticks(rotation=0)\n",
|
|
||||||
"plt.tight_layout()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "280d5637",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Экспорт графика\n",
|
|
||||||
"plt.savefig(\"../confusion_matrix_resnet50.png\", dpi=300, bbox_inches='tight')\n",
|
|
||||||
"plt.show()"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -0,0 +1,125 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "0336fd0c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"✅ База загружена. Треков: 1744\n",
|
||||||
|
"🔍 Собираем акустические признаки...\n",
|
||||||
|
"\n",
|
||||||
|
"🚀 ГОТОВО! Обогащенная база сохранена: ../../dataset/DEAM/music_db_enriched.csv\n",
|
||||||
|
"Собрано фичей для 1744 из 1744 треков.\n",
|
||||||
|
" song_id valence arousal energy flux centroid pitch \\\n",
|
||||||
|
"0 2 3.1 3.0 0.097268 0.846947 483.421751 93.884056 \n",
|
||||||
|
"1 3 3.5 3.3 0.126809 0.959460 173.219616 62.682589 \n",
|
||||||
|
"2 4 5.7 5.5 0.156699 1.333944 466.434797 92.850316 \n",
|
||||||
|
"3 5 4.4 5.3 0.126455 1.009927 546.152506 158.673853 \n",
|
||||||
|
"4 7 5.8 6.4 0.268180 1.589191 175.369162 83.823484 \n",
|
||||||
|
"\n",
|
||||||
|
" hnr zcr entropy sharpness \n",
|
||||||
|
"0 3.615380 0.034270 3.299075 0.426490 \n",
|
||||||
|
"1 -2.600122 0.017893 2.294971 0.165583 \n",
|
||||||
|
"2 -0.579130 0.042936 3.258138 0.395410 \n",
|
||||||
|
"3 1.751148 0.043781 3.514585 0.494367 \n",
|
||||||
|
"4 12.006770 0.014783 2.177862 0.170058 \n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from tqdm import tqdm # для красивого прогресс-бара, если не установлен - убери\n",
|
||||||
|
"\n",
|
||||||
|
"# 1. Пути к файлам\n",
|
||||||
|
"base_dir = Path(\"../../dataset/DEAM\") # Поправь, если запускаешь из другого места\n",
|
||||||
|
"music_db_path = base_dir / \"music_db.csv\"\n",
|
||||||
|
"features_dir = base_dir / \"features\" / \"features\"\n",
|
||||||
|
"output_path = base_dir / \"music_db_enriched.csv\"\n",
|
||||||
|
"\n",
|
||||||
|
"# 2. Наш \"Золотой список\" (8 признаков)\n",
|
||||||
|
"target_columns = {\n",
|
||||||
|
" 'pcm_RMSenergy_sma_amean': 'energy',\n",
|
||||||
|
" 'pcm_fftMag_spectralFlux_sma_amean': 'flux',\n",
|
||||||
|
" 'pcm_fftMag_spectralCentroid_sma_amean': 'centroid',\n",
|
||||||
|
" 'F0final_sma_amean': 'pitch',\n",
|
||||||
|
" 'logHNR_sma_amean': 'hnr',\n",
|
||||||
|
" 'pcm_zcr_sma_amean': 'zcr',\n",
|
||||||
|
" 'pcm_fftMag_spectralEntropy_sma_amean': 'entropy',\n",
|
||||||
|
" 'pcm_fftMag_psySharpness_sma_amean': 'sharpness'\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"# 3. Загружаем текущую базу с V/A\n",
|
||||||
|
"if not music_db_path.exists():\n",
|
||||||
|
" print(f\"❌ ОШИБКА: Не найден файл {music_db_path}\")\n",
|
||||||
|
"else:\n",
|
||||||
|
" df_main = pd.read_csv(music_db_path)\n",
|
||||||
|
" print(f\"✅ База загружена. Треков: {len(df_main)}\")\n",
|
||||||
|
"\n",
|
||||||
|
" # Подготавливаем новые колонки\n",
|
||||||
|
" for col_name in target_columns.values():\n",
|
||||||
|
" df_main[col_name] = np.nan\n",
|
||||||
|
"\n",
|
||||||
|
" # 4. Проходимся по всем трекам и ищем их акустические CSV\n",
|
||||||
|
" print(\"🔍 Собираем акустические признаки...\")\n",
|
||||||
|
" found_count = 0\n",
|
||||||
|
" \n",
|
||||||
|
" for index, row in df_main.iterrows():\n",
|
||||||
|
" song_id = int(row['song_id'])\n",
|
||||||
|
" feature_file = features_dir / f\"{song_id}.csv\"\n",
|
||||||
|
" \n",
|
||||||
|
" if feature_file.exists():\n",
|
||||||
|
" try:\n",
|
||||||
|
" # Читаем CSV с признаками (разделитель там обычно точка с запятой)\n",
|
||||||
|
" df_feat = pd.read_csv(feature_file, sep=';')\n",
|
||||||
|
" \n",
|
||||||
|
" # Усредняем значения по всем фреймам (одна песня разбита на сотни строк-фреймов)\n",
|
||||||
|
" mean_features = df_feat[list(target_columns.keys())].mean()\n",
|
||||||
|
" \n",
|
||||||
|
" # Записываем в главную базу\n",
|
||||||
|
" for orig_col, new_col in target_columns.items():\n",
|
||||||
|
" df_main.at[index, new_col] = mean_features[orig_col]\n",
|
||||||
|
" \n",
|
||||||
|
" found_count += 1\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" print(f\"Ошибка чтения {feature_file}: {e}\")\n",
|
||||||
|
" \n",
|
||||||
|
" # 5. Сохраняем результат\n",
|
||||||
|
" # Удаляем треки, для которых не нашлось фичей (если такие есть)\n",
|
||||||
|
" df_main = df_main.dropna(subset=list(target_columns.values()))\n",
|
||||||
|
" \n",
|
||||||
|
" df_main.to_csv(output_path, index=False)\n",
|
||||||
|
" print(f\"\\n🚀 ГОТОВО! Обогащенная база сохранена: {output_path}\")\n",
|
||||||
|
" print(f\"Собрано фичей для {found_count} из {len(df_main)} треков.\")\n",
|
||||||
|
" print(df_main.head())"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python (thesis)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "thesis"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@@ -0,0 +1,614 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "09f9237a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Requirement already satisfied: datasets in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (4.4.2)\n",
|
||||||
|
"Requirement already satisfied: tqdm in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (4.67.1)\n",
|
||||||
|
"Requirement already satisfied: pillow in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (12.1.0)\n",
|
||||||
|
"Requirement already satisfied: requests in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (2.32.5)\n",
|
||||||
|
"Requirement already satisfied: filelock in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (3.20.3)\n",
|
||||||
|
"Requirement already satisfied: numpy>=1.17 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (2.4.1)\n",
|
||||||
|
"Requirement already satisfied: pyarrow>=21.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (22.0.0)\n",
|
||||||
|
"Requirement already satisfied: dill<0.4.1,>=0.3.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.4.0)\n",
|
||||||
|
"Requirement already satisfied: pandas in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (2.3.3)\n",
|
||||||
|
"Requirement already satisfied: httpx<1.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.28.1)\n",
|
||||||
|
"Requirement already satisfied: xxhash in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (3.6.0)\n",
|
||||||
|
"Requirement already satisfied: multiprocess<0.70.19 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.70.18)\n",
|
||||||
|
"Requirement already satisfied: fsspec<=2025.10.0,>=2023.1.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2025.10.0)\n",
|
||||||
|
"Requirement already satisfied: huggingface-hub<2.0,>=0.25.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (1.3.1)\n",
|
||||||
|
"Requirement already satisfied: packaging in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (25.0)\n",
|
||||||
|
"Requirement already satisfied: pyyaml>=5.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (6.0.3)\n",
|
||||||
|
"Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (3.13.3)\n",
|
||||||
|
"Requirement already satisfied: anyio in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (4.12.1)\n",
|
||||||
|
"Requirement already satisfied: certifi in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (2026.1.4)\n",
|
||||||
|
"Requirement already satisfied: httpcore==1.* in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (1.0.9)\n",
|
||||||
|
"Requirement already satisfied: idna in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (3.11)\n",
|
||||||
|
"Requirement already satisfied: h11>=0.16 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx<1.0.0->datasets) (0.16.0)\n",
|
||||||
|
"Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.2.0)\n",
|
||||||
|
"Requirement already satisfied: shellingham in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.5.4)\n",
|
||||||
|
"Requirement already satisfied: typer-slim in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (0.21.1)\n",
|
||||||
|
"Requirement already satisfied: typing-extensions>=4.1.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (4.15.0)\n",
|
||||||
|
"Requirement already satisfied: charset_normalizer<4,>=2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests) (3.4.4)\n",
|
||||||
|
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests) (2.6.3)\n",
|
||||||
|
"Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2.6.1)\n",
|
||||||
|
"Requirement already satisfied: aiosignal>=1.4.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.4.0)\n",
|
||||||
|
"Requirement already satisfied: attrs>=17.3.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (25.4.0)\n",
|
||||||
|
"Requirement already satisfied: frozenlist>=1.1.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.8.0)\n",
|
||||||
|
"Requirement already satisfied: multidict<7.0,>=4.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (6.7.0)\n",
|
||||||
|
"Requirement already satisfied: propcache>=0.2.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (0.4.1)\n",
|
||||||
|
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.22.0)\n",
|
||||||
|
"Requirement already satisfied: python-dateutil>=2.8.2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2.9.0.post0)\n",
|
||||||
|
"Requirement already satisfied: pytz>=2020.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2025.2)\n",
|
||||||
|
"Requirement already satisfied: tzdata>=2022.7 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2025.3)\n",
|
||||||
|
"Requirement already satisfied: six>=1.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n",
|
||||||
|
"Requirement already satisfied: click>=8.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from typer-slim->huggingface-hub<2.0,>=0.25.0->datasets) (8.3.1)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"!pip install datasets tqdm pillow requests\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "6f0b2e2c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "95f07577d20642b09f2cda6f0b2cca14",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "868d872a109d49f9966f2f19985e7048",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "06741794289540849ad179c5966dcab8",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Downloading data: 0%| | 0/18 [00:00<?, ?files/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "e47aad5270144913996cb5b226213ab9",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00000-of-00018.parquet: 0%| | 0.00/509M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "30d1492a948245e3b6b58e92218cd760",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00001-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "931823b458cb4696b459e9011537cf1e",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00002-of-00018.parquet: 0%| | 0.00/489M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "846f4245b16d4cc096a43c940590ad11",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00003-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "71df201ff1a24811af67458c3fe3f2f4",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00004-of-00018.parquet: 0%| | 0.00/495M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "404dce6c69fc413dbe4aa84c289a0ab6",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00005-of-00018.parquet: 0%| | 0.00/501M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "e52b0bbbfdd14c599f44f02a48542317",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00006-of-00018.parquet: 0%| | 0.00/510M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "172981d77fc941cfa32c05f5a34bf742",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00007-of-00018.parquet: 0%| | 0.00/497M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "cc9d886ff22f4165bf696c8b4d758931",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00008-of-00018.parquet: 0%| | 0.00/512M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "5f118a9923c64ee2aa2001a1414927a3",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00009-of-00018.parquet: 0%| | 0.00/502M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "db61d8d556dc4574adbd8f916f790fa7",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00010-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "75414190b19c4affbe190f6dd4f7bc4f",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00011-of-00018.parquet: 0%| | 0.00/500M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "172aa22ed0c44a289e0ac68b240c13c4",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00012-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "2baa935ed3524a73883909752cb15907",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00013-of-00018.parquet: 0%| | 0.00/491M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "5e716611b29b44788e0bf2e7ad05be5b",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00014-of-00018.parquet: 0%| | 0.00/502M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "d9c0baac101b449794155392f07b49c3",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00015-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "b31cdc7f17ac4ac8a04593e8a01a300a",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00016-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "ed6766f750c54b4194957bfe3db78ed6",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/train-00017-of-00018.parquet: 0%| | 0.00/494M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "5454d2ecded64b82a12823f02a7ab12d",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/val-00000-of-00002.parquet: 0%| | 0.00/282M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "62dd1439e0514c98b0c24cc8f600c57e",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/val-00001-of-00002.parquet: 0%| | 0.00/283M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "3a5b966f79314e069251462bff82395f",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/test-00000-of-00004.parquet: 0%| | 0.00/422M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "422974f938924910a0712b30a9c2bd84",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/test-00001-of-00004.parquet: 0%| | 0.00/430M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "f155a08427094de7ad1a5884e623db2b",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/test-00002-of-00004.parquet: 0%| | 0.00/420M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "a94a4621d19f45f690e0064fee83767b",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"data/test-00003-of-00004.parquet: 0%| | 0.00/422M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "50f55b00a27b4213b573b398e5b0d708",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Generating train split: 0%| | 0/94481 [00:00<?, ? examples/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "8658b8414f604f0ca2fd248a214ad4aa",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Generating val split: 0%| | 0/5905 [00:00<?, ? examples/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "d59b7dea75f84b64bb8b262b43730e51",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Generating test split: 0%| | 0/17716 [00:00<?, ? examples/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "0c5815040f0a4a31903348a8327811a5",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Loading dataset shards: 0%| | 0/18 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"DatasetDict({\n",
|
||||||
|
" train: Dataset({\n",
|
||||||
|
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
|
||||||
|
" num_rows: 94481\n",
|
||||||
|
" })\n",
|
||||||
|
" val: Dataset({\n",
|
||||||
|
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
|
||||||
|
" num_rows: 5905\n",
|
||||||
|
" })\n",
|
||||||
|
" test: Dataset({\n",
|
||||||
|
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
|
||||||
|
" num_rows: 17716\n",
|
||||||
|
" })\n",
|
||||||
|
"})\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from datasets import load_dataset\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"import requests\n",
|
||||||
|
"\n",
|
||||||
|
"# куда сохраняем датасет\n",
|
||||||
|
"DATA_DIR = Path(\"../dataset/EmoSet-118K\")\n",
|
||||||
|
"DATA_DIR.mkdir(exist_ok=True, parents=True)\n",
|
||||||
|
"\n",
|
||||||
|
"# загружаем через Hugging Face\n",
|
||||||
|
"ds = load_dataset(\"Woleek/EmoSet-118K\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(ds)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "052ab073",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"\n",
|
||||||
|
"def save_split(split):\n",
|
||||||
|
" split_dir = DATA_DIR / split\n",
|
||||||
|
" img_dir = split_dir / \"images\"\n",
|
||||||
|
" img_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||||
|
"\n",
|
||||||
|
" labels_path = split_dir / \"labels.csv\"\n",
|
||||||
|
"\n",
|
||||||
|
" # перезаписываем labels.csv\n",
|
||||||
|
" with open(labels_path, \"w\") as f:\n",
|
||||||
|
" f.write(\"filename,label\\n\")\n",
|
||||||
|
"\n",
|
||||||
|
" for example in tqdm(ds[split]):\n",
|
||||||
|
" img = example[\"image\"] # уже PIL.Image\n",
|
||||||
|
" label = example[\"emotion\"]\n",
|
||||||
|
" image_id = example[\"image_id\"]\n",
|
||||||
|
"\n",
|
||||||
|
" fname = f\"{image_id}.jpg\"\n",
|
||||||
|
" img.save(img_dir / fname)\n",
|
||||||
|
"\n",
|
||||||
|
" with open(labels_path, \"a\") as f:\n",
|
||||||
|
" f.write(f\"{fname},{label}\\n\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "a74ceedf",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 94481/94481 [18:43<00:00, 84.10it/s] \n",
|
||||||
|
"100%|██████████| 5905/5905 [01:08<00:00, 86.57it/s] \n",
|
||||||
|
"100%|██████████| 17716/17716 [02:57<00:00, 100.01it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"save_split(\"train\")\n",
|
||||||
|
"save_split(\"val\")\n",
|
||||||
|
"save_split(\"test\")\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "thesis-py3.11",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@@ -0,0 +1,140 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Загрузка датасета DEAM\n",
|
||||||
|
"\n",
|
||||||
|
"Этот ноутбук предназначен для автоматизации процесса скачивания и подготовки музыкального датасета **DEAM** (Database for Emotional Analysis in Music).\n",
|
||||||
|
"Данные будут помещены в папку `dataset/DEAM`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Collecting kagglehub\n",
|
||||||
|
" Downloading kagglehub-1.0.1-py3-none-any.whl.metadata (40 kB)\n",
|
||||||
|
"Collecting kagglesdk<1.0,>=0.1.22 (from kagglehub)\n",
|
||||||
|
" Downloading kagglesdk-0.1.23-py3-none-any.whl.metadata (13 kB)\n",
|
||||||
|
"Requirement already satisfied: packaging in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (25.0)\n",
|
||||||
|
"Requirement already satisfied: pyyaml in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (6.0.3)\n",
|
||||||
|
"Requirement already satisfied: requests in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (2.32.5)\n",
|
||||||
|
"Requirement already satisfied: tqdm in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (4.67.1)\n",
|
||||||
|
"Requirement already satisfied: protobuf in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglesdk<1.0,>=0.1.22->kagglehub) (6.33.4)\n",
|
||||||
|
"Requirement already satisfied: charset_normalizer<4,>=2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (3.4.4)\n",
|
||||||
|
"Requirement already satisfied: idna<4,>=2.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (3.11)\n",
|
||||||
|
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (2.6.3)\n",
|
||||||
|
"Requirement already satisfied: certifi>=2017.4.17 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (2026.1.4)\n",
|
||||||
|
"Downloading kagglehub-1.0.1-py3-none-any.whl (70 kB)\n",
|
||||||
|
"Downloading kagglesdk-0.1.23-py3-none-any.whl (217 kB)\n",
|
||||||
|
"Installing collected packages: kagglesdk, kagglehub\n",
|
||||||
|
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2/2\u001b[0m [kagglehub]\n",
|
||||||
|
"\u001b[1A\u001b[2KSuccessfully installed kagglehub-1.0.1 kagglesdk-0.1.23\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.1.1\u001b[0m\n",
|
||||||
|
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"!pip install kagglehub"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Скачиваем датасет DEAM...\n",
|
||||||
|
"Downloading to /home/zin/.cache/kagglehub/datasets/imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music/1.archive...\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1.83G/1.83G [01:09<00:00, 28.2MB/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Extracting files...\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Датасет скачан во временную директорию: /home/zin/.cache/kagglehub/datasets/imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music/versions/1\n",
|
||||||
|
"Переносим файлы в ../dataset/DEAM...\n",
|
||||||
|
"\n",
|
||||||
|
"[УСПЕХ] Датасет DEAM готов к работе!\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import shutil\n",
|
||||||
|
"import kagglehub\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"\n",
|
||||||
|
"# 1. Настройка путей\n",
|
||||||
|
"DATASET_ROOT = Path(\"../dataset\")\n",
|
||||||
|
"DEAM_ROOT = DATASET_ROOT / \"DEAM\"\n",
|
||||||
|
"DEAM_ROOT.mkdir(parents=True, exist_ok=True)\n",
|
||||||
|
"\n",
|
||||||
|
"# 2. Загрузка через kagglehub\n",
|
||||||
|
"print(\"Скачиваем датасет DEAM...\")\n",
|
||||||
|
"kaggle_cache_path = kagglehub.dataset_download(\"imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music\")\n",
|
||||||
|
"print(f\"Датасет скачан во временную директорию: {kaggle_cache_path}\")\n",
|
||||||
|
"\n",
|
||||||
|
"# 3. Перемещение файлов в проект\n",
|
||||||
|
"print(f\"Переносим файлы в {DEAM_ROOT}...\")\n",
|
||||||
|
"shutil.copytree(kaggle_cache_path, DEAM_ROOT, dirs_exist_ok=True)\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"\\n[УСПЕХ] Датасет DEAM готов к работе!\")\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python (my-python-project)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "my-python-project"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
||||||
@@ -0,0 +1,171 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class EmoSet(Dataset):
|
||||||
|
ATTRIBUTES_MULTI_CLASS = [
|
||||||
|
'scene', 'facial_expression', 'human_action', 'brightness', 'colorfulness',
|
||||||
|
]
|
||||||
|
ATTRIBUTES_MULTI_LABEL = [
|
||||||
|
'object'
|
||||||
|
]
|
||||||
|
NUM_CLASSES = {
|
||||||
|
'brightness': 11,
|
||||||
|
'colorfulness': 11,
|
||||||
|
'scene': 254,
|
||||||
|
'object': 409,
|
||||||
|
'facial_expression': 6,
|
||||||
|
'human_action': 264,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
data_root,
|
||||||
|
num_emotion_classes,
|
||||||
|
phase,
|
||||||
|
):
|
||||||
|
assert num_emotion_classes in (8, 2)
|
||||||
|
assert phase in ('train', 'val', 'test')
|
||||||
|
self.transforms_dict = self.get_data_transforms()
|
||||||
|
|
||||||
|
self.info = self.get_info(data_root, num_emotion_classes)
|
||||||
|
|
||||||
|
if phase == 'train':
|
||||||
|
self.transform = self.transforms_dict['train']
|
||||||
|
elif phase == 'val':
|
||||||
|
self.transform = self.transforms_dict['val']
|
||||||
|
elif phase == 'test':
|
||||||
|
self.transform = self.transforms_dict['test']
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
data_store = json.load(open(os.path.join(data_root, f'{phase}.json')))
|
||||||
|
self.data_store = [
|
||||||
|
[
|
||||||
|
self.info['emotion']['label2idx'][item[0]],
|
||||||
|
item[1],
|
||||||
|
os.path.join(data_root, item[2]),
|
||||||
|
os.path.join(data_root, item[3])
|
||||||
|
]
|
||||||
|
for item in data_store
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_data_transforms(cls):
|
||||||
|
transforms_dict = {
|
||||||
|
'train': transforms.Compose([
|
||||||
|
transforms.RandomResizedCrop(224),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
]),
|
||||||
|
'val': transforms.Compose([
|
||||||
|
transforms.Resize(224),
|
||||||
|
transforms.CenterCrop(224),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
]),
|
||||||
|
'test': transforms.Compose([
|
||||||
|
transforms.Resize(224),
|
||||||
|
transforms.CenterCrop(224),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
]),
|
||||||
|
}
|
||||||
|
return transforms_dict
|
||||||
|
|
||||||
|
def get_info(self, data_root, num_emotion_classes):
|
||||||
|
assert num_emotion_classes in (8, 2)
|
||||||
|
info = json.load(open(os.path.join(data_root, 'info.json')))
|
||||||
|
if num_emotion_classes == 8:
|
||||||
|
pass
|
||||||
|
elif num_emotion_classes == 2:
|
||||||
|
emotion_info = {
|
||||||
|
'label2idx': {
|
||||||
|
'amusement': 0,
|
||||||
|
'awe': 0,
|
||||||
|
'contentment': 0,
|
||||||
|
'excitement': 0,
|
||||||
|
'anger': 1,
|
||||||
|
'disgust': 1,
|
||||||
|
'fear': 1,
|
||||||
|
'sadness': 1,
|
||||||
|
},
|
||||||
|
'idx2label': {
|
||||||
|
'0': 'positive',
|
||||||
|
'1': 'negative',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info['emotion'] = emotion_info
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
def load_image_by_path(self, path):
|
||||||
|
image = Image.open(path).convert('RGB')
|
||||||
|
image = self.transform(image)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def load_annotation_by_path(self, path):
|
||||||
|
json_data = json.load(open(path))
|
||||||
|
return json_data
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
emotion_label_idx, image_id, image_path, annotation_path = self.data_store[item]
|
||||||
|
image = self.load_image_by_path(image_path)
|
||||||
|
annotation_data = self.load_annotation_by_path(annotation_path)
|
||||||
|
data = {'image_id': image_id, 'image': image, 'emotion_label_idx': emotion_label_idx}
|
||||||
|
|
||||||
|
for attribute in self.ATTRIBUTES_MULTI_CLASS:
|
||||||
|
# if empty, set to -1, else set to label index
|
||||||
|
attribute_label_idx = -1
|
||||||
|
if attribute in annotation_data:
|
||||||
|
attribute_label_idx = self.info[attribute]['label2idx'][str(annotation_data[attribute])]
|
||||||
|
data.update({f'{attribute}_label_idx': attribute_label_idx})
|
||||||
|
|
||||||
|
for attribute in self.ATTRIBUTES_MULTI_LABEL:
|
||||||
|
# if empty, set to 0, else set to 1
|
||||||
|
assert attribute == 'object'
|
||||||
|
num_classes = self.NUM_CLASSES[attribute]
|
||||||
|
attribute_label_idx = torch.zeros(num_classes)
|
||||||
|
if attribute in annotation_data:
|
||||||
|
for label in annotation_data[attribute]:
|
||||||
|
attribute_label_idx[self.info[attribute]['label2idx'][label]] = 1
|
||||||
|
data.update({f'{attribute}_label_idx': attribute_label_idx})
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data_store)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
data_root = r'F:\common_file_system\EmoSet\EmoSet_v5_划分train-test-val'
|
||||||
|
num_emotion_classes = 8
|
||||||
|
phase = 'train'
|
||||||
|
|
||||||
|
dataset = EmoSet(
|
||||||
|
data_root=data_root,
|
||||||
|
num_emotion_classes=num_emotion_classes,
|
||||||
|
phase=phase,
|
||||||
|
)
|
||||||
|
|
||||||
|
# print(dataset.info)
|
||||||
|
dataloader = DataLoader(dataset, batch_size = 16, shuffle = True)
|
||||||
|
|
||||||
|
for i, data in enumerate(dataloader):
|
||||||
|
pass
|
||||||
|
# print(data['emotion_label_idx'])
|
||||||
|
# print(data['scene_label_idx'])
|
||||||
|
# print(data['facial_expression_label_idx'])
|
||||||
|
# print(data['human_action_label_idx'])
|
||||||
|
# print(data['brightness_label_idx'])
|
||||||
|
# print(data['colorfulness_label_idx'])
|
||||||
|
# print(data['object_label_idx'])
|
||||||
|
# break
|
||||||
|
|
||||||
File diff suppressed because one or more lines are too long
Binary file not shown.
|
Before Width: | Height: | Size: 313 KiB |
@@ -1,12 +0,0 @@
|
|||||||
|
|
||||||
==================================================
|
|
||||||
ТАБЛИЦА МЕТРИК ДЛЯ СЛАЙДА 10
|
|
||||||
==================================================
|
|
||||||
| Метрика | Valence (V) | Arousal (A) | Общая (Total) |
|
|
||||||
|------------|--------------|--------------|---------------|
|
|
||||||
| MSE | 1.5135 | 2.2743 | 1.8939 |
|
|
||||||
| R² | 0.7927 | 0.4321 | 0.6124 |
|
|
||||||
==================================================
|
|
||||||
|
|
||||||
Формула целевой функции для вставки на слайд (LaTeX):
|
|
||||||
$$Score_{final} = D_{emo} + 4.0 \cdot Acoustic_{penalty}$$
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 243 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 1.3 MiB |
@@ -0,0 +1,88 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "b92e0213",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from pathlib import Path"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "1763c51e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"✅ УСПЕХ! База создана: ../../dataset/DEAM/music_db.csv\n",
|
||||||
|
"Всего треков в базе: 1744\n",
|
||||||
|
"Пример данных:\n",
|
||||||
|
" song_id valence arousal\n",
|
||||||
|
"0 2 3.1 3.0\n",
|
||||||
|
"1 3 3.5 3.3\n",
|
||||||
|
"2 4 5.7 5.5\n",
|
||||||
|
"3 5 4.4 5.3\n",
|
||||||
|
"4 7 5.8 6.4\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Точный путь к оригинальным аннотациям\n",
|
||||||
|
"source_path = Path(\"../../dataset/DEAM/DEAM_Annotations/annotations/annotations averaged per song/song_level/static_annotations_averaged_songs_1_2000.csv\")\n",
|
||||||
|
"# Путь, куда сохраним очищенную базу для движка\n",
|
||||||
|
"output_path = Path(\"../../dataset/DEAM/music_db.csv\")\n",
|
||||||
|
"\n",
|
||||||
|
"if not source_path.exists():\n",
|
||||||
|
" print(f\"❌ Исходный файл не найден по пути: {source_path}\")\n",
|
||||||
|
"else:\n",
|
||||||
|
" # skipinitialspace=True уберет лишние пробелы в названиях колонок, если они есть\n",
|
||||||
|
" df = pd.read_csv(source_path, skipinitialspace=True)\n",
|
||||||
|
" \n",
|
||||||
|
" # Берем только нужные колонки (по твоему примеру)\n",
|
||||||
|
" clean_df = df[['song_id', 'valence_mean', 'arousal_mean']].copy()\n",
|
||||||
|
" \n",
|
||||||
|
" # Переименовываем для простоты кода в движке\n",
|
||||||
|
" clean_df.columns = ['song_id', 'valence', 'arousal']\n",
|
||||||
|
" \n",
|
||||||
|
" # Приводим ID к целому числу (2, 3, 4...), чтобы искать файлы '2.mp3'\n",
|
||||||
|
" clean_df['song_id'] = clean_df['song_id'].astype(int)\n",
|
||||||
|
" \n",
|
||||||
|
" # Сохраняем финальный файл\n",
|
||||||
|
" clean_df.to_csv(output_path, index=False)\n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"✅ УСПЕХ! База создана: {output_path}\")\n",
|
||||||
|
" print(f\"Всего треков в базе: {len(clean_df)}\")\n",
|
||||||
|
" print(\"Пример данных:\")\n",
|
||||||
|
" print(clean_df.head())"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python (thesis)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "thesis"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "d70d8e32",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from concurrent.futures import ProcessPoolExecutor\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from torchvision import transforms\n",
|
||||||
|
"from tqdm import tqdm"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "31b0fa82",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
|
||||||
|
"TRANSFORM = transforms.Compose([\n",
|
||||||
|
" transforms.Resize((224,224)),\n",
|
||||||
|
" transforms.ToTensor(),\n",
|
||||||
|
" transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])\n",
|
||||||
|
"])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "1a17ecf5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" 0%| | 0/94481 [00:00<?, ?it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "PicklingError",
|
||||||
|
"evalue": "Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||||||
|
"\u001b[31m_RemoteTraceback\u001b[39m Traceback (most recent call last)",
|
||||||
|
"\u001b[31m_RemoteTraceback\u001b[39m: \n\"\"\"\nTraceback (most recent call last):\n File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 244, in _feed\n obj = _ForkingPickler.dumps(obj)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/reduction.py\", line 51, in dumps\n cls(buf, protocol).dump(obj)\n_pickle.PicklingError: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed\n\"\"\"",
|
||||||
|
"\nThe above exception was the direct cause of the following exception:\n",
|
||||||
|
"\u001b[31mPicklingError\u001b[39m Traceback (most recent call last)",
|
||||||
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 20\u001b[39m\n\u001b[32m 18\u001b[39m futures = [executor.submit(process_row, row, split_dir, tensor_dir) \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m df.itertuples()]\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m tqdm(futures):\n\u001b[32m---> \u001b[39m\u001b[32m20\u001b[39m results.append(\u001b[43mf\u001b[49m\u001b[43m.\u001b[49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 22\u001b[39m new_df = pd.DataFrame(results)\n\u001b[32m 23\u001b[39m new_df.to_csv(DATA_ROOT / split / \u001b[33m\"\u001b[39m\u001b[33mlabels_tensor.csv\u001b[39m\u001b[33m\"\u001b[39m, index=\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
|
||||||
|
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:449\u001b[39m, in \u001b[36mFuture.result\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 447\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m CancelledError()\n\u001b[32m 448\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state == FINISHED:\n\u001b[32m--> \u001b[39m\u001b[32m449\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m__get_result\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 451\u001b[39m \u001b[38;5;28mself\u001b[39m._condition.wait(timeout)\n\u001b[32m 453\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state \u001b[38;5;129;01min\u001b[39;00m [CANCELLED, CANCELLED_AND_NOTIFIED]:\n",
|
||||||
|
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:401\u001b[39m, in \u001b[36mFuture.__get_result\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 399\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception:\n\u001b[32m 400\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m401\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception\n\u001b[32m 402\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 403\u001b[39m \u001b[38;5;66;03m# Break a reference cycle with the exception in self._exception\u001b[39;00m\n\u001b[32m 404\u001b[39m \u001b[38;5;28mself\u001b[39m = \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
||||||
|
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py:244\u001b[39m, in \u001b[36mQueue._feed\u001b[39m\u001b[34m(buffer, notempty, send_bytes, writelock, reader_close, writer_close, ignore_epipe, onerror, queue_sem)\u001b[39m\n\u001b[32m 241\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[32m 243\u001b[39m \u001b[38;5;66;03m# serialize the data before acquiring the lock\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m244\u001b[39m obj = \u001b[43m_ForkingPickler\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdumps\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 245\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m wacquire \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 246\u001b[39m send_bytes(obj)\n",
|
||||||
|
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/reduction.py:51\u001b[39m, in \u001b[36mForkingPickler.dumps\u001b[39m\u001b[34m(cls, obj, protocol)\u001b[39m\n\u001b[32m 48\u001b[39m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[32m 49\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdumps\u001b[39m(\u001b[38;5;28mcls\u001b[39m, obj, protocol=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m 50\u001b[39m buf = io.BytesIO()\n\u001b[32m---> \u001b[39m\u001b[32m51\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mbuf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprotocol\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdump\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 52\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m buf.getbuffer()\n",
|
||||||
|
"\u001b[31mPicklingError\u001b[39m: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"def process_row(row, split_dir, tensor_dir):\n",
|
||||||
|
" img_path = split_dir / row.filename\n",
|
||||||
|
" img = Image.open(img_path).convert(\"RGB\")\n",
|
||||||
|
" tensor = TRANSFORM(img)\n",
|
||||||
|
" tensor_path = tensor_dir / f\"{row.filename}.pt\"\n",
|
||||||
|
" torch.save(tensor, tensor_path)\n",
|
||||||
|
" return {\"tensor_path\": str(tensor_path), \"label\": row.label}\n",
|
||||||
|
"\n",
|
||||||
|
"for split in [\"train\",\"val\",\"test\"]:\n",
|
||||||
|
" split_dir = DATA_ROOT / split / \"images\"\n",
|
||||||
|
" tensor_dir = DATA_ROOT / split / \"tensors\"\n",
|
||||||
|
" tensor_dir.mkdir(exist_ok=True, parents=True)\n",
|
||||||
|
"\n",
|
||||||
|
" df = pd.read_csv(DATA_ROOT / split / \"labels.csv\")\n",
|
||||||
|
"\n",
|
||||||
|
" results = []\n",
|
||||||
|
" with ProcessPoolExecutor(max_workers=12) as executor:\n",
|
||||||
|
" futures = [executor.submit(process_row, row, split_dir, tensor_dir) for row in df.itertuples()]\n",
|
||||||
|
" for f in tqdm(futures):\n",
|
||||||
|
" results.append(f.result())\n",
|
||||||
|
"\n",
|
||||||
|
" new_df = pd.DataFrame(results)\n",
|
||||||
|
" new_df.to_csv(DATA_ROOT / split / \"labels_tensor.csv\", index=False)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "thesis-py3.11",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@@ -1,319 +0,0 @@
|
|||||||
import os
|
|
||||||
import random
|
|
||||||
import warnings
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
from PIL import Image, ImageFile
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
import torchvision.transforms as T
|
|
||||||
from torch.amp import autocast, GradScaler
|
|
||||||
import timm
|
|
||||||
|
|
||||||
# Подавление предупреждений и защита от битых "хвостов" JPEG
|
|
||||||
warnings.filterwarnings("ignore")
|
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
||||||
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
print(f"Устройство: {DEVICE}")
|
|
||||||
|
|
||||||
# --- ПУТИ ---
|
|
||||||
TRAIN_ROOT = Path("./dataset/Original-2.41M")
|
|
||||||
ANCHOR_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/train") # ЯКОРЬ (Чистые данные для обучения)
|
|
||||||
VAL_118K_ROOT = Path("./NFS/Thesis/Emoset/EmoSet-118K/val")
|
|
||||||
|
|
||||||
SAVE_MODEL_PATH = Path("./src/emosetV2_resnet50_finetuned_2_41M.pth")
|
|
||||||
RESUME_CHECKPOINT = Path("./src/finetuneV2_resume.pth")
|
|
||||||
PRETRAINED_PATH = Path("./src/emosetV2_resnet50_best.pth")
|
|
||||||
|
|
||||||
CLASS_MAPPING = {
|
|
||||||
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
|
|
||||||
"disgust": 4, "excitement": 5, "fear": 6, "sadness": 7
|
|
||||||
}
|
|
||||||
|
|
||||||
# --- НАСТРОЙКИ ---
|
|
||||||
TOTAL_BATCH_SIZE = 64
|
|
||||||
BATCH_NOISY = 48 # 75% батча - новые данные 2.41M
|
|
||||||
BATCH_ANCHOR = 16 # 25% батча - чистые якорные данные 118K
|
|
||||||
|
|
||||||
EPOCHS_PER_FOLDER = 15
|
|
||||||
PATIENCE = 5
|
|
||||||
LR = 1e-6
|
|
||||||
NUM_TRAIN_WORKERS = 32
|
|
||||||
NUM_VAL_WORKERS = 32
|
|
||||||
|
|
||||||
def worker_init_fn(worker_id):
|
|
||||||
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
|
||||||
|
|
||||||
# --- 1. ТРАНСФОРМАЦИИ ---
|
|
||||||
train_transform = T.Compose([
|
|
||||||
T.Resize(256),
|
|
||||||
T.RandomResizedCrop(224, scale=(0.8, 1.0)),
|
|
||||||
T.RandomHorizontalFlip(),
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
||||||
])
|
|
||||||
|
|
||||||
val_transform = T.Compose([
|
|
||||||
T.Resize(256),
|
|
||||||
T.CenterCrop(224),
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
||||||
])
|
|
||||||
|
|
||||||
# --- 2. ДАТАСЕТЫ ---
|
|
||||||
class ChunkTrainDataset(Dataset):
|
|
||||||
def __init__(self, paths, transform):
|
|
||||||
self.paths = paths
|
|
||||||
self.transform = transform
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.paths)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
path = self.paths[idx]
|
|
||||||
try:
|
|
||||||
img = Image.open(path).convert('RGB')
|
|
||||||
tensor = self.transform(img)
|
|
||||||
label = CLASS_MAPPING.get(path.parts[-3].lower(), 0)
|
|
||||||
return tensor, label
|
|
||||||
except Exception:
|
|
||||||
return torch.zeros((3, 224, 224)), 0
|
|
||||||
|
|
||||||
class CsvDataset(Dataset):
|
|
||||||
def __init__(self, root, transform):
|
|
||||||
self.root = Path(root)
|
|
||||||
self.df = pd.read_csv(self.root / "labels.csv")
|
|
||||||
self.transform = transform
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.df)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
row = self.df.iloc[idx]
|
|
||||||
path = self.root / "images" / row["filename"]
|
|
||||||
try:
|
|
||||||
img = Image.open(path).convert('RGB')
|
|
||||||
tensor = self.transform(img)
|
|
||||||
label = CLASS_MAPPING.get(row["label"].lower(), 0)
|
|
||||||
return tensor, label
|
|
||||||
except Exception:
|
|
||||||
return torch.zeros((3, 224, 224)), 0
|
|
||||||
|
|
||||||
# --- 3. СБОР ДАННЫХ ---
|
|
||||||
def prepare_chunks():
|
|
||||||
print("\nСканирование датасета 2.41M...")
|
|
||||||
chunk_dict = defaultdict(list)
|
|
||||||
for path in TRAIN_ROOT.rglob('*.jpg'):
|
|
||||||
emotion = path.parts[-3].lower()
|
|
||||||
if emotion not in CLASS_MAPPING:
|
|
||||||
continue
|
|
||||||
folder_str = path.parts[-2]
|
|
||||||
if folder_str.isdigit():
|
|
||||||
chunk_dict[int(folder_str)].append(path)
|
|
||||||
|
|
||||||
sorted_chunks = sorted(chunk_dict.keys())
|
|
||||||
print(f"Найдено пронумерованных папок (чанков): {len(sorted_chunks)}")
|
|
||||||
return chunk_dict, sorted_chunks
|
|
||||||
# --- 4. ОСНОВНОЙ ЦИКЛ ОБУЧЕНИЯ ---
|
|
||||||
if __name__ == "__main__":
|
|
||||||
chunk_dict, sorted_chunks = prepare_chunks()
|
|
||||||
|
|
||||||
# Валидационный датасет (только чистые данные)
|
|
||||||
val_loader = DataLoader(
|
|
||||||
CsvDataset(VAL_118K_ROOT, val_transform),
|
|
||||||
batch_size=TOTAL_BATCH_SIZE, shuffle=False,
|
|
||||||
num_workers=NUM_VAL_WORKERS, pin_memory=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# ЯКОРНЫЙ ЗАГРУЗЧИК (Чистые данные для подмешивания)
|
|
||||||
# Используем prefetch_factor и persistent_workers для устранения рывков CPU
|
|
||||||
anchor_dataset = CsvDataset(ANCHOR_118K_ROOT, train_transform)
|
|
||||||
anchor_loader = DataLoader(
|
|
||||||
anchor_dataset, batch_size=BATCH_ANCHOR, shuffle=True,
|
|
||||||
num_workers=16, pin_memory=True, drop_last=True,
|
|
||||||
prefetch_factor=2, persistent_workers=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Инициализация модели
|
|
||||||
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
|
|
||||||
if PRETRAINED_PATH.exists():
|
|
||||||
model.load_state_dict(torch.load(PRETRAINED_PATH, map_location=DEVICE))
|
|
||||||
print(f"Базовые веса загружены из {PRETRAINED_PATH.name}")
|
|
||||||
|
|
||||||
# Размораживаем всю модель
|
|
||||||
for param in model.parameters():
|
|
||||||
param.requires_grad = True
|
|
||||||
|
|
||||||
# Дифференцированный оптимизатор
|
|
||||||
backbone_params = [p for n, p in model.named_parameters() if "fc" not in n]
|
|
||||||
fc_params = [p for n, p in model.named_parameters() if "fc" in n]
|
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW([
|
|
||||||
{'params': backbone_params, 'lr': LR}, # 1e-6: микро-шаг для основы
|
|
||||||
{'params': fc_params, 'lr': LR * 10} # 1e-5: шаг для классификатора
|
|
||||||
], weight_decay=1e-3)
|
|
||||||
|
|
||||||
# Label Smoothing помогает игнорировать мусор в разметке 2.41M
|
|
||||||
criterion = nn.CrossEntropyLoss(label_smoothing=0.15)
|
|
||||||
scaler = GradScaler()
|
|
||||||
|
|
||||||
# --- ПАРАМЕТРЫ ВОССТАНОВЛЕНИЯ ---
|
|
||||||
start_stage = 0
|
|
||||||
start_epoch = 1
|
|
||||||
best_val_loss = float('inf')
|
|
||||||
|
|
||||||
if RESUME_CHECKPOINT.exists():
|
|
||||||
print(f"\nОбнаружен чекпоинт: {RESUME_CHECKPOINT.name}. Восстановление...")
|
|
||||||
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
|
|
||||||
model.load_state_dict(checkpoint['model_state_dict'])
|
|
||||||
|
|
||||||
try:
|
|
||||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Оптимизатор сброшен: {e}")
|
|
||||||
|
|
||||||
best_val_loss = checkpoint['best_val_loss']
|
|
||||||
start_stage = checkpoint['stage']
|
|
||||||
start_epoch = checkpoint['epoch'] + 1
|
|
||||||
print(f"Успешный запуск с ЭТАПА {start_stage + 1}, Эпохи {start_epoch}. Best Val Loss: {best_val_loss:.4f}\n")
|
|
||||||
else:
|
|
||||||
# --- ЗАМЕР EPOCH 0 (БАЗОВАЯ ТОЧНОСТЬ) ---
|
|
||||||
# Выполняется только если мы начинаем с нуля
|
|
||||||
print("\n[Проверка базовых весов перед обучением (Epoch 0)]")
|
|
||||||
model.eval()
|
|
||||||
val_loss, val_correct, val_total = 0.0, 0, 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for inputs, labels in tqdm(val_loader, desc="Baseline Eval", smoothing=0):
|
|
||||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
|
||||||
with autocast(device_type="cuda"):
|
|
||||||
outputs = model(inputs)
|
|
||||||
v_loss = criterion(outputs, labels)
|
|
||||||
val_loss += v_loss.item() * inputs.size(0)
|
|
||||||
_, pred = outputs.max(1)
|
|
||||||
val_total += labels.size(0)
|
|
||||||
val_correct += pred.eq(labels).sum().item()
|
|
||||||
|
|
||||||
best_val_loss = val_loss / val_total
|
|
||||||
baseline_acc = val_correct / val_total
|
|
||||||
print(f"Стартовая точка -> Val Loss: {best_val_loss:.4f} | Val Acc: {baseline_acc:.4f}\n")
|
|
||||||
|
|
||||||
# ВОССТАНОВЛЕНИЕ НАКОПЛЕННЫХ ДАННЫХ
|
|
||||||
current_train_paths = []
|
|
||||||
for s in range(start_stage):
|
|
||||||
current_train_paths.extend(chunk_dict[sorted_chunks[s]])
|
|
||||||
|
|
||||||
print("Старт Anchor Curriculum Learning (Смешивание чистых и шумных данных).")
|
|
||||||
|
|
||||||
# ГЛАВНЫЙ ЦИКЛ ПО ПАПКАМ
|
|
||||||
for stage in range(start_stage, len(sorted_chunks)):
|
|
||||||
chunk_id = sorted_chunks[stage]
|
|
||||||
print(f"\n{'='*50}")
|
|
||||||
print(f"ЭТАП {stage+1}/{len(sorted_chunks)}: Добавляем папку '{chunk_id}'")
|
|
||||||
|
|
||||||
# Накопление и перемешивание
|
|
||||||
current_train_paths.extend(chunk_dict[chunk_id])
|
|
||||||
random.shuffle(current_train_paths)
|
|
||||||
print(f"Всего файлов (грязных) в текущем пуле: {len(current_train_paths)}")
|
|
||||||
|
|
||||||
# ОСНОВНОЙ ЗАГРУЗЧИК (Грязные данные) с PREFETCH
|
|
||||||
train_loader = DataLoader(
|
|
||||||
ChunkTrainDataset(current_train_paths, train_transform),
|
|
||||||
batch_size=BATCH_NOISY, shuffle=True,
|
|
||||||
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
|
|
||||||
worker_init_fn=worker_init_fn, drop_last=True,
|
|
||||||
prefetch_factor=4, persistent_workers=True # Устраняет рывки CPU
|
|
||||||
)
|
|
||||||
|
|
||||||
epochs_no_improve = 0
|
|
||||||
first_epoch = start_epoch if stage == start_stage else 1
|
|
||||||
|
|
||||||
# Инициализация итератора якорей
|
|
||||||
anchor_iter = iter(anchor_loader)
|
|
||||||
|
|
||||||
# ЦИКЛ ЭПОХ ДЛЯ ТЕКУЩЕГО ЭТАПА
|
|
||||||
for epoch in range(first_epoch, EPOCHS_PER_FOLDER + 1):
|
|
||||||
model.train()
|
|
||||||
train_loss, train_correct, train_total = 0.0, 0, 0
|
|
||||||
|
|
||||||
for noisy_inputs, noisy_labels in tqdm(train_loader, desc=f"S{stage+1}-Ep{epoch}/{EPOCHS_PER_FOLDER} [Train]", smoothing=0):
|
|
||||||
|
|
||||||
# Достаем якорный чистый батч
|
|
||||||
try:
|
|
||||||
anc_inputs, anc_labels = next(anchor_iter)
|
|
||||||
except StopIteration:
|
|
||||||
anchor_iter = iter(anchor_loader)
|
|
||||||
anc_inputs, anc_labels = next(anchor_iter)
|
|
||||||
|
|
||||||
# СМЕШИВАЕМ БАТЧИ (Грязные + Чистые)
|
|
||||||
inputs = torch.cat([noisy_inputs, anc_inputs]).to(DEVICE)
|
|
||||||
labels = torch.cat([noisy_labels, anc_labels]).to(DEVICE)
|
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
|
||||||
with autocast(device_type="cuda"):
|
|
||||||
outputs = model(inputs)
|
|
||||||
loss = criterion(outputs, labels)
|
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
|
|
||||||
train_loss += loss.item() * inputs.size(0)
|
|
||||||
_, pred = outputs.max(1)
|
|
||||||
train_total += labels.size(0)
|
|
||||||
train_correct += pred.eq(labels).sum().item()
|
|
||||||
|
|
||||||
# ВАЛИДАЦИЯ
|
|
||||||
model.eval()
|
|
||||||
val_loss, val_correct, val_total = 0.0, 0, 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for inputs, labels in tqdm(val_loader, desc="[Val]", leave=False, smoothing=0):
|
|
||||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
|
||||||
with autocast(device_type="cuda"):
|
|
||||||
outputs = model(inputs)
|
|
||||||
v_loss = criterion(outputs, labels)
|
|
||||||
val_loss += v_loss.item() * inputs.size(0)
|
|
||||||
_, pred = outputs.max(1)
|
|
||||||
val_total += labels.size(0)
|
|
||||||
val_correct += pred.eq(labels).sum().item()
|
|
||||||
|
|
||||||
avg_train_loss = train_loss / train_total
|
|
||||||
avg_train_acc = train_correct / train_total
|
|
||||||
avg_val_loss = val_loss / val_total
|
|
||||||
avg_val_acc = val_correct / val_total
|
|
||||||
|
|
||||||
print(f"S{stage+1}-E{epoch} | Train L: {avg_train_loss:.4f}, Acc: {avg_train_acc:.4f} | Val L: {avg_val_loss:.4f}, Acc: {avg_val_acc:.4f}")
|
|
||||||
|
|
||||||
# СОХРАНЕНИЕ ЛУЧШИХ ВЕСОВ
|
|
||||||
if avg_val_loss < best_val_loss:
|
|
||||||
best_val_loss = avg_val_loss
|
|
||||||
epochs_no_improve = 0
|
|
||||||
torch.save(model.state_dict(), SAVE_MODEL_PATH)
|
|
||||||
print("--> Обновлены лучшие веса")
|
|
||||||
else:
|
|
||||||
epochs_no_improve += 1
|
|
||||||
|
|
||||||
# АВАРИЙНОЕ СОХРАНЕНИЕ В КОНЦЕ ЭПОХИ
|
|
||||||
checkpoint_state = {
|
|
||||||
'stage': stage,
|
|
||||||
'epoch': epoch,
|
|
||||||
'model_state_dict': model.state_dict(),
|
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
|
||||||
'best_val_loss': best_val_loss
|
|
||||||
}
|
|
||||||
torch.save(checkpoint_state, RESUME_CHECKPOINT)
|
|
||||||
os.sync() # Защита от отключения электричества
|
|
||||||
print(f"--> Чекпоинт (Этап {stage+1}, Эпоха {epoch}) зафиксирован на диске.")
|
|
||||||
|
|
||||||
# РАННЯЯ ОСТАНОВКА ДЛЯ ТЕКУЩЕГО ЭТАПА
|
|
||||||
if epochs_no_improve >= PATIENCE:
|
|
||||||
print(f"Ранняя остановка для ЭТАПА {stage+1}. Переход к следующей папке...")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Сброс счетчика стартовой эпохи после прохождения восстановительного этапа
|
|
||||||
start_epoch = 1
|
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "ca08df84",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Using device: cuda\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Step 0/1000, Loss: 1.0013\n",
|
||||||
|
"Step 10/1000, Loss: 1.0088\n",
|
||||||
|
"Step 20/1000, Loss: 0.9956\n",
|
||||||
|
"Step 30/1000, Loss: 0.9781\n",
|
||||||
|
"Step 40/1000, Loss: 0.9613\n",
|
||||||
|
"Step 50/1000, Loss: 0.9313\n",
|
||||||
|
"Step 60/1000, Loss: 0.8927\n",
|
||||||
|
"Step 70/1000, Loss: 0.8503\n",
|
||||||
|
"Step 80/1000, Loss: 0.7537\n",
|
||||||
|
"Step 90/1000, Loss: 0.6689\n",
|
||||||
|
"Step 100/1000, Loss: 0.6063\n",
|
||||||
|
"Step 110/1000, Loss: 0.5172\n",
|
||||||
|
"Step 120/1000, Loss: 0.4592\n",
|
||||||
|
"Step 130/1000, Loss: 0.4044\n",
|
||||||
|
"Step 140/1000, Loss: 0.3610\n",
|
||||||
|
"Step 150/1000, Loss: 0.3175\n",
|
||||||
|
"Step 160/1000, Loss: 0.2825\n",
|
||||||
|
"Step 170/1000, Loss: 0.2560\n",
|
||||||
|
"Step 180/1000, Loss: 0.2360\n",
|
||||||
|
"Step 190/1000, Loss: 0.2203\n",
|
||||||
|
"Step 200/1000, Loss: 0.1930\n",
|
||||||
|
"Step 210/1000, Loss: 0.1854\n",
|
||||||
|
"Step 220/1000, Loss: 0.1723\n",
|
||||||
|
"Step 230/1000, Loss: 0.1546\n",
|
||||||
|
"Step 240/1000, Loss: 0.1386\n",
|
||||||
|
"Step 250/1000, Loss: 0.1271\n",
|
||||||
|
"Step 260/1000, Loss: 0.1109\n",
|
||||||
|
"Step 270/1000, Loss: 0.1032\n",
|
||||||
|
"Step 280/1000, Loss: 0.0899\n",
|
||||||
|
"Step 290/1000, Loss: 0.0807\n",
|
||||||
|
"Step 300/1000, Loss: 0.0750\n",
|
||||||
|
"Step 310/1000, Loss: 0.0813\n",
|
||||||
|
"Step 320/1000, Loss: 0.0612\n",
|
||||||
|
"Step 330/1000, Loss: 0.0544\n",
|
||||||
|
"Step 340/1000, Loss: 0.0552\n",
|
||||||
|
"Step 350/1000, Loss: 0.0446\n",
|
||||||
|
"Step 360/1000, Loss: 0.0403\n",
|
||||||
|
"Step 370/1000, Loss: 0.0350\n",
|
||||||
|
"Step 380/1000, Loss: 0.0612\n",
|
||||||
|
"Step 390/1000, Loss: 0.0364\n",
|
||||||
|
"Step 400/1000, Loss: 0.0322\n",
|
||||||
|
"Step 410/1000, Loss: 0.0302\n",
|
||||||
|
"Step 420/1000, Loss: 0.0519\n",
|
||||||
|
"Step 430/1000, Loss: 0.0319\n",
|
||||||
|
"Step 440/1000, Loss: 0.0260\n",
|
||||||
|
"Step 450/1000, Loss: 0.0208\n",
|
||||||
|
"Step 460/1000, Loss: 0.0409\n",
|
||||||
|
"Step 470/1000, Loss: 0.0291\n",
|
||||||
|
"Step 480/1000, Loss: 0.0234\n",
|
||||||
|
"Step 490/1000, Loss: 0.0194\n",
|
||||||
|
"Step 500/1000, Loss: 0.0274\n",
|
||||||
|
"Step 510/1000, Loss: 0.0231\n",
|
||||||
|
"Step 520/1000, Loss: 0.0199\n",
|
||||||
|
"Step 530/1000, Loss: 0.0154\n",
|
||||||
|
"Step 540/1000, Loss: 0.0278\n",
|
||||||
|
"Step 550/1000, Loss: 0.0185\n",
|
||||||
|
"Step 560/1000, Loss: 0.0180\n",
|
||||||
|
"Step 570/1000, Loss: 0.0152\n",
|
||||||
|
"Step 580/1000, Loss: 0.0132\n",
|
||||||
|
"Step 590/1000, Loss: 0.0111\n",
|
||||||
|
"Step 600/1000, Loss: 0.0396\n",
|
||||||
|
"Step 610/1000, Loss: 0.0179\n",
|
||||||
|
"Step 620/1000, Loss: 0.0148\n",
|
||||||
|
"Step 630/1000, Loss: 0.0123\n",
|
||||||
|
"Step 640/1000, Loss: 0.0265\n",
|
||||||
|
"Step 650/1000, Loss: 0.0133\n",
|
||||||
|
"Step 660/1000, Loss: 0.0128\n",
|
||||||
|
"Step 670/1000, Loss: 0.0107\n",
|
||||||
|
"Step 680/1000, Loss: 0.0142\n",
|
||||||
|
"Step 690/1000, Loss: 0.0202\n",
|
||||||
|
"Step 700/1000, Loss: 0.0125\n",
|
||||||
|
"Step 710/1000, Loss: 0.0107\n",
|
||||||
|
"Step 720/1000, Loss: 0.0140\n",
|
||||||
|
"Step 730/1000, Loss: 0.0195\n",
|
||||||
|
"Step 740/1000, Loss: 0.0148\n",
|
||||||
|
"Step 750/1000, Loss: 0.0109\n",
|
||||||
|
"Step 760/1000, Loss: 0.0094\n",
|
||||||
|
"Step 770/1000, Loss: 0.0121\n",
|
||||||
|
"Step 780/1000, Loss: 0.0233\n",
|
||||||
|
"Step 790/1000, Loss: 0.0151\n",
|
||||||
|
"Step 800/1000, Loss: 0.0134\n",
|
||||||
|
"Step 810/1000, Loss: 0.0117\n",
|
||||||
|
"Step 820/1000, Loss: 0.0124\n",
|
||||||
|
"Step 830/1000, Loss: 0.0221\n",
|
||||||
|
"Step 840/1000, Loss: 0.0161\n",
|
||||||
|
"Step 850/1000, Loss: 0.0136\n",
|
||||||
|
"Step 860/1000, Loss: 0.0161\n",
|
||||||
|
"Step 870/1000, Loss: 0.0194\n",
|
||||||
|
"Step 880/1000, Loss: 0.0145\n",
|
||||||
|
"Step 890/1000, Loss: 0.0149\n",
|
||||||
|
"Step 900/1000, Loss: 0.0232\n",
|
||||||
|
"Step 910/1000, Loss: 0.0166\n",
|
||||||
|
"Step 920/1000, Loss: 0.0156\n",
|
||||||
|
"Step 930/1000, Loss: 0.0276\n",
|
||||||
|
"Step 940/1000, Loss: 0.0176\n",
|
||||||
|
"Step 950/1000, Loss: 0.0152\n",
|
||||||
|
"Step 960/1000, Loss: 0.0162\n",
|
||||||
|
"Step 970/1000, Loss: 0.0143\n",
|
||||||
|
"Step 980/1000, Loss: 0.0136\n",
|
||||||
|
"Step 990/1000, Loss: 0.0117\n",
|
||||||
|
"Total time: 67.25 s\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"import torch.optim as optim\n",
|
||||||
|
"import time\n",
|
||||||
|
"\n",
|
||||||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||||
|
"print(\"Using device:\", device)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Огромные параметры\n",
|
||||||
|
"N, D_in, H1, H2, H3, D_out = 300_000, 4096, 2048, 1024, 512, 10\n",
|
||||||
|
"batch_size = 16_384 # большой батч\n",
|
||||||
|
"steps = 1000 # много итераций для длительной нагрузки\n",
|
||||||
|
"\n",
|
||||||
|
"# Случайные данные на GPU\n",
|
||||||
|
"x = torch.randn(N, D_in, device=device, dtype=torch.float32)\n",
|
||||||
|
"y = torch.randn(N, D_out, device=device, dtype=torch.float32)\n",
|
||||||
|
"\n",
|
||||||
|
"model = nn.Sequential(\n",
|
||||||
|
" nn.Linear(D_in, H1),\n",
|
||||||
|
" nn.ReLU(),\n",
|
||||||
|
" nn.Linear(H1, H2),\n",
|
||||||
|
" nn.ReLU(),\n",
|
||||||
|
" nn.Linear(H2, H3),\n",
|
||||||
|
" nn.ReLU(),\n",
|
||||||
|
" nn.Linear(H3, D_out)\n",
|
||||||
|
").to(device)\n",
|
||||||
|
"\n",
|
||||||
|
"loss_fn = nn.MSELoss()\n",
|
||||||
|
"optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
|
||||||
|
"\n",
|
||||||
|
"start = time.time()\n",
|
||||||
|
"for t in range(steps):\n",
|
||||||
|
" idx = torch.randint(0, N, (batch_size,), device=device)\n",
|
||||||
|
" x_batch = x[idx]\n",
|
||||||
|
" y_batch = y[idx]\n",
|
||||||
|
"\n",
|
||||||
|
" y_pred = model(x_batch)\n",
|
||||||
|
" loss = loss_fn(y_pred, y_batch)\n",
|
||||||
|
"\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
"\n",
|
||||||
|
" if t % 10 == 0:\n",
|
||||||
|
" # замедляем вывод, чтобы можно было наблюдать\n",
|
||||||
|
" print(f\"Step {t}/{steps}, Loss: {loss.item():.4f}\")\n",
|
||||||
|
"\n",
|
||||||
|
"end = time.time()\n",
|
||||||
|
"print(f\"Total time: {end-start:.2f} s\")\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@@ -0,0 +1,759 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9336560f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "0c00b67b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"from torch.utils.data import Dataset, DataLoader\n",
|
||||||
|
"import torchvision.transforms as T\n",
|
||||||
|
"\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"\n",
|
||||||
|
"import timm\n",
|
||||||
|
"import numpy as np\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "84c3657f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'cuda'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# === CONFIG ===\n",
|
||||||
|
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
|
||||||
|
"BATCH_SIZE = 64 # V100 спокойно тянет\n",
|
||||||
|
"EPOCHS = 15\n",
|
||||||
|
"LR = 3e-4\n",
|
||||||
|
"NUM_WORKERS = 24\n",
|
||||||
|
"\n",
|
||||||
|
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||||
|
"DEVICE\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "9f749add",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class EmoSetDataset(Dataset):\n",
|
||||||
|
" def __init__(self, root, split):\n",
|
||||||
|
" self.root = Path(root) / split\n",
|
||||||
|
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
|
||||||
|
"\n",
|
||||||
|
" self.labels = sorted(self.df[\"label\"].unique())\n",
|
||||||
|
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
|
||||||
|
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
|
||||||
|
"\n",
|
||||||
|
" self.transform = T.Compose([\n",
|
||||||
|
" T.Resize((224, 224)),\n",
|
||||||
|
" T.ToTensor(),\n",
|
||||||
|
" T.Normalize(\n",
|
||||||
|
" mean=[0.485, 0.456, 0.406],\n",
|
||||||
|
" std=[0.229, 0.224, 0.225]\n",
|
||||||
|
" )\n",
|
||||||
|
" ])\n",
|
||||||
|
"\n",
|
||||||
|
" def __len__(self):\n",
|
||||||
|
" return len(self.df)\n",
|
||||||
|
"\n",
|
||||||
|
" def __getitem__(self, idx):\n",
|
||||||
|
" row = self.df.iloc[idx]\n",
|
||||||
|
" img_path = self.root / \"images\" / row[\"filename\"]\n",
|
||||||
|
"\n",
|
||||||
|
" img = Image.open(img_path).convert(\"RGB\")\n",
|
||||||
|
" img = self.transform(img)\n",
|
||||||
|
"\n",
|
||||||
|
" label = self.label2idx[row[\"label\"]]\n",
|
||||||
|
" return img, label\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "c8805341",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Classes: ['amusement', 'anger', 'awe', 'contentment', 'disgust', 'excitement', 'fear', 'sadness']\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"train_ds = EmoSetDataset(DATA_ROOT, \"train\")\n",
|
||||||
|
"val_ds = EmoSetDataset(DATA_ROOT, \"val\")\n",
|
||||||
|
"\n",
|
||||||
|
"train_loader = DataLoader(\n",
|
||||||
|
" train_ds,\n",
|
||||||
|
" batch_size=BATCH_SIZE,\n",
|
||||||
|
" shuffle=True,\n",
|
||||||
|
" num_workers=NUM_WORKERS,\n",
|
||||||
|
" pin_memory=True\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"val_loader = DataLoader(\n",
|
||||||
|
" val_ds,\n",
|
||||||
|
" batch_size=BATCH_SIZE,\n",
|
||||||
|
" shuffle=False,\n",
|
||||||
|
" num_workers=NUM_WORKERS,\n",
|
||||||
|
" pin_memory=True\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Classes:\", train_ds.labels)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "dffce582",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"ResNet(\n",
|
||||||
|
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
|
||||||
|
" (layer1): Sequential(\n",
|
||||||
|
" (0): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" (downsample): Sequential(\n",
|
||||||
|
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (1): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" (2): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (layer2): Sequential(\n",
|
||||||
|
" (0): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" (downsample): Sequential(\n",
|
||||||
|
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||||
|
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (1): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" (2): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" (3): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (layer3): Sequential(\n",
|
||||||
|
" (0): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" (downsample): Sequential(\n",
|
||||||
|
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||||
|
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (1): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" (2): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" (3): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" (4): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" (5): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (layer4): Sequential(\n",
|
||||||
|
" (0): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" (downsample): Sequential(\n",
|
||||||
|
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||||
|
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (1): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" (2): Bottleneck(\n",
|
||||||
|
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act1): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (drop_block): Identity()\n",
|
||||||
|
" (act2): ReLU(inplace=True)\n",
|
||||||
|
" (aa): Identity()\n",
|
||||||
|
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||||
|
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (act3): ReLU(inplace=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
|
||||||
|
" (fc): Linear(in_features=2048, out_features=8, bias=True)\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model = timm.create_model(\n",
|
||||||
|
" \"resnet50\",\n",
|
||||||
|
" pretrained=True,\n",
|
||||||
|
" num_classes=len(train_ds.labels)\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"model.to(DEVICE)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "81a457ef",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"criterion = nn.CrossEntropyLoss()\n",
|
||||||
|
"\n",
|
||||||
|
"optimizer = torch.optim.AdamW(\n",
|
||||||
|
" model.parameters(),\n",
|
||||||
|
" lr=LR,\n",
|
||||||
|
" weight_decay=1e-4\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
|
||||||
|
" optimizer,\n",
|
||||||
|
" T_max=EPOCHS\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "951aa9e3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def train_epoch(model, loader):\n",
|
||||||
|
" model.train()\n",
|
||||||
|
" total_loss = 0\n",
|
||||||
|
" correct = 0\n",
|
||||||
|
" total = 0\n",
|
||||||
|
"\n",
|
||||||
|
" for imgs, labels in tqdm(loader, leave=False):\n",
|
||||||
|
" imgs = imgs.to(DEVICE)\n",
|
||||||
|
" labels = labels.to(DEVICE)\n",
|
||||||
|
"\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" logits = model(imgs)\n",
|
||||||
|
" loss = criterion(logits, labels)\n",
|
||||||
|
"\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
"\n",
|
||||||
|
" total_loss += loss.item() * imgs.size(0)\n",
|
||||||
|
" preds = logits.argmax(dim=1)\n",
|
||||||
|
" correct += (preds == labels).sum().item()\n",
|
||||||
|
" total += labels.size(0)\n",
|
||||||
|
"\n",
|
||||||
|
" return total_loss / total, correct / total\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "fb7e9398",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"@torch.no_grad()\n",
|
||||||
|
"def val_epoch(model, loader):\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
" total_loss = 0\n",
|
||||||
|
" correct = 0\n",
|
||||||
|
" total = 0\n",
|
||||||
|
"\n",
|
||||||
|
" for imgs, labels in loader:\n",
|
||||||
|
" imgs = imgs.to(DEVICE)\n",
|
||||||
|
" labels = labels.to(DEVICE)\n",
|
||||||
|
"\n",
|
||||||
|
" logits = model(imgs)\n",
|
||||||
|
" loss = criterion(logits, labels)\n",
|
||||||
|
"\n",
|
||||||
|
" total_loss += loss.item() * imgs.size(0)\n",
|
||||||
|
" preds = logits.argmax(dim=1)\n",
|
||||||
|
" correct += (preds == labels).sum().item()\n",
|
||||||
|
" total += labels.size(0)\n",
|
||||||
|
"\n",
|
||||||
|
" return total_loss / total, correct / total\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "9e870e5d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" 0%| | 0/1477 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 01 | Train loss: 0.8383, acc: 0.6954 | Val loss: 0.6694, acc: 0.7563\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 02 | Train loss: 0.5462, acc: 0.7972 | Val loss: 0.6592, acc: 0.7594\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 03 | Train loss: 0.3654, acc: 0.8632 | Val loss: 0.7263, acc: 0.7600\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 04 | Train loss: 0.2111, acc: 0.9230 | Val loss: 0.8572, acc: 0.7472\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 05 | Train loss: 0.1187, acc: 0.9585 | Val loss: 1.0372, acc: 0.7453\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 06 | Train loss: 0.0690, acc: 0.9768 | Val loss: 1.1982, acc: 0.7529\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 07 | Train loss: 0.0466, acc: 0.9843 | Val loss: 1.3178, acc: 0.7492\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 08 | Train loss: 0.0295, acc: 0.9905 | Val loss: 1.3926, acc: 0.7551\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 09 | Train loss: 0.0204, acc: 0.9938 | Val loss: 1.4682, acc: 0.7497\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 10 | Train loss: 0.0146, acc: 0.9955 | Val loss: 1.4784, acc: 0.7604\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 11 | Train loss: 0.0087, acc: 0.9975 | Val loss: 1.5263, acc: 0.7580\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 12 | Train loss: 0.0057, acc: 0.9987 | Val loss: 1.5689, acc: 0.7558\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 13 | Train loss: 0.0044, acc: 0.9990 | Val loss: 1.5952, acc: 0.7566\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 14 | Train loss: 0.0030, acc: 0.9993 | Val loss: 1.6130, acc: 0.7600\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 15 | Train loss: 0.0025, acc: 0.9995 | Val loss: 1.5921, acc: 0.7627\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"best_val_acc = 0.0\n",
|
||||||
|
"\n",
|
||||||
|
"for epoch in range(1, EPOCHS + 1):\n",
|
||||||
|
" train_loss, train_acc = train_epoch(model, train_loader)\n",
|
||||||
|
" val_loss, val_acc = val_epoch(model, val_loader)\n",
|
||||||
|
"\n",
|
||||||
|
" scheduler.step()\n",
|
||||||
|
"\n",
|
||||||
|
" print(\n",
|
||||||
|
" f\"Epoch {epoch:02d} | \"\n",
|
||||||
|
" f\"Train loss: {train_loss:.4f}, acc: {train_acc:.4f} | \"\n",
|
||||||
|
" f\"Val loss: {val_loss:.4f}, acc: {val_acc:.4f}\"\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" if val_acc > best_val_acc:\n",
|
||||||
|
" best_val_acc = val_acc\n",
|
||||||
|
" torch.save(model.state_dict(), \"emoset_resnet50_best.pth\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7796ef11",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "thesis-py3.11",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
from sklearn.linear_model import RidgeCV
|
||||||
|
from sklearn.multioutput import MultiOutputRegressor
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from sklearn.pipeline import Pipeline
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.metrics import mean_squared_error, r2_score
|
||||||
|
import joblib
|
||||||
|
|
||||||
|
# 1. Алфавитный маппинг EmoSet
|
||||||
|
EMO_VA_MAP = {
|
||||||
|
0: (7.5, 6.5), # amusement
|
||||||
|
1: (2.0, 8.0), # anger
|
||||||
|
2: (6.5, 5.0), # awe
|
||||||
|
3: (7.0, 3.0), # contentment
|
||||||
|
4: (3.0, 6.0), # disgust
|
||||||
|
5: (8.0, 8.0), # excitement
|
||||||
|
6: (2.5, 7.5), # fear
|
||||||
|
7: (2.0, 2.0), # sadness
|
||||||
|
}
|
||||||
|
|
||||||
|
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||||
|
EMBEDDINGS_PATH = BASE_DIR / "emoset_test_embeddings.npy"
|
||||||
|
LABELS_PATH = BASE_DIR / "emoset_test_labels.npy"
|
||||||
|
|
||||||
|
print("Загрузка данных...")
|
||||||
|
X = np.load(EMBEDDINGS_PATH)
|
||||||
|
y_labels = np.load(LABELS_PATH)
|
||||||
|
|
||||||
|
y_va = np.array([EMO_VA_MAP[label] for label in y_labels])
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y_va, test_size=0.2, random_state=42)
|
||||||
|
|
||||||
|
# 2. НОВАЯ, ПРАВИЛЬНАЯ АРХИТЕКТУРА (Pipeline)
|
||||||
|
print("Обучение масштабатора и RidgeCV регрессора...")
|
||||||
|
# Pipeline гарантирует, что при предсказании в main.py новые векторы тоже будут масштабированы
|
||||||
|
model = Pipeline([
|
||||||
|
('scaler', StandardScaler()),
|
||||||
|
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
|
||||||
|
])
|
||||||
|
|
||||||
|
model.fit(X_train, y_train)
|
||||||
|
|
||||||
|
# 3. Диагностика и Оценка
|
||||||
|
y_pred = model.predict(X_test)
|
||||||
|
|
||||||
|
mse = mean_squared_error(y_test, y_pred)
|
||||||
|
r2 = r2_score(y_test, y_pred)
|
||||||
|
|
||||||
|
print(f"\n[УСПЕХ] Обучение завершено!")
|
||||||
|
print(f"MSE: {mse:.4f}")
|
||||||
|
print(f"R^2 Score: {r2:.4f}")
|
||||||
|
|
||||||
|
# === ТОТ САМЫЙ ТЕСТ НА КОЛЛАПС ===
|
||||||
|
print("\n--- ДИАГНОСТИКА РАЗБРОСА ПРЕДСКАЗАНИЙ ---")
|
||||||
|
print(f"Valence: от {y_pred[:, 0].min():.2f} до {y_pred[:, 0].max():.2f} (Эталон: 2.0 - 8.0)")
|
||||||
|
print(f"Arousal: от {y_pred[:, 1].min():.2f} до {y_pred[:, 1].max():.2f} (Эталон: 2.0 - 8.0)")
|
||||||
|
# ===============================================
|
||||||
|
|
||||||
|
# 4. Сохранение (Pipeline сохраняется целиком со StandardScaler)
|
||||||
|
output_model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
|
||||||
|
output_model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
joblib.dump(model, output_model_path)
|
||||||
|
print(f"\nМодель сохранена в: {output_model_path}")
|
||||||
@@ -42,7 +42,7 @@ def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_pa
|
|||||||
st.session_state.ds_current_options = random.sample(range(len(image_files)), 6)
|
st.session_state.ds_current_options = random.sample(range(len(image_files)), 6)
|
||||||
st.rerun()
|
st.rerun()
|
||||||
else:
|
else:
|
||||||
st.success("Анализ завершен! Ваш эмоциональный профиль готов.")
|
st.success("✅ Анализ завершен! Ваш эмоциональный профиль готов.")
|
||||||
|
|
||||||
all_v, all_a = [], []
|
all_v, all_a = [], []
|
||||||
for idx in st.session_state.ds_chosen_indices:
|
for idx in st.session_state.ds_chosen_indices:
|
||||||
@@ -56,7 +56,7 @@ def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_pa
|
|||||||
col_left, col_right = st.columns([1, 2])
|
col_left, col_right = st.columns([1, 2])
|
||||||
|
|
||||||
with col_left:
|
with col_left:
|
||||||
st.header("Ваш профиль")
|
st.header("📊 Ваш профиль")
|
||||||
st.metric("Позитивность (Valence)", f"{target_v:.2f}")
|
st.metric("Позитивность (Valence)", f"{target_v:.2f}")
|
||||||
st.metric("Энергия (Arousal)", f"{target_a:.2f}")
|
st.metric("Энергия (Arousal)", f"{target_a:.2f}")
|
||||||
|
|
||||||
@@ -74,8 +74,7 @@ def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_pa
|
|||||||
c1, c2 = st.columns([1, 3])
|
c1, c2 = st.columns([1, 3])
|
||||||
with c1:
|
with c1:
|
||||||
st.write(f"**ID:** {int(row['song_id'])}")
|
st.write(f"**ID:** {int(row['song_id'])}")
|
||||||
score_val = row.get('final_score', row.get('emo_distance', 0))
|
st.caption(f"L2 Dist: {row['distance']:.2f}")
|
||||||
st.caption(f"Dist Score: {score_val:.2f}")
|
|
||||||
with c2:
|
with c2:
|
||||||
audio_path = matcher.get_audio_path(row['song_id'])
|
audio_path = matcher.get_audio_path(row['song_id'])
|
||||||
if audio_path:
|
if audio_path:
|
||||||
|
|||||||
+36
-162
@@ -1,162 +1,75 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
import streamlit.components.v1 as components
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import base64
|
import matplotlib.pyplot as plt
|
||||||
from io import BytesIO
|
from music_engine.llm_bridge import LLMAcousticBridge # ИМПОРТИРУЕМ МОСТ
|
||||||
from music_engine.llm_bridge import LLMAcousticBridge
|
|
||||||
|
|
||||||
# Вспомогательная функция для крохотного предпросмотра
|
|
||||||
def get_thumbnail_html(images, max_display=12):
|
|
||||||
html_images = ""
|
|
||||||
for file in images[:max_display]:
|
|
||||||
img = Image.open(file)
|
|
||||||
img.thumbnail((100, 100)) # Сжимаем картинку
|
|
||||||
if img.mode != "RGB":
|
|
||||||
img = img.convert("RGB")
|
|
||||||
|
|
||||||
buffered = BytesIO()
|
|
||||||
img.save(buffered, format="JPEG")
|
|
||||||
b64_str = base64.b64encode(buffered.getvalue()).decode()
|
|
||||||
|
|
||||||
# Строгие стили для квадратных миниатюр
|
|
||||||
html_images += f'<img src="data:image/jpeg;base64,{b64_str}" style="width: 60px; height: 60px; object-fit: cover; border-radius: 8px; margin-right: 8px; margin-bottom: 8px; border: 1px solid rgba(255, 255, 255, 0.2);">'
|
|
||||||
|
|
||||||
# Индикатор оставшихся фото, если их много
|
|
||||||
if len(images) > max_display:
|
|
||||||
html_images += f'<span style="display: inline-block; width: 60px; height: 60px; line-height: 60px; text-align: center; background: rgba(150, 150, 150, 0.2); border-radius: 8px; vertical-align: top; font-size: 14px;">+{len(images) - max_display}</span>'
|
|
||||||
|
|
||||||
return f'<div style="display: flex; flex-wrap: wrap;">{html_images}</div>'
|
|
||||||
|
|
||||||
def render_live_tab(matcher, image_processor):
|
def render_live_tab(matcher, image_processor):
|
||||||
if "live_state" not in st.session_state:
|
|
||||||
st.session_state.live_state = "upload"
|
|
||||||
if "result_data" not in st.session_state:
|
|
||||||
st.session_state.result_data = None
|
|
||||||
|
|
||||||
viewport = st.query_params.get("viewport", "desktop")
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# CSS ИНЪЕКЦИИ
|
|
||||||
# ==========================================
|
|
||||||
st.markdown("""
|
|
||||||
<style>
|
|
||||||
[data-testid="stFileUploadDropzone"] {
|
|
||||||
min-height: 250px !important;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
border-radius: 16px;
|
|
||||||
background-color: rgba(255, 75, 75, 0.03);
|
|
||||||
}
|
|
||||||
.spinner-container {
|
|
||||||
display: flex; flex-direction: column; align-items: center;
|
|
||||||
justify-content: center; min-height: 40vh; margin-top: 10vh;
|
|
||||||
}
|
|
||||||
.big-spinner {
|
|
||||||
width: 120px; height: 120px; border: 10px solid rgba(255, 75, 75, 0.1);
|
|
||||||
border-top: 10px solid #ff4b4b; border-radius: 50%;
|
|
||||||
animation: spin 1s linear infinite; margin-bottom: 2rem;
|
|
||||||
}
|
|
||||||
@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
|
|
||||||
</style>
|
|
||||||
""", unsafe_allow_html=True)
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# ЭКРАН 1: ЗАГРУЗКА
|
|
||||||
# ==========================================
|
|
||||||
if st.session_state.live_state == "upload":
|
|
||||||
|
|
||||||
upload_placeholder = st.empty()
|
|
||||||
with upload_placeholder.container():
|
|
||||||
st.write("Загрузите фотографии с вашего устройства. Система проанализирует эмоции и семантику кадра.")
|
st.write("Загрузите фотографии с вашего устройства. Система проанализирует эмоции и семантику кадра.")
|
||||||
|
|
||||||
if viewport == "mobile":
|
|
||||||
st.markdown("<br>", unsafe_allow_html=True)
|
|
||||||
|
|
||||||
uploaded_files = st.file_uploader(
|
uploaded_files = st.file_uploader(
|
||||||
"Перетащите изображения сюда",
|
"Перетащите изображения сюда",
|
||||||
type=['png', 'jpg', 'jpeg'],
|
type=['png', 'jpg', 'jpeg'],
|
||||||
accept_multiple_files=True,
|
accept_multiple_files=True
|
||||||
label_visibility="collapsed" if viewport == "mobile" else "visible"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if uploaded_files:
|
if uploaded_files:
|
||||||
# 1. КНОПКА СРАЗУ ПОСЛЕ ЗАГРУЗКИ (Не нужно скроллить вниз)
|
st.subheader("Анализ визуальных признаков:")
|
||||||
st.markdown("<br>", unsafe_allow_html=True)
|
|
||||||
if st.button("Сгенерировать саундтрек", type="primary", use_container_width=True):
|
|
||||||
st.session_state.uploaded_images = uploaded_files
|
|
||||||
st.session_state.live_state = "processing"
|
|
||||||
upload_placeholder.empty()
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# 2. МИНИАТЮРЫ ПОД КНОПКОЙ
|
|
||||||
st.markdown("<br>", unsafe_allow_html=True)
|
|
||||||
st.caption("Выбранные кадры:")
|
|
||||||
# Генерируем компактный блок миниатюр
|
|
||||||
st.markdown(get_thumbnail_html(uploaded_files), unsafe_allow_html=True)
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# ЭКРАН 2: АНАЛИЗ (СПИННЕР)
|
|
||||||
# ==========================================
|
|
||||||
elif st.session_state.live_state == "processing":
|
|
||||||
|
|
||||||
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
|
|
||||||
|
|
||||||
files = st.session_state.get("uploaded_images", [])
|
|
||||||
st.markdown('<div class="spinner-container"><div class="big-spinner"></div></div>', unsafe_allow_html=True)
|
|
||||||
status_text = st.empty()
|
|
||||||
|
|
||||||
|
cols = st.columns(min(len(uploaded_files), 5))
|
||||||
images = []
|
images = []
|
||||||
all_objects = []
|
all_objects = []
|
||||||
all_v, all_a = [], []
|
|
||||||
|
|
||||||
for i, file in enumerate(files):
|
|
||||||
status_text.markdown(f"<h3 style='text-align: center; font-weight: 400;'>Анализ кадра {i + 1} из {len(files)}...</h3>", unsafe_allow_html=True)
|
|
||||||
|
|
||||||
|
for i, file in enumerate(uploaded_files):
|
||||||
img = Image.open(file)
|
img = Image.open(file)
|
||||||
images.append(img)
|
images.append(img)
|
||||||
|
with cols[i % 5]:
|
||||||
|
st.image(img, use_container_width=True)
|
||||||
|
with st.spinner("VLM Анализ..."):
|
||||||
|
caption = image_processor.describe_scene(img)
|
||||||
|
st.caption(f"👁️ *{caption.capitalize()}*")
|
||||||
|
all_objects.append(caption)
|
||||||
|
|
||||||
|
if st.button("🎵 Сгенерировать саундтрек", type="primary", use_container_width=True):
|
||||||
|
|
||||||
|
# 1. Извлекаем эмоции
|
||||||
|
all_v, all_a = [], []
|
||||||
|
for img in images:
|
||||||
embedding = image_processor.extract_embedding(img)
|
embedding = image_processor.extract_embedding(img)
|
||||||
v, a = matcher.predict_va(embedding)
|
v, a = matcher.predict_va(embedding)
|
||||||
all_v.append(v)
|
all_v.append(v)
|
||||||
all_a.append(a)
|
all_a.append(a)
|
||||||
|
|
||||||
caption = image_processor.describe_scene(img)
|
|
||||||
all_objects.append(caption)
|
|
||||||
|
|
||||||
target_v, target_a = np.mean(all_v), np.mean(all_a)
|
target_v, target_a = np.mean(all_v), np.mean(all_a)
|
||||||
|
|
||||||
status_text.markdown("<h3 style='text-align: center; font-weight: 400;'>Трансляция семантики в аудиопрофиль...</h3>", unsafe_allow_html=True)
|
# 2. Переводим Объекты -> Акустику через LLM
|
||||||
|
with st.spinner("Phi-3 генерирует акустический профиль..."):
|
||||||
llm = LLMAcousticBridge()
|
llm = LLMAcousticBridge()
|
||||||
llm_profile = llm.get_acoustic_profile(target_v, target_a, list(set(all_objects)))
|
llm_profile = llm.get_acoustic_profile(target_v, target_a, list(set(all_objects)))
|
||||||
|
|
||||||
status_text.markdown("<h3 style='text-align: center; font-weight: 400;'>Поиск идеальных композиций...</h3>", unsafe_allow_html=True)
|
# 3. Ищем треки
|
||||||
playlist = matcher.find_nearest_tracks(target_v, target_a, llm_profile=llm_profile, top_k=15)
|
with st.spinner("Поиск треков в базе DEAM..."):
|
||||||
|
playlist = matcher.find_nearest_tracks(target_v, target_a, llm_profile=llm_profile, top_k=5)
|
||||||
|
|
||||||
st.session_state.result_data = {
|
st.success("✅ Кросс-модальный анализ завершен!")
|
||||||
"target_v": target_v,
|
|
||||||
"target_a": target_a,
|
|
||||||
"llm_profile": llm_profile,
|
|
||||||
"playlist": playlist,
|
|
||||||
"semantics": list(set(all_objects))
|
|
||||||
}
|
|
||||||
st.session_state.live_state = "result"
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# ==========================================
|
# ВЫВОД РЕЗУЛЬТАТОВ
|
||||||
# ЭКРАН 3: РЕЗУЛЬТАТЫ
|
col_left, col_right = st.columns([1, 2])
|
||||||
# ==========================================
|
|
||||||
elif st.session_state.live_state == "result":
|
|
||||||
|
|
||||||
components.html("<script>window.parent.scrollTo(0, 0);</script>", height=0, width=0)
|
with col_left:
|
||||||
|
st.header("📊 Профиль")
|
||||||
|
st.metric("Valence (Настроение)", f"{target_v:.2f}")
|
||||||
|
st.metric("Arousal (Энергия)", f"{target_a:.2f}")
|
||||||
|
|
||||||
data = st.session_state.result_data
|
if llm_profile:
|
||||||
st.header("Рекомендованный плейлист")
|
st.write("**Требования LLM к звуку:**")
|
||||||
|
for k, v in llm_profile.items():
|
||||||
|
st.caption(f"- {k}: {v:.2f}")
|
||||||
|
|
||||||
for _, row in data["playlist"].iterrows():
|
with col_right:
|
||||||
|
st.header("🎵 Плейлист")
|
||||||
|
for _, row in playlist.iterrows():
|
||||||
with st.container(border=True):
|
with st.container(border=True):
|
||||||
if viewport == "desktop":
|
|
||||||
c1, c2 = st.columns([1, 3])
|
c1, c2 = st.columns([1, 3])
|
||||||
with c1:
|
with c1:
|
||||||
st.write(f"**Track:** {int(row['song_id'])}")
|
st.write(f"**Track:** {int(row['song_id'])}")
|
||||||
@@ -167,42 +80,3 @@ def render_live_tab(matcher, image_processor):
|
|||||||
st.audio(str(audio_path))
|
st.audio(str(audio_path))
|
||||||
else:
|
else:
|
||||||
st.warning("Файл не найден")
|
st.warning("Файл не найден")
|
||||||
else:
|
|
||||||
st.write(f"**Track:** {int(row['song_id'])} (Score: {row['final_score']:.2f})")
|
|
||||||
audio_path = matcher.get_audio_path(row['song_id'])
|
|
||||||
if audio_path:
|
|
||||||
st.audio(str(audio_path))
|
|
||||||
else:
|
|
||||||
st.warning("Файл не найден")
|
|
||||||
|
|
||||||
st.markdown("<br>", unsafe_allow_html=True)
|
|
||||||
|
|
||||||
with st.expander("Технические параметры анализа"):
|
|
||||||
c_v, c_a = st.columns(2)
|
|
||||||
c_v.metric("Valence (Настроение)", f"{data['target_v']:.2f}")
|
|
||||||
c_a.metric("Arousal (Энергия)", f"{data['target_a']:.2f}")
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
st.write("**Акустические таргеты (LLM):**")
|
|
||||||
if data["llm_profile"]:
|
|
||||||
cols_per_row = 2 if viewport == "mobile" else 3
|
|
||||||
llm_items = list(data["llm_profile"].items())
|
|
||||||
|
|
||||||
for i in range(0, len(llm_items), cols_per_row):
|
|
||||||
cols = st.columns(cols_per_row)
|
|
||||||
for j in range(cols_per_row):
|
|
||||||
if i + j < len(llm_items):
|
|
||||||
k, v = llm_items[i + j]
|
|
||||||
cols[j].metric(k, f"{v:.2f}")
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
st.write("**Обнаруженная семантика:**")
|
|
||||||
st.write(", ".join([str(c).capitalize() for c in data["semantics"]]))
|
|
||||||
|
|
||||||
st.markdown("<br>", unsafe_allow_html=True)
|
|
||||||
|
|
||||||
if st.button("Новый анализ", use_container_width=True):
|
|
||||||
st.session_state.live_state = "upload"
|
|
||||||
st.session_state.result_data = None
|
|
||||||
st.session_state.pop("uploaded_images", None)
|
|
||||||
st.rerun()
|
|
||||||
@@ -1,461 +0,0 @@
|
|||||||
.
|
|
||||||
├── bin
|
|
||||||
│ ├── activate
|
|
||||||
│ ├── activate.csh
|
|
||||||
│ ├── activate.fish
|
|
||||||
│ ├── activate.nu
|
|
||||||
│ ├── activate.ps1
|
|
||||||
│ ├── activate_this.py
|
|
||||||
│ ├── debugpy
|
|
||||||
│ ├── debugpy-adapter
|
|
||||||
│ ├── f2py
|
|
||||||
│ ├── fonttools
|
|
||||||
│ ├── httpx
|
|
||||||
│ ├── ipython
|
|
||||||
│ ├── ipython3
|
|
||||||
│ ├── isympy
|
|
||||||
│ ├── jlpm
|
|
||||||
│ ├── jsonpointer
|
|
||||||
│ ├── jsonschema
|
|
||||||
│ ├── jupyter
|
|
||||||
│ ├── jupyter-dejavu
|
|
||||||
│ ├── jupyter-events
|
|
||||||
│ ├── jupyter-execute
|
|
||||||
│ ├── jupyter-kernel
|
|
||||||
│ ├── jupyter-kernelspec
|
|
||||||
│ ├── jupyter-lab
|
|
||||||
│ ├── jupyter-labextension
|
|
||||||
│ ├── jupyter-labhub
|
|
||||||
│ ├── jupyter-migrate
|
|
||||||
│ ├── jupyter-nbconvert
|
|
||||||
│ ├── jupyter-run
|
|
||||||
│ ├── jupyter-server
|
|
||||||
│ ├── jupyter-troubleshoot
|
|
||||||
│ ├── jupyter-trust
|
|
||||||
│ ├── normalizer
|
|
||||||
│ ├── numpy-config
|
|
||||||
│ ├── pip
|
|
||||||
│ ├── pip3
|
|
||||||
│ ├── pip3.12
|
|
||||||
│ ├── proton
|
|
||||||
│ ├── proton-viewer
|
|
||||||
│ ├── pybabel
|
|
||||||
│ ├── pyftmerge
|
|
||||||
│ ├── pyftsubset
|
|
||||||
│ ├── pygmentize
|
|
||||||
│ ├── pyjson5
|
|
||||||
│ ├── python -> /usr/bin/python3
|
|
||||||
│ ├── python3 -> python
|
|
||||||
│ ├── python3.12 -> python
|
|
||||||
│ ├── send2trash
|
|
||||||
│ ├── streamlit
|
|
||||||
│ ├── streamlit.cmd
|
|
||||||
│ ├── torchfrtrace
|
|
||||||
│ ├── torchrun
|
|
||||||
│ ├── tqdm
|
|
||||||
│ ├── ttx
|
|
||||||
│ ├── watchmedo
|
|
||||||
│ └── wsdump
|
|
||||||
├── CACHEDIR.TAG
|
|
||||||
├── docker
|
|
||||||
│ ├── Dockerfile.api
|
|
||||||
│ └── Dockerfile.ui
|
|
||||||
├── docker-compose.yml
|
|
||||||
├── Dockerfile
|
|
||||||
├── .dockerignore
|
|
||||||
├── .env
|
|
||||||
├── etc
|
|
||||||
│ └── jupyter
|
|
||||||
│ ├── jupyter_notebook_config.d
|
|
||||||
│ │ └── jupyterlab.json
|
|
||||||
│ ├── jupyter_server_config.d
|
|
||||||
│ │ ├── jupyterlab.json
|
|
||||||
│ │ ├── jupyter-lsp-jupyter-server.json
|
|
||||||
│ │ ├── jupyter_server_terminals.json
|
|
||||||
│ │ └── notebook_shim.json
|
|
||||||
│ └── nbconfig
|
|
||||||
│ └── notebook.d
|
|
||||||
├── .gitignore
|
|
||||||
├── .idea
|
|
||||||
│ ├── .gitignore
|
|
||||||
│ ├── inspectionProfiles
|
|
||||||
│ │ └── profiles_settings.xml
|
|
||||||
│ ├── misc.xml
|
|
||||||
│ ├── modules.xml
|
|
||||||
│ ├── Thesis.iml
|
|
||||||
│ ├── vcs.xml
|
|
||||||
│ └── workspace.xml
|
|
||||||
├── lib
|
|
||||||
│ └── python3.12
|
|
||||||
│ └── site-packages
|
|
||||||
│ ├── altair
|
|
||||||
│ ├── altair-6.0.0.dist-info
|
|
||||||
│ ├── anyio
|
|
||||||
│ ├── anyio-4.12.1.dist-info
|
|
||||||
│ ├── argon2
|
|
||||||
│ ├── argon2_cffi-25.1.0.dist-info
|
|
||||||
│ ├── _argon2_cffi_bindings
|
|
||||||
│ ├── argon2_cffi_bindings-25.1.0.dist-info
|
|
||||||
│ ├── arrow
|
|
||||||
│ ├── arrow-1.4.0.dist-info
|
|
||||||
│ ├── asttokens
|
|
||||||
│ ├── asttokens-3.0.1.dist-info
|
|
||||||
│ ├── async_lru
|
|
||||||
│ ├── async_lru-2.0.5.dist-info
|
|
||||||
│ ├── attr
|
|
||||||
│ ├── attrs
|
|
||||||
│ ├── attrs-25.4.0.dist-info
|
|
||||||
│ ├── babel
|
|
||||||
│ ├── babel-2.17.0.dist-info
|
|
||||||
│ ├── beautifulsoup4-4.14.3.dist-info
|
|
||||||
│ ├── bleach
|
|
||||||
│ ├── bleach-6.3.0.dist-info
|
|
||||||
│ ├── blinker
|
|
||||||
│ ├── blinker-1.9.0.dist-info
|
|
||||||
│ ├── bs4
|
|
||||||
│ ├── cachetools
|
|
||||||
│ ├── cachetools-6.2.4.dist-info
|
|
||||||
│ ├── certifi
|
|
||||||
│ ├── certifi-2026.1.4.dist-info
|
|
||||||
│ ├── cffi
|
|
||||||
│ ├── cffi-2.0.0.dist-info
|
|
||||||
│ ├── _cffi_backend.cpython-312-x86_64-linux-gnu.so
|
|
||||||
│ ├── charset_normalizer
|
|
||||||
│ ├── charset_normalizer-3.4.4.dist-info
|
|
||||||
│ ├── click
|
|
||||||
│ ├── click-8.3.1.dist-info
|
|
||||||
│ ├── comm
|
|
||||||
│ ├── comm-0.2.3.dist-info
|
|
||||||
│ ├── contourpy
|
|
||||||
│ ├── contourpy-1.3.3.dist-info
|
|
||||||
│ ├── cycler
|
|
||||||
│ ├── cycler-0.12.1.dist-info
|
|
||||||
│ ├── dateutil
|
|
||||||
│ ├── debugpy
|
|
||||||
│ ├── debugpy-1.8.19.dist-info
|
|
||||||
│ ├── decorator-5.2.1.dist-info
|
|
||||||
│ ├── decorator.py
|
|
||||||
│ ├── defusedxml
|
|
||||||
│ ├── defusedxml-0.7.1.dist-info
|
|
||||||
│ ├── _distutils_hack
|
|
||||||
│ ├── distutils-precedence.pth
|
|
||||||
│ ├── .DS_Store
|
|
||||||
│ ├── executing
|
|
||||||
│ ├── executing-2.2.1.dist-info
|
|
||||||
│ ├── fastjsonschema
|
|
||||||
│ ├── fastjsonschema-2.21.2.dist-info
|
|
||||||
│ ├── filelock
|
|
||||||
│ ├── filelock-3.20.3.dist-info
|
|
||||||
│ ├── fontTools
|
|
||||||
│ ├── fonttools-4.61.1.dist-info
|
|
||||||
│ ├── fqdn
|
|
||||||
│ ├── fqdn-1.5.1.dist-info
|
|
||||||
│ ├── fsspec
|
|
||||||
│ ├── fsspec-2026.1.0.dist-info
|
|
||||||
│ ├── functorch
|
|
||||||
│ ├── git
|
|
||||||
│ ├── gitdb
|
|
||||||
│ ├── gitdb-4.0.12.dist-info
|
|
||||||
│ ├── gitpython-3.1.46.dist-info
|
|
||||||
│ ├── google
|
|
||||||
│ ├── h11
|
|
||||||
│ ├── h11-0.16.0.dist-info
|
|
||||||
│ ├── httpcore
|
|
||||||
│ ├── httpcore-1.0.9.dist-info
|
|
||||||
│ ├── httpx
|
|
||||||
│ ├── httpx-0.28.1.dist-info
|
|
||||||
│ ├── idna
|
|
||||||
│ ├── idna-3.11.dist-info
|
|
||||||
│ ├── ipykernel
|
|
||||||
│ ├── ipykernel-7.1.0.dist-info
|
|
||||||
│ ├── ipykernel_launcher.py
|
|
||||||
│ ├── IPython
|
|
||||||
│ ├── ipython-9.9.0.dist-info
|
|
||||||
│ ├── ipython_pygments_lexers-1.1.1.dist-info
|
|
||||||
│ ├── ipython_pygments_lexers.py
|
|
||||||
│ ├── isoduration
|
|
||||||
│ ├── isoduration-20.11.0.dist-info
|
|
||||||
│ ├── isympy.py
|
|
||||||
│ ├── jedi
|
|
||||||
│ ├── jedi-0.19.2.dist-info
|
|
||||||
│ ├── jinja2
|
|
||||||
│ ├── jinja2-3.1.6.dist-info
|
|
||||||
│ ├── joblib
|
|
||||||
│ ├── joblib-1.5.3.dist-info
|
|
||||||
│ ├── json5
|
|
||||||
│ ├── json5-0.13.0.dist-info
|
|
||||||
│ ├── jsonpointer-3.0.0.dist-info
|
|
||||||
│ ├── jsonpointer.py
|
|
||||||
│ ├── jsonschema
|
|
||||||
│ ├── jsonschema-4.26.0.dist-info
|
|
||||||
│ ├── jsonschema_specifications
|
|
||||||
│ ├── jsonschema_specifications-2025.9.1.dist-info
|
|
||||||
│ ├── jupyter_client
|
|
||||||
│ ├── jupyter_client-8.8.0.dist-info
|
|
||||||
│ ├── jupyter_core
|
|
||||||
│ ├── jupyter_core-5.9.1.dist-info
|
|
||||||
│ ├── jupyter_events
|
|
||||||
│ ├── jupyter_events-0.12.0.dist-info
|
|
||||||
│ ├── jupyterlab
|
|
||||||
│ ├── jupyterlab-4.5.1.dist-info
|
|
||||||
│ ├── jupyterlab_pygments
|
|
||||||
│ ├── jupyterlab_pygments-0.3.0.dist-info
|
|
||||||
│ ├── jupyterlab_server
|
|
||||||
│ ├── jupyterlab_server-2.28.0.dist-info
|
|
||||||
│ ├── jupyter_lsp
|
|
||||||
│ ├── jupyter_lsp-2.3.0.dist-info
|
|
||||||
│ ├── jupyter.py
|
|
||||||
│ ├── jupyter_server
|
|
||||||
│ ├── jupyter_server-2.17.0.dist-info
|
|
||||||
│ ├── jupyter_server_terminals
|
|
||||||
│ ├── jupyter_server_terminals-0.5.3.dist-info
|
|
||||||
│ ├── kiwisolver
|
|
||||||
│ ├── kiwisolver-1.4.9.dist-info
|
|
||||||
│ ├── lark
|
|
||||||
│ ├── lark-1.3.1.dist-info
|
|
||||||
│ ├── markupsafe
|
|
||||||
│ ├── markupsafe-3.0.3.dist-info
|
|
||||||
│ ├── matplotlib
|
|
||||||
│ ├── matplotlib-3.10.8.dist-info
|
|
||||||
│ ├── matplotlib_inline
|
|
||||||
│ ├── matplotlib_inline-0.2.1.dist-info
|
|
||||||
│ ├── mistune
|
|
||||||
│ ├── mistune-3.2.0.dist-info
|
|
||||||
│ ├── mpl_toolkits
|
|
||||||
│ ├── mpmath
|
|
||||||
│ ├── mpmath-1.3.0.dist-info
|
|
||||||
│ ├── narwhals
|
|
||||||
│ ├── narwhals-2.15.0.dist-info
|
|
||||||
│ ├── nbclient
|
|
||||||
│ ├── nbclient-0.10.4.dist-info
|
|
||||||
│ ├── nbconvert
|
|
||||||
│ ├── nbconvert-7.16.6.dist-info
|
|
||||||
│ ├── nbformat
|
|
||||||
│ ├── nbformat-5.10.4.dist-info
|
|
||||||
│ ├── nest_asyncio-1.6.0.dist-info
|
|
||||||
│ ├── nest_asyncio.py
|
|
||||||
│ ├── networkx
|
|
||||||
│ ├── networkx-3.6.1.dist-info
|
|
||||||
│ ├── notebook_shim
|
|
||||||
│ ├── notebook_shim-0.2.4.dist-info
|
|
||||||
│ ├── numpy
|
|
||||||
│ ├── numpy-2.4.1.dist-info
|
|
||||||
│ ├── numpy.libs
|
|
||||||
│ ├── nvidia
|
|
||||||
│ ├── nvidia_cublas_cu12-12.8.4.1.dist-info
|
|
||||||
│ ├── nvidia_cuda_cupti_cu12-12.8.90.dist-info
|
|
||||||
│ ├── nvidia_cuda_nvrtc_cu12-12.8.93.dist-info
|
|
||||||
│ ├── nvidia_cuda_runtime_cu12-12.8.90.dist-info
|
|
||||||
│ ├── nvidia_cudnn_cu12-9.10.2.21.dist-info
|
|
||||||
│ ├── nvidia_cufft_cu12-11.3.3.83.dist-info
|
|
||||||
│ ├── nvidia_cufile_cu12-1.13.1.3.dist-info
|
|
||||||
│ ├── nvidia_curand_cu12-10.3.9.90.dist-info
|
|
||||||
│ ├── nvidia_cusolver_cu12-11.7.3.90.dist-info
|
|
||||||
│ ├── nvidia_cusparse_cu12-12.5.8.93.dist-info
|
|
||||||
│ ├── nvidia_cusparselt_cu12-0.7.1.dist-info
|
|
||||||
│ ├── nvidia_nccl_cu12-2.27.5.dist-info
|
|
||||||
│ ├── nvidia_nvjitlink_cu12-12.8.93.dist-info
|
|
||||||
│ ├── nvidia_nvshmem_cu12-3.3.20.dist-info
|
|
||||||
│ ├── nvidia_nvtx_cu12-12.8.90.dist-info
|
|
||||||
│ ├── packaging
|
|
||||||
│ ├── packaging-25.0.dist-info
|
|
||||||
│ ├── pandas
|
|
||||||
│ ├── pandas-2.3.3.dist-info
|
|
||||||
│ ├── pandocfilters-1.5.1.dist-info
|
|
||||||
│ ├── pandocfilters.py
|
|
||||||
│ ├── parso
|
|
||||||
│ ├── parso-0.8.5.dist-info
|
|
||||||
│ ├── pexpect
|
|
||||||
│ ├── pexpect-4.9.0.dist-info
|
|
||||||
│ ├── PIL
|
|
||||||
│ ├── pillow-12.1.0.dist-info
|
|
||||||
│ ├── pillow.libs
|
|
||||||
│ ├── pip
|
|
||||||
│ ├── pip-25.3.dist-info
|
|
||||||
│ ├── pkg_resources
|
|
||||||
│ ├── platformdirs
|
|
||||||
│ ├── platformdirs-4.5.1.dist-info
|
|
||||||
│ ├── prometheus_client
|
|
||||||
│ ├── prometheus_client-0.23.1.dist-info
|
|
||||||
│ ├── prompt_toolkit
|
|
||||||
│ ├── prompt_toolkit-3.0.52.dist-info
|
|
||||||
│ ├── protobuf-6.33.4.dist-info
|
|
||||||
│ ├── psutil
|
|
||||||
│ ├── psutil-7.2.1.dist-info
|
|
||||||
│ ├── ptyprocess
|
|
||||||
│ ├── ptyprocess-0.7.0.dist-info
|
|
||||||
│ ├── pure_eval
|
|
||||||
│ ├── pure_eval-0.2.3.dist-info
|
|
||||||
│ ├── pyarrow
|
|
||||||
│ ├── pyarrow-22.0.0.dist-info
|
|
||||||
│ ├── pycparser
|
|
||||||
│ ├── pycparser-2.23.dist-info
|
|
||||||
│ ├── pydeck
|
|
||||||
│ ├── pydeck-0.9.1.dist-info
|
|
||||||
│ ├── pygments
|
|
||||||
│ ├── pygments-2.19.2.dist-info
|
|
||||||
│ ├── pylab.py
|
|
||||||
│ ├── pyparsing
|
|
||||||
│ ├── pyparsing-3.3.1.dist-info
|
|
||||||
│ ├── python_dateutil-2.9.0.post0.dist-info
|
|
||||||
│ ├── pythonjsonlogger
|
|
||||||
│ ├── python_json_logger-4.0.0.dist-info
|
|
||||||
│ ├── pytz
|
|
||||||
│ ├── pytz-2025.2.dist-info
|
|
||||||
│ ├── pyyaml-6.0.3.dist-info
|
|
||||||
│ ├── pyzmq-27.1.0.dist-info
|
|
||||||
│ ├── pyzmq.libs
|
|
||||||
│ ├── referencing
|
|
||||||
│ ├── referencing-0.37.0.dist-info
|
|
||||||
│ ├── requests
|
|
||||||
│ ├── requests-2.32.5.dist-info
|
|
||||||
│ ├── rfc3339_validator-0.1.4.dist-info
|
|
||||||
│ ├── rfc3339_validator.py
|
|
||||||
│ ├── rfc3986_validator-0.1.1.dist-info
|
|
||||||
│ ├── rfc3986_validator.py
|
|
||||||
│ ├── rfc3987_syntax
|
|
||||||
│ ├── rfc3987_syntax-1.1.0.dist-info
|
|
||||||
│ ├── rpds
|
|
||||||
│ ├── rpds_py-0.30.0.dist-info
|
|
||||||
│ ├── scikit_learn-1.8.0.dist-info
|
|
||||||
│ ├── scikit_learn.libs
|
|
||||||
│ ├── scipy
|
|
||||||
│ ├── scipy-1.17.0.dist-info
|
|
||||||
│ ├── scipy.libs
|
|
||||||
│ ├── send2trash
|
|
||||||
│ ├── send2trash-2.0.0.dist-info
|
|
||||||
│ ├── setuptools
|
|
||||||
│ ├── setuptools-80.9.0.dist-info
|
|
||||||
│ ├── six-1.17.0.dist-info
|
|
||||||
│ ├── six.py
|
|
||||||
│ ├── sklearn
|
|
||||||
│ ├── smmap
|
|
||||||
│ ├── smmap-5.0.2.dist-info
|
|
||||||
│ ├── soupsieve
|
|
||||||
│ ├── soupsieve-2.8.1.dist-info
|
|
||||||
│ ├── stack_data
|
|
||||||
│ ├── stack_data-0.6.3.dist-info
|
|
||||||
│ ├── streamlit
|
|
||||||
│ ├── streamlit-1.53.0.dist-info
|
|
||||||
│ ├── sympy
|
|
||||||
│ ├── sympy-1.14.0.dist-info
|
|
||||||
│ ├── tenacity
|
|
||||||
│ ├── tenacity-9.1.2.dist-info
|
|
||||||
│ ├── terminado
|
|
||||||
│ ├── terminado-0.18.1.dist-info
|
|
||||||
│ ├── threadpoolctl-3.6.0.dist-info
|
|
||||||
│ ├── threadpoolctl.py
|
|
||||||
│ ├── tinycss2
|
|
||||||
│ ├── tinycss2-1.4.0.dist-info
|
|
||||||
│ ├── toml
|
|
||||||
│ ├── toml-0.10.2.dist-info
|
|
||||||
│ ├── torch
|
|
||||||
│ ├── torch-2.9.1.dist-info
|
|
||||||
│ ├── torchaudio
|
|
||||||
│ ├── torchaudio-2.9.1.dist-info
|
|
||||||
│ ├── torchgen
|
|
||||||
│ ├── torchvision
|
|
||||||
│ ├── torchvision-0.24.1.dist-info
|
|
||||||
│ ├── torchvision.libs
|
|
||||||
│ ├── tornado
|
|
||||||
│ ├── tornado-6.5.4.dist-info
|
|
||||||
│ ├── tqdm
|
|
||||||
│ ├── tqdm-4.67.1.dist-info
|
|
||||||
│ ├── traitlets
|
|
||||||
│ ├── traitlets-5.14.3.dist-info
|
|
||||||
│ ├── triton
|
|
||||||
│ ├── triton-3.5.1.dist-info
|
|
||||||
│ ├── typing_extensions-4.15.0.dist-info
|
|
||||||
│ ├── typing_extensions.py
|
|
||||||
│ ├── tzdata
|
|
||||||
│ ├── tzdata-2025.3.dist-info
|
|
||||||
│ ├── uri_template
|
|
||||||
│ ├── uri_template-1.3.0.dist-info
|
|
||||||
│ ├── urllib3
|
|
||||||
│ ├── urllib3-2.6.3.dist-info
|
|
||||||
│ ├── _virtualenv.pth
|
|
||||||
│ ├── _virtualenv.py
|
|
||||||
│ ├── watchdog
|
|
||||||
│ ├── watchdog-6.0.0.dist-info
|
|
||||||
│ ├── wcwidth
|
|
||||||
│ ├── wcwidth-0.2.14.dist-info
|
|
||||||
│ ├── webcolors
|
|
||||||
│ ├── webcolors-25.10.0.dist-info
|
|
||||||
│ ├── webencodings
|
|
||||||
│ ├── webencodings-0.5.1.dist-info
|
|
||||||
│ ├── websocket
|
|
||||||
│ ├── websocket_client-1.9.0.dist-info
|
|
||||||
│ ├── _yaml
|
|
||||||
│ ├── yaml
|
|
||||||
│ └── zmq
|
|
||||||
├── Makefile
|
|
||||||
├── NFS
|
|
||||||
├── poetry.lock
|
|
||||||
├── pyproject.toml
|
|
||||||
├── pyvenv.cfg
|
|
||||||
├── README.md
|
|
||||||
├── requirements.txt
|
|
||||||
├── runs
|
|
||||||
├── share
|
|
||||||
│ ├── applications
|
|
||||||
│ │ └── jupyterlab.desktop
|
|
||||||
│ ├── icons
|
|
||||||
│ │ └── hicolor
|
|
||||||
│ │ └── scalable
|
|
||||||
│ ├── jupyter
|
|
||||||
│ │ ├── kernels
|
|
||||||
│ │ │ └── python3
|
|
||||||
│ │ ├── lab
|
|
||||||
│ │ │ ├── schemas
|
|
||||||
│ │ │ ├── static
|
|
||||||
│ │ │ └── themes
|
|
||||||
│ │ ├── labextensions
|
|
||||||
│ │ │ └── jupyterlab_pygments
|
|
||||||
│ │ ├── nbconvert
|
|
||||||
│ │ │ └── templates
|
|
||||||
│ │ └── nbextensions
|
|
||||||
│ │ └── pydeck
|
|
||||||
│ └── man
|
|
||||||
│ └── man1
|
|
||||||
│ ├── ipython.1
|
|
||||||
│ ├── isympy.1
|
|
||||||
│ └── ttx.1
|
|
||||||
├── src
|
|
||||||
│ ├── 5_epoch_emoset_resnet50_finetuned_2.41M.pth
|
|
||||||
│ ├── api.py
|
|
||||||
│ ├── data_loader.py
|
|
||||||
│ ├── dataset_paths_cache.pkl
|
|
||||||
│ ├── emoset_resnet50_best.pth
|
|
||||||
│ ├── emoset_resnet50_finetuned_2_41M_best.pth
|
|
||||||
│ ├── emoset_resnet50_resume.pth
|
|
||||||
│ ├── emoset_test_embeddings.npy
|
|
||||||
│ ├── emoset_test_labels.npy
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── music_engine
|
|
||||||
│ │ ├── image_processor.py
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── llm_bridge.py
|
|
||||||
│ │ ├── matcher.py
|
|
||||||
│ │ └── va_regressor.pkl
|
|
||||||
│ ├── scripts
|
|
||||||
│ │ ├── 00_setup_env.sh
|
|
||||||
│ │ ├── 01_download_DEAM.py
|
|
||||||
│ │ ├── 02_download_EmoSet.py
|
|
||||||
│ │ ├── 11_prerp_DEAM.py
|
|
||||||
│ │ ├── 20_bench_GPU.py
|
|
||||||
│ │ ├── 21_train_images.ipynb
|
|
||||||
│ │ ├── 22_extract_embeddings.ipynb
|
|
||||||
│ │ ├── 23_aggregate_DEAM_timeline.py
|
|
||||||
│ │ ├── 24_train_regressor.py
|
|
||||||
│ │ ├── 31_finetune_2.41M.py
|
|
||||||
│ │ ├── 90_acc_images_model.ipynb
|
|
||||||
│ │ └── 91_generate_metrics.py
|
|
||||||
│ └── tabs
|
|
||||||
│ ├── tab_dataset.py
|
|
||||||
│ └── tab_live.py
|
|
||||||
├── tree.txt
|
|
||||||
└── .vscode
|
|
||||||
├── launch.json
|
|
||||||
└── tasks.json
|
|
||||||
|
|
||||||
322 directories, 137 files
|
|
||||||
Reference in New Issue
Block a user