Added NFS and runs to .gitignore
This commit is contained in:
@@ -202,3 +202,5 @@ dataset/
|
||||
|
||||
src/emoset_test_embeddings.npy
|
||||
src/emoset_test_labels.npy
|
||||
runs
|
||||
NFS
|
||||
@@ -0,0 +1,171 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
import os
|
||||
import json
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class EmoSet(Dataset):
|
||||
ATTRIBUTES_MULTI_CLASS = [
|
||||
'scene', 'facial_expression', 'human_action', 'brightness', 'colorfulness',
|
||||
]
|
||||
ATTRIBUTES_MULTI_LABEL = [
|
||||
'object'
|
||||
]
|
||||
NUM_CLASSES = {
|
||||
'brightness': 11,
|
||||
'colorfulness': 11,
|
||||
'scene': 254,
|
||||
'object': 409,
|
||||
'facial_expression': 6,
|
||||
'human_action': 264,
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
data_root,
|
||||
num_emotion_classes,
|
||||
phase,
|
||||
):
|
||||
assert num_emotion_classes in (8, 2)
|
||||
assert phase in ('train', 'val', 'test')
|
||||
self.transforms_dict = self.get_data_transforms()
|
||||
|
||||
self.info = self.get_info(data_root, num_emotion_classes)
|
||||
|
||||
if phase == 'train':
|
||||
self.transform = self.transforms_dict['train']
|
||||
elif phase == 'val':
|
||||
self.transform = self.transforms_dict['val']
|
||||
elif phase == 'test':
|
||||
self.transform = self.transforms_dict['test']
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
data_store = json.load(open(os.path.join(data_root, f'{phase}.json')))
|
||||
self.data_store = [
|
||||
[
|
||||
self.info['emotion']['label2idx'][item[0]],
|
||||
item[1],
|
||||
os.path.join(data_root, item[2]),
|
||||
os.path.join(data_root, item[3])
|
||||
]
|
||||
for item in data_store
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_data_transforms(cls):
|
||||
transforms_dict = {
|
||||
'train': transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
]),
|
||||
'val': transforms.Compose([
|
||||
transforms.Resize(224),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
]),
|
||||
'test': transforms.Compose([
|
||||
transforms.Resize(224),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
]),
|
||||
}
|
||||
return transforms_dict
|
||||
|
||||
def get_info(self, data_root, num_emotion_classes):
|
||||
assert num_emotion_classes in (8, 2)
|
||||
info = json.load(open(os.path.join(data_root, 'info.json')))
|
||||
if num_emotion_classes == 8:
|
||||
pass
|
||||
elif num_emotion_classes == 2:
|
||||
emotion_info = {
|
||||
'label2idx': {
|
||||
'amusement': 0,
|
||||
'awe': 0,
|
||||
'contentment': 0,
|
||||
'excitement': 0,
|
||||
'anger': 1,
|
||||
'disgust': 1,
|
||||
'fear': 1,
|
||||
'sadness': 1,
|
||||
},
|
||||
'idx2label': {
|
||||
'0': 'positive',
|
||||
'1': 'negative',
|
||||
}
|
||||
}
|
||||
info['emotion'] = emotion_info
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return info
|
||||
|
||||
def load_image_by_path(self, path):
|
||||
image = Image.open(path).convert('RGB')
|
||||
image = self.transform(image)
|
||||
return image
|
||||
|
||||
def load_annotation_by_path(self, path):
|
||||
json_data = json.load(open(path))
|
||||
return json_data
|
||||
|
||||
def __getitem__(self, item):
|
||||
emotion_label_idx, image_id, image_path, annotation_path = self.data_store[item]
|
||||
image = self.load_image_by_path(image_path)
|
||||
annotation_data = self.load_annotation_by_path(annotation_path)
|
||||
data = {'image_id': image_id, 'image': image, 'emotion_label_idx': emotion_label_idx}
|
||||
|
||||
for attribute in self.ATTRIBUTES_MULTI_CLASS:
|
||||
# if empty, set to -1, else set to label index
|
||||
attribute_label_idx = -1
|
||||
if attribute in annotation_data:
|
||||
attribute_label_idx = self.info[attribute]['label2idx'][str(annotation_data[attribute])]
|
||||
data.update({f'{attribute}_label_idx': attribute_label_idx})
|
||||
|
||||
for attribute in self.ATTRIBUTES_MULTI_LABEL:
|
||||
# if empty, set to 0, else set to 1
|
||||
assert attribute == 'object'
|
||||
num_classes = self.NUM_CLASSES[attribute]
|
||||
attribute_label_idx = torch.zeros(num_classes)
|
||||
if attribute in annotation_data:
|
||||
for label in annotation_data[attribute]:
|
||||
attribute_label_idx[self.info[attribute]['label2idx'][label]] = 1
|
||||
data.update({f'{attribute}_label_idx': attribute_label_idx})
|
||||
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_store)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_root = r'F:\common_file_system\EmoSet\EmoSet_v5_划分train-test-val'
|
||||
num_emotion_classes = 8
|
||||
phase = 'train'
|
||||
|
||||
dataset = EmoSet(
|
||||
data_root=data_root,
|
||||
num_emotion_classes=num_emotion_classes,
|
||||
phase=phase,
|
||||
)
|
||||
|
||||
# print(dataset.info)
|
||||
dataloader = DataLoader(dataset, batch_size = 16, shuffle = True)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
pass
|
||||
# print(data['emotion_label_idx'])
|
||||
# print(data['scene_label_idx'])
|
||||
# print(data['facial_expression_label_idx'])
|
||||
# print(data['human_action_label_idx'])
|
||||
# print(data['brightness_label_idx'])
|
||||
# print(data['colorfulness_label_idx'])
|
||||
# print(data['object_label_idx'])
|
||||
# break
|
||||
|
||||
+112
-105
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"id": "ca08df84",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -10,108 +10,114 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using device: cuda\n",
|
||||
"Step 0/1000, Loss: 1.0055\n",
|
||||
"Step 10/1000, Loss: 0.9936\n",
|
||||
"Step 20/1000, Loss: 0.9853\n",
|
||||
"Step 30/1000, Loss: 0.9698\n",
|
||||
"Step 40/1000, Loss: 0.9523\n",
|
||||
"Step 50/1000, Loss: 0.9206\n",
|
||||
"Step 60/1000, Loss: 0.8736\n",
|
||||
"Step 70/1000, Loss: 0.7981\n",
|
||||
"Step 80/1000, Loss: 0.7176\n",
|
||||
"Step 90/1000, Loss: 0.6491\n",
|
||||
"Step 100/1000, Loss: 0.5748\n",
|
||||
"Step 110/1000, Loss: 0.5016\n",
|
||||
"Step 120/1000, Loss: 0.4303\n",
|
||||
"Step 130/1000, Loss: 0.3937\n",
|
||||
"Step 140/1000, Loss: 0.3528\n",
|
||||
"Step 150/1000, Loss: 0.2982\n",
|
||||
"Step 160/1000, Loss: 0.2696\n",
|
||||
"Step 170/1000, Loss: 0.2489\n",
|
||||
"Step 180/1000, Loss: 0.2180\n",
|
||||
"Step 190/1000, Loss: 0.2003\n",
|
||||
"Step 200/1000, Loss: 0.1671\n",
|
||||
"Step 210/1000, Loss: 0.1725\n",
|
||||
"Step 220/1000, Loss: 0.1364\n",
|
||||
"Step 230/1000, Loss: 0.1171\n",
|
||||
"Step 240/1000, Loss: 0.1252\n",
|
||||
"Step 250/1000, Loss: 0.1032\n",
|
||||
"Step 260/1000, Loss: 0.0834\n",
|
||||
"Step 270/1000, Loss: 0.0892\n",
|
||||
"Step 280/1000, Loss: 0.0858\n",
|
||||
"Step 290/1000, Loss: 0.0685\n",
|
||||
"Step 300/1000, Loss: 0.0546\n",
|
||||
"Step 310/1000, Loss: 0.0829\n",
|
||||
"Step 320/1000, Loss: 0.0516\n",
|
||||
"Step 330/1000, Loss: 0.0446\n",
|
||||
"Step 340/1000, Loss: 0.0569\n",
|
||||
"Step 350/1000, Loss: 0.0406\n",
|
||||
"Step 360/1000, Loss: 0.0373\n",
|
||||
"Step 370/1000, Loss: 0.0377\n",
|
||||
"Step 380/1000, Loss: 0.0395\n",
|
||||
"Step 390/1000, Loss: 0.0299\n",
|
||||
"Step 400/1000, Loss: 0.0311\n",
|
||||
"Step 410/1000, Loss: 0.0272\n",
|
||||
"Step 420/1000, Loss: 0.0223\n",
|
||||
"Step 430/1000, Loss: 0.0261\n",
|
||||
"Step 440/1000, Loss: 0.0230\n",
|
||||
"Step 450/1000, Loss: 0.0191\n",
|
||||
"Step 460/1000, Loss: 0.0196\n",
|
||||
"Step 470/1000, Loss: 0.0219\n",
|
||||
"Step 480/1000, Loss: 0.0193\n",
|
||||
"Step 490/1000, Loss: 0.0244\n",
|
||||
"Step 500/1000, Loss: 0.0231\n",
|
||||
"Step 510/1000, Loss: 0.0167\n",
|
||||
"Step 520/1000, Loss: 0.0167\n",
|
||||
"Step 530/1000, Loss: 0.0252\n",
|
||||
"Step 540/1000, Loss: 0.0215\n",
|
||||
"Step 550/1000, Loss: 0.0190\n",
|
||||
"Step 560/1000, Loss: 0.0175\n",
|
||||
"Step 570/1000, Loss: 0.0204\n",
|
||||
"Step 580/1000, Loss: 0.0184\n",
|
||||
"Step 590/1000, Loss: 0.0159\n",
|
||||
"Step 600/1000, Loss: 0.0334\n",
|
||||
"Step 610/1000, Loss: 0.0177\n",
|
||||
"Step 620/1000, Loss: 0.0173\n",
|
||||
"Step 630/1000, Loss: 0.0215\n",
|
||||
"Step 640/1000, Loss: 0.0180\n",
|
||||
"Step 650/1000, Loss: 0.0165\n",
|
||||
"Step 660/1000, Loss: 0.0194\n",
|
||||
"Step 670/1000, Loss: 0.0244\n",
|
||||
"Step 680/1000, Loss: 0.0193\n",
|
||||
"Step 690/1000, Loss: 0.0169\n",
|
||||
"Step 700/1000, Loss: 0.0195\n",
|
||||
"Step 710/1000, Loss: 0.0162\n",
|
||||
"Step 720/1000, Loss: 0.0355\n",
|
||||
"Step 730/1000, Loss: 0.0217\n",
|
||||
"Step 740/1000, Loss: 0.0158\n",
|
||||
"Step 750/1000, Loss: 0.0145\n",
|
||||
"Step 760/1000, Loss: 0.0135\n",
|
||||
"Step 770/1000, Loss: 0.0240\n",
|
||||
"Step 780/1000, Loss: 0.0182\n",
|
||||
"Step 790/1000, Loss: 0.0233\n",
|
||||
"Step 800/1000, Loss: 0.0181\n",
|
||||
"Step 810/1000, Loss: 0.0161\n",
|
||||
"Step 820/1000, Loss: 0.0187\n",
|
||||
"Step 830/1000, Loss: 0.0149\n",
|
||||
"Step 840/1000, Loss: 0.0198\n",
|
||||
"Step 850/1000, Loss: 0.0270\n",
|
||||
"Step 860/1000, Loss: 0.0176\n",
|
||||
"Step 870/1000, Loss: 0.0168\n",
|
||||
"Step 880/1000, Loss: 0.0230\n",
|
||||
"Step 890/1000, Loss: 0.0282\n",
|
||||
"Step 900/1000, Loss: 0.0193\n",
|
||||
"Step 910/1000, Loss: 0.0193\n",
|
||||
"Step 920/1000, Loss: 0.0153\n",
|
||||
"Step 930/1000, Loss: 0.0157\n",
|
||||
"Step 940/1000, Loss: 0.0196\n",
|
||||
"Step 950/1000, Loss: 0.0168\n",
|
||||
"Step 960/1000, Loss: 0.0132\n",
|
||||
"Step 970/1000, Loss: 0.0124\n",
|
||||
"Step 980/1000, Loss: 0.0224\n",
|
||||
"Step 990/1000, Loss: 0.0164\n",
|
||||
"Total time: 65.72 s\n"
|
||||
"Using device: cuda\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Step 0/1000, Loss: 1.0013\n",
|
||||
"Step 10/1000, Loss: 1.0088\n",
|
||||
"Step 20/1000, Loss: 0.9956\n",
|
||||
"Step 30/1000, Loss: 0.9781\n",
|
||||
"Step 40/1000, Loss: 0.9613\n",
|
||||
"Step 50/1000, Loss: 0.9313\n",
|
||||
"Step 60/1000, Loss: 0.8927\n",
|
||||
"Step 70/1000, Loss: 0.8503\n",
|
||||
"Step 80/1000, Loss: 0.7537\n",
|
||||
"Step 90/1000, Loss: 0.6689\n",
|
||||
"Step 100/1000, Loss: 0.6063\n",
|
||||
"Step 110/1000, Loss: 0.5172\n",
|
||||
"Step 120/1000, Loss: 0.4592\n",
|
||||
"Step 130/1000, Loss: 0.4044\n",
|
||||
"Step 140/1000, Loss: 0.3610\n",
|
||||
"Step 150/1000, Loss: 0.3175\n",
|
||||
"Step 160/1000, Loss: 0.2825\n",
|
||||
"Step 170/1000, Loss: 0.2560\n",
|
||||
"Step 180/1000, Loss: 0.2360\n",
|
||||
"Step 190/1000, Loss: 0.2203\n",
|
||||
"Step 200/1000, Loss: 0.1930\n",
|
||||
"Step 210/1000, Loss: 0.1854\n",
|
||||
"Step 220/1000, Loss: 0.1723\n",
|
||||
"Step 230/1000, Loss: 0.1546\n",
|
||||
"Step 240/1000, Loss: 0.1386\n",
|
||||
"Step 250/1000, Loss: 0.1271\n",
|
||||
"Step 260/1000, Loss: 0.1109\n",
|
||||
"Step 270/1000, Loss: 0.1032\n",
|
||||
"Step 280/1000, Loss: 0.0899\n",
|
||||
"Step 290/1000, Loss: 0.0807\n",
|
||||
"Step 300/1000, Loss: 0.0750\n",
|
||||
"Step 310/1000, Loss: 0.0813\n",
|
||||
"Step 320/1000, Loss: 0.0612\n",
|
||||
"Step 330/1000, Loss: 0.0544\n",
|
||||
"Step 340/1000, Loss: 0.0552\n",
|
||||
"Step 350/1000, Loss: 0.0446\n",
|
||||
"Step 360/1000, Loss: 0.0403\n",
|
||||
"Step 370/1000, Loss: 0.0350\n",
|
||||
"Step 380/1000, Loss: 0.0612\n",
|
||||
"Step 390/1000, Loss: 0.0364\n",
|
||||
"Step 400/1000, Loss: 0.0322\n",
|
||||
"Step 410/1000, Loss: 0.0302\n",
|
||||
"Step 420/1000, Loss: 0.0519\n",
|
||||
"Step 430/1000, Loss: 0.0319\n",
|
||||
"Step 440/1000, Loss: 0.0260\n",
|
||||
"Step 450/1000, Loss: 0.0208\n",
|
||||
"Step 460/1000, Loss: 0.0409\n",
|
||||
"Step 470/1000, Loss: 0.0291\n",
|
||||
"Step 480/1000, Loss: 0.0234\n",
|
||||
"Step 490/1000, Loss: 0.0194\n",
|
||||
"Step 500/1000, Loss: 0.0274\n",
|
||||
"Step 510/1000, Loss: 0.0231\n",
|
||||
"Step 520/1000, Loss: 0.0199\n",
|
||||
"Step 530/1000, Loss: 0.0154\n",
|
||||
"Step 540/1000, Loss: 0.0278\n",
|
||||
"Step 550/1000, Loss: 0.0185\n",
|
||||
"Step 560/1000, Loss: 0.0180\n",
|
||||
"Step 570/1000, Loss: 0.0152\n",
|
||||
"Step 580/1000, Loss: 0.0132\n",
|
||||
"Step 590/1000, Loss: 0.0111\n",
|
||||
"Step 600/1000, Loss: 0.0396\n",
|
||||
"Step 610/1000, Loss: 0.0179\n",
|
||||
"Step 620/1000, Loss: 0.0148\n",
|
||||
"Step 630/1000, Loss: 0.0123\n",
|
||||
"Step 640/1000, Loss: 0.0265\n",
|
||||
"Step 650/1000, Loss: 0.0133\n",
|
||||
"Step 660/1000, Loss: 0.0128\n",
|
||||
"Step 670/1000, Loss: 0.0107\n",
|
||||
"Step 680/1000, Loss: 0.0142\n",
|
||||
"Step 690/1000, Loss: 0.0202\n",
|
||||
"Step 700/1000, Loss: 0.0125\n",
|
||||
"Step 710/1000, Loss: 0.0107\n",
|
||||
"Step 720/1000, Loss: 0.0140\n",
|
||||
"Step 730/1000, Loss: 0.0195\n",
|
||||
"Step 740/1000, Loss: 0.0148\n",
|
||||
"Step 750/1000, Loss: 0.0109\n",
|
||||
"Step 760/1000, Loss: 0.0094\n",
|
||||
"Step 770/1000, Loss: 0.0121\n",
|
||||
"Step 780/1000, Loss: 0.0233\n",
|
||||
"Step 790/1000, Loss: 0.0151\n",
|
||||
"Step 800/1000, Loss: 0.0134\n",
|
||||
"Step 810/1000, Loss: 0.0117\n",
|
||||
"Step 820/1000, Loss: 0.0124\n",
|
||||
"Step 830/1000, Loss: 0.0221\n",
|
||||
"Step 840/1000, Loss: 0.0161\n",
|
||||
"Step 850/1000, Loss: 0.0136\n",
|
||||
"Step 860/1000, Loss: 0.0161\n",
|
||||
"Step 870/1000, Loss: 0.0194\n",
|
||||
"Step 880/1000, Loss: 0.0145\n",
|
||||
"Step 890/1000, Loss: 0.0149\n",
|
||||
"Step 900/1000, Loss: 0.0232\n",
|
||||
"Step 910/1000, Loss: 0.0166\n",
|
||||
"Step 920/1000, Loss: 0.0156\n",
|
||||
"Step 930/1000, Loss: 0.0276\n",
|
||||
"Step 940/1000, Loss: 0.0176\n",
|
||||
"Step 950/1000, Loss: 0.0152\n",
|
||||
"Step 960/1000, Loss: 0.0162\n",
|
||||
"Step 970/1000, Loss: 0.0143\n",
|
||||
"Step 980/1000, Loss: 0.0136\n",
|
||||
"Step 990/1000, Loss: 0.0117\n",
|
||||
"Total time: 67.25 s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -124,6 +130,7 @@
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"print(\"Using device:\", device)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Огромные параметры\n",
|
||||
"N, D_in, H1, H2, H3, D_out = 300_000, 4096, 2048, 1024, 512, 10\n",
|
||||
"batch_size = 16_384 # большой батч\n",
|
||||
@@ -170,9 +177,9 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python (thesis)",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "thesis"
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
||||
Reference in New Issue
Block a user