datadings.torch package
- class datadings.torch.Compose(*transforms, prefix='')[source]
Bases:
object
Compose a sequence of transform functions. Functions must accept the intended value from samples as first argument. They may have an arbitrary number of positional and keyword arguments.
Example usage with
Dataset
:import random from datadings.torch import Compose from datadings.torch import Dataset from datadings.reader import ListReader def add(v, number): return v + number def sub(x, value): return x - value def rng(_): return { 'number': random.randrange(1, 10), 'value': random.randrange(1, 10), } samples = [{'a': 0, 'b': 0, 'c': 0} for _ in range(10)] reader = ListReader(samples) transforms = { 'a': Compose(add), 'b': Compose(sub), 'c': Compose((add, sub)), } dataset = Dataset(reader, transforms=transforms, rng=rng) for i in range(len(dataset)): print(dataset[i])
- Parameters:
transforms – sequence of transform functions, either one iterable or varargs
prefix – string prefix for parameter names, i.e., if a function normally requires parameter
size
givenprefix='mask_'
the parameter'mask_size'
- class datadings.torch.CompressedToPIL[source]
Bases:
object
Compatible torchvision transform that takes a compressed image as bytes (or similar) and returns a PIL image.
- class datadings.torch.Dataset(*args: Any, **kwargs: Any)[source]
Bases:
DatasetBase
,Dataset
Implementation of
torch.utils.data.Dataset
.Warning
Dataset
can be significantly slower thanIterableDataset
. If shuffling is necessary consider usingQuasiShuffler
instead.Example usage with the PyTorch
DataLoader
:path = '.../train.msgpack' batch_size = 256 reader = MsgpackReader(path) transforms = {'image': Compose( CompressedToPIL(), ..., ToTensor(), )} ds = Dataset(reader, transforms=transforms) train = DataLoader(dataset=ds, batch_size=batch_size) for epoch in range(3): for x, y in dict2tuple(tqdm(train)): pass
- Parameters:
reader – the datadings reader instance
transforms – Transforms applied to samples before they are returned. Either a dict of transform functions or callable with signature
f(sample: dict) -> dict
that is applied directly to samples. In the dict form keys correspond to keys in the sample and values are callables with signaturet(value: any, params: dict) -> any
(e.g., an instance ofCompose
) withparams
the value returned by therng
callable.rng – callable with signature
rng(params: dict) -> dict
that returns a dict of parameters applied to transforms
- class datadings.torch.IterableDataset(*args: Any, **kwargs: Any)[source]
Bases:
DatasetBase
,IterableDataset
Implementation of
torch.utils.data.IterableDataset
to use with datadings readers.With distributed training the reader is divided into
world_size * num_workers
shards. Each dataloader worker of each rank iterates over a different shard. The final batch delivered by a worker may be smaller than the batch size if the length of the reader is not divisible bybatch_size * num_shards
.Note
Set
persistent_workers=True
for theDataLoader
to let the dataset object track the current epoch. It then cycles through shards This makes ranks cycle through shards of the dataset Without this option torch may create new worker processes at any time, which resets the dataset to its initial state.Warning
Raises
RuntimeError
if0 < len(shard) % batch_size < 1
, since this may lead to an uneven number of batches generated by each worker. This can lead to crashes if it happens between rank workers, or deadlock if ranks receive different a number of batches. Changenum_workers
,batch_size
, orworld_size
to avoid this.Example usage with the PyTorch
DataLoader
:path = '.../train.msgpack' batch_size = 256 reader = MsgpackReader(path) transforms = {'image': Compose( CompressedToPIL(), ..., ToTensor(), )} ds = IterableDataset( reader, transforms=transforms, batch_size=batch_size, ) train = DataLoader( dataset=ds, batch_size=batch_size, num_workers=4, persistent_workers=True, ) for epoch in range(3): print('Epoch', epoch) for x, y in dict2tuple(tqdm(train)): pass
- Parameters:
reader – the datadings reader instance
transforms – Transforms applied to samples before they are returned. Either a dict of transform functions or callable with signature
f(sample: dict) -> dict
that is applied directly to samples. In the dict form keys correspond to keys in the sample and values are callables with signaturet(value: any, params: dict) -> any
(e.g., an instance ofCompose
) withparams
the value returned by therng
callable.rng – callable with signature
rng(params: dict) -> dict
that returns a dict of parameters applied to transformsbatch_size – same batch size as given to the
DataLoader
epoch – starting epoch, zero indexed; only relevant when resuming
copy – see
datadings.reader.reader.Reader.iter()
chunk_size – see
datadings.reader.reader.Reader.iter()
group – distributed process group to use (if not using the default)