Beta V.1.1
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
import timm
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
class ImageProcessor:
|
||||
def __init__(self, model_path: Path | str):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Загружаем базовую архитектуру, как при обучении EmoSet
|
||||
self.model = timm.create_model('resnet50', pretrained=False, num_classes=8)
|
||||
|
||||
# Подгружаем обученные веса
|
||||
if Path(model_path).exists():
|
||||
# map_location позволяет загрузить модель на CPU, если нет видеокарты
|
||||
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
||||
print(f"✅ Веса ResNet-50 успешно загружены из {model_path}")
|
||||
else:
|
||||
print(f"⚠️ ОШИБКА: Файл весов {model_path} не найден! Модель будет выдавать случайный шум.")
|
||||
|
||||
# Удаляем последний слой (классификатор на 8 эмоций),
|
||||
# чтобы на выходе получать сырой вектор (embedding) на 2048 чисел
|
||||
self.model.fc = torch.nn.Identity()
|
||||
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# Стандартные трансформации ImageNet (строго как при обучении)
|
||||
self.transform = T.Compose([
|
||||
T.Resize((224, 224)),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_embedding(self, image: Image.Image) -> np.ndarray:
|
||||
"""Принимает PIL Image, возвращает numpy-вектор."""
|
||||
# Переводим в RGB (на случай если загрузят PNG с прозрачностью или ЧБ)
|
||||
img_rgb = image.convert('RGB')
|
||||
img_tensor = self.transform(img_rgb).unsqueeze(0).to(self.device)
|
||||
|
||||
embedding = self.model(img_tensor)
|
||||
return embedding.cpu().numpy().flatten()
|
||||
Reference in New Issue
Block a user