Source code for datadings.sets.CIFAR100_write

"""Create CIFAR 100 data set files.

This tool will look for the following files in the input directory
and download them if necessary:

- cifar-100-python.tar.gz

See also:
    https://www.cs.toronto.edu/~kriz/cifar.html
"""
import tarfile
import os.path as pt
import random
from pickle import load
from collections import defaultdict

from ..writer import FileWriter
from . import ImageCoarseClassificationData
from .CIFAR10_write import get_files
from .CIFAR10_write import row2image
from ..tools import document_keys


__doc__ += document_keys(ImageCoarseClassificationData)


BASE_URL = 'https://www.cs.toronto.edu/~kriz/'
FILES = {
    'all': {
        'path': 'cifar-100-python.tar.gz',
        'url': BASE_URL+'cifar-100-python.tar.gz',
        'md5': 'eb9058c3a382ffc7106e4002c42a8d85',
    }
}


[docs]def yield_rows(files): seen = defaultdict(lambda: 0) for f in files: d = load(f, encoding='bytes') for row, label, coarse_label, filename in zip( d[b'data'], d[b'fine_labels'], d[b'coarse_labels'], d[b'filenames'] ): filename = filename.decode('utf-8') seen[filename] += 1 # apparently some files occur multiple times... if seen[filename] > 1: filename += str(seen[filename]) image = row2image(row) yield image, label, coarse_label, filename
[docs]def write_set(tar, outdir, split, args): files = get_files(tar, 'cifar-100-python', [split]) gen = yield_rows(files) if args.shuffle: gen = list(gen) random.shuffle(gen) outfile = pt.join(outdir, split + '.msgpack') with FileWriter(outfile, total=len(files), overwrite=args.no_confirm) as writer: for data, label, coarse_label, filename in gen: writer.write(ImageCoarseClassificationData( filename, data, int(label), int(coarse_label), ))
[docs]def write_sets(files, outdir, args): with tarfile.open(files['all']['path'], 'r:gz') as tar: for split in ('train', 'test'): try: write_set(tar, outdir, split, args) except FileExistsError: pass
[docs]def main(): from ..tools.argparse import make_parser from ..tools import prepare_indir parser = make_parser(__doc__) args = parser.parse_args() outdir = args.outdir or args.indir files = prepare_indir(FILES, args) write_sets(files, outdir, args)
if __name__ == '__main__': try: main() except KeyboardInterrupt: pass finally: print()