57 lines
2.8 KiB
Python
57 lines
2.8 KiB
Python
import torch
|
|
import torchvision.transforms as T
|
|
from PIL import Image
|
|
import timm
|
|
from pathlib import Path
|
|
import numpy as np
|
|
|
|
# НОВЫЙ ИМПОРТ ДЛЯ VLM
|
|
from transformers import BlipProcessor, BlipForConditionalGeneration
|
|
|
|
class ImageProcessor:
|
|
def __init__(self, model_path: Path | str):
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# --- ПОТОК 1: ЭМОЦИИ (ResNet-50) ---
|
|
print("⏳ Загрузка эмоционального модуля (ResNet-50)...")
|
|
self.emo_model = timm.create_model('resnet50', pretrained=False, num_classes=8)
|
|
if Path(model_path).exists():
|
|
self.emo_model.load_state_dict(torch.load(model_path, map_location=self.device))
|
|
self.emo_model.fc = torch.nn.Identity()
|
|
self.emo_model.to(self.device).eval()
|
|
|
|
self.emo_transform = T.Compose([
|
|
T.Resize((224, 224)),
|
|
T.ToTensor(),
|
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
# --- ПОТОК 2: СЕМАНТИКА И КОНТЕКСТ (BLIP Large) ---
|
|
print("⏳ Загрузка мощной VLM модели (BLIP) для описания сцен...")
|
|
# Используем версию Large, так как позволяет железо V100
|
|
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
|
self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
|
|
|
|
print("✅ Обе нейросети визуального анализа успешно загружены на V100!")
|
|
|
|
@torch.no_grad()
|
|
def extract_embedding(self, image: Image.Image) -> np.ndarray:
|
|
"""Извлекает 2048-мерный вектор эмоций."""
|
|
img_rgb = image.convert('RGB')
|
|
img_tensor = self.emo_transform(img_rgb).unsqueeze(0).to(self.device)
|
|
return self.emo_model(img_tensor).cpu().numpy().flatten()
|
|
|
|
@torch.no_grad()
|
|
def describe_scene(self, image: Image.Image) -> str:
|
|
"""Генерирует текстовое описание картинки (Captioning) для LLM."""
|
|
img_rgb = image.convert('RGB')
|
|
|
|
# Готовим картинку для BLIP
|
|
inputs = self.blip_processor(img_rgb, return_tensors="pt").to(self.device)
|
|
|
|
# Генерируем описание (max_new_tokens ограничим, чтобы было лаконично)
|
|
out = self.blip_model.generate(**inputs, max_new_tokens=30)
|
|
|
|
# Декодируем тензор в строку
|
|
caption = self.blip_processor.decode(out[0], skip_special_tokens=True)
|
|
return caption |