Comments (6)
Here's an attempt that now executes, possibly butchering some of the intended functionality of the original. It's modeled after the example here (https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly)
import numpy as np
import keras
class Mixup_threadsafe(keras.utils.Sequence):
'Generates data for Keras'
def __init__(self, X_train, y_train, batch_size=32, shuffle=True, alpha=.2, datagen=None):
'Initialization'
self.batch_size = batch_size
self.X_train = X_train
self.y_train = y_train
self.shuffle = shuffle
self.on_epoch_end()
self.alpha= alpha
self.datagen=datagen
def __len__(self):
'Denotes the number of batches per epoch'
return int(np.floor(len(self.X_train) / self.batch_size))
def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# Generate data
X, y = self.__data_generation(indexes)
return X, y
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.X_train))
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __data_generation(self, batch_ids):
_, h, w, c = self.X_train.shape
l = np.random.beta(self.alpha, self.alpha, self.batch_size)
X_l = l.reshape(self.batch_size, 1, 1, 1)
y_l = l.reshape(self.batch_size, 1)
X1 = self.X_train[batch_ids]
X2 = self.X_train[np.flip(batch_ids)] #replaced this with flip
X = X1 * X_l + X2 * (1 - X_l)
if self.datagen:
for i in range(self.batch_size):
X[i] = self.datagen.random_transform(X[i])
X[i] = self.datagen.standardize(X[i])
y1 = self.y_train[batch_ids]
y2 = self.y_train[np.flip(batch_ids)]
y = y1 * y_l + y2 * (1 - y_l) #removed the list option
return X/255, y #Rex added dividing by 255 here
from mixup-generator.
If you want to handle the last batch( smaller size than others), you just need to replace self.batch_size by len(batch_ids)
from mixup-generator.
Thank you for your suggestion.
Feel free to make pull request!
I prefer something like this for batch diversity:
batch_ids2 = np.random.permutation(batch_ids)
X2 = self.X_train[batch_ids2]
...
y2 = self.y_train[batch_ids2]
from mixup-generator.
regarding the len(Number of batches per epoch) shouldn't you instead select the ceil to have the correct number of batches?
from mixup-generator.
You are right, but flooring version is simpler in creating batches and would not make problems.
from mixup-generator.
Oh, I see!
from mixup-generator.
Related Issues (8)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mixup-generator.