diff --git a/.gitignore b/.gitignore index cb5e5bd..9782a36 100644 --- a/.gitignore +++ b/.gitignore @@ -202,3 +202,5 @@ dataset/ src/emoset_test_embeddings.npy src/emoset_test_labels.npy +runs +NFS \ No newline at end of file diff --git a/src/download_dataset.py b/src/download_dataset.py new file mode 100644 index 0000000..5ddb4ae --- /dev/null +++ b/src/download_dataset.py @@ -0,0 +1,171 @@ +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +import os +import json + +from PIL import Image + + +class EmoSet(Dataset): + ATTRIBUTES_MULTI_CLASS = [ + 'scene', 'facial_expression', 'human_action', 'brightness', 'colorfulness', + ] + ATTRIBUTES_MULTI_LABEL = [ + 'object' + ] + NUM_CLASSES = { + 'brightness': 11, + 'colorfulness': 11, + 'scene': 254, + 'object': 409, + 'facial_expression': 6, + 'human_action': 264, + } + + def __init__(self, + data_root, + num_emotion_classes, + phase, + ): + assert num_emotion_classes in (8, 2) + assert phase in ('train', 'val', 'test') + self.transforms_dict = self.get_data_transforms() + + self.info = self.get_info(data_root, num_emotion_classes) + + if phase == 'train': + self.transform = self.transforms_dict['train'] + elif phase == 'val': + self.transform = self.transforms_dict['val'] + elif phase == 'test': + self.transform = self.transforms_dict['test'] + else: + raise NotImplementedError + + data_store = json.load(open(os.path.join(data_root, f'{phase}.json'))) + self.data_store = [ + [ + self.info['emotion']['label2idx'][item[0]], + item[1], + os.path.join(data_root, item[2]), + os.path.join(data_root, item[3]) + ] + for item in data_store + ] + + @classmethod + def get_data_transforms(cls): + transforms_dict = { + 'train': transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + 'val': transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + 'test': transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + } + return transforms_dict + + def get_info(self, data_root, num_emotion_classes): + assert num_emotion_classes in (8, 2) + info = json.load(open(os.path.join(data_root, 'info.json'))) + if num_emotion_classes == 8: + pass + elif num_emotion_classes == 2: + emotion_info = { + 'label2idx': { + 'amusement': 0, + 'awe': 0, + 'contentment': 0, + 'excitement': 0, + 'anger': 1, + 'disgust': 1, + 'fear': 1, + 'sadness': 1, + }, + 'idx2label': { + '0': 'positive', + '1': 'negative', + } + } + info['emotion'] = emotion_info + else: + raise NotImplementedError + + return info + + def load_image_by_path(self, path): + image = Image.open(path).convert('RGB') + image = self.transform(image) + return image + + def load_annotation_by_path(self, path): + json_data = json.load(open(path)) + return json_data + + def __getitem__(self, item): + emotion_label_idx, image_id, image_path, annotation_path = self.data_store[item] + image = self.load_image_by_path(image_path) + annotation_data = self.load_annotation_by_path(annotation_path) + data = {'image_id': image_id, 'image': image, 'emotion_label_idx': emotion_label_idx} + + for attribute in self.ATTRIBUTES_MULTI_CLASS: + # if empty, set to -1, else set to label index + attribute_label_idx = -1 + if attribute in annotation_data: + attribute_label_idx = self.info[attribute]['label2idx'][str(annotation_data[attribute])] + data.update({f'{attribute}_label_idx': attribute_label_idx}) + + for attribute in self.ATTRIBUTES_MULTI_LABEL: + # if empty, set to 0, else set to 1 + assert attribute == 'object' + num_classes = self.NUM_CLASSES[attribute] + attribute_label_idx = torch.zeros(num_classes) + if attribute in annotation_data: + for label in annotation_data[attribute]: + attribute_label_idx[self.info[attribute]['label2idx'][label]] = 1 + data.update({f'{attribute}_label_idx': attribute_label_idx}) + + return data + + def __len__(self): + return len(self.data_store) + + +if __name__ == '__main__': + data_root = r'F:\common_file_system\EmoSet\EmoSet_v5_划分train-test-val' + num_emotion_classes = 8 + phase = 'train' + + dataset = EmoSet( + data_root=data_root, + num_emotion_classes=num_emotion_classes, + phase=phase, + ) + + # print(dataset.info) + dataloader = DataLoader(dataset, batch_size = 16, shuffle = True) + + for i, data in enumerate(dataloader): + pass + # print(data['emotion_label_idx']) + # print(data['scene_label_idx']) + # print(data['facial_expression_label_idx']) + # print(data['human_action_label_idx']) + # print(data['brightness_label_idx']) + # print(data['colorfulness_label_idx']) + # print(data['object_label_idx']) + # break + diff --git a/src/test.ipynb b/src/test.ipynb index 955404a..1215824 100644 --- a/src/test.ipynb +++ b/src/test.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "ca08df84", "metadata": {}, "outputs": [ @@ -10,108 +10,114 @@ "name": "stdout", "output_type": "stream", "text": [ - "Using device: cuda\n", - "Step 0/1000, Loss: 1.0055\n", - "Step 10/1000, Loss: 0.9936\n", - "Step 20/1000, Loss: 0.9853\n", - "Step 30/1000, Loss: 0.9698\n", - "Step 40/1000, Loss: 0.9523\n", - "Step 50/1000, Loss: 0.9206\n", - "Step 60/1000, Loss: 0.8736\n", - "Step 70/1000, Loss: 0.7981\n", - "Step 80/1000, Loss: 0.7176\n", - "Step 90/1000, Loss: 0.6491\n", - "Step 100/1000, Loss: 0.5748\n", - "Step 110/1000, Loss: 0.5016\n", - "Step 120/1000, Loss: 0.4303\n", - "Step 130/1000, Loss: 0.3937\n", - "Step 140/1000, Loss: 0.3528\n", - "Step 150/1000, Loss: 0.2982\n", - "Step 160/1000, Loss: 0.2696\n", - "Step 170/1000, Loss: 0.2489\n", - "Step 180/1000, Loss: 0.2180\n", - "Step 190/1000, Loss: 0.2003\n", - "Step 200/1000, Loss: 0.1671\n", - "Step 210/1000, Loss: 0.1725\n", - "Step 220/1000, Loss: 0.1364\n", - "Step 230/1000, Loss: 0.1171\n", - "Step 240/1000, Loss: 0.1252\n", - "Step 250/1000, Loss: 0.1032\n", - "Step 260/1000, Loss: 0.0834\n", - "Step 270/1000, Loss: 0.0892\n", - "Step 280/1000, Loss: 0.0858\n", - "Step 290/1000, Loss: 0.0685\n", - "Step 300/1000, Loss: 0.0546\n", - "Step 310/1000, Loss: 0.0829\n", - "Step 320/1000, Loss: 0.0516\n", - "Step 330/1000, Loss: 0.0446\n", - "Step 340/1000, Loss: 0.0569\n", - "Step 350/1000, Loss: 0.0406\n", - "Step 360/1000, Loss: 0.0373\n", - "Step 370/1000, Loss: 0.0377\n", - "Step 380/1000, Loss: 0.0395\n", - "Step 390/1000, Loss: 0.0299\n", - "Step 400/1000, Loss: 0.0311\n", - "Step 410/1000, Loss: 0.0272\n", - "Step 420/1000, Loss: 0.0223\n", - "Step 430/1000, Loss: 0.0261\n", - "Step 440/1000, Loss: 0.0230\n", - "Step 450/1000, Loss: 0.0191\n", - "Step 460/1000, Loss: 0.0196\n", - "Step 470/1000, Loss: 0.0219\n", - "Step 480/1000, Loss: 0.0193\n", - "Step 490/1000, Loss: 0.0244\n", - "Step 500/1000, Loss: 0.0231\n", - "Step 510/1000, Loss: 0.0167\n", - "Step 520/1000, Loss: 0.0167\n", - "Step 530/1000, Loss: 0.0252\n", - "Step 540/1000, Loss: 0.0215\n", - "Step 550/1000, Loss: 0.0190\n", - "Step 560/1000, Loss: 0.0175\n", - "Step 570/1000, Loss: 0.0204\n", - "Step 580/1000, Loss: 0.0184\n", - "Step 590/1000, Loss: 0.0159\n", - "Step 600/1000, Loss: 0.0334\n", - "Step 610/1000, Loss: 0.0177\n", - "Step 620/1000, Loss: 0.0173\n", - "Step 630/1000, Loss: 0.0215\n", - "Step 640/1000, Loss: 0.0180\n", - "Step 650/1000, Loss: 0.0165\n", - "Step 660/1000, Loss: 0.0194\n", - "Step 670/1000, Loss: 0.0244\n", - "Step 680/1000, Loss: 0.0193\n", - "Step 690/1000, Loss: 0.0169\n", - "Step 700/1000, Loss: 0.0195\n", - "Step 710/1000, Loss: 0.0162\n", - "Step 720/1000, Loss: 0.0355\n", - "Step 730/1000, Loss: 0.0217\n", - "Step 740/1000, Loss: 0.0158\n", - "Step 750/1000, Loss: 0.0145\n", - "Step 760/1000, Loss: 0.0135\n", - "Step 770/1000, Loss: 0.0240\n", - "Step 780/1000, Loss: 0.0182\n", - "Step 790/1000, Loss: 0.0233\n", - "Step 800/1000, Loss: 0.0181\n", - "Step 810/1000, Loss: 0.0161\n", - "Step 820/1000, Loss: 0.0187\n", - "Step 830/1000, Loss: 0.0149\n", - "Step 840/1000, Loss: 0.0198\n", - "Step 850/1000, Loss: 0.0270\n", - "Step 860/1000, Loss: 0.0176\n", - "Step 870/1000, Loss: 0.0168\n", - "Step 880/1000, Loss: 0.0230\n", - "Step 890/1000, Loss: 0.0282\n", - "Step 900/1000, Loss: 0.0193\n", - "Step 910/1000, Loss: 0.0193\n", - "Step 920/1000, Loss: 0.0153\n", - "Step 930/1000, Loss: 0.0157\n", - "Step 940/1000, Loss: 0.0196\n", - "Step 950/1000, Loss: 0.0168\n", - "Step 960/1000, Loss: 0.0132\n", - "Step 970/1000, Loss: 0.0124\n", - "Step 980/1000, Loss: 0.0224\n", - "Step 990/1000, Loss: 0.0164\n", - "Total time: 65.72 s\n" + "Using device: cuda\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 0/1000, Loss: 1.0013\n", + "Step 10/1000, Loss: 1.0088\n", + "Step 20/1000, Loss: 0.9956\n", + "Step 30/1000, Loss: 0.9781\n", + "Step 40/1000, Loss: 0.9613\n", + "Step 50/1000, Loss: 0.9313\n", + "Step 60/1000, Loss: 0.8927\n", + "Step 70/1000, Loss: 0.8503\n", + "Step 80/1000, Loss: 0.7537\n", + "Step 90/1000, Loss: 0.6689\n", + "Step 100/1000, Loss: 0.6063\n", + "Step 110/1000, Loss: 0.5172\n", + "Step 120/1000, Loss: 0.4592\n", + "Step 130/1000, Loss: 0.4044\n", + "Step 140/1000, Loss: 0.3610\n", + "Step 150/1000, Loss: 0.3175\n", + "Step 160/1000, Loss: 0.2825\n", + "Step 170/1000, Loss: 0.2560\n", + "Step 180/1000, Loss: 0.2360\n", + "Step 190/1000, Loss: 0.2203\n", + "Step 200/1000, Loss: 0.1930\n", + "Step 210/1000, Loss: 0.1854\n", + "Step 220/1000, Loss: 0.1723\n", + "Step 230/1000, Loss: 0.1546\n", + "Step 240/1000, Loss: 0.1386\n", + "Step 250/1000, Loss: 0.1271\n", + "Step 260/1000, Loss: 0.1109\n", + "Step 270/1000, Loss: 0.1032\n", + "Step 280/1000, Loss: 0.0899\n", + "Step 290/1000, Loss: 0.0807\n", + "Step 300/1000, Loss: 0.0750\n", + "Step 310/1000, Loss: 0.0813\n", + "Step 320/1000, Loss: 0.0612\n", + "Step 330/1000, Loss: 0.0544\n", + "Step 340/1000, Loss: 0.0552\n", + "Step 350/1000, Loss: 0.0446\n", + "Step 360/1000, Loss: 0.0403\n", + "Step 370/1000, Loss: 0.0350\n", + "Step 380/1000, Loss: 0.0612\n", + "Step 390/1000, Loss: 0.0364\n", + "Step 400/1000, Loss: 0.0322\n", + "Step 410/1000, Loss: 0.0302\n", + "Step 420/1000, Loss: 0.0519\n", + "Step 430/1000, Loss: 0.0319\n", + "Step 440/1000, Loss: 0.0260\n", + "Step 450/1000, Loss: 0.0208\n", + "Step 460/1000, Loss: 0.0409\n", + "Step 470/1000, Loss: 0.0291\n", + "Step 480/1000, Loss: 0.0234\n", + "Step 490/1000, Loss: 0.0194\n", + "Step 500/1000, Loss: 0.0274\n", + "Step 510/1000, Loss: 0.0231\n", + "Step 520/1000, Loss: 0.0199\n", + "Step 530/1000, Loss: 0.0154\n", + "Step 540/1000, Loss: 0.0278\n", + "Step 550/1000, Loss: 0.0185\n", + "Step 560/1000, Loss: 0.0180\n", + "Step 570/1000, Loss: 0.0152\n", + "Step 580/1000, Loss: 0.0132\n", + "Step 590/1000, Loss: 0.0111\n", + "Step 600/1000, Loss: 0.0396\n", + "Step 610/1000, Loss: 0.0179\n", + "Step 620/1000, Loss: 0.0148\n", + "Step 630/1000, Loss: 0.0123\n", + "Step 640/1000, Loss: 0.0265\n", + "Step 650/1000, Loss: 0.0133\n", + "Step 660/1000, Loss: 0.0128\n", + "Step 670/1000, Loss: 0.0107\n", + "Step 680/1000, Loss: 0.0142\n", + "Step 690/1000, Loss: 0.0202\n", + "Step 700/1000, Loss: 0.0125\n", + "Step 710/1000, Loss: 0.0107\n", + "Step 720/1000, Loss: 0.0140\n", + "Step 730/1000, Loss: 0.0195\n", + "Step 740/1000, Loss: 0.0148\n", + "Step 750/1000, Loss: 0.0109\n", + "Step 760/1000, Loss: 0.0094\n", + "Step 770/1000, Loss: 0.0121\n", + "Step 780/1000, Loss: 0.0233\n", + "Step 790/1000, Loss: 0.0151\n", + "Step 800/1000, Loss: 0.0134\n", + "Step 810/1000, Loss: 0.0117\n", + "Step 820/1000, Loss: 0.0124\n", + "Step 830/1000, Loss: 0.0221\n", + "Step 840/1000, Loss: 0.0161\n", + "Step 850/1000, Loss: 0.0136\n", + "Step 860/1000, Loss: 0.0161\n", + "Step 870/1000, Loss: 0.0194\n", + "Step 880/1000, Loss: 0.0145\n", + "Step 890/1000, Loss: 0.0149\n", + "Step 900/1000, Loss: 0.0232\n", + "Step 910/1000, Loss: 0.0166\n", + "Step 920/1000, Loss: 0.0156\n", + "Step 930/1000, Loss: 0.0276\n", + "Step 940/1000, Loss: 0.0176\n", + "Step 950/1000, Loss: 0.0152\n", + "Step 960/1000, Loss: 0.0162\n", + "Step 970/1000, Loss: 0.0143\n", + "Step 980/1000, Loss: 0.0136\n", + "Step 990/1000, Loss: 0.0117\n", + "Total time: 67.25 s\n" ] } ], @@ -124,6 +130,7 @@ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(\"Using device:\", device)\n", "\n", + "\n", "# Огромные параметры\n", "N, D_in, H1, H2, H3, D_out = 300_000, 4096, 2048, 1024, 512, 10\n", "batch_size = 16_384 # большой батч\n", @@ -170,9 +177,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python (thesis)", + "display_name": ".venv", "language": "python", - "name": "thesis" + "name": "python3" }, "language_info": { "codemirror_mode": {