Files
Thesis/src/scripts/download_dataset.py
T
2026-05-06 19:48:18 +00:00

172 lines
5.6 KiB
Python

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
import json
from PIL import Image
class EmoSet(Dataset):
ATTRIBUTES_MULTI_CLASS = [
'scene', 'facial_expression', 'human_action', 'brightness', 'colorfulness',
]
ATTRIBUTES_MULTI_LABEL = [
'object'
]
NUM_CLASSES = {
'brightness': 11,
'colorfulness': 11,
'scene': 254,
'object': 409,
'facial_expression': 6,
'human_action': 264,
}
def __init__(self,
data_root,
num_emotion_classes,
phase,
):
assert num_emotion_classes in (8, 2)
assert phase in ('train', 'val', 'test')
self.transforms_dict = self.get_data_transforms()
self.info = self.get_info(data_root, num_emotion_classes)
if phase == 'train':
self.transform = self.transforms_dict['train']
elif phase == 'val':
self.transform = self.transforms_dict['val']
elif phase == 'test':
self.transform = self.transforms_dict['test']
else:
raise NotImplementedError
data_store = json.load(open(os.path.join(data_root, f'{phase}.json')))
self.data_store = [
[
self.info['emotion']['label2idx'][item[0]],
item[1],
os.path.join(data_root, item[2]),
os.path.join(data_root, item[3])
]
for item in data_store
]
@classmethod
def get_data_transforms(cls):
transforms_dict = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
return transforms_dict
def get_info(self, data_root, num_emotion_classes):
assert num_emotion_classes in (8, 2)
info = json.load(open(os.path.join(data_root, 'info.json')))
if num_emotion_classes == 8:
pass
elif num_emotion_classes == 2:
emotion_info = {
'label2idx': {
'amusement': 0,
'awe': 0,
'contentment': 0,
'excitement': 0,
'anger': 1,
'disgust': 1,
'fear': 1,
'sadness': 1,
},
'idx2label': {
'0': 'positive',
'1': 'negative',
}
}
info['emotion'] = emotion_info
else:
raise NotImplementedError
return info
def load_image_by_path(self, path):
image = Image.open(path).convert('RGB')
image = self.transform(image)
return image
def load_annotation_by_path(self, path):
json_data = json.load(open(path))
return json_data
def __getitem__(self, item):
emotion_label_idx, image_id, image_path, annotation_path = self.data_store[item]
image = self.load_image_by_path(image_path)
annotation_data = self.load_annotation_by_path(annotation_path)
data = {'image_id': image_id, 'image': image, 'emotion_label_idx': emotion_label_idx}
for attribute in self.ATTRIBUTES_MULTI_CLASS:
# if empty, set to -1, else set to label index
attribute_label_idx = -1
if attribute in annotation_data:
attribute_label_idx = self.info[attribute]['label2idx'][str(annotation_data[attribute])]
data.update({f'{attribute}_label_idx': attribute_label_idx})
for attribute in self.ATTRIBUTES_MULTI_LABEL:
# if empty, set to 0, else set to 1
assert attribute == 'object'
num_classes = self.NUM_CLASSES[attribute]
attribute_label_idx = torch.zeros(num_classes)
if attribute in annotation_data:
for label in annotation_data[attribute]:
attribute_label_idx[self.info[attribute]['label2idx'][label]] = 1
data.update({f'{attribute}_label_idx': attribute_label_idx})
return data
def __len__(self):
return len(self.data_store)
if __name__ == '__main__':
data_root = r'F:\common_file_system\EmoSet\EmoSet_v5_划分train-test-val'
num_emotion_classes = 8
phase = 'train'
dataset = EmoSet(
data_root=data_root,
num_emotion_classes=num_emotion_classes,
phase=phase,
)
# print(dataset.info)
dataloader = DataLoader(dataset, batch_size = 16, shuffle = True)
for i, data in enumerate(dataloader):
pass
# print(data['emotion_label_idx'])
# print(data['scene_label_idx'])
# print(data['facial_expression_label_idx'])
# print(data['human_action_label_idx'])
# print(data['brightness_label_idx'])
# print(data['colorfulness_label_idx'])
# print(data['object_label_idx'])
# break