PyTorch learning notes 5.torchvision Library

Keywords: Pytorch Deep Learning

1, Introduction

tochvision mainly processes image data, including some commonly used data sets, models, conversion functions, etc. torchvision is independent of PyTorch and requires special installation.

torchvision mainly includes the following three parts:

  • torchvision.models: provide various classic network structures and pre trained models in in-depth learning, such as Alex net, VGG, ResNet, perception, etc.
  • torchvision.datasets: provides common datasets, inheriting torch.utils.data.Dataset in design, mainly including MNIST, CIFAR10/100, ImageNet, COCO, etc.
  • torchvision.transforms: provides common data preprocessing operations, mainly including operations on Tensor and PIL Image objects.
  • torchvision.utils: tool classes, such as saving tensors as images to disk and creating an image grid for a small batch.

2, Installation

pip3 install torchvision 

The torch vision should match the pytorch version and Cuda.
To query the versions of pytorch and torchvision, you can use the following statement:

import torch
import torchvision

print(torch.__version__)

print(torchvision.__version__)

3, Examples of the main functions of torchvision

1. Load model

(1) Load several pre training models

The pre training model can be loaded by pre trained = true. The default value of pre trained is False. If it is not assigned, the effect is the same as that of assigning False.

import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
densenet = models.densenet_161()

Expected input of pre training model:

  • Mini batch of RGB image: (batch_size, 3, H, W), and H and W cannot be lower than 224.
  • The pixel value of the image must be in the range [0,1] and normalized with mean=[0.485, 0.456, 0.406] and variance std=[0.229, 0.224, 0.225].

The downloaded model can be downloaded through state_dict() to print status parameters and cached dictionaries.

import torchvision.models as models

vgg16 = models.vgg16(pretrained=True)
# Returns a dictionary containing all States of the module, including parameters and cache
pretrained_dict = vgg16.state_dict()

(2) Only the model is loaded, and the pre training parameters are not loaded

If you only need the network structure and do not need to initialize with the parameters of the training model, you can set pre trained = false

# Import model structure
resnet18 = models.resnet18(pretrained=False)
# Load the pre downloaded pre training parameters to resnet18
resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))

(4) Loading partial pre training model

In actual use, the pre training model may be adjusted, that is, the layers in the pre training model may be modified.
In the following example, the mismatched keys in the original model are deleted. Note that the name of the changed layer of the new model needs to be different from that of the corresponding layer of the original model. For example, the name of the last layer of resnet is fc (in pytorch), so the last layer of resnet we modified cannot take this name, but can be called fc_

import torchvision.models as models
resnet152 = models.resnet152(pretrained=True)
# Extract parameters
pretrained_dict = resnet152.state_dict()
# The pre training model can also be used through model_zoo download parameters
# pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
model_dict = resnet152.state_dict()
# Pre trained_ Dict does not belong to model_ The key of dict is removed
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# Update existing model_dict
model_dict.update(pretrained_dict)
# Load the state you really need_ dict
resnet152.load_state_dict(model_dict)

(5) Adjustment model

Some layers of the pre training model are not directly available and need to be slightly modified. For example, the last full connection layer of resnet is divided into 1000 categories, while we only have 21 categories; Or the channel received by the first layer convolution of resnet is 3, and the channel we may input the picture is 4, so it can be modified by the following methods:

# Modify the number of channels
resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Here 21 is the classification
resnet.fc = nn.Linear(2048, 21)
from torchvision import models
from torch import nn

# Load the pre trained model and save it to ~ /. torch/models /
resnet34 = models.resnet34(pretrained=True, num_classes=1000)
# The default is 1000 classification on ImageNet. Here, the last full connection layer is modified to 10 classification
resnet34.fc = nn.Linear(512, 10)

(6) Method of loading non pre training model

This has nothing to do with torch vision.

3.1.6.1 save and load the whole model

torch.save(model_object, 'resnet.pth')
model = torch.load('resnet.pth')

3.1.6.2 load the structure and parameters of the network respectively

# Will my_ The RESNET model dictionary is saved as my_resnet.pth
torch.save(my_resnet.state_dict(), "my_resnet.pth")
# Load RESNET and store the model in my_resnet.pth
my_resnet.load_state_dict(torch.load("my_resnet.pth"))

2. Load dataset

torchvision.datasets is a subclass from torch.utils.data.Dataset. You can use torch.utils.data.DataLoader for multithreading.

Official website reference address:
https://pytorch.org/vision/stable/datasets.html#

(1) Example: loading MNIST

from torchvision import datasets
dataset = datasets.MNIST('data/', download=True, train=False, transform=None)

(2) Example: loading fashion MNIST

from torchvision import datasets
dataset = datasets.FashionMNIST('data/', download=True, train=False, transform=None)

(3) ImageFolder implements data import

The datasets.ImageFolder method can import data.

ImageFolder(root,transform=None,target_transform=None,loader=default_loader)

Parameter Description:

  • Root: find the image under the specified root path.
  • Transform: receive the function / transform of PIL image and return the converted version. You can directly use the Compose method above to combine the required transformations.
  • target_transform: transform the label.
  • loader: Specifies the function to load pictures. The default operation is to read PIL image objects.

This method returns a list, which can be converted into Tensor data using data.DataLoader.
Example:

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224), #Random cutting of an image (224, 224) from the original image
        transforms.RandomHorizontalFlip(), #Reversal at 0.5 probability level
        transforms.ToTensor(),  #Convert a PIL. Image with a range of [0,255] or numpy. ndarray to a shape of [C, H, W], and a FloadTensor with a range of [0,1.0].
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) #Normalization
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}
 
#image data file
data_root = ''
image_datasets = {x: datasets.ImageFolder(os.path.join(data_root, x),
                                          data_transforms[x]) for x in ['train', 'val']}
# wrap your data and label into Tensor
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                             batch_size=10,
                                             shuffle=True,
                                             num_workers=4) for x in ['train', 'val']}

dataloaders are of type Variable and can be used as input parameters to the model.

3. transforms

transforms contains some image preprocessing operations that can be connected together for serial operations using torch vision. transforms. Compose.

These operations include:

__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
           "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
           "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
           "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
  • Compose(): used to manage all transforms operations.
  • ToTensor(): convert the picture data into tensors and the conversion range is within the [0,1] interval.
  • Normalize(mean, std): normalize.
  • Resize(size): adjust the input PIL image to the specified size. The parameters can be int or int tuple.
  • CenterCrop(size): center cut a given PIL Image to obtain a tuple of the specified size.
  • RandomCrop(size, padding=0): random center point cutting.
  • RandomHorizontalFlip(size, interpolation=2): randomly cut the given PIL Image and resize it.
  • RandomHorizontalFlip(): randomly flip a given PIL Image horizontally.
  • RandomVerticalFlip(): randomly flip the given PIL Image vertically.
  • ToPILImage(): converts Tensor or numpy.ndarray to PIL Image.
  • FiveCrop(size): crop a given PIL image into 4 corner areas and center areas.
  • Pad(padding, fill=0, padding_mode = 'constant'): fill the PIL edge.
  • RandomAffine(degrees, translate=None, scale=None): the image with the same center is subjected to random affine change.
  • Random apply (transforms, P = 0.5): randomly select transforms.

Posted by justinede on Tue, 12 Oct 2021 01:21:30 -0700