ref: refactor before chekout
This commit is contained in:
+20
-19
@@ -1,54 +1,55 @@
|
|||||||
import streamlit as st
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
from music_engine.matcher import MusicMatcher
|
from music_engine.matcher import MusicMatcher
|
||||||
from music_engine.image_processor import ImageProcessor
|
from music_engine.image_processor import ImageProcessor
|
||||||
|
|
||||||
# Определяем базовую директорию (папка src)
|
|
||||||
BASE_DIR = Path(__file__).resolve().parent
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
def load_music_engine():
|
def load_music_engine():
|
||||||
"""Загрузка базы данных и модели регрессора."""
|
# Инициализация базы данных и регрессора для музыкального мэтчинга
|
||||||
# music_db.csv лежит в dataset/DEAM/ (на уровень выше от src)
|
|
||||||
db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
|
db_path = BASE_DIR.parent / "dataset" / "DEAM" / "music_db.csv"
|
||||||
# va_regressor.pkl лежит в src/music_engine/
|
|
||||||
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
|
model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
|
||||||
|
|
||||||
if not db_path.exists():
|
if not db_path.exists():
|
||||||
print(f"⚠️ Файл базы {db_path} не найден!")
|
print(f"Музыкальная БД не найдена: {db_path}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return MusicMatcher(db_path=db_path, model_path=model_path)
|
return MusicMatcher(db_path=db_path, model_path=model_path)
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
def load_image_processor():
|
def load_image_processor():
|
||||||
"""Загрузка ResNet-50 для извлечения признаков на лету."""
|
# Модуль обработки визуальных признаков
|
||||||
# Файл весов лежит в той же папке src, что и этот скрипт
|
|
||||||
model_path = BASE_DIR / "emoset_resnet50_best.pth"
|
model_path = BASE_DIR / "emoset_resnet50_best.pth"
|
||||||
|
|
||||||
|
# Обработка пути при вызове из корневой директории
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
print(f"Ошибка: Веса не найдены по пути: {model_path}")
|
|
||||||
# Если не нашли в src, попробуем поискать в корне проекта на всякий случай
|
|
||||||
model_path = BASE_DIR.parent / "emoset_resnet50_best.pth"
|
model_path = BASE_DIR.parent / "emoset_resnet50_best.pth"
|
||||||
|
|
||||||
return ImageProcessor(model_path=model_path)
|
return ImageProcessor(model_path=model_path)
|
||||||
|
|
||||||
@st.cache_data
|
@st.cache_data
|
||||||
def load_emoset_data():
|
def load_emoset_data():
|
||||||
"""Загрузка тестовой выборки EmoSet для первой вкладки."""
|
# Выборка данных датасета для вкладки отладки
|
||||||
# Пути относительно корня проекта
|
dataset_root = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test"
|
||||||
csv_path = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "labels.csv"
|
|
||||||
img_dir = BASE_DIR.parent / "dataset" / "EmoSet-118K" / "test" / "images"
|
csv_path = dataset_root / "labels.csv"
|
||||||
|
img_dir = dataset_root / "images"
|
||||||
emb_path = BASE_DIR / "emoset_test_embeddings.npy"
|
emb_path = BASE_DIR / "emoset_test_embeddings.npy"
|
||||||
lbl_path = BASE_DIR / "emoset_test_labels.npy"
|
lbl_path = BASE_DIR / "emoset_test_labels.npy"
|
||||||
|
|
||||||
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
|
if not all([csv_path.exists(), emb_path.exists(), lbl_path.exists()]):
|
||||||
|
print("Тестовые файлы датасета не найдены, вкладка отладки может работать некорректно")
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
df = pd.read_csv(csv_path)
|
labels_df = pd.read_csv(csv_path)
|
||||||
image_list = df['filename'].tolist()
|
|
||||||
embs = np.load(emb_path)
|
|
||||||
lbls = np.load(lbl_path)
|
|
||||||
|
|
||||||
return image_list, embs, lbls, img_dir
|
test_filenames = labels_df['filename'].tolist()
|
||||||
|
test_embeddings = np.load(emb_path)
|
||||||
|
test_labels = np.load(lbl_path)
|
||||||
|
|
||||||
|
return test_filenames, test_embeddings, test_labels, img_dir
|
||||||
@@ -1,49 +1,62 @@
|
|||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from PIL import Image
|
|
||||||
import timm
|
import timm
|
||||||
from pathlib import Path
|
|
||||||
import numpy as np
|
|
||||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||||
|
|
||||||
class ImageProcessor:
|
class ImageProcessor:
|
||||||
def __init__(self, model_path: Path | str):
|
def __init__(self, weights_path: str | Path):
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
self.emo_model = timm.create_model('resnet50', pretrained=False, num_classes=8)
|
# Модель извлечения визуальных признаков
|
||||||
if Path(model_path).exists():
|
self.feature_extractor = timm.create_model('resnet50', pretrained=False, num_classes=8)
|
||||||
self.emo_model.load_state_dict(torch.load(model_path, map_location=self.device))
|
|
||||||
|
|
||||||
self.emo_model.fc = torch.nn.Identity()
|
if Path(weights_path).exists():
|
||||||
self.emo_model.to(self.device).eval()
|
self.feature_extractor.load_state_dict(torch.load(weights_path, map_location=self.device))
|
||||||
|
else:
|
||||||
|
print(f"Не удалось найти веса ResNet по пути: {weights_path}")
|
||||||
|
|
||||||
self.emo_transform = T.Compose([
|
# Удаление слоя классификации для вывода сырого вектора эмбеддингов
|
||||||
|
self.feature_extractor.fc = torch.nn.Identity()
|
||||||
|
self.feature_extractor.to(self.device).eval()
|
||||||
|
|
||||||
|
# Трансформации для предварительной обработки изображений
|
||||||
|
self.preprocess_image = T.Compose([
|
||||||
T.Resize((224, 224)),
|
T.Resize((224, 224)),
|
||||||
T.ToTensor(),
|
T.ToTensor(),
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
])
|
])
|
||||||
|
|
||||||
print("Загрузка BLIP-2...")
|
# Модуль семантического описания сцены
|
||||||
|
print("Инициализация BLIP-2...")
|
||||||
self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||||
self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
||||||
"Salesforce/blip2-opt-2.7b",
|
"Salesforce/blip2-opt-2.7b",
|
||||||
torch_dtype=torch.float16
|
torch_dtype=torch.float16
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
print("BLIP-2 и ResNet-50 готовы.")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def extract_embedding(self, image: Image.Image) -> np.ndarray:
|
def extract_embedding(self, image: Image.Image) -> np.ndarray:
|
||||||
img_rgb = image.convert('RGB')
|
# Извлечение эмбеддингов из изображения
|
||||||
img_tensor = self.emo_transform(img_rgb).unsqueeze(0).to(self.device)
|
rgb_image = image.convert('RGB')
|
||||||
return self.emo_model(img_tensor).cpu().numpy().flatten()
|
img_tensor = self.preprocess_image(rgb_image).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
features = self.feature_extractor(img_tensor)
|
||||||
|
features_np = features.cpu().numpy()
|
||||||
|
|
||||||
|
return features_np.flatten()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def describe_scene(self, image: Image.Image) -> str:
|
def describe_scene(self, image: Image.Image) -> str:
|
||||||
"""Генерирует описание через BLIP-2."""
|
# Генерация текстового описания сцены
|
||||||
img_rgb = image.convert('RGB')
|
rgb_image = image.convert('RGB')
|
||||||
|
|
||||||
inputs = self.blip_processor(images=img_rgb, return_tensors="pt").to(self.device, torch.float16)
|
|
||||||
|
|
||||||
|
inputs = self.blip_processor(images=rgb_image, return_tensors="pt").to(self.device, torch.float16)
|
||||||
generated_ids = self.blip_model.generate(**inputs, max_new_tokens=40)
|
generated_ids = self.blip_model.generate(**inputs, max_new_tokens=40)
|
||||||
caption = self.blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
|
||||||
return caption
|
scene_description = self.blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
|
return scene_description.strip()
|
||||||
@@ -1,31 +1,31 @@
|
|||||||
import requests
|
|
||||||
import json
|
|
||||||
import re
|
import re
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
|
||||||
class LLMAcousticBridge:
|
class LLMAcousticBridge:
|
||||||
def __init__(self, model_name="dolphin-llama3:8b"):
|
def __init__(self, target_model="dolphin-llama3:8b"):
|
||||||
self.model_name = model_name
|
|
||||||
self.api_url = "http://localhost:11434/api/generate"
|
self.api_url = "http://localhost:11434/api/generate"
|
||||||
|
self.model = target_model
|
||||||
|
|
||||||
def _clean_json(self, text):
|
def _extract_json(self, raw_text: str):
|
||||||
"""Вытаскивает чистый JSON из ответа нейросети."""
|
# Проверка на ИИдиота, LLM иногда игнорирует format="json" и оборачивает ответ в маркдаун
|
||||||
try:
|
try:
|
||||||
match = re.search(r'\{.*\}', text, re.DOTALL)
|
match = re.search(r'\{.*\}', raw_text, re.DOTALL)
|
||||||
if match:
|
if match:
|
||||||
return json.loads(match.group(0))
|
return json.loads(match.group(0))
|
||||||
return json.loads(text)
|
return json.loads(raw_text)
|
||||||
except:
|
except json.JSONDecodeError:
|
||||||
|
# Если ИИдиот
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_acoustic_profile(self, valence, arousal, scene_descriptions):
|
def get_acoustic_profile(self, v_score: float, a_score: float, scene_context: list) -> dict | None:
|
||||||
"""Просит LLM сгенерировать идеальный звук под описание."""
|
# Агрегация контекста для обработки серии снимков (события)
|
||||||
# Объединяем описания, если загружено несколько фото
|
context_merged = " | ".join(scene_context) if scene_context else "abstract scene"
|
||||||
context_str = " | ".join(scene_descriptions) if scene_descriptions else "abstract scene"
|
|
||||||
|
|
||||||
prompt = f"""You are an expert music producer and acoustic engineer.
|
system_prompt = f"""You are an expert music producer and acoustic engineer.
|
||||||
Analyze the visual context and emotions to determine the ideal background music properties.
|
Analyze the visual context and emotions to determine the ideal background music properties.
|
||||||
Emotions: Valence {valence:.1f}/9.0 (Positivity), Arousal {arousal:.1f}/9.0 (Energy).
|
Emotions: Valence {v_score:.1f}/9.0 (Positivity), Arousal {a_score:.1f}/9.0 (Energy).
|
||||||
Visual Context: {context_str}.
|
Visual Context: {context_merged}.
|
||||||
|
|
||||||
Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0.
|
Map this scene to exactly 6 acoustic features. Values MUST be floats between 0.0 and 1.0.
|
||||||
1. "energy": (Loudness/Density. High for massive/busy scenes, Low for calm)
|
1. "energy": (Loudness/Density. High for massive/busy scenes, Low for calm)
|
||||||
@@ -39,22 +39,27 @@ Return ONLY a valid JSON object. Do not add any text or explanation.
|
|||||||
Example: {{"energy": 0.5, "flux": 0.2, "centroid": 0.4, "pitch": 0.3, "hnr": 0.8, "zcr": 0.1}}"""
|
Example: {{"energy": 0.5, "flux": 0.2, "centroid": 0.4, "pitch": 0.3, "hnr": 0.8, "zcr": 0.1}}"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Отправка промпта локальной Ollama
|
||||||
response = requests.post(self.api_url, json={
|
response = requests.post(self.api_url, json={
|
||||||
"model": self.model_name,
|
"model": self.model,
|
||||||
"prompt": prompt,
|
"prompt": system_prompt,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"format": "json"
|
"format": "json"
|
||||||
}, timeout=30)
|
}, timeout=45)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
result_text = response.json().get("response", "")
|
raw_response = response.json().get("response", "")
|
||||||
profile = self._clean_json(result_text)
|
profile_data = self._extract_json(raw_response)
|
||||||
|
|
||||||
# Проверяем, что все нужные ключи есть
|
# Валидация структуры ответа
|
||||||
required_keys = ['energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr']
|
expected_features = {'energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr'}
|
||||||
if profile and all(k in profile for k in required_keys):
|
|
||||||
return profile
|
if profile_data and expected_features.issubset(profile_data.keys()):
|
||||||
|
return profile_data
|
||||||
|
|
||||||
|
print("LLM вернула неполный или некорректный набор акустических признаков")
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
|
||||||
print(f"Ошибка связи с локальной LLM: {e}")
|
except requests.exceptions.RequestException as req_err:
|
||||||
|
print(f"Не удалось подключиться к Ollama: {req_err}")
|
||||||
return None
|
return None
|
||||||
+41
-25
@@ -1,67 +1,83 @@
|
|||||||
|
import joblib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import joblib
|
|
||||||
|
|
||||||
class MusicMatcher:
|
class MusicMatcher:
|
||||||
def __init__(self, db_path: Path | str, model_path: Path | str):
|
def __init__(self, db_path: Path | str, model_path: Path | str):
|
||||||
# Загружаем твою новую, обогащенную базу
|
# Загрузка базы данных музыкальных произведений
|
||||||
self.music_db = pd.read_csv(db_path)
|
self.music_db = pd.read_csv(db_path)
|
||||||
self.acoustic_features = ['energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr']
|
self.acoustic_features = ['energy', 'flux', 'centroid', 'pitch', 'hnr', 'zcr']
|
||||||
|
|
||||||
# Удаляем строки, где нет акустических фич
|
# Удаление записей с пропущенными целевыми или акустическими признаками
|
||||||
self.music_db = self.music_db.dropna(subset=['valence', 'arousal'] + self.acoustic_features)
|
target_columns = ['valence', 'arousal'] + self.acoustic_features
|
||||||
|
self.music_db = self.music_db.dropna(subset=target_columns)
|
||||||
|
|
||||||
# Нормализуем акустику от 0 до 1, чтобы сравнивать с ответом LLM
|
# Масштабирование акустических параметров к диапазону [0, 1]
|
||||||
self.norm_db = self.music_db.copy()
|
self.norm_db = self.music_db.copy()
|
||||||
for feat in self.acoustic_features:
|
for feat in self.acoustic_features:
|
||||||
f_min, f_max = self.norm_db[feat].min(), self.norm_db[feat].max()
|
f_min = self.norm_db[feat].min()
|
||||||
|
f_max = self.norm_db[feat].max()
|
||||||
if f_max > f_min:
|
if f_max > f_min:
|
||||||
self.norm_db[f"norm_{feat}"] = (self.norm_db[feat] - f_min) / (f_max - f_min)
|
self.norm_db[f"norm_{feat}"] = (self.norm_db[feat] - f_min) / (f_max - f_min)
|
||||||
else:
|
else:
|
||||||
self.norm_db[f"norm_{feat}"] = 0.0
|
self.norm_db[f"norm_{feat}"] = 0.0
|
||||||
|
|
||||||
|
# Определение путей к аудиофайлам и загрузка модели регрессии
|
||||||
self.audio_dir = Path(db_path).parent / "DEAM_audio" / "MEMD_audio"
|
self.audio_dir = Path(db_path).parent / "DEAM_audio" / "MEMD_audio"
|
||||||
self.regressor = joblib.load(model_path) if Path(model_path).exists() else None
|
|
||||||
|
|
||||||
def predict_va(self, embedding: np.ndarray):
|
if Path(model_path).exists():
|
||||||
if self.regressor:
|
self.regressor = joblib.load(model_path)
|
||||||
prediction = self.regressor.predict(embedding.reshape(1, -1))[0]
|
else:
|
||||||
return np.clip(prediction[0], 1.0, 9.0), np.clip(prediction[1], 1.0, 9.0)
|
self.regressor = None
|
||||||
|
|
||||||
|
def predict_va(self, embedding: np.ndarray) -> tuple[float, float]:
|
||||||
|
# Прогнозирование координат Valence/Arousal по визуальному эмбеддингу
|
||||||
|
if not self.regressor:
|
||||||
return 5.0, 5.0
|
return 5.0, 5.0
|
||||||
|
|
||||||
def get_audio_path(self, song_id):
|
raw_prediction = self.regressor.predict(embedding.reshape(1, -1))[0]
|
||||||
if not self.audio_dir.exists(): return None
|
valence_pred = np.clip(raw_prediction[0], 1.0, 9.0)
|
||||||
|
arousal_pred = np.clip(raw_prediction[1], 1.0, 9.0)
|
||||||
|
|
||||||
|
return float(valence_pred), float(arousal_pred)
|
||||||
|
|
||||||
|
def get_audio_path(self, song_id: int | float | str) -> Path | None:
|
||||||
|
# Поиск физического пути к аудиофайлу в зависимости от расширения
|
||||||
|
if not self.audio_dir.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
clean_id = str(int(float(song_id)))
|
clean_id = str(int(float(song_id)))
|
||||||
for ext in ['.mp3', '.wav']:
|
for ext in ['.mp3', '.wav']:
|
||||||
path = self.audio_dir / f"{clean_id}{ext}"
|
path = self.audio_dir / f"{clean_id}{ext}"
|
||||||
if path.exists(): return path
|
if path.exists():
|
||||||
|
return path
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def find_nearest_tracks(self, target_v: float, target_a: float, llm_profile: dict = None, top_k: int = 5):
|
def find_nearest_tracks(self, target_v: float, target_a: float, llm_profile: dict = None, top_k: int = 5) -> pd.DataFrame:
|
||||||
# 1. Эмоциональная дистанция (как и раньше)
|
# Расчет евклидова расстояния в эмоциональном пространстве Рассела
|
||||||
emo_dist = np.sqrt(
|
v_dist = (self.norm_db['valence'] - target_v) ** 2
|
||||||
1.0 * (self.norm_db['valence'] - target_v)**2 +
|
a_dist = (self.norm_db['arousal'] - target_a) ** 2
|
||||||
2.5 * (self.norm_db['arousal'] - target_a)**2
|
|
||||||
)
|
|
||||||
self.norm_db['emo_distance'] = emo_dist
|
|
||||||
|
|
||||||
# Если LLM не дала ответ, сортируем только по эмоциям
|
# Взвешенное расстояние с приоритетом оси активации (Arousal)
|
||||||
|
self.norm_db['emo_distance'] = np.sqrt(1.0 * v_dist + 2.5 * a_dist)
|
||||||
|
|
||||||
|
# Ранжирование только по эмоциональному критерию при отсутствии профиля LLM
|
||||||
if not llm_profile:
|
if not llm_profile:
|
||||||
self.norm_db['final_score'] = self.norm_db['emo_distance']
|
self.norm_db['final_score'] = self.norm_db['emo_distance']
|
||||||
return self.norm_db.sort_values(by='final_score').head(top_k)
|
return self.norm_db.sort_values(by='final_score').head(top_k)
|
||||||
|
|
||||||
# 2. Акустическая дистанция (сравниваем треки с запросом LLM)
|
# Расчет отклонений по вектору акустических параметров LLM
|
||||||
acoustic_penalty = np.zeros(len(self.norm_db))
|
acoustic_penalty = np.zeros(len(self.norm_db))
|
||||||
for feat in self.acoustic_features:
|
for feat in self.acoustic_features:
|
||||||
if feat in llm_profile:
|
if feat in llm_profile:
|
||||||
target_val = llm_profile[feat]
|
target_val = llm_profile[feat]
|
||||||
acoustic_penalty += np.abs(self.norm_db[f"norm_{feat}"] - target_val)
|
acoustic_penalty += np.abs(self.norm_db[f"norm_{feat}"] - target_val)
|
||||||
|
|
||||||
# Усредняем штраф
|
# Нормирование акустической дистанции
|
||||||
self.norm_db['acoustic_distance'] = acoustic_penalty / len(self.acoustic_features)
|
self.norm_db['acoustic_distance'] = acoustic_penalty / len(self.acoustic_features)
|
||||||
|
|
||||||
# 3. Финальный Score (Смесь Эмоций и Акустики). Коэф 4.0 делает акустику важной!
|
# Вычисление интегральной метрики соответствия (мультимодальный скоринг)
|
||||||
self.norm_db['final_score'] = self.norm_db['emo_distance'] + (self.norm_db['acoustic_distance'] * 4.0)
|
self.norm_db['final_score'] = self.norm_db['emo_distance'] + (self.norm_db['acoustic_distance'] * 4.0)
|
||||||
|
|
||||||
return self.norm_db.sort_values(by='final_score').head(top_k)
|
return self.norm_db.sort_values(by='final_score').head(top_k)
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
import kagglehub
|
||||||
|
|
||||||
|
dataset_dir = Path("../dataset/DEAM")
|
||||||
|
dataset_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print("Скачивание датасета DEAM...")
|
||||||
|
|
||||||
|
# kagglehub по умолчанию тянет данные в системный кэш (~/.cache)
|
||||||
|
cache_path = kagglehub.dataset_download("imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music")
|
||||||
|
|
||||||
|
print(f"Загружено в кэш: {cache_path}")
|
||||||
|
print(f"Перенос файлов в {dataset_dir} и очистка временной директории...")
|
||||||
|
|
||||||
|
# Перемещаем данные
|
||||||
|
shutil.copytree(cache_path, dataset_dir, dirs_exist_ok=True)
|
||||||
|
shutil.rmtree(cache_path)
|
||||||
|
|
||||||
|
print("Готово. Датасет DEAM загружен, кэш очищен.")
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
import csv
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Конфигурация корневой директории локального датасета
|
||||||
|
DATASET_DIR = Path("../dataset/EmoSet-118K")
|
||||||
|
|
||||||
|
def process_and_save_split(dataset_split, split_name: str, output_dir: Path):
|
||||||
|
# Подготовка структуры директорий для текущей выборки
|
||||||
|
split_dir = output_dir / split_name
|
||||||
|
img_dir = split_dir / "images"
|
||||||
|
img_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
labels_path = split_dir / "labels.csv"
|
||||||
|
|
||||||
|
print(f"Обработка выборки: {split_name}...")
|
||||||
|
|
||||||
|
# Открытие файла разметки перед циклом для минимизации I/O операций диска
|
||||||
|
with open(labels_path, mode="w", newline="", encoding="utf-8") as csv_file:
|
||||||
|
writer = csv.writer(csv_file)
|
||||||
|
writer.writerow(["filename", "label"])
|
||||||
|
|
||||||
|
for example in tqdm(dataset_split, desc=split_name):
|
||||||
|
img = example["image"]
|
||||||
|
emotion_label = example["emotion"]
|
||||||
|
img_id = example["image_id"]
|
||||||
|
|
||||||
|
file_name = f"{img_id}.jpg"
|
||||||
|
|
||||||
|
# Принудительная конвертация в RGB для безопасного сохранения в JPEG-формате
|
||||||
|
if img.mode != "RGB":
|
||||||
|
img = img.convert("RGB")
|
||||||
|
|
||||||
|
img.save(img_dir / file_name, format="JPEG")
|
||||||
|
|
||||||
|
writer.writerow([file_name, emotion_label])
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
DATASET_DIR.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
# Инициализация подключения к Hugging Face Hub
|
||||||
|
print("Загрузка метаданных EmoSet-118K...")
|
||||||
|
raw_dataset = load_dataset("Woleek/EmoSet-118K")
|
||||||
|
|
||||||
|
# Итеративная выгрузка размеченных данных
|
||||||
|
for split_key in ["train", "val", "test"]:
|
||||||
|
if split_key in raw_dataset:
|
||||||
|
process_and_save_split(
|
||||||
|
dataset_split=raw_dataset[split_key],
|
||||||
|
split_name=split_key,
|
||||||
|
output_dir=DATASET_DIR
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Экспорт датасета завершен.")
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Конфигурация локальных путей
|
||||||
|
SOURCE_CSV = Path("../../dataset/DEAM/DEAM_Annotations/annotations/annotations averaged per song/song_level/static_annotations_averaged_songs_1_2000.csv")
|
||||||
|
OUTPUT_CSV = Path("../../dataset/DEAM/music_db.csv")
|
||||||
|
|
||||||
|
def prepare_deam_database():
|
||||||
|
if not SOURCE_CSV.exists():
|
||||||
|
print(f"Исходный файл аннотаций не найден: {SOURCE_CSV}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("Обработка разметки датасета DEAM...")
|
||||||
|
|
||||||
|
# Загрузка сырых данных с очисткой артефактов форматирования
|
||||||
|
raw_df = pd.read_csv(SOURCE_CSV, skipinitialspace=True)
|
||||||
|
|
||||||
|
# Экстракция координат пространства Рассела (Valence/Arousal)
|
||||||
|
processed_df = raw_df[['song_id', 'valence_mean', 'arousal_mean']].copy()
|
||||||
|
processed_df.columns = ['song_id', 'valence', 'arousal']
|
||||||
|
|
||||||
|
# Приведение идентификаторов к формату файловой системы (int)
|
||||||
|
processed_df['song_id'] = processed_df['song_id'].astype(int)
|
||||||
|
|
||||||
|
processed_df.to_csv(OUTPUT_CSV, index=False)
|
||||||
|
|
||||||
|
print(f"База успешно сформирована. Всего записей: {len(processed_df)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
prepare_deam_database()
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
# Конфигурация параметров нагрузочного тестирования
|
||||||
|
NUM_SAMPLES = 300_000
|
||||||
|
DIM_IN = 4096
|
||||||
|
DIM_OUT = 10
|
||||||
|
BATCH_SIZE = 16_384
|
||||||
|
NUM_STEPS = 1000
|
||||||
|
|
||||||
|
def run_gpu_benchmark():
|
||||||
|
# Проверка доступности аппаратного ускорения
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Инициализация стресс-теста на устройстве: {device}")
|
||||||
|
|
||||||
|
# Генерация синтетического датасета для аллокации VRAM
|
||||||
|
x_data = torch.randn(NUM_SAMPLES, DIM_IN, device=device, dtype=torch.float32)
|
||||||
|
y_data = torch.randn(NUM_SAMPLES, DIM_OUT, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Архитектура тестовой полносвязной сети
|
||||||
|
model = nn.Sequential(
|
||||||
|
nn.Linear(DIM_IN, 2048),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(2048, 1024),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(1024, 512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(512, DIM_OUT)
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
loss_fn = nn.MSELoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
print("Начало прогрева GPU и симуляции цикла обучения...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for step in range(NUM_STEPS):
|
||||||
|
# Сэмплирование случайного батча
|
||||||
|
idx = torch.randint(0, NUM_SAMPLES, (BATCH_SIZE,), device=device)
|
||||||
|
x_batch = x_data[idx]
|
||||||
|
y_batch = y_data[idx]
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
predictions = model(x_batch)
|
||||||
|
loss = loss_fn(predictions, y_batch)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# Логирование статуса (каждые 100 итераций для снижения I/O overhead)
|
||||||
|
if step % 100 == 0:
|
||||||
|
print(f"Итерация {step}/{NUM_STEPS} | Текущий loss: {loss.item():.4f}")
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"Стресс-тест завершен. Общее время: {end_time - start_time:.2f} сек.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_gpu_benchmark()
|
||||||
@@ -3,30 +3,22 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "9336560f",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "0c00b67b",
|
"id": "0c00b67b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"\n",
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"import torch.nn as nn\n",
|
"import torch.nn as nn\n",
|
||||||
"from torch.utils.data import Dataset, DataLoader\n",
|
"from torch.utils.data import Dataset, DataLoader\n",
|
||||||
"import torchvision.transforms as T\n",
|
"import torchvision.transforms as T\n",
|
||||||
"\n",
|
"import timm"
|
||||||
"import pandas as pd\n",
|
|
||||||
"from pathlib import Path\n",
|
|
||||||
"from PIL import Image\n",
|
|
||||||
"from tqdm import tqdm\n",
|
|
||||||
"\n",
|
|
||||||
"import timm\n",
|
|
||||||
"import numpy as np\n"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -47,40 +39,52 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# === CONFIG ===\n",
|
"# Конфигурация параметров обучения и путей файловой системы\n",
|
||||||
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
|
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
|
||||||
"BATCH_SIZE = 64\n",
|
"BATCH_SIZE = 64\n",
|
||||||
"EPOCHS = 15\n",
|
"EPOCHS = 15\n",
|
||||||
"LR = 3e-4\n",
|
"LR = 3e-4\n",
|
||||||
"NUM_WORKERS = 24\n",
|
"NUM_WORKERS = 40\n",
|
||||||
"\n",
|
"\n",
|
||||||
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||||
"DEVICE\n"
|
"print(f\"Аппаратное ускорение: {device}\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": null,
|
||||||
"id": "9f749add",
|
"id": "9f749add",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"class EmoSetDataset(Dataset):\n",
|
"class EmoSetDataset(Dataset):\n",
|
||||||
" def __init__(self, root, split):\n",
|
" def __init__(self, root: Path | str, split: str):\n",
|
||||||
" self.root = Path(root) / split\n",
|
" self.root = Path(root) / split\n",
|
||||||
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
|
" self.df = pd.read_csv(self.root / \"labels.csv\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" # Формирование словарей маппинга классов\n",
|
||||||
" self.labels = sorted(self.df[\"label\"].unique())\n",
|
" self.labels = sorted(self.df[\"label\"].unique())\n",
|
||||||
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
|
" self.label2idx = {l: i for i, l in enumerate(self.labels)}\n",
|
||||||
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
|
" self.idx2label = {i: l for l, i in self.label2idx.items()}\n",
|
||||||
"\n",
|
"\n",
|
||||||
" self.transform = T.Compose([\n",
|
" # Базовые трансформации для валидации и теста\n",
|
||||||
" T.Resize((224, 224)),\n",
|
" base_tf = [\n",
|
||||||
" T.ToTensor(),\n",
|
" T.ToTensor(),\n",
|
||||||
" T.Normalize(\n",
|
" T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
|
||||||
" mean=[0.485, 0.456, 0.406],\n",
|
" ]\n",
|
||||||
" std=[0.229, 0.224, 0.225]\n",
|
"\n",
|
||||||
" )\n",
|
" # Внедрение аугментации исключительно для обучающей выборки (предотвращение переобучения)\n",
|
||||||
|
" if split == \"train\":\n",
|
||||||
|
" self.transform = T.Compose([\n",
|
||||||
|
" T.RandomResizedCrop(224),\n",
|
||||||
|
" T.RandomHorizontalFlip(),\n",
|
||||||
|
" *base_tf\n",
|
||||||
|
" ])\n",
|
||||||
|
" else:\n",
|
||||||
|
" self.transform = T.Compose([\n",
|
||||||
|
" T.Resize(256),\n",
|
||||||
|
" T.CenterCrop(224),\n",
|
||||||
|
" *base_tf\n",
|
||||||
" ])\n",
|
" ])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def __len__(self):\n",
|
" def __len__(self):\n",
|
||||||
@@ -90,16 +94,21 @@
|
|||||||
" row = self.df.iloc[idx]\n",
|
" row = self.df.iloc[idx]\n",
|
||||||
" img_path = self.root / \"images\" / row[\"filename\"]\n",
|
" img_path = self.root / \"images\" / row[\"filename\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" # Обработка возможных исключений ввода-вывода (поврежденные JPEG-файлы в датасете)\n",
|
||||||
|
" try:\n",
|
||||||
" img = Image.open(img_path).convert(\"RGB\")\n",
|
" img = Image.open(img_path).convert(\"RGB\")\n",
|
||||||
" img = self.transform(img)\n",
|
" except Exception:\n",
|
||||||
|
" img = Image.new(\"RGB\", (224, 224), (0, 0, 0))\n",
|
||||||
"\n",
|
"\n",
|
||||||
" label = self.label2idx[row[\"label\"]]\n",
|
" img_tensor = self.transform(img)\n",
|
||||||
" return img, label\n"
|
" label_idx = self.label2idx[row[\"label\"]]\n",
|
||||||
|
" \n",
|
||||||
|
" return img_tensor, label_idx"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": null,
|
||||||
"id": "c8805341",
|
"id": "c8805341",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -112,9 +121,11 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# Подготовка объектов выборки\n",
|
||||||
"train_ds = EmoSetDataset(DATA_ROOT, \"train\")\n",
|
"train_ds = EmoSetDataset(DATA_ROOT, \"train\")\n",
|
||||||
"val_ds = EmoSetDataset(DATA_ROOT, \"val\")\n",
|
"val_ds = EmoSetDataset(DATA_ROOT, \"val\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# Инициализация итераторов с закреплением памяти (pin_memory) для ускорения передачи на GPU\n",
|
||||||
"train_loader = DataLoader(\n",
|
"train_loader = DataLoader(\n",
|
||||||
" train_ds,\n",
|
" train_ds,\n",
|
||||||
" batch_size=BATCH_SIZE,\n",
|
" batch_size=BATCH_SIZE,\n",
|
||||||
@@ -131,12 +142,12 @@
|
|||||||
" pin_memory=True\n",
|
" pin_memory=True\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"Classes:\", train_ds.labels)\n"
|
"print(f\"Индексированные классы: {train_ds.labels}\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": null,
|
||||||
"id": "dffce582",
|
"id": "dffce582",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -391,55 +402,51 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# TODO перед защитой, повторить оптимизаторы\n",
|
||||||
|
"# Загрузка предобученной архитектуры ResNet-50 с заменой классификационного слоя\n",
|
||||||
"model = timm.create_model(\n",
|
"model = timm.create_model(\n",
|
||||||
" \"resnet50\",\n",
|
" \"resnet50\",\n",
|
||||||
" pretrained=True,\n",
|
" pretrained=True,\n",
|
||||||
" num_classes=len(train_ds.labels)\n",
|
" num_classes=len(train_ds.labels)\n",
|
||||||
")\n",
|
")\n",
|
||||||
|
"model.to(device)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model.to(DEVICE)\n"
|
"# Функция потерь для многоклассовой классификации\n",
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"id": "81a457ef",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"criterion = nn.CrossEntropyLoss()\n",
|
"criterion = nn.CrossEntropyLoss()\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# Оптимизатор AdamW с L2-регуляризацией (weight_decay) для повышения обобщающей способности\n",
|
||||||
"optimizer = torch.optim.AdamW(\n",
|
"optimizer = torch.optim.AdamW(\n",
|
||||||
" model.parameters(),\n",
|
" model.parameters(),\n",
|
||||||
" lr=LR,\n",
|
" lr=LR,\n",
|
||||||
" weight_decay=1e-4\n",
|
" weight_decay=1e-4\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# Планировщик скорости обучения: косинусный отжиг\n",
|
||||||
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
|
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
|
||||||
" optimizer,\n",
|
" optimizer,\n",
|
||||||
" T_max=EPOCHS\n",
|
" T_max=EPOCHS\n",
|
||||||
")\n"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": null,
|
||||||
"id": "951aa9e3",
|
"id": "81a457ef",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def train_epoch(model, loader):\n",
|
"def train_epoch(current_model, loader):\n",
|
||||||
" model.train()\n",
|
" current_model.train()\n",
|
||||||
" total_loss = 0\n",
|
" total_loss = 0.0\n",
|
||||||
" correct = 0\n",
|
" correct_preds = 0\n",
|
||||||
" total = 0\n",
|
" total_samples = 0\n",
|
||||||
"\n",
|
"\n",
|
||||||
" for imgs, labels in tqdm(loader, leave=False):\n",
|
" for imgs, labels in tqdm(loader, desc=\"Тренировка\", leave=False):\n",
|
||||||
" imgs = imgs.to(DEVICE)\n",
|
" imgs = imgs.to(device)\n",
|
||||||
" labels = labels.to(DEVICE)\n",
|
" labels = labels.to(device)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" optimizer.zero_grad()\n",
|
" optimizer.zero_grad()\n",
|
||||||
" logits = model(imgs)\n",
|
" logits = current_model(imgs)\n",
|
||||||
" loss = criterion(logits, labels)\n",
|
" loss = criterion(logits, labels)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" loss.backward()\n",
|
" loss.backward()\n",
|
||||||
@@ -447,292 +454,67 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" total_loss += loss.item() * imgs.size(0)\n",
|
" total_loss += loss.item() * imgs.size(0)\n",
|
||||||
" preds = logits.argmax(dim=1)\n",
|
" preds = logits.argmax(dim=1)\n",
|
||||||
" correct += (preds == labels).sum().item()\n",
|
" correct_preds += (preds == labels).sum().item()\n",
|
||||||
" total += labels.size(0)\n",
|
" total_samples += labels.size(0)\n",
|
||||||
|
"\n",
|
||||||
|
" return total_loss / total_samples, correct_preds / total_samples\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return total_loss / total, correct / total\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 8,
|
|
||||||
"id": "fb7e9398",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"@torch.no_grad()\n",
|
"@torch.no_grad()\n",
|
||||||
"def val_epoch(model, loader):\n",
|
"def val_epoch(current_model, loader):\n",
|
||||||
" model.eval()\n",
|
" # Перевод модели в режим инференса (отключение Dropout и фиксация BatchNorm)\n",
|
||||||
" total_loss = 0\n",
|
" current_model.eval()\n",
|
||||||
" correct = 0\n",
|
" total_loss = 0.0\n",
|
||||||
" total = 0\n",
|
" correct_preds = 0\n",
|
||||||
|
" total_samples = 0\n",
|
||||||
"\n",
|
"\n",
|
||||||
" for imgs, labels in loader:\n",
|
" for imgs, labels in tqdm(loader, desc=\"Валидация\", leave=False):\n",
|
||||||
" imgs = imgs.to(DEVICE)\n",
|
" imgs = imgs.to(device)\n",
|
||||||
" labels = labels.to(DEVICE)\n",
|
" labels = labels.to(device)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" logits = model(imgs)\n",
|
" logits = current_model(imgs)\n",
|
||||||
" loss = criterion(logits, labels)\n",
|
" loss = criterion(logits, labels)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" total_loss += loss.item() * imgs.size(0)\n",
|
" total_loss += loss.item() * imgs.size(0)\n",
|
||||||
" preds = logits.argmax(dim=1)\n",
|
" preds = logits.argmax(dim=1)\n",
|
||||||
" correct += (preds == labels).sum().item()\n",
|
" correct_preds += (preds == labels).sum().item()\n",
|
||||||
" total += labels.size(0)\n",
|
" total_samples += labels.size(0)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return total_loss / total, correct / total\n"
|
" return total_loss / total_samples, correct_preds / total_samples"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"id": "9e870e5d",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" 0%| | 0/1477 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 01 | Train loss: 0.8383, acc: 0.6954 | Val loss: 0.6694, acc: 0.7563\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 02 | Train loss: 0.5462, acc: 0.7972 | Val loss: 0.6592, acc: 0.7594\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 03 | Train loss: 0.3654, acc: 0.8632 | Val loss: 0.7263, acc: 0.7600\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 04 | Train loss: 0.2111, acc: 0.9230 | Val loss: 0.8572, acc: 0.7472\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 05 | Train loss: 0.1187, acc: 0.9585 | Val loss: 1.0372, acc: 0.7453\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 06 | Train loss: 0.0690, acc: 0.9768 | Val loss: 1.1982, acc: 0.7529\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 07 | Train loss: 0.0466, acc: 0.9843 | Val loss: 1.3178, acc: 0.7492\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 08 | Train loss: 0.0295, acc: 0.9905 | Val loss: 1.3926, acc: 0.7551\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 09 | Train loss: 0.0204, acc: 0.9938 | Val loss: 1.4682, acc: 0.7497\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 10 | Train loss: 0.0146, acc: 0.9955 | Val loss: 1.4784, acc: 0.7604\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 11 | Train loss: 0.0087, acc: 0.9975 | Val loss: 1.5263, acc: 0.7580\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 12 | Train loss: 0.0057, acc: 0.9987 | Val loss: 1.5689, acc: 0.7558\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 13 | Train loss: 0.0044, acc: 0.9990 | Val loss: 1.5952, acc: 0.7566\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 14 | Train loss: 0.0030, acc: 0.9993 | Val loss: 1.6130, acc: 0.7600\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 15 | Train loss: 0.0025, acc: 0.9995 | Val loss: 1.5921, acc: 0.7627\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"best_val_acc = 0.0\n",
|
|
||||||
"\n",
|
|
||||||
"for epoch in range(1, EPOCHS + 1):\n",
|
|
||||||
" train_loss, train_acc = train_epoch(model, train_loader)\n",
|
|
||||||
" val_loss, val_acc = val_epoch(model, val_loader)\n",
|
|
||||||
"\n",
|
|
||||||
" scheduler.step()\n",
|
|
||||||
"\n",
|
|
||||||
" print(\n",
|
|
||||||
" f\"Epoch {epoch:02d} | \"\n",
|
|
||||||
" f\"Train loss: {train_loss:.4f}, acc: {train_acc:.4f} | \"\n",
|
|
||||||
" f\"Val loss: {val_loss:.4f}, acc: {val_acc:.4f}\"\n",
|
|
||||||
" )\n",
|
|
||||||
"\n",
|
|
||||||
" if val_acc > best_val_acc:\n",
|
|
||||||
" best_val_acc = val_acc\n",
|
|
||||||
" torch.save(model.state_dict(), \"emoset_resnet50_best.pth\")\n"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "7796ef11",
|
"id": "951aa9e3",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": [
|
||||||
|
"best_val_acc = 0.0\n",
|
||||||
|
"checkpoint_path = \"../emoset_resnet50_best.pth\"\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Старт процесса обучения...\")\n",
|
||||||
|
"\n",
|
||||||
|
"for epoch in range(1, EPOCHS + 1):\n",
|
||||||
|
" train_loss, train_acc = train_epoch(model, train_loader)\n",
|
||||||
|
" val_loss, val_acc = val_epoch(model, val_loader)\n",
|
||||||
|
"\n",
|
||||||
|
" # Обновление шага планировщика\n",
|
||||||
|
" scheduler.step()\n",
|
||||||
|
"\n",
|
||||||
|
" print(\n",
|
||||||
|
" f\"Эпоха {epoch:02d}/{EPOCHS} | \"\n",
|
||||||
|
" f\"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | \"\n",
|
||||||
|
" f\"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}\"\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" # Экспорт весов при улучшении целевой метрики\n",
|
||||||
|
" if val_acc > best_val_acc:\n",
|
||||||
|
" best_val_acc = val_acc\n",
|
||||||
|
" torch.save(model.state_dict(), checkpoint_path)\n",
|
||||||
|
" print(f\" -> Сохранен новый лучший чекпоинт (Acc: {best_val_acc:.4f})\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Обучение завершено.\")"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,69 @@
|
|||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Конфигурация путей и целевых признаков
|
||||||
|
BASE_DIR = Path("../../dataset/DEAM")
|
||||||
|
MUSIC_DB_PATH = BASE_DIR / "music_db.csv"
|
||||||
|
FEATURES_DIR = BASE_DIR / "features" / "features"
|
||||||
|
OUTPUT_PATH = BASE_DIR / "music_db_enriched.csv"
|
||||||
|
|
||||||
|
# Маппинг низкоуровневых признаков экстрактора (openSMILE/GeMAPS) в дескрипторы системы
|
||||||
|
TARGET_FEATURES = {
|
||||||
|
'pcm_RMSenergy_sma_amean': 'energy',
|
||||||
|
'pcm_fftMag_spectralFlux_sma_amean': 'flux',
|
||||||
|
'pcm_fftMag_spectralCentroid_sma_amean': 'centroid',
|
||||||
|
'F0final_sma_amean': 'pitch',
|
||||||
|
'logHNR_sma_amean': 'hnr',
|
||||||
|
'pcm_zcr_sma_amean': 'zcr',
|
||||||
|
'pcm_fftMag_spectralEntropy_sma_amean': 'entropy',
|
||||||
|
'pcm_fftMag_psySharpness_sma_amean': 'sharpness'
|
||||||
|
}
|
||||||
|
|
||||||
|
def aggregate_acoustic_features():
|
||||||
|
if not MUSIC_DB_PATH.exists():
|
||||||
|
print(f"Базовый файл аннотаций не найден: {MUSIC_DB_PATH}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("Загрузка эмоциональной разметки DEAM...")
|
||||||
|
df_main = pd.read_csv(MUSIC_DB_PATH)
|
||||||
|
|
||||||
|
print("Агрегация фреймовых акустических признаков...")
|
||||||
|
aggregated_data = []
|
||||||
|
|
||||||
|
# Итерация по трекам для сбора покадровых характеристик
|
||||||
|
for _, row in tqdm(df_main.iterrows(), total=len(df_main), desc="Обработка аудио-векторов"):
|
||||||
|
song_id = int(row['song_id'])
|
||||||
|
feature_file = FEATURES_DIR / f"{song_id}.csv"
|
||||||
|
|
||||||
|
if feature_file.exists():
|
||||||
|
try:
|
||||||
|
# Чтение сырых векторов (формат csv с разделителем ';')
|
||||||
|
df_feat = pd.read_csv(feature_file, sep=';')
|
||||||
|
|
||||||
|
# Усреднение характеристик по временной оси (time frames)
|
||||||
|
mean_features = df_feat[list(TARGET_FEATURES.keys())].mean()
|
||||||
|
|
||||||
|
# Формирование агрегированной записи
|
||||||
|
track_data = {'song_id': song_id}
|
||||||
|
for orig_col, new_col in TARGET_FEATURES.items():
|
||||||
|
track_data[new_col] = mean_features[orig_col]
|
||||||
|
|
||||||
|
aggregated_data.append(track_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Ошибка парсинга файла {feature_file.name}: {e}")
|
||||||
|
|
||||||
|
# Слияние акустических дескрипторов с эмоциональными координатами (Inner Join)
|
||||||
|
df_features = pd.DataFrame(aggregated_data)
|
||||||
|
df_enriched = pd.merge(df_main, df_features, on='song_id', how='inner')
|
||||||
|
|
||||||
|
# Очистка возможных артефактов NaN после агрегации
|
||||||
|
df_enriched = df_enriched.dropna(subset=list(TARGET_FEATURES.values()))
|
||||||
|
|
||||||
|
df_enriched.to_csv(OUTPUT_PATH, index=False)
|
||||||
|
print(f"Экспорт завершен. Сформирована обогащенная база: {OUTPUT_PATH.name}")
|
||||||
|
print(f"Итоговый размер выборки: {len(df_enriched)} треков.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
aggregate_acoustic_features()
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
import joblib
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sklearn.linear_model import RidgeCV
|
||||||
|
from sklearn.multioutput import MultiOutputRegressor
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from sklearn.pipeline import Pipeline
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.metrics import mean_squared_error, r2_score
|
||||||
|
|
||||||
|
# Проекция дискретных классов эмоций на непрерывное пространство Рассела (Valence, Arousal)
|
||||||
|
# Значения откалиброваны в диапазоне [1.0, 9.0]
|
||||||
|
EMOTION_TO_VA_COORDS = {
|
||||||
|
0: (7.5, 6.5), # amusement
|
||||||
|
1: (2.0, 8.0), # anger
|
||||||
|
2: (6.5, 5.0), # awe
|
||||||
|
3: (7.0, 3.0), # contentment
|
||||||
|
4: (3.0, 6.0), # disgust
|
||||||
|
5: (8.0, 8.0), # excitement
|
||||||
|
6: (2.5, 7.5), # fear
|
||||||
|
7: (2.0, 2.0), # sadness
|
||||||
|
}
|
||||||
|
|
||||||
|
def train_va_regressor():
|
||||||
|
# Настройка путей
|
||||||
|
base_dir = Path(__file__).resolve().parent.parent
|
||||||
|
embeddings_path = base_dir / "emoset_test_embeddings.npy"
|
||||||
|
labels_path = base_dir / "emoset_test_labels.npy"
|
||||||
|
model_output_path = base_dir / "music_engine" / "va_regressor.pkl"
|
||||||
|
|
||||||
|
if not embeddings_path.exists() or not labels_path.exists():
|
||||||
|
print(f"Артефакты признаков не найдены в директории: {base_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("Загрузка вектора признаков и меток классов...")
|
||||||
|
x_features = np.load(embeddings_path)
|
||||||
|
y_discrete = np.load(labels_path)
|
||||||
|
|
||||||
|
# Трансформация целевой переменной: классы -> непрерывные координаты V/A
|
||||||
|
y_continuous = np.array([EMOTION_TO_VA_COORDS[label] for label in y_discrete])
|
||||||
|
|
||||||
|
x_train, x_test, y_train, y_test = train_test_split(
|
||||||
|
x_features, y_continuous, test_size=0.2, random_state=42
|
||||||
|
)
|
||||||
|
|
||||||
|
# Построение пайплайна: Z-масштабирование и L2-регуляризованная регрессия
|
||||||
|
# RidgeCV автоматически подбирает оптимальный гиперпараметр alpha (силу регуляризации)
|
||||||
|
print("Инициализация и обучение пайплайна RidgeCV...")
|
||||||
|
regression_pipeline = Pipeline([
|
||||||
|
('scaler', StandardScaler()),
|
||||||
|
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
|
||||||
|
])
|
||||||
|
|
||||||
|
regression_pipeline.fit(x_train, y_train)
|
||||||
|
|
||||||
|
# Оценка обобщающей способности модели
|
||||||
|
y_pred = regression_pipeline.predict(x_test)
|
||||||
|
|
||||||
|
mse_score = mean_squared_error(y_test, y_pred)
|
||||||
|
r2 = r2_score(y_test, y_pred)
|
||||||
|
|
||||||
|
print("Обучение завершено. Метрики качества на тестовой выборке:")
|
||||||
|
print(f" - MSE: {mse_score:.4f}")
|
||||||
|
print(f" - R^2: {r2:.4f}")
|
||||||
|
|
||||||
|
# Диагностика дисперсии предсказаний
|
||||||
|
v_min, v_max = y_pred[:, 0].min(), y_pred[:, 0].max()
|
||||||
|
a_min, a_max = y_pred[:, 1].min(), y_pred[:, 1].max()
|
||||||
|
print(f"Распределение Valence (прогноз): [{v_min:.2f}, {v_max:.2f}] (Эталон: 1.0 - 9.0)")
|
||||||
|
print(f"Распределение Arousal (прогноз): [{a_min:.2f}, {a_max:.2f}] (Эталон: 1.0 - 9.0)")
|
||||||
|
|
||||||
|
# Экспорт обученного пайплайна
|
||||||
|
model_output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
joblib.dump(regression_pipeline, model_output_path)
|
||||||
|
print(f"Пайплайн сохранен: {model_output_path.name}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_va_regressor()
|
||||||
@@ -0,0 +1,264 @@
|
|||||||
|
import os
|
||||||
|
import gc
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import torchvision.transforms as T
|
||||||
|
import torchvision.io as tv_io
|
||||||
|
from torch.amp import autocast, GradScaler
|
||||||
|
from tqdm import tqdm
|
||||||
|
import timm
|
||||||
|
|
||||||
|
# Конфигурация стенда и путей файловой системы
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M")
|
||||||
|
CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl")
|
||||||
|
|
||||||
|
PREVIOUS_WEIGHTS = Path("/home/zin/projects/Thesis/src/emoset_resnet50_best.pth")
|
||||||
|
RESUME_CHECKPOINT = Path("/home/zin/projects/Thesis/src/emoset_resnet50_resume.pth")
|
||||||
|
SAVE_MODEL_PATH = Path("/home/zin/projects/Thesis/src/emoset_resnet50_finetuned_2_41M.pth")
|
||||||
|
|
||||||
|
CLASS_MAPPING = {
|
||||||
|
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
|
||||||
|
"disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7
|
||||||
|
}
|
||||||
|
|
||||||
|
# Гиперпараметры конвейера обучения
|
||||||
|
BATCH_SIZE = 82
|
||||||
|
EPOCHS = 15
|
||||||
|
LR = 5e-5
|
||||||
|
NUM_TRAIN_WORKERS = 48
|
||||||
|
NUM_VAL_WORKERS = 18
|
||||||
|
PATIENCE = 4
|
||||||
|
|
||||||
|
def prepare_dataset_index():
|
||||||
|
# Построение или загрузка индекса файлов для минимизации I/O операций по сети (NFS)
|
||||||
|
if CACHE_PATH.exists():
|
||||||
|
print(f"Загрузка карты файловой системы из кэша: {CACHE_PATH.name}")
|
||||||
|
with open(CACHE_PATH, 'rb') as f:
|
||||||
|
cache_data = pickle.load(f)
|
||||||
|
return cache_data['image_paths'], cache_data['labels']
|
||||||
|
|
||||||
|
print(f"Сканирование сетевой директории {DATA_ROOT} (первичная индексация)...")
|
||||||
|
paths, labels = [], []
|
||||||
|
for img_path in DATA_ROOT.rglob('*.jpg'):
|
||||||
|
emotion_folder = img_path.parts[-3].lower()
|
||||||
|
if emotion_folder in CLASS_MAPPING:
|
||||||
|
paths.append(str(img_path))
|
||||||
|
labels.append(CLASS_MAPPING[emotion_folder])
|
||||||
|
|
||||||
|
with open(CACHE_PATH, 'wb') as f:
|
||||||
|
pickle.dump({'image_paths': paths, 'labels': labels}, f)
|
||||||
|
|
||||||
|
return paths, labels
|
||||||
|
|
||||||
|
class EmoSetDirectDataset(Dataset):
|
||||||
|
# Датасет с отложенной аугментацией: декодирование на CPU, трансформации на GPU
|
||||||
|
def __init__(self, image_paths, labels):
|
||||||
|
self.image_paths = image_paths
|
||||||
|
self.labels = labels
|
||||||
|
self.base_transform = T.Resize((256, 256), antialias=True)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_paths)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
try:
|
||||||
|
image = tv_io.read_image(self.image_paths[idx], mode=tv_io.ImageReadMode.RGB)
|
||||||
|
image = image.to(torch.float32) / 255.0
|
||||||
|
image = self.base_transform(image)
|
||||||
|
except Exception:
|
||||||
|
# Изолирование сбоев ввода-вывода (поврежденные файлы на сетевом диске)
|
||||||
|
image = torch.zeros((3, 256, 256), dtype=torch.float32)
|
||||||
|
return image, self.labels[idx]
|
||||||
|
|
||||||
|
def build_gpu_transforms():
|
||||||
|
# Перенос матричных операций аугментации на тензорные ядра видеокарты
|
||||||
|
train_tf = torch.nn.Sequential(
|
||||||
|
T.RandomCrop((224, 224)),
|
||||||
|
T.RandomHorizontalFlip(p=0.5),
|
||||||
|
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
|
||||||
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
).to(DEVICE)
|
||||||
|
|
||||||
|
val_tf = torch.nn.Sequential(
|
||||||
|
T.CenterCrop((224, 224)),
|
||||||
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
).to(DEVICE)
|
||||||
|
|
||||||
|
return train_tf, val_tf
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(f"Инициализация конвейера обучения. Устройство: {DEVICE}")
|
||||||
|
|
||||||
|
all_paths, all_labels = prepare_dataset_index()
|
||||||
|
|
||||||
|
# Фиксация сида для детерминированного разделения выборок при перезапусках скрипта
|
||||||
|
random.seed(42)
|
||||||
|
combined = list(zip(all_paths, all_labels))
|
||||||
|
random.shuffle(combined)
|
||||||
|
all_paths, all_labels = zip(*combined)
|
||||||
|
|
||||||
|
split_idx = int(len(all_paths) * 0.95)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
EmoSetDirectDataset(all_paths[:split_idx], all_labels[:split_idx]),
|
||||||
|
batch_size=BATCH_SIZE, shuffle=True,
|
||||||
|
num_workers=NUM_TRAIN_WORKERS, pin_memory=True,
|
||||||
|
prefetch_factor=2, persistent_workers=True
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
EmoSetDirectDataset(all_paths[split_idx:], all_labels[split_idx:]),
|
||||||
|
batch_size=BATCH_SIZE, shuffle=False,
|
||||||
|
num_workers=NUM_VAL_WORKERS, pin_memory=True,
|
||||||
|
prefetch_factor=2, persistent_workers=True
|
||||||
|
)
|
||||||
|
|
||||||
|
gpu_train_tf, gpu_val_tf = build_gpu_transforms()
|
||||||
|
|
||||||
|
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
|
||||||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
||||||
|
scaler = GradScaler()
|
||||||
|
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
epochs_no_improve = 0
|
||||||
|
start_epoch = 1
|
||||||
|
|
||||||
|
# Инициализация механизма отказоустойчивости и интеграция весов
|
||||||
|
if RESUME_CHECKPOINT.exists():
|
||||||
|
print(f"Восстановление контекста выполнения из: {RESUME_CHECKPOINT.name}")
|
||||||
|
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
|
if 'scaler_state_dict' in checkpoint: scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
||||||
|
if 'best_val_loss' in checkpoint: best_val_loss = checkpoint['best_val_loss']
|
||||||
|
start_epoch = checkpoint['epoch'] + 1
|
||||||
|
elif PREVIOUS_WEIGHTS.exists():
|
||||||
|
print(f"Интеграция претренированных весов: {PREVIOUS_WEIGHTS.name}")
|
||||||
|
model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE))
|
||||||
|
else:
|
||||||
|
print("Веса не найдены. Инициализация с ImageNet.")
|
||||||
|
model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for epoch in range(start_epoch, EPOCHS + 1):
|
||||||
|
|
||||||
|
# Проход по обучающей выборке
|
||||||
|
model.train()
|
||||||
|
running_loss, correct, total = 0.0, 0, 0
|
||||||
|
|
||||||
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [Train]")
|
||||||
|
for inputs, labels in pbar:
|
||||||
|
try:
|
||||||
|
inputs = inputs.to(DEVICE, non_blocking=True)
|
||||||
|
labels = labels.to(DEVICE, non_blocking=True)
|
||||||
|
inputs = gpu_train_tf(inputs)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Смешанная точность для экономии VRAM
|
||||||
|
with autocast(device_type="cuda"):
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
running_loss += loss.item() * inputs.size(0)
|
||||||
|
_, predicted = outputs.max(1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += predicted.eq(labels).sum().item()
|
||||||
|
|
||||||
|
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
||||||
|
|
||||||
|
except RuntimeError as memory_err:
|
||||||
|
# Подавление пиковых скачков потребления VRAM
|
||||||
|
if "out of memory" in str(memory_err).lower():
|
||||||
|
if 'outputs' in locals(): del outputs
|
||||||
|
if 'loss' in locals(): del loss
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
continue
|
||||||
|
raise memory_err
|
||||||
|
|
||||||
|
train_loss = running_loss / total if total > 0 else 0
|
||||||
|
train_acc = correct / total if total > 0 else 0
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Проход по валидационной выборке
|
||||||
|
model.eval()
|
||||||
|
val_loss, val_correct, val_total = 0.0, 0, 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for val_inputs, val_labels in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} [Val]", leave=False):
|
||||||
|
val_inputs = val_inputs.to(DEVICE, non_blocking=True)
|
||||||
|
val_labels = val_labels.to(DEVICE, non_blocking=True)
|
||||||
|
val_inputs = gpu_val_tf(val_inputs)
|
||||||
|
|
||||||
|
with autocast(device_type="cuda"):
|
||||||
|
val_outputs = model(val_inputs)
|
||||||
|
v_loss = criterion(val_outputs, val_labels)
|
||||||
|
|
||||||
|
val_loss += v_loss.item() * val_inputs.size(0)
|
||||||
|
_, val_predicted = val_outputs.max(1)
|
||||||
|
val_total += val_labels.size(0)
|
||||||
|
val_correct += val_predicted.eq(val_labels).sum().item()
|
||||||
|
|
||||||
|
epoch_val_loss = val_loss / val_total if val_total > 0 else 0
|
||||||
|
epoch_val_acc = val_correct / val_total if val_total > 0 else 0
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
print(f"[{epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}")
|
||||||
|
|
||||||
|
# Оценка критериев ранней остановки и сохранение состояния сессии
|
||||||
|
if epoch_val_loss < best_val_loss:
|
||||||
|
best_val_loss = epoch_val_loss
|
||||||
|
epochs_no_improve = 0
|
||||||
|
torch.save(model.state_dict(), str(SAVE_MODEL_PATH).replace(".pth", "_best.pth"))
|
||||||
|
else:
|
||||||
|
epochs_no_improve += 1
|
||||||
|
if epochs_no_improve >= PATIENCE and epoch >= 15:
|
||||||
|
print(f"Сработал механизм Early Stopping. Валидация не улучшается {PATIENCE} эпох.")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Атомарное сохранение контекста
|
||||||
|
checkpoint_state = {
|
||||||
|
'epoch': epoch,
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'scheduler_state_dict': scheduler.state_dict(),
|
||||||
|
'scaler_state_dict': scaler.state_dict(),
|
||||||
|
'best_val_loss': best_val_loss
|
||||||
|
}
|
||||||
|
torch.save(checkpoint_state, RESUME_CHECKPOINT)
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nВыполнение прервано пользователем (SIGINT).")
|
||||||
|
print(f"Дамп памяти конвейера зафиксирован на эпохе {epoch}.")
|
||||||
|
checkpoint_state = {
|
||||||
|
'epoch': epoch, 'model_state_dict': model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': scaler.state_dict(),
|
||||||
|
'best_val_loss': best_val_loss
|
||||||
|
}
|
||||||
|
torch.save(checkpoint_state, RESUME_CHECKPOINT)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if SAVE_MODEL_PATH.parent.exists():
|
||||||
|
torch.save(model.state_dict(), SAVE_MODEL_PATH)
|
||||||
|
print(f"Процесс Fine-Tuning завершен. Артефакт сохранен: {SAVE_MODEL_PATH.name}")
|
||||||
|
if RESUME_CHECKPOINT.exists():
|
||||||
|
RESUME_CHECKPOINT.unlink()
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "0336fd0c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"✅ База загружена. Треков: 1744\n",
|
|
||||||
"🔍 Собираем акустические признаки...\n",
|
|
||||||
"\n",
|
|
||||||
"🚀 ГОТОВО! Обогащенная база сохранена: ../../dataset/DEAM/music_db_enriched.csv\n",
|
|
||||||
"Собрано фичей для 1744 из 1744 треков.\n",
|
|
||||||
" song_id valence arousal energy flux centroid pitch \\\n",
|
|
||||||
"0 2 3.1 3.0 0.097268 0.846947 483.421751 93.884056 \n",
|
|
||||||
"1 3 3.5 3.3 0.126809 0.959460 173.219616 62.682589 \n",
|
|
||||||
"2 4 5.7 5.5 0.156699 1.333944 466.434797 92.850316 \n",
|
|
||||||
"3 5 4.4 5.3 0.126455 1.009927 546.152506 158.673853 \n",
|
|
||||||
"4 7 5.8 6.4 0.268180 1.589191 175.369162 83.823484 \n",
|
|
||||||
"\n",
|
|
||||||
" hnr zcr entropy sharpness \n",
|
|
||||||
"0 3.615380 0.034270 3.299075 0.426490 \n",
|
|
||||||
"1 -2.600122 0.017893 2.294971 0.165583 \n",
|
|
||||||
"2 -0.579130 0.042936 3.258138 0.395410 \n",
|
|
||||||
"3 1.751148 0.043781 3.514585 0.494367 \n",
|
|
||||||
"4 12.006770 0.014783 2.177862 0.170058 \n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import pandas as pd\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"from pathlib import Path\n",
|
|
||||||
"from tqdm import tqdm # для красивого прогресс-бара, если не установлен - убери\n",
|
|
||||||
"\n",
|
|
||||||
"# 1. Пути к файлам\n",
|
|
||||||
"base_dir = Path(\"../../dataset/DEAM\") # Поправь, если запускаешь из другого места\n",
|
|
||||||
"music_db_path = base_dir / \"music_db.csv\"\n",
|
|
||||||
"features_dir = base_dir / \"features\" / \"features\"\n",
|
|
||||||
"output_path = base_dir / \"music_db_enriched.csv\"\n",
|
|
||||||
"\n",
|
|
||||||
"# 2. Наш \"Золотой список\" (8 признаков)\n",
|
|
||||||
"target_columns = {\n",
|
|
||||||
" 'pcm_RMSenergy_sma_amean': 'energy',\n",
|
|
||||||
" 'pcm_fftMag_spectralFlux_sma_amean': 'flux',\n",
|
|
||||||
" 'pcm_fftMag_spectralCentroid_sma_amean': 'centroid',\n",
|
|
||||||
" 'F0final_sma_amean': 'pitch',\n",
|
|
||||||
" 'logHNR_sma_amean': 'hnr',\n",
|
|
||||||
" 'pcm_zcr_sma_amean': 'zcr',\n",
|
|
||||||
" 'pcm_fftMag_spectralEntropy_sma_amean': 'entropy',\n",
|
|
||||||
" 'pcm_fftMag_psySharpness_sma_amean': 'sharpness'\n",
|
|
||||||
"}\n",
|
|
||||||
"\n",
|
|
||||||
"# 3. Загружаем текущую базу с V/A\n",
|
|
||||||
"if not music_db_path.exists():\n",
|
|
||||||
" print(f\"ОШИБКА: Не найден файл {music_db_path}\")\n",
|
|
||||||
"else:\n",
|
|
||||||
" df_main = pd.read_csv(music_db_path)\n",
|
|
||||||
" print(f\"База загружена. Треков: {len(df_main)}\")\n",
|
|
||||||
"\n",
|
|
||||||
" # Подготавливаем новые колонки\n",
|
|
||||||
" for col_name in target_columns.values():\n",
|
|
||||||
" df_main[col_name] = np.nan\n",
|
|
||||||
"\n",
|
|
||||||
" # 4. Проходимся по всем трекам и ищем их акустические CSV\n",
|
|
||||||
" print(\"Собираем акустические признаки...\")\n",
|
|
||||||
" found_count = 0\n",
|
|
||||||
" \n",
|
|
||||||
" for index, row in df_main.iterrows():\n",
|
|
||||||
" song_id = int(row['song_id'])\n",
|
|
||||||
" feature_file = features_dir / f\"{song_id}.csv\"\n",
|
|
||||||
" \n",
|
|
||||||
" if feature_file.exists():\n",
|
|
||||||
" try:\n",
|
|
||||||
" # Читаем CSV с признаками (разделитель там обычно точка с запятой)\n",
|
|
||||||
" df_feat = pd.read_csv(feature_file, sep=';')\n",
|
|
||||||
" \n",
|
|
||||||
" # Усредняем значения по всем фреймам (одна песня разбита на сотни строк-фреймов)\n",
|
|
||||||
" mean_features = df_feat[list(target_columns.keys())].mean()\n",
|
|
||||||
" \n",
|
|
||||||
" # Записываем в главную базу\n",
|
|
||||||
" for orig_col, new_col in target_columns.items():\n",
|
|
||||||
" df_main.at[index, new_col] = mean_features[orig_col]\n",
|
|
||||||
" \n",
|
|
||||||
" found_count += 1\n",
|
|
||||||
" except Exception as e:\n",
|
|
||||||
" print(f\"Ошибка чтения {feature_file}: {e}\")\n",
|
|
||||||
" \n",
|
|
||||||
" # 5. Сохраняем результат\n",
|
|
||||||
" # Удаляем треки, для которых не нашлось фичей (если такие есть)\n",
|
|
||||||
" df_main = df_main.dropna(subset=list(target_columns.values()))\n",
|
|
||||||
" \n",
|
|
||||||
" df_main.to_csv(output_path, index=False)\n",
|
|
||||||
" print(f\"\\nГотово! Обогащенная база сохранена: {output_path}\")\n",
|
|
||||||
" print(f\"Собрано фичей для {found_count} из {len(df_main)} треков.\")\n",
|
|
||||||
" print(df_main.head())"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python (thesis)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "thesis"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.7"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
||||||
@@ -1,614 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "09f9237a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Requirement already satisfied: datasets in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (4.4.2)\n",
|
|
||||||
"Requirement already satisfied: tqdm in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (4.67.1)\n",
|
|
||||||
"Requirement already satisfied: pillow in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (12.1.0)\n",
|
|
||||||
"Requirement already satisfied: requests in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (2.32.5)\n",
|
|
||||||
"Requirement already satisfied: filelock in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (3.20.3)\n",
|
|
||||||
"Requirement already satisfied: numpy>=1.17 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (2.4.1)\n",
|
|
||||||
"Requirement already satisfied: pyarrow>=21.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (22.0.0)\n",
|
|
||||||
"Requirement already satisfied: dill<0.4.1,>=0.3.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.4.0)\n",
|
|
||||||
"Requirement already satisfied: pandas in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (2.3.3)\n",
|
|
||||||
"Requirement already satisfied: httpx<1.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.28.1)\n",
|
|
||||||
"Requirement already satisfied: xxhash in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (3.6.0)\n",
|
|
||||||
"Requirement already satisfied: multiprocess<0.70.19 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (0.70.18)\n",
|
|
||||||
"Requirement already satisfied: fsspec<=2025.10.0,>=2023.1.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2025.10.0)\n",
|
|
||||||
"Requirement already satisfied: huggingface-hub<2.0,>=0.25.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (1.3.1)\n",
|
|
||||||
"Requirement already satisfied: packaging in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (25.0)\n",
|
|
||||||
"Requirement already satisfied: pyyaml>=5.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from datasets) (6.0.3)\n",
|
|
||||||
"Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (3.13.3)\n",
|
|
||||||
"Requirement already satisfied: anyio in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (4.12.1)\n",
|
|
||||||
"Requirement already satisfied: certifi in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (2026.1.4)\n",
|
|
||||||
"Requirement already satisfied: httpcore==1.* in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (1.0.9)\n",
|
|
||||||
"Requirement already satisfied: idna in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpx<1.0.0->datasets) (3.11)\n",
|
|
||||||
"Requirement already satisfied: h11>=0.16 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx<1.0.0->datasets) (0.16.0)\n",
|
|
||||||
"Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.2.0)\n",
|
|
||||||
"Requirement already satisfied: shellingham in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.5.4)\n",
|
|
||||||
"Requirement already satisfied: typer-slim in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (0.21.1)\n",
|
|
||||||
"Requirement already satisfied: typing-extensions>=4.1.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (4.15.0)\n",
|
|
||||||
"Requirement already satisfied: charset_normalizer<4,>=2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests) (3.4.4)\n",
|
|
||||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests) (2.6.3)\n",
|
|
||||||
"Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (2.6.1)\n",
|
|
||||||
"Requirement already satisfied: aiosignal>=1.4.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.4.0)\n",
|
|
||||||
"Requirement already satisfied: attrs>=17.3.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (25.4.0)\n",
|
|
||||||
"Requirement already satisfied: frozenlist>=1.1.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.8.0)\n",
|
|
||||||
"Requirement already satisfied: multidict<7.0,>=4.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (6.7.0)\n",
|
|
||||||
"Requirement already satisfied: propcache>=0.2.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (0.4.1)\n",
|
|
||||||
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets) (1.22.0)\n",
|
|
||||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2.9.0.post0)\n",
|
|
||||||
"Requirement already satisfied: pytz>=2020.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2025.2)\n",
|
|
||||||
"Requirement already satisfied: tzdata>=2022.7 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from pandas->datasets) (2025.3)\n",
|
|
||||||
"Requirement already satisfied: six>=1.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n",
|
|
||||||
"Requirement already satisfied: click>=8.0.0 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from typer-slim->huggingface-hub<2.0,>=0.25.0->datasets) (8.3.1)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"!pip install datasets tqdm pillow requests\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "6f0b2e2c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "95f07577d20642b09f2cda6f0b2cca14",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "868d872a109d49f9966f2f19985e7048",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "06741794289540849ad179c5966dcab8",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"Downloading data: 0%| | 0/18 [00:00<?, ?files/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "e47aad5270144913996cb5b226213ab9",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00000-of-00018.parquet: 0%| | 0.00/509M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "30d1492a948245e3b6b58e92218cd760",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00001-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "931823b458cb4696b459e9011537cf1e",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00002-of-00018.parquet: 0%| | 0.00/489M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "846f4245b16d4cc096a43c940590ad11",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00003-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "71df201ff1a24811af67458c3fe3f2f4",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00004-of-00018.parquet: 0%| | 0.00/495M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "404dce6c69fc413dbe4aa84c289a0ab6",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00005-of-00018.parquet: 0%| | 0.00/501M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "e52b0bbbfdd14c599f44f02a48542317",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00006-of-00018.parquet: 0%| | 0.00/510M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "172981d77fc941cfa32c05f5a34bf742",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00007-of-00018.parquet: 0%| | 0.00/497M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "cc9d886ff22f4165bf696c8b4d758931",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00008-of-00018.parquet: 0%| | 0.00/512M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "5f118a9923c64ee2aa2001a1414927a3",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00009-of-00018.parquet: 0%| | 0.00/502M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "db61d8d556dc4574adbd8f916f790fa7",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00010-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "75414190b19c4affbe190f6dd4f7bc4f",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00011-of-00018.parquet: 0%| | 0.00/500M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "172aa22ed0c44a289e0ac68b240c13c4",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00012-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "2baa935ed3524a73883909752cb15907",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00013-of-00018.parquet: 0%| | 0.00/491M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "5e716611b29b44788e0bf2e7ad05be5b",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00014-of-00018.parquet: 0%| | 0.00/502M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "d9c0baac101b449794155392f07b49c3",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00015-of-00018.parquet: 0%| | 0.00/504M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "b31cdc7f17ac4ac8a04593e8a01a300a",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00016-of-00018.parquet: 0%| | 0.00/507M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "ed6766f750c54b4194957bfe3db78ed6",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/train-00017-of-00018.parquet: 0%| | 0.00/494M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "5454d2ecded64b82a12823f02a7ab12d",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/val-00000-of-00002.parquet: 0%| | 0.00/282M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "62dd1439e0514c98b0c24cc8f600c57e",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/val-00001-of-00002.parquet: 0%| | 0.00/283M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "3a5b966f79314e069251462bff82395f",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/test-00000-of-00004.parquet: 0%| | 0.00/422M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "422974f938924910a0712b30a9c2bd84",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/test-00001-of-00004.parquet: 0%| | 0.00/430M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "f155a08427094de7ad1a5884e623db2b",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/test-00002-of-00004.parquet: 0%| | 0.00/420M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "a94a4621d19f45f690e0064fee83767b",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"data/test-00003-of-00004.parquet: 0%| | 0.00/422M [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "50f55b00a27b4213b573b398e5b0d708",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"Generating train split: 0%| | 0/94481 [00:00<?, ? examples/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "8658b8414f604f0ca2fd248a214ad4aa",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"Generating val split: 0%| | 0/5905 [00:00<?, ? examples/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "d59b7dea75f84b64bb8b262b43730e51",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"Generating test split: 0%| | 0/17716 [00:00<?, ? examples/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "0c5815040f0a4a31903348a8327811a5",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"Loading dataset shards: 0%| | 0/18 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"DatasetDict({\n",
|
|
||||||
" train: Dataset({\n",
|
|
||||||
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
|
|
||||||
" num_rows: 94481\n",
|
|
||||||
" })\n",
|
|
||||||
" val: Dataset({\n",
|
|
||||||
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
|
|
||||||
" num_rows: 5905\n",
|
|
||||||
" })\n",
|
|
||||||
" test: Dataset({\n",
|
|
||||||
" features: ['image', 'label', 'image_id', 'emotion', 'brightness', 'colorfulness', 'facial_expression', 'human_action', 'scene', 'object'],\n",
|
|
||||||
" num_rows: 17716\n",
|
|
||||||
" })\n",
|
|
||||||
"})\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"from datasets import load_dataset\n",
|
|
||||||
"from pathlib import Path\n",
|
|
||||||
"from PIL import Image\n",
|
|
||||||
"import requests\n",
|
|
||||||
"\n",
|
|
||||||
"# куда сохраняем датасет\n",
|
|
||||||
"DATA_DIR = Path(\"../dataset/EmoSet-118K\")\n",
|
|
||||||
"DATA_DIR.mkdir(exist_ok=True, parents=True)\n",
|
|
||||||
"\n",
|
|
||||||
"# загружаем через Hugging Face\n",
|
|
||||||
"ds = load_dataset(\"Woleek/EmoSet-118K\")\n",
|
|
||||||
"\n",
|
|
||||||
"print(ds)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"id": "052ab073",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from tqdm import tqdm\n",
|
|
||||||
"from pathlib import Path\n",
|
|
||||||
"\n",
|
|
||||||
"def save_split(split):\n",
|
|
||||||
" split_dir = DATA_DIR / split\n",
|
|
||||||
" img_dir = split_dir / \"images\"\n",
|
|
||||||
" img_dir.mkdir(parents=True, exist_ok=True)\n",
|
|
||||||
"\n",
|
|
||||||
" labels_path = split_dir / \"labels.csv\"\n",
|
|
||||||
"\n",
|
|
||||||
" # перезаписываем labels.csv\n",
|
|
||||||
" with open(labels_path, \"w\") as f:\n",
|
|
||||||
" f.write(\"filename,label\\n\")\n",
|
|
||||||
"\n",
|
|
||||||
" for example in tqdm(ds[split]):\n",
|
|
||||||
" img = example[\"image\"] # уже PIL.Image\n",
|
|
||||||
" label = example[\"emotion\"]\n",
|
|
||||||
" image_id = example[\"image_id\"]\n",
|
|
||||||
"\n",
|
|
||||||
" fname = f\"{image_id}.jpg\"\n",
|
|
||||||
" img.save(img_dir / fname)\n",
|
|
||||||
"\n",
|
|
||||||
" with open(labels_path, \"a\") as f:\n",
|
|
||||||
" f.write(f\"{fname},{label}\\n\")\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"id": "a74ceedf",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 94481/94481 [18:43<00:00, 84.10it/s] \n",
|
|
||||||
"100%|██████████| 5905/5905 [01:08<00:00, 86.57it/s] \n",
|
|
||||||
"100%|██████████| 17716/17716 [02:57<00:00, 100.01it/s]\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"save_split(\"train\")\n",
|
|
||||||
"save_split(\"val\")\n",
|
|
||||||
"save_split(\"test\")\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "thesis-py3.11",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.7"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
||||||
@@ -1,140 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# Загрузка датасета DEAM\n",
|
|
||||||
"\n",
|
|
||||||
"Этот ноутбук предназначен для автоматизации процесса скачивания и подготовки музыкального датасета **DEAM** (Database for Emotional Analysis in Music).\n",
|
|
||||||
"Данные будут помещены в папку `dataset/DEAM`."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Collecting kagglehub\n",
|
|
||||||
" Downloading kagglehub-1.0.1-py3-none-any.whl.metadata (40 kB)\n",
|
|
||||||
"Collecting kagglesdk<1.0,>=0.1.22 (from kagglehub)\n",
|
|
||||||
" Downloading kagglesdk-0.1.23-py3-none-any.whl.metadata (13 kB)\n",
|
|
||||||
"Requirement already satisfied: packaging in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (25.0)\n",
|
|
||||||
"Requirement already satisfied: pyyaml in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (6.0.3)\n",
|
|
||||||
"Requirement already satisfied: requests in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (2.32.5)\n",
|
|
||||||
"Requirement already satisfied: tqdm in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglehub) (4.67.1)\n",
|
|
||||||
"Requirement already satisfied: protobuf in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from kagglesdk<1.0,>=0.1.22->kagglehub) (6.33.4)\n",
|
|
||||||
"Requirement already satisfied: charset_normalizer<4,>=2 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (3.4.4)\n",
|
|
||||||
"Requirement already satisfied: idna<4,>=2.5 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (3.11)\n",
|
|
||||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (2.6.3)\n",
|
|
||||||
"Requirement already satisfied: certifi>=2017.4.17 in /home/zin/projects/Thesis/.venv/lib/python3.11/site-packages (from requests->kagglehub) (2026.1.4)\n",
|
|
||||||
"Downloading kagglehub-1.0.1-py3-none-any.whl (70 kB)\n",
|
|
||||||
"Downloading kagglesdk-0.1.23-py3-none-any.whl (217 kB)\n",
|
|
||||||
"Installing collected packages: kagglesdk, kagglehub\n",
|
|
||||||
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2/2\u001b[0m [kagglehub]\n",
|
|
||||||
"\u001b[1A\u001b[2KSuccessfully installed kagglehub-1.0.1 kagglesdk-0.1.23\n",
|
|
||||||
"\n",
|
|
||||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.1.1\u001b[0m\n",
|
|
||||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"!pip install kagglehub"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Скачиваем датасет DEAM...\n",
|
|
||||||
"Downloading to /home/zin/.cache/kagglehub/datasets/imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music/1.archive...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 1.83G/1.83G [01:09<00:00, 28.2MB/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Extracting files...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Датасет скачан во временную директорию: /home/zin/.cache/kagglehub/datasets/imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music/versions/1\n",
|
|
||||||
"Переносим файлы в ../dataset/DEAM...\n",
|
|
||||||
"\n",
|
|
||||||
"[УСПЕХ] Датасет DEAM готов к работе!\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import os\n",
|
|
||||||
"import shutil\n",
|
|
||||||
"import kagglehub\n",
|
|
||||||
"from pathlib import Path\n",
|
|
||||||
"\n",
|
|
||||||
"# 1. Настройка путей\n",
|
|
||||||
"DATASET_ROOT = Path(\"../dataset\")\n",
|
|
||||||
"DEAM_ROOT = DATASET_ROOT / \"DEAM\"\n",
|
|
||||||
"DEAM_ROOT.mkdir(parents=True, exist_ok=True)\n",
|
|
||||||
"\n",
|
|
||||||
"# 2. Загрузка через kagglehub\n",
|
|
||||||
"print(\"Скачиваем датасет DEAM...\")\n",
|
|
||||||
"kaggle_cache_path = kagglehub.dataset_download(\"imsparsh/deam-mediaeval-dataset-emotional-analysis-in-music\")\n",
|
|
||||||
"print(f\"Датасет скачан во временную директорию: {kaggle_cache_path}\")\n",
|
|
||||||
"\n",
|
|
||||||
"# 3. Перемещение файлов в проект\n",
|
|
||||||
"print(f\"Переносим файлы в {DEAM_ROOT}...\")\n",
|
|
||||||
"shutil.copytree(kaggle_cache_path, DEAM_ROOT, dirs_exist_ok=True)\n",
|
|
||||||
"\n",
|
|
||||||
"print(\"\\nУСПЕХ! Датасет DEAM готов к работе!\")\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python (my-python-project)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "my-python-project"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.7"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 4
|
|
||||||
}
|
|
||||||
@@ -1,171 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
from torchvision import transforms
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
class EmoSet(Dataset):
|
|
||||||
ATTRIBUTES_MULTI_CLASS = [
|
|
||||||
'scene', 'facial_expression', 'human_action', 'brightness', 'colorfulness',
|
|
||||||
]
|
|
||||||
ATTRIBUTES_MULTI_LABEL = [
|
|
||||||
'object'
|
|
||||||
]
|
|
||||||
NUM_CLASSES = {
|
|
||||||
'brightness': 11,
|
|
||||||
'colorfulness': 11,
|
|
||||||
'scene': 254,
|
|
||||||
'object': 409,
|
|
||||||
'facial_expression': 6,
|
|
||||||
'human_action': 264,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
data_root,
|
|
||||||
num_emotion_classes,
|
|
||||||
phase,
|
|
||||||
):
|
|
||||||
assert num_emotion_classes in (8, 2)
|
|
||||||
assert phase in ('train', 'val', 'test')
|
|
||||||
self.transforms_dict = self.get_data_transforms()
|
|
||||||
|
|
||||||
self.info = self.get_info(data_root, num_emotion_classes)
|
|
||||||
|
|
||||||
if phase == 'train':
|
|
||||||
self.transform = self.transforms_dict['train']
|
|
||||||
elif phase == 'val':
|
|
||||||
self.transform = self.transforms_dict['val']
|
|
||||||
elif phase == 'test':
|
|
||||||
self.transform = self.transforms_dict['test']
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
data_store = json.load(open(os.path.join(data_root, f'{phase}.json')))
|
|
||||||
self.data_store = [
|
|
||||||
[
|
|
||||||
self.info['emotion']['label2idx'][item[0]],
|
|
||||||
item[1],
|
|
||||||
os.path.join(data_root, item[2]),
|
|
||||||
os.path.join(data_root, item[3])
|
|
||||||
]
|
|
||||||
for item in data_store
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_data_transforms(cls):
|
|
||||||
transforms_dict = {
|
|
||||||
'train': transforms.Compose([
|
|
||||||
transforms.RandomResizedCrop(224),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
||||||
]),
|
|
||||||
'val': transforms.Compose([
|
|
||||||
transforms.Resize(224),
|
|
||||||
transforms.CenterCrop(224),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
||||||
]),
|
|
||||||
'test': transforms.Compose([
|
|
||||||
transforms.Resize(224),
|
|
||||||
transforms.CenterCrop(224),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
||||||
]),
|
|
||||||
}
|
|
||||||
return transforms_dict
|
|
||||||
|
|
||||||
def get_info(self, data_root, num_emotion_classes):
|
|
||||||
assert num_emotion_classes in (8, 2)
|
|
||||||
info = json.load(open(os.path.join(data_root, 'info.json')))
|
|
||||||
if num_emotion_classes == 8:
|
|
||||||
pass
|
|
||||||
elif num_emotion_classes == 2:
|
|
||||||
emotion_info = {
|
|
||||||
'label2idx': {
|
|
||||||
'amusement': 0,
|
|
||||||
'awe': 0,
|
|
||||||
'contentment': 0,
|
|
||||||
'excitement': 0,
|
|
||||||
'anger': 1,
|
|
||||||
'disgust': 1,
|
|
||||||
'fear': 1,
|
|
||||||
'sadness': 1,
|
|
||||||
},
|
|
||||||
'idx2label': {
|
|
||||||
'0': 'positive',
|
|
||||||
'1': 'negative',
|
|
||||||
}
|
|
||||||
}
|
|
||||||
info['emotion'] = emotion_info
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
def load_image_by_path(self, path):
|
|
||||||
image = Image.open(path).convert('RGB')
|
|
||||||
image = self.transform(image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def load_annotation_by_path(self, path):
|
|
||||||
json_data = json.load(open(path))
|
|
||||||
return json_data
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
emotion_label_idx, image_id, image_path, annotation_path = self.data_store[item]
|
|
||||||
image = self.load_image_by_path(image_path)
|
|
||||||
annotation_data = self.load_annotation_by_path(annotation_path)
|
|
||||||
data = {'image_id': image_id, 'image': image, 'emotion_label_idx': emotion_label_idx}
|
|
||||||
|
|
||||||
for attribute in self.ATTRIBUTES_MULTI_CLASS:
|
|
||||||
# if empty, set to -1, else set to label index
|
|
||||||
attribute_label_idx = -1
|
|
||||||
if attribute in annotation_data:
|
|
||||||
attribute_label_idx = self.info[attribute]['label2idx'][str(annotation_data[attribute])]
|
|
||||||
data.update({f'{attribute}_label_idx': attribute_label_idx})
|
|
||||||
|
|
||||||
for attribute in self.ATTRIBUTES_MULTI_LABEL:
|
|
||||||
# if empty, set to 0, else set to 1
|
|
||||||
assert attribute == 'object'
|
|
||||||
num_classes = self.NUM_CLASSES[attribute]
|
|
||||||
attribute_label_idx = torch.zeros(num_classes)
|
|
||||||
if attribute in annotation_data:
|
|
||||||
for label in annotation_data[attribute]:
|
|
||||||
attribute_label_idx[self.info[attribute]['label2idx'][label]] = 1
|
|
||||||
data.update({f'{attribute}_label_idx': attribute_label_idx})
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data_store)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
data_root = r'F:\common_file_system\EmoSet\EmoSet_v5_划分train-test-val'
|
|
||||||
num_emotion_classes = 8
|
|
||||||
phase = 'train'
|
|
||||||
|
|
||||||
dataset = EmoSet(
|
|
||||||
data_root=data_root,
|
|
||||||
num_emotion_classes=num_emotion_classes,
|
|
||||||
phase=phase,
|
|
||||||
)
|
|
||||||
|
|
||||||
# print(dataset.info)
|
|
||||||
dataloader = DataLoader(dataset, batch_size = 16, shuffle = True)
|
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
|
||||||
pass
|
|
||||||
# print(data['emotion_label_idx'])
|
|
||||||
# print(data['scene_label_idx'])
|
|
||||||
# print(data['facial_expression_label_idx'])
|
|
||||||
# print(data['human_action_label_idx'])
|
|
||||||
# print(data['brightness_label_idx'])
|
|
||||||
# print(data['colorfulness_label_idx'])
|
|
||||||
# print(data['object_label_idx'])
|
|
||||||
# break
|
|
||||||
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,314 +0,0 @@
|
|||||||
import os
|
|
||||||
import gc
|
|
||||||
import pickle
|
|
||||||
import random
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
import torchvision.transforms as T
|
|
||||||
import torchvision.io as tv_io
|
|
||||||
from torch.amp import autocast, GradScaler
|
|
||||||
from tqdm import tqdm
|
|
||||||
import timm
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 1. КОНФИГУРАЦИЯ И ПУТИ
|
|
||||||
# ==========================================
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
print(f"Используем устройство: {DEVICE}")
|
|
||||||
|
|
||||||
# Путь к огромному датасету на NFS
|
|
||||||
DATA_ROOT = Path("/home/zin/projects/Thesis/dataset/Original-2.41M")
|
|
||||||
CACHE_PATH = Path("/home/zin/projects/Thesis/src/dataset_paths_cache.pkl")
|
|
||||||
|
|
||||||
# Пути к моделям
|
|
||||||
PREVIOUS_WEIGHTS = Path("/home/zin/projects/Thesis/src/emoset_resnet50_best.pth") # Старые веса (118K)
|
|
||||||
RESUME_CHECKPOINT = Path("/home/zin/projects/Thesis/src/emoset_resnet50_resume.pth") # Файл для восстановления сессии
|
|
||||||
SAVE_MODEL_PATH = Path("/home/zin/projects/Thesis/src/emoset_resnet50_finetuned_2.41M.pth") # Финальный файл
|
|
||||||
|
|
||||||
EMO_MAP = {
|
|
||||||
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
|
|
||||||
"disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7
|
|
||||||
}
|
|
||||||
|
|
||||||
# --- НАСТРОЙКИ ОБУЧЕНИЯ ---
|
|
||||||
BATCH_SIZE = 82
|
|
||||||
EPOCHS = 15
|
|
||||||
LR = 5e-5 # Низкий LR, так как мы делаем Fine-Tuning
|
|
||||||
NUM_TRAIN_WORKERS = 48
|
|
||||||
NUM_VAL_WORKERS = 18
|
|
||||||
|
|
||||||
# Настройки защиты от переобучения
|
|
||||||
PATIENCE = 4
|
|
||||||
best_val_loss = float('inf')
|
|
||||||
epochs_no_improve = 0
|
|
||||||
start_epoch = 1
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 2. ПОДГОТОВКА ДАННЫХ (БЫСТРЫЙ КЭШ)
|
|
||||||
# ==========================================
|
|
||||||
if CACHE_PATH.exists():
|
|
||||||
print(f"Загрузка списка файлов из кэша: {CACHE_PATH}...")
|
|
||||||
with open(CACHE_PATH, 'rb') as f:
|
|
||||||
cache_data = pickle.load(f)
|
|
||||||
all_paths = cache_data['image_paths']
|
|
||||||
all_labels = cache_data['labels']
|
|
||||||
print(f"Готово! Моментально загружено {len(all_paths)} путей.")
|
|
||||||
else:
|
|
||||||
print(f"Сканирование NFS директории {DATA_ROOT} (Выполняется один раз)...")
|
|
||||||
all_paths, all_labels = [], []
|
|
||||||
for img_path in DATA_ROOT.rglob('*.jpg'):
|
|
||||||
emotion_folder = img_path.parts[-3].lower()
|
|
||||||
if emotion_folder in EMO_MAP:
|
|
||||||
all_paths.append(str(img_path))
|
|
||||||
all_labels.append(EMO_MAP[emotion_folder])
|
|
||||||
|
|
||||||
with open(CACHE_PATH, 'wb') as f:
|
|
||||||
pickle.dump({'image_paths': all_paths, 'labels': all_labels}, f)
|
|
||||||
print(f"Сохранено в кэш: {len(all_paths)} изображений.")
|
|
||||||
|
|
||||||
# Разделение на Train / Validation (95% / 5%)
|
|
||||||
random.seed(42) # Фиксируем сид, чтобы при перезапусках сплит не менялся
|
|
||||||
combined = list(zip(all_paths, all_labels))
|
|
||||||
random.shuffle(combined)
|
|
||||||
all_paths, all_labels = zip(*combined)
|
|
||||||
|
|
||||||
split_idx = int(len(all_paths) * 0.95)
|
|
||||||
train_paths, train_labels = all_paths[:split_idx], all_labels[:split_idx]
|
|
||||||
val_paths, val_labels = all_paths[split_idx:], all_labels[split_idx:]
|
|
||||||
print(f"Трейн: {len(train_paths)} | Валидация: {len(val_paths)}")
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 3. DATASET & DATALOADER
|
|
||||||
# ==========================================
|
|
||||||
class EmoSetDirectDataset(Dataset):
|
|
||||||
def __init__(self, image_paths, labels):
|
|
||||||
self.image_paths = image_paths
|
|
||||||
self.labels = labels
|
|
||||||
# На процессоре делаем только базовый ресайз
|
|
||||||
self.base_transform = T.Resize((256, 256), antialias=True)
|
|
||||||
|
|
||||||
def __len__(self): return len(self.image_paths)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
try:
|
|
||||||
image = tv_io.read_image(self.image_paths[idx], mode=tv_io.ImageReadMode.RGB)
|
|
||||||
image = image.to(torch.float32) / 255.0
|
|
||||||
image = self.base_transform(image)
|
|
||||||
except Exception:
|
|
||||||
# Отказоустойчивость для битых файлов из интернета
|
|
||||||
image = torch.zeros((3, 256, 256), dtype=torch.float32)
|
|
||||||
return image, self.labels[idx]
|
|
||||||
|
|
||||||
# --- ИСПРАВЛЕННЫЕ ЗАГРУЗЧИКИ С PREFETCH ---
|
|
||||||
train_loader = DataLoader(
|
|
||||||
EmoSetDirectDataset(train_paths, train_labels),
|
|
||||||
batch_size=BATCH_SIZE,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=NUM_TRAIN_WORKERS,
|
|
||||||
pin_memory=True,
|
|
||||||
prefetch_factor=2,
|
|
||||||
persistent_workers=True
|
|
||||||
)
|
|
||||||
|
|
||||||
val_loader = DataLoader(
|
|
||||||
EmoSetDirectDataset(val_paths, val_labels),
|
|
||||||
batch_size=BATCH_SIZE,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=NUM_VAL_WORKERS,
|
|
||||||
pin_memory=True,
|
|
||||||
prefetch_factor=2,
|
|
||||||
persistent_workers=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 4. АУГМЕНТАЦИИ НА GPU (СУПЕР СКОРОСТЬ)
|
|
||||||
# ==========================================
|
|
||||||
gpu_train_transforms = torch.nn.Sequential(
|
|
||||||
T.RandomCrop((224, 224)),
|
|
||||||
T.RandomHorizontalFlip(p=0.5),
|
|
||||||
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
|
|
||||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
||||||
).to(DEVICE)
|
|
||||||
|
|
||||||
gpu_val_transforms = torch.nn.Sequential(
|
|
||||||
T.CenterCrop((224, 224)),
|
|
||||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
||||||
).to(DEVICE)
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 5. ИНИЦИАЛИЗАЦИЯ МОДЕЛИ
|
|
||||||
# ==========================================
|
|
||||||
print("\nСоздание архитектуры ResNet-50...")
|
|
||||||
model = timm.create_model('resnet50', pretrained=False, num_classes=8).to(DEVICE)
|
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
|
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
|
||||||
scaler = GradScaler()
|
|
||||||
|
|
||||||
# --- ЛОГИКА БЕСШОВНОГО ВОССТАНОВЛЕНИЯ ---
|
|
||||||
if RESUME_CHECKPOINT.exists():
|
|
||||||
print(f"ВОССТАНОВЛЕНИЕ СЕССИИ из: {RESUME_CHECKPOINT}")
|
|
||||||
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)
|
|
||||||
model.load_state_dict(checkpoint['model_state_dict'])
|
|
||||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
||||||
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
||||||
if 'scaler_state_dict' in checkpoint:
|
|
||||||
scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
||||||
if 'best_val_loss' in checkpoint:
|
|
||||||
best_val_loss = checkpoint['best_val_loss']
|
|
||||||
start_epoch = checkpoint['epoch'] + 1
|
|
||||||
print(f"УСПЕХ: Продолжаем обучение с эпохи {start_epoch}")
|
|
||||||
else:
|
|
||||||
print("Чекпоинт сессии не найден. Проверяем наличие базовых весов...")
|
|
||||||
if PREVIOUS_WEIGHTS.exists():
|
|
||||||
print(f"📥 Загрузка базовых весов (от 118K): {PREVIOUS_WEIGHTS}")
|
|
||||||
model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE))
|
|
||||||
else:
|
|
||||||
print("ВНИМАНИЕ: Базовые веса не найдены. Обучение начнется с нуля (ImageNet).")
|
|
||||||
model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE)
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 6. ГЛАВНЫЙ ЦИКЛ ОБУЧЕНИЯ
|
|
||||||
# ==========================================
|
|
||||||
if start_epoch > EPOCHS:
|
|
||||||
print(f"\nОбучение уже было завершено (цель: {EPOCHS} эпох).")
|
|
||||||
else:
|
|
||||||
print(f"\nСтарт обучения. Целевое количество эпох: {EPOCHS}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
for epoch in range(start_epoch, EPOCHS + 1):
|
|
||||||
|
|
||||||
# --- ФАЗА 1: ТРЕНИРОВКА ---
|
|
||||||
model.train()
|
|
||||||
running_loss, correct, total = 0.0, 0, 0
|
|
||||||
|
|
||||||
pbar = tqdm(train_loader, desc=f"Эпоха {epoch}/{EPOCHS} [Тренировка]")
|
|
||||||
for inputs, labels in pbar:
|
|
||||||
try:
|
|
||||||
# Перенос на GPU и применение быстрых аугментаций
|
|
||||||
inputs = inputs.to(DEVICE, non_blocking=True)
|
|
||||||
labels = labels.to(DEVICE, non_blocking=True)
|
|
||||||
inputs = gpu_train_transforms(inputs)
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Смешанная точность (AMP) для экономии VRAM и ускорения
|
|
||||||
with autocast(device_type="cuda"):
|
|
||||||
outputs = model(inputs)
|
|
||||||
loss = criterion(outputs, labels)
|
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
|
|
||||||
running_loss += loss.item() * inputs.size(0)
|
|
||||||
_, predicted = outputs.max(1)
|
|
||||||
total += labels.size(0)
|
|
||||||
correct += predicted.eq(labels).sum().item()
|
|
||||||
|
|
||||||
pbar.set_postfix({'loss': f"{loss.item():.4f}", 'acc': f"{correct/total:.4f}"})
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
# Хендлер нехватки памяти (OOM)
|
|
||||||
if "out of memory" in str(e).lower():
|
|
||||||
print(f"\nВНИМАНИЕ: Нехватка VRAM! Очистка...")
|
|
||||||
if 'outputs' in locals(): del outputs
|
|
||||||
if 'loss' in locals(): del loss
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
continue
|
|
||||||
raise e
|
|
||||||
|
|
||||||
train_loss = running_loss / total if total > 0 else 0
|
|
||||||
train_acc = correct / total if total > 0 else 0
|
|
||||||
|
|
||||||
# Очистка мусора перед валидацией
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# --- ФАЗА 2: ВАЛИДАЦИЯ ---
|
|
||||||
model.eval()
|
|
||||||
val_loss, val_correct, val_total = 0.0, 0, 0
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for val_inputs, val_labels in tqdm(val_loader, desc=f"Эпоха {epoch}/{EPOCHS} [Валидация]", leave=False):
|
|
||||||
val_inputs, val_labels = val_inputs.to(DEVICE), val_labels.to(DEVICE)
|
|
||||||
val_inputs = gpu_val_transforms(val_inputs)
|
|
||||||
|
|
||||||
with autocast(device_type="cuda"):
|
|
||||||
val_outputs = model(val_inputs)
|
|
||||||
v_loss = criterion(val_outputs, val_labels)
|
|
||||||
|
|
||||||
val_loss += v_loss.item() * val_inputs.size(0)
|
|
||||||
_, val_predicted = val_outputs.max(1)
|
|
||||||
val_total += val_labels.size(0)
|
|
||||||
val_correct += val_predicted.eq(val_labels).sum().item()
|
|
||||||
|
|
||||||
epoch_val_loss = val_loss / val_total if val_total > 0 else 0
|
|
||||||
epoch_val_acc = val_correct / val_total if val_total > 0 else 0
|
|
||||||
|
|
||||||
scheduler.step()
|
|
||||||
print(f"🏁 Эпоха {epoch} | Train Loss: {train_loss:.4f} (Acc: {train_acc:.4f}) | Val Loss: {epoch_val_loss:.4f} (Acc: {epoch_val_acc:.4f})")
|
|
||||||
|
|
||||||
# --- ФАЗА 3: EARLY STOPPING И СОХРАНЕНИЕ ---
|
|
||||||
if epoch_val_loss < best_val_loss:
|
|
||||||
best_val_loss = epoch_val_loss
|
|
||||||
epochs_no_improve = 0
|
|
||||||
torch.save(model.state_dict(), "../emoset_resnet50_best_2_41M.pth")
|
|
||||||
print("Новая лучшая модель найдена! Веса сохранены.")
|
|
||||||
else:
|
|
||||||
epochs_no_improve += 1
|
|
||||||
print(f"Валидация не улучшается {epochs_no_improve}/{PATIENCE} эпох.")
|
|
||||||
if epochs_no_improve >= PATIENCE and epoch >= 15: # Даем модели минимум 15 эпох на раскачку
|
|
||||||
print("\nСРАБОТАЛА ЗАЩИТА ОТ ПЕРЕОБУЧЕНИЯ (Early Stopping)!")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Сохранение полного состояния сессии
|
|
||||||
checkpoint_state = {
|
|
||||||
'epoch': epoch,
|
|
||||||
'model_state_dict': model.state_dict(),
|
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
|
||||||
'scheduler_state_dict': scheduler.state_dict(),
|
|
||||||
'scaler_state_dict': scaler.state_dict(),
|
|
||||||
'best_val_loss': best_val_loss
|
|
||||||
}
|
|
||||||
torch.save(checkpoint_state, RESUME_CHECKPOINT)
|
|
||||||
|
|
||||||
# Сохранение весов конкретной эпохи как бэкап
|
|
||||||
torch.save(model.state_dict(), f"../emoset_resnet50_finetuned_ep{epoch}.pth")
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 7. ПЕРЕХВАТ РУЧНОЙ ОСТАНОВКИ (CTRL+C)
|
|
||||||
# ==========================================
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n\nОБУЧЕНИЕ ПРЕРВАНО ВРУЧНУЮ (KeyboardInterrupt)!")
|
|
||||||
print(f"Экстренное сохранение состояния конвейера на эпохе {epoch}...")
|
|
||||||
|
|
||||||
checkpoint_state = {
|
|
||||||
'epoch': epoch, 'model_state_dict': model.state_dict(),
|
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
|
||||||
'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': scaler.state_dict(),
|
|
||||||
'best_val_loss': best_val_loss
|
|
||||||
}
|
|
||||||
torch.save(checkpoint_state, RESUME_CHECKPOINT)
|
|
||||||
|
|
||||||
interrupted_weights_path = f"../emoset_resnet50_interrupted_ep{epoch}.pth"
|
|
||||||
torch.save(model.state_dict(), interrupted_weights_path)
|
|
||||||
print(f"Прогресс безопасно зафиксирован в файле {interrupted_weights_path}. Выходим.")
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 8. ФИНАЛЬНОЕ СОХРАНЕНИЕ
|
|
||||||
# ==========================================
|
|
||||||
else:
|
|
||||||
if SAVE_MODEL_PATH.parent.exists():
|
|
||||||
torch.save(model.state_dict(), SAVE_MODEL_PATH)
|
|
||||||
print(f"\nОБУЧЕНИЕ УСПЕШНО ЗАВЕРШЕНО! Финальная модель: {SAVE_MODEL_PATH}")
|
|
||||||
if RESUME_CHECKPOINT.exists():
|
|
||||||
RESUME_CHECKPOINT.unlink() # Удаляем resume файл за ненадобностью
|
|
||||||
@@ -1,467 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "71ef58af",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Используем устройство: cuda\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import os\n",
|
|
||||||
"import torch\n",
|
|
||||||
"import torch.nn as nn\n",
|
|
||||||
"from torch.utils.data import Dataset, DataLoader\n",
|
|
||||||
"import torchvision.transforms as T\n",
|
|
||||||
"import pandas as pd\n",
|
|
||||||
"from pathlib import Path\n",
|
|
||||||
"from PIL import Image\n",
|
|
||||||
"from tqdm.notebook import tqdm\n",
|
|
||||||
"import timm\n",
|
|
||||||
"\n",
|
|
||||||
"# Проверяем GPU\n",
|
|
||||||
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
|
||||||
"print(f\"Используем устройство: {DEVICE}\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "f4ae931c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# === НАСТРОЙКИ ДООБУЧЕНИЯ ===\n",
|
|
||||||
"\n",
|
|
||||||
"# Абсолютный путь к смонтированному NFS\n",
|
|
||||||
"DATA_ROOT = Path(\"/home/zin/projects/Thesis/NFS/Thesis/Emoset/Original-2.41M\")\n",
|
|
||||||
"\n",
|
|
||||||
"# Пути относительно src/scripts/\n",
|
|
||||||
"PREVIOUS_WEIGHTS = Path(\"../emoset_resnet50_best.pth\")\n",
|
|
||||||
"SAVE_MODEL_PATH = Path(\"../emoset_resnet50_finetuned_2.41M.pth\")\n",
|
|
||||||
"\n",
|
|
||||||
"# Маппинг эмоций в те же индексы (0-7), которые использовались при первоначальном обучении\n",
|
|
||||||
"EMO_MAP = {\n",
|
|
||||||
" \"amusement\": 0,\n",
|
|
||||||
" \"anger\": 1,\n",
|
|
||||||
" \"awe\": 2,\n",
|
|
||||||
" \"contentment\": 3,\n",
|
|
||||||
" \"disgust\": 4,\n",
|
|
||||||
" \"excitement\": 5,\n",
|
|
||||||
" \"fear\": 6,\n",
|
|
||||||
" \"sad\": 7, # В твоем сообщении папка называется \"sad\"\n",
|
|
||||||
" \"sadness\": 7 # На всякий случай оставляем и классическое название\n",
|
|
||||||
"}\n",
|
|
||||||
"\n",
|
|
||||||
"BATCH_SIZE = 96\n",
|
|
||||||
"EPOCHS = 15\n",
|
|
||||||
"LR = 5e-5\n",
|
|
||||||
"NUM_WORKERS = 42"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"id": "934cfe2c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import pickle\n",
|
|
||||||
"import os\n",
|
|
||||||
"import torchvision.io as tv_io\n",
|
|
||||||
"\n",
|
|
||||||
"# Твои трансформации (только Resize, как мы сделали ранее)\n",
|
|
||||||
"train_transforms = T.Compose([\n",
|
|
||||||
" T.Resize((256, 256), antialias=True)\n",
|
|
||||||
"])\n",
|
|
||||||
"\n",
|
|
||||||
"class EmoSetNestedDataset(Dataset):\n",
|
|
||||||
" def __init__(self, root_dir, transform=None, cache_file=\"../dataset_cache_2.41M.pkl\"):\n",
|
|
||||||
" self.root_dir = Path(root_dir)\n",
|
|
||||||
" self.transform = transform\n",
|
|
||||||
" self.image_paths = []\n",
|
|
||||||
" self.labels = []\n",
|
|
||||||
" \n",
|
|
||||||
" # === ЛОГИКА КЭШИРОВАНИЯ ===\n",
|
|
||||||
" if os.path.exists(cache_file):\n",
|
|
||||||
" print(f\"📦 Загрузка списка файлов из локального кэша: {cache_file}...\")\n",
|
|
||||||
" with open(cache_file, 'rb') as f:\n",
|
|
||||||
" cache_data = pickle.load(f)\n",
|
|
||||||
" self.image_paths = cache_data['image_paths']\n",
|
|
||||||
" self.labels = cache_data['labels']\n",
|
|
||||||
" print(f\"⚡ Готово! Моментально загружено {len(self.image_paths)} путей.\")\n",
|
|
||||||
" else:\n",
|
|
||||||
" print(f\"🔍 Сканирование NFS директории {self.root_dir}...\")\n",
|
|
||||||
" print(\"Это займет около 8-10 минут. Выполняется один раз...\")\n",
|
|
||||||
" \n",
|
|
||||||
" for img_path in self.root_dir.rglob('*.jpg'):\n",
|
|
||||||
" emotion_folder = img_path.parts[-3].lower()\n",
|
|
||||||
" if emotion_folder in EMO_MAP:\n",
|
|
||||||
" self.image_paths.append(str(img_path))\n",
|
|
||||||
" self.labels.append(EMO_MAP[emotion_folder])\n",
|
|
||||||
" \n",
|
|
||||||
" print(f\"💾 Сохранение результатов сканирования в кэш: {cache_file}...\")\n",
|
|
||||||
" with open(cache_file, 'wb') as f:\n",
|
|
||||||
" pickle.dump({'image_paths': self.image_paths, 'labels': self.labels}, f)\n",
|
|
||||||
" \n",
|
|
||||||
" print(f\"✅ Успешно найдено и закэшировано {len(self.image_paths)} изображений.\")\n",
|
|
||||||
"\n",
|
|
||||||
" def __len__(self):\n",
|
|
||||||
" return len(self.image_paths)\n",
|
|
||||||
"\n",
|
|
||||||
" def __getitem__(self, idx):\n",
|
|
||||||
" img_path = self.image_paths[idx]\n",
|
|
||||||
" label = self.labels[idx]\n",
|
|
||||||
" \n",
|
|
||||||
" try:\n",
|
|
||||||
" image = tv_io.read_image(str(img_path), mode=tv_io.ImageReadMode.RGB)\n",
|
|
||||||
" image = image.to(torch.float32) / 255.0\n",
|
|
||||||
" except Exception as e:\n",
|
|
||||||
" image = torch.zeros((3, 256, 256), dtype=torch.float32)\n",
|
|
||||||
" \n",
|
|
||||||
" if self.transform:\n",
|
|
||||||
" image = self.transform(image)\n",
|
|
||||||
" \n",
|
|
||||||
" return image, label"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"id": "b10adc06",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"📦 Загрузка списка файлов из локального кэша: ../dataset_paths_cache.pkl...\n",
|
|
||||||
"⚡ Готово! Моментально загружено 2048377 путей.\n",
|
|
||||||
"Батчей за одну эпоху: 21338\n",
|
|
||||||
"\n",
|
|
||||||
"Создание архитектуры ResNet50...\n",
|
|
||||||
"📝 Чекпоинт прерванной сессии не найден. Проверяем базовые веса...\n",
|
|
||||||
"УСПЕХ: Найдены предыдущие веса '../emoset_resnet50_best.pth' (из EmoSet-118K). Загружаем...\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# Путь к кэш-файлу (лучше положить его в src, рядом со скриптами, чтобы быстро читался)\n",
|
|
||||||
"CACHE_PATH = Path(\"../dataset_paths_cache.pkl\")\n",
|
|
||||||
"\n",
|
|
||||||
"# 1. Загрузка данных напрямую из папок (или из кэша!)\n",
|
|
||||||
"train_dataset = EmoSetNestedDataset(DATA_ROOT, transform=train_transforms, cache_file=CACHE_PATH)\n",
|
|
||||||
"\n",
|
|
||||||
"train_loader = DataLoader(\n",
|
|
||||||
" train_dataset, \n",
|
|
||||||
" batch_size=BATCH_SIZE, \n",
|
|
||||||
" shuffle=True, \n",
|
|
||||||
" num_workers=NUM_WORKERS,\n",
|
|
||||||
" pin_memory=True,\n",
|
|
||||||
" prefetch_factor=1,\n",
|
|
||||||
" persistent_workers=True\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"print(f\"Батчей за одну эпоху: {len(train_loader)}\")\n",
|
|
||||||
"\n",
|
|
||||||
"# Путь к файлу автоматического восстановления (чекпоинт полной сессии)\n",
|
|
||||||
"RESUME_CHECKPOINT = Path(\"../emoset_resnet50_checkpoint_latest.pth\")\n",
|
|
||||||
"\n",
|
|
||||||
"# 2. Инициализация архитектуры модели ResNet-50\n",
|
|
||||||
"print(\"\\nСоздание архитектуры ResNet50...\")\n",
|
|
||||||
"model = timm.create_model('resnet50', pretrained=False, num_classes=8)\n",
|
|
||||||
"model = model.to(DEVICE)\n",
|
|
||||||
"\n",
|
|
||||||
"# Инициализируем компоненты оптимизации\n",
|
|
||||||
"criterion = nn.CrossEntropyLoss()\n",
|
|
||||||
"optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)\n",
|
|
||||||
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)\n",
|
|
||||||
"\n",
|
|
||||||
"# По умолчанию начинаем с 1-й эпохи\n",
|
|
||||||
"start_epoch = 1\n",
|
|
||||||
"\n",
|
|
||||||
"# === ЛОГИКА ВОССТАНОВЛЕНИЯ / СТАРТА ===\n",
|
|
||||||
"if RESUME_CHECKPOINT.exists():\n",
|
|
||||||
" print(f\"🔄 ОБНАРУЖЕН ДЕЙСТВУЮЩИЙ ЧЕКПОИНТ: '{RESUME_CHECKPOINT}'\")\n",
|
|
||||||
" print(\"Восстанавливаем полное состояние сессии...\")\n",
|
|
||||||
" \n",
|
|
||||||
" # Загружаем сохраненный словарь состояния\n",
|
|
||||||
" checkpoint = torch.load(RESUME_CHECKPOINT, map_location=DEVICE)\n",
|
|
||||||
" \n",
|
|
||||||
" # Восстанавливаем всё до единого\n",
|
|
||||||
" model.load_state_dict(checkpoint['model_state_dict'])\n",
|
|
||||||
" optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
|
|
||||||
" scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n",
|
|
||||||
" start_epoch = checkpoint['epoch'] + 1 # Начинаем со следующей эпохи\n",
|
|
||||||
" \n",
|
|
||||||
" print(f\"🚀 УСПЕХ: Сессия восстановлена! Продолжаем обучение с эпохи {start_epoch}\")\n",
|
|
||||||
"else:\n",
|
|
||||||
" print(\"📝 Чекпоинт прерванной сессии не найден. Проверяем базовые веса...\")\n",
|
|
||||||
" if PREVIOUS_WEIGHTS.exists():\n",
|
|
||||||
" print(f\"УСПЕХ: Найдены предыдущие веса '{PREVIOUS_WEIGHTS}' (из EmoSet-118K). Загружаем...\")\n",
|
|
||||||
" model.load_state_dict(torch.load(PREVIOUS_WEIGHTS, map_location=DEVICE))\n",
|
|
||||||
" else:\n",
|
|
||||||
" print(\"ВНИМАНИЕ: Базовые веса не найдены. Начинаем обучение с нуля (ImageNet pretrained).\")\n",
|
|
||||||
" model = timm.create_model('resnet50', pretrained=True, num_classes=8).to(DEVICE)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"id": "a7480834",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"⏰ Старт обучения. Целевое количество эпох: 15\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Эпоха 1/15 [Тренировка]: 0%| | 0/21338 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Эпоха 1/15 [Тренировка]: 13%|█▎ | 2705/21338 [32:27<2:39:51, 1.94it/s, loss=1.6873, acc=0.3036] Corrupt JPEG data: 80 extraneous bytes before marker 0xd9\n",
|
|
||||||
"Эпоха 1/15 [Тренировка]: 17%|█▋ | 3623/21338 [43:15<4:46:57, 1.03it/s, loss=1.7141, acc=0.3102] Invalid SOS parameters for sequential JPEG\n",
|
|
||||||
"Эпоха 1/15 [Тренировка]: 18%|█▊ | 3741/21338 [44:35<2:23:26, 2.04it/s, loss=1.7848, acc=0.3109] Invalid SOS parameters for sequential JPEG\n",
|
|
||||||
"Эпоха 1/15 [Тренировка]: 19%|█▉ | 4109/21338 [49:12<2:17:42, 2.09it/s, loss=1.8072, acc=0.3133] Corrupt JPEG data: 485 extraneous bytes before marker 0xd9\n",
|
|
||||||
"Эпоха 1/15 [Тренировка]: 22%|██▏ | 4729/21338 [56:22<3:06:07, 1.49it/s, loss=1.6499, acc=0.3173] Invalid SOS parameters for sequential JPEG\n",
|
|
||||||
"Эпоха 1/15 [Тренировка]: 38%|███▊ | 8033/21338 [1:35:20<2:21:51, 1.56it/s, loss=1.5897, acc=0.3338] Corrupt JPEG data: 41 extraneous bytes before marker 0xd9\n",
|
|
||||||
"Эпоха 1/15 [Тренировка]: 45%|████▌ | 9684/21338 [1:54:34<3:40:06, 1.13s/it, loss=1.6740, acc=0.3399] Invalid SOS parameters for sequential JPEG\n",
|
|
||||||
"Эпоха 1/15 [Тренировка]: 50%|█████ | 10679/21338 [2:06:15<47:11, 3.76it/s, loss=1.6234, acc=0.3431] Invalid SOS parameters for sequential JPEG\n",
|
|
||||||
"Эпоха 1/15 [Тренировка]: 55%|█████▌ | 11802/21338 [2:19:39<1:50:22, 1.44it/s, loss=1.5677, acc=0.3463] Unknown Adobe color transform code 2\n",
|
|
||||||
"Эпоха 1/15 [Тренировка]: 67%|██████▋ | 14253/21338 [2:48:12<27:27, 4.30it/s, loss=1.7579, acc=0.3525] Invalid SOS parameters for sequential JPEG\n",
|
|
||||||
"Эпоха 1/15 [Тренировка]: 77%|███████▋ | 16377/21338 [3:13:17<1:06:23, 1.25it/s, loss=1.6855, acc=0.3572] Invalid SOS parameters for sequential JPEG\n",
|
|
||||||
"Эпоха 1/15 [Тренировка]: 92%|█████████▏| 19575/21338 [3:51:14<11:13, 2.62it/s, loss=1.5876, acc=0.3631] Invalid SOS parameters for sequential JPEG\n",
|
|
||||||
"Эпоха 1/15 [Тренировка]: 92%|█████████▏| 19679/21338 [3:52:26<07:42, 3.59it/s, loss=1.7134, acc=0.3633] Invalid SOS parameters for sequential JPEG\n",
|
|
||||||
"Эпоха 1/15 [Тренировка]: 100%|█████████▉| 21283/21338 [4:11:18<00:20, 2.70it/s, loss=1.4613, acc=0.3658] Unknown Adobe color transform code 2\n",
|
|
||||||
"Эпоха 1/15 [Тренировка]: 100%|██████████| 21338/21338 [4:11:47<00:00, 1.41it/s, loss=1.6304, acc=0.3659]\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"ename": "NameError",
|
|
||||||
"evalue": "name 'val_loader' is not defined",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
|
||||||
"\u001b[31mNameError\u001b[39m Traceback (most recent call last)",
|
|
||||||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 88\u001b[39m\n\u001b[32m 85\u001b[39m \u001b[38;5;66;03m# ВАЖНО: Если у тебя нет val_loader, создай его (откуси 5-10% от датасета)\u001b[39;00m\n\u001b[32m 86\u001b[39m \u001b[38;5;66;03m# На валидации мы НЕ применяем gpu_transforms (только нормализацию)\u001b[39;00m\n\u001b[32m 87\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m torch.no_grad():\n\u001b[32m---> \u001b[39m\u001b[32m88\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m val_inputs, val_labels \u001b[38;5;129;01min\u001b[39;00m \u001b[43mval_loader\u001b[49m:\n\u001b[32m 89\u001b[39m val_inputs, val_labels = val_inputs.to(DEVICE), val_labels.to(DEVICE)\n\u001b[32m 91\u001b[39m \u001b[38;5;66;03m# Валидация тоже идет в смешанной точности для скорости\u001b[39;00m\n",
|
|
||||||
"\u001b[31mNameError\u001b[39m: name 'val_loader' is not defined"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import gc\n",
|
|
||||||
"import torch\n",
|
|
||||||
"from torch.amp import autocast, GradScaler\n",
|
|
||||||
"from tqdm import tqdm\n",
|
|
||||||
"import torchvision.transforms as T\n",
|
|
||||||
"\n",
|
|
||||||
"# --- НАСТРОЙКИ EARLY STOPPING ---\n",
|
|
||||||
"PATIENCE = 4 # Сколько эпох ждем, если валидация не улучшается\n",
|
|
||||||
"best_val_loss = float('inf')\n",
|
|
||||||
"epochs_no_improve = 0\n",
|
|
||||||
"start_epoch = 1\n",
|
|
||||||
"\n",
|
|
||||||
"# --- ПЕРЕНОСИМ ТЯЖЕЛЫЕ АУГМЕНТАЦИИ НА GPU ---\n",
|
|
||||||
"gpu_transforms = torch.nn.Sequential(\n",
|
|
||||||
" T.RandomCrop((224, 224)),\n",
|
|
||||||
" T.RandomHorizontalFlip(p=0.5),\n",
|
|
||||||
" T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),\n",
|
|
||||||
" T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
|
|
||||||
").to(DEVICE)\n",
|
|
||||||
"# -------------------------------------------\n",
|
|
||||||
"\n",
|
|
||||||
"print(f\"⏰ Старт обучения. Целевое количество эпох: {EPOCHS}\")\n",
|
|
||||||
"scaler = GradScaler()\n",
|
|
||||||
"\n",
|
|
||||||
"try:\n",
|
|
||||||
" for epoch in range(start_epoch, EPOCHS + 1):\n",
|
|
||||||
" # ==========================================\n",
|
|
||||||
" # 1. ТРЕНИРОВКА (TRAIN)\n",
|
|
||||||
" # ==========================================\n",
|
|
||||||
" model.train()\n",
|
|
||||||
" running_loss = 0.0\n",
|
|
||||||
" correct = 0\n",
|
|
||||||
" total = 0\n",
|
|
||||||
" \n",
|
|
||||||
" pbar = tqdm(train_loader, desc=f\"Эпоха {epoch}/{EPOCHS} [Тренировка]\")\n",
|
|
||||||
" for inputs, labels in pbar:\n",
|
|
||||||
" try:\n",
|
|
||||||
" inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)\n",
|
|
||||||
" \n",
|
|
||||||
" # Применяем аугментации на лету к батчу (на GPU)\n",
|
|
||||||
" inputs = gpu_transforms(inputs)\n",
|
|
||||||
" \n",
|
|
||||||
" optimizer.zero_grad()\n",
|
|
||||||
" \n",
|
|
||||||
" with autocast(device_type=\"cuda\"): # Для PyTorch >= 2.0 лучше указывать \"cuda\"\n",
|
|
||||||
" outputs = model(inputs)\n",
|
|
||||||
" loss = criterion(outputs, labels)\n",
|
|
||||||
" \n",
|
|
||||||
" scaler.scale(loss).backward()\n",
|
|
||||||
" scaler.step(optimizer)\n",
|
|
||||||
" scaler.update()\n",
|
|
||||||
" \n",
|
|
||||||
" # Статистика\n",
|
|
||||||
" running_loss += loss.item() * inputs.size(0)\n",
|
|
||||||
" _, predicted = outputs.max(1)\n",
|
|
||||||
" total += labels.size(0)\n",
|
|
||||||
" correct += predicted.eq(labels).sum().item()\n",
|
|
||||||
" \n",
|
|
||||||
" pbar.set_postfix({'loss': f\"{loss.item():.4f}\", 'acc': f\"{correct/total:.4f}\"})\n",
|
|
||||||
" \n",
|
|
||||||
" except RuntimeError as e:\n",
|
|
||||||
" # Обработчик OOM\n",
|
|
||||||
" if \"out of memory\" in str(e).lower():\n",
|
|
||||||
" print(f\"\\n⚠️ ВНИМАНИЕ: Нехватка VRAM на батче! Выполняем экстренную очистку...\")\n",
|
|
||||||
" if 'outputs' in locals(): del outputs\n",
|
|
||||||
" if 'loss' in locals(): del loss\n",
|
|
||||||
" torch.cuda.empty_cache()\n",
|
|
||||||
" optimizer.zero_grad()\n",
|
|
||||||
" print(\"♻️ Кэш очищен. Батч пропущен. Продолжаем обучение...\")\n",
|
|
||||||
" continue\n",
|
|
||||||
" else:\n",
|
|
||||||
" raise e\n",
|
|
||||||
" \n",
|
|
||||||
" epoch_loss = running_loss / total if total > 0 else 0\n",
|
|
||||||
" epoch_acc = correct / total if total > 0 else 0\n",
|
|
||||||
" \n",
|
|
||||||
" # ==========================================\n",
|
|
||||||
" # 2. ВАЛИДАЦИЯ И EARLY STOPPING\n",
|
|
||||||
" # ==========================================\n",
|
|
||||||
" model.eval()\n",
|
|
||||||
" val_loss = 0.0\n",
|
|
||||||
" val_correct = 0\n",
|
|
||||||
" val_total = 0\n",
|
|
||||||
" \n",
|
|
||||||
" # ВАЖНО: Если у тебя нет val_loader, создай его (откуси 5-10% от датасета)\n",
|
|
||||||
" # На валидации мы НЕ применяем gpu_transforms (только нормализацию)\n",
|
|
||||||
" with torch.no_grad():\n",
|
|
||||||
" for val_inputs, val_labels in val_loader:\n",
|
|
||||||
" val_inputs, val_labels = val_inputs.to(DEVICE), val_labels.to(DEVICE)\n",
|
|
||||||
" \n",
|
|
||||||
" # Валидация тоже идет в смешанной точности для скорости\n",
|
|
||||||
" with autocast(device_type=\"cuda\"):\n",
|
|
||||||
" val_outputs = model(val_inputs)\n",
|
|
||||||
" v_loss = criterion(val_outputs, val_labels)\n",
|
|
||||||
" \n",
|
|
||||||
" val_loss += v_loss.item() * val_inputs.size(0)\n",
|
|
||||||
" _, val_predicted = val_outputs.max(1)\n",
|
|
||||||
" val_total += val_labels.size(0)\n",
|
|
||||||
" val_correct += val_predicted.eq(val_labels).sum().item()\n",
|
|
||||||
" \n",
|
|
||||||
" epoch_val_loss = val_loss / val_total if val_total > 0 else 0\n",
|
|
||||||
" epoch_val_acc = val_correct / val_total if val_total > 0 else 0\n",
|
|
||||||
" \n",
|
|
||||||
" scheduler.step()\n",
|
|
||||||
" print(f\"🏁 Эпоха {epoch} завершена | Train Loss: {epoch_loss:.4f} (Acc: {epoch_acc:.4f}) | Val Loss: {epoch_val_loss:.4f} (Acc: {epoch_val_acc:.4f})\")\n",
|
|
||||||
" \n",
|
|
||||||
" # --- ЛОГИКА РАННЕЙ ОСТАНОВКИ ---\n",
|
|
||||||
" if epoch_val_loss < best_val_loss:\n",
|
|
||||||
" best_val_loss = epoch_val_loss\n",
|
|
||||||
" epochs_no_improve = 0\n",
|
|
||||||
" # Сохраняем идеальные веса, пока сеть не переобучилась\n",
|
|
||||||
" torch.save(model.state_dict(), \"../emoset_resnet50_best.pth\")\n",
|
|
||||||
" print(\"🌟 Найдена лучшая модель! Веса сохранены.\")\n",
|
|
||||||
" else:\n",
|
|
||||||
" epochs_no_improve += 1\n",
|
|
||||||
" print(f\"⚠️ Валидационный Loss не улучшается {epochs_no_improve}/{PATIENCE} эпох.\")\n",
|
|
||||||
" \n",
|
|
||||||
" # Условие: если переобучение длится долго И мы прошли хотя бы 15 эпох\n",
|
|
||||||
" if epochs_no_improve >= PATIENCE and epoch >= 15:\n",
|
|
||||||
" print(\"\\n🛑 СРАБОТАЛА ЗАЩИТА ОТ ПЕРЕОБУЧЕНИЯ (Early Stopping)!\")\n",
|
|
||||||
" print(\"Модель начала запоминать данные вместо обобщения. Обучение досрочно завершено.\")\n",
|
|
||||||
" break # Прерываем цикл эпох\n",
|
|
||||||
" \n",
|
|
||||||
" # ==========================================\n",
|
|
||||||
" # 3. РЕГУЛЯРНОЕ СОХРАНЕНИЕ ПРОГРЕССА\n",
|
|
||||||
" # ==========================================\n",
|
|
||||||
" checkpoint_state = {\n",
|
|
||||||
" 'epoch': epoch,\n",
|
|
||||||
" 'model_state_dict': model.state_dict(),\n",
|
|
||||||
" 'optimizer_state_dict': optimizer.state_dict(),\n",
|
|
||||||
" 'scheduler_state_dict': scheduler.state_dict(),\n",
|
|
||||||
" 'scaler_state_dict': scaler.state_dict(),\n",
|
|
||||||
" 'loss': epoch_loss,\n",
|
|
||||||
" 'val_loss': epoch_val_loss\n",
|
|
||||||
" }\n",
|
|
||||||
" torch.save(checkpoint_state, RESUME_CHECKPOINT)\n",
|
|
||||||
" \n",
|
|
||||||
" # Сохранение весов конкретной эпохи (на всякий случай)\n",
|
|
||||||
" epoch_weights_path = f\"../emoset_resnet50_finetuned_ep{epoch}.pth\"\n",
|
|
||||||
" torch.save(model.state_dict(), epoch_weights_path)\n",
|
|
||||||
" \n",
|
|
||||||
" gc.collect()\n",
|
|
||||||
" torch.cuda.empty_cache()\n",
|
|
||||||
"\n",
|
|
||||||
"# ==========================================\n",
|
|
||||||
"# 4. БЕЗОПАСНЫЙ ВЫХОД ПРИ РУЧНОМ ПРЕРЫВАНИИ\n",
|
|
||||||
"# ==========================================\n",
|
|
||||||
"except KeyboardInterrupt:\n",
|
|
||||||
" print(\"\\n\\n🛑 ОБУЧЕНИЕ ПРЕРВАНО ВРУЧНУЮ (KeyboardInterrupt)!\")\n",
|
|
||||||
" print(f\"💾 Экстренное сохранение состояния конвейера на эпохе {epoch}...\")\n",
|
|
||||||
" \n",
|
|
||||||
" # Сохраняем всё, чтобы потом можно было продолжить с этого же места\n",
|
|
||||||
" checkpoint_state = {\n",
|
|
||||||
" 'epoch': epoch,\n",
|
|
||||||
" 'model_state_dict': model.state_dict(),\n",
|
|
||||||
" 'optimizer_state_dict': optimizer.state_dict(),\n",
|
|
||||||
" 'scheduler_state_dict': scheduler.state_dict(),\n",
|
|
||||||
" 'scaler_state_dict': scaler.state_dict()\n",
|
|
||||||
" }\n",
|
|
||||||
" torch.save(checkpoint_state, RESUME_CHECKPOINT)\n",
|
|
||||||
" \n",
|
|
||||||
" # Сохраняем промежуточные веса на момент остановки\n",
|
|
||||||
" interrupted_weights_path = f\"../emoset_resnet50_interrupted_ep{epoch}.pth\"\n",
|
|
||||||
" torch.save(model.state_dict(), interrupted_weights_path)\n",
|
|
||||||
" \n",
|
|
||||||
" print(f\"✅ Прогресс безопасно зафиксирован в файле {interrupted_weights_path}. Выходим.\")\n",
|
|
||||||
"\n",
|
|
||||||
"# Финальное сохранение (если цикл дошел до конца сам)\n",
|
|
||||||
"else:\n",
|
|
||||||
" if SAVE_MODEL_PATH.parent.exists():\n",
|
|
||||||
" torch.save(model.state_dict(), SAVE_MODEL_PATH)\n",
|
|
||||||
" print(f\"\\n🎉 ОБУЧЕНИЕ УСПЕШНО ЗАВЕРШЕНО! Финальная модель: {SAVE_MODEL_PATH}\")\n",
|
|
||||||
" if RESUME_CHECKPOINT.exists():\n",
|
|
||||||
" RESUME_CHECKPOINT.unlink()"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python (thesis)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "thesis"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.7"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"id": "b92e0213",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import pandas as pd\n",
|
|
||||||
"from pathlib import Path"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "1763c51e",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"✅ УСПЕХ! База создана: ../../dataset/DEAM/music_db.csv\n",
|
|
||||||
"Всего треков в базе: 1744\n",
|
|
||||||
"Пример данных:\n",
|
|
||||||
" song_id valence arousal\n",
|
|
||||||
"0 2 3.1 3.0\n",
|
|
||||||
"1 3 3.5 3.3\n",
|
|
||||||
"2 4 5.7 5.5\n",
|
|
||||||
"3 5 4.4 5.3\n",
|
|
||||||
"4 7 5.8 6.4\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# Точный путь к оригинальным аннотациям\n",
|
|
||||||
"source_path = Path(\"../../dataset/DEAM/DEAM_Annotations/annotations/annotations averaged per song/song_level/static_annotations_averaged_songs_1_2000.csv\")\n",
|
|
||||||
"# Путь, куда сохраним очищенную базу для движка\n",
|
|
||||||
"output_path = Path(\"../../dataset/DEAM/music_db.csv\")\n",
|
|
||||||
"\n",
|
|
||||||
"if not source_path.exists():\n",
|
|
||||||
" print(f\"Исходный файл не найден по пути: {source_path}\")\n",
|
|
||||||
"else:\n",
|
|
||||||
" # skipinitialspace=True уберет лишние пробелы в названиях колонок, если они есть\n",
|
|
||||||
" df = pd.read_csv(source_path, skipinitialspace=True)\n",
|
|
||||||
" \n",
|
|
||||||
" # Берем только нужные колонки (по твоему примеру)\n",
|
|
||||||
" clean_df = df[['song_id', 'valence_mean', 'arousal_mean']].copy()\n",
|
|
||||||
" \n",
|
|
||||||
" # Переименовываем для простоты кода в движке\n",
|
|
||||||
" clean_df.columns = ['song_id', 'valence', 'arousal']\n",
|
|
||||||
" \n",
|
|
||||||
" # Приводим ID к целому числу (2, 3, 4...), чтобы искать файлы '2.mp3'\n",
|
|
||||||
" clean_df['song_id'] = clean_df['song_id'].astype(int)\n",
|
|
||||||
" \n",
|
|
||||||
" # Сохраняем финальный файл\n",
|
|
||||||
" clean_df.to_csv(output_path, index=False)\n",
|
|
||||||
" \n",
|
|
||||||
" print(f\"Успех! База создана: {output_path}\")\n",
|
|
||||||
" print(f\"Всего треков в базе: {len(clean_df)}\")\n",
|
|
||||||
" print(\"Пример данных:\")\n",
|
|
||||||
" print(clean_df.head())"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python (thesis)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "thesis"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.7"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "d70d8e32",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from concurrent.futures import ProcessPoolExecutor\n",
|
|
||||||
"import pandas as pd\n",
|
|
||||||
"from pathlib import Path\n",
|
|
||||||
"from PIL import Image\n",
|
|
||||||
"import torch\n",
|
|
||||||
"from torchvision import transforms\n",
|
|
||||||
"from tqdm import tqdm"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "31b0fa82",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"DATA_ROOT = Path(\"../dataset/EmoSet-118K\")\n",
|
|
||||||
"TRANSFORM = transforms.Compose([\n",
|
|
||||||
" transforms.Resize((224,224)),\n",
|
|
||||||
" transforms.ToTensor(),\n",
|
|
||||||
" transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])\n",
|
|
||||||
"])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"id": "1a17ecf5",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" 0%| | 0/94481 [00:00<?, ?it/s]\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"ename": "PicklingError",
|
|
||||||
"evalue": "Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
|
||||||
"\u001b[31m_RemoteTraceback\u001b[39m Traceback (most recent call last)",
|
|
||||||
"\u001b[31m_RemoteTraceback\u001b[39m: \n\"\"\"\nTraceback (most recent call last):\n File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 244, in _feed\n obj = _ForkingPickler.dumps(obj)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/reduction.py\", line 51, in dumps\n cls(buf, protocol).dump(obj)\n_pickle.PicklingError: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed\n\"\"\"",
|
|
||||||
"\nThe above exception was the direct cause of the following exception:\n",
|
|
||||||
"\u001b[31mPicklingError\u001b[39m Traceback (most recent call last)",
|
|
||||||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 20\u001b[39m\n\u001b[32m 18\u001b[39m futures = [executor.submit(process_row, row, split_dir, tensor_dir) \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m df.itertuples()]\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m tqdm(futures):\n\u001b[32m---> \u001b[39m\u001b[32m20\u001b[39m results.append(\u001b[43mf\u001b[49m\u001b[43m.\u001b[49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 22\u001b[39m new_df = pd.DataFrame(results)\n\u001b[32m 23\u001b[39m new_df.to_csv(DATA_ROOT / split / \u001b[33m\"\u001b[39m\u001b[33mlabels_tensor.csv\u001b[39m\u001b[33m\"\u001b[39m, index=\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:449\u001b[39m, in \u001b[36mFuture.result\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 447\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m CancelledError()\n\u001b[32m 448\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state == FINISHED:\n\u001b[32m--> \u001b[39m\u001b[32m449\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m__get_result\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 451\u001b[39m \u001b[38;5;28mself\u001b[39m._condition.wait(timeout)\n\u001b[32m 453\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state \u001b[38;5;129;01min\u001b[39;00m [CANCELLED, CANCELLED_AND_NOTIFIED]:\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:401\u001b[39m, in \u001b[36mFuture.__get_result\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 399\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception:\n\u001b[32m 400\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m401\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception\n\u001b[32m 402\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 403\u001b[39m \u001b[38;5;66;03m# Break a reference cycle with the exception in self._exception\u001b[39;00m\n\u001b[32m 404\u001b[39m \u001b[38;5;28mself\u001b[39m = \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py:244\u001b[39m, in \u001b[36mQueue._feed\u001b[39m\u001b[34m(buffer, notempty, send_bytes, writelock, reader_close, writer_close, ignore_epipe, onerror, queue_sem)\u001b[39m\n\u001b[32m 241\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[32m 243\u001b[39m \u001b[38;5;66;03m# serialize the data before acquiring the lock\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m244\u001b[39m obj = \u001b[43m_ForkingPickler\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdumps\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 245\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m wacquire \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 246\u001b[39m send_bytes(obj)\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/reduction.py:51\u001b[39m, in \u001b[36mForkingPickler.dumps\u001b[39m\u001b[34m(cls, obj, protocol)\u001b[39m\n\u001b[32m 48\u001b[39m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[32m 49\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdumps\u001b[39m(\u001b[38;5;28mcls\u001b[39m, obj, protocol=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m 50\u001b[39m buf = io.BytesIO()\n\u001b[32m---> \u001b[39m\u001b[32m51\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mbuf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprotocol\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdump\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 52\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m buf.getbuffer()\n",
|
|
||||||
"\u001b[31mPicklingError\u001b[39m: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"def process_row(row, split_dir, tensor_dir):\n",
|
|
||||||
" img_path = split_dir / row.filename\n",
|
|
||||||
" img = Image.open(img_path).convert(\"RGB\")\n",
|
|
||||||
" tensor = TRANSFORM(img)\n",
|
|
||||||
" tensor_path = tensor_dir / f\"{row.filename}.pt\"\n",
|
|
||||||
" torch.save(tensor, tensor_path)\n",
|
|
||||||
" return {\"tensor_path\": str(tensor_path), \"label\": row.label}\n",
|
|
||||||
"\n",
|
|
||||||
"for split in [\"train\",\"val\",\"test\"]:\n",
|
|
||||||
" split_dir = DATA_ROOT / split / \"images\"\n",
|
|
||||||
" tensor_dir = DATA_ROOT / split / \"tensors\"\n",
|
|
||||||
" tensor_dir.mkdir(exist_ok=True, parents=True)\n",
|
|
||||||
"\n",
|
|
||||||
" df = pd.read_csv(DATA_ROOT / split / \"labels.csv\")\n",
|
|
||||||
"\n",
|
|
||||||
" results = []\n",
|
|
||||||
" with ProcessPoolExecutor(max_workers=12) as executor:\n",
|
|
||||||
" futures = [executor.submit(process_row, row, split_dir, tensor_dir) for row in df.itertuples()]\n",
|
|
||||||
" for f in tqdm(futures):\n",
|
|
||||||
" results.append(f.result())\n",
|
|
||||||
"\n",
|
|
||||||
" new_df = pd.DataFrame(results)\n",
|
|
||||||
" new_df.to_csv(DATA_ROOT / split / \"labels_tensor.csv\", index=False)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "thesis-py3.11",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.7"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
||||||
@@ -1,134 +0,0 @@
|
|||||||
import os
|
|
||||||
import random
|
|
||||||
from pathlib import Path
|
|
||||||
from tqdm.notebook import tqdm
|
|
||||||
import webdataset as wds
|
|
||||||
|
|
||||||
# --- НАСТРОЙКИ ---
|
|
||||||
# Оригинальная папка с датасетом (на NFS)
|
|
||||||
DATA_ROOT = Path("/home/zin/projects/Thesis/NFS/Thesis/Emoset/Original-2.41M")
|
|
||||||
|
|
||||||
# Новая папка, куда мы сложим готовые .tar архивы (шарды)
|
|
||||||
# Лучше создать её рядом с оригинальным датасетом на NFS
|
|
||||||
SHARDS_DIR = Path("/home/zin/projects/Thesis/NFS/Thesis/Emoset/shards-2.41M")
|
|
||||||
SHARDS_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Маппинг классов
|
|
||||||
EMO_MAP = {
|
|
||||||
"amusement": 0, "anger": 1, "awe": 2, "contentment": 3,
|
|
||||||
"disgust": 4, "excitement": 5, "fear": 6, "sad": 7, "sadness": 7
|
|
||||||
}
|
|
||||||
|
|
||||||
# Размер одного архива. 10 000 картинок — идеальный баланс
|
|
||||||
MAX_SAMPLES_PER_SHARD = 10000
|
|
||||||
|
|
||||||
samples = []
|
|
||||||
|
|
||||||
print(f"🔍 Сканирование директории {DATA_ROOT}...")
|
|
||||||
# Используем os.walk, он часто работает быстрее rglob на сетевых дисках
|
|
||||||
for root, dirs, files in os.walk(DATA_ROOT):
|
|
||||||
for file in files:
|
|
||||||
if file.lower().endswith('.jpg'):
|
|
||||||
full_path = os.path.join(root, file)
|
|
||||||
# Извлекаем эмоцию (зависит от структуры папок, берем предпоследнюю папку)
|
|
||||||
# Путь: .../amusement/0/image.jpg -> root_parts[-2] будет 'amusement'
|
|
||||||
path_parts = Path(full_path).parts
|
|
||||||
emotion_folder = path_parts[-3].lower()
|
|
||||||
|
|
||||||
if emotion_folder in EMO_MAP:
|
|
||||||
samples.append((full_path, EMO_MAP[emotion_folder]))
|
|
||||||
|
|
||||||
print(f"✅ Найдено изображений: {len(samples)}")
|
|
||||||
|
|
||||||
# САМЫЙ ВАЖНЫЙ ШАГ: Глобальное перемешивание перед упаковкой
|
|
||||||
print("🔀 Перемешиваем датасет...")
|
|
||||||
random.shuffle(samples)
|
|
||||||
print("✅ Перемешивание завершено!")
|
|
||||||
|
|
||||||
import multiprocessing as mp
|
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
||||||
import webdataset as wds
|
|
||||||
from PIL import Image
|
|
||||||
import io
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# ВАЖНО: Импортируем базовый tqdm, а не notebook-версию.
|
|
||||||
# Notebook-версия в мультипроцессинге вызывает зависание графического интерфейса Jupyter.
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# --- ПУТИ И НАСТРОЙКИ ---
|
|
||||||
SHARDS_DIR = Path("../../dataset/EmoSet-2.41M-shards")
|
|
||||||
SHARDS_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
NUM_WORKERS = 50
|
|
||||||
|
|
||||||
# 1. Дробим наш список на чанки
|
|
||||||
chunks = [samples[i:i + MAX_SAMPLES_PER_SHARD] for i in range(0, len(samples), MAX_SAMPLES_PER_SHARD)]
|
|
||||||
|
|
||||||
print(f"📦 Подготовлено {len(chunks)} задач (шардов).")
|
|
||||||
print(f"💾 Целевая папка (Локальный SSD): {SHARDS_DIR}")
|
|
||||||
print(f"🚀 Запуск упаковки и сжатия в {NUM_WORKERS} потоков...\n")
|
|
||||||
|
|
||||||
# Инициализация блокировки tqdm для мультипроцессинга (чтобы бары не съезжали)
|
|
||||||
tqdm.set_lock(mp.RLock())
|
|
||||||
|
|
||||||
# 2. Функция, которую выполняет каждый воркер
|
|
||||||
def build_shard(args):
|
|
||||||
shard_idx, chunk = args
|
|
||||||
shard_path = SHARDS_DIR / f"emoset-{shard_idx:06d}.tar"
|
|
||||||
|
|
||||||
success_count = 0
|
|
||||||
error_count = 0
|
|
||||||
|
|
||||||
# ХИТРОСТЬ ЗДЕСЬ: position = остаток от деления + 1.
|
|
||||||
# Это гарантирует, что все 42 воркера поделят между собой 42 строчки на экране,
|
|
||||||
# и когда воркер берет новый шард, он обновляет свою старую строчку, а не создает новую.
|
|
||||||
# leave=False заставит бар исчезнуть, когда чанк докачается.
|
|
||||||
worker_pos = (shard_idx % NUM_WORKERS) + 1
|
|
||||||
|
|
||||||
with wds.TarWriter(str(shard_path)) as sink:
|
|
||||||
# Рисуем прогресс-бар для текущего шарда
|
|
||||||
for i, (img_path, label) in enumerate(tqdm(chunk, desc=f"Шард {shard_idx:03d}", position=worker_pos, leave=False)):
|
|
||||||
try:
|
|
||||||
# --- МАГИЯ СЖАТИЯ ---
|
|
||||||
with Image.open(img_path) as img:
|
|
||||||
img = img.convert("RGB")
|
|
||||||
img = img.resize((256, 256), Image.Resampling.BILINEAR)
|
|
||||||
|
|
||||||
with io.BytesIO() as img_byte_arr:
|
|
||||||
img.save(img_byte_arr, format='JPEG', quality=85)
|
|
||||||
image_data = img_byte_arr.getvalue()
|
|
||||||
# --------------------
|
|
||||||
|
|
||||||
key = f"{shard_idx:06d}_{i:05d}"
|
|
||||||
|
|
||||||
sink.write({
|
|
||||||
"__key__": key,
|
|
||||||
"jpg": image_data,
|
|
||||||
"cls": label
|
|
||||||
})
|
|
||||||
success_count += 1
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
error_count += 1
|
|
||||||
continue # Игнорируем битые файлы
|
|
||||||
|
|
||||||
return shard_idx
|
|
||||||
|
|
||||||
# 3. Запускаем армию воркеров
|
|
||||||
with ProcessPoolExecutor(max_workers=NUM_WORKERS) as executor:
|
|
||||||
tasks = [(i, chunk) for i, chunk in enumerate(chunks)]
|
|
||||||
|
|
||||||
# Отправляем задачи в пул
|
|
||||||
futures = [executor.submit(build_shard, task) for task in tasks]
|
|
||||||
|
|
||||||
# ГЛАВНЫЙ прогресс-бар (position=0, всегда висит на самой первой строчке)
|
|
||||||
for future in tqdm(as_completed(futures), total=len(tasks), desc="📊 ОБЩИЙ ПРОГРЕСС", position=0, leave=True):
|
|
||||||
try:
|
|
||||||
future.result()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Печатаем пару пустых строк, чтобы финальный текст не налез на бары
|
|
||||||
print("\n" * (NUM_WORKERS + 2))
|
|
||||||
print("🎉 ПАРАЛЛЕЛЬНАЯ УПАКОВКА И СЖАТИЕ ПОЛНОСТЬЮ ЗАВЕРШЕНЫ!")
|
|
||||||
@@ -1,919 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"id": "e6aa65e8",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import os\n",
|
|
||||||
"import random\n",
|
|
||||||
"from pathlib import Path\n",
|
|
||||||
"from tqdm.notebook import tqdm\n",
|
|
||||||
"import webdataset as wds\n",
|
|
||||||
"\n",
|
|
||||||
"# --- НАСТРОЙКИ ---\n",
|
|
||||||
"# Оригинальная папка с датасетом (на NFS)\n",
|
|
||||||
"DATA_ROOT = Path(\"/home/zin/projects/Thesis/NFS/Thesis/Emoset/Original-2.41M\")\n",
|
|
||||||
"\n",
|
|
||||||
"# Новая папка, куда мы сложим готовые .tar архивы (шарды)\n",
|
|
||||||
"# Лучше создать её рядом с оригинальным датасетом на NFS\n",
|
|
||||||
"SHARDS_DIR = Path(\"/home/zin/projects/Thesis/NFS/Thesis/Emoset/shards-2.41M\")\n",
|
|
||||||
"SHARDS_DIR.mkdir(parents=True, exist_ok=True)\n",
|
|
||||||
"\n",
|
|
||||||
"# Маппинг классов\n",
|
|
||||||
"EMO_MAP = {\n",
|
|
||||||
" \"amusement\": 0, \"anger\": 1, \"awe\": 2, \"contentment\": 3,\n",
|
|
||||||
" \"disgust\": 4, \"excitement\": 5, \"fear\": 6, \"sad\": 7, \"sadness\": 7\n",
|
|
||||||
"}\n",
|
|
||||||
"\n",
|
|
||||||
"# Размер одного архива. 10 000 картинок — идеальный баланс\n",
|
|
||||||
"MAX_SAMPLES_PER_SHARD = 10000"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"id": "09d0e56c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"🔍 Сканирование директории /home/zin/projects/Thesis/NFS/Thesis/Emoset/Original-2.41M...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"ename": "KeyboardInterrupt",
|
|
||||||
"evalue": "",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
|
||||||
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
|
|
||||||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 11\u001b[39m\n\u001b[32m 8\u001b[39m full_path = os.path.join(root, file)\n\u001b[32m 9\u001b[39m \u001b[38;5;66;03m# Извлекаем эмоцию (зависит от структуры папок, берем предпоследнюю папку)\u001b[39;00m\n\u001b[32m 10\u001b[39m \u001b[38;5;66;03m# Путь: .../amusement/0/image.jpg -> root_parts[-2] будет 'amusement'\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m11\u001b[39m path_parts = \u001b[43mPath\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfull_path\u001b[49m\u001b[43m)\u001b[49m.parts\n\u001b[32m 12\u001b[39m emotion_folder = path_parts[-\u001b[32m3\u001b[39m].lower()\n\u001b[32m 14\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m emotion_folder \u001b[38;5;129;01min\u001b[39;00m EMO_MAP:\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/pathlib.py:871\u001b[39m, in \u001b[36mPath.__new__\u001b[39m\u001b[34m(cls, *args, **kwargs)\u001b[39m\n\u001b[32m 869\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m Path:\n\u001b[32m 870\u001b[39m \u001b[38;5;28mcls\u001b[39m = WindowsPath \u001b[38;5;28;01mif\u001b[39;00m os.name == \u001b[33m'\u001b[39m\u001b[33mnt\u001b[39m\u001b[33m'\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m PosixPath\n\u001b[32m--> \u001b[39m\u001b[32m871\u001b[39m \u001b[38;5;28mself\u001b[39m = \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_from_parts\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 872\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._flavour.is_supported:\n\u001b[32m 873\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mcannot instantiate \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[33m on your system\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 874\u001b[39m % (\u001b[38;5;28mcls\u001b[39m.\u001b[34m__name__\u001b[39m,))\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/pathlib.py:509\u001b[39m, in \u001b[36mPurePath._from_parts\u001b[39m\u001b[34m(cls, args)\u001b[39m\n\u001b[32m 504\u001b[39m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[32m 505\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_from_parts\u001b[39m(\u001b[38;5;28mcls\u001b[39m, args):\n\u001b[32m 506\u001b[39m \u001b[38;5;66;03m# We need to call _parse_args on the instance, so as to get the\u001b[39;00m\n\u001b[32m 507\u001b[39m \u001b[38;5;66;03m# right flavour.\u001b[39;00m\n\u001b[32m 508\u001b[39m \u001b[38;5;28mself\u001b[39m = \u001b[38;5;28mobject\u001b[39m.\u001b[34m__new__\u001b[39m(\u001b[38;5;28mcls\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m509\u001b[39m drv, root, parts = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_parse_args\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 510\u001b[39m \u001b[38;5;28mself\u001b[39m._drv = drv\n\u001b[32m 511\u001b[39m \u001b[38;5;28mself\u001b[39m._root = root\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/pathlib.py:502\u001b[39m, in \u001b[36mPurePath._parse_args\u001b[39m\u001b[34m(cls, args)\u001b[39m\n\u001b[32m 497\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 498\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[32m 499\u001b[39m \u001b[33m\"\u001b[39m\u001b[33margument should be a str object or an os.PathLike \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 500\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mobject returning str, not \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 501\u001b[39m % \u001b[38;5;28mtype\u001b[39m(a))\n\u001b[32m--> \u001b[39m\u001b[32m502\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_flavour\u001b[49m\u001b[43m.\u001b[49m\u001b[43mparse_parts\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparts\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/pathlib.py:67\u001b[39m, in \u001b[36m_Flavour.parse_parts\u001b[39m\u001b[34m(self, parts)\u001b[39m\n\u001b[32m 65\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m altsep:\n\u001b[32m 66\u001b[39m part = part.replace(altsep, sep)\n\u001b[32m---> \u001b[39m\u001b[32m67\u001b[39m drv, root, rel = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43msplitroot\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpart\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 68\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m sep \u001b[38;5;129;01min\u001b[39;00m rel:\n\u001b[32m 69\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mreversed\u001b[39m(rel.split(sep)):\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/pathlib.py:241\u001b[39m, in \u001b[36m_PosixFlavour.splitroot\u001b[39m\u001b[34m(self, part, sep)\u001b[39m\n\u001b[32m 239\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34msplitroot\u001b[39m(\u001b[38;5;28mself\u001b[39m, part, sep=sep):\n\u001b[32m 240\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m part \u001b[38;5;129;01mand\u001b[39;00m part[\u001b[32m0\u001b[39m] == sep:\n\u001b[32m--> \u001b[39m\u001b[32m241\u001b[39m stripped_part = part.lstrip(sep)\n\u001b[32m 242\u001b[39m \u001b[38;5;66;03m# According to POSIX path resolution:\u001b[39;00m\n\u001b[32m 243\u001b[39m \u001b[38;5;66;03m# http://pubs.opengroup.org/onlinepubs/009695399/basedefs/xbd_chap04.html#tag_04_11\u001b[39;00m\n\u001b[32m 244\u001b[39m \u001b[38;5;66;03m# \"A pathname that begins with two successive slashes may be\u001b[39;00m\n\u001b[32m 245\u001b[39m \u001b[38;5;66;03m# interpreted in an implementation-defined manner, although more\u001b[39;00m\n\u001b[32m 246\u001b[39m \u001b[38;5;66;03m# than two leading slashes shall be treated as a single slash\".\u001b[39;00m\n\u001b[32m 247\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(part) - \u001b[38;5;28mlen\u001b[39m(stripped_part) == \u001b[32m2\u001b[39m:\n",
|
|
||||||
"\u001b[31mKeyboardInterrupt\u001b[39m: "
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"samples = []\n",
|
|
||||||
"\n",
|
|
||||||
"print(f\"🔍 Сканирование директории {DATA_ROOT}...\")\n",
|
|
||||||
"# Используем os.walk, он часто работает быстрее rglob на сетевых дисках\n",
|
|
||||||
"for root, dirs, files in os.walk(DATA_ROOT):\n",
|
|
||||||
" for file in files:\n",
|
|
||||||
" if file.lower().endswith('.jpg'):\n",
|
|
||||||
" full_path = os.path.join(root, file)\n",
|
|
||||||
" # Извлекаем эмоцию (зависит от структуры папок, берем предпоследнюю папку)\n",
|
|
||||||
" # Путь: .../amusement/0/image.jpg -> root_parts[-2] будет 'amusement'\n",
|
|
||||||
" path_parts = Path(full_path).parts\n",
|
|
||||||
" emotion_folder = path_parts[-3].lower()\n",
|
|
||||||
" \n",
|
|
||||||
" if emotion_folder in EMO_MAP:\n",
|
|
||||||
" samples.append((full_path, EMO_MAP[emotion_folder]))\n",
|
|
||||||
"\n",
|
|
||||||
"print(f\"✅ Найдено изображений: {len(samples)}\")\n",
|
|
||||||
"\n",
|
|
||||||
"# САМЫЙ ВАЖНЫЙ ШАГ: Глобальное перемешивание перед упаковкой\n",
|
|
||||||
"print(\"🔀 Перемешиваем датасет...\")\n",
|
|
||||||
"random.shuffle(samples)\n",
|
|
||||||
"print(\"✅ Перемешивание завершено!\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "0fe71d72",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"📦 Подготовлено 205 задач (шардов).\n",
|
|
||||||
"💾 Целевая папка: ../../dataset/EmoSet-2.41M-shards\n",
|
|
||||||
"🚀 Запуск упаковки в 42 потоков...\n",
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Exception in thread Thread-4 (ui_thread_func):\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/threading.py\", line 1045, in _bootstrap_inner\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/threading.py\", line 982, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/tmp/ipykernel_16083/731719608.py\", line 72, in ui_thread_func\n",
|
|
||||||
" File \"/home/zin/projects/Thesis/.venv/lib/python3.11/site-packages/tqdm/notebook.py\", line 223, in __init__\n",
|
|
||||||
" super().__init__(*args, **kwargs)\n",
|
|
||||||
" File \"/home/zin/projects/Thesis/.venv/lib/python3.11/site-packages/tqdm/std.py\", line 1001, in __init__\n",
|
|
||||||
" raise (\n",
|
|
||||||
"tqdm.std.TqdmKeyError: \"Unknown argument(s): {'color': 'blue'}\"\n",
|
|
||||||
"Process ForkProcess-39:\n",
|
|
||||||
"Process ForkProcess-35:\n",
|
|
||||||
"Process ForkProcess-28:\n",
|
|
||||||
"Process ForkProcess-29:\n",
|
|
||||||
"Process ForkProcess-43:\n",
|
|
||||||
"Process ForkProcess-38:\n",
|
|
||||||
"Process ForkProcess-36:\n",
|
|
||||||
"Process ForkProcess-37:\n",
|
|
||||||
"Process ForkProcess-34:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Process ForkProcess-27:\n",
|
|
||||||
"Process ForkProcess-25:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Process ForkProcess-30:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
"Process ForkProcess-41:\n",
|
|
||||||
"Process ForkProcess-31:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
"Process ForkProcess-22:\n",
|
|
||||||
"Process ForkProcess-32:\n",
|
|
||||||
"Process ForkProcess-40:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Process ForkProcess-42:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
"Process ForkProcess-20:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Process ForkProcess-24:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
"Process ForkProcess-23:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
"Process ForkProcess-26:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
"Process ForkProcess-21:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Process ForkProcess-16:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Process ForkProcess-13:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
"Process ForkProcess-9:\n",
|
|
||||||
"Process ForkProcess-17:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
"Process ForkProcess-6:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"Process ForkProcess-8:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
"Process ForkProcess-3:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Process ForkProcess-10:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"Process ForkProcess-7:\n",
|
|
||||||
"Process ForkProcess-11:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Process ForkProcess-14:\n",
|
|
||||||
"Process ForkProcess-4:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Process ForkProcess-5:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/threading.py\", line 982, in run\n",
|
|
||||||
"Process ForkProcess-2:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Process ForkProcess-12:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Process ForkProcess-15:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
"Process ForkProcess-18:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"ename": "KeyboardInterrupt",
|
|
||||||
"evalue": "",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
|
||||||
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
|
|
||||||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 102\u001b[39m\n\u001b[32m 101\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m ProcessPoolExecutor(max_workers=NUM_WORKERS) \u001b[38;5;28;01mas\u001b[39;00m executor:\n\u001b[32m--> \u001b[39m\u001b[32m102\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m_\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mexecutor\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbuild_shard\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtasks\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 103\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mpass\u001b[39;49;00m \u001b[38;5;66;03m# Просто ждем завершения всех задач\u001b[39;00m\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py:620\u001b[39m, in \u001b[36m_chain_from_iterable_of_lists\u001b[39m\u001b[34m(iterable)\u001b[39m\n\u001b[32m 615\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 616\u001b[39m \u001b[33;03mSpecialized implementation of itertools.chain.from_iterable.\u001b[39;00m\n\u001b[32m 617\u001b[39m \u001b[33;03mEach item in *iterable* should be a list. This function is\u001b[39;00m\n\u001b[32m 618\u001b[39m \u001b[33;03mcareful not to keep references to yielded objects.\u001b[39;00m\n\u001b[32m 619\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m620\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43melement\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43miterable\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 621\u001b[39m \u001b[43m \u001b[49m\u001b[43melement\u001b[49m\u001b[43m.\u001b[49m\u001b[43mreverse\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:619\u001b[39m, in \u001b[36mExecutor.map.<locals>.result_iterator\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m 618\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m619\u001b[39m \u001b[38;5;28;01myield\u001b[39;00m \u001b[43m_result_or_cancel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpop\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 620\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:317\u001b[39m, in \u001b[36m_result_or_cancel\u001b[39m\u001b[34m(***failed resolving arguments***)\u001b[39m\n\u001b[32m 316\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m317\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfut\u001b[49m\u001b[43m.\u001b[49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 318\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:451\u001b[39m, in \u001b[36mFuture.result\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 449\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.__get_result()\n\u001b[32m--> \u001b[39m\u001b[32m451\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_condition\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 453\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state \u001b[38;5;129;01min\u001b[39;00m [CANCELLED, CANCELLED_AND_NOTIFIED]:\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/threading.py:327\u001b[39m, in \u001b[36mCondition.wait\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 326\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m327\u001b[39m \u001b[43mwaiter\u001b[49m\u001b[43m.\u001b[49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 328\u001b[39m gotit = \u001b[38;5;28;01mTrue\u001b[39;00m\n",
|
|
||||||
"\u001b[31mKeyboardInterrupt\u001b[39m: ",
|
|
||||||
"\nDuring handling of the above exception, another exception occurred:\n",
|
|
||||||
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
|
|
||||||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 101\u001b[39m\n\u001b[32m 99\u001b[39m \u001b[38;5;66;03m# 3. Запускаем 42 боевых ядра\u001b[39;00m\n\u001b[32m 100\u001b[39m tasks = [(i, chunk, queue) \u001b[38;5;28;01mfor\u001b[39;00m i, chunk \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(chunks)]\n\u001b[32m--> \u001b[39m\u001b[32m101\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mwith\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mProcessPoolExecutor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmax_workers\u001b[49m\u001b[43m=\u001b[49m\u001b[43mNUM_WORKERS\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mas\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mexecutor\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 102\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m_\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mexecutor\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbuild_shard\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtasks\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 103\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mpass\u001b[39;49;00m \u001b[38;5;66;03m# Просто ждем завершения всех задач\u001b[39;00m\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/_base.py:647\u001b[39m, in \u001b[36mExecutor.__exit__\u001b[39m\u001b[34m(self, exc_type, exc_val, exc_tb)\u001b[39m\n\u001b[32m 646\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__exit__\u001b[39m(\u001b[38;5;28mself\u001b[39m, exc_type, exc_val, exc_tb):\n\u001b[32m--> \u001b[39m\u001b[32m647\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mshutdown\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwait\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 648\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py:851\u001b[39m, in \u001b[36mProcessPoolExecutor.shutdown\u001b[39m\u001b[34m(self, wait, cancel_futures)\u001b[39m\n\u001b[32m 848\u001b[39m \u001b[38;5;28mself\u001b[39m._executor_manager_thread_wakeup.wakeup()\n\u001b[32m 850\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._executor_manager_thread \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m wait:\n\u001b[32m--> \u001b[39m\u001b[32m851\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_executor_manager_thread\u001b[49m\u001b[43m.\u001b[49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 852\u001b[39m \u001b[38;5;66;03m# To reduce the risk of opening too many files, remove references to\u001b[39;00m\n\u001b[32m 853\u001b[39m \u001b[38;5;66;03m# objects that use file descriptors.\u001b[39;00m\n\u001b[32m 854\u001b[39m \u001b[38;5;28mself\u001b[39m._executor_manager_thread = \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/threading.py:1119\u001b[39m, in \u001b[36mThread.join\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 1116\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mcannot join current thread\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 1118\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1119\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_wait_for_tstate_lock\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1120\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 1121\u001b[39m \u001b[38;5;66;03m# the behavior of a negative timeout isn't documented, but\u001b[39;00m\n\u001b[32m 1122\u001b[39m \u001b[38;5;66;03m# historically .join(timeout=x) for x<0 has acted as if timeout=0\u001b[39;00m\n\u001b[32m 1123\u001b[39m \u001b[38;5;28mself\u001b[39m._wait_for_tstate_lock(timeout=\u001b[38;5;28mmax\u001b[39m(timeout, \u001b[32m0\u001b[39m))\n",
|
|
||||||
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/versions/3.11.7/lib/python3.11/threading.py:1139\u001b[39m, in \u001b[36mThread._wait_for_tstate_lock\u001b[39m\u001b[34m(self, block, timeout)\u001b[39m\n\u001b[32m 1136\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[32m 1138\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1139\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mlock\u001b[49m\u001b[43m.\u001b[49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43mblock\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[32m 1140\u001b[39m lock.release()\n\u001b[32m 1141\u001b[39m \u001b[38;5;28mself\u001b[39m._stop()\n",
|
|
||||||
"\u001b[31mKeyboardInterrupt\u001b[39m: "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
"Traceback (most recent call last):\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
"Process ForkProcess-19:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 314, in _bootstrap\n",
|
|
||||||
" self.run()\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/process.py\", line 108, in run\n",
|
|
||||||
" self._target(*self._args, **self._kwargs)\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 103, in get\n",
|
|
||||||
" res = self._recv_bytes()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/concurrent/futures/process.py\", line 249, in _process_worker\n",
|
|
||||||
" call_item = call_queue.get(block=True)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/connection.py\", line 216, in recv_bytes\n",
|
|
||||||
" buf = self._recv_bytes(maxlength)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/queues.py\", line 102, in get\n",
|
|
||||||
" with self._rlock:\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/connection.py\", line 430, in _recv_bytes\n",
|
|
||||||
" buf = self._recv(4)\n",
|
|
||||||
" ^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/synchronize.py\", line 95, in __enter__\n",
|
|
||||||
" return self._semlock.__enter__()\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
" File \"/home/zin/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/connection.py\", line 395, in _recv\n",
|
|
||||||
" chunk = read(handle, remaining)\n",
|
|
||||||
" ^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
||||||
"KeyboardInterrupt\n",
|
|
||||||
"KeyboardInterrupt\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import multiprocessing as mp\n",
|
|
||||||
"from concurrent.futures import ProcessPoolExecutor\n",
|
|
||||||
"import webdataset as wds\n",
|
|
||||||
"from PIL import Image\n",
|
|
||||||
"import io\n",
|
|
||||||
"import threading\n",
|
|
||||||
"from pathlib import Path\n",
|
|
||||||
"\n",
|
|
||||||
"# Включаем красивую версию tqdm специально для Jupyter\n",
|
|
||||||
"from tqdm.notebook import tqdm\n",
|
|
||||||
"\n",
|
|
||||||
"# --- ПУТИ И НАСТРОЙКИ ---\n",
|
|
||||||
"SHARDS_DIR = Path(\"../../dataset/EmoSet-2.41M-shards\")\n",
|
|
||||||
"SHARDS_DIR.mkdir(parents=True, exist_ok=True)\n",
|
|
||||||
"\n",
|
|
||||||
"NUM_WORKERS = 42\n",
|
|
||||||
"MAX_SAMPLES_PER_SHARD = 10000\n",
|
|
||||||
"\n",
|
|
||||||
"# Дробим список на чанки\n",
|
|
||||||
"chunks = [samples[i:i + MAX_SAMPLES_PER_SHARD] for i in range(0, len(samples), MAX_SAMPLES_PER_SHARD)]\n",
|
|
||||||
"TOTAL_FILES = len(samples)\n",
|
|
||||||
"TOTAL_SHARDS = len(chunks)\n",
|
|
||||||
"\n",
|
|
||||||
"print(f\"📦 Подготовлено {TOTAL_SHARDS} задач (шардов).\")\n",
|
|
||||||
"print(f\"💾 Целевая папка: {SHARDS_DIR}\")\n",
|
|
||||||
"print(f\"🚀 Запуск упаковки в {NUM_WORKERS} потоков...\\n\")\n",
|
|
||||||
"\n",
|
|
||||||
"# --- ФУНКЦИЯ ДЛЯ ЯДЕР ПРОЦЕССОРА ---\n",
|
|
||||||
"def build_shard(args):\n",
|
|
||||||
" shard_idx, chunk, queue = args\n",
|
|
||||||
" shard_path = SHARDS_DIR / f\"emoset-{shard_idx:06d}.tar\"\n",
|
|
||||||
" \n",
|
|
||||||
" with wds.TarWriter(str(shard_path)) as sink:\n",
|
|
||||||
" for i, (img_path, label) in enumerate(chunk):\n",
|
|
||||||
" try:\n",
|
|
||||||
" # Магия сжатия\n",
|
|
||||||
" with Image.open(img_path) as img:\n",
|
|
||||||
" img = img.convert(\"RGB\")\n",
|
|
||||||
" img = img.resize((256, 256), Image.Resampling.BILINEAR)\n",
|
|
||||||
" with io.BytesIO() as img_byte_arr:\n",
|
|
||||||
" img.save(img_byte_arr, format='JPEG', quality=85)\n",
|
|
||||||
" image_data = img_byte_arr.getvalue()\n",
|
|
||||||
" \n",
|
|
||||||
" key = f\"{shard_idx:06d}_{i:05d}\"\n",
|
|
||||||
" sink.write({\n",
|
|
||||||
" \"__key__\": key,\n",
|
|
||||||
" \"jpg\": image_data,\n",
|
|
||||||
" \"cls\": label\n",
|
|
||||||
" })\n",
|
|
||||||
" \n",
|
|
||||||
" # Чтобы не перегружать очередь, отправляем отчет каждые 50 файлов\n",
|
|
||||||
" if (i + 1) % 50 == 0:\n",
|
|
||||||
" queue.put((\"file\", 50))\n",
|
|
||||||
" \n",
|
|
||||||
" except Exception:\n",
|
|
||||||
" # Если файл битый, всё равно считаем его \"пройденным\", чтобы бар не застрял\n",
|
|
||||||
" queue.put((\"file\", 1))\n",
|
|
||||||
" continue\n",
|
|
||||||
" \n",
|
|
||||||
" # Сообщаем об остатке файлов в чанке, которые не попали в % 50\n",
|
|
||||||
" remainder = len(chunk) % 50\n",
|
|
||||||
" if remainder != 0:\n",
|
|
||||||
" queue.put((\"file\", remainder))\n",
|
|
||||||
" \n",
|
|
||||||
" # Сообщаем, что целый шард готов\n",
|
|
||||||
" queue.put((\"shard\", 1))\n",
|
|
||||||
" return shard_idx\n",
|
|
||||||
"\n",
|
|
||||||
"# --- ФУНКЦИЯ ОТРИСОВКИ ИНТЕРФЕЙСА (Фоновый поток) ---\n",
|
|
||||||
"def ui_thread_func(q, total_files, total_shards):\n",
|
|
||||||
" # Создаем две красивые независимые полоски\n",
|
|
||||||
" pbar_files = tqdm(total=total_files, desc=\"🖼️ Сжато файлов\", color=\"blue\")\n",
|
|
||||||
" pbar_shards = tqdm(total=total_shards, desc=\"📦 Готово архивов\", color=\"green\")\n",
|
|
||||||
" \n",
|
|
||||||
" while True:\n",
|
|
||||||
" msg = q.get()\n",
|
|
||||||
" if msg == \"DONE\":\n",
|
|
||||||
" break\n",
|
|
||||||
" \n",
|
|
||||||
" msg_type, count = msg\n",
|
|
||||||
" if msg_type == \"file\":\n",
|
|
||||||
" pbar_files.update(count)\n",
|
|
||||||
" elif msg_type == \"shard\":\n",
|
|
||||||
" pbar_shards.update(count)\n",
|
|
||||||
" \n",
|
|
||||||
" pbar_files.close()\n",
|
|
||||||
" pbar_shards.close()\n",
|
|
||||||
"\n",
|
|
||||||
"# === ГЛАВНЫЙ ЗАПУСК ===\n",
|
|
||||||
"if __name__ == '__main__':\n",
|
|
||||||
" # 1. Создаем диспетчер очередей\n",
|
|
||||||
" manager = mp.Manager()\n",
|
|
||||||
" queue = manager.Queue()\n",
|
|
||||||
" \n",
|
|
||||||
" # 2. Запускаем фоновый поток отрисовки\n",
|
|
||||||
" ui_thread = threading.Thread(target=ui_thread_func, args=(queue, TOTAL_FILES, TOTAL_SHARDS))\n",
|
|
||||||
" ui_thread.start()\n",
|
|
||||||
" \n",
|
|
||||||
" # 3. Запускаем 42 боевых ядра\n",
|
|
||||||
" tasks = [(i, chunk, queue) for i, chunk in enumerate(chunks)]\n",
|
|
||||||
" with ProcessPoolExecutor(max_workers=NUM_WORKERS) as executor:\n",
|
|
||||||
" for _ in executor.map(build_shard, tasks):\n",
|
|
||||||
" pass # Просто ждем завершения всех задач\n",
|
|
||||||
" \n",
|
|
||||||
" # 4. Убиваем поток отрисовки и завершаем работу\n",
|
|
||||||
" queue.put(\"DONE\")\n",
|
|
||||||
" ui_thread.join()\n",
|
|
||||||
" \n",
|
|
||||||
" print(\"\\n🎉 ПАРАЛЛЕЛЬНАЯ УПАКОВКА И СЖАТИЕ ПОЛНОСТЬЮ ЗАВЕРШЕНЫ!\")"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python (my-python-project)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "my-python-project"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.7"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
||||||
@@ -1,199 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "ca08df84",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Using device: cuda\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Step 0/1000, Loss: 1.0013\n",
|
|
||||||
"Step 10/1000, Loss: 1.0088\n",
|
|
||||||
"Step 20/1000, Loss: 0.9956\n",
|
|
||||||
"Step 30/1000, Loss: 0.9781\n",
|
|
||||||
"Step 40/1000, Loss: 0.9613\n",
|
|
||||||
"Step 50/1000, Loss: 0.9313\n",
|
|
||||||
"Step 60/1000, Loss: 0.8927\n",
|
|
||||||
"Step 70/1000, Loss: 0.8503\n",
|
|
||||||
"Step 80/1000, Loss: 0.7537\n",
|
|
||||||
"Step 90/1000, Loss: 0.6689\n",
|
|
||||||
"Step 100/1000, Loss: 0.6063\n",
|
|
||||||
"Step 110/1000, Loss: 0.5172\n",
|
|
||||||
"Step 120/1000, Loss: 0.4592\n",
|
|
||||||
"Step 130/1000, Loss: 0.4044\n",
|
|
||||||
"Step 140/1000, Loss: 0.3610\n",
|
|
||||||
"Step 150/1000, Loss: 0.3175\n",
|
|
||||||
"Step 160/1000, Loss: 0.2825\n",
|
|
||||||
"Step 170/1000, Loss: 0.2560\n",
|
|
||||||
"Step 180/1000, Loss: 0.2360\n",
|
|
||||||
"Step 190/1000, Loss: 0.2203\n",
|
|
||||||
"Step 200/1000, Loss: 0.1930\n",
|
|
||||||
"Step 210/1000, Loss: 0.1854\n",
|
|
||||||
"Step 220/1000, Loss: 0.1723\n",
|
|
||||||
"Step 230/1000, Loss: 0.1546\n",
|
|
||||||
"Step 240/1000, Loss: 0.1386\n",
|
|
||||||
"Step 250/1000, Loss: 0.1271\n",
|
|
||||||
"Step 260/1000, Loss: 0.1109\n",
|
|
||||||
"Step 270/1000, Loss: 0.1032\n",
|
|
||||||
"Step 280/1000, Loss: 0.0899\n",
|
|
||||||
"Step 290/1000, Loss: 0.0807\n",
|
|
||||||
"Step 300/1000, Loss: 0.0750\n",
|
|
||||||
"Step 310/1000, Loss: 0.0813\n",
|
|
||||||
"Step 320/1000, Loss: 0.0612\n",
|
|
||||||
"Step 330/1000, Loss: 0.0544\n",
|
|
||||||
"Step 340/1000, Loss: 0.0552\n",
|
|
||||||
"Step 350/1000, Loss: 0.0446\n",
|
|
||||||
"Step 360/1000, Loss: 0.0403\n",
|
|
||||||
"Step 370/1000, Loss: 0.0350\n",
|
|
||||||
"Step 380/1000, Loss: 0.0612\n",
|
|
||||||
"Step 390/1000, Loss: 0.0364\n",
|
|
||||||
"Step 400/1000, Loss: 0.0322\n",
|
|
||||||
"Step 410/1000, Loss: 0.0302\n",
|
|
||||||
"Step 420/1000, Loss: 0.0519\n",
|
|
||||||
"Step 430/1000, Loss: 0.0319\n",
|
|
||||||
"Step 440/1000, Loss: 0.0260\n",
|
|
||||||
"Step 450/1000, Loss: 0.0208\n",
|
|
||||||
"Step 460/1000, Loss: 0.0409\n",
|
|
||||||
"Step 470/1000, Loss: 0.0291\n",
|
|
||||||
"Step 480/1000, Loss: 0.0234\n",
|
|
||||||
"Step 490/1000, Loss: 0.0194\n",
|
|
||||||
"Step 500/1000, Loss: 0.0274\n",
|
|
||||||
"Step 510/1000, Loss: 0.0231\n",
|
|
||||||
"Step 520/1000, Loss: 0.0199\n",
|
|
||||||
"Step 530/1000, Loss: 0.0154\n",
|
|
||||||
"Step 540/1000, Loss: 0.0278\n",
|
|
||||||
"Step 550/1000, Loss: 0.0185\n",
|
|
||||||
"Step 560/1000, Loss: 0.0180\n",
|
|
||||||
"Step 570/1000, Loss: 0.0152\n",
|
|
||||||
"Step 580/1000, Loss: 0.0132\n",
|
|
||||||
"Step 590/1000, Loss: 0.0111\n",
|
|
||||||
"Step 600/1000, Loss: 0.0396\n",
|
|
||||||
"Step 610/1000, Loss: 0.0179\n",
|
|
||||||
"Step 620/1000, Loss: 0.0148\n",
|
|
||||||
"Step 630/1000, Loss: 0.0123\n",
|
|
||||||
"Step 640/1000, Loss: 0.0265\n",
|
|
||||||
"Step 650/1000, Loss: 0.0133\n",
|
|
||||||
"Step 660/1000, Loss: 0.0128\n",
|
|
||||||
"Step 670/1000, Loss: 0.0107\n",
|
|
||||||
"Step 680/1000, Loss: 0.0142\n",
|
|
||||||
"Step 690/1000, Loss: 0.0202\n",
|
|
||||||
"Step 700/1000, Loss: 0.0125\n",
|
|
||||||
"Step 710/1000, Loss: 0.0107\n",
|
|
||||||
"Step 720/1000, Loss: 0.0140\n",
|
|
||||||
"Step 730/1000, Loss: 0.0195\n",
|
|
||||||
"Step 740/1000, Loss: 0.0148\n",
|
|
||||||
"Step 750/1000, Loss: 0.0109\n",
|
|
||||||
"Step 760/1000, Loss: 0.0094\n",
|
|
||||||
"Step 770/1000, Loss: 0.0121\n",
|
|
||||||
"Step 780/1000, Loss: 0.0233\n",
|
|
||||||
"Step 790/1000, Loss: 0.0151\n",
|
|
||||||
"Step 800/1000, Loss: 0.0134\n",
|
|
||||||
"Step 810/1000, Loss: 0.0117\n",
|
|
||||||
"Step 820/1000, Loss: 0.0124\n",
|
|
||||||
"Step 830/1000, Loss: 0.0221\n",
|
|
||||||
"Step 840/1000, Loss: 0.0161\n",
|
|
||||||
"Step 850/1000, Loss: 0.0136\n",
|
|
||||||
"Step 860/1000, Loss: 0.0161\n",
|
|
||||||
"Step 870/1000, Loss: 0.0194\n",
|
|
||||||
"Step 880/1000, Loss: 0.0145\n",
|
|
||||||
"Step 890/1000, Loss: 0.0149\n",
|
|
||||||
"Step 900/1000, Loss: 0.0232\n",
|
|
||||||
"Step 910/1000, Loss: 0.0166\n",
|
|
||||||
"Step 920/1000, Loss: 0.0156\n",
|
|
||||||
"Step 930/1000, Loss: 0.0276\n",
|
|
||||||
"Step 940/1000, Loss: 0.0176\n",
|
|
||||||
"Step 950/1000, Loss: 0.0152\n",
|
|
||||||
"Step 960/1000, Loss: 0.0162\n",
|
|
||||||
"Step 970/1000, Loss: 0.0143\n",
|
|
||||||
"Step 980/1000, Loss: 0.0136\n",
|
|
||||||
"Step 990/1000, Loss: 0.0117\n",
|
|
||||||
"Total time: 67.25 s\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import torch\n",
|
|
||||||
"import torch.nn as nn\n",
|
|
||||||
"import torch.optim as optim\n",
|
|
||||||
"import time\n",
|
|
||||||
"\n",
|
|
||||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
||||||
"print(\"Using device:\", device)\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"# Огромные параметры\n",
|
|
||||||
"N, D_in, H1, H2, H3, D_out = 300_000, 4096, 2048, 1024, 512, 10\n",
|
|
||||||
"batch_size = 16_384 # большой батч\n",
|
|
||||||
"steps = 1000 # много итераций для длительной нагрузки\n",
|
|
||||||
"\n",
|
|
||||||
"# Случайные данные на GPU\n",
|
|
||||||
"x = torch.randn(N, D_in, device=device, dtype=torch.float32)\n",
|
|
||||||
"y = torch.randn(N, D_out, device=device, dtype=torch.float32)\n",
|
|
||||||
"\n",
|
|
||||||
"model = nn.Sequential(\n",
|
|
||||||
" nn.Linear(D_in, H1),\n",
|
|
||||||
" nn.ReLU(),\n",
|
|
||||||
" nn.Linear(H1, H2),\n",
|
|
||||||
" nn.ReLU(),\n",
|
|
||||||
" nn.Linear(H2, H3),\n",
|
|
||||||
" nn.ReLU(),\n",
|
|
||||||
" nn.Linear(H3, D_out)\n",
|
|
||||||
").to(device)\n",
|
|
||||||
"\n",
|
|
||||||
"loss_fn = nn.MSELoss()\n",
|
|
||||||
"optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
|
|
||||||
"\n",
|
|
||||||
"start = time.time()\n",
|
|
||||||
"for t in range(steps):\n",
|
|
||||||
" idx = torch.randint(0, N, (batch_size,), device=device)\n",
|
|
||||||
" x_batch = x[idx]\n",
|
|
||||||
" y_batch = y[idx]\n",
|
|
||||||
"\n",
|
|
||||||
" y_pred = model(x_batch)\n",
|
|
||||||
" loss = loss_fn(y_pred, y_batch)\n",
|
|
||||||
"\n",
|
|
||||||
" optimizer.zero_grad()\n",
|
|
||||||
" loss.backward()\n",
|
|
||||||
" optimizer.step()\n",
|
|
||||||
"\n",
|
|
||||||
" if t % 10 == 0:\n",
|
|
||||||
" # замедляем вывод, чтобы можно было наблюдать\n",
|
|
||||||
" print(f\"Step {t}/{steps}, Loss: {loss.item():.4f}\")\n",
|
|
||||||
"\n",
|
|
||||||
"end = time.time()\n",
|
|
||||||
"print(f\"Total time: {end-start:.2f} s\")\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": ".venv",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.7"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from pathlib import Path
|
|
||||||
from sklearn.linear_model import RidgeCV
|
|
||||||
from sklearn.multioutput import MultiOutputRegressor
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
|
||||||
from sklearn.pipeline import Pipeline
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from sklearn.metrics import mean_squared_error, r2_score
|
|
||||||
import joblib
|
|
||||||
|
|
||||||
EMO_VA_MAP = {
|
|
||||||
0: (7.5, 6.5), # amusement
|
|
||||||
1: (2.0, 8.0), # anger
|
|
||||||
2: (6.5, 5.0), # awe
|
|
||||||
3: (7.0, 3.0), # contentment
|
|
||||||
4: (3.0, 6.0), # disgust
|
|
||||||
5: (8.0, 8.0), # excitement
|
|
||||||
6: (2.5, 7.5), # fear
|
|
||||||
7: (2.0, 2.0), # sadness
|
|
||||||
}
|
|
||||||
|
|
||||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
|
||||||
EMBEDDINGS_PATH = BASE_DIR / "emoset_test_embeddings.npy"
|
|
||||||
LABELS_PATH = BASE_DIR / "emoset_test_labels.npy"
|
|
||||||
|
|
||||||
print("Загрузка данных...")
|
|
||||||
X = np.load(EMBEDDINGS_PATH)
|
|
||||||
y_labels = np.load(LABELS_PATH)
|
|
||||||
|
|
||||||
y_va = np.array([EMO_VA_MAP[label] for label in y_labels])
|
|
||||||
X_train, X_test, y_train, y_test = train_test_split(X, y_va, test_size=0.2, random_state=42)
|
|
||||||
|
|
||||||
print("Обучение масштабатора и RidgeCV регрессора...")
|
|
||||||
model = Pipeline([
|
|
||||||
('scaler', StandardScaler()),
|
|
||||||
('regressor', MultiOutputRegressor(RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0, 1000.0])))
|
|
||||||
])
|
|
||||||
|
|
||||||
model.fit(X_train, y_train)
|
|
||||||
|
|
||||||
y_pred = model.predict(X_test)
|
|
||||||
|
|
||||||
mse = mean_squared_error(y_test, y_pred)
|
|
||||||
r2 = r2_score(y_test, y_pred)
|
|
||||||
|
|
||||||
print(f"\nУспех! Обучение завершено!")
|
|
||||||
print(f"MSE: {mse:.4f}")
|
|
||||||
print(f"R^2 Score: {r2:.4f}")
|
|
||||||
|
|
||||||
print("\n--- ДИАГНОСТИКА РАЗБРОСА ПРЕДСКАЗАНИЙ ---")
|
|
||||||
print(f"Valence: от {y_pred[:, 0].min():.2f} до {y_pred[:, 0].max():.2f} (Эталон: 2.0 - 8.0)")
|
|
||||||
print(f"Arousal: от {y_pred[:, 1].min():.2f} до {y_pred[:, 1].max():.2f} (Эталон: 2.0 - 8.0)")
|
|
||||||
|
|
||||||
output_model_path = BASE_DIR / "music_engine" / "va_regressor.pkl"
|
|
||||||
output_model_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
joblib.dump(model, output_model_path)
|
|
||||||
print(f"\nМодель сохранена в: {output_model_path}")
|
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 1.8 MiB |
Reference in New Issue
Block a user