📅  最后修改于: 2022-03-11 14:57:31.094000             🧑  作者: Mango
import glob
import os
import keras
import numpy as np
import skimage
from imgaug import augmenters as iaa
class DataGenerator(keras.utils.Sequence):
"""Generates data for Keras"""
"""This structure guarantees that the network will only train once on each sample per epoch"""
def __init__(self, list_IDs, im_path, label_path, batch_size=4, dim=(128, 128, 128),
n_classes=4, shuffle=True, augment=False):
'Initialization'
self.dim = dim
self.batch_size = batch_size
self.list_IDs = list_IDs
self.im_path = im_path
self.label_path = label_path
self.n_classes = n_classes
self.shuffle = shuffle
self.augment = augment
self.on_epoch_end()
def __len__(self):
'Denotes the number of batches per epoch'
return int(np.floor(len(self.list_IDs) / self.batch_size))
def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
# Find list of IDs
list_IDs_temp = [self.list_IDs[k] for k in indexes]
# Generate data
X, y = self.__data_generation(list_IDs_temp)
return X, y
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.list_IDs))
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __data_generation(self, list_IDs_temp):
if self.augment:
pass
if not self.augment:
X = np.empty([self.batch_size, *self.dim])
Y = np.empty([self.batch_size, *self.dim, self.n_classes])
# Generate data
for i, ID in enumerate(list_IDs_temp):
img_X = skimage.io.imread(os.path.join(im_path, ID))
X[i,] = img_X
img_Y = skimage.io.imread(os.path.join(label_path, ID))
Y[i,] = keras.utils.to_categorical(img_Y, num_classes=self.n_classes)
X = X.reshape(self.batch_size, *self.dim, 1)
return X, Y
params = {'dim': (128, 128, 128),
'batch_size': 4,
'im_path': "some/path/for/the/images/",
'label_path': "some/path/for/the/label_images",
'n_classes': 4,
'shuffle': True,
'augment': True}
partition = {}
im_path = "some/path/for/the/images/"
label_path = "some/path/for/the/label_images/"
images = glob.glob(os.path.join(im_path, "*.tif"))
images_IDs = [name.split("/")[-1] for name in images]
partition['train'] = images_IDs
training_generator = DataGenerator(partition['train'], **params)