# combination of this # https://keras.io/examples/vision/mixup/ # and # https://github.com/shenasa-ai/blog_tutorials/blob/main/mixup-data-augmentation.ipynb import numpy as np import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras import layers # this stuff is from keras.io example AUTO = tf.data.AUTOTUNE BATCH_SIZE = 64 EPOCHS = 10 #(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data() x_train = x_train.astype("float32") / 255.0 x_train = np.reshape(x_train, (-1, 28, 28, 1)) y_train = tf.one_hot(y_train, 10) x_test = x_test.astype("float32") / 255.0 x_test = np.reshape(x_test, (-1, 28, 28, 1)) y_test = tf.one_hot(y_test, 10) # Put aside a few samples to create our validation set val_samples = 2000 x_val, y_val = x_train[:val_samples], y_train[:val_samples] new_x_train, new_y_train = x_train[val_samples:], y_train[val_samples:] train_ds_one = ( tf.data.Dataset.from_tensor_slices((new_x_train, new_y_train)) .shuffle(BATCH_SIZE * 100) .batch(BATCH_SIZE) ) train_ds_two = ( tf.data.Dataset.from_tensor_slices((new_x_train, new_y_train)) .shuffle(BATCH_SIZE * 100) .batch(BATCH_SIZE) ) # Because we will be mixing up the images and their corresponding labels, we will be # combining two shuffled datasets from the same training data. train_ds = tf.data.Dataset.zip((train_ds_one, train_ds_two)) val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(BATCH_SIZE) test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE) # this next set of code is from: https://github.com/shenasa-ai/blog_tutorials/blob/main/mixup-data-augmentation.ipynb ################### def sample_beta_distribution(size, alpha): gamma_left = tf.random.gamma(shape=[size], alpha=alpha) gamma_right = tf.random.gamma(shape=[size], alpha=alpha) beta = gamma_left / (gamma_left + gamma_right) return beta def linear_combination(x1, x2, alpha): return x1 * alpha + x2 * (1 - alpha) def mix_up(ds_one, ds_two, alpha=0.2): # Unpack two datasets images_one, labels_one = ds_one images_two, labels_two = ds_two batch_size = tf.shape(images_one)[0] # Sample lambda and reshape it to do the mixup λ = sample_beta_distribution(batch_size, alpha) images_λ = tf.reshape(λ, (batch_size, 1, 1, 1)) # 3channel images labels_λ = tf.reshape(λ, (batch_size, 1)) # Perform mixup on both images and labels by combining a pair of images/labels # (one from each dataset) into one image/label images = linear_combination(images_one, images_two, images_λ) labels = linear_combination(labels_one, labels_two, labels_λ) return (images, labels) # Because we will be mixing up the images and their corresponding labels, we will be # combining two shuffled datasets from the same training data. train_ds_mu = tf.data.Dataset.zip((train_ds_one, train_ds_two)) test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(64) # First create the new dataset using our `mix_up` utility train_ds_mu = train_ds_mu.map(mix_up, num_parallel_calls=tf.data.AUTOTUNE) ################### ############ # preview of mixup samples from keras.io # Let's preview 9 samples from the dataset sample_images, sample_labels = next(iter(train_ds_mu)) plt.figure(figsize=(10, 10)) for i, (image, label) in enumerate(zip(sample_images[:9], sample_labels[:9])): ax = plt.subplot(3, 3, i + 1) plt.imshow(image.numpy().squeeze()) print(label.numpy().tolist()) plt.axis("off") ########### ############ # This displays it in a nicer format # preview of mixup samples from: https://github.com/shenasa-ai/blog_tutorials/blob/main/mixup-data-augmentation.ipynb # Let's preview 9 samples from the dataset sample_images, sample_labels = next(iter(train_ds_mu)) plt.figure(figsize=(10, 10)) for i, (image, label) in enumerate(zip(sample_images[:9], sample_labels[:9])): ax = plt.subplot(3, 3, i + 1) classes = np.argsort(label)[-2:][::-1] scores = np.round(label, 2)[classes] plt.title(f'C:{classes}, S:{scores}') plt.imshow(image.numpy().squeeze(), cmap='gray') plt.axis("off") ############ # model from keras.io example def get_training_model(): model = tf.keras.Sequential( [ layers.Input(shape=(28, 28, 1)), layers.Conv2D(16, (5, 5), activation="relu"), layers.MaxPooling2D(pool_size=(2, 2)), layers.Conv2D(32, (5, 5), activation="relu"), layers.MaxPooling2D(pool_size=(2, 2)), layers.Dropout(0.2), layers.GlobalAveragePooling2D(), layers.Dense(128, activation="relu"), layers.Dense(10, activation="softmax"), ] ) return model initial_model = get_training_model() initial_model.save_weights("initial_weights.weights.h5") # train with mixup samples and test model = get_training_model() model.load_weights("initial_weights.weights.h5") model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) model.fit(train_ds_mu, validation_data=val_ds, epochs=EPOCHS) _, test_acc = model.evaluate(test_ds) print("After training WITH MIXUP: Test accuracy: {:.2f}%".format(test_acc * 100)) # train without mixup and test model = get_training_model() model.load_weights("initial_weights.weights.h5") model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) # Notice that we are NOT using the mixed up dataset here model.fit(train_ds_one, validation_data=val_ds, epochs=EPOCHS) _, test_acc = model.evaluate(test_ds) print("After training WITHOUT MIXUP: Test accuracy: {:.2f}%".format(test_acc * 100)) ''' Epoch 1/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 6s 5ms/step - accuracy: 0.5375 - loss: 1.4093 - val_accuracy: 0.7515 - val_loss: 0.6717 Epoch 2/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.7208 - loss: 0.9537 - val_accuracy: 0.7875 - val_loss: 0.5825 Epoch 3/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.7514 - loss: 0.8885 - val_accuracy: 0.8225 - val_loss: 0.5249 Epoch 4/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.7772 - loss: 0.8295 - val_accuracy: 0.8410 - val_loss: 0.4742 Epoch 5/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.7905 - loss: 0.7969 - val_accuracy: 0.8480 - val_loss: 0.4588 Epoch 6/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.8040 - loss: 0.7697 - val_accuracy: 0.8620 - val_loss: 0.4141 Epoch 7/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.8069 - loss: 0.7498 - val_accuracy: 0.8635 - val_loss: 0.4045 Epoch 8/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.8197 - loss: 0.7215 - val_accuracy: 0.8695 - val_loss: 0.3827 Epoch 9/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.8216 - loss: 0.7129 - val_accuracy: 0.8710 - val_loss: 0.3813 Epoch 10/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.8244 - loss: 0.6967 - val_accuracy: 0.8665 - val_loss: 0.3821 157/157 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - accuracy: 0.8694 - loss: 0.3873 Test accuracy: 86.30% Epoch 1/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 6s 5ms/step - accuracy: 0.5734 - loss: 1.1591 - val_accuracy: 0.7670 - val_loss: 0.6384 Epoch 2/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.7528 - loss: 0.6558 - val_accuracy: 0.8110 - val_loss: 0.5361 Epoch 3/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.7931 - loss: 0.5621 - val_accuracy: 0.8460 - val_loss: 0.4629 Epoch 4/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.8169 - loss: 0.5054 - val_accuracy: 0.8425 - val_loss: 0.4595 Epoch 5/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.8314 - loss: 0.4684 - val_accuracy: 0.8585 - val_loss: 0.4215 Epoch 6/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.8364 - loss: 0.4495 - val_accuracy: 0.8610 - val_loss: 0.4135 Epoch 7/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.8491 - loss: 0.4213 - val_accuracy: 0.8645 - val_loss: 0.3763 Epoch 8/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.8552 - loss: 0.4034 - val_accuracy: 0.8740 - val_loss: 0.3632 Epoch 9/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.8591 - loss: 0.3875 - val_accuracy: 0.8665 - val_loss: 0.3602 Epoch 10/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.8646 - loss: 0.3730 - val_accuracy: 0.8745 - val_loss: 0.3381 157/157 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - accuracy: 0.8720 - loss: 0.3592 Test accuracy: 86.73% another run: Test accuracy: 86.55% Test accuracy: 87.16% '''