Some problems I ran into:
- I wasn't able to get
tfds.ImageFolder
working with a "flat" folder of images. I had to nest a dummy label folder inside a dummy split folder. I followed the instructions here: https://www.tensorflow.org/datasets/api_docs/python/tfds/folder_dataset/ImageFolder
- There doesn't seem to be a
num_examples
property in tfds.core.DatasetInfo
, so I had to use builder.info.splits['fake_split'].num_examples
where fake_split
is the name of my dummy split folder. It does look like there's a total_num_examples
property, but I'm not sure how to access it - maybe it's a private field (though I'm not sure if those are possible in Python)?
- I had to edit
pre_process
because it was expecting protobufs instead of {image, label}
objects.
Note that the reason I am using the ImageFolder
approach is because the tfrecords approach blew my 3GB dataset up to 200GB, since I think it's storing the raw tensor data? I'm new to this, but it seems like it'd make more sense to just store the data in jpg format since jpg decoding is so fast? That said, even if the tfrecords approach used a reasonable amount of space, I'd probably still prefer to store the ImageFolder
approach since it just seems nicer and more portable. Even better, from my (newbie) perspective, would be the ability to load a tar
of images with any internal directory structure.
Below is my new data_pipeline.py
so far. It seems to work okay now, but I haven't got training to work yet as I'm still debugging some stuff. Will update this post if I run into any more problems with data_pipeline.py
.
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
import flax
import numpy as np
from PIL import Image
import os
from typing import Sequence
from tqdm import tqdm
import json
from tqdm import tqdm
def prefetch(dataset, n_prefetch):
# Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
ds_iter = iter(dataset)
ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
ds_iter)
if n_prefetch:
ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
return ds_iter
def get_data(data_dir, img_size, img_channels, num_classes, num_devices, batch_size, shuffle_buffer=1000):
"""
Args:
data_dir (str): Root directory of the dataset.
img_size (int): Image size for training.
img_channels (int): Number of image channels.
num_classes (int): Number of classes, 0 for no classes.
num_devices (int): Number of devices.
batch_size (int): Batch size (per device).
shuffle_buffer (int): Buffer used for shuffling the dataset.
Returns:
(tf.data.Dataset): Dataset.
"""
def pre_process(example):
# feature = {'height': tf.io.FixedLenFeature([], tf.int64),
# 'width': tf.io.FixedLenFeature([], tf.int64),
# 'channels': tf.io.FixedLenFeature([], tf.int64),
# 'image': tf.io.FixedLenFeature([], tf.string),
# 'label': tf.io.FixedLenFeature([], tf.int64)}
# example = tf.io.parse_single_example(serialized_example, feature)
# height = tf.cast(example['height'], dtype=tf.int64)
# width = tf.cast(example['width'], dtype=tf.int64)
# channels = tf.cast(example['channels'], dtype=tf.int64)
# image = tf.io.decode_raw(example['image'], out_type=tf.uint8)
# image = tf.reshape(image, shape=[height, width, channels])
image = example['image']
image = tf.cast(image, dtype='float32')
image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True)
image = tf.image.random_flip_left_right(image)
image = (image - 127.5) / 127.5
label = tf.one_hot(example['label'], num_classes)
return {'image': image, 'label': label}
def shard(data):
# Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C]
# because the first dimension will be mapped across devices using jax.pmap
data['image'] = tf.reshape(data['image'], [num_devices, -1, img_size, img_size, img_channels])
data['label'] = tf.reshape(data['label'], [num_devices, -1, num_classes])
return data
# print('Loading TFRecord...')
# with open(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin:
# dataset_info = json.load(fin)
# ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords'))
# ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer))
builder = tfds.ImageFolder(data_dir)
print(builder.info)
ds = builder.as_dataset(split='fake_split', shuffle_files=True)
num_examples = builder.info.splits['fake_split'].num_examples
dataset_info = {'num_examples': num_examples, 'num_classes': 1}
ds = ds.shuffle(min(num_examples, shuffle_buffer))
ds = ds.map(pre_process, tf.data.AUTOTUNE)
ds = ds.batch(batch_size * num_devices, drop_remainder=True)
ds = ds.map(shard, tf.data.AUTOTUNE)
ds = ds.prefetch(1)
return ds, dataset_info