Fix tab_dataset
This commit is contained in:
@@ -5,18 +5,18 @@ import timm
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# НОВЫЙ ИМПОРТ ДЛЯ VLM
|
# ТЕПЕРЬ BLIP-2
|
||||||
from transformers import BlipProcessor, BlipForConditionalGeneration
|
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||||
|
|
||||||
class ImageProcessor:
|
class ImageProcessor:
|
||||||
def __init__(self, model_path: Path | str):
|
def __init__(self, model_path: Path | str):
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
# --- ПОТОК 1: ЭМОЦИИ (ResNet-50) ---
|
# --- ПОТОК 1: ТВОЙ АВТОРСКИЙ ЭМБЕДДИНГ (Core) ---
|
||||||
print("⏳ Загрузка эмоционального модуля (ResNet-50)...")
|
|
||||||
self.emo_model = timm.create_model('resnet50', pretrained=False, num_classes=8)
|
self.emo_model = timm.create_model('resnet50', pretrained=False, num_classes=8)
|
||||||
if Path(model_path).exists():
|
if Path(model_path).exists():
|
||||||
self.emo_model.load_state_dict(torch.load(model_path, map_location=self.device))
|
self.emo_model.load_state_dict(torch.load(model_path, map_location=self.device))
|
||||||
|
|
||||||
self.emo_model.fc = torch.nn.Identity()
|
self.emo_model.fc = torch.nn.Identity()
|
||||||
self.emo_model.to(self.device).eval()
|
self.emo_model.to(self.device).eval()
|
||||||
|
|
||||||
@@ -26,32 +26,31 @@ class ImageProcessor:
|
|||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
])
|
])
|
||||||
|
|
||||||
# --- ПОТОК 2: СЕМАНТИКА И КОНТЕКСТ (BLIP Large) ---
|
# --- ПОТОК 2: СЕМАНТИЧЕСКИЙ ЭКСПЕРТ BLIP-2 ---
|
||||||
print("⏳ Загрузка мощной VLM модели (BLIP) для описания сцен...")
|
print("⏳ Загрузка тяжелой артиллерии: BLIP-2...")
|
||||||
# Используем версию Large, так как позволяет железо V100
|
# Используем версию opt-2.7b — она идеально сбалансирована для V100
|
||||||
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||||
self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
|
self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
||||||
|
"Salesforce/blip2-opt-2.7b",
|
||||||
print("✅ Обе нейросети визуального анализа успешно загружены на V100!")
|
torch_dtype=torch.float16 # Обязательно для скорости на V100
|
||||||
|
).to(self.device)
|
||||||
|
print("✅ BLIP-2 и ResNet-50 готовы.")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def extract_embedding(self, image: Image.Image) -> np.ndarray:
|
def extract_embedding(self, image: Image.Image) -> np.ndarray:
|
||||||
"""Извлекает 2048-мерный вектор эмоций."""
|
|
||||||
img_rgb = image.convert('RGB')
|
img_rgb = image.convert('RGB')
|
||||||
img_tensor = self.emo_transform(img_rgb).unsqueeze(0).to(self.device)
|
img_tensor = self.emo_transform(img_rgb).unsqueeze(0).to(self.device)
|
||||||
return self.emo_model(img_tensor).cpu().numpy().flatten()
|
return self.emo_model(img_tensor).cpu().numpy().flatten()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def describe_scene(self, image: Image.Image) -> str:
|
def describe_scene(self, image: Image.Image) -> str:
|
||||||
"""Генерирует текстовое описание картинки (Captioning) для LLM."""
|
"""Генерирует описание через BLIP-2."""
|
||||||
img_rgb = image.convert('RGB')
|
img_rgb = image.convert('RGB')
|
||||||
|
|
||||||
# Готовим картинку для BLIP
|
# Инференс BLIP-2 требует float16 для V100
|
||||||
inputs = self.blip_processor(img_rgb, return_tensors="pt").to(self.device)
|
inputs = self.blip_processor(images=img_rgb, return_tensors="pt").to(self.device, torch.float16)
|
||||||
|
|
||||||
# Генерируем описание (max_new_tokens ограничим, чтобы было лаконично)
|
# Генерируем описание
|
||||||
out = self.blip_model.generate(**inputs, max_new_tokens=30)
|
generated_ids = self.blip_model.generate(**inputs, max_new_tokens=40)
|
||||||
|
caption = self.blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
||||||
# Декодируем тензор в строку
|
|
||||||
caption = self.blip_processor.decode(out[0], skip_special_tokens=True)
|
|
||||||
return caption
|
return caption
|
||||||
@@ -3,9 +3,9 @@ import json
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
class LLMAcousticBridge:
|
class LLMAcousticBridge:
|
||||||
def __init__(self, model_name="phi3", host="http://localhost:11434"):
|
def __init__(self, model_name="dolphin-llama3:8b"):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.api_url = f"{host}/api/generate"
|
self.api_url = "http://localhost:11434/api/generate"
|
||||||
|
|
||||||
def _clean_json(self, text):
|
def _clean_json(self, text):
|
||||||
"""Вытаскивает чистый JSON из ответа нейросети."""
|
"""Вытаскивает чистый JSON из ответа нейросети."""
|
||||||
|
|||||||
@@ -74,7 +74,8 @@ def render_dataset_tab(matcher, image_files, embeddings, labels_array, images_pa
|
|||||||
c1, c2 = st.columns([1, 3])
|
c1, c2 = st.columns([1, 3])
|
||||||
with c1:
|
with c1:
|
||||||
st.write(f"**ID:** {int(row['song_id'])}")
|
st.write(f"**ID:** {int(row['song_id'])}")
|
||||||
st.caption(f"L2 Dist: {row['distance']:.2f}")
|
score_val = row.get('final_score', row.get('emo_distance', 0))
|
||||||
|
st.caption(f"Dist Score: {score_val:.2f}")
|
||||||
with c2:
|
with c2:
|
||||||
audio_path = matcher.get_audio_path(row['song_id'])
|
audio_path = matcher.get_audio_path(row['song_id'])
|
||||||
if audio_path:
|
if audio_path:
|
||||||
|
|||||||
Reference in New Issue
Block a user