ref: refactor before chekout

This commit is contained in:
zin
2026-06-02 17:27:05 +00:00
parent 9ce92b70a9
commit f04cd7359b
27 changed files with 1103 additions and 4266 deletions
+20 -19
View File
@@ -1,54 +1,55 @@
import streamlit as st
import os
from pathlib import Path
import pandas as pd
import numpy as np
import streamlit as st
from music_engine.matcher import MusicMatcher
from music_engine.image_processor import ImageProcessor
# Определяем базовую директорию (папка src)
BASE_DIR = Path(__file__).resolve().parent
@st.cache_resource
def load_music_engine():
"""Загрузка базы данных и модели регрессора."""
# music_db.csv лежит в dataset/DEAM/ (на уровень выше от src)
# Инициализация базы данных и регрессора для музыкального мэтчинга
db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
# va_regressor.pkl лежит в src/music_engine/
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
if not db_path.exists():
print(f"⚠️ Файл базы {db_path} не найден!")
print(f"Музыкальная БД не найдена: {db_path}")
return None
return MusicMatcher(db_path=db_path, model_path=model_path)
@st.cache_resource
def load_image_processor():
"""Загрузка ResNet-50 для извлечения признаков на лету."""
# Файл весов лежит в той же папке src, что и этот скрипт
# Модуль обработки визуальных признаков
model_path = BASE_DIR / "emoset_resnet50_best.pth"
# Обработка пути при вызове из корневой директории
if not model_path.exists():
print(f"Ошибка: Веса не найдены по пути: {model_path}")
# Если не нашли в src, попробуем поискать в корне проекта на всякий случай
model_path = BASE_DIR.parent / "emoset_resnet50_best.pth"
return ImageProcessor(model_path=model_path)
@st.cache_data
def load_emoset_data():
"""Загрузка тестовой выборки EmoSet для первой вкладки."""
# Пути относительно корня проекта
csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
img_dir = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
# Выборка данных датасета для вкладки отладки
dataset_root = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test"
csv_path = dataset_root / "labels.csv"
img_dir = dataset_root / "images"
emb_path = BASE_DIR / "emoset_test_embeddings.npy"
lbl_path = BASE_DIR / "emoset_test_labels.npy"
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
print("Тестовые файлы датасета не найдены, вкладка отладки может работать некорректно")
return None, None, None, None
df = pd.read_csv(csv_path)
image_list = df['filename'].tolist()
embs = np.load(emb_path)
lbls = np.load(lbl_path)
labels_df = pd.read_csv(csv_path)
return image_list, embs, lbls, img_dir
test_filenames = labels_df['filename'].tolist()
test_embeddings = np.load(emb_path)
test_labels = np.load(lbl_path)
return test_filenames, test_embeddings, test_labels, img_dir
+34 -21
View File
@@ -1,49 +1,62 @@
import numpy as np
from pathlib import Path
from PIL import Image
import torch
import torchvision.transforms as T
from PIL import Image
import timm
from pathlib import Path
import numpy as np
from transformers import Blip2Processor, Blip2ForConditionalGeneration
class ImageProcessor:
def __init__(self, model_path: Path | str):
def __init__(self, weights_path: str | Path):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.emo_model = timm.create_model('resnet50', pretrained=False, num_classes=8)
if Path(model_path).exists():
self.emo_model.load_state_dict(torch.load(model_path, map_location=self.device))
# Модель извлечения визуальных признаков
self.feature_extractor = timm.create_model('resnet50', pretrained=False, num_classes=8)
self.emo_model.fc = torch.nn.Identity()
self.emo_model.to(self.device).eval()
if Path(weights_path).exists():
self.feature_extractor.load_state_dict(torch.load(weights_path, map_location=self.device))
else:
print(f"Не удалось найти веса ResNet по пути: {weights_path}")
self.emo_transform = T.Compose([
# Удаление слоя классификации для вывода сырого вектора эмбеддингов
self.feature_extractor.fc = torch.nn.Identity()
self.feature_extractor.to(self.device).eval()
# Трансформации для предварительной обработки изображений
self.preprocess_image = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print("Загрузка BLIP-2...")
# Модуль семантического описания сцены
print("Инициализация BLIP-2...")
self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
torch_dtype=torch.float16
).to(self.device)
print("BLIP-2 и ResNet-50 готовы.")
@torch.no_grad()
def extract_embedding(self, image: Image.Image) -> np.ndarray:
img_rgb = image.convert('RGB')
img_tensor = self.emo_transform(img_rgb).unsqueeze(0).to(self.device)
return self.emo_model(img_tensor).cpu().numpy().flatten()
# Извлечение эмбеддингов из изображения
rgb_image = image.convert('RGB')
img_tensor = self.preprocess_image(rgb_image).unsqueeze(0).to(self.device)
features = self.feature_extractor(img_tensor)
features_np = features.cpu().numpy()
return features_np.flatten()
@torch.no_grad()
def describe_scene(self, image: Image.Image) -> str:
"""Генерирует описание через BLIP-2."""
img_rgb = image.convert('RGB')
inputs = self.blip_processor(images=img_rgb, return_tensors="pt").to(self.device, torch.float16)
# Генерация текстового описания сцены
rgb_image = image.convert('RGB')
inputs = self.blip_processor(images=rgb_image, return_tensors="pt").to(self.device, torch.float16)
generated_ids = self.blip_model.generate(**inputs, max_new_tokens=40)
caption = self.blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return caption
scene_description = self.blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return scene_description.strip()
+32 -27
View File
@@ -1,31 +1,31 @@
import requests
import json
import re
import json
import requests
class LLMAcousticBridge:
def __init__(self, model_name="dolphin-llama3:8b"):
self.model_name = model_name
def __init__(self, target_model="dolphin-llama3:8b"):
self.api_url = "http://localhost:11434/api/generate"
self.model = target_model
def _clean_json(self, text):
"""Вытаскивает чистый JSON из ответа нейросети."""
def _extract_json(self, raw_text: str):
# Проверка на ИИдиота, LLM иногда игнорирует format="json" и оборачивает ответ в маркдаун
try:
match = re.search(r'\{.*\}', text, re.DOTALL)
match = re.search(r'\{.*\}', raw_text, re.DOTALL)
if match:
return json.loads(match.group(0))
return json.loads(text)
except:
return json.loads(raw_text)
except json.JSONDecodeError:
# Если ИИдиот
return None
def get_acoustic_profile(self, valence, arousal, scene_descriptions):
"""Просит LLM сгенерировать идеальный звук под описание."""
# Объединяем описания, если загружено несколько фото
context_str = " | ".join(scene_descriptions) if scene_descriptions else "abstract scene"
def get_acoustic_profile(self, v_score: float, a_score: float, scene_context: list) -> dict | None:
# Агрегация контекста для обработки серии снимков (события)
context_merged = " | ".join(scene_context) if scene_context else "abstract scene"
prompt = f"""You are an expert music producer and acoustic engineer.
system_prompt = f"""You are an expert music producer and acoustic engineer.
Analyze the visual context and emotions to determine the ideal background music properties.
Emotions: Valence {valence:.1f}/9.0 (Positivity), Arousal {arousal:.1f}/9.0 (Energy).
Visual Context: {context_str}.
Emotions: Valence {v_score:.1f}/9.0 (Positivity), Arousal {a_score:.1f}/9.0 (Energy).
Visual Context: {context_merged}.
Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0.
1. "energy": (Loudness/Density. High for massive/busy scenes, Low for calm)
@@ -39,22 +39,27 @@ Return ONLY a valid JSON object. Do not add any text or explanation.
Example: {{"energy": 0.5, "flux": 0.2, "centroid": 0.4, "pitch": 0.3, "hnr": 0.8, "zcr": 0.1}}"""
try:
# Отправка промпта локальной Ollama
response = requests.post(self.api_url, json={
"model": self.model_name,
"prompt": prompt,
"model": self.model,
"prompt": system_prompt,
"stream": False,
"format": "json"
}, timeout=30)
}, timeout=45)
response.raise_for_status()
result_text = response.json().get("response", "")
profile = self._clean_json(result_text)
raw_response = response.json().get("response", "")
profile_data = self._extract_json(raw_response)
# Проверяем, что все нужные ключи есть
required_keys = ['energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr']
if profile and all(k in profile for k in required_keys):
return profile
# Валидация структуры ответа
expected_features = {'energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr'}
if profile_data and expected_features.issubset(profile_data.keys()):
return profile_data
print("LLM вернула неполный или некорректный набор акустических признаков")
return None
except Exception as e:
print(f"Ошибка связи с локальной LLM: {e}")
except requests.exceptions.RequestException as req_err:
print(f"Не удалось подключиться к Ollama: {req_err}")
return None
+41 -25
View File
@@ -1,67 +1,83 @@
import joblib
import numpy as np
import pandas as pd
from pathlib import Path
import joblib
class MusicMatcher:
def __init__(self, db_path: Path | str, model_path: Path | str):
# Загружаем твою новую, обогащенную базу
# Загрузка базы данных музыкальных произведений
self.music_db = pd.read_csv(db_path)
self.acoustic_features = ['energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr']
# Удаляем строки, где нет акустических фич
self.music_db = self.music_db.dropna(subset=['valence', 'arousal'] + self.acoustic_features)
# Удаление записей с пропущенными целевыми или акустическими признаками
target_columns = ['valence', 'arousal'] + self.acoustic_features
self.music_db = self.music_db.dropna(subset=target_columns)
# Нормализуем акустику от 0 до 1, чтобы сравнивать с ответом LLM
# Масштабирование акустических параметров к диапазону [0, 1]
self.norm_db = self.music_db.copy()
for feat in self.acoustic_features:
f_min, f_max = self.norm_db[feat].min(), self.norm_db[feat].max()
f_min = self.norm_db[feat].min()
f_max = self.norm_db[feat].max()
if f_max > f_min:
self.norm_db[f"norm_{feat}"] = (self.norm_db[feat] - f_min) / (f_max - f_min)
else:
self.norm_db[f"norm_{feat}"] = 0.0
# Определение путей к аудиофайлам и загрузка модели регрессии
self.audio_dir = Path(db_path).parent / "DEAM_audio" / "MEMD_audio"
self.regressor = joblib.load(model_path) if Path(model_path).exists() else None
def predict_va(self, embedding: np.ndarray):
if self.regressor:
prediction = self.regressor.predict(embedding.reshape(1, -1))[0]
return np.clip(prediction[0], 1.0, 9.0), np.clip(prediction[1], 1.0, 9.0)
if Path(model_path).exists():
self.regressor = joblib.load(model_path)
else:
self.regressor = None
def predict_va(self, embedding: np.ndarray) -> tuple[float, float]:
# Прогнозирование координат Valence/Arousal по визуальному эмбеддингу
if not self.regressor:
return 5.0, 5.0
def get_audio_path(self, song_id):
if not self.audio_dir.exists(): return None
raw_prediction = self.regressor.predict(embedding.reshape(1, -1))[0]
valence_pred = np.clip(raw_prediction[0], 1.0, 9.0)
arousal_pred = np.clip(raw_prediction[1], 1.0, 9.0)
return float(valence_pred), float(arousal_pred)
def get_audio_path(self, song_id: int | float | str) -> Path | None:
# Поиск физического пути к аудиофайлу в зависимости от расширения
if not self.audio_dir.exists():
return None
clean_id = str(int(float(song_id)))
for ext in ['.mp3', '.wav']:
path = self.audio_dir / f"{clean_id}{ext}"
if path.exists(): return path
if path.exists():
return path
return None
def find_nearest_tracks(self, target_v: float, target_a: float, llm_profile: dict = None, top_k: int = 5):
# 1. Эмоциональная дистанция (как и раньше)
emo_dist = np.sqrt(
1.0 * (self.norm_db['valence'] - target_v)**2 +
2.5 * (self.norm_db['arousal'] - target_a)**2
)
self.norm_db['emo_distance'] = emo_dist
def find_nearest_tracks(self, target_v: float, target_a: float, llm_profile: dict = None, top_k: int = 5) -> pd.DataFrame:
# Расчет евклидова расстояния в эмоциональном пространстве Рассела
v_dist = (self.norm_db['valence'] - target_v) ** 2
a_dist = (self.norm_db['arousal'] - target_a) ** 2
# Если LLM не дала ответ, сортируем только по эмоциям
# Взвешенное расстояние с приоритетом оси активации (Arousal)
self.norm_db['emo_distance'] = np.sqrt(1.0 * v_dist + 2.5 * a_dist)
# Ранжирование только по эмоциональному критерию при отсутствии профиля LLM
if not llm_profile:
self.norm_db['final_score'] = self.norm_db['emo_distance']
return self.norm_db.sort_values(by='final_score').head(top_k)
# 2. Акустическая дистанция (сравниваем треки с запросом LLM)
# Расчет отклонений по вектору акустических параметров LLM
acoustic_penalty = np.zeros(len(self.norm_db))
for feat in self.acoustic_features:
if feat in llm_profile:
target_val = llm_profile[feat]
acoustic_penalty += np.abs(self.norm_db[f"norm_{feat}"] - target_val)
# Усредняем штраф
# Нормирование акустической дистанции
self.norm_db['acoustic_distance'] = acoustic_penalty / len(self.acoustic_features)
# 3. Финальный Score (Смесь Эмоций и Акустики). Коэф 4.0 делает акустику важной!
# Вычисление интегральной метрики соответствия (мультимодальный скоринг)
self.norm_db['final_score'] = self.norm_db['emo_distance'] + (self.norm_db['acoustic_distance'] * 4.0)
return self.norm_db.sort_values(by='final_score').head(top_k)
+20
View File
@@ -0,0 +1,20 @@
import shutil
from pathlib import Path
import kagglehub
dataset_dir = Path("../dataset/DEAM")
dataset_dir.mkdir(parents=True, exist_ok=True)
print("Скачивание датасета DEAM...")
# kagglehub по умолчанию тянет данные в системный кэш (~/.cache)
cache_path = kagglehub.dataset_download("imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music")
print(f"Загружено в кэш: {cache_path}")
print(f"Перенос файлов в {dataset_dir} и очистка временной директории...")
# Перемещаем данные
shutil.copytree(cache_path, dataset_dir, dirs_exist_ok=True)
shutil.rmtree(cache_path)
print("Готово. Датасет DEAM загружен, кэш очищен.")
+56
View File
@@ -0,0 +1,56 @@
import csv
from pathlib import Path
from datasets import load_dataset
from tqdm import tqdm
# Конфигурация корневой директории локального датасета
DATASET_DIR = Path("../dataset/EmoSet-118K")
def process_and_save_split(dataset_split, split_name: str, output_dir: Path):
# Подготовка структуры директорий для текущей выборки
split_dir = output_dir / split_name
img_dir = split_dir / "images"
img_dir.mkdir(parents=True, exist_ok=True)
labels_path = split_dir / "labels.csv"
print(f"Обработка выборки: {split_name}...")
# Открытие файла разметки перед циклом для минимизации I/O операций диска
with open(labels_path, mode="w", newline="", encoding="utf-8") as csv_file:
writer = csv.writer(csv_file)
writer.writerow(["filename", "label"])
for example in tqdm(dataset_split, desc=split_name):
img = example["image"]
emotion_label = example["emotion"]
img_id = example["image_id"]
file_name = f"{img_id}.jpg"
# Принудительная конвертация в RGB для безопасного сохранения в JPEG-формате
if img.mode != "RGB":
img = img.convert("RGB")
img.save(img_dir / file_name, format="JPEG")
writer.writerow([file_name, emotion_label])
if __name__ == "__main__":
DATASET_DIR.mkdir(exist_ok=True, parents=True)
# Инициализация подключения к Hugging Face Hub
print("Загрузка метаданных EmoSet-118K...")
raw_dataset = load_dataset("Woleek/EmoSet-118K")
# Итеративная выгрузка размеченных данных
for split_key in ["train", "val", "test"]:
if split_key in raw_dataset:
process_and_save_split(
dataset_split=raw_dataset[split_key],
split_name=split_key,
output_dir=DATASET_DIR
)
print("Экспорт датасета завершен.")
+30
View File
@@ -0,0 +1,30 @@
import pandas as pd
from pathlib import Path
# Конфигурация локальных путей
SOURCE_CSV = Path("../../dataset/DEAM/DEAM_Annotations/annotations/annotations averaged per song/song_level/static_annotations_averaged_songs_1_2000.csv")
OUTPUT_CSV = Path("../../dataset/DEAM/music_db.csv")
def prepare_deam_database():
if not SOURCE_CSV.exists():
print(f"Исходный файл аннотаций не найден: {SOURCE_CSV}")
return
print("Обработка разметки датасета DEAM...")
# Загрузка сырых данных с очисткой артефактов форматирования
raw_df = pd.read_csv(SOURCE_CSV, skipinitialspace=True)
# Экстракция координат пространства Рассела (Valence/Arousal)
processed_df = raw_df[['song_id', 'valence_mean', 'arousal_mean']].copy()
processed_df.columns = ['song_id', 'valence', 'arousal']
# Приведение идентификаторов к формату файловой системы (int)
processed_df['song_id'] = processed_df['song_id'].astype(int)
processed_df.to_csv(OUTPUT_CSV, index=False)
print(f"База успешно сформирована. Всего записей: {len(processed_df)}")
if __name__ == "__main__":
prepare_deam_database()
+60
View File
@@ -0,0 +1,60 @@
import time
import torch
import torch.nn as nn
import torch.optim as optim
# Конфигурация параметров нагрузочного тестирования
NUM_SAMPLES = 300_000
DIM_IN = 4096
DIM_OUT = 10
BATCH_SIZE = 16_384
NUM_STEPS = 1000
def run_gpu_benchmark():
# Проверка доступности аппаратного ускорения
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Инициализация стресс-теста на устройстве: {device}")
# Генерация синтетического датасета для аллокации VRAM
x_data = torch.randn(NUM_SAMPLES, DIM_IN, device=device, dtype=torch.float32)
y_data = torch.randn(NUM_SAMPLES, DIM_OUT, device=device, dtype=torch.float32)
# Архитектура тестовой полносвязной сети
model = nn.Sequential(
nn.Linear(DIM_IN, 2048),
nn.ReLU(),
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, DIM_OUT)
).to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print("Начало прогрева GPU и симуляции цикла обучения...")
start_time = time.time()
for step in range(NUM_STEPS):
# Сэмплирование случайного батча
idx = torch.randint(0, NUM_SAMPLES, (BATCH_SIZE,), device=device)
x_batch = x_data[idx]
y_batch = y_data[idx]
optimizer.zero_grad()
predictions = model(x_batch)
loss = loss_fn(predictions, y_batch)
loss.backward()
optimizer.step()
# Логирование статуса (каждые 100 итераций для снижения I/O overhead)
if step % 100 == 0:
print(f"Итерация {step}/{NUM_STEPS} | Текущий loss: {loss.item():.4f}")
end_time = time.time()
print(f"Стресс-тест завершен. Общее время: {end_time - start_time:.2f} сек.")
if __name__ == "__main__":
run_gpu_benchmark()
@@ -3,30 +3,22 @@
{
"cell_type": "code",
"execution_count": null,
"id": "9336560f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0c00b67b",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"import pandas as pd\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torchvision.transforms as T\n",
"\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"from tqdm import tqdm\n",
"\n",
"import timm\n",
"import numpy as np\n"
"import timm"
]
},
{
@@ -47,40 +39,52 @@
}
],
"source": [
"# === CONFIG ===\n",
"# Конфигурация параметров обучения и путей файловой системы\n",
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
"BATCH_SIZE = 64\n",
"EPOCHS = 15\n",
"LR = 3e-4\n",
"NUM_WORKERS = 24\n",
"NUM_WORKERS = 40\n",
"\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"DEVICE\n"
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Аппаратное ускорение: {device}\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "9f749add",
"metadata": {},
"outputs": [],
"source": [
"class EmoSetDataset(Dataset):\n",
" def __init__(self, root, split):\n",
" def __init__(self, root: Path | str, split: str):\n",
" self.root = Path(root) / split\n",
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
"\n",
" # Формирование словарей маппинга классов\n",
" self.labels = sorted(self.df[\"label\"].unique())\n",
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
"\n",
" self.transform = T.Compose([\n",
" T.Resize((224, 224)),\n",
" # Базовые трансформации для валидации и теста\n",
" base_tf = [\n",
" T.ToTensor(),\n",
" T.Normalize(\n",
" mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225]\n",
" )\n",
" T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
" ]\n",
"\n",
" # Внедрение аугментации исключительно для обучающей выборки (предотвращение переобучения)\n",
" if split == \"train\":\n",
" self.transform = T.Compose([\n",
" T.RandomResizedCrop(224),\n",
" T.RandomHorizontalFlip(),\n",
" *base_tf\n",
" ])\n",
" else:\n",
" self.transform = T.Compose([\n",
" T.Resize(256),\n",
" T.CenterCrop(224),\n",
" *base_tf\n",
" ])\n",
"\n",
" def __len__(self):\n",
@@ -90,16 +94,21 @@
" row = self.df.iloc[idx]\n",
" img_path = self.root / \"images\" / row[\"filename\"]\n",
"\n",
" # Обработка возможных исключений ввода-вывода (поврежденные JPEG-файлы в датасете)\n",
" try:\n",
" img = Image.open(img_path).convert(\"RGB\")\n",
" img = self.transform(img)\n",
" except Exception:\n",
" img = Image.new(\"RGB\", (224, 224), (0, 0, 0))\n",
"\n",
" label = self.label2idx[row[\"label\"]]\n",
" return img, label\n"
" img_tensor = self.transform(img)\n",
" label_idx = self.label2idx[row[\"label\"]]\n",
" \n",
" return img_tensor, label_idx"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "c8805341",
"metadata": {},
"outputs": [
@@ -112,9 +121,11 @@
}
],
"source": [
"# Подготовка объектов выборки\n",
"train_ds = EmoSetDataset(DATA_ROOT, \"train\")\n",
"val_ds = EmoSetDataset(DATA_ROOT, \"val\")\n",
"\n",
"# Инициализация итераторов с закреплением памяти (pin_memory) для ускорения передачи на GPU\n",
"train_loader = DataLoader(\n",
" train_ds,\n",
" batch_size=BATCH_SIZE,\n",
@@ -131,12 +142,12 @@
" pin_memory=True\n",
")\n",
"\n",
"print(\"Classes:\", train_ds.labels)\n"
"print(f\"Индексированные классы: {train_ds.labels}\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "dffce582",
"metadata": {},
"outputs": [
@@ -391,55 +402,51 @@
}
],
"source": [
"# TODO перед защитой, повторить оптимизаторы\n",
"# Загрузка предобученной архитектуры ResNet-50 с заменой классификационного слоя\n",
"model = timm.create_model(\n",
" \"resnet50\",\n",
" pretrained=True,\n",
" num_classes=len(train_ds.labels)\n",
")\n",
"model.to(device)\n",
"\n",
"model.to(DEVICE)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "81a457ef",
"metadata": {},
"outputs": [],
"source": [
"# Функция потерь для многоклассовой классификации\n",
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"# Оптимизатор AdamW с L2-регуляризацией (weight_decay) для повышения обобщающей способности\n",
"optimizer = torch.optim.AdamW(\n",
" model.parameters(),\n",
" lr=LR,\n",
" weight_decay=1e-4\n",
")\n",
"\n",
"# Планировщик скорости обучения: косинусный отжиг\n",
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
" optimizer,\n",
" T_max=EPOCHS\n",
")\n"
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "951aa9e3",
"execution_count": null,
"id": "81a457ef",
"metadata": {},
"outputs": [],
"source": [
"def train_epoch(model, loader):\n",
" model.train()\n",
" total_loss = 0\n",
" correct = 0\n",
" total = 0\n",
"def train_epoch(current_model, loader):\n",
" current_model.train()\n",
" total_loss = 0.0\n",
" correct_preds = 0\n",
" total_samples = 0\n",
"\n",
" for imgs, labels in tqdm(loader, leave=False):\n",
" imgs = imgs.to(DEVICE)\n",
" labels = labels.to(DEVICE)\n",
" for imgs, labels in tqdm(loader, desc=\"Тренировка\", leave=False):\n",
" imgs = imgs.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" logits = model(imgs)\n",
" logits = current_model(imgs)\n",
" loss = criterion(logits, labels)\n",
"\n",
" loss.backward()\n",
@@ -447,292 +454,67 @@
"\n",
" total_loss += loss.item() * imgs.size(0)\n",
" preds = logits.argmax(dim=1)\n",
" correct += (preds == labels).sum().item()\n",
" total += labels.size(0)\n",
" correct_preds += (preds == labels).sum().item()\n",
" total_samples += labels.size(0)\n",
"\n",
" return total_loss / total_samples, correct_preds / total_samples\n",
"\n",
" return total_loss / total, correct / total\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "fb7e9398",
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def val_epoch(model, loader):\n",
" model.eval()\n",
" total_loss = 0\n",
" correct = 0\n",
" total = 0\n",
"def val_epoch(current_model, loader):\n",
" # Перевод модели в режим инференса (отключение Dropout и фиксация BatchNorm)\n",
" current_model.eval()\n",
" total_loss = 0.0\n",
" correct_preds = 0\n",
" total_samples = 0\n",
"\n",
" for imgs, labels in loader:\n",
" imgs = imgs.to(DEVICE)\n",
" labels = labels.to(DEVICE)\n",
" for imgs, labels in tqdm(loader, desc=\"Валидация\", leave=False):\n",
" imgs = imgs.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" logits = model(imgs)\n",
" logits = current_model(imgs)\n",
" loss = criterion(logits, labels)\n",
"\n",
" total_loss += loss.item() * imgs.size(0)\n",
" preds = logits.argmax(dim=1)\n",
" correct += (preds == labels).sum().item()\n",
" total += labels.size(0)\n",
" correct_preds += (preds == labels).sum().item()\n",
" total_samples += labels.size(0)\n",
"\n",
" return total_loss / total, correct / total\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9e870e5d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/1477 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 01 | Train loss: 0.8383, acc: 0.6954 | Val loss: 0.6694, acc: 0.7563\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 02 | Train loss: 0.5462, acc: 0.7972 | Val loss: 0.6592, acc: 0.7594\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 03 | Train loss: 0.3654, acc: 0.8632 | Val loss: 0.7263, acc: 0.7600\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 04 | Train loss: 0.2111, acc: 0.9230 | Val loss: 0.8572, acc: 0.7472\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 05 | Train loss: 0.1187, acc: 0.9585 | Val loss: 1.0372, acc: 0.7453\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 06 | Train loss: 0.0690, acc: 0.9768 | Val loss: 1.1982, acc: 0.7529\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 07 | Train loss: 0.0466, acc: 0.9843 | Val loss: 1.3178, acc: 0.7492\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 08 | Train loss: 0.0295, acc: 0.9905 | Val loss: 1.3926, acc: 0.7551\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 09 | Train loss: 0.0204, acc: 0.9938 | Val loss: 1.4682, acc: 0.7497\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10 | Train loss: 0.0146, acc: 0.9955 | Val loss: 1.4784, acc: 0.7604\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 11 | Train loss: 0.0087, acc: 0.9975 | Val loss: 1.5263, acc: 0.7580\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 12 | Train loss: 0.0057, acc: 0.9987 | Val loss: 1.5689, acc: 0.7558\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 13 | Train loss: 0.0044, acc: 0.9990 | Val loss: 1.5952, acc: 0.7566\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 14 | Train loss: 0.0030, acc: 0.9993 | Val loss: 1.6130, acc: 0.7600\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 15 | Train loss: 0.0025, acc: 0.9995 | Val loss: 1.5921, acc: 0.7627\n"
]
}
],
"source": [
"best_val_acc = 0.0\n",
"\n",
"for epoch in range(1, EPOCHS + 1):\n",
" train_loss, train_acc = train_epoch(model, train_loader)\n",
" val_loss, val_acc = val_epoch(model, val_loader)\n",
"\n",
" scheduler.step()\n",
"\n",
" print(\n",
" f\"Epoch {epoch:02d} | \"\n",
" f\"Train loss: {train_loss:.4f}, acc: {train_acc:.4f} | \"\n",
" f\"Val loss: {val_loss:.4f}, acc: {val_acc:.4f}\"\n",
" )\n",
"\n",
" if val_acc > best_val_acc:\n",
" best_val_acc = val_acc\n",
" torch.save(model.state_dict(), \"emoset_resnet50_best.pth\")\n"
" return total_loss / total_samples, correct_preds / total_samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7796ef11",
"id": "951aa9e3",
"metadata": {},
"outputs": [],
"source": []
"source": [
"best_val_acc = 0.0\n",
"checkpoint_path = \"../emoset_resnet50_best.pth\"\n",
"\n",
"print(\"Старт процесса обучения...\")\n",
"\n",
"for epoch in range(1, EPOCHS + 1):\n",
" train_loss, train_acc = train_epoch(model, train_loader)\n",
" val_loss, val_acc = val_epoch(model, val_loader)\n",
"\n",
" # Обновление шага планировщика\n",
" scheduler.step()\n",
"\n",
" print(\n",
" f\"Эпоха {epoch:02d}/{EPOCHS} | \"\n",
" f\"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | \"\n",
" f\"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}\"\n",
" )\n",
"\n",
" # Экспорт весов при улучшении целевой метрики\n",
" if val_acc > best_val_acc:\n",
" best_val_acc = val_acc\n",
" torch.save(model.state_dict(), checkpoint_path)\n",
" print(f\" -> Сохранен новый лучший чекпоинт (Acc: {best_val_acc:.4f})\")\n",
"\n",
"print(\"Обучение завершено.\")"
]
}
],
"metadata": {
File diff suppressed because one or more lines are too long
+69
View File
@@ -0,0 +1,69 @@
import pandas as pd
from pathlib import Path
from tqdm import tqdm
# Конфигурация путей и целевых признаков
BASE_DIR = Path("../../dataset/DEAM")
MUSIC_DB_PATH = BASE_DIR / "music_db.csv"
FEATURES_DIR = BASE_DIR / "features" / "features"
OUTPUT_PATH = BASE_DIR / "music_db_enriched.csv"
# Маппинг низкоуровневых признаков экстрактора (openSMILE/GeMAPS) в дескрипторы системы
TARGET_FEATURES = {
'pcm_RMSenergy_sma_amean': 'energy',
'pcm_fftMag_spectralFlux_sma_amean': 'flux',
'pcm_fftMag_spectralCentroid_sma_amean': 'centroid',
'F0final_sma_amean': 'pitch',
'logHNR_sma_amean': 'hnr',
'pcm_zcr_sma_amean': 'zcr',
'pcm_fftMag_spectralEntropy_sma_amean': 'entropy',
'pcm_fftMag_psySharpness_sma_amean': 'sharpness'
}
def aggregate_acoustic_features():
if not MUSIC_DB_PATH.exists():
print(f"Базовый файл аннотаций не найден: {MUSIC_DB_PATH}")
return
print("Загрузка эмоциональной разметки DEAM...")
df_main = pd.read_csv(MUSIC_DB_PATH)
print("Агрегация фреймовых акустических признаков...")
aggregated_data = []
# Итерация по трекам для сбора покадровых характеристик
for _, row in tqdm(df_main.iterrows(), total=len(df_main), desc="Обработка аудио-векторов"):
song_id = int(row['song_id'])
feature_file = FEATURES_DIR / f"{song_id}.csv"
if feature_file.exists():
try:
# Чтение сырых векторов (формат csv с разделителем ';')
df_feat = pd.read_csv(feature_file, sep=';')
# Усреднение характеристик по временной оси (time frames)
mean_features = df_feat[list(TARGET_FEATURES.keys())].mean()
# Формирование агрегированной записи
track_data = {'song_id': song_id}
for orig_col, new_col in TARGET_FEATURES.items():
track_data[new_col] = mean_features[orig_col]
aggregated_data.append(track_data)
except Exception as e:
print(f"Ошибка парсинга файла {feature_file.name}: {e}")
# Слияние акустических дескрипторов с эмоциональными координатами (Inner Join)
df_features = pd.DataFrame(aggregated_data)
df_enriched = pd.merge(df_main, df_features, on='song_id', how='inner')
# Очистка возможных артефактов NaN после агрегации
df_enriched = df_enriched.dropna(subset=list(TARGET_FEATURES.values()))
df_enriched.to_csv(OUTPUT_PATH, index=False)
print(f"Экспорт завершен. Сформирована обогащенная база: {OUTPUT_PATH.name}")
print(f"Итоговый размер выборки: {len(df_enriched)} треков.")
if __name__ == "__main__":
aggregate_acoustic_features()
+80
View File
@@ -0,0 +1,80 @@
import joblib
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.linear_model import RidgeCV
from sklearn.multioutput import MultiOutputRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
# Проекция дискретных классов эмоций на непрерывное пространство Рассела (Valence, Arousal)
# Значения откалиброваны в диапазоне [1.0, 9.0]
EMOTION_TO_VA_COORDS = {
0: (7.5, 6.5), # amusement
1: (2.0, 8.0), # anger
2: (6.5, 5.0), # awe
3: (7.0, 3.0), # contentment
4: (3.0, 6.0), # disgust
5: (8.0, 8.0), # excitement
6: (2.5, 7.5), # fear
7: (2.0, 2.0), # sadness
}
def train_va_regressor():
# Настройка путей
base_dir = Path(__file__).resolve().parent.parent
embeddings_path = base_dir / "emoset_test_embeddings.npy"
labels_path = base_dir / "emoset_test_labels.npy"
model_output_path = base_dir / "music_engine" / "va_regressor.pkl"
if not embeddings_path.exists() or not labels_path.exists():
print(f"Артефакты признаков не найдены в директории: {base_dir}")
return
print("Загрузка вектора признаков и меток классов...")
x_features = np.load(embeddings_path)
y_discrete = np.load(labels_path)
# Трансформация целевой переменной: классы -> непрерывные координаты V/A
y_continuous = np.array([EMOTION_TO_VA_COORDS[label] for label in y_discrete])
x_train, x_test, y_train, y_test = train_test_split(
x_features, y_continuous, test_size=0.2, random_state=42
)
# Построение пайплайна: Z-масштабирование и L2-регуляризованная регрессия
# RidgeCV автоматически подбирает оптимальный гиперпараметр alpha (силу регуляризации)
print("Инициализация и обучение пайплайна RidgeCV...")
regression_pipeline = Pipeline([
('scaler', StandardScaler()),
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
])
regression_pipeline.fit(x_train, y_train)
# Оценка обобщающей способности модели
y_pred = regression_pipeline.predict(x_test)
mse_score = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print("Обучение завершено. Метрики качества на тестовой выборке:")
print(f" - MSE: {mse_score:.4f}")
print(f" - R^2: {r2:.4f}")
# Диагностика дисперсии предсказаний
v_min, v_max = y_pred[:, 0].min(), y_pred[:, 0].max()
a_min, a_max = y_pred[:, 1].min(), y_pred[:, 1].max()
print(f"Распределение Valence (прогноз): [{v_min:.2f}, {v_max:.2f}] (Эталон: 1.0 - 9.0)")
print(f"Распределение Arousal (прогноз): [{a_min:.2f}, {a_max:.2f}] (Эталон: 1.0 - 9.0)")
# Экспорт обученного пайплайна
model_output_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(regression_pipeline, model_output_path)
print(f"Пайплайн сохранен: {model_output_path.name}")
if __name__ == "__main__":
train_va_regressor()
+264
View File
@@ -0,0 +1,264 @@
import os
import gc
import pickle
import random
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.io as tv_io
from torch.amp import autocast, GradScaler
from tqdm import tqdm
import timm
# Конфигурация стенда и путей файловой системы
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M")
CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl")
PREVIOUS_WEIGHTS = Path("/home/zin/projects/Thesis/src/emoset_resnet50_best.pth")
RESUME_CHECKPOINT = Path("/home/zin/projects/Thesis/src/emoset_resnet50_resume.pth")
SAVE_MODEL_PATH = Path("/home/zin/projects/Thesis/src/emoset_resnet50_finetuned_2_41M.pth")
CLASS_MAPPING = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7
}
# Гиперпараметры конвейера обучения
BATCH_SIZE = 82
EPOCHS = 15
LR = 5e-5
NUM_TRAIN_WORKERS = 48
NUM_VAL_WORKERS = 18
PATIENCE = 4
def prepare_dataset_index():
# Построение или загрузка индекса файлов для минимизации I/O операций по сети (NFS)
if CACHE_PATH.exists():
print(f"Загрузка карты файловой системы из кэша: {CACHE_PATH.name}")
with open(CACHE_PATH, 'rb') as f:
cache_data = pickle.load(f)
return cache_data['image_paths'], cache_data['labels']
print(f"Сканирование сетевой директории {DATA_ROOT} (первичная индексация)...")
paths, labels = [], []
for img_path in DATA_ROOT.rglob('*.jpg'):
emotion_folder = img_path.parts[-3].lower()
if emotion_folder in CLASS_MAPPING:
paths.append(str(img_path))
labels.append(CLASS_MAPPING[emotion_folder])
with open(CACHE_PATH, 'wb') as f:
pickle.dump({'image_paths': paths, 'labels': labels}, f)
return paths, labels
class EmoSetDirectDataset(Dataset):
# Датасет с отложенной аугментацией: декодирование на CPU, трансформации на GPU
def __init__(self, image_paths, labels):
self.image_paths = image_paths
self.labels = labels
self.base_transform = T.Resize((256, 256), antialias=True)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
try:
image = tv_io.read_image(self.image_paths[idx], mode=tv_io.ImageReadMode.RGB)
image = image.to(torch.float32) / 255.0
image = self.base_transform(image)
except Exception:
# Изолирование сбоев ввода-вывода (поврежденные файлы на сетевом диске)
image = torch.zeros((3, 256, 256), dtype=torch.float32)
return image, self.labels[idx]
def build_gpu_transforms():
# Перенос матричных операций аугментации на тензорные ядра видеокарты
train_tf = torch.nn.Sequential(
T.RandomCrop((224, 224)),
T.RandomHorizontalFlip(p=0.5),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(DEVICE)
val_tf = torch.nn.Sequential(
T.CenterCrop((224, 224)),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(DEVICE)
return train_tf, val_tf
if __name__ == "__main__":
print(f"Инициализация конвейера обучения. Устройство: {DEVICE}")
all_paths, all_labels = prepare_dataset_index()
# Фиксация сида для детерминированного разделения выборок при перезапусках скрипта
random.seed(42)
combined = list(zip(all_paths, all_labels))
random.shuffle(combined)
all_paths, all_labels = zip(*combined)
split_idx = int(len(all_paths) * 0.95)
train_loader = DataLoader(
EmoSetDirectDataset(all_paths[:split_idx], all_labels[:split_idx]),
batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
prefetch_factor=2, persistent_workers=True
)
val_loader = DataLoader(
EmoSetDirectDataset(all_paths[split_idx:], all_labels[split_idx:]),
batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_VAL_WORKERS, pin_memory=True,
prefetch_factor=2, persistent_workers=True
)
gpu_train_tf, gpu_val_tf = build_gpu_transforms()
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = GradScaler()
best_val_loss = float('inf')
epochs_no_improve = 0
start_epoch = 1
# Инициализация механизма отказоустойчивости и интеграция весов
if RESUME_CHECKPOINT.exists():
print(f"Восстановление контекста выполнения из: {RESUME_CHECKPOINT.name}")
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if 'scaler_state_dict' in checkpoint: scaler.load_state_dict(checkpoint['scaler_state_dict'])
if 'best_val_loss' in checkpoint: best_val_loss = checkpoint['best_val_loss']
start_epoch = checkpoint['epoch'] + 1
elif PREVIOUS_WEIGHTS.exists():
print(f"Интеграция претренированных весов: {PREVIOUS_WEIGHTS.name}")
model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE))
else:
print("Веса не найдены. Инициализация с ImageNet.")
model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE)
try:
for epoch in range(start_epoch, EPOCHS + 1):
# Проход по обучающей выборке
model.train()
running_loss, correct, total = 0.0, 0, 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [Train]")
for inputs, labels in pbar:
try:
inputs = inputs.to(DEVICE, non_blocking=True)
labels = labels.to(DEVICE, non_blocking=True)
inputs = gpu_train_tf(inputs)
optimizer.zero_grad()
# Смешанная точность для экономии VRAM
with autocast(device_type="cuda"):
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
except RuntimeError as memory_err:
# Подавление пиковых скачков потребления VRAM
if "out of memory" in str(memory_err).lower():
if 'outputs' in locals(): del outputs
if 'loss' in locals(): del loss
torch.cuda.empty_cache()
optimizer.zero_grad()
continue
raise memory_err
train_loss = running_loss / total if total > 0 else 0
train_acc = correct / total if total > 0 else 0
gc.collect()
torch.cuda.empty_cache()
# Проход по валидационной выборке
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for val_inputs, val_labels in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} [Val]", leave=False):
val_inputs = val_inputs.to(DEVICE, non_blocking=True)
val_labels = val_labels.to(DEVICE, non_blocking=True)
val_inputs = gpu_val_tf(val_inputs)
with autocast(device_type="cuda"):
val_outputs = model(val_inputs)
v_loss = criterion(val_outputs, val_labels)
val_loss += v_loss.item() * val_inputs.size(0)
_, val_predicted = val_outputs.max(1)
val_total += val_labels.size(0)
val_correct += val_predicted.eq(val_labels).sum().item()
epoch_val_loss = val_loss / val_total if val_total > 0 else 0
epoch_val_acc = val_correct / val_total if val_total > 0 else 0
scheduler.step()
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}")
# Оценка критериев ранней остановки и сохранение состояния сессии
if epoch_val_loss < best_val_loss:
best_val_loss = epoch_val_loss
epochs_no_improve = 0
torch.save(model.state_dict(), str(SAVE_MODEL_PATH).replace(".pth", "_best.pth"))
else:
epochs_no_improve += 1
if epochs_no_improve >= PATIENCE and epoch >= 15:
print(f"Сработал механизм Early Stopping. Валидация не улучшается {PATIENCE} эпох.")
break
# Атомарное сохранение контекста
checkpoint_state = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'best_val_loss': best_val_loss
}
torch.save(checkpoint_state, RESUME_CHECKPOINT)
gc.collect()
except KeyboardInterrupt:
print("\nВыполнение прервано пользователем (SIGINT).")
print(f"Дамп памяти конвейера зафиксирован на эпохе {epoch}.")
checkpoint_state = {
'epoch': epoch, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': scaler.state_dict(),
'best_val_loss': best_val_loss
}
torch.save(checkpoint_state, RESUME_CHECKPOINT)
else:
if SAVE_MODEL_PATH.parent.exists():
torch.save(model.state_dict(), SAVE_MODEL_PATH)
print(f"Процесс Fine-Tuning завершен. Артефакт сохранен: {SAVE_MODEL_PATH.name}")
if RESUME_CHECKPOINT.exists():
RESUME_CHECKPOINT.unlink()
-125
View File
@@ -1,125 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0336fd0c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ База загружена. Треков: 1744\n",
"🔍 Собираем акустические признаки...\n",
"\n",
"🚀 ГОТОВО! Обогащенная база сохранена: ../../dataset/DEAM/music_db_enriched.csv\n",
"Собрано фичей для 1744 из 1744 треков.\n",
" song_id valence arousal energy flux centroid pitch \\\n",
"0 2 3.1 3.0 0.097268 0.846947 483.421751 93.884056 \n",
"1 3 3.5 3.3 0.126809 0.959460 173.219616 62.682589 \n",
"2 4 5.7 5.5 0.156699 1.333944 466.434797 92.850316 \n",
"3 5 4.4 5.3 0.126455 1.009927 546.152506 158.673853 \n",
"4 7 5.8 6.4 0.268180 1.589191 175.369162 83.823484 \n",
"\n",
" hnr zcr entropy sharpness \n",
"0 3.615380 0.034270 3.299075 0.426490 \n",
"1 -2.600122 0.017893 2.294971 0.165583 \n",
"2 -0.579130 0.042936 3.258138 0.395410 \n",
"3 1.751148 0.043781 3.514585 0.494367 \n",
"4 12.006770 0.014783 2.177862 0.170058 \n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from pathlib import Path\n",
"from tqdm import tqdm # для красивого прогресс-бара, если не установлен - убери\n",
"\n",
"# 1. Пути к файлам\n",
"base_dir = Path(\"../../dataset/DEAM\") # Поправь, если запускаешь из другого места\n",
"music_db_path = base_dir / \"music_db.csv\"\n",
"features_dir = base_dir / \"features\" / \"features\"\n",
"output_path = base_dir / \"music_db_enriched.csv\"\n",
"\n",
"# 2. Наш \"Золотой список\" (8 признаков)\n",
"target_columns = {\n",
" 'pcm_RMSenergy_sma_amean': 'energy',\n",
" 'pcm_fftMag_spectralFlux_sma_amean': 'flux',\n",
" 'pcm_fftMag_spectralCentroid_sma_amean': 'centroid',\n",
" 'F0final_sma_amean': 'pitch',\n",
" 'logHNR_sma_amean': 'hnr',\n",
" 'pcm_zcr_sma_amean': 'zcr',\n",
" 'pcm_fftMag_spectralEntropy_sma_amean': 'entropy',\n",
" 'pcm_fftMag_psySharpness_sma_amean': 'sharpness'\n",
"}\n",
"\n",
"# 3. Загружаем текущую базу с V/A\n",
"if not music_db_path.exists():\n",
" print(f\"ОШИБКА: Не найден файл {music_db_path}\")\n",
"else:\n",
" df_main = pd.read_csv(music_db_path)\n",
" print(f\"База загружена. Треков: {len(df_main)}\")\n",
"\n",
" # Подготавливаем новые колонки\n",
" for col_name in target_columns.values():\n",
" df_main[col_name] = np.nan\n",
"\n",
" # 4. Проходимся по всем трекам и ищем их акустические CSV\n",
" print(\"Собираем акустические признаки...\")\n",
" found_count = 0\n",
" \n",
" for index, row in df_main.iterrows():\n",
" song_id = int(row['song_id'])\n",
" feature_file = features_dir / f\"{song_id}.csv\"\n",
" \n",
" if feature_file.exists():\n",
" try:\n",
" # Читаем CSV с признаками (разделитель там обычно точка с запятой)\n",
" df_feat = pd.read_csv(feature_file, sep=';')\n",
" \n",
" # Усредняем значения по всем фреймам (одна песня разбита на сотни строк-фреймов)\n",
" mean_features = df_feat[list(target_columns.keys())].mean()\n",
" \n",
" # Записываем в главную базу\n",
" for orig_col, new_col in target_columns.items():\n",
" df_main.at[index, new_col] = mean_features[orig_col]\n",
" \n",
" found_count += 1\n",
" except Exception as e:\n",
" print(f\"Ошибка чтения {feature_file}: {e}\")\n",
" \n",
" # 5. Сохраняем результат\n",
" # Удаляем треки, для которых не нашлось фичей (если такие есть)\n",
" df_main = df_main.dropna(subset=list(target_columns.values()))\n",
" \n",
" df_main.to_csv(output_path, index=False)\n",
" print(f\"\\nГотово! Обогащенная база сохранена: {output_path}\")\n",
" print(f\"Собрано фичей для {found_count} из {len(df_main)} треков.\")\n",
" print(df_main.head())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (thesis)",
"language": "python",
"name": "thesis"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
-614
View File
@@ -1,614 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "09f9237a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (4.4.2)\n",
"Requirement already satisfied: tqdm in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (4.67.1)\n",
"Requirement already satisfied: pillow in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (12.1.0)\n",
"Requirement already satisfied: requests in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (2.32.5)\n",
"Requirement already satisfied: filelock in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (3.20.3)\n",
"Requirement already satisfied: numpy>=1.17 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (2.4.1)\n",
"Requirement already satisfied: pyarrow>=21.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (22.0.0)\n",
"Requirement already satisfied: dill<0.4.1,>=0.3.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.4.0)\n",
"Requirement already satisfied: pandas in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (2.3.3)\n",
"Requirement already satisfied: httpx<1.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.28.1)\n",
"Requirement already satisfied: xxhash in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (3.6.0)\n",
"Requirement already satisfied: multiprocess<0.70.19 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.70.18)\n",
"Requirement already satisfied: fsspec<=2025.10.0,>=2023.1.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2025.10.0)\n",
"Requirement already satisfied: huggingface-hub<2.0,>=0.25.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (1.3.1)\n",
"Requirement already satisfied: packaging in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (25.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (6.0.3)\n",
"Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (3.13.3)\n",
"Requirement already satisfied: anyio in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (4.12.1)\n",
"Requirement already satisfied: certifi in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (2026.1.4)\n",
"Requirement already satisfied: httpcore==1.* in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (1.0.9)\n",
"Requirement already satisfied: idna in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (3.11)\n",
"Requirement already satisfied: h11>=0.16 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx<1.0.0->datasets) (0.16.0)\n",
"Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.2.0)\n",
"Requirement already satisfied: shellingham in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.5.4)\n",
"Requirement already satisfied: typer-slim in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (0.21.1)\n",
"Requirement already satisfied: typing-extensions>=4.1.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (4.15.0)\n",
"Requirement already satisfied: charset_normalizer<4,>=2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests) (3.4.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests) (2.6.3)\n",
"Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2.6.1)\n",
"Requirement already satisfied: aiosignal>=1.4.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.4.0)\n",
"Requirement already satisfied: attrs>=17.3.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (25.4.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.8.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (6.7.0)\n",
"Requirement already satisfied: propcache>=0.2.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (0.4.1)\n",
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.22.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2025.2)\n",
"Requirement already satisfied: tzdata>=2022.7 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2025.3)\n",
"Requirement already satisfied: six>=1.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n",
"Requirement already satisfied: click>=8.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from typer-slim->huggingface-hub<2.0,>=0.25.0->datasets) (8.3.1)\n"
]
}
],
"source": [
"!pip install datasets tqdm pillow requests\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6f0b2e2c",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "95f07577d20642b09f2cda6f0b2cca14",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "868d872a109d49f9966f2f19985e7048",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "06741794289540849ad179c5966dcab8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0/18 [00:00<?, ?files/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e47aad5270144913996cb5b226213ab9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00000-of-00018.parquet: 0%| | 0.00/509M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "30d1492a948245e3b6b58e92218cd760",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00001-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "931823b458cb4696b459e9011537cf1e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00002-of-00018.parquet: 0%| | 0.00/489M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "846f4245b16d4cc096a43c940590ad11",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00003-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "71df201ff1a24811af67458c3fe3f2f4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00004-of-00018.parquet: 0%| | 0.00/495M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "404dce6c69fc413dbe4aa84c289a0ab6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00005-of-00018.parquet: 0%| | 0.00/501M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e52b0bbbfdd14c599f44f02a48542317",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00006-of-00018.parquet: 0%| | 0.00/510M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "172981d77fc941cfa32c05f5a34bf742",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00007-of-00018.parquet: 0%| | 0.00/497M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cc9d886ff22f4165bf696c8b4d758931",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00008-of-00018.parquet: 0%| | 0.00/512M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5f118a9923c64ee2aa2001a1414927a3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00009-of-00018.parquet: 0%| | 0.00/502M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "db61d8d556dc4574adbd8f916f790fa7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00010-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "75414190b19c4affbe190f6dd4f7bc4f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00011-of-00018.parquet: 0%| | 0.00/500M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "172aa22ed0c44a289e0ac68b240c13c4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00012-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2baa935ed3524a73883909752cb15907",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00013-of-00018.parquet: 0%| | 0.00/491M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5e716611b29b44788e0bf2e7ad05be5b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00014-of-00018.parquet: 0%| | 0.00/502M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d9c0baac101b449794155392f07b49c3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00015-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b31cdc7f17ac4ac8a04593e8a01a300a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00016-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ed6766f750c54b4194957bfe3db78ed6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/train-00017-of-00018.parquet: 0%| | 0.00/494M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5454d2ecded64b82a12823f02a7ab12d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/val-00000-of-00002.parquet: 0%| | 0.00/282M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "62dd1439e0514c98b0c24cc8f600c57e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/val-00001-of-00002.parquet: 0%| | 0.00/283M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3a5b966f79314e069251462bff82395f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00000-of-00004.parquet: 0%| | 0.00/422M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "422974f938924910a0712b30a9c2bd84",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00001-of-00004.parquet: 0%| | 0.00/430M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f155a08427094de7ad1a5884e623db2b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00002-of-00004.parquet: 0%| | 0.00/420M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a94a4621d19f45f690e0064fee83767b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/test-00003-of-00004.parquet: 0%| | 0.00/422M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "50f55b00a27b4213b573b398e5b0d708",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0%| | 0/94481 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8658b8414f604f0ca2fd248a214ad4aa",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating val split: 0%| | 0/5905 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d59b7dea75f84b64bb8b262b43730e51",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating test split: 0%| | 0/17716 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0c5815040f0a4a31903348a8327811a5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading dataset shards: 0%| | 0/18 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
" num_rows: 94481\n",
" })\n",
" val: Dataset({\n",
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
" num_rows: 5905\n",
" })\n",
" test: Dataset({\n",
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
" num_rows: 17716\n",
" })\n",
"})\n"
]
}
],
"source": [
"from datasets import load_dataset\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"import requests\n",
"\n",
"# куда сохраняем датасет\n",
"DATA_DIR = Path(\"../dataset/EmoSet-118K\")\n",
"DATA_DIR.mkdir(exist_ok=True, parents=True)\n",
"\n",
"# загружаем через Hugging Face\n",
"ds = load_dataset(\"Woleek/EmoSet-118K\")\n",
"\n",
"print(ds)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "052ab073",
"metadata": {},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"from pathlib import Path\n",
"\n",
"def save_split(split):\n",
" split_dir = DATA_DIR / split\n",
" img_dir = split_dir / \"images\"\n",
" img_dir.mkdir(parents=True, exist_ok=True)\n",
"\n",
" labels_path = split_dir / \"labels.csv\"\n",
"\n",
" # перезаписываем labels.csv\n",
" with open(labels_path, \"w\") as f:\n",
" f.write(\"filename,label\\n\")\n",
"\n",
" for example in tqdm(ds[split]):\n",
" img = example[\"image\"] # уже PIL.Image\n",
" label = example[\"emotion\"]\n",
" image_id = example[\"image_id\"]\n",
"\n",
" fname = f\"{image_id}.jpg\"\n",
" img.save(img_dir / fname)\n",
"\n",
" with open(labels_path, \"a\") as f:\n",
" f.write(f\"{fname},{label}\\n\")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a74ceedf",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 94481/94481 [18:43<00:00, 84.10it/s] \n",
"100%|██████████| 5905/5905 [01:08<00:00, 86.57it/s] \n",
"100%|██████████| 17716/17716 [02:57<00:00, 100.01it/s]\n"
]
}
],
"source": [
"save_split(\"train\")\n",
"save_split(\"val\")\n",
"save_split(\"test\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "thesis-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
-140
View File
@@ -1,140 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Загрузка датасета DEAM\n",
"\n",
"Этот ноутбук предназначен для автоматизации процесса скачивания и подготовки музыкального датасета **DEAM** (Database for Emotional Analysis in Music).\n",
"Данные будут помещены в папку `dataset/DEAM`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting kagglehub\n",
" Downloading kagglehub-1.0.1-py3-none-any.whl.metadata (40 kB)\n",
"Collecting kagglesdk<1.0,>=0.1.22 (from kagglehub)\n",
" Downloading kagglesdk-0.1.23-py3-none-any.whl.metadata (13 kB)\n",
"Requirement already satisfied: packaging in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (25.0)\n",
"Requirement already satisfied: pyyaml in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (6.0.3)\n",
"Requirement already satisfied: requests in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (2.32.5)\n",
"Requirement already satisfied: tqdm in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (4.67.1)\n",
"Requirement already satisfied: protobuf in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglesdk<1.0,>=0.1.22->kagglehub) (6.33.4)\n",
"Requirement already satisfied: charset_normalizer<4,>=2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (3.4.4)\n",
"Requirement already satisfied: idna<4,>=2.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (3.11)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (2.6.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (2026.1.4)\n",
"Downloading kagglehub-1.0.1-py3-none-any.whl (70 kB)\n",
"Downloading kagglesdk-0.1.23-py3-none-any.whl (217 kB)\n",
"Installing collected packages: kagglesdk, kagglehub\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2/2\u001b[0m [kagglehub]\n",
"\u001b[1A\u001b[2KSuccessfully installed kagglehub-1.0.1 kagglesdk-0.1.23\n",
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.1.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"!pip install kagglehub"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Скачиваем датасет DEAM...\n",
"Downloading to /home/zin/.cache/kagglehub/datasets/imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music/1.archive...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1.83G/1.83G [01:09<00:00, 28.2MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting files...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Датасет скачан во временную директорию: /home/zin/.cache/kagglehub/datasets/imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music/versions/1\n",
"Переносим файлы в ../dataset/DEAM...\n",
"\n",
"[УСПЕХ] Датасет DEAM готов к работе!\n"
]
}
],
"source": [
"import os\n",
"import shutil\n",
"import kagglehub\n",
"from pathlib import Path\n",
"\n",
"# 1. Настройка путей\n",
"DATASET_ROOT = Path(\"../dataset\")\n",
"DEAM_ROOT = DATASET_ROOT / \"DEAM\"\n",
"DEAM_ROOT.mkdir(parents=True, exist_ok=True)\n",
"\n",
"# 2. Загрузка через kagglehub\n",
"print(\"Скачиваем датасет DEAM...\")\n",
"kaggle_cache_path = kagglehub.dataset_download(\"imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music\")\n",
"print(f\"Датасет скачан во временную директорию: {kaggle_cache_path}\")\n",
"\n",
"# 3. Перемещение файлов в проект\n",
"print(f\"Переносим файлы в {DEAM_ROOT}...\")\n",
"shutil.copytree(kaggle_cache_path, DEAM_ROOT, dirs_exist_ok=True)\n",
"\n",
"print(\"\\nУСПЕХ! Датасет DEAM готов к работе!\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (my-python-project)",
"language": "python",
"name": "my-python-project"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
-171
View File
@@ -1,171 +0,0 @@
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
import json
from PIL import Image
class EmoSet(Dataset):
ATTRIBUTES_MULTI_CLASS = [
'scene', 'facial_expression', 'human_action', 'brightness', 'colorfulness',
]
ATTRIBUTES_MULTI_LABEL = [
'object'
]
NUM_CLASSES = {
'brightness': 11,
'colorfulness': 11,
'scene': 254,
'object': 409,
'facial_expression': 6,
'human_action': 264,
}
def __init__(self,
data_root,
num_emotion_classes,
phase,
):
assert num_emotion_classes in (8, 2)
assert phase in ('train', 'val', 'test')
self.transforms_dict = self.get_data_transforms()
self.info = self.get_info(data_root, num_emotion_classes)
if phase == 'train':
self.transform = self.transforms_dict['train']
elif phase == 'val':
self.transform = self.transforms_dict['val']
elif phase == 'test':
self.transform = self.transforms_dict['test']
else:
raise NotImplementedError
data_store = json.load(open(os.path.join(data_root, f'{phase}.json')))
self.data_store = [
[
self.info['emotion']['label2idx'][item[0]],
item[1],
os.path.join(data_root, item[2]),
os.path.join(data_root, item[3])
]
for item in data_store
]
@classmethod
def get_data_transforms(cls):
transforms_dict = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
return transforms_dict
def get_info(self, data_root, num_emotion_classes):
assert num_emotion_classes in (8, 2)
info = json.load(open(os.path.join(data_root, 'info.json')))
if num_emotion_classes == 8:
pass
elif num_emotion_classes == 2:
emotion_info = {
'label2idx': {
'amusement': 0,
'awe': 0,
'contentment': 0,
'excitement': 0,
'anger': 1,
'disgust': 1,
'fear': 1,
'sadness': 1,
},
'idx2label': {
'0': 'positive',
'1': 'negative',
}
}
info['emotion'] = emotion_info
else:
raise NotImplementedError
return info
def load_image_by_path(self, path):
image = Image.open(path).convert('RGB')
image = self.transform(image)
return image
def load_annotation_by_path(self, path):
json_data = json.load(open(path))
return json_data
def __getitem__(self, item):
emotion_label_idx, image_id, image_path, annotation_path = self.data_store[item]
image = self.load_image_by_path(image_path)
annotation_data = self.load_annotation_by_path(annotation_path)
data = {'image_id': image_id, 'image': image, 'emotion_label_idx': emotion_label_idx}
for attribute in self.ATTRIBUTES_MULTI_CLASS:
# if empty, set to -1, else set to label index
attribute_label_idx = -1
if attribute in annotation_data:
attribute_label_idx = self.info[attribute]['label2idx'][str(annotation_data[attribute])]
data.update({f'{attribute}_label_idx': attribute_label_idx})
for attribute in self.ATTRIBUTES_MULTI_LABEL:
# if empty, set to 0, else set to 1
assert attribute == 'object'
num_classes = self.NUM_CLASSES[attribute]
attribute_label_idx = torch.zeros(num_classes)
if attribute in annotation_data:
for label in annotation_data[attribute]:
attribute_label_idx[self.info[attribute]['label2idx'][label]] = 1
data.update({f'{attribute}_label_idx': attribute_label_idx})
return data
def __len__(self):
return len(self.data_store)
if __name__ == '__main__':
data_root = r'F:\common_file_system\EmoSet\EmoSet_v5_划分train-test-val'
num_emotion_classes = 8
phase = 'train'
dataset = EmoSet(
data_root=data_root,
num_emotion_classes=num_emotion_classes,
phase=phase,
)
# print(dataset.info)
dataloader = DataLoader(dataset, batch_size = 16, shuffle = True)
for i, data in enumerate(dataloader):
pass
# print(data['emotion_label_idx'])
# print(data['scene_label_idx'])
# print(data['facial_expression_label_idx'])
# print(data['human_action_label_idx'])
# print(data['brightness_label_idx'])
# print(data['colorfulness_label_idx'])
# print(data['object_label_idx'])
# break
File diff suppressed because one or more lines are too long
-314
View File
@@ -1,314 +0,0 @@
import os
import gc
import pickle
import random
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.io as tv_io
from torch.amp import autocast, GradScaler
from tqdm import tqdm
import timm
# ==========================================
# 1. КОНФИГУРАЦИЯ И ПУТИ
# ==========================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Используем устройство: {DEVICE}")
# Путь к огромному датасету на NFS
DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M")
CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl")
# Пути к моделям
PREVIOUS_WEIGHTS = Path("/home/zin/projects/Thesis/src/emoset_resnet50_best.pth") # Старые веса (118K)
RESUME_CHECKPOINT = Path("/home/zin/projects/Thesis/src/emoset_resnet50_resume.pth") # Файл для восстановления сессии
SAVE_MODEL_PATH = Path("/home/zin/projects/Thesis/src/emoset_resnet50_finetuned_2.41M.pth") # Финальный файл
EMO_MAP = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7
}
# --- НАСТРОЙКИ ОБУЧЕНИЯ ---
BATCH_SIZE = 82
EPOCHS = 15
LR = 5e-5 # Низкий LR, так как мы делаем Fine-Tuning
NUM_TRAIN_WORKERS = 48
NUM_VAL_WORKERS = 18
# Настройки защиты от переобучения
PATIENCE = 4
best_val_loss = float('inf')
epochs_no_improve = 0
start_epoch = 1
# ==========================================
# 2. ПОДГОТОВКА ДАННЫХ (БЫСТРЫЙ КЭШ)
# ==========================================
if CACHE_PATH.exists():
print(f"Загрузка списка файлов из кэша: {CACHE_PATH}...")
with open(CACHE_PATH, 'rb') as f:
cache_data = pickle.load(f)
all_paths = cache_data['image_paths']
all_labels = cache_data['labels']
print(f"Готово! Моментально загружено {len(all_paths)} путей.")
else:
print(f"Сканирование NFS директории {DATA_ROOT} (Выполняется один раз)...")
all_paths, all_labels = [], []
for img_path in DATA_ROOT.rglob('*.jpg'):
emotion_folder = img_path.parts[-3].lower()
if emotion_folder in EMO_MAP:
all_paths.append(str(img_path))
all_labels.append(EMO_MAP[emotion_folder])
with open(CACHE_PATH, 'wb') as f:
pickle.dump({'image_paths': all_paths, 'labels': all_labels}, f)
print(f"Сохранено в кэш: {len(all_paths)} изображений.")
# Разделение на Train / Validation (95% / 5%)
random.seed(42) # Фиксируем сид, чтобы при перезапусках сплит не менялся
combined = list(zip(all_paths, all_labels))
random.shuffle(combined)
all_paths, all_labels = zip(*combined)
split_idx = int(len(all_paths) * 0.95)
train_paths, train_labels = all_paths[:split_idx], all_labels[:split_idx]
val_paths, val_labels = all_paths[split_idx:], all_labels[split_idx:]
print(f"Трейн: {len(train_paths)} | Валидация: {len(val_paths)}")
# ==========================================
# 3. DATASET & DATALOADER
# ==========================================
class EmoSetDirectDataset(Dataset):
def __init__(self, image_paths, labels):
self.image_paths = image_paths
self.labels = labels
# На процессоре делаем только базовый ресайз
self.base_transform = T.Resize((256, 256), antialias=True)
def __len__(self): return len(self.image_paths)
def __getitem__(self, idx):
try:
image = tv_io.read_image(self.image_paths[idx], mode=tv_io.ImageReadMode.RGB)
image = image.to(torch.float32) / 255.0
image = self.base_transform(image)
except Exception:
# Отказоустойчивость для битых файлов из интернета
image = torch.zeros((3, 256, 256), dtype=torch.float32)
return image, self.labels[idx]
# --- ИСПРАВЛЕННЫЕ ЗАГРУЗЧИКИ С PREFETCH ---
train_loader = DataLoader(
EmoSetDirectDataset(train_paths, train_labels),
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_TRAIN_WORKERS,
pin_memory=True,
prefetch_factor=2,
persistent_workers=True
)
val_loader = DataLoader(
EmoSetDirectDataset(val_paths, val_labels),
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_VAL_WORKERS,
pin_memory=True,
prefetch_factor=2,
persistent_workers=True
)
# ==========================================
# 4. АУГМЕНТАЦИИ НА GPU (СУПЕР СКОРОСТЬ)
# ==========================================
gpu_train_transforms = torch.nn.Sequential(
T.RandomCrop((224, 224)),
T.RandomHorizontalFlip(p=0.5),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(DEVICE)
gpu_val_transforms = torch.nn.Sequential(
T.CenterCrop((224, 224)),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(DEVICE)
# ==========================================
# 5. ИНИЦИАЛИЗАЦИЯ МОДЕЛИ
# ==========================================
print("\nСоздание архитектуры ResNet-50...")
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = GradScaler()
# --- ЛОГИКА БЕСШОВНОГО ВОССТАНОВЛЕНИЯ ---
if RESUME_CHECKPOINT.exists():
print(f"ВОССТАНОВЛЕНИЕ СЕССИИ из: {RESUME_CHECKPOINT}")
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if 'scaler_state_dict' in checkpoint:
scaler.load_state_dict(checkpoint['scaler_state_dict'])
if 'best_val_loss' in checkpoint:
best_val_loss = checkpoint['best_val_loss']
start_epoch = checkpoint['epoch'] + 1
print(f"УСПЕХ: Продолжаем обучение с эпохи {start_epoch}")
else:
print("Чекпоинт сессии не найден. Проверяем наличие базовых весов...")
if PREVIOUS_WEIGHTS.exists():
print(f"📥 Загрузка базовых весов (от 118K): {PREVIOUS_WEIGHTS}")
model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE))
else:
print("ВНИМАНИЕ: Базовые веса не найдены. Обучение начнется с нуля (ImageNet).")
model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE)
# ==========================================
# 6. ГЛАВНЫЙ ЦИКЛ ОБУЧЕНИЯ
# ==========================================
if start_epoch > EPOCHS:
print(f"\nОбучение уже было завершено (цель: {EPOCHS} эпох).")
else:
print(f"\nСтарт обучения. Целевое количество эпох: {EPOCHS}")
try:
for epoch in range(start_epoch, EPOCHS + 1):
# --- ФАЗА 1: ТРЕНИРОВКА ---
model.train()
running_loss, correct, total = 0.0, 0, 0
pbar = tqdm(train_loader, desc=f"Эпоха {epoch}/{EPOCHS} [Тренировка]")
for inputs, labels in pbar:
try:
# Перенос на GPU и применение быстрых аугментаций
inputs = inputs.to(DEVICE, non_blocking=True)
labels = labels.to(DEVICE, non_blocking=True)
inputs = gpu_train_transforms(inputs)
optimizer.zero_grad()
# Смешанная точность (AMP) для экономии VRAM и ускорения
with autocast(device_type="cuda"):
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
pbar.set_postfix({'loss': f"{loss.item():.4f}", 'acc': f"{correct/total:.4f}"})
except RuntimeError as e:
# Хендлер нехватки памяти (OOM)
if "out of memory" in str(e).lower():
print(f"\nВНИМАНИЕ: Нехватка VRAM! Очистка...")
if 'outputs' in locals(): del outputs
if 'loss' in locals(): del loss
torch.cuda.empty_cache()
optimizer.zero_grad()
continue
raise e
train_loss = running_loss / total if total > 0 else 0
train_acc = correct / total if total > 0 else 0
# Очистка мусора перед валидацией
gc.collect()
torch.cuda.empty_cache()
# --- ФАЗА 2: ВАЛИДАЦИЯ ---
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for val_inputs, val_labels in tqdm(val_loader, desc=f"Эпоха {epoch}/{EPOCHS} [Валидация]", leave=False):
val_inputs, val_labels = val_inputs.to(DEVICE), val_labels.to(DEVICE)
val_inputs = gpu_val_transforms(val_inputs)
with autocast(device_type="cuda"):
val_outputs = model(val_inputs)
v_loss = criterion(val_outputs, val_labels)
val_loss += v_loss.item() * val_inputs.size(0)
_, val_predicted = val_outputs.max(1)
val_total += val_labels.size(0)
val_correct += val_predicted.eq(val_labels).sum().item()
epoch_val_loss = val_loss / val_total if val_total > 0 else 0
epoch_val_acc = val_correct / val_total if val_total > 0 else 0
scheduler.step()
print(f"🏁 Эпоха {epoch} | Train Loss: {train_loss:.4f} (Acc: {train_acc:.4f}) | Val Loss: {epoch_val_loss:.4f} (Acc: {epoch_val_acc:.4f})")
# --- ФАЗА 3: EARLY STOPPING И СОХРАНЕНИЕ ---
if epoch_val_loss < best_val_loss:
best_val_loss = epoch_val_loss
epochs_no_improve = 0
torch.save(model.state_dict(), "../emoset_resnet50_best_2_41M.pth")
print("Новая лучшая модель найдена! Веса сохранены.")
else:
epochs_no_improve += 1
print(f"Валидация не улучшается {epochs_no_improve}/{PATIENCE} эпох.")
if epochs_no_improve >= PATIENCE and epoch >= 15: # Даем модели минимум 15 эпох на раскачку
print("\nСРАБОТАЛА ЗАЩИТА ОТ ПЕРЕОБУЧЕНИЯ (Early Stopping)!")
break
# Сохранение полного состояния сессии
checkpoint_state = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'best_val_loss': best_val_loss
}
torch.save(checkpoint_state, RESUME_CHECKPOINT)
# Сохранение весов конкретной эпохи как бэкап
torch.save(model.state_dict(), f"../emoset_resnet50_finetuned_ep{epoch}.pth")
gc.collect()
torch.cuda.empty_cache()
# ==========================================
# 7. ПЕРЕХВАТ РУЧНОЙ ОСТАНОВКИ (CTRL+C)
# ==========================================
except KeyboardInterrupt:
print("\n\nОБУЧЕНИЕ ПРЕРВАНО ВРУЧНУЮ (KeyboardInterrupt)!")
print(f"Экстренное сохранение состояния конвейера на эпохе {epoch}...")
checkpoint_state = {
'epoch': epoch, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': scaler.state_dict(),
'best_val_loss': best_val_loss
}
torch.save(checkpoint_state, RESUME_CHECKPOINT)
interrupted_weights_path = f"../emoset_resnet50_interrupted_ep{epoch}.pth"
torch.save(model.state_dict(), interrupted_weights_path)
print(f"Прогресс безопасно зафиксирован в файле {interrupted_weights_path}. Выходим.")
# ==========================================
# 8. ФИНАЛЬНОЕ СОХРАНЕНИЕ
# ==========================================
else:
if SAVE_MODEL_PATH.parent.exists():
torch.save(model.state_dict(), SAVE_MODEL_PATH)
print(f"\nОБУЧЕНИЕ УСПЕШНО ЗАВЕРШЕНО! Финальная модель: {SAVE_MODEL_PATH}")
if RESUME_CHECKPOINT.exists():
RESUME_CHECKPOINT.unlink() # Удаляем resume файл за ненадобностью
-467
View File
@@ -1,467 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "71ef58af",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Используем устройство: cuda\n"
]
}
],
"source": [
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torchvision.transforms as T\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"from tqdm.notebook import tqdm\n",
"import timm\n",
"\n",
"# Проверяем GPU\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(f\"Используем устройство: {DEVICE}\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f4ae931c",
"metadata": {},
"outputs": [],
"source": [
"# === НАСТРОЙКИ ДООБУЧЕНИЯ ===\n",
"\n",
"# Абсолютный путь к смонтированному NFS\n",
"DATA_ROOT = Path(\"/home/zin/projects/Thesis/NFS/Thesis/Emoset/Original-2.41M\")\n",
"\n",
"# Пути относительно src/scripts/\n",
"PREVIOUS_WEIGHTS = Path(\"../emoset_resnet50_best.pth\")\n",
"SAVE_MODEL_PATH = Path(\"../emoset_resnet50_finetuned_2.41M.pth\")\n",
"\n",
"# Маппинг эмоций в те же индексы (0-7), которые использовались при первоначальном обучении\n",
"EMO_MAP = {\n",
" \"amusement\": 0,\n",
" \"anger\": 1,\n",
" \"awe\": 2,\n",
" \"contentment\": 3,\n",
" \"disgust\": 4,\n",
" \"excitement\": 5,\n",
" \"fear\": 6,\n",
" \"sad\": 7, # В твоем сообщении папка называется \"sad\"\n",
" \"sadness\": 7 # На всякий случай оставляем и классическое название\n",
"}\n",
"\n",
"BATCH_SIZE = 96\n",
"EPOCHS = 15\n",
"LR = 5e-5\n",
"NUM_WORKERS = 42"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "934cfe2c",
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import os\n",
"import torchvision.io as tv_io\n",
"\n",
"# Твои трансформации (только Resize, как мы сделали ранее)\n",
"train_transforms = T.Compose([\n",
" T.Resize((256, 256), antialias=True)\n",
"])\n",
"\n",
"class EmoSetNestedDataset(Dataset):\n",
" def __init__(self, root_dir, transform=None, cache_file=\"../dataset_cache_2.41M.pkl\"):\n",
" self.root_dir = Path(root_dir)\n",
" self.transform = transform\n",
" self.image_paths = []\n",
" self.labels = []\n",
" \n",
" # === ЛОГИКА КЭШИРОВАНИЯ ===\n",
" if os.path.exists(cache_file):\n",
" print(f\"📦 Загрузка списка файлов из локального кэша: {cache_file}...\")\n",
" with open(cache_file, 'rb') as f:\n",
" cache_data = pickle.load(f)\n",
" self.image_paths = cache_data['image_paths']\n",
" self.labels = cache_data['labels']\n",
" print(f\"⚡ Готово! Моментально загружено {len(self.image_paths)} путей.\")\n",
" else:\n",
" print(f\"🔍 Сканирование NFS директории {self.root_dir}...\")\n",
" print(\"Это займет около 8-10 минут. Выполняется один раз...\")\n",
" \n",
" for img_path in self.root_dir.rglob('*.jpg'):\n",
" emotion_folder = img_path.parts[-3].lower()\n",
" if emotion_folder in EMO_MAP:\n",
" self.image_paths.append(str(img_path))\n",
" self.labels.append(EMO_MAP[emotion_folder])\n",
" \n",
" print(f\"💾 Сохранение результатов сканирования в кэш: {cache_file}...\")\n",
" with open(cache_file, 'wb') as f:\n",
" pickle.dump({'image_paths': self.image_paths, 'labels': self.labels}, f)\n",
" \n",
" print(f\"✅ Успешно найдено и закэшировано {len(self.image_paths)} изображений.\")\n",
"\n",
" def __len__(self):\n",
" return len(self.image_paths)\n",
"\n",
" def __getitem__(self, idx):\n",
" img_path = self.image_paths[idx]\n",
" label = self.labels[idx]\n",
" \n",
" try:\n",
" image = tv_io.read_image(str(img_path), mode=tv_io.ImageReadMode.RGB)\n",
" image = image.to(torch.float32) / 255.0\n",
" except Exception as e:\n",
" image = torch.zeros((3, 256, 256), dtype=torch.float32)\n",
" \n",
" if self.transform:\n",
" image = self.transform(image)\n",
" \n",
" return image, label"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b10adc06",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"📦 Загрузка списка файлов из локального кэша: ../dataset_paths_cache.pkl...\n",
"⚡ Готово! Моментально загружено 2048377 путей.\n",
"Батчей за одну эпоху: 21338\n",
"\n",
"Создание архитектуры ResNet50...\n",
"📝 Чекпоинт прерванной сессии не найден. Проверяем базовые веса...\n",
"УСПЕХ: Найдены предыдущие веса '../emoset_resnet50_best.pth' (из EmoSet-118K). Загружаем...\n"
]
}
],
"source": [
"# Путь к кэш-файлу (лучше положить его в src, рядом со скриптами, чтобы быстро читался)\n",
"CACHE_PATH = Path(\"../dataset_paths_cache.pkl\")\n",
"\n",
"# 1. Загрузка данных напрямую из папок (или из кэша!)\n",
"train_dataset = EmoSetNestedDataset(DATA_ROOT, transform=train_transforms, cache_file=CACHE_PATH)\n",
"\n",
"train_loader = DataLoader(\n",
" train_dataset, \n",
" batch_size=BATCH_SIZE, \n",
" shuffle=True, \n",
" num_workers=NUM_WORKERS,\n",
" pin_memory=True,\n",
" prefetch_factor=1,\n",
" persistent_workers=True\n",
")\n",
"\n",
"print(f\"Батчей за одну эпоху: {len(train_loader)}\")\n",
"\n",
"# Путь к файлу автоматического восстановления (чекпоинт полной сессии)\n",
"RESUME_CHECKPOINT = Path(\"../emoset_resnet50_checkpoint_latest.pth\")\n",
"\n",
"# 2. Инициализация архитектуры модели ResNet-50\n",
"print(\"\\nСоздание архитектуры ResNet50...\")\n",
"model = timm.create_model('resnet50', pretrained=False, num_classes=8)\n",
"model = model.to(DEVICE)\n",
"\n",
"# Инициализируем компоненты оптимизации\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)\n",
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)\n",
"\n",
"# По умолчанию начинаем с 1-й эпохи\n",
"start_epoch = 1\n",
"\n",
"# === ЛОГИКА ВОССТАНОВЛЕНИЯ / СТАРТА ===\n",
"if RESUME_CHECKPOINT.exists():\n",
" print(f\"🔄 ОБНАРУЖЕН ДЕЙСТВУЮЩИЙ ЧЕКПОИНТ: '{RESUME_CHECKPOINT}'\")\n",
" print(\"Восстанавливаем полное состояние сессии...\")\n",
" \n",
" # Загружаем сохраненный словарь состояния\n",
" checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)\n",
" \n",
" # Восстанавливаем всё до единого\n",
" model.load_state_dict(checkpoint['model_state_dict'])\n",
" optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
" scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n",
" start_epoch = checkpoint['epoch'] + 1 # Начинаем со следующей эпохи\n",
" \n",
" print(f\"🚀 УСПЕХ: Сессия восстановлена! Продолжаем обучение с эпохи {start_epoch}\")\n",
"else:\n",
" print(\"📝 Чекпоинт прерванной сессии не найден. Проверяем базовые веса...\")\n",
" if PREVIOUS_WEIGHTS.exists():\n",
" print(f\"УСПЕХ: Найдены предыдущие веса '{PREVIOUS_WEIGHTS}' (из EmoSet-118K). Загружаем...\")\n",
" model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE))\n",
" else:\n",
" print(\"ВНИМАНИЕ: Базовые веса не найдены. Начинаем обучение с нуля (ImageNet pretrained).\")\n",
" model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a7480834",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"⏰ Старт обучения. Целевое количество эпох: 15\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Эпоха 1/15 [Тренировка]: 0%| | 0/21338 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Эпоха 1/15 [Тренировка]: 13%|█▎ | 2705/21338 [32:27<2:39:51, 1.94it/s, loss=1.6873, acc=0.3036] Corrupt JPEG data: 80 extraneous bytes before marker 0xd9\n",
"Эпоха 1/15 [Тренировка]: 17%|█▋ | 3623/21338 [43:15<4:46:57, 1.03it/s, loss=1.7141, acc=0.3102] Invalid SOS parameters for sequential JPEG\n",
"Эпоха 1/15 [Тренировка]: 18%|█▊ | 3741/21338 [44:35<2:23:26, 2.04it/s, loss=1.7848, acc=0.3109] Invalid SOS parameters for sequential JPEG\n",
"Эпоха 1/15 [Тренировка]: 19%|█▉ | 4109/21338 [49:12<2:17:42, 2.09it/s, loss=1.8072, acc=0.3133] Corrupt JPEG data: 485 extraneous bytes before marker 0xd9\n",
"Эпоха 1/15 [Тренировка]: 22%|██▏ | 4729/21338 [56:22<3:06:07, 1.49it/s, loss=1.6499, acc=0.3173] Invalid SOS parameters for sequential JPEG\n",
"Эпоха 1/15 [Тренировка]: 38%|███▊ | 8033/21338 [1:35:20<2:21:51, 1.56it/s, loss=1.5897, acc=0.3338] Corrupt JPEG data: 41 extraneous bytes before marker 0xd9\n",
"Эпоха 1/15 [Тренировка]: 45%|████▌ | 9684/21338 [1:54:34<3:40:06, 1.13s/it, loss=1.6740, acc=0.3399] Invalid SOS parameters for sequential JPEG\n",
"Эпоха 1/15 [Тренировка]: 50%|█████ | 10679/21338 [2:06:15<47:11, 3.76it/s, loss=1.6234, acc=0.3431] Invalid SOS parameters for sequential JPEG\n",
"Эпоха 1/15 [Тренировка]: 55%|█████▌ | 11802/21338 [2:19:39<1:50:22, 1.44it/s, loss=1.5677, acc=0.3463] Unknown Adobe color transform code 2\n",
"Эпоха 1/15 [Тренировка]: 67%|██████▋ | 14253/21338 [2:48:12<27:27, 4.30it/s, loss=1.7579, acc=0.3525] Invalid SOS parameters for sequential JPEG\n",
"Эпоха 1/15 [Тренировка]: 77%|███████▋ | 16377/21338 [3:13:17<1:06:23, 1.25it/s, loss=1.6855, acc=0.3572] Invalid SOS parameters for sequential JPEG\n",
"Эпоха 1/15 [Тренировка]: 92%|█████████▏| 19575/21338 [3:51:14<11:13, 2.62it/s, loss=1.5876, acc=0.3631] Invalid SOS parameters for sequential JPEG\n",
"Эпоха 1/15 [Тренировка]: 92%|█████████▏| 19679/21338 [3:52:26<07:42, 3.59it/s, loss=1.7134, acc=0.3633] Invalid SOS parameters for sequential JPEG\n",
"Эпоха 1/15 [Тренировка]: 100%|█████████▉| 21283/21338 [4:11:18<00:20, 2.70it/s, loss=1.4613, acc=0.3658] Unknown Adobe color transform code 2\n",
"Эпоха 1/15 [Тренировка]: 100%|██████████| 21338/21338 [4:11:47<00:00, 1.41it/s, loss=1.6304, acc=0.3659]\n"
]
},
{
"ename": "NameError",
"evalue": "name 'val_loader' is not defined",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mNameError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 88\u001b[39m\n\u001b[32m 85\u001b[39m \u001b[38;5;66;03m# ВАЖНО: Если у тебя нет val_loader, создай его (откуси 5-10% от датасета)\u001b[39;00m\n\u001b[32m 86\u001b[39m \u001b[38;5;66;03m# На валидации мы НЕ применяем gpu_transforms (только нормализацию)\u001b[39;00m\n\u001b[32m 87\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m torch.no_grad():\n\u001b[32m---> \u001b[39m\u001b[32m88\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m val_inputs, val_labels \u001b[38;5;129;01min\u001b[39;00m \u001b[43mval_loader\u001b[49m:\n\u001b[32m 89\u001b[39m val_inputs, val_labels = val_inputs.to(DEVICE), val_labels.to(DEVICE)\n\u001b[32m 91\u001b[39m \u001b[38;5;66;03m# Валидация тоже идет в смешанной точности для скорости\u001b[39;00m\n",
"\u001b[31mNameError\u001b[39m: name 'val_loader' is not defined"
]
}
],
"source": [
"import gc\n",
"import torch\n",
"from torch.amp import autocast, GradScaler\n",
"from tqdm import tqdm\n",
"import torchvision.transforms as T\n",
"\n",
"# --- НАСТРОЙКИ EARLY STOPPING ---\n",
"PATIENCE = 4 # Сколько эпох ждем, если валидация не улучшается\n",
"best_val_loss = float('inf')\n",
"epochs_no_improve = 0\n",
"start_epoch = 1\n",
"\n",
"# --- ПЕРЕНОСИМ ТЯЖЕЛЫЕ АУГМЕНТАЦИИ НА GPU ---\n",
"gpu_transforms = torch.nn.Sequential(\n",
" T.RandomCrop((224, 224)),\n",
" T.RandomHorizontalFlip(p=0.5),\n",
" T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),\n",
" T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
").to(DEVICE)\n",
"# -------------------------------------------\n",
"\n",
"print(f\"⏰ Старт обучения. Целевое количество эпох: {EPOCHS}\")\n",
"scaler = GradScaler()\n",
"\n",
"try:\n",
" for epoch in range(start_epoch, EPOCHS + 1):\n",
" # ==========================================\n",
" # 1. ТРЕНИРОВКА (TRAIN)\n",
" # ==========================================\n",
" model.train()\n",
" running_loss = 0.0\n",
" correct = 0\n",
" total = 0\n",
" \n",
" pbar = tqdm(train_loader, desc=f\"Эпоха {epoch}/{EPOCHS} [Тренировка]\")\n",
" for inputs, labels in pbar:\n",
" try:\n",
" inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)\n",
" \n",
" # Применяем аугментации на лету к батчу (на GPU)\n",
" inputs = gpu_transforms(inputs)\n",
" \n",
" optimizer.zero_grad()\n",
" \n",
" with autocast(device_type=\"cuda\"): # Для PyTorch >= 2.0 лучше указывать \"cuda\"\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, labels)\n",
" \n",
" scaler.scale(loss).backward()\n",
" scaler.step(optimizer)\n",
" scaler.update()\n",
" \n",
" # Статистика\n",
" running_loss += loss.item() * inputs.size(0)\n",
" _, predicted = outputs.max(1)\n",
" total += labels.size(0)\n",
" correct += predicted.eq(labels).sum().item()\n",
" \n",
" pbar.set_postfix({'loss': f\"{loss.item():.4f}\", 'acc': f\"{correct/total:.4f}\"})\n",
" \n",
" except RuntimeError as e:\n",
" # Обработчик OOM\n",
" if \"out of memory\" in str(e).lower():\n",
" print(f\"\\n⚠️ ВНИМАНИЕ: Нехватка VRAM на батче! Выполняем экстренную очистку...\")\n",
" if 'outputs' in locals(): del outputs\n",
" if 'loss' in locals(): del loss\n",
" torch.cuda.empty_cache()\n",
" optimizer.zero_grad()\n",
" print(\"♻️ Кэш очищен. Батч пропущен. Продолжаем обучение...\")\n",
" continue\n",
" else:\n",
" raise e\n",
" \n",
" epoch_loss = running_loss / total if total > 0 else 0\n",
" epoch_acc = correct / total if total > 0 else 0\n",
" \n",
" # ==========================================\n",
" # 2. ВАЛИДАЦИЯ И EARLY STOPPING\n",
" # ==========================================\n",
" model.eval()\n",
" val_loss = 0.0\n",
" val_correct = 0\n",
" val_total = 0\n",
" \n",
" # ВАЖНО: Если у тебя нет val_loader, создай его (откуси 5-10% от датасета)\n",
" # На валидации мы НЕ применяем gpu_transforms (только нормализацию)\n",
" with torch.no_grad():\n",
" for val_inputs, val_labels in val_loader:\n",
" val_inputs, val_labels = val_inputs.to(DEVICE), val_labels.to(DEVICE)\n",
" \n",
" # Валидация тоже идет в смешанной точности для скорости\n",
" with autocast(device_type=\"cuda\"):\n",
" val_outputs = model(val_inputs)\n",
" v_loss = criterion(val_outputs, val_labels)\n",
" \n",
" val_loss += v_loss.item() * val_inputs.size(0)\n",
" _, val_predicted = val_outputs.max(1)\n",
" val_total += val_labels.size(0)\n",
" val_correct += val_predicted.eq(val_labels).sum().item()\n",
" \n",
" epoch_val_loss = val_loss / val_total if val_total > 0 else 0\n",
" epoch_val_acc = val_correct / val_total if val_total > 0 else 0\n",
" \n",
" scheduler.step()\n",
" print(f\"🏁 Эпоха {epoch} завершена | Train Loss: {epoch_loss:.4f} (Acc: {epoch_acc:.4f}) | Val Loss: {epoch_val_loss:.4f} (Acc: {epoch_val_acc:.4f})\")\n",
" \n",
" # --- ЛОГИКА РАННЕЙ ОСТАНОВКИ ---\n",
" if epoch_val_loss < best_val_loss:\n",
" best_val_loss = epoch_val_loss\n",
" epochs_no_improve = 0\n",
" # Сохраняем идеальные веса, пока сеть не переобучилась\n",
" torch.save(model.state_dict(), \"../emoset_resnet50_best.pth\")\n",
" print(\"🌟 Найдена лучшая модель! Веса сохранены.\")\n",
" else:\n",
" epochs_no_improve += 1\n",
" print(f\"⚠️ Валидационный Loss не улучшается {epochs_no_improve}/{PATIENCE} эпох.\")\n",
" \n",
" # Условие: если переобучение длится долго И мы прошли хотя бы 15 эпох\n",
" if epochs_no_improve >= PATIENCE and epoch >= 15:\n",
" print(\"\\n🛑 СРАБОТАЛА ЗАЩИТА ОТ ПЕРЕОБУЧЕНИЯ (Early Stopping)!\")\n",
" print(\"Модель начала запоминать данные вместо обобщения. Обучение досрочно завершено.\")\n",
" break # Прерываем цикл эпох\n",
" \n",
" # ==========================================\n",
" # 3. РЕГУЛЯРНОЕ СОХРАНЕНИЕ ПРОГРЕССА\n",
" # ==========================================\n",
" checkpoint_state = {\n",
" 'epoch': epoch,\n",
" 'model_state_dict': model.state_dict(),\n",
" 'optimizer_state_dict': optimizer.state_dict(),\n",
" 'scheduler_state_dict': scheduler.state_dict(),\n",
" 'scaler_state_dict': scaler.state_dict(),\n",
" 'loss': epoch_loss,\n",
" 'val_loss': epoch_val_loss\n",
" }\n",
" torch.save(checkpoint_state, RESUME_CHECKPOINT)\n",
" \n",
" # Сохранение весов конкретной эпохи (на всякий случай)\n",
" epoch_weights_path = f\"../emoset_resnet50_finetuned_ep{epoch}.pth\"\n",
" torch.save(model.state_dict(), epoch_weights_path)\n",
" \n",
" gc.collect()\n",
" torch.cuda.empty_cache()\n",
"\n",
"# ==========================================\n",
"# 4. БЕЗОПАСНЫЙ ВЫХОД ПРИ РУЧНОМ ПРЕРЫВАНИИ\n",
"# ==========================================\n",
"except KeyboardInterrupt:\n",
" print(\"\\n\\n🛑 ОБУЧЕНИЕ ПРЕРВАНО ВРУЧНУЮ (KeyboardInterrupt)!\")\n",
" print(f\"💾 Экстренное сохранение состояния конвейера на эпохе {epoch}...\")\n",
" \n",
" # Сохраняем всё, чтобы потом можно было продолжить с этого же места\n",
" checkpoint_state = {\n",
" 'epoch': epoch,\n",
" 'model_state_dict': model.state_dict(),\n",
" 'optimizer_state_dict': optimizer.state_dict(),\n",
" 'scheduler_state_dict': scheduler.state_dict(),\n",
" 'scaler_state_dict': scaler.state_dict()\n",
" }\n",
" torch.save(checkpoint_state, RESUME_CHECKPOINT)\n",
" \n",
" # Сохраняем промежуточные веса на момент остановки\n",
" interrupted_weights_path = f\"../emoset_resnet50_interrupted_ep{epoch}.pth\"\n",
" torch.save(model.state_dict(), interrupted_weights_path)\n",
" \n",
" print(f\"✅ Прогресс безопасно зафиксирован в файле {interrupted_weights_path}. Выходим.\")\n",
"\n",
"# Финальное сохранение (если цикл дошел до конца сам)\n",
"else:\n",
" if SAVE_MODEL_PATH.parent.exists():\n",
" torch.save(model.state_dict(), SAVE_MODEL_PATH)\n",
" print(f\"\\n🎉 ОБУЧЕНИЕ УСПЕШНО ЗАВЕРШЕНО! Финальная модель: {SAVE_MODEL_PATH}\")\n",
" if RESUME_CHECKPOINT.exists():\n",
" RESUME_CHECKPOINT.unlink()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (thesis)",
"language": "python",
"name": "thesis"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
-88
View File
@@ -1,88 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"id": "b92e0213",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from pathlib import Path"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1763c51e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ УСПЕХ! База создана: ../../dataset/DEAM/music_db.csv\n",
"Всего треков в базе: 1744\n",
"Пример данных:\n",
" song_id valence arousal\n",
"0 2 3.1 3.0\n",
"1 3 3.5 3.3\n",
"2 4 5.7 5.5\n",
"3 5 4.4 5.3\n",
"4 7 5.8 6.4\n"
]
}
],
"source": [
"# Точный путь к оригинальным аннотациям\n",
"source_path = Path(\"../../dataset/DEAM/DEAM_Annotations/annotations/annotations averaged per song/song_level/static_annotations_averaged_songs_1_2000.csv\")\n",
"# Путь, куда сохраним очищенную базу для движка\n",
"output_path = Path(\"../../dataset/DEAM/music_db.csv\")\n",
"\n",
"if not source_path.exists():\n",
" print(f\"Исходный файл не найден по пути: {source_path}\")\n",
"else:\n",
" # skipinitialspace=True уберет лишние пробелы в названиях колонок, если они есть\n",
" df = pd.read_csv(source_path, skipinitialspace=True)\n",
" \n",
" # Берем только нужные колонки (по твоему примеру)\n",
" clean_df = df[['song_id', 'valence_mean', 'arousal_mean']].copy()\n",
" \n",
" # Переименовываем для простоты кода в движке\n",
" clean_df.columns = ['song_id', 'valence', 'arousal']\n",
" \n",
" # Приводим ID к целому числу (2, 3, 4...), чтобы искать файлы '2.mp3'\n",
" clean_df['song_id'] = clean_df['song_id'].astype(int)\n",
" \n",
" # Сохраняем финальный файл\n",
" clean_df.to_csv(output_path, index=False)\n",
" \n",
" print(f\"Успех! База создана: {output_path}\")\n",
" print(f\"Всего треков в базе: {len(clean_df)}\")\n",
" print(\"Пример данных:\")\n",
" print(clean_df.head())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (thesis)",
"language": "python",
"name": "thesis"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
-114
View File
@@ -1,114 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d70d8e32",
"metadata": {},
"outputs": [],
"source": [
"from concurrent.futures import ProcessPoolExecutor\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"import torch\n",
"from torchvision import transforms\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "31b0fa82",
"metadata": {},
"outputs": [],
"source": [
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
"TRANSFORM = transforms.Compose([\n",
" transforms.Resize((224,224)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1a17ecf5",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/94481 [00:00<?, ?it/s]\n"
]
},
{
"ename": "PicklingError",
"evalue": "Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31m_RemoteTraceback\u001b[39m Traceback (most recent call last)",
"\u001b[31m_RemoteTraceback\u001b[39m: \n\"\"\"\nTraceback (most recent call last):\n File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 244, in _feed\n obj = _ForkingPickler.dumps(obj)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/reduction.py\", line 51, in dumps\n cls(buf, protocol).dump(obj)\n_pickle.PicklingError: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed\n\"\"\"",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[31mPicklingError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 20\u001b[39m\n\u001b[32m 18\u001b[39m futures = [executor.submit(process_row, row, split_dir, tensor_dir) \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m df.itertuples()]\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m tqdm(futures):\n\u001b[32m---> \u001b[39m\u001b[32m20\u001b[39m results.append(\u001b[43mf\u001b[49m\u001b[43m.\u001b[49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 22\u001b[39m new_df = pd.DataFrame(results)\n\u001b[32m 23\u001b[39m new_df.to_csv(DATA_ROOT / split / \u001b[33m\"\u001b[39m\u001b[33mlabels_tensor.csv\u001b[39m\u001b[33m\"\u001b[39m, index=\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:449\u001b[39m, in \u001b[36mFuture.result\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 447\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m CancelledError()\n\u001b[32m 448\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state == FINISHED:\n\u001b[32m--> \u001b[39m\u001b[32m449\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m__get_result\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 451\u001b[39m \u001b[38;5;28mself\u001b[39m._condition.wait(timeout)\n\u001b[32m 453\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state \u001b[38;5;129;01min\u001b[39;00m [CANCELLED, CANCELLED_AND_NOTIFIED]:\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:401\u001b[39m, in \u001b[36mFuture.__get_result\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 399\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception:\n\u001b[32m 400\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m401\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception\n\u001b[32m 402\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 403\u001b[39m \u001b[38;5;66;03m# Break a reference cycle with the exception in self._exception\u001b[39;00m\n\u001b[32m 404\u001b[39m \u001b[38;5;28mself\u001b[39m = \u001b[38;5;28;01mNone\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py:244\u001b[39m, in \u001b[36mQueue._feed\u001b[39m\u001b[34m(buffer, notempty, send_bytes, writelock, reader_close, writer_close, ignore_epipe, onerror, queue_sem)\u001b[39m\n\u001b[32m 241\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[32m 243\u001b[39m \u001b[38;5;66;03m# serialize the data before acquiring the lock\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m244\u001b[39m obj = \u001b[43m_ForkingPickler\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdumps\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 245\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m wacquire \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 246\u001b[39m send_bytes(obj)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/reduction.py:51\u001b[39m, in \u001b[36mForkingPickler.dumps\u001b[39m\u001b[34m(cls, obj, protocol)\u001b[39m\n\u001b[32m 48\u001b[39m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[32m 49\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdumps\u001b[39m(\u001b[38;5;28mcls\u001b[39m, obj, protocol=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m 50\u001b[39m buf = io.BytesIO()\n\u001b[32m---> \u001b[39m\u001b[32m51\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mbuf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprotocol\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdump\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 52\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m buf.getbuffer()\n",
"\u001b[31mPicklingError\u001b[39m: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed"
]
}
],
"source": [
"def process_row(row, split_dir, tensor_dir):\n",
" img_path = split_dir / row.filename\n",
" img = Image.open(img_path).convert(\"RGB\")\n",
" tensor = TRANSFORM(img)\n",
" tensor_path = tensor_dir / f\"{row.filename}.pt\"\n",
" torch.save(tensor, tensor_path)\n",
" return {\"tensor_path\": str(tensor_path), \"label\": row.label}\n",
"\n",
"for split in [\"train\",\"val\",\"test\"]:\n",
" split_dir = DATA_ROOT / split / \"images\"\n",
" tensor_dir = DATA_ROOT / split / \"tensors\"\n",
" tensor_dir.mkdir(exist_ok=True, parents=True)\n",
"\n",
" df = pd.read_csv(DATA_ROOT / split / \"labels.csv\")\n",
"\n",
" results = []\n",
" with ProcessPoolExecutor(max_workers=12) as executor:\n",
" futures = [executor.submit(process_row, row, split_dir, tensor_dir) for row in df.itertuples()]\n",
" for f in tqdm(futures):\n",
" results.append(f.result())\n",
"\n",
" new_df = pd.DataFrame(results)\n",
" new_df.to_csv(DATA_ROOT / split / \"labels_tensor.csv\", index=False)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "thesis-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
-134
View File
@@ -1,134 +0,0 @@
import os
import random
from pathlib import Path
from tqdm.notebook import tqdm
import webdataset as wds
# --- НАСТРОЙКИ ---
# Оригинальная папка с датасетом (на NFS)
DATA_ROOT = Path("/home/zin/projects/Thesis/NFS/Thesis/Emoset/Original-2.41M")
# Новая папка, куда мы сложим готовые .tar архивы (шарды)
# Лучше создать её рядом с оригинальным датасетом на NFS
SHARDS_DIR = Path("/home/zin/projects/Thesis/NFS/Thesis/Emoset/shards-2.41M")
SHARDS_DIR.mkdir(parents=True, exist_ok=True)
# Маппинг классов
EMO_MAP = {
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
"disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7
}
# Размер одного архива. 10 000 картинок — идеальный баланс
MAX_SAMPLES_PER_SHARD = 10000
samples = []
print(f"🔍 Сканирование директории {DATA_ROOT}...")
# Используем os.walk, он часто работает быстрее rglob на сетевых дисках
for root, dirs, files in os.walk(DATA_ROOT):
for file in files:
if file.lower().endswith('.jpg'):
full_path = os.path.join(root, file)
# Извлекаем эмоцию (зависит от структуры папок, берем предпоследнюю папку)
# Путь: .../amusement/0/image.jpg -> root_parts[-2] будет 'amusement'
path_parts = Path(full_path).parts
emotion_folder = path_parts[-3].lower()
if emotion_folder in EMO_MAP:
samples.append((full_path, EMO_MAP[emotion_folder]))
print(f"✅ Найдено изображений: {len(samples)}")
# САМЫЙ ВАЖНЫЙ ШАГ: Глобальное перемешивание перед упаковкой
print("🔀 Перемешиваем датасет...")
random.shuffle(samples)
print("✅ Перемешивание завершено!")
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
import webdataset as wds
from PIL import Image
import io
from pathlib import Path
# ВАЖНО: Импортируем базовый tqdm, а не notebook-версию.
# Notebook-версия в мультипроцессинге вызывает зависание графического интерфейса Jupyter.
from tqdm import tqdm
# --- ПУТИ И НАСТРОЙКИ ---
SHARDS_DIR = Path("../../dataset/EmoSet-2.41M-shards")
SHARDS_DIR.mkdir(parents=True, exist_ok=True)
NUM_WORKERS = 50
# 1. Дробим наш список на чанки
chunks = [samples[i:i + MAX_SAMPLES_PER_SHARD] for i in range(0, len(samples), MAX_SAMPLES_PER_SHARD)]
print(f"📦 Подготовлено {len(chunks)} задач (шардов).")
print(f"💾 Целевая папка (Локальный SSD): {SHARDS_DIR}")
print(f"🚀 Запуск упаковки и сжатия в {NUM_WORKERS} потоков...\n")
# Инициализация блокировки tqdm для мультипроцессинга (чтобы бары не съезжали)
tqdm.set_lock(mp.RLock())
# 2. Функция, которую выполняет каждый воркер
def build_shard(args):
shard_idx, chunk = args
shard_path = SHARDS_DIR / f"emoset-{shard_idx:06d}.tar"
success_count = 0
error_count = 0
# ХИТРОСТЬ ЗДЕСЬ: position = остаток от деления + 1.
# Это гарантирует, что все 42 воркера поделят между собой 42 строчки на экране,
# и когда воркер берет новый шард, он обновляет свою старую строчку, а не создает новую.
# leave=False заставит бар исчезнуть, когда чанк докачается.
worker_pos = (shard_idx % NUM_WORKERS) + 1
with wds.TarWriter(str(shard_path)) as sink:
# Рисуем прогресс-бар для текущего шарда
for i, (img_path, label) in enumerate(tqdm(chunk, desc=f"Шард {shard_idx:03d}", position=worker_pos, leave=False)):
try:
# --- МАГИЯ СЖАТИЯ ---
with Image.open(img_path) as img:
img = img.convert("RGB")
img = img.resize((256, 256), Image.Resampling.BILINEAR)
with io.BytesIO() as img_byte_arr:
img.save(img_byte_arr, format='JPEG', quality=85)
image_data = img_byte_arr.getvalue()
# --------------------
key = f"{shard_idx:06d}_{i:05d}"
sink.write({
"__key__": key,
"jpg": image_data,
"cls": label
})
success_count += 1
except Exception:
error_count += 1
continue # Игнорируем битые файлы
return shard_idx
# 3. Запускаем армию воркеров
with ProcessPoolExecutor(max_workers=NUM_WORKERS) as executor:
tasks = [(i, chunk) for i, chunk in enumerate(chunks)]
# Отправляем задачи в пул
futures = [executor.submit(build_shard, task) for task in tasks]
# ГЛАВНЫЙ прогресс-бар (position=0, всегда висит на самой первой строчке)
for future in tqdm(as_completed(futures), total=len(tasks), desc="📊 ОБЩИЙ ПРОГРЕСС", position=0, leave=True):
try:
future.result()
except Exception:
pass
# Печатаем пару пустых строк, чтобы финальный текст не налез на бары
print("\n" * (NUM_WORKERS + 2))
print("🎉 ПАРАЛЛЕЛЬНАЯ УПАКОВКА И СЖАТИЕ ПОЛНОСТЬЮ ЗАВЕРШЕНЫ!")
-919
View File
@@ -1,919 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"id": "e6aa65e8",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import random\n",
"from pathlib import Path\n",
"from tqdm.notebook import tqdm\n",
"import webdataset as wds\n",
"\n",
"# --- НАСТРОЙКИ ---\n",
"# Оригинальная папка с датасетом (на NFS)\n",
"DATA_ROOT = Path(\"/home/zin/projects/Thesis/NFS/Thesis/Emoset/Original-2.41M\")\n",
"\n",
"# Новая папка, куда мы сложим готовые .tar архивы (шарды)\n",
"# Лучше создать её рядом с оригинальным датасетом на NFS\n",
"SHARDS_DIR = Path(\"/home/zin/projects/Thesis/NFS/Thesis/Emoset/shards-2.41M\")\n",
"SHARDS_DIR.mkdir(parents=True, exist_ok=True)\n",
"\n",
"# Маппинг классов\n",
"EMO_MAP = {\n",
" \"amusement\": 0, \"anger\": 1, \"awe\": 2, \"contentment\": 3,\n",
" \"disgust\": 4, \"excitement\": 5, \"fear\": 6, \"sad\": 7, \"sadness\": 7\n",
"}\n",
"\n",
"# Размер одного архива. 10 000 картинок — идеальный баланс\n",
"MAX_SAMPLES_PER_SHARD = 10000"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "09d0e56c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🔍 Сканирование директории /home/zin/projects/Thesis/NFS/Thesis/Emoset/Original-2.41M...\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 11\u001b[39m\n\u001b[32m 8\u001b[39m full_path = os.path.join(root, file)\n\u001b[32m 9\u001b[39m \u001b[38;5;66;03m# Извлекаем эмоцию (зависит от структуры папок, берем предпоследнюю папку)\u001b[39;00m\n\u001b[32m 10\u001b[39m \u001b[38;5;66;03m# Путь: .../amusement/0/image.jpg -> root_parts[-2] будет 'amusement'\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m11\u001b[39m path_parts = \u001b[43mPath\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfull_path\u001b[49m\u001b[43m)\u001b[49m.parts\n\u001b[32m 12\u001b[39m emotion_folder = path_parts[-\u001b[32m3\u001b[39m].lower()\n\u001b[32m 14\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m emotion_folder \u001b[38;5;129;01min\u001b[39;00m EMO_MAP:\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/pathlib.py:871\u001b[39m, in \u001b[36mPath.__new__\u001b[39m\u001b[34m(cls, *args, **kwargs)\u001b[39m\n\u001b[32m 869\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m Path:\n\u001b[32m 870\u001b[39m \u001b[38;5;28mcls\u001b[39m = WindowsPath \u001b[38;5;28;01mif\u001b[39;00m os.name == \u001b[33m'\u001b[39m\u001b[33mnt\u001b[39m\u001b[33m'\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m PosixPath\n\u001b[32m--> \u001b[39m\u001b[32m871\u001b[39m \u001b[38;5;28mself\u001b[39m = \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_from_parts\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 872\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._flavour.is_supported:\n\u001b[32m 873\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mcannot instantiate \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[33m on your system\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 874\u001b[39m % (\u001b[38;5;28mcls\u001b[39m.\u001b[34m__name__\u001b[39m,))\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/pathlib.py:509\u001b[39m, in \u001b[36mPurePath._from_parts\u001b[39m\u001b[34m(cls, args)\u001b[39m\n\u001b[32m 504\u001b[39m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[32m 505\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_from_parts\u001b[39m(\u001b[38;5;28mcls\u001b[39m, args):\n\u001b[32m 506\u001b[39m \u001b[38;5;66;03m# We need to call _parse_args on the instance, so as to get the\u001b[39;00m\n\u001b[32m 507\u001b[39m \u001b[38;5;66;03m# right flavour.\u001b[39;00m\n\u001b[32m 508\u001b[39m \u001b[38;5;28mself\u001b[39m = \u001b[38;5;28mobject\u001b[39m.\u001b[34m__new__\u001b[39m(\u001b[38;5;28mcls\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m509\u001b[39m drv, root, parts = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_parse_args\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 510\u001b[39m \u001b[38;5;28mself\u001b[39m._drv = drv\n\u001b[32m 511\u001b[39m \u001b[38;5;28mself\u001b[39m._root = root\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/pathlib.py:502\u001b[39m, in \u001b[36mPurePath._parse_args\u001b[39m\u001b[34m(cls, args)\u001b[39m\n\u001b[32m 497\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 498\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[32m 499\u001b[39m \u001b[33m\"\u001b[39m\u001b[33margument should be a str object or an os.PathLike \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 500\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mobject returning str, not \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 501\u001b[39m % \u001b[38;5;28mtype\u001b[39m(a))\n\u001b[32m--> \u001b[39m\u001b[32m502\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_flavour\u001b[49m\u001b[43m.\u001b[49m\u001b[43mparse_parts\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparts\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/pathlib.py:67\u001b[39m, in \u001b[36m_Flavour.parse_parts\u001b[39m\u001b[34m(self, parts)\u001b[39m\n\u001b[32m 65\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m altsep:\n\u001b[32m 66\u001b[39m part = part.replace(altsep, sep)\n\u001b[32m---> \u001b[39m\u001b[32m67\u001b[39m drv, root, rel = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43msplitroot\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpart\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 68\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m sep \u001b[38;5;129;01min\u001b[39;00m rel:\n\u001b[32m 69\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mreversed\u001b[39m(rel.split(sep)):\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/pathlib.py:241\u001b[39m, in \u001b[36m_PosixFlavour.splitroot\u001b[39m\u001b[34m(self, part, sep)\u001b[39m\n\u001b[32m 239\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34msplitroot\u001b[39m(\u001b[38;5;28mself\u001b[39m, part, sep=sep):\n\u001b[32m 240\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m part \u001b[38;5;129;01mand\u001b[39;00m part[\u001b[32m0\u001b[39m] == sep:\n\u001b[32m--> \u001b[39m\u001b[32m241\u001b[39m stripped_part = part.lstrip(sep)\n\u001b[32m 242\u001b[39m \u001b[38;5;66;03m# According to POSIX path resolution:\u001b[39;00m\n\u001b[32m 243\u001b[39m \u001b[38;5;66;03m# http://pubs.opengroup.org/onlinepubs/009695399/basedefs/xbd_chap04.html#tag_04_11\u001b[39;00m\n\u001b[32m 244\u001b[39m \u001b[38;5;66;03m# \"A pathname that begins with two successive slashes may be\u001b[39;00m\n\u001b[32m 245\u001b[39m \u001b[38;5;66;03m# interpreted in an implementation-defined manner, although more\u001b[39;00m\n\u001b[32m 246\u001b[39m \u001b[38;5;66;03m# than two leading slashes shall be treated as a single slash\".\u001b[39;00m\n\u001b[32m 247\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(part) - \u001b[38;5;28mlen\u001b[39m(stripped_part) == \u001b[32m2\u001b[39m:\n",
"\u001b[31mKeyboardInterrupt\u001b[39m: "
]
}
],
"source": [
"samples = []\n",
"\n",
"print(f\"🔍 Сканирование директории {DATA_ROOT}...\")\n",
"# Используем os.walk, он часто работает быстрее rglob на сетевых дисках\n",
"for root, dirs, files in os.walk(DATA_ROOT):\n",
" for file in files:\n",
" if file.lower().endswith('.jpg'):\n",
" full_path = os.path.join(root, file)\n",
" # Извлекаем эмоцию (зависит от структуры папок, берем предпоследнюю папку)\n",
" # Путь: .../amusement/0/image.jpg -> root_parts[-2] будет 'amusement'\n",
" path_parts = Path(full_path).parts\n",
" emotion_folder = path_parts[-3].lower()\n",
" \n",
" if emotion_folder in EMO_MAP:\n",
" samples.append((full_path, EMO_MAP[emotion_folder]))\n",
"\n",
"print(f\"✅ Найдено изображений: {len(samples)}\")\n",
"\n",
"# САМЫЙ ВАЖНЫЙ ШАГ: Глобальное перемешивание перед упаковкой\n",
"print(\"🔀 Перемешиваем датасет...\")\n",
"random.shuffle(samples)\n",
"print(\"✅ Перемешивание завершено!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0fe71d72",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"📦 Подготовлено 205 задач (шардов).\n",
"💾 Целевая папка: ../../dataset/EmoSet-2.41M-shards\n",
"🚀 Запуск упаковки в 42 потоков...\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Exception in thread Thread-4 (ui_thread_func):\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/threading.py\", line 1045, in _bootstrap_inner\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/threading.py\", line 982, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/tmp/ipykernel_16083/731719608.py\", line 72, in ui_thread_func\n",
" File \"/home/zin/projects/Thesis/.venv/lib/python3.11/site-packages/tqdm/notebook.py\", line 223, in __init__\n",
" super().__init__(*args, **kwargs)\n",
" File \"/home/zin/projects/Thesis/.venv/lib/python3.11/site-packages/tqdm/std.py\", line 1001, in __init__\n",
" raise (\n",
"tqdm.std.TqdmKeyError: \"Unknown argument(s): {'color': 'blue'}\"\n",
"Process ForkProcess-39:\n",
"Process ForkProcess-35:\n",
"Process ForkProcess-28:\n",
"Process ForkProcess-29:\n",
"Process ForkProcess-43:\n",
"Process ForkProcess-38:\n",
"Process ForkProcess-36:\n",
"Process ForkProcess-37:\n",
"Process ForkProcess-34:\n",
"Traceback (most recent call last):\n",
"Traceback (most recent call last):\n",
"Traceback (most recent call last):\n",
"Traceback (most recent call last):\n",
"Traceback (most recent call last):\n",
"Traceback (most recent call last):\n",
"Traceback (most recent call last):\n",
"Process ForkProcess-27:\n",
"Process ForkProcess-25:\n",
"Traceback (most recent call last):\n",
"Process ForkProcess-30:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
"Process ForkProcess-41:\n",
"Process ForkProcess-31:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
"Process ForkProcess-22:\n",
"Process ForkProcess-32:\n",
"Process ForkProcess-40:\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
"Traceback (most recent call last):\n",
"Process ForkProcess-42:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"Process ForkProcess-20:\n",
"Traceback (most recent call last):\n",
"Process ForkProcess-24:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
"Process ForkProcess-23:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"Process ForkProcess-26:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"Process ForkProcess-21:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"Traceback (most recent call last):\n",
"Process ForkProcess-16:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"Traceback (most recent call last):\n",
"Process ForkProcess-13:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"Process ForkProcess-9:\n",
"Process ForkProcess-17:\n",
"Traceback (most recent call last):\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"Process ForkProcess-6:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"Process ForkProcess-8:\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
"Process ForkProcess-3:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"Traceback (most recent call last):\n",
"Traceback (most recent call last):\n",
"Process ForkProcess-10:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"Process ForkProcess-7:\n",
"Process ForkProcess-11:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"Traceback (most recent call last):\n",
"Process ForkProcess-14:\n",
"Process ForkProcess-4:\n",
"Traceback (most recent call last):\n",
"Process ForkProcess-5:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/threading.py\", line 982, in run\n",
"Process ForkProcess-2:\n",
"Traceback (most recent call last):\n",
"Process ForkProcess-12:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"Traceback (most recent call last):\n",
"Process ForkProcess-15:\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"Traceback (most recent call last):\n",
"Process ForkProcess-18:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 102\u001b[39m\n\u001b[32m 101\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m ProcessPoolExecutor(max_workers=NUM_WORKERS) \u001b[38;5;28;01mas\u001b[39;00m executor:\n\u001b[32m--> \u001b[39m\u001b[32m102\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m_\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mexecutor\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbuild_shard\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtasks\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 103\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mpass\u001b[39;49;00m \u001b[38;5;66;03m# Просто ждем завершения всех задач\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py:620\u001b[39m, in \u001b[36m_chain_from_iterable_of_lists\u001b[39m\u001b[34m(iterable)\u001b[39m\n\u001b[32m 615\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 616\u001b[39m \u001b[33;03mSpecialized implementation of itertools.chain.from_iterable.\u001b[39;00m\n\u001b[32m 617\u001b[39m \u001b[33;03mEach item in *iterable* should be a list. This function is\u001b[39;00m\n\u001b[32m 618\u001b[39m \u001b[33;03mcareful not to keep references to yielded objects.\u001b[39;00m\n\u001b[32m 619\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m620\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43melement\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43miterable\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 621\u001b[39m \u001b[43m \u001b[49m\u001b[43melement\u001b[49m\u001b[43m.\u001b[49m\u001b[43mreverse\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:619\u001b[39m, in \u001b[36mExecutor.map.<locals>.result_iterator\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m 618\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m619\u001b[39m \u001b[38;5;28;01myield\u001b[39;00m \u001b[43m_result_or_cancel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpop\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 620\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:317\u001b[39m, in \u001b[36m_result_or_cancel\u001b[39m\u001b[34m(***failed resolving arguments***)\u001b[39m\n\u001b[32m 316\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m317\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfut\u001b[49m\u001b[43m.\u001b[49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 318\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:451\u001b[39m, in \u001b[36mFuture.result\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 449\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.__get_result()\n\u001b[32m--> \u001b[39m\u001b[32m451\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_condition\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 453\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state \u001b[38;5;129;01min\u001b[39;00m [CANCELLED, CANCELLED_AND_NOTIFIED]:\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/threading.py:327\u001b[39m, in \u001b[36mCondition.wait\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 326\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m327\u001b[39m \u001b[43mwaiter\u001b[49m\u001b[43m.\u001b[49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 328\u001b[39m gotit = \u001b[38;5;28;01mTrue\u001b[39;00m\n",
"\u001b[31mKeyboardInterrupt\u001b[39m: ",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 101\u001b[39m\n\u001b[32m 99\u001b[39m \u001b[38;5;66;03m# 3. Запускаем 42 боевых ядра\u001b[39;00m\n\u001b[32m 100\u001b[39m tasks = [(i, chunk, queue) \u001b[38;5;28;01mfor\u001b[39;00m i, chunk \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(chunks)]\n\u001b[32m--> \u001b[39m\u001b[32m101\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mwith\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mProcessPoolExecutor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmax_workers\u001b[49m\u001b[43m=\u001b[49m\u001b[43mNUM_WORKERS\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mas\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mexecutor\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 102\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m_\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mexecutor\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbuild_shard\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtasks\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 103\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mpass\u001b[39;49;00m \u001b[38;5;66;03m# Просто ждем завершения всех задач\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:647\u001b[39m, in \u001b[36mExecutor.__exit__\u001b[39m\u001b[34m(self, exc_type, exc_val, exc_tb)\u001b[39m\n\u001b[32m 646\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__exit__\u001b[39m(\u001b[38;5;28mself\u001b[39m, exc_type, exc_val, exc_tb):\n\u001b[32m--> \u001b[39m\u001b[32m647\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mshutdown\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwait\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 648\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py:851\u001b[39m, in \u001b[36mProcessPoolExecutor.shutdown\u001b[39m\u001b[34m(self, wait, cancel_futures)\u001b[39m\n\u001b[32m 848\u001b[39m \u001b[38;5;28mself\u001b[39m._executor_manager_thread_wakeup.wakeup()\n\u001b[32m 850\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._executor_manager_thread \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m wait:\n\u001b[32m--> \u001b[39m\u001b[32m851\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_executor_manager_thread\u001b[49m\u001b[43m.\u001b[49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 852\u001b[39m \u001b[38;5;66;03m# To reduce the risk of opening too many files, remove references to\u001b[39;00m\n\u001b[32m 853\u001b[39m \u001b[38;5;66;03m# objects that use file descriptors.\u001b[39;00m\n\u001b[32m 854\u001b[39m \u001b[38;5;28mself\u001b[39m._executor_manager_thread = \u001b[38;5;28;01mNone\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/threading.py:1119\u001b[39m, in \u001b[36mThread.join\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 1116\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mcannot join current thread\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 1118\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1119\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_wait_for_tstate_lock\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1120\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 1121\u001b[39m \u001b[38;5;66;03m# the behavior of a negative timeout isn't documented, but\u001b[39;00m\n\u001b[32m 1122\u001b[39m \u001b[38;5;66;03m# historically .join(timeout=x) for x<0 has acted as if timeout=0\u001b[39;00m\n\u001b[32m 1123\u001b[39m \u001b[38;5;28mself\u001b[39m._wait_for_tstate_lock(timeout=\u001b[38;5;28mmax\u001b[39m(timeout, \u001b[32m0\u001b[39m))\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/threading.py:1139\u001b[39m, in \u001b[36mThread._wait_for_tstate_lock\u001b[39m\u001b[34m(self, block, timeout)\u001b[39m\n\u001b[32m 1136\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[32m 1138\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1139\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mlock\u001b[49m\u001b[43m.\u001b[49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43mblock\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[32m 1140\u001b[39m lock.release()\n\u001b[32m 1141\u001b[39m \u001b[38;5;28mself\u001b[39m._stop()\n",
"\u001b[31mKeyboardInterrupt\u001b[39m: "
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"Traceback (most recent call last):\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
"Process ForkProcess-19:\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"Traceback (most recent call last):\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
" self.run()\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
" self._target(*self._args, **self._kwargs)\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 103, in get\n",
" res = self._recv_bytes()\n",
" ^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
" call_item = call_queue.get(block=True)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/connection.py\", line 216, in recv_bytes\n",
" buf = self._recv_bytes(maxlength)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
" with self._rlock:\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/connection.py\", line 430, in _recv_bytes\n",
" buf = self._recv(4)\n",
" ^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
" return self._semlock.__enter__()\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n",
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/connection.py\", line 395, in _recv\n",
" chunk = read(handle, remaining)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^\n",
"KeyboardInterrupt\n",
"KeyboardInterrupt\n"
]
}
],
"source": [
"import multiprocessing as mp\n",
"from concurrent.futures import ProcessPoolExecutor\n",
"import webdataset as wds\n",
"from PIL import Image\n",
"import io\n",
"import threading\n",
"from pathlib import Path\n",
"\n",
"# Включаем красивую версию tqdm специально для Jupyter\n",
"from tqdm.notebook import tqdm\n",
"\n",
"# --- ПУТИ И НАСТРОЙКИ ---\n",
"SHARDS_DIR = Path(\"../../dataset/EmoSet-2.41M-shards\")\n",
"SHARDS_DIR.mkdir(parents=True, exist_ok=True)\n",
"\n",
"NUM_WORKERS = 42\n",
"MAX_SAMPLES_PER_SHARD = 10000\n",
"\n",
"# Дробим список на чанки\n",
"chunks = [samples[i:i + MAX_SAMPLES_PER_SHARD] for i in range(0, len(samples), MAX_SAMPLES_PER_SHARD)]\n",
"TOTAL_FILES = len(samples)\n",
"TOTAL_SHARDS = len(chunks)\n",
"\n",
"print(f\"📦 Подготовлено {TOTAL_SHARDS} задач (шардов).\")\n",
"print(f\"💾 Целевая папка: {SHARDS_DIR}\")\n",
"print(f\"🚀 Запуск упаковки в {NUM_WORKERS} потоков...\\n\")\n",
"\n",
"# --- ФУНКЦИЯ ДЛЯ ЯДЕР ПРОЦЕССОРА ---\n",
"def build_shard(args):\n",
" shard_idx, chunk, queue = args\n",
" shard_path = SHARDS_DIR / f\"emoset-{shard_idx:06d}.tar\"\n",
" \n",
" with wds.TarWriter(str(shard_path)) as sink:\n",
" for i, (img_path, label) in enumerate(chunk):\n",
" try:\n",
" # Магия сжатия\n",
" with Image.open(img_path) as img:\n",
" img = img.convert(\"RGB\")\n",
" img = img.resize((256, 256), Image.Resampling.BILINEAR)\n",
" with io.BytesIO() as img_byte_arr:\n",
" img.save(img_byte_arr, format='JPEG', quality=85)\n",
" image_data = img_byte_arr.getvalue()\n",
" \n",
" key = f\"{shard_idx:06d}_{i:05d}\"\n",
" sink.write({\n",
" \"__key__\": key,\n",
" \"jpg\": image_data,\n",
" \"cls\": label\n",
" })\n",
" \n",
" # Чтобы не перегружать очередь, отправляем отчет каждые 50 файлов\n",
" if (i + 1) % 50 == 0:\n",
" queue.put((\"file\", 50))\n",
" \n",
" except Exception:\n",
" # Если файл битый, всё равно считаем его \"пройденным\", чтобы бар не застрял\n",
" queue.put((\"file\", 1))\n",
" continue\n",
" \n",
" # Сообщаем об остатке файлов в чанке, которые не попали в % 50\n",
" remainder = len(chunk) % 50\n",
" if remainder != 0:\n",
" queue.put((\"file\", remainder))\n",
" \n",
" # Сообщаем, что целый шард готов\n",
" queue.put((\"shard\", 1))\n",
" return shard_idx\n",
"\n",
"# --- ФУНКЦИЯ ОТРИСОВКИ ИНТЕРФЕЙСА (Фоновый поток) ---\n",
"def ui_thread_func(q, total_files, total_shards):\n",
" # Создаем две красивые независимые полоски\n",
" pbar_files = tqdm(total=total_files, desc=\"🖼️ Сжато файлов\", color=\"blue\")\n",
" pbar_shards = tqdm(total=total_shards, desc=\"📦 Готово архивов\", color=\"green\")\n",
" \n",
" while True:\n",
" msg = q.get()\n",
" if msg == \"DONE\":\n",
" break\n",
" \n",
" msg_type, count = msg\n",
" if msg_type == \"file\":\n",
" pbar_files.update(count)\n",
" elif msg_type == \"shard\":\n",
" pbar_shards.update(count)\n",
" \n",
" pbar_files.close()\n",
" pbar_shards.close()\n",
"\n",
"# === ГЛАВНЫЙ ЗАПУСК ===\n",
"if __name__ == '__main__':\n",
" # 1. Создаем диспетчер очередей\n",
" manager = mp.Manager()\n",
" queue = manager.Queue()\n",
" \n",
" # 2. Запускаем фоновый поток отрисовки\n",
" ui_thread = threading.Thread(target=ui_thread_func, args=(queue, TOTAL_FILES, TOTAL_SHARDS))\n",
" ui_thread.start()\n",
" \n",
" # 3. Запускаем 42 боевых ядра\n",
" tasks = [(i, chunk, queue) for i, chunk in enumerate(chunks)]\n",
" with ProcessPoolExecutor(max_workers=NUM_WORKERS) as executor:\n",
" for _ in executor.map(build_shard, tasks):\n",
" pass # Просто ждем завершения всех задач\n",
" \n",
" # 4. Убиваем поток отрисовки и завершаем работу\n",
" queue.put(\"DONE\")\n",
" ui_thread.join()\n",
" \n",
" print(\"\\n🎉 ПАРАЛЛЕЛЬНАЯ УПАКОВКА И СЖАТИЕ ПОЛНОСТЬЮ ЗАВЕРШЕНЫ!\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (my-python-project)",
"language": "python",
"name": "my-python-project"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
-199
View File
@@ -1,199 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "ca08df84",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 0/1000, Loss: 1.0013\n",
"Step 10/1000, Loss: 1.0088\n",
"Step 20/1000, Loss: 0.9956\n",
"Step 30/1000, Loss: 0.9781\n",
"Step 40/1000, Loss: 0.9613\n",
"Step 50/1000, Loss: 0.9313\n",
"Step 60/1000, Loss: 0.8927\n",
"Step 70/1000, Loss: 0.8503\n",
"Step 80/1000, Loss: 0.7537\n",
"Step 90/1000, Loss: 0.6689\n",
"Step 100/1000, Loss: 0.6063\n",
"Step 110/1000, Loss: 0.5172\n",
"Step 120/1000, Loss: 0.4592\n",
"Step 130/1000, Loss: 0.4044\n",
"Step 140/1000, Loss: 0.3610\n",
"Step 150/1000, Loss: 0.3175\n",
"Step 160/1000, Loss: 0.2825\n",
"Step 170/1000, Loss: 0.2560\n",
"Step 180/1000, Loss: 0.2360\n",
"Step 190/1000, Loss: 0.2203\n",
"Step 200/1000, Loss: 0.1930\n",
"Step 210/1000, Loss: 0.1854\n",
"Step 220/1000, Loss: 0.1723\n",
"Step 230/1000, Loss: 0.1546\n",
"Step 240/1000, Loss: 0.1386\n",
"Step 250/1000, Loss: 0.1271\n",
"Step 260/1000, Loss: 0.1109\n",
"Step 270/1000, Loss: 0.1032\n",
"Step 280/1000, Loss: 0.0899\n",
"Step 290/1000, Loss: 0.0807\n",
"Step 300/1000, Loss: 0.0750\n",
"Step 310/1000, Loss: 0.0813\n",
"Step 320/1000, Loss: 0.0612\n",
"Step 330/1000, Loss: 0.0544\n",
"Step 340/1000, Loss: 0.0552\n",
"Step 350/1000, Loss: 0.0446\n",
"Step 360/1000, Loss: 0.0403\n",
"Step 370/1000, Loss: 0.0350\n",
"Step 380/1000, Loss: 0.0612\n",
"Step 390/1000, Loss: 0.0364\n",
"Step 400/1000, Loss: 0.0322\n",
"Step 410/1000, Loss: 0.0302\n",
"Step 420/1000, Loss: 0.0519\n",
"Step 430/1000, Loss: 0.0319\n",
"Step 440/1000, Loss: 0.0260\n",
"Step 450/1000, Loss: 0.0208\n",
"Step 460/1000, Loss: 0.0409\n",
"Step 470/1000, Loss: 0.0291\n",
"Step 480/1000, Loss: 0.0234\n",
"Step 490/1000, Loss: 0.0194\n",
"Step 500/1000, Loss: 0.0274\n",
"Step 510/1000, Loss: 0.0231\n",
"Step 520/1000, Loss: 0.0199\n",
"Step 530/1000, Loss: 0.0154\n",
"Step 540/1000, Loss: 0.0278\n",
"Step 550/1000, Loss: 0.0185\n",
"Step 560/1000, Loss: 0.0180\n",
"Step 570/1000, Loss: 0.0152\n",
"Step 580/1000, Loss: 0.0132\n",
"Step 590/1000, Loss: 0.0111\n",
"Step 600/1000, Loss: 0.0396\n",
"Step 610/1000, Loss: 0.0179\n",
"Step 620/1000, Loss: 0.0148\n",
"Step 630/1000, Loss: 0.0123\n",
"Step 640/1000, Loss: 0.0265\n",
"Step 650/1000, Loss: 0.0133\n",
"Step 660/1000, Loss: 0.0128\n",
"Step 670/1000, Loss: 0.0107\n",
"Step 680/1000, Loss: 0.0142\n",
"Step 690/1000, Loss: 0.0202\n",
"Step 700/1000, Loss: 0.0125\n",
"Step 710/1000, Loss: 0.0107\n",
"Step 720/1000, Loss: 0.0140\n",
"Step 730/1000, Loss: 0.0195\n",
"Step 740/1000, Loss: 0.0148\n",
"Step 750/1000, Loss: 0.0109\n",
"Step 760/1000, Loss: 0.0094\n",
"Step 770/1000, Loss: 0.0121\n",
"Step 780/1000, Loss: 0.0233\n",
"Step 790/1000, Loss: 0.0151\n",
"Step 800/1000, Loss: 0.0134\n",
"Step 810/1000, Loss: 0.0117\n",
"Step 820/1000, Loss: 0.0124\n",
"Step 830/1000, Loss: 0.0221\n",
"Step 840/1000, Loss: 0.0161\n",
"Step 850/1000, Loss: 0.0136\n",
"Step 860/1000, Loss: 0.0161\n",
"Step 870/1000, Loss: 0.0194\n",
"Step 880/1000, Loss: 0.0145\n",
"Step 890/1000, Loss: 0.0149\n",
"Step 900/1000, Loss: 0.0232\n",
"Step 910/1000, Loss: 0.0166\n",
"Step 920/1000, Loss: 0.0156\n",
"Step 930/1000, Loss: 0.0276\n",
"Step 940/1000, Loss: 0.0176\n",
"Step 950/1000, Loss: 0.0152\n",
"Step 960/1000, Loss: 0.0162\n",
"Step 970/1000, Loss: 0.0143\n",
"Step 980/1000, Loss: 0.0136\n",
"Step 990/1000, Loss: 0.0117\n",
"Total time: 67.25 s\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import time\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(\"Using device:\", device)\n",
"\n",
"\n",
"# Огромные параметры\n",
"N, D_in, H1, H2, H3, D_out = 300_000, 4096, 2048, 1024, 512, 10\n",
"batch_size = 16_384 # большой батч\n",
"steps = 1000 # много итераций для длительной нагрузки\n",
"\n",
"# Случайные данные на GPU\n",
"x = torch.randn(N, D_in, device=device, dtype=torch.float32)\n",
"y = torch.randn(N, D_out, device=device, dtype=torch.float32)\n",
"\n",
"model = nn.Sequential(\n",
" nn.Linear(D_in, H1),\n",
" nn.ReLU(),\n",
" nn.Linear(H1, H2),\n",
" nn.ReLU(),\n",
" nn.Linear(H2, H3),\n",
" nn.ReLU(),\n",
" nn.Linear(H3, D_out)\n",
").to(device)\n",
"\n",
"loss_fn = nn.MSELoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
"\n",
"start = time.time()\n",
"for t in range(steps):\n",
" idx = torch.randint(0, N, (batch_size,), device=device)\n",
" x_batch = x[idx]\n",
" y_batch = y[idx]\n",
"\n",
" y_pred = model(x_batch)\n",
" loss = loss_fn(y_pred, y_batch)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if t % 10 == 0:\n",
" # замедляем вывод, чтобы можно было наблюдать\n",
" print(f\"Step {t}/{steps}, Loss: {loss.item():.4f}\")\n",
"\n",
"end = time.time()\n",
"print(f\"Total time: {end-start:.2f} s\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
-58
View File
@@ -1,58 +0,0 @@
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.linear_model import RidgeCV
from sklearn.multioutput import MultiOutputRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import joblib
EMO_VA_MAP = {
0: (7.5, 6.5), # amusement
1: (2.0, 8.0), # anger
2: (6.5, 5.0), # awe
3: (7.0, 3.0), # contentment
4: (3.0, 6.0), # disgust
5: (8.0, 8.0), # excitement
6: (2.5, 7.5), # fear
7: (2.0, 2.0), # sadness
}
BASE_DIR = Path(__file__).resolve().parent.parent
EMBEDDINGS_PATH = BASE_DIR / "emoset_test_embeddings.npy"
LABELS_PATH = BASE_DIR / "emoset_test_labels.npy"
print("Загрузка данных...")
X = np.load(EMBEDDINGS_PATH)
y_labels = np.load(LABELS_PATH)
y_va = np.array([EMO_VA_MAP[label] for label in y_labels])
X_train, X_test, y_train, y_test = train_test_split(X, y_va, test_size=0.2, random_state=42)
print("Обучение масштабатора и RidgeCV регрессора...")
model = Pipeline([
('scaler', StandardScaler()),
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
])
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"\nУспех! Обучение завершено!")
print(f"MSE: {mse:.4f}")
print(f"R^2 Score: {r2:.4f}")
print("\n--- ДИАГНОСТИКА РАЗБРОСА ПРЕДСКАЗАНИЙ ---")
print(f"Valence: от {y_pred[:, 0].min():.2f} до {y_pred[:, 0].max():.2f} (Эталон: 2.0 - 8.0)")
print(f"Arousal: от {y_pred[:, 1].min():.2f} до {y_pred[:, 1].max():.2f} (Эталон: 2.0 - 8.0)")
output_model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
output_model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(model, output_model_path)
print(f"\nМодель сохранена в: {output_model_path}")
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 MiB