Source code for datadings.sets.ILSVRC2012_write

"""Create ILSVRC 2012 challenge data set files.

This tool will look for the following files in the input directory:

- ILSVRC2012_img_train.tar
- ILSVRC2012_img_val.tar

See also:
    http://image-net.org/challenges/LSVRC/2012/index

Note:
    Registration is required to download this dataset.
    Please visit the website to download it.

Important:
    For performance reasons samples are read in same order as they are stored
    in the source tar files. It is recommended to use the datadings-shuffle
    command to create a shuffled copy.
"""
import os.path as pt
import tarfile
import io
import gzip
from multiprocessing.dummy import Pool as ThreadPool

import numpy as np
from PIL import Image
from PIL import UnidentifiedImageError
from simplejpeg import decode_jpeg
from simplejpeg import decode_jpeg_header
from simplejpeg import encode_jpeg

from ..writer import FileWriter
from .tools import open_comp
from ..tools import yield_threaded
from . import ImageClassificationData
from .ILSVRC2012_synsets import SYNSETS
from ..tools import document_keys


__doc__ += document_keys(ImageClassificationData)


FILES = {
    'train': {
        'url': None,
        'path': 'ILSVRC2012_img_train.tar',
        'md5': '1d675b47d978889d74fa0da5fadfb00e',
    },
    'val': {
        'url': None,
        'path': 'ILSVRC2012_img_val.tar',
        'md5': '29b22e2961454d5413ddabcf34fc5622',
    },
}
SET_ROOT = pt.abspath(pt.dirname(__file__))
READ_SIZE = 4 * 1024 * 1024


[docs]def yield_train(tar): for synset in tar: label = SYNSETS[pt.splitext(synset.name)[0]] with tarfile.open(fileobj=tar.extractfile(synset), bufsize=READ_SIZE) as images: for image in images: yield image.name, images.extractfile(image).read(), label
[docs]def yield_val(tar): with open_comp('ILSVRC2012_val.txt.xz', 'rt', encoding='utf8') as f: labels = dict(line.strip('\n').split(' ', 1) for line in f) for image in tar: yield image.name, tar.extractfile(image).read(), int(labels[image.name])
[docs]def yield_samples(split, tar): if split == 'train': return yield_train(tar) elif split == 'val': return yield_val(tar) elif split == 'test': raise ValueError('test set not supported')
[docs]def verify_image( data, quality=None, short_side=375, long_side=500, colorsubsampling='422', ): target_size = 3 * short_side * long_side # try to decode data using simplejpeg try: h, w, colorspace, _ = decode_jpeg_header(data) # decode images to match at least target size im = decode_jpeg( data, min_width=short_side if w < h else long_side, min_height=short_side if h < w else long_side ) # encode quality is given # and image is big enough to not suffer from re-encoding compress = quality is not None and im.size > 0.5*target_size # simplejpeg could not decode image, fall back to Pillow # could be faulty JPEG or other image format, e.g. PNG except ValueError as e1: try: bio = io.BytesIO(data) im = np.array(Image.open(bio).convert('RGB')) colorspace = 'RGB' # converted to RGB guaranteed compress = True # force compression since simplejpeg failed except UnidentifiedImageError as e2: # For ImageNet21k: # check if images are actually gzip file and try to decompress try: if data[:2] == b'\x1f\x8b': bio.seek(0) with gzip.open(bio) as zf: return verify_image( zf.read(), quality=quality, short_side=short_side, long_side=long_side, colorsubsampling=colorsubsampling, ) except IOError: print(f'image could not be decoded, first 8 bytes: {data.hex()[:8]}') print(e1) print(e2) return None # if images are CMYK or if colorspace == 'CMYK' or compress: # for CMYK or non-JPEG images, # quality might not be given, so assume 99 if quality is None: quality = 98 # default to subsampling 422 # use full color resolution for small images # or if compression is disabled, # i.e. for CMYK images or if simplejpeg failed to decode if not compress or im.size <= 0.5*target_size: colorsubsampling = '444' # downscale large images if im.size > target_size*1.5: h, w = im.shape[:2] s = max(h, w) r = long_side/s h, w = int(round(r*h)), int(round(r*w)) pil = Image.fromarray(im, 'RGB') im = np.array(pil.resize((w, h), resample=Image.LANCZOS)) return encode_jpeg(im, quality=quality, colorsubsampling=colorsubsampling) else: return data
TOTAL = {'train': 1281167, 'val': 50000, 'test': 100000}
[docs]def write_set(split, outdir, gen, args): outfile = pt.join(outdir, split + '.msgpack') def __verify_inner(item): key, data, label = item data = verify_image(data, args.compress, colorsubsampling=args.subsampling) return ImageClassificationData(key, data, label) pool = ThreadPool(args.threads) with FileWriter(outfile, total=TOTAL[split], overwrite=args.no_confirm) as writer: for sample in pool.imap_unordered(__verify_inner, gen): writer.write(sample)
[docs]def write_sets(files, outdir, args): for split in ('train', 'val'): with tarfile.open(files[split]['path'], bufsize=READ_SIZE) as tar: gen = yield_threaded(yield_samples(split, tar)) write_set(split, outdir, gen, args)
[docs]def argument_compress(parser): parser.add_argument( '--compress', nargs='?', default=None, const=85, choices=range(101), metavar='quality 0-100', type=int, help='Use JPEG compression with optional quality. ' 'Default quality is 85. ' 'Big images are resized to roughly fit 500x375. ' )
[docs]def argument_subsampling(parser): parser.add_argument( '--subsampling', default='422', choices=('444', '422', '420', '440', '411', 'Gray'), type=str, help='Color subsampling factor used with compress option. ' '444 is forced for small images to preserve details.' )
[docs]def main(): from ..tools.argparse import make_parser from ..tools.argparse import argument_threads from ..tools import prepare_indir parser = make_parser(__doc__, shuffle=False) argument_threads(parser, default=1) argument_compress(parser) argument_subsampling(parser) args = parser.parse_args() outdir = args.outdir or args.indir files = prepare_indir(FILES, args) write_sets(files, outdir, args)
if __name__ == '__main__': try: main() except KeyboardInterrupt: pass finally: print()