Source code for datadings.sets.CAT2000_write

"""Create CAT2000 data set files.

This tool will look for the following files in the input directory
and download them if necessary:

- trainSet.zip
- testSet.zip

See also:
    http://saliency.mit.edu/results_cat2000.html
"""
import os
import os.path as pt
import zipfile
import random
from multiprocessing.dummy import Pool as ThreadPool

import numpy as np
from PIL import Image
from PIL import ImageChops
from simplejpeg import decode_jpeg
from simplejpeg import encode_jpeg

from ..writer import FileWriter
from ..tools.matlab import loadmat
from ..tools import yield_threaded
from . import SaliencyData
from . import SaliencyExperiment
from ..tools import document_keys


__doc__ += document_keys(
    SaliencyData,
    postfix=document_keys(
        SaliencyExperiment,
        block='',
        prefix='Each experiment has the following keys:'
    )
)


BASE_URL = 'http://saliency.mit.edu/'
FILES = {
    'train': {
        'path': 'trainSet.zip',
        'url': BASE_URL+'trainSet.zip',
        'md5': '56ad5c77e6c8f72ed9ef2901628d6e48',
    },
    'test': {
        'path': 'testSet.zip',
        'url': BASE_URL+'testSet.zip',
        'md5': '903ec668df2e5a8470aef9d8654e7985',
    }
}


def __find_bbox(im):
    bg = Image.new(im.mode, im.size, im.getpixel((0, 0)))
    diff = ImageChops.difference(im, bg)
    diff = ImageChops.add(diff, diff, 1.0, -20)
    return diff.getbbox()


def __transform_image(im, bbox, size=1024):
    cropped = im.crop(bbox)
    w, h = cropped.size
    d = max(w, h)
    r = size / d
    return r, cropped.resize(
        (int(round(w*r)), int(round(h*r))),
        Image.ANTIALIAS,
    )


def __decode(data):
    return Image.fromarray(decode_jpeg(
        data, fastupsample=False, fastdct=False
    ), 'RGB')


def __crompress(image, quality=90):
    return encode_jpeg(np.array(image), quality=quality)


def __load_fixmap(imagezip, stimuluspath):
    with imagezip.open(
            stimuluspath.replace('Stimuli', 'FIXATIONLOCS').replace('jpg', 'mat')
    ) as f:
        data = f.read()
    return loadmat(data)['fixLocs']


[docs]def find_fixpoints(arr): # must flip (x,y) coordinate return np.transpose(np.nonzero(arr)[::-1]).astype(np.float32)
[docs]def transform_points(points, offset, scale_factor): return (points - offset[:2]) * scale_factor
[docs]def filter_invalid_fixpoints(points, size): w, h = size ind = (points > 0).any(axis=1) ind = np.logical_and(ind, points[:, 0] < w) ind = np.logical_and(ind, points[:, 1] < h) return points[ind]
[docs]def yield_samples(imagezip, names): for stimuluspath in names: with imagezip.open(stimuluspath) as f: stimulusdata = f.read() try: response = __load_fixmap(imagezip, stimuluspath) except KeyError: response = None yield stimuluspath, stimulusdata, response
[docs]def create_sample(item): stimuluspath, stimulusdata, response = item stimulus = __decode(stimulusdata) bbox = __find_bbox(stimulus) r, cropped = __transform_image(stimulus, bbox) stimulusdata = __crompress(cropped) if response is not None: locations = transform_points(find_fixpoints(response), bbox, r) locations = filter_invalid_fixpoints(locations, cropped.size) else: locations = None filename = os.sep.join(stimuluspath.split(os.sep)[-2:]) return SaliencyData( filename, stimulusdata, [SaliencyExperiment(locations, None)], )
def __is_stimulus(path): return 'Stimuli' in path and 'Output' not in path and path.endswith('.jpg')
[docs]def write_set(imagezip, outdir, split, args): names = [f for f in imagezip.namelist() if __is_stimulus(f)] if args.shuffle: random.shuffle(names) gen = yield_threaded(yield_samples(imagezip, names)) outfile = pt.join(outdir, split + '.msgpack') with FileWriter(outfile, total=len(names), overwrite=args.no_confirm) as writer: pool = ThreadPool(args.threads) for sample in pool.imap_unordered(create_sample, gen): writer.write(sample)
[docs]def write_sets(files, outdir, args): for split in ('train', 'test'): with zipfile.ZipFile(files[split]['path']) as imagezip: try: write_set(imagezip, outdir, split, args) except FileExistsError: pass
[docs]def main(): from ..tools.argparse import make_parser from ..tools.argparse import argument_threads from ..tools import prepare_indir parser = make_parser(__doc__) argument_threads(parser) args = parser.parse_args() outdir = args.outdir or args.indir files = prepare_indir(FILES, args) write_sets(files, outdir, args)
if __name__ == '__main__': try: main() except KeyboardInterrupt: pass finally: print()