1. Application scenarios
tv.datasets.ImageFolder() function is usually used when using Pytorch to do classification tasks. However, this method of data storage is not necessarily suitable for itself. If you can easily load your own data, you should consider rewriting the Dataset class.
ImageFolder requires data storage:
root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png ... ... root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png
2. Customize your own data loading method
Generally, we don't want to move data back and forth, just know the path of the picture, and tell the model where to get it by ourselves is a better way. So we just inherit the Dataset and re-implement the following.
The general method can be divided into three steps:
- The path and label of the picture are arranged in the text (whatever text is available, and the method is not limited, but it is convenient for us to analyze it).
- Data information is parsed and stored in a list.
- The _getitem_() function reads each data and label and returns.
train.txt - (first column data path, second column label)
root/dog/xxx.png 0 root/dog/xxy.png 0 root/dog/xxz.png 0 root/cat/123.png 1 root/cat/nsdf3.png 1 root/cat/asd932_.png 1
Specific Code Implementation
#!/usr/bin/python # -*- coding: UTF-8 -*- from PIL import Image from torch.utils.data import Dataset import torchvision.transforms as transforms import torch __all__ = ['MyDataset'] class MyDataset(Dataset): def __init__(self, dataPath, transform=None, target_transform=None): imgsPath = open(dataPath, 'r') imgs = [] for line in imgsPath: line = line.rstrip() words = line.split() imgs.append((words[0], int(words[1]))) self.imgs = imgs self.transform = transform self.target_transform = target_transform def __getitem__(self, index): fn, label = self.imgs[index] img = Image.open(fn).convert('RGB') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: label = self.transform(label) return img, label def __len__(self): return len(self.imgs) if __name__ == '__main__': transform_train = transforms.Compose([transforms.Resize(256), # Reset image resolution transforms.RandomResizedCrop(224), # Random clipping transforms.RandomHorizontalFlip(), # Flip at the level of probability p transforms.RandomVerticalFlip(), # Vertical inversion with probability p transforms.ToTensor(),]) trainset = MyDataset(dataPath='train.txt', transform=transform_train) # Training set trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True) for step, (tx, ty) in enumerate(trainloader, 0): print('---test---', tx, ty)
Statement: Summary learning, there are problems or inappropriate, you can criticize and correct oh, thank you.
Excellent reference links
[1]: https://github.com/tensor-yu/PyTorch_Tutorial
[2]: https://blog.csdn.net/u011995719/article/details/85102770