Pytorch refactoring Dataset to load its own dataset

Keywords: Python github

1. Application scenarios

tv.datasets.ImageFolder() function is usually used when using Pytorch to do classification tasks. However, this method of data storage is not necessarily suitable for itself. If you can easily load your own data, you should consider rewriting the Dataset class.

ImageFolder requires data storage:
        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png
        ... ...
        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

2. Customize your own data loading method

Generally, we don't want to move data back and forth, just know the path of the picture, and tell the model where to get it by ourselves is a better way. So we just inherit the Dataset and re-implement the following.

The general method can be divided into three steps:
  1. The path and label of the picture are arranged in the text (whatever text is available, and the method is not limited, but it is convenient for us to analyze it).
  2. Data information is parsed and stored in a list.
  3. The _getitem_() function reads each data and label and returns.
train.txt - (first column data path, second column label)
        root/dog/xxx.png	0
        root/dog/xxy.png	0
        root/dog/xxz.png	0
        root/cat/123.png	1
        root/cat/nsdf3.png	1
        root/cat/asd932_.png	1
Specific Code Implementation
#!/usr/bin/python
# -*- coding: UTF-8 -*-

from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch


__all__ = ['MyDataset']


class MyDataset(Dataset):

    def __init__(self, dataPath, transform=None, target_transform=None):
        imgsPath = open(dataPath, 'r')
        imgs = []
        for line in imgsPath:
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0], int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            label = self.transform(label)
        return img, label

    def __len__(self):
        return len(self.imgs)


if __name__ == '__main__':

    transform_train = transforms.Compose([transforms.Resize(256),  # Reset image resolution
                                          transforms.RandomResizedCrop(224),  # Random clipping
                                          transforms.RandomHorizontalFlip(),  # Flip at the level of probability p
                                          transforms.RandomVerticalFlip(),  # Vertical inversion with probability p
                                          transforms.ToTensor(),])
    trainset = MyDataset(dataPath='train.txt', transform=transform_train)  # Training set
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)
    for step, (tx, ty) in enumerate(trainloader, 0):
        
        print('---test---', tx, ty)

Statement: Summary learning, there are problems or inappropriate, you can criticize and correct oh, thank you.

Excellent reference links

[1]: https://github.com/tensor-yu/PyTorch_Tutorial
[2]: https://blog.csdn.net/u011995719/article/details/85102770

Posted by delayedinsanity on Mon, 07 Oct 2019 22:30:14 -0700