datadings.torch package
- class datadings.torch.Compose(*transforms, prefix='')[source]
Bases:
objectCompose a sequence of transform functions. Functions must accept the intended value from samples as first argument. They may have an arbitrary number of positional and keyword arguments.
Example usage with
Dataset:import random from datadings.torch import Compose from datadings.torch import Dataset from datadings.reader import ListReader def add(v, number): return v + number def sub(x, value): return x - value def rng(_): return { 'number': random.randrange(1, 10), 'value': random.randrange(1, 10), } samples = [{'a': 0, 'b': 0, 'c': 0} for _ in range(10)] reader = ListReader(samples) transforms = { 'a': Compose(add), 'b': Compose(sub), 'c': Compose((add, sub)), } dataset = Dataset(reader, transforms=transforms, rng=rng) for i in range(len(dataset)): print(dataset[i])
- Parameters:
transforms – sequence of transform functions, either one iterable or varargs
prefix – string prefix for parameter names, i.e., if a function normally requires parameter
sizegivenprefix='mask_'the parameter'mask_size'
- class datadings.torch.CompressedToPIL[source]
Bases:
objectCompatible torchvision transform that takes a compressed image as bytes (or similar) and returns a PIL image.
- class datadings.torch.Dataset(*args: Any, **kwargs: Any)[source]
Bases:
DatasetBase,DatasetImplementation of
torch.utils.data.Dataset.Warning
Datasetcan be significantly slower thanIterableDataset. If shuffling is necessary consider usingQuasiShufflerinstead.Example usage with the PyTorch
DataLoader:path = '.../train.msgpack' batch_size = 256 reader = MsgpackReader(path) transforms = {'image': Compose( CompressedToPIL(), ..., ToTensor(), )} ds = Dataset(reader, transforms=transforms) train = DataLoader(dataset=ds, batch_size=batch_size) for epoch in range(3): for x, y in dict2tuple(tqdm(train)): pass
- Parameters:
reader – the datadings reader instance
transforms – Transforms applied to samples before they are returned. Either a dict of transform functions or callable with signature
f(sample: dict) -> dictthat is applied directly to samples. In the dict form keys correspond to keys in the sample and values are callables with signaturet(value: any, params: dict) -> any(e.g., an instance ofCompose) withparamsthe value returned by therngcallable.rng – callable with signature
rng(params: dict) -> dictthat returns a dict of parameters applied to transforms
- class datadings.torch.IterableDataset(*args: Any, **kwargs: Any)[source]
Bases:
DatasetBase,IterableDatasetImplementation of
torch.utils.data.IterableDatasetto use with datadings readers.With distributed training the reader is divided into
world_size * num_workersshards. Each dataloader worker of each rank iterates over a different shard. The final batch delivered by a worker may be smaller than the batch size if the length of the reader is not divisible bybatch_size * num_shards.Note
Set
persistent_workers=Truefor theDataLoaderto let the dataset object track the current epoch. It then cycles through shards This makes ranks cycle through shards of the dataset Without this option torch may create new worker processes at any time, which resets the dataset to its initial state.Warning
Raises
RuntimeErrorif0 < len(shard) % batch_size < 1, since this may lead to an uneven number of batches generated by each worker. This can lead to crashes if it happens between rank workers, or deadlock if ranks receive different a number of batches. Changenum_workers,batch_size, orworld_sizeto avoid this.Example usage with the PyTorch
DataLoader:path = '.../train.msgpack' batch_size = 256 reader = MsgpackReader(path) transforms = {'image': Compose( CompressedToPIL(), ..., ToTensor(), )} ds = IterableDataset( reader, transforms=transforms, batch_size=batch_size, ) train = DataLoader( dataset=ds, batch_size=batch_size, num_workers=4, persistent_workers=True, ) for epoch in range(3): print('Epoch', epoch) for x, y in dict2tuple(tqdm(train)): pass
- Parameters:
reader – the datadings reader instance
transforms – Transforms applied to samples before they are returned. Either a dict of transform functions or callable with signature
f(sample: dict) -> dictthat is applied directly to samples. In the dict form keys correspond to keys in the sample and values are callables with signaturet(value: any, params: dict) -> any(e.g., an instance ofCompose) withparamsthe value returned by therngcallable.rng – callable with signature
rng(params: dict) -> dictthat returns a dict of parameters applied to transformsbatch_size – same batch size as given to the
DataLoaderepoch – starting epoch, zero indexed; only relevant when resuming
copy – see
datadings.reader.reader.Reader.iter()chunk_size – see
datadings.reader.reader.Reader.iter()group – distributed process group to use (if not using the default)