Original: http://studyai.com/article/11efc2bf#%E9%87%87%E6%A0%B7%E5%99%A8%20Sampler%20&%20BatchSampler
Database DataBase + DataSet + Sampler = Loader
from torch.utils.data import *
IMDB + Dataset + Sampler || BatchSampler = DataLoader
Database DataBase
Image DataBase, short for IMDB, refers to data information stored in files.
File formats can vary. For example, xml, yaml, json, sql.
VOC is in xml format and COCO is in JSON format.
The process of constructing IMDB is the process of parsing these files and establishing data index.
It is usually parsed into a Python list to facilitate subsequent iterations.
Data Set
DataSet: Provides a single or slice access method to data based on database IMDB.
In other words, it defines the indexing mechanism of objects in the database, and how to realize single index or slice index.
In short, DataSet defines DataSet as an indexable object, An Indexerable Object, through _getitem_.
That is, after passing in a given Index index, how to access singletons or slices according to the Index, singletons or slices depending on whether the Index is a single value or a list.
The Pytorch source code is as follows:
class Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ # Define a singleton/slice access method, namely dataItem = Dataset[index] def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])
There are two ways to customize data sets based on the above Dataset base class and IMDB base class.
# Method 1: Single Inheritance class XxDataset(Dataset) # IMDB is fed in as a parameter for secondary encapsulation imdb = IMDB() pass # Method 2: Double Inheritance class XxDataset(IMDB, Dataset): pass
Sampler & Batch Sampler
In practical applications, data are not necessarily accessed in a regular order, but in a random disordered order or in a random weighted way.
Therefore, to read data according to a specific rule is a sampling operation, which needs to define a sampler: Sampler.
In addition, the data may not be read one by one, but need a batch of reads, that is, batch sampling operations, the definition of batch sampler: Batch Sampler.
Therefore, only the single access method of Dataset is not enough. On this basis, we need to further define the batch access method.
In short, the sampler defines the generation rules of index and generates index according to the specified rules, thus controlling the reading mechanism of data.
BatchSampler is based on Sampler: BatchSampler = Sampler + BatchSize
The Pytorch source code is as follows.
class Sampler(object): """Base class for all Samplers. //Sampler base class, you can customize the sampler based on this. Every Sampler subclass has to provide an __iter__ method, providing a way to iterate over indices of dataset elements, and a __len__ method that returns the length of the returned iterators. """ def __init__(self, data_source): pass def __iter__(self): raise NotImplementedError def __len__(self): raise NotImplementedError # Sequential Inertial Sampling class SequentialSampler(Sampler): def __init__(self, data_source): self.data_source = data_source def __iter__(self): return iter(range(len(self.data_source))) def __len__(self): return len(self.data_source) # Random sampling class RandomSampler(Sampler): def __init__(self, data_source): self.data_source = data_source def __iter__(self): return iter(torch.randperm(len(self.data_source)).long()) def __len__(self): return len(self.data_source) # Random Subsampling class SubsetRandomSampler(Sampler): pass # Weighted random sampling class WeightedRandomSampler(Sampler): pass
class BatchSampler(object): """Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler): Base sampler. batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ def __init__(self, sampler, batch_size, drop_last): self.sampler = sampler # ****** self.batch_size = batch_size self.drop_last = drop_last def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch def __len__(self): if self.drop_last: return len(self.sampler) // self.batch_size else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size
As can be seen from the above, Sampler is essentially an iterative object with specific rules, but can only iterate singly.
For example, [x for x in range(10)], range(10) is the most basic Sampler, and only one value can be taken out of each loop.
[x for x in range(10)] Out[10]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] from torch.utils.data.sampler import SequentialSampler [x for x in SequentialSampler(range(10))] Out[14]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] from torch.utils.data.sampler import RandomSampler [x for x in RandomSampler(range(10))] Out[12]: [4, 9, 5, 0, 2, 8, 3, 1, 7, 6]
Batch Sampler encapsulates Sampler twice and introduces batchSize parameter to realize batch iteration.
from torch.utils.data.sampler import BatchSampler [x for x in BatchSampler(range(10), batch_size=3, drop_last=False)] Out[9]: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] [x for x in BatchSampler(RandomSampler(range(10)), batch_size=3, drop_last=False)] Out[15]: [[1, 3, 7], [9, 2, 0], [5, 4, 6], [8]]
Loader DataLoader
In practical calculation, if the amount of data is large, considering the limited memory and the slow speed of IO,
Therefore, it can not be loaded into memory all at once, nor can it be loaded with only one thread.
As a result, multithreaded, iterative loading is required, so a loader is specifically defined: DataLoader.
DataLoader is an iteratable object, An Iterable Object, which is internally configured with a magic function, iter, and calls it back to an iterator.
This function can be called directly by the built-in function iter, namely DataIteror = iter(DataLoader).
dataloader = DataLoader(dataset=Dataset(imdb=IMDB()), sampler=Sampler(), num_works, ...)
_ The init_ parameter consists of two parts, the first part is used to specify the data set + sampler, and the second part is multithreaded parameter.
class DataLoader(object): """ Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset. """ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory self.drop_last = drop_last self.timeout = timeout self.worker_init_fn = worker_init_fn if timeout < 0: raise ValueError('timeout option should be non-negative') # Detection of parameter conflicts: default batchSampler vs custom BatchSampler if batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler is mutually exclusive with ' 'batch_size, shuffle, sampler, and drop_last') if sampler is not None and shuffle: raise ValueError('sampler is mutually exclusive with shuffle') if self.num_workers < 0: raise ValueError('num_workers cannot be negative; ' 'use num_workers=0 to disable multiprocessing.') # A BatchSampler will be forcibly specified here if batch_sampler is None: # A Sampler is forcibly specified here if sampler is None: if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size, drop_last) # Use custom samplers and batch samplers self.sampler = sampler self.batch_sampler = batch_sampler def __iter__(self): # Loading data by calling Pytorch's multithreaded iterator return DataLoaderIter(self) def __len__(self): return len(self.batch_sampler)
Data Iterator Data Loader Iter
There is a difference between an iterator and an Iterable object.
Iterable object means that when it uses the Iter function, it can return an iterator so that it can be accessed iteratively in succession.
The iterator object has an additional magic function _next_ inside. If the built-in function next is used on it, the next data can be generated continuously. The generation rules are determined by this function.
Iterable objects describe that objects are iterative, but specific iteration rules are described by iterators. The advantage of decoupling is that iterators with different rules can be configured for the same iterative object.
General process for data set/container traversal: NILIS
NILIS rule: data = next (iter (DataSet [sampler])) data = next (iter (DataSet [sampler]))
- sampler defines the rules of index generation, returns an index list, and controls the subsequent index access process.
- indexer defines rules for indexed access on containers based on _item_ so that containers can be indexed objects and [] operations are available.
- The loader defines iteratability on the container based on _iter_ and describes the loading rules, including returning an iterator to make the container an iterative object that can be operated on by iter().
- Next defines iterators on containers based on _next_ and describes specific iteration rules so that containers become iterator objects and can be operated on by next().
## Initialization sampler = Sampler() dataSet = DataSet(sampler) # __getitem__ dataLoader = DataLoader(dataSet, sampler) / DataIterable() # __iter__() dataIterator = DataLoaderIter(dataLoader) #__next__() data_iter = iter(dataLoader) ## Ergodic method 1 for _ in range(len(data_iter)) data = next(data_iter) ## Ergodic method 2 for i, data in enumerate(dataLoader): data = data