### Tutorial: Basic Classification --------------------------------------------

##1# DATA Source : INPUT LAYER --------------------------------------------------
fashion_mnist <- dataset_fashion_mnist()
fashion_mnist %>% str

# List of 2
# $ train:List of 2
# ..$ x: int [1:60000, 1:28, 1:28] 0 0 0 0 0 0 0 0 0 0 ...
# ..$ y: int [1:60000(1d)] 9 0 0 3 0 2 7 2 5 5 ...
# $ test :List of 2
# ..$ x: int [1:10000, 1:28, 1:28] 0 0 0 0 0 0 0 0 0 0 ...
# ..$ y: int [1:10000(1d)] 9 2 1 1 6 1 4 6 5 7 ...

c(train_images, train_labels) %<-% fashion_mnist$train
c(test_images, test_labels) %<-% fashion_mnist$test 

# x_train <- fashion_mnist$train$x
# y_train <- fashion_mnist$train$y
# identical(x_train, train_images) # TRUE
# identical(y_train, train_labels) # TRUE

class_names = c('T-shirt/top',
                'Ankle boot')

##2# Preprocess the data ---------------------------------------------------------

# ` ` rescale ------------------------------------------------------------------
# from integers ranging between 0 to 255 
# into floating point values ranging between 0 and 1
train_images <- train_images / max(train_images)
test_images  <- test_images  / max(test_images)

image_1 <- train_images[1, , ]

# ` ` Ploting -Way1 ------------------------------------------------------------
# Matrix to DF to ggplot
imgDT_1 <- train_images[1, , ] %>%
colnames(imgDT_1) <- seq_len(ncol(imgDT_1)) %>% as.character()

imgDT_1[ , id:= seq(1,nrow(imgDT_1))]
imgDT_1_melt <-, id.vars = c("id"))
imgDT_1_melt[ , variable:=as.integer(variable)]

imgDT_1_melt %>% ggplot(aes(x=variable, y=id, fill=value)) + geom_tile() +
  scale_y_reverse() +  
  scale_fill_gradient(low="white", high="black", na.value=NA) +
  theme_ipsum() +  
  theme(panel.grid=element_blank()) + labs(x="", y="") +

# ` ` Ploting -Way2 ----
# Matrix to image
image_1 %>% image()

# rotate 90 clockwise 
rotate90 <- function(x) t(apply(x, 2, rev))
image_1 %>% rotate90 %>% image()

image_1 %>% rotate90 %>% image(useRaster=T)                            # rasterImage, 브라운관 점방식
image_1 %>% rotate90 %>% image(axes=F)                                 # remove axis
image_1 %>% rotate90 %>% image(col=grey(seq(from=0,to=1,length=256)))  # black&white
image_1 %>% rotate90 %>% image(useRaster=T, axes=F, col=grey(seq(from=0,to=1,length=256)))

# ` ` Ploting -Way2-1 ----
par(mar=c(0, 0, 1.5, 0), xaxs='i', yaxs='i')
for (i in 1:25) { 
  img <- train_images[i, , ]
  img <- t(apply(img, 2, rev)) 
  image(x=1:28, y=1:28, img, col = gray((0:255)/255), xaxt='n', yaxt = 'n',
        main = paste(class_names[train_labels[i] + 1]))

##3# Build the model -------------------------------------------------------------
# ` ` Setup the layers ---------------------------------------------------------

model <- keras_model_sequential()
model %>%
  layer_flatten(input_shape=c(28, 28)) %>%
  layer_dense(units=128, activation='relu') %>%
  layer_dense(units=10,  activation='softmax')

model %>% summary
# ` ` reshape - Image Flatten --------------------------------------------------
# dim(train_images) <- c(nrow(train_images), 28*28) # 60000   784
# dim(test_images)  <- c(nrow(test_images),  784)

# model %>% 
#   layer_dense(units=256, activation="relu", input_shape=c(28*28)) %>% 
#   layer_dropout(rate=0.4) %>% 
#   layer_dense(units=128, activation="relu") %>% 
#   layer_dropout(rate=0.3) %>% 
#   layer_dense(units=10,  activation="softmax")

# ` ` Compile the model --------------------------------------------------------
model %>% compile(
  optimizer = 'adam', 
  loss = 'sparse_categorical_crossentropy',
  metrics = c('accuracy')
# model %>% compile(
#   optimizer=optimizer_rmsprop(),
#   loss="categorical_crossentropy",
#   metrics = c("accuracy"))

##4# Train the model --------------------------------------------------------------
model %>% fit(train_images, train_labels, epochs=5, validation_split=0.2)

# history <- model %>% fit(train_images, train_labels, 
#                          epochs=5, batch_size=128, validation_split=0.2)
# history %>% plot()

##5# Evaluate accuracy ------------------------------------------------------------
score <- model %>% evaluate(test_images, test_labels)
cat('Test loss:',     score$loss, "\
cat('Test accuracy:', score$acc,  "\

##6# Make predictions -------------------------------------------------------------
predicted <- model %>% predict(test_images)
predicted[1, ]
which.max(predicted[1, ])

class_pred <- model %>% predict_classes(test_images)

par(mar=c(0, 0, 1.5, 0), xaxs='i', yaxs='i')
for (i in 1:25) { 
  img <- test_images[i, , ]
  img <- t(apply(img, 2, rev)) 
  # subtract 1 as labels go from 0 to 9
  predicted_label <- which.max(predicted[i, ]) - 1
  true_label <- test_labels[i]
  if (predicted_label == true_label) {
    color <- '#008800' 
  } else {
    color <- '#bb0000'
  image(1:28, 1:28, img, col = gray((0:255)/255), xaxt = 'n', yaxt = 'n',
        main = paste0(class_names[predicted_label + 1], " (",
                      class_names[true_label + 1], ")"),
        col.main = color)

# Grab an image from the test dataset
# take care to keep the batch dimension, as this is expected by the model
img <- test_images[1, , , drop = FALSE]

predictions <- model %>% predict(img)

# subtract 1 as labels are 0-based
prediction <- predictions[1, ] - 1

class_pred <- model %>% predict_classes(img)
