This notebook is based on tutorials, including code, from Rstudio (https://tensorflow.rstudio.com/) as well as Chollet and Allaire’s Deep Learning with R (2018), currently accessible free to Penn State students through the Penn State library: https://learning.oreilly.com/library/view/deep-learning-with/9781617295546/. Notebooks for the originals of Deep Learning with R are available here: https://github.com/jjallaire/deep-learning-with-r-notebooks.

We will be implementing neural models in R through the keras package, which itself, by default, uses the tensorflow “backend.” You can access TensorFlow directly – which provides more flexibility but requires more of the user – and you can also use different backends, specifically CNTK and Theano through keras. (The R library keras is an interface to Keras itself, which offers an API to a backend like TensorFlow.) Keras is generally described as “high-level” or “model-level”, meaning the researcher can build models using Keras building blocks – which is probably all most of you would ever want to do.

Warning 1: Keras (https://keras.io) is written in Python, so (a) installing keras and tensorflow creates a Python environment on your machine (in my case, it detects Anaconda and creates a conda environment called r-tensorflow), and (b) much of the keras syntax is Pythonic (like 0-based indexing in some contexts), as are the often untraceable error messages.

# devtools::install_github("rstudio/keras")

I used the install_keras function to install a default CPU-based keras and tensorflow. There are more details for alternative installations here: https://tensorflow.rstudio.com/keras/

Warning 2: There is currently an error in TensorFlow that manifested in the code that follows. The fix in my case was installing the “nightly” build which also back installs Python 3.6 instead of 3.7 in a new r-tensorflow environment. By the time you read this, the error probably won’t exist, in which case install_keras() should be sufficient.

library(keras)
# install_keras(tensorflow="nightly")

Building a Deep Classifier in Keras

We’ll work with the IMDB review dataset that comes with keras

imdb <- dataset_imdb(num_words = 5000)
c(c(train_data, train_labels), c(test_data, test_labels)) %<-% imdb
### This is equivalent to:
# imdb <- dataset_imdb(num_words = 5000)
# train_data <- imdb$train$x
# train_labels <- imdb$train$y
# test_data <- imdb$test$x
# test_labels <- imdb$test$y

Look at the training data (features) and labels (positive/negative).

str(train_data[[1]])
 int [1:218] 1 14 22 16 43 530 973 1622 1385 65 ...
train_labels[[1]]
[1] 1
max(sapply(train_data, max))
[1] 4999

Probably a good idea to figure out how to get back to text. Decode review 1:

decoded_review
  [1] "?"          "this"       "film"      
  [4] "was"        "just"       "brilliant" 
  [7] "casting"    "location"   "scenery"   
 [10] "story"      "direction"  "everyone's"
 [13] "really"     "suited"     "the"       
 [16] "part"       "they"       "played"    
 [19] "and"        "you"        "could"     
 [22] "just"       "imagine"    "being"     
 [25] "there"      "robert"     "?"         
 [28] "is"         "an"         "amazing"   
 [31] "actor"      "and"        "now"       
 [34] "the"        "same"       "being"     
 [37] "director"   "?"          "father"    
 [40] "came"       "from"       "the"       
 [43] "same"       "scottish"   "island"    
 [46] "as"         "myself"     "so"        
 [49] "i"          "loved"      "the"       
 [52] "fact"       "there"      "was"       
 [55] "a"          "real"       "connection"
 [58] "with"       "this"       "film"      
 [61] "the"        "witty"      "remarks"   
 [64] "throughout" "the"        "film"      
 [67] "were"       "great"      "it"        
 [70] "was"        "just"       "brilliant" 
 [73] "so"         "much"       "that"      
 [76] "i"          "bought"     "the"       
 [79] "film"       "as"         "soon"      
 [82] "as"         "it"         "was"       
 [85] "released"   "for"        "?"         
 [88] "and"        "would"      "recommend" 
 [91] "it"         "to"         "everyone"  
 [94] "to"         "watch"      "and"       
 [97] "the"        "fly"        "?"         
[100] "was"        "amazing"    "really"    
[103] "cried"      "at"         "the"       
[106] "end"        "it"         "was"       
[109] "so"         "sad"        "and"       
[112] "you"        "know"       "what"      
[115] "they"       "say"        "if"        
[118] "you"        "cry"        "at"        
[121] "a"          "film"       "it"        
[124] "must"       "have"       "been"      
[127] "good"       "and"        "this"      
[130] "definitely" "was"        "also"      
[133] "?"          "to"         "the"       
[136] "two"        "little"     "?"         
[139] "that"       "played"     "the"       
[142] "?"          "of"         "norman"    
[145] "and"        "paul"       "they"      
[148] "were"       "just"       "brilliant" 
[151] "children"   "are"        "often"     
[154] "left"       "out"        "of"        
[157] "the"        "?"          "list"      
[160] "i"          "think"      "because"   
[163] "the"        "stars"      "that"      
[166] "play"       "them"       "all"       
[169] "grown"      "up"         "are"       
[172] "such"       "a"          "big"       
[175] "?"          "for"        "the"       
[178] "whole"      "film"       "but"       
[181] "these"      "children"   "are"       
[184] "amazing"    "and"        "should"    
[187] "be"         "?"          "for"       
[190] "what"       "they"       "have"      
[193] "done"       "don't"      "you"       
[196] "think"      "the"        "whole"     
[199] "story"      "was"        "so"        
[202] "lovely"     "because"    "it"        
[205] "was"        "true"       "and"       
[208] "was"        "someone's"  "life"      
[211] "after"      "all"        "that"      
[214] "was"        "?"          "with"      
[217] "us"         "all"       

Create “one-hot” vectorization of input features. (Binary indicators of presence or absence of feature - in this case word / token - in “sequence” - in this case review.)

vectorize_sequences <- function(sequences, dimension = 5000) {
  results <- matrix(0, nrow = length(sequences), ncol = dimension)
  for (i in 1:length(sequences))
    results[i, sequences[[i]]] <- 1 
  results
}
x_train <- vectorize_sequences(train_data)
x_test <- vectorize_sequences(test_data)
# Also change labels from integer to numeric
y_train <- as.numeric(train_labels)
y_test <- as.numeric(test_labels)

We need a model architecture. In many cases, this can be built a simple layer building blocks from Keras. Here’s a three layer network for our classification problem:

model <- keras_model_sequential() %>%
  layer_dense(units = 16, activation = "relu", input_shape = c(5000)) %>%
  layer_dense(units = 16, activation = "relu") %>%
  layer_dense(units = 1, activation = "sigmoid")

We need to “compile” our model by adding information about our loss function, what optimizer we wish to use, and what metrics we want to keep track of.

model %>% compile(
  optimizer = "rmsprop",
  loss = "binary_crossentropy",
  metrics = c("accuracy")
)

Create a held-out set of your training data for validation.

val_indices <- 1:10000 # not great practice if these are ordered
x_val <- x_train[val_indices,]
partial_x_train <- x_train[-val_indices,]
y_val <- y_train[val_indices]
partial_y_train <- y_train[-val_indices]

Fit the model, store the fit history.

history <- model %>% fit(
  partial_x_train,
  partial_y_train,
  epochs = 20,
  batch_size = 512,
  validation_data = list(x_val, y_val)
)
2019-11-11 16:31:29.700615: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA

WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Train on 15000 samples, validate on 10000 samples
Epoch 1/20

  512/15000 [>.............................] - ETA: 5s - loss: 0.6948 - acc: 0.4980
 3584/15000 [======>.......................] - ETA: 0s - loss: 0.6565 - acc: 0.6283
 7168/15000 [=============>................] - ETA: 0s - loss: 0.6042 - acc: 0.7121
10752/15000 [====================>.........] - ETA: 0s - loss: 0.5660 - acc: 0.7503
13824/15000 [==========================>...] - ETA: 0s - loss: 0.5373 - acc: 0.7717
15000/15000 [==============================] - 1s 51us/sample - loss: 0.5266 - acc: 0.7797 - val_loss: 0.4061 - val_acc: 0.8590
Epoch 2/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.3654 - acc: 0.8926
 3584/15000 [======>.......................] - ETA: 0s - loss: 0.3678 - acc: 0.8792
 6656/15000 [============>.................] - ETA: 0s - loss: 0.3620 - acc: 0.8804
 9728/15000 [==================>...........] - ETA: 0s - loss: 0.3491 - acc: 0.8859
12800/15000 [========================>.....] - ETA: 0s - loss: 0.3385 - acc: 0.8870
15000/15000 [==============================] - 1s 37us/sample - loss: 0.3337 - acc: 0.8880 - val_loss: 0.3255 - val_acc: 0.8745
Epoch 3/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.2889 - acc: 0.9023
 3584/15000 [======>.......................] - ETA: 0s - loss: 0.2659 - acc: 0.9141
 6656/15000 [============>.................] - ETA: 0s - loss: 0.2666 - acc: 0.9124
 9728/15000 [==================>...........] - ETA: 0s - loss: 0.2646 - acc: 0.9108
12800/15000 [========================>.....] - ETA: 0s - loss: 0.2616 - acc: 0.9103
15000/15000 [==============================] - 1s 34us/sample - loss: 0.2595 - acc: 0.9099 - val_loss: 0.2926 - val_acc: 0.8844
Epoch 4/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.1898 - acc: 0.9395
 3584/15000 [======>.......................] - ETA: 0s - loss: 0.2176 - acc: 0.9314
 6656/15000 [============>.................] - ETA: 0s - loss: 0.2225 - acc: 0.9259
 9728/15000 [==================>...........] - ETA: 0s - loss: 0.2207 - acc: 0.9230
12288/15000 [=======================>......] - ETA: 0s - loss: 0.2234 - acc: 0.9214
14848/15000 [============================>.] - ETA: 0s - loss: 0.2229 - acc: 0.9209
15000/15000 [==============================] - 1s 36us/sample - loss: 0.2220 - acc: 0.9212 - val_loss: 0.2873 - val_acc: 0.8830
Epoch 5/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.2280 - acc: 0.9199
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.1959 - acc: 0.9303
 5632/15000 [==========>...................] - ETA: 0s - loss: 0.1888 - acc: 0.9345
 8192/15000 [===============>..............] - ETA: 0s - loss: 0.1878 - acc: 0.9330
10752/15000 [====================>.........] - ETA: 0s - loss: 0.1894 - acc: 0.9322
13824/15000 [==========================>...] - ETA: 0s - loss: 0.1927 - acc: 0.9301
15000/15000 [==============================] - 1s 36us/sample - loss: 0.1929 - acc: 0.9307 - val_loss: 0.3036 - val_acc: 0.8783
Epoch 6/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.1768 - acc: 0.9316
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.1677 - acc: 0.9430
 5632/15000 [==========>...................] - ETA: 0s - loss: 0.1682 - acc: 0.9411
 8704/15000 [================>.............] - ETA: 0s - loss: 0.1703 - acc: 0.9415
11776/15000 [======================>.......] - ETA: 0s - loss: 0.1678 - acc: 0.9423
14848/15000 [============================>.] - ETA: 0s - loss: 0.1729 - acc: 0.9387
15000/15000 [==============================] - 1s 36us/sample - loss: 0.1726 - acc: 0.9388 - val_loss: 0.2886 - val_acc: 0.8841
Epoch 7/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.1492 - acc: 0.9492
 2048/15000 [===>..........................] - ETA: 0s - loss: 0.1376 - acc: 0.9556
 4096/15000 [=======>......................] - ETA: 0s - loss: 0.1442 - acc: 0.9514
 5632/15000 [==========>...................] - ETA: 0s - loss: 0.1466 - acc: 0.9510
 7680/15000 [==============>...............] - ETA: 0s - loss: 0.1472 - acc: 0.9500
10240/15000 [===================>..........] - ETA: 0s - loss: 0.1519 - acc: 0.9485
12800/15000 [========================>.....] - ETA: 0s - loss: 0.1570 - acc: 0.9453
15000/15000 [==============================] - 1s 45us/sample - loss: 0.1586 - acc: 0.9443 - val_loss: 0.2977 - val_acc: 0.8807
Epoch 8/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.1394 - acc: 0.9414
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.1331 - acc: 0.9515
 5632/15000 [==========>...................] - ETA: 0s - loss: 0.1349 - acc: 0.9522
 8192/15000 [===============>..............] - ETA: 0s - loss: 0.1344 - acc: 0.9543
10752/15000 [====================>.........] - ETA: 0s - loss: 0.1408 - acc: 0.9510
13312/15000 [=========================>....] - ETA: 0s - loss: 0.1428 - acc: 0.9491
15000/15000 [==============================] - 1s 38us/sample - loss: 0.1414 - acc: 0.9500 - val_loss: 0.3113 - val_acc: 0.8759
Epoch 9/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.1477 - acc: 0.9492
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.1261 - acc: 0.9609
 5632/15000 [==========>...................] - ETA: 0s - loss: 0.1271 - acc: 0.9590
 8192/15000 [===============>..............] - ETA: 0s - loss: 0.1272 - acc: 0.9573
10752/15000 [====================>.........] - ETA: 0s - loss: 0.1260 - acc: 0.9574
13312/15000 [=========================>....] - ETA: 0s - loss: 0.1310 - acc: 0.9542
15000/15000 [==============================] - 1s 37us/sample - loss: 0.1310 - acc: 0.9541 - val_loss: 0.3268 - val_acc: 0.8776
Epoch 10/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.0997 - acc: 0.9766
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.1017 - acc: 0.9707
 5632/15000 [==========>...................] - ETA: 0s - loss: 0.1053 - acc: 0.9686
 8192/15000 [===============>..............] - ETA: 0s - loss: 0.1107 - acc: 0.9661
10752/15000 [====================>.........] - ETA: 0s - loss: 0.1135 - acc: 0.9636
13312/15000 [=========================>....] - ETA: 0s - loss: 0.1193 - acc: 0.9603
15000/15000 [==============================] - 1s 39us/sample - loss: 0.1194 - acc: 0.9596 - val_loss: 0.3521 - val_acc: 0.8697
Epoch 11/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.1096 - acc: 0.9648
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.1012 - acc: 0.9688
 5632/15000 [==========>...................] - ETA: 0s - loss: 0.0988 - acc: 0.9696
 7680/15000 [==============>...............] - ETA: 0s - loss: 0.1009 - acc: 0.9694
10240/15000 [===================>..........] - ETA: 0s - loss: 0.1050 - acc: 0.9672
12800/15000 [========================>.....] - ETA: 0s - loss: 0.1050 - acc: 0.9667
15000/15000 [==============================] - 1s 39us/sample - loss: 0.1054 - acc: 0.9657 - val_loss: 0.3571 - val_acc: 0.8696
Epoch 12/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.0850 - acc: 0.9766
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.0988 - acc: 0.9678
 5632/15000 [==========>...................] - ETA: 0s - loss: 0.0961 - acc: 0.9700
 8192/15000 [===============>..............] - ETA: 0s - loss: 0.0940 - acc: 0.9706
10752/15000 [====================>.........] - ETA: 0s - loss: 0.0936 - acc: 0.9706
13312/15000 [=========================>....] - ETA: 0s - loss: 0.0965 - acc: 0.9684
15000/15000 [==============================] - 1s 38us/sample - loss: 0.0991 - acc: 0.9673 - val_loss: 0.3736 - val_acc: 0.8680
Epoch 13/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.0817 - acc: 0.9746
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.0732 - acc: 0.9775
 5632/15000 [==========>...................] - ETA: 0s - loss: 0.0762 - acc: 0.9783
 8192/15000 [===============>..............] - ETA: 0s - loss: 0.0840 - acc: 0.9742
10240/15000 [===================>..........] - ETA: 0s - loss: 0.0849 - acc: 0.9741
12800/15000 [========================>.....] - ETA: 0s - loss: 0.0901 - acc: 0.9709
15000/15000 [==============================] - 1s 39us/sample - loss: 0.0915 - acc: 0.9697 - val_loss: 0.3958 - val_acc: 0.8648
Epoch 14/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.0643 - acc: 0.9883
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.0666 - acc: 0.9847
 5120/15000 [=========>....................] - ETA: 0s - loss: 0.0665 - acc: 0.9840
 7680/15000 [==============>...............] - ETA: 0s - loss: 0.0686 - acc: 0.9820
10240/15000 [===================>..........] - ETA: 0s - loss: 0.0730 - acc: 0.9790
12800/15000 [========================>.....] - ETA: 0s - loss: 0.0746 - acc: 0.9773
15000/15000 [==============================] - 1s 39us/sample - loss: 0.0767 - acc: 0.9764 - val_loss: 0.4126 - val_acc: 0.8631
Epoch 15/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.0629 - acc: 0.9883
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.0650 - acc: 0.9811
 5120/15000 [=========>....................] - ETA: 0s - loss: 0.0721 - acc: 0.9789
 7168/15000 [=============>................] - ETA: 0s - loss: 0.0756 - acc: 0.9766
 9728/15000 [==================>...........] - ETA: 0s - loss: 0.0747 - acc: 0.9773
12800/15000 [========================>.....] - ETA: 0s - loss: 0.0751 - acc: 0.9767
15000/15000 [==============================] - 1s 39us/sample - loss: 0.0743 - acc: 0.9767 - val_loss: 0.4257 - val_acc: 0.8650
Epoch 16/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.0618 - acc: 0.9883
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.0593 - acc: 0.9867
 5632/15000 [==========>...................] - ETA: 0s - loss: 0.0572 - acc: 0.9860
 8192/15000 [===============>..............] - ETA: 0s - loss: 0.0605 - acc: 0.9832
10752/15000 [====================>.........] - ETA: 0s - loss: 0.0614 - acc: 0.9828
13312/15000 [=========================>....] - ETA: 0s - loss: 0.0632 - acc: 0.9816
15000/15000 [==============================] - 1s 39us/sample - loss: 0.0634 - acc: 0.9816 - val_loss: 0.4482 - val_acc: 0.8633
Epoch 17/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.0530 - acc: 0.9902
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.0670 - acc: 0.9805
 5632/15000 [==========>...................] - ETA: 0s - loss: 0.0604 - acc: 0.9822
 8192/15000 [===============>..............] - ETA: 0s - loss: 0.0618 - acc: 0.9813
10752/15000 [====================>.........] - ETA: 0s - loss: 0.0599 - acc: 0.9825
13312/15000 [=========================>....] - ETA: 0s - loss: 0.0608 - acc: 0.9817
15000/15000 [==============================] - 1s 39us/sample - loss: 0.0599 - acc: 0.9825 - val_loss: 0.4742 - val_acc: 0.8616
Epoch 18/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.0339 - acc: 0.9980
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.0410 - acc: 0.9935
 5632/15000 [==========>...................] - ETA: 0s - loss: 0.0462 - acc: 0.9899
 8192/15000 [===============>..............] - ETA: 0s - loss: 0.0532 - acc: 0.9849
10752/15000 [====================>.........] - ETA: 0s - loss: 0.0532 - acc: 0.9855
13312/15000 [=========================>....] - ETA: 0s - loss: 0.0519 - acc: 0.9866
15000/15000 [==============================] - 1s 38us/sample - loss: 0.0533 - acc: 0.9859 - val_loss: 0.5359 - val_acc: 0.8493
Epoch 19/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.0659 - acc: 0.9805
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.0431 - acc: 0.9906
 5120/15000 [=========>....................] - ETA: 0s - loss: 0.0426 - acc: 0.9916
 7680/15000 [==============>...............] - ETA: 0s - loss: 0.0414 - acc: 0.9914
10240/15000 [===================>..........] - ETA: 0s - loss: 0.0418 - acc: 0.9912
12800/15000 [========================>.....] - ETA: 0s - loss: 0.0468 - acc: 0.9887
15000/15000 [==============================] - 1s 39us/sample - loss: 0.0468 - acc: 0.9882 - val_loss: 0.5153 - val_acc: 0.8619
Epoch 20/20

  512/15000 [>.............................] - ETA: 0s - loss: 0.0289 - acc: 0.9941
 3072/15000 [=====>........................] - ETA: 0s - loss: 0.0300 - acc: 0.9948
 5632/15000 [==========>...................] - ETA: 0s - loss: 0.0319 - acc: 0.9943
 8192/15000 [===============>..............] - ETA: 0s - loss: 0.0375 - acc: 0.9918
10752/15000 [====================>.........] - ETA: 0s - loss: 0.0407 - acc: 0.9906
13312/15000 [=========================>....] - ETA: 0s - loss: 0.0407 - acc: 0.9905
15000/15000 [==============================] - 1s 39us/sample - loss: 0.0405 - acc: 0.9904 - val_loss: 0.5282 - val_acc: 0.8611
str(history)
List of 2
 $ params :List of 7
  ..$ batch_size   : int 512
  ..$ epochs       : int 20
  ..$ steps        : NULL
  ..$ samples      : int 15000
  ..$ verbose      : int 0
  ..$ do_validation: logi TRUE
  ..$ metrics      : chr [1:4] "loss" "acc" "val_loss" "val_acc"
 $ metrics:List of 4
  ..$ loss    : num [1:20] 0.527 0.334 0.26 0.222 0.193 ...
  ..$ acc     : num [1:20] 0.78 0.888 0.91 0.921 0.931 ...
  ..$ val_loss: num [1:20] 0.406 0.326 0.293 0.287 0.304 ...
  ..$ val_acc : num [1:20] 0.859 0.874 0.884 0.883 0.878 ...
 - attr(*, "class")= chr "keras_training_history"
plot(history)

Overfitting. Fit at smaller number of epochs.

model <- keras_model_sequential() %>%
  layer_dense(units = 16, activation = "relu", input_shape = c(5000)) %>%
  layer_dense(units = 16, activation = "relu") %>%
  layer_dense(units = 1, activation = "sigmoid")
model %>% compile(
  optimizer = "rmsprop",
  loss = "binary_crossentropy",
  metrics = c("accuracy")
)
model %>% fit(x_train, y_train, epochs = 6, batch_size = 512)
Epoch 1/6

  512/25000 [..............................] - ETA: 8s - loss: 0.6923 - acc: 0.5195
 2560/25000 [==>...........................] - ETA: 2s - loss: 0.6715 - acc: 0.5996
 4608/25000 [====>.........................] - ETA: 1s - loss: 0.6407 - acc: 0.6597
 6656/25000 [======>.......................] - ETA: 0s - loss: 0.6095 - acc: 0.7006
 9216/25000 [==========>...................] - ETA: 0s - loss: 0.5798 - acc: 0.7296
11776/25000 [=============>................] - ETA: 0s - loss: 0.5550 - acc: 0.7502
14336/25000 [================>.............] - ETA: 0s - loss: 0.5333 - acc: 0.7657
16896/25000 [===================>..........] - ETA: 0s - loss: 0.5143 - acc: 0.7792
19456/25000 [======================>.......] - ETA: 0s - loss: 0.4984 - acc: 0.7895
22016/25000 [=========================>....] - ETA: 0s - loss: 0.4845 - acc: 0.7987
24576/25000 [============================>.] - ETA: 0s - loss: 0.4729 - acc: 0.8053
25000/25000 [==============================] - 1s 36us/sample - loss: 0.4709 - acc: 0.8064
Epoch 2/6

  512/25000 [..............................] - ETA: 0s - loss: 0.3216 - acc: 0.9043
 3072/25000 [==>...........................] - ETA: 0s - loss: 0.3091 - acc: 0.9105
 6144/25000 [======>.......................] - ETA: 0s - loss: 0.3146 - acc: 0.8991
 9216/25000 [==========>...................] - ETA: 0s - loss: 0.3020 - acc: 0.9016
11776/25000 [=============>................] - ETA: 0s - loss: 0.2999 - acc: 0.9001
14336/25000 [================>.............] - ETA: 0s - loss: 0.2984 - acc: 0.8991
16896/25000 [===================>..........] - ETA: 0s - loss: 0.2968 - acc: 0.8978
19456/25000 [======================>.......] - ETA: 0s - loss: 0.2946 - acc: 0.8975
22016/25000 [=========================>....] - ETA: 0s - loss: 0.2935 - acc: 0.8968
24576/25000 [============================>.] - ETA: 0s - loss: 0.2919 - acc: 0.8964
25000/25000 [==============================] - 1s 26us/sample - loss: 0.2915 - acc: 0.8962
Epoch 3/6

  512/25000 [..............................] - ETA: 0s - loss: 0.2423 - acc: 0.9219
 3072/25000 [==>...........................] - ETA: 0s - loss: 0.2397 - acc: 0.9144
 5632/25000 [=====>........................] - ETA: 0s - loss: 0.2429 - acc: 0.9148
 8192/25000 [========>.....................] - ETA: 0s - loss: 0.2430 - acc: 0.9142
10752/25000 [===========>..................] - ETA: 0s - loss: 0.2430 - acc: 0.9134
13312/25000 [==============>...............] - ETA: 0s - loss: 0.2411 - acc: 0.9142
15872/25000 [==================>...........] - ETA: 0s - loss: 0.2383 - acc: 0.9144
18432/25000 [=====================>........] - ETA: 0s - loss: 0.2370 - acc: 0.9149
20992/25000 [========================>.....] - ETA: 0s - loss: 0.2380 - acc: 0.9140
22528/25000 [==========================>...] - ETA: 0s - loss: 0.2384 - acc: 0.9135
24576/25000 [============================>.] - ETA: 0s - loss: 0.2391 - acc: 0.9124
25000/25000 [==============================] - 1s 28us/sample - loss: 0.2389 - acc: 0.9125
Epoch 4/6

  512/25000 [..............................] - ETA: 0s - loss: 0.1904 - acc: 0.9355
 3072/25000 [==>...........................] - ETA: 0s - loss: 0.2067 - acc: 0.9271
 5632/25000 [=====>........................] - ETA: 0s - loss: 0.2042 - acc: 0.9293
 8192/25000 [========>.....................] - ETA: 0s - loss: 0.2061 - acc: 0.9281
10752/25000 [===========>..................] - ETA: 0s - loss: 0.2090 - acc: 0.9253
13312/25000 [==============>...............] - ETA: 0s - loss: 0.2071 - acc: 0.9257
15872/25000 [==================>...........] - ETA: 0s - loss: 0.2078 - acc: 0.9255
18432/25000 [=====================>........] - ETA: 0s - loss: 0.2080 - acc: 0.9252
20992/25000 [========================>.....] - ETA: 0s - loss: 0.2072 - acc: 0.9249
23552/25000 [===========================>..] - ETA: 0s - loss: 0.2094 - acc: 0.9236
25000/25000 [==============================] - 1s 26us/sample - loss: 0.2112 - acc: 0.9222
Epoch 5/6

  512/25000 [..............................] - ETA: 0s - loss: 0.1536 - acc: 0.9531
 3072/25000 [==>...........................] - ETA: 0s - loss: 0.1755 - acc: 0.9385
 5632/25000 [=====>........................] - ETA: 0s - loss: 0.1755 - acc: 0.9400
 8192/25000 [========>.....................] - ETA: 0s - loss: 0.1770 - acc: 0.9404
10752/25000 [===========>..................] - ETA: 0s - loss: 0.1831 - acc: 0.9342
13312/25000 [==============>...............] - ETA: 0s - loss: 0.1885 - acc: 0.9315
15872/25000 [==================>...........] - ETA: 0s - loss: 0.1903 - acc: 0.9300
17920/25000 [====================>.........] - ETA: 0s - loss: 0.1894 - acc: 0.9305
20480/25000 [=======================>......] - ETA: 0s - loss: 0.1925 - acc: 0.9282
23040/25000 [==========================>...] - ETA: 0s - loss: 0.1933 - acc: 0.9280
25000/25000 [==============================] - 1s 27us/sample - loss: 0.1959 - acc: 0.9275
Epoch 6/6

  512/25000 [..............................] - ETA: 0s - loss: 0.1763 - acc: 0.9336
 3072/25000 [==>...........................] - ETA: 0s - loss: 0.1597 - acc: 0.9479
 5632/25000 [=====>........................] - ETA: 0s - loss: 0.1656 - acc: 0.9441
 8192/25000 [========>.....................] - ETA: 0s - loss: 0.1738 - acc: 0.9386
10752/25000 [===========>..................] - ETA: 0s - loss: 0.1737 - acc: 0.9387
13312/25000 [==============>...............] - ETA: 0s - loss: 0.1748 - acc: 0.9374
15872/25000 [==================>...........] - ETA: 0s - loss: 0.1762 - acc: 0.9355
18432/25000 [=====================>........] - ETA: 0s - loss: 0.1792 - acc: 0.9344
20992/25000 [========================>.....] - ETA: 0s - loss: 0.1812 - acc: 0.9333
23552/25000 [===========================>..] - ETA: 0s - loss: 0.1820 - acc: 0.9331
25000/25000 [==============================] - 1s 26us/sample - loss: 0.1832 - acc: 0.9323
results <- model %>% evaluate(x_test, y_test)

   32/25000 [..............................] - ETA: 37s - loss: 0.3705 - acc: 0.8438
 2560/25000 [==>...........................] - ETA: 0s - loss: 0.3205 - acc: 0.8715 
 4544/25000 [====>.........................] - ETA: 0s - loss: 0.3138 - acc: 0.8737
 6656/25000 [======>.......................] - ETA: 0s - loss: 0.3129 - acc: 0.8762
 8960/25000 [=========>....................] - ETA: 0s - loss: 0.3109 - acc: 0.8765
11296/25000 [============>.................] - ETA: 0s - loss: 0.3198 - acc: 0.8743
13472/25000 [===============>..............] - ETA: 0s - loss: 0.3152 - acc: 0.8759
15456/25000 [=================>............] - ETA: 0s - loss: 0.3137 - acc: 0.8764
17312/25000 [===================>..........] - ETA: 0s - loss: 0.3139 - acc: 0.8771
19168/25000 [======================>.......] - ETA: 0s - loss: 0.3094 - acc: 0.8790
20960/25000 [========================>.....] - ETA: 0s - loss: 0.3086 - acc: 0.8789
22592/25000 [==========================>...] - ETA: 0s - loss: 0.3083 - acc: 0.8790
24576/25000 [============================>.] - ETA: 0s - loss: 0.3071 - acc: 0.8794
25000/25000 [==============================] - 1s 27us/sample - loss: 0.3068 - acc: 0.8796

In the test data, we get accuracy of 87.96% with this model.

So, what happened there? What is the model learning? It learned a function that converts 5000 inputs into 16 hidden / latent / intermediate numbers, then converts those 16 into a different 16 intermediate numbers, and then those into 1 number at the output. Those intermediate functions are nonlinear, otherwise there wouldn’t be any gain from stacking them together. But we can get an approximate idea of how these inputs map to the single output by treating each layer as linear and multiplying the weights through: \(W^{(5000\times 1)}_{io} \approx W_{i1}^{(5000\times 16)} \times W_{12}^{(16\times 16)} \times W_{2o}^{(16\times 1)}\). Those aggregate weights can give us an approximate idea of the main effect of each input. (This will not work, generally speaking, in more complex models or contexts.)

model.weights.approx <- (get_weights(model)[[1]] %*% get_weights(model)[[3]] %*% get_weights(model)[[5]])[,1]
top_words <- reverse_word_index[as.character(1:5000)]
names(model.weights.approx) <- c("<PAD>","<START>","<UNK>",top_words[1:4997])
sort(model.weights.approx, dec=T)[1:20]
            7    refreshing   wonderfully 
    1.6517927     1.4027301     1.3446769 
         rare          edie        hooked 
    1.2740762     1.2113184     1.2086754 
    excellent extraordinary          noir 
    1.2033552     1.1940259     1.1527241 
     flawless   appreciated             8 
    1.1384029     1.1199644     1.1187128 
     captures       perfect         tight 
    1.0980826     1.0617401     1.0599815 
    vengeance        superb      superbly 
    1.0371490     1.0346370     1.0314989 
          gem    underrated 
    0.9951883     0.9877272 
sort(model.weights.approx, dec=F)[1:20]
         waste          worst disappointment 
     -1.705202      -1.627594      -1.589494 
        poorly        unfunny     incoherent 
     -1.453674      -1.413595      -1.393460 
   unwatchable        stinker          awful 
     -1.391467      -1.353770      -1.352675 
   forgettable      pointless      laughable 
     -1.306769      -1.297418      -1.297009 
         lousy      obnoxious          lacks 
     -1.252515      -1.246433      -1.233059 
         mst3k  uninteresting  disappointing 
     -1.222244      -1.215452      -1.211375 
          dull       tiresome 
     -1.198182      -1.186388 

“7” is interesting. It comes from reviews like #168, which ends like this:

sapply(train_data[[168]],function(index) {
  word <- if (index >= 3) reverse_word_index[[as.character(index - 3)]]
  if (!is.null(word)) word 
  else "?"})[201:215]
 [1] "creates"    "a"          "great"     
 [4] "monster"    "like"       "atmosphere"
 [7] "br"         "br"         "vote"      
[10] "7"          "and"        "half"      
[13] "out"        "of"         "10"        

That means this whole thing is cheating, as far as I’m concerned. Some of the reviews end with the text summarizing the rating with a number out of 10! And then use that to “predict” whether the review is positive or negative (which is probably based on that 10 point scale rating in the first place).

And we can look at these similarly to how we looked at the coefficients from the classifiers in our earlier notebook. I’ll also highlight those numbers.

# Plot weights
plot(colSums(x_train),model.weights.approx, pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Weights Learned in Deep Model, IMDB", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
2 x values <= 0 omitted from logarithmic plot
text(colSums(x_train),model.weights.approx, names(model.weights.approx),pos=4,cex=1.5*abs(model.weights.approx), col=rgb(0,0,0,.5*abs(model.weights.approx)))
text(colSums(x_train[,c("1","2","3","4","5","6","7","8","9")]),model.weights.approx[c("1","2","3","4","5","6","7","8","9")], names(model.weights.approx[c("1","2","3","4","5","6","7","8","9")]),pos=4,cex=1.5*model.weights.approx[c("1","2","3","4","5","6","7","8","9")], col=rgb(1,0,0,1))

Compare to shallow logistic classifier

It’s worth pointing out that “deep” didn’t buy us much. If we just estimate with a single sigmoid (logistic) layer, we get nearly identical results, with 87.72% accuracy.

logistic.mod <- keras_model_sequential() %>%
  layer_dense(units = 1, activation = "sigmoid", input_shape = c(5000))
logistic.mod %>% compile(
  optimizer = "rmsprop",
  loss = "binary_crossentropy",
  metrics = c("accuracy")
)
logistic.mod %>% fit(x_train, y_train, epochs = 6, batch_size = 512)
Epoch 1/6

  512/25000 [..............................] - ETA: 7s - loss: 0.6906 - acc: 0.5449
 3072/25000 [==>...........................] - ETA: 1s - loss: 0.6730 - acc: 0.5970
 6144/25000 [======>.......................] - ETA: 0s - loss: 0.6562 - acc: 0.6478
 8704/25000 [=========>....................] - ETA: 0s - loss: 0.6441 - acc: 0.6769
11776/25000 [=============>................] - ETA: 0s - loss: 0.6321 - acc: 0.7035
14848/25000 [================>.............] - ETA: 0s - loss: 0.6216 - acc: 0.7200
17920/25000 [====================>.........] - ETA: 0s - loss: 0.6121 - acc: 0.7338
20992/25000 [========================>.....] - ETA: 0s - loss: 0.6027 - acc: 0.7455
24064/25000 [===========================>..] - ETA: 0s - loss: 0.5942 - acc: 0.7552
25000/25000 [==============================] - 1s 30us/sample - loss: 0.5915 - acc: 0.7577
Epoch 2/6

  512/25000 [..............................] - ETA: 0s - loss: 0.5087 - acc: 0.8359
 4096/25000 [===>..........................] - ETA: 0s - loss: 0.5129 - acc: 0.8374
 7680/25000 [========>.....................] - ETA: 0s - loss: 0.5060 - acc: 0.8439
10752/25000 [===========>..................] - ETA: 0s - loss: 0.5025 - acc: 0.8424
13824/25000 [===============>..............] - ETA: 0s - loss: 0.4960 - acc: 0.8432
16384/25000 [==================>...........] - ETA: 0s - loss: 0.4928 - acc: 0.8435
19456/25000 [======================>.......] - ETA: 0s - loss: 0.4883 - acc: 0.8445
23040/25000 [==========================>...] - ETA: 0s - loss: 0.4822 - acc: 0.8465
25000/25000 [==============================] - 1s 21us/sample - loss: 0.4797 - acc: 0.8470
Epoch 3/6

  512/25000 [..............................] - ETA: 0s - loss: 0.4360 - acc: 0.8730
 3584/25000 [===>..........................] - ETA: 0s - loss: 0.4352 - acc: 0.8627
 6656/25000 [======>.......................] - ETA: 0s - loss: 0.4328 - acc: 0.8628
 9728/25000 [==========>...................] - ETA: 0s - loss: 0.4317 - acc: 0.8600
13312/25000 [==============>...............] - ETA: 0s - loss: 0.4269 - acc: 0.8621
16384/25000 [==================>...........] - ETA: 0s - loss: 0.4233 - acc: 0.8646
19968/25000 [======================>.......] - ETA: 0s - loss: 0.4199 - acc: 0.8659
23040/25000 [==========================>...] - ETA: 0s - loss: 0.4167 - acc: 0.8678
25000/25000 [==============================] - 1s 22us/sample - loss: 0.4156 - acc: 0.8677
Epoch 4/6

  512/25000 [..............................] - ETA: 0s - loss: 0.3826 - acc: 0.8828
 3584/25000 [===>..........................] - ETA: 0s - loss: 0.3741 - acc: 0.8867
 6656/25000 [======>.......................] - ETA: 0s - loss: 0.3778 - acc: 0.8824
 9728/25000 [==========>...................] - ETA: 0s - loss: 0.3783 - acc: 0.8814
12800/25000 [==============>...............] - ETA: 0s - loss: 0.3777 - acc: 0.8798
15872/25000 [==================>...........] - ETA: 0s - loss: 0.3770 - acc: 0.8792
18944/25000 [=====================>........] - ETA: 0s - loss: 0.3767 - acc: 0.8778
22016/25000 [=========================>....] - ETA: 0s - loss: 0.3759 - acc: 0.8775
25000/25000 [==============================] - 1s 22us/sample - loss: 0.3723 - acc: 0.8801
Epoch 5/6

  512/25000 [..............................] - ETA: 0s - loss: 0.3388 - acc: 0.8965
 3584/25000 [===>..........................] - ETA: 0s - loss: 0.3491 - acc: 0.8876
 6656/25000 [======>.......................] - ETA: 0s - loss: 0.3488 - acc: 0.8849
 9728/25000 [==========>...................] - ETA: 0s - loss: 0.3506 - acc: 0.8836
12800/25000 [==============>...............] - ETA: 0s - loss: 0.3476 - acc: 0.8852
15872/25000 [==================>...........] - ETA: 0s - loss: 0.3466 - acc: 0.8853
18944/25000 [=====================>........] - ETA: 0s - loss: 0.3439 - acc: 0.8873
22016/25000 [=========================>....] - ETA: 0s - loss: 0.3426 - acc: 0.8870
25000/25000 [==============================] - 1s 22us/sample - loss: 0.3413 - acc: 0.8874
Epoch 6/6

  512/25000 [..............................] - ETA: 0s - loss: 0.2844 - acc: 0.9180
 3584/25000 [===>..........................] - ETA: 0s - loss: 0.3194 - acc: 0.8979
 6656/25000 [======>.......................] - ETA: 0s - loss: 0.3240 - acc: 0.8947
 9728/25000 [==========>...................] - ETA: 0s - loss: 0.3214 - acc: 0.8976
12800/25000 [==============>...............] - ETA: 0s - loss: 0.3222 - acc: 0.8950
15872/25000 [==================>...........] - ETA: 0s - loss: 0.3199 - acc: 0.8959
18944/25000 [=====================>........] - ETA: 0s - loss: 0.3191 - acc: 0.8959
22016/25000 [=========================>....] - ETA: 0s - loss: 0.3192 - acc: 0.8949
25000/25000 [==============================] - 1s 34us/sample - loss: 0.3180 - acc: 0.8949
results <- logistic.mod %>% evaluate(x_test, y_test)

   32/25000 [..............................] - ETA: 51s - loss: 0.3305 - acc: 0.8438
 1856/25000 [=>............................] - ETA: 1s - loss: 0.3378 - acc: 0.8793 
 4000/25000 [===>..........................] - ETA: 0s - loss: 0.3450 - acc: 0.8770
 6400/25000 [======>.......................] - ETA: 0s - loss: 0.3449 - acc: 0.8748
 8864/25000 [=========>....................] - ETA: 0s - loss: 0.3430 - acc: 0.8751
11424/25000 [============>.................] - ETA: 0s - loss: 0.3455 - acc: 0.8733
13984/25000 [===============>..............] - ETA: 0s - loss: 0.3442 - acc: 0.8738
16704/25000 [===================>..........] - ETA: 0s - loss: 0.3444 - acc: 0.8725
19456/25000 [======================>.......] - ETA: 0s - loss: 0.3421 - acc: 0.8753
22240/25000 [=========================>....] - ETA: 0s - loss: 0.3415 - acc: 0.8759
25000/25000 [==============================] - 1s 23us/sample - loss: 0.3405 - acc: 0.8772

This we can interpret directly. It’s basically a ridge regression like we saw in the earlier classification notebook.

logmod.weights <- get_weights(logistic.mod)[[1]][,1]
top_words <- reverse_word_index[as.character(1:5000)]
names(logmod.weights) <- c("<PAD>","<START>","<UNK>",top_words[1:4997])
#Most positive words
sort(logmod.weights,dec=T)[1:20]
  excellent   wonderful     perfect 
  0.2910813   0.2788198   0.2777115 
          8       loved   fantastic 
  0.2681858   0.2612670   0.2599397 
   favorite           7        best 
  0.2543596   0.2516940   0.2509511 
     superb      highly  incredible 
  0.2509483   0.2504135   0.2423920 
      great         fun  definitely 
  0.2413276   0.2395824   0.2350059 
   captures         gem wonderfully 
  0.2334028   0.2325189   0.2299354 
       rare    superbly 
  0.2299205   0.2298841 
#Most negative words
sort(logmod.weights,dec=F)[1:20]
        awful          poor      terrible 
   -0.2987652    -0.2955288    -0.2779662 
        worst     redeeming        stupid 
   -0.2719380    -0.2696612    -0.2666538 
          bad         badly     pointless 
   -0.2661456    -0.2661098    -0.2659274 
        waste          mess         worse 
   -0.2612805    -0.2611499    -0.2577833 
            1        poorly          crap 
   -0.2522472    -0.2493045    -0.2488209 
       wasted    ridiculous disappointing 
   -0.2469009    -0.2464800    -0.2455676 
         save      pathetic 
   -0.2444713    -0.2432728 
# Plot weights
plot(colSums(x_train),logmod.weights, pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Weights Learned in Shallow Logistic Model, IMDB", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
2 x values <= 0 omitted from logarithmic plot
text(colSums(x_train),logmod.weights, names(logmod.weights),pos=4,cex=10*abs(logmod.weights), col=rgb(0,0,0,3*abs(logmod.weights)))

Compare to Naive Bayes

Just to have another comparison, let’s check with Naive Bayes.

library(quanteda)
colnames(x_train) <- colnames(x_test) <-  c("<PAD>","<START>","<UNK>",top_words[1:4997])
dfm.train <- as.dfm(x_train)
dfm.test <- as.dfm(x_test)
nb.mod <- textmodel_nb(dfm.train, y_train, distribution = "Bernoulli")
summary(nb.mod)

Call:
textmodel_nb.dfm(x = dfm.train, y = y_train, distribution = "Bernoulli")

Class Priors:
(showing first 2 elements)
  0   1 
0.5 0.5 

Estimated Feature Scores:
  <PAD> <START> <UNK>    the    and      a
0   0.5  0.4998   0.5 0.5006 0.4972 0.5007
1   0.5  0.5002   0.5 0.4994 0.5028 0.4993
      of     to    is    br     in     it
0 0.5003 0.5049 0.494 0.512 0.4962 0.5031
1 0.4997 0.4951 0.506 0.488 0.5038 0.4969
       i this   that    was     as    for
0 0.5199 0.51 0.5105 0.5283 0.4787 0.5025
1 0.4801 0.49 0.4895 0.4717 0.5213 0.4975
    with  movie    but   film     on    not
0 0.4939 0.5408 0.5126 0.4909 0.5088 0.5276
1 0.5061 0.4592 0.4874 0.5091 0.4912 0.4724
     you    are   his   have     he    be
0 0.5143 0.4932 0.462 0.5301 0.4785 0.527
1 0.4857 0.5068 0.538 0.4699 0.5215 0.473
y_test.pred.nb <- predict(nb.mod, newdata=dfm.test)
nb.class.table <- table(y_test,y_test.pred.nb)
sum(diag(nb.class.table/sum(nb.class.table)))
[1] 0.8418

As we might expect, that’s not as good.

As before, Naive Bayes overfits.

#Most positive words
sort(nb.mod$PcGw[2,],dec=T)[1:20]
        edie       paulie     flawless 
   0.9756098    0.9500000    0.9200000 
      gundam      matthau     superbly 
   0.9130435    0.9104478    0.9098361 
       felix   perfection     captures 
   0.9019608    0.8985507    0.8829268 
 wonderfully    masterful   refreshing 
   0.8821656    0.8735632    0.8693467 
breathtaking      mildred   delightful 
   0.8666667    0.8611111    0.8582677 
    polanski  beautifully       voight 
   0.8571429    0.8530120    0.8529412 
  underrated       powell 
   0.8515284    0.8484848 
#Most negative words
sort(nb.mod$PcGw[2,],dec=F)[1:20]
        uwe        boll unwatchable 
 0.03389831  0.03448276  0.04807692 
    stinker  incoherent       mst3k 
 0.04807692  0.05185185  0.06140351 
    unfunny       waste      seagal 
 0.06837607  0.07356322  0.07843137 
  atrocious   pointless      horrid 
 0.08762887  0.08874459  0.09009009 
  redeeming      drivel        blah 
 0.09148265  0.09836066  0.09876543 
      worst   laughable       lousy 
 0.09969122  0.09975062  0.10294118 
      awful     wasting 
 0.10595568  0.10666667 
# Plot weights
plot(colSums(x_train),nb.mod$PcGw[2,], pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Posterior Probabilities, Naive Bayes Classifier, IMDB", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
2 x values <= 0 omitted from logarithmic plot
text(colSums(x_train),nb.mod$PcGw[2,], names(logmod.weights),pos=4,cex=5*abs(.5-nb.mod$PcGw[2,]), col=rgb(0,0,0,1.5*abs(.5-nb.mod$PcGw[2,])))

Word Embeddings

Simple training of task-specific embeddings

Keras has its own “embedding” that you can use as a layer (the first layer) in a model.

max_features <- 5000
maxlen <- 500
imdb.s <- dataset_imdb(num_words = max_features)
c(c(x_train.s, y_train.s), c(x_test.s, y_test.s)) %<-% imdb.s
x_train.s <- pad_sequences(x_train.s, maxlen = maxlen)
x_test.s <- pad_sequences(x_test.s, maxlen = maxlen)
emb.mod <- keras_model_sequential() %>%
  layer_embedding(input_dim = max_features, output_dim = 6,  
                  input_length = maxlen) %>%
  layer_flatten() %>%
  layer_dense(units = 1, activation = "sigmoid")
emb.mod %>% compile(
  optimizer = "rmsprop",
  loss = "binary_crossentropy",
  metrics = c("acc")
)
summary(emb.mod)
Model: "sequential_5"
______________________________________________
Layer (type)        Output Shape       Param #
==============================================
embedding_1 (Embedd (None, 500, 6)     30000  
______________________________________________
flatten_1 (Flatten) (None, 3000)       0      
______________________________________________
dense_9 (Dense)     (None, 1)          3001   
==============================================
Total params: 33,001
Trainable params: 33,001
Non-trainable params: 0
______________________________________________
emb.history <- emb.mod %>% fit(
  x_train.s, y_train.s, 
  epochs = 6,
  batch_size = 32,
  validation_split = 0.2
)
Train on 20000 samples, validate on 5000 samples
Epoch 1/6

   32/20000 [..............................] - ETA: 1:44 - loss: 0.6993 - acc: 0.3438
  544/20000 [..............................] - ETA: 7s - loss: 0.6953 - acc: 0.4926  
 1024/20000 [>.............................] - ETA: 4s - loss: 0.6928 - acc: 0.5244
 1600/20000 [=>............................] - ETA: 3s - loss: 0.6935 - acc: 0.5131
 2176/20000 [==>...........................] - ETA: 3s - loss: 0.6936 - acc: 0.5087
 2304/20000 [==>...........................] - ETA: 3s - loss: 0.6936 - acc: 0.5078
 2784/20000 [===>..........................] - ETA: 3s - loss: 0.6931 - acc: 0.5133
 3328/20000 [===>..........................] - ETA: 2s - loss: 0.6928 - acc: 0.5177
 3904/20000 [====>.........................] - ETA: 2s - loss: 0.6929 - acc: 0.5169
 4512/20000 [=====>........................] - ETA: 2s - loss: 0.6924 - acc: 0.5204
 5152/20000 [======>.......................] - ETA: 2s - loss: 0.6917 - acc: 0.5297
 5760/20000 [=======>......................] - ETA: 1s - loss: 0.6907 - acc: 0.5361
 6336/20000 [========>.....................] - ETA: 1s - loss: 0.6899 - acc: 0.5406
 6944/20000 [=========>....................] - ETA: 1s - loss: 0.6883 - acc: 0.5462
 7520/20000 [==========>...................] - ETA: 1s - loss: 0.6867 - acc: 0.5527
 8128/20000 [===========>..................] - ETA: 1s - loss: 0.6847 - acc: 0.5613
 8736/20000 [============>.................] - ETA: 1s - loss: 0.6822 - acc: 0.5698
 9376/20000 [=============>................] - ETA: 1s - loss: 0.6786 - acc: 0.5806
 9984/20000 [=============>................] - ETA: 1s - loss: 0.6753 - acc: 0.5879
10624/20000 [==============>...............] - ETA: 1s - loss: 0.6716 - acc: 0.5955
11104/20000 [===============>..............] - ETA: 0s - loss: 0.6680 - acc: 0.6016
11744/20000 [================>.............] - ETA: 0s - loss: 0.6633 - acc: 0.6090
12384/20000 [=================>............] - ETA: 0s - loss: 0.6584 - acc: 0.6177
13056/20000 [==================>...........] - ETA: 0s - loss: 0.6520 - acc: 0.6269
13728/20000 [===================>..........] - ETA: 0s - loss: 0.6458 - acc: 0.6350
14400/20000 [====================>.........] - ETA: 0s - loss: 0.6397 - acc: 0.6421
15040/20000 [=====================>........] - ETA: 0s - loss: 0.6337 - acc: 0.6486
15712/20000 [======================>.......] - ETA: 0s - loss: 0.6273 - acc: 0.6554
16352/20000 [=======================>......] - ETA: 0s - loss: 0.6205 - acc: 0.6629
16992/20000 [========================>.....] - ETA: 0s - loss: 0.6141 - acc: 0.6690
17632/20000 [=========================>....] - ETA: 0s - loss: 0.6081 - acc: 0.6744
18272/20000 [==========================>...] - ETA: 0s - loss: 0.6018 - acc: 0.6799
18944/20000 [===========================>..] - ETA: 0s - loss: 0.5961 - acc: 0.6848
19584/20000 [============================>.] - ETA: 0s - loss: 0.5898 - acc: 0.6902
20000/20000 [==============================] - 2s 114us/sample - loss: 0.5859 - acc: 0.6935 - val_loss: 0.3996 - val_acc: 0.8404
Epoch 2/6

   32/20000 [..............................] - ETA: 2s - loss: 0.3811 - acc: 0.8750
  640/20000 [..............................] - ETA: 1s - loss: 0.3796 - acc: 0.8703
 1312/20000 [>.............................] - ETA: 1s - loss: 0.3743 - acc: 0.8598
 1952/20000 [=>............................] - ETA: 1s - loss: 0.3714 - acc: 0.8612
 2592/20000 [==>...........................] - ETA: 1s - loss: 0.3602 - acc: 0.8681
 3168/20000 [===>..........................] - ETA: 1s - loss: 0.3563 - acc: 0.8681
 3584/20000 [====>.........................] - ETA: 1s - loss: 0.3532 - acc: 0.8686
 4000/20000 [=====>........................] - ETA: 1s - loss: 0.3520 - acc: 0.8702
 4608/20000 [=====>........................] - ETA: 1s - loss: 0.3492 - acc: 0.8691
 5248/20000 [======>.......................] - ETA: 1s - loss: 0.3444 - acc: 0.8720
 5856/20000 [=======>......................] - ETA: 1s - loss: 0.3431 - acc: 0.8723
 6496/20000 [========>.....................] - ETA: 1s - loss: 0.3428 - acc: 0.8716
 7104/20000 [=========>....................] - ETA: 1s - loss: 0.3409 - acc: 0.8720
 7648/20000 [==========>...................] - ETA: 1s - loss: 0.3379 - acc: 0.8725
 8256/20000 [===========>..................] - ETA: 1s - loss: 0.3375 - acc: 0.8711
 8832/20000 [============>.................] - ETA: 1s - loss: 0.3364 - acc: 0.8718
 9504/20000 [=============>................] - ETA: 0s - loss: 0.3340 - acc: 0.8729
10144/20000 [==============>...............] - ETA: 0s - loss: 0.3318 - acc: 0.8733
10816/20000 [===============>..............] - ETA: 0s - loss: 0.3310 - acc: 0.8725
11392/20000 [================>.............] - ETA: 0s - loss: 0.3303 - acc: 0.8723
12064/20000 [=================>............] - ETA: 0s - loss: 0.3270 - acc: 0.8737
12768/20000 [==================>...........] - ETA: 0s - loss: 0.3268 - acc: 0.8728
13440/20000 [===================>..........] - ETA: 0s - loss: 0.3255 - acc: 0.8726
14080/20000 [====================>.........] - ETA: 0s - loss: 0.3242 - acc: 0.8736
14752/20000 [=====================>........] - ETA: 0s - loss: 0.3230 - acc: 0.8738
15424/20000 [======================>.......] - ETA: 0s - loss: 0.3217 - acc: 0.8740
16032/20000 [=======================>......] - ETA: 0s - loss: 0.3213 - acc: 0.8741
16672/20000 [========================>.....] - ETA: 0s - loss: 0.3188 - acc: 0.8751
17344/20000 [=========================>....] - ETA: 0s - loss: 0.3178 - acc: 0.8752
17952/20000 [=========================>....] - ETA: 0s - loss: 0.3183 - acc: 0.8747
18560/20000 [==========================>...] - ETA: 0s - loss: 0.3176 - acc: 0.8745
19200/20000 [===========================>..] - ETA: 0s - loss: 0.3155 - acc: 0.8751
19872/20000 [============================>.] - ETA: 0s - loss: 0.3143 - acc: 0.8751
20000/20000 [==============================] - 2s 95us/sample - loss: 0.3139 - acc: 0.8754 - val_loss: 0.2994 - val_acc: 0.8768
Epoch 3/6

   32/20000 [..............................] - ETA: 2s - loss: 0.4134 - acc: 0.7812
  608/20000 [..............................] - ETA: 1s - loss: 0.2737 - acc: 0.8766
 1152/20000 [>.............................] - ETA: 1s - loss: 0.2479 - acc: 0.9010
 1696/20000 [=>............................] - ETA: 1s - loss: 0.2495 - acc: 0.9033
 2304/20000 [==>...........................] - ETA: 1s - loss: 0.2642 - acc: 0.8984
 2944/20000 [===>..........................] - ETA: 1s - loss: 0.2545 - acc: 0.9015
 3584/20000 [====>.........................] - ETA: 1s - loss: 0.2531 - acc: 0.9012
 4224/20000 [=====>........................] - ETA: 1s - loss: 0.2515 - acc: 0.9018
 4864/20000 [======>.......................] - ETA: 1s - loss: 0.2491 - acc: 0.9036
 5024/20000 [======>.......................] - ETA: 1s - loss: 0.2506 - acc: 0.9021
 5568/20000 [=======>......................] - ETA: 1s - loss: 0.2464 - acc: 0.9039
 6176/20000 [========>.....................] - ETA: 1s - loss: 0.2457 - acc: 0.9041
 6752/20000 [=========>....................] - ETA: 1s - loss: 0.2465 - acc: 0.9034
 7328/20000 [=========>....................] - ETA: 1s - loss: 0.2478 - acc: 0.9042
 7904/20000 [==========>...................] - ETA: 1s - loss: 0.2474 - acc: 0.9044
 8480/20000 [===========>..................] - ETA: 1s - loss: 0.2482 - acc: 0.9044
 9088/20000 [============>.................] - ETA: 0s - loss: 0.2495 - acc: 0.9033
 9728/20000 [=============>................] - ETA: 0s - loss: 0.2476 - acc: 0.9045
10368/20000 [==============>...............] - ETA: 0s - loss: 0.2464 - acc: 0.9053
11008/20000 [===============>..............] - ETA: 0s - loss: 0.2479 - acc: 0.9045
11648/20000 [================>.............] - ETA: 0s - loss: 0.2496 - acc: 0.9038
12256/20000 [=================>............] - ETA: 0s - loss: 0.2495 - acc: 0.9036
12896/20000 [==================>...........] - ETA: 0s - loss: 0.2499 - acc: 0.9041
13568/20000 [===================>..........] - ETA: 0s - loss: 0.2495 - acc: 0.9040
14208/20000 [====================>.........] - ETA: 0s - loss: 0.2475 - acc: 0.9054
14848/20000 [=====================>........] - ETA: 0s - loss: 0.2473 - acc: 0.9052
15520/20000 [======================>.......] - ETA: 0s - loss: 0.2465 - acc: 0.9050
16192/20000 [=======================>......] - ETA: 0s - loss: 0.2456 - acc: 0.9048
16864/20000 [========================>.....] - ETA: 0s - loss: 0.2455 - acc: 0.9049
17504/20000 [=========================>....] - ETA: 0s - loss: 0.2458 - acc: 0.9044
18176/20000 [==========================>...] - ETA: 0s - loss: 0.2453 - acc: 0.9042
18816/20000 [===========================>..] - ETA: 0s - loss: 0.2457 - acc: 0.9041
19424/20000 [============================>.] - ETA: 0s - loss: 0.2461 - acc: 0.9040
20000/20000 [==============================] - 2s 96us/sample - loss: 0.2467 - acc: 0.9041 - val_loss: 0.2816 - val_acc: 0.8868
Epoch 4/6

   32/20000 [..............................] - ETA: 2s - loss: 0.2937 - acc: 0.8438
  608/20000 [..............................] - ETA: 1s - loss: 0.2283 - acc: 0.9079
 1184/20000 [>.............................] - ETA: 1s - loss: 0.2253 - acc: 0.9088
 1728/20000 [=>............................] - ETA: 1s - loss: 0.2159 - acc: 0.9161
 2240/20000 [==>...........................] - ETA: 1s - loss: 0.2223 - acc: 0.9152
 2848/20000 [===>..........................] - ETA: 1s - loss: 0.2223 - acc: 0.9143
 3424/20000 [====>.........................] - ETA: 1s - loss: 0.2259 - acc: 0.9130
 4064/20000 [=====>........................] - ETA: 1s - loss: 0.2221 - acc: 0.9149
 4736/20000 [======>.......................] - ETA: 1s - loss: 0.2210 - acc: 0.9158
 5376/20000 [=======>......................] - ETA: 1s - loss: 0.2198 - acc: 0.9157
 6048/20000 [========>.....................] - ETA: 1s - loss: 0.2186 - acc: 0.9153
 6208/20000 [========>.....................] - ETA: 1s - loss: 0.2176 - acc: 0.9158
 6752/20000 [=========>....................] - ETA: 1s - loss: 0.2190 - acc: 0.9154
 7360/20000 [==========>...................] - ETA: 1s - loss: 0.2191 - acc: 0.9147
 7968/20000 [==========>...................] - ETA: 1s - loss: 0.2191 - acc: 0.9154
 8576/20000 [===========>..................] - ETA: 1s - loss: 0.2188 - acc: 0.9159
 9184/20000 [============>.................] - ETA: 0s - loss: 0.2181 - acc: 0.9167
 9760/20000 [=============>................] - ETA: 0s - loss: 0.2165 - acc: 0.9167
10336/20000 [==============>...............] - ETA: 0s - loss: 0.2184 - acc: 0.9158
10944/20000 [===============>..............] - ETA: 0s - loss: 0.2193 - acc: 0.9147
11616/20000 [================>.............] - ETA: 0s - loss: 0.2167 - acc: 0.9156
12256/20000 [=================>............] - ETA: 0s - loss: 0.2163 - acc: 0.9157
12896/20000 [==================>...........] - ETA: 0s - loss: 0.2162 - acc: 0.9163
13536/20000 [===================>..........] - ETA: 0s - loss: 0.2163 - acc: 0.9165
14208/20000 [====================>.........] - ETA: 0s - loss: 0.2159 - acc: 0.9166
14880/20000 [=====================>........] - ETA: 0s - loss: 0.2148 - acc: 0.9168
15552/20000 [======================>.......] - ETA: 0s - loss: 0.2158 - acc: 0.9164
16224/20000 [=======================>......] - ETA: 0s - loss: 0.2155 - acc: 0.9162
16896/20000 [========================>.....] - ETA: 0s - loss: 0.2160 - acc: 0.9161
17568/20000 [=========================>....] - ETA: 0s - loss: 0.2163 - acc: 0.9158
18208/20000 [==========================>...] - ETA: 0s - loss: 0.2166 - acc: 0.9159
18816/20000 [===========================>..] - ETA: 0s - loss: 0.2169 - acc: 0.9157
19456/20000 [============================>.] - ETA: 0s - loss: 0.2168 - acc: 0.9160
20000/20000 [==============================] - 2s 96us/sample - loss: 0.2166 - acc: 0.9161 - val_loss: 0.2762 - val_acc: 0.8908
Epoch 5/6

   32/20000 [..............................] - ETA: 2s - loss: 0.1332 - acc: 0.9688
  608/20000 [..............................] - ETA: 1s - loss: 0.1757 - acc: 0.9359
 1184/20000 [>.............................] - ETA: 1s - loss: 0.1815 - acc: 0.9341
 1728/20000 [=>............................] - ETA: 1s - loss: 0.1853 - acc: 0.9282
 2304/20000 [==>...........................] - ETA: 1s - loss: 0.1842 - acc: 0.9293
 2848/20000 [===>..........................] - ETA: 1s - loss: 0.1878 - acc: 0.9266
 3456/20000 [====>.........................] - ETA: 1s - loss: 0.1837 - acc: 0.9280
 4096/20000 [=====>........................] - ETA: 1s - loss: 0.1930 - acc: 0.9236
 4736/20000 [======>.......................] - ETA: 1s - loss: 0.1906 - acc: 0.9250
 5376/20000 [=======>......................] - ETA: 1s - loss: 0.1915 - acc: 0.9249
 6048/20000 [========>.....................] - ETA: 1s - loss: 0.1887 - acc: 0.9256
 6720/20000 [=========>....................] - ETA: 1s - loss: 0.1853 - acc: 0.9269
 7392/20000 [==========>...................] - ETA: 1s - loss: 0.1856 - acc: 0.9267
 7520/20000 [==========>...................] - ETA: 1s - loss: 0.1857 - acc: 0.9266
 8064/20000 [===========>..................] - ETA: 1s - loss: 0.1882 - acc: 0.9265
 8640/20000 [===========>..................] - ETA: 1s - loss: 0.1899 - acc: 0.9258
 9216/20000 [============>.................] - ETA: 0s - loss: 0.1880 - acc: 0.9266
 9792/20000 [=============>................] - ETA: 0s - loss: 0.1885 - acc: 0.9266
10400/20000 [==============>...............] - ETA: 0s - loss: 0.1903 - acc: 0.9265
10976/20000 [===============>..............] - ETA: 0s - loss: 0.1919 - acc: 0.9260
11584/20000 [================>.............] - ETA: 0s - loss: 0.1913 - acc: 0.9260
12256/20000 [=================>............] - ETA: 0s - loss: 0.1911 - acc: 0.9264
12896/20000 [==================>...........] - ETA: 0s - loss: 0.1926 - acc: 0.9256
13568/20000 [===================>..........] - ETA: 0s - loss: 0.1930 - acc: 0.9253
14240/20000 [====================>.........] - ETA: 0s - loss: 0.1941 - acc: 0.9250
14912/20000 [=====================>........] - ETA: 0s - loss: 0.1954 - acc: 0.9244
15552/20000 [======================>.......] - ETA: 0s - loss: 0.1950 - acc: 0.9248
16224/20000 [=======================>......] - ETA: 0s - loss: 0.1957 - acc: 0.9248
16896/20000 [========================>.....] - ETA: 0s - loss: 0.1943 - acc: 0.9257
17568/20000 [=========================>....] - ETA: 0s - loss: 0.1943 - acc: 0.9258
18208/20000 [==========================>...] - ETA: 0s - loss: 0.1947 - acc: 0.9255
18880/20000 [===========================>..] - ETA: 0s - loss: 0.1949 - acc: 0.9251
19552/20000 [============================>.] - ETA: 0s - loss: 0.1969 - acc: 0.9244
20000/20000 [==============================] - 2s 95us/sample - loss: 0.1969 - acc: 0.9245 - val_loss: 0.2784 - val_acc: 0.8898
Epoch 6/6

   32/20000 [..............................] - ETA: 2s - loss: 0.1332 - acc: 0.9688
  608/20000 [..............................] - ETA: 1s - loss: 0.1955 - acc: 0.9243
 1184/20000 [>.............................] - ETA: 1s - loss: 0.1983 - acc: 0.9231
 1728/20000 [=>............................] - ETA: 1s - loss: 0.1830 - acc: 0.9317
 2272/20000 [==>...........................] - ETA: 1s - loss: 0.1808 - acc: 0.9340
 2784/20000 [===>..........................] - ETA: 1s - loss: 0.1789 - acc: 0.9357
 3392/20000 [====>.........................] - ETA: 1s - loss: 0.1872 - acc: 0.9337
 4032/20000 [=====>........................] - ETA: 1s - loss: 0.1876 - acc: 0.9328
 4672/20000 [======>.......................] - ETA: 1s - loss: 0.1877 - acc: 0.9317
 5344/20000 [=======>......................] - ETA: 1s - loss: 0.1860 - acc: 0.9311
 5984/20000 [=======>......................] - ETA: 1s - loss: 0.1825 - acc: 0.9325
 6656/20000 [========>.....................] - ETA: 1s - loss: 0.1816 - acc: 0.9325
 7296/20000 [=========>....................] - ETA: 1s - loss: 0.1826 - acc: 0.9313
 7936/20000 [==========>...................] - ETA: 1s - loss: 0.1811 - acc: 0.9330
 8544/20000 [===========>..................] - ETA: 0s - loss: 0.1786 - acc: 0.9346
 8672/20000 [============>.................] - ETA: 1s - loss: 0.1797 - acc: 0.9344
 9184/20000 [============>.................] - ETA: 0s - loss: 0.1828 - acc: 0.9329
 9760/20000 [=============>................] - ETA: 0s - loss: 0.1828 - acc: 0.9333
10368/20000 [==============>...............] - ETA: 0s - loss: 0.1820 - acc: 0.9345
10944/20000 [===============>..............] - ETA: 0s - loss: 0.1821 - acc: 0.9343
11584/20000 [================>.............] - ETA: 0s - loss: 0.1806 - acc: 0.9351
12160/20000 [=================>............] - ETA: 0s - loss: 0.1808 - acc: 0.9346
12800/20000 [==================>...........] - ETA: 0s - loss: 0.1801 - acc: 0.9345
13472/20000 [===================>..........] - ETA: 0s - loss: 0.1800 - acc: 0.9343
14112/20000 [====================>.........] - ETA: 0s - loss: 0.1794 - acc: 0.9349
14752/20000 [=====================>........] - ETA: 0s - loss: 0.1796 - acc: 0.9350
15328/20000 [=====================>........] - ETA: 0s - loss: 0.1792 - acc: 0.9346
15936/20000 [======================>.......] - ETA: 0s - loss: 0.1805 - acc: 0.9337
16544/20000 [=======================>......] - ETA: 0s - loss: 0.1800 - acc: 0.9339
17184/20000 [========================>.....] - ETA: 0s - loss: 0.1805 - acc: 0.9334
17792/20000 [=========================>....] - ETA: 0s - loss: 0.1805 - acc: 0.9336
18464/20000 [==========================>...] - ETA: 0s - loss: 0.1810 - acc: 0.9335
19136/20000 [===========================>..] - ETA: 0s - loss: 0.1810 - acc: 0.9335
19776/20000 [============================>.] - ETA: 0s - loss: 0.1821 - acc: 0.9329
20000/20000 [==============================] - 2s 97us/sample - loss: 0.1821 - acc: 0.9329 - val_loss: 0.2811 - val_acc: 0.8918
emb.results <- emb.mod %>% evaluate(x_test.s, y_test.s)

   32/25000 [..............................] - ETA: 0s - loss: 0.2774 - acc: 0.8750
 2016/25000 [=>............................] - ETA: 0s - loss: 0.2774 - acc: 0.8800
 4224/25000 [====>.........................] - ETA: 0s - loss: 0.2870 - acc: 0.8859
 6592/25000 [======>.......................] - ETA: 0s - loss: 0.2831 - acc: 0.8844
 8736/25000 [=========>....................] - ETA: 0s - loss: 0.2829 - acc: 0.8844
10624/25000 [===========>..................] - ETA: 0s - loss: 0.2887 - acc: 0.8830
12896/25000 [==============>...............] - ETA: 0s - loss: 0.2899 - acc: 0.8838
15424/25000 [=================>............] - ETA: 0s - loss: 0.2864 - acc: 0.8843
18016/25000 [====================>.........] - ETA: 0s - loss: 0.2847 - acc: 0.8850
20512/25000 [=======================>......] - ETA: 0s - loss: 0.2816 - acc: 0.8862
23168/25000 [==========================>...] - ETA: 0s - loss: 0.2807 - acc: 0.8864
25000/25000 [==============================] - 1s 22us/sample - loss: 0.2796 - acc: 0.8868

That got us up to 88.68%, the best yet, so the embedding layer bought us something.

I’ve never seen anybody else do this, but you can look inside the embedding weights themselves to get a sense of what’s going on. For space sake, we’ll just plot the embeddings of the first 500 words.

embeds.6 <- get_weights(emb.mod)[[1]]
rownames(embeds.6) <- c("??","??","??","??",top_words[1:4996])
colnames(embeds.6) <- paste("Dim",1:6)
pairs(embeds.6[1:500,], col=rgb(0,0,0,.5),cex=.5, main="Six-dimensional Embeddings Trained in Classification Model")

Ahhhhh… our 6-dimensional embeddings are pretty close to one-dimensional. Why? Because they are trained for the essentially one-dimensional classification of positive-negative sentiment.

We can tease this out with something like PCA:

pc.emb <- princomp(embeds.6)
summary(pc.emb)
Importance of components:
                          Comp.1     Comp.2
Standard deviation     0.3566176 0.05058622
Proportion of Variance 0.9316192 0.01874552
Cumulative Proportion  0.9316192 0.95036475
                           Comp.3     Comp.4
Standard deviation     0.04372121 0.04287706
Proportion of Variance 0.01400288 0.01346738
Cumulative Proportion  0.96436763 0.97783501
                           Comp.5     Comp.6
Standard deviation     0.03973241 0.03804071
Proportion of Variance 0.01156439 0.01060059
Cumulative Proportion  0.98939941 1.00000000
# Most positive words (main dimension of embeddings)
sort(-pc.emb$scores[,1],dec=T)[1:20]
  excellent           7     perfect 
   1.660139    1.647258    1.629065 
        gem           8    favorite 
   1.490264    1.387230    1.360757 
 refreshing wonderfully      superb 
   1.355163    1.303961    1.256989 
   captures     amazing   wonderful 
   1.209862    1.177840    1.133390 
       rare      highly    superbly 
   1.133343    1.121740    1.111449 
  marvelous        noir   perfectly 
   1.094616    1.089450    1.087435 
     finest   favorites 
   1.064197    1.052959 
# Most negative words (main dimension of embeddings)
sort(-pc.emb$scores[,1],dec=F)[1:20]
         worst          waste          awful 
     -2.426991      -2.287825      -2.110015 
        poorly      pointless           dull 
     -1.897196      -1.736179      -1.693020 
     laughable          avoid           mess 
     -1.663206      -1.601402      -1.572703 
disappointment  disappointing          worse 
     -1.530430      -1.523112      -1.520477 
      terrible        unfunny          badly 
     -1.484618      -1.475630      -1.440923 
        boring      redeeming          mst3k 
     -1.433651      -1.406507      -1.398853 
         lacks       dreadful 
     -1.370802      -1.364966 
# Most positive words (2nd dimension of embeddings)
sort(pc.emb$scores[,2],dec=T)[1:20]
      didn't      academy       coming 
   0.1993631    0.1894763    0.1867834 
       piece         must protagonists 
   0.1674755    0.1651370    0.1625642 
     towards      delight           so 
   0.1594424    0.1576414    0.1528698 
      reason      through         self 
   0.1481342    0.1480310    0.1461657 
     opening     possibly           to 
   0.1456153    0.1450613    0.1421859 
     victims        known    involving 
   0.1389948    0.1389144    0.1368349 
       point        don't 
   0.1362795    0.1351084 
# Most negative words (2nd dimension of embeddings)
sort(pc.emb$scores[,2],dec=F)[1:20]
   bizarre      takes      rings     become 
-0.1997089 -0.1818287 -0.1748362 -0.1709388 
     frank    descent      steps      those 
-0.1627424 -0.1616023 -0.1497373 -0.1496597 
  students       tons      count     mother 
-0.1494445 -0.1474266 -0.1466301 -0.1449812 
      stop     animal      group    segment 
-0.1447104 -0.1437884 -0.1400553 -0.1348926 
       its     toilet   laughing     father 
-0.1347013 -0.1344255 -0.1320045 -0.1316030 
# Plot weights
plot(colSums(x_train),-pc.emb$scores[,1], pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Principal Component of Embeddings Layer", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
2 x values <= 0 omitted from logarithmic plot
text(colSums(x_train),-pc.emb$scores[,1], rownames(embeds.6),pos=4,cex=1*abs(pc.emb$scores[,1]), col=rgb(0,0,0,.4*abs(pc.emb$scores[,1])))

This first embedding dimension is very similar to the weights we learned in the shallow logistic network:

cor(logmod.weights[3:4999],-pc.emb$scores[4:5000,1])
[1] 0.8978127

The second dimension captures something … perhaps in more nuanced criticism … that acted as a subtle bit of extra information for our classifier that provided some slight improvement in perforamce. In effect, by letting the model learn six numbers per word instead of one, we gave it the opportunity to learn some more subtle characteristics of review sentiment.

# Plot weights
plot(colSums(x_train),pc.emb$scores[,2], pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Second Principal Component of Embeddings Layer", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
2 x values <= 0 omitted from logarithmic plot
text(colSums(x_train),pc.emb$scores[,2], rownames(embeds.6),pos=4,cex=10*abs(pc.emb$scores[,2]), col=rgb(0,0,0,3*abs(pc.emb$scores[,2])))

Using pretrained embeddings in classifier

The commented out code assumes you have downloaded the 822M zip file glove.6B.zip from https://nlp.stanford.edu/projects/glove/ and unzipped the folder (more than 2G). It then reads in the smallest file – defining 50 dimensional embeddings for 400000 tokens – 171M, and then creates embedding_matrix of 5000 x 50 for the 5000 top words in the imdb data (a 1.9M object).

I have included this object in an rds file, and the readRDS command reads in this object directly.

# glove_dir <- "../Embeddings/glove.6B"
# lines <- readLines(file.path(glove_dir, "glove.6B.50d.txt"))
# embeddings_index <- new.env(hash = TRUE, parent = emptyenv())
# for (i in 1:length(lines)) {
#   line <- lines[[i]]
#   values <- strsplit(line, " ")[[1]]
#   word <- values[[1]]
#   embeddings_index[[word]] <- as.double(values[-1])
# }
# cat("Found", length(embeddings_index), "word vectors.\n")
max_words=5000
embedding_dim <- 50
# embedding_matrix <- array(0, c(max_words, embedding_dim))
# for (word in names(word_index)) {
#   index <- word_index[[word]]
#   if (index < max_words) {
#     embedding_vector <- embeddings_index[[word]]
#     if (!is.null(embedding_vector))
#       embedding_matrix[index+1,] <- embedding_vector
#     }
# }
# saveRDS(embedding_matrix, file="glove_imdb_example.rds")
embedding_matrix <- readRDS("glove_imdb_example.rds")
glove.mod <- keras_model_sequential() %>%
  layer_embedding(input_dim = max_words, output_dim = embedding_dim,
                  input_length = maxlen) %>%
  layer_flatten() %>%
  layer_dense(units = 1, activation = "sigmoid")
summary(glove.mod)
Model: "sequential_7"
______________________________________________
Layer (type)        Output Shape       Param #
==============================================
embedding_2 (Embedd (None, 500, 50)    250000 
______________________________________________
flatten_2 (Flatten) (None, 25000)      0      
______________________________________________
dense_10 (Dense)    (None, 1)          25001  
==============================================
Total params: 275,001
Trainable params: 275,001
Non-trainable params: 0
______________________________________________
get_layer(glove.mod, index = 1) %>%
  set_weights(list(embedding_matrix)) %>%
  freeze_weights()
glove.mod %>% compile(
  optimizer = "rmsprop",
  loss = "binary_crossentropy",
  metrics = c("acc")
)
glove.history <- glove.mod %>% fit(
  x_train.s, y_train.s, 
  epochs = 6,
  batch_size = 32,
  validation_split = 0.2
)
Train on 20000 samples, validate on 5000 samples
Epoch 1/6

   32/20000 [..............................] - ETA: 1:28 - loss: 0.7380 - acc: 0.4375
  608/20000 [..............................] - ETA: 6s - loss: 1.0776 - acc: 0.4951  
 1120/20000 [>.............................] - ETA: 4s - loss: 1.0127 - acc: 0.5036
 1696/20000 [=>............................] - ETA: 3s - loss: 0.9551 - acc: 0.5136
 1888/20000 [=>............................] - ETA: 3s - loss: 0.9557 - acc: 0.5138
 2272/20000 [==>...........................] - ETA: 3s - loss: 0.9370 - acc: 0.5167
 2688/20000 [===>..........................] - ETA: 3s - loss: 0.9219 - acc: 0.5205
 3136/20000 [===>..........................] - ETA: 2s - loss: 0.9209 - acc: 0.5217
 3584/20000 [====>.........................] - ETA: 2s - loss: 0.9329 - acc: 0.5179
 4032/20000 [=====>........................] - ETA: 2s - loss: 0.9219 - acc: 0.5191
 4480/20000 [=====>........................] - ETA: 2s - loss: 0.9222 - acc: 0.5174
 4928/20000 [======>.......................] - ETA: 2s - loss: 0.9222 - acc: 0.5172
 5376/20000 [=======>......................] - ETA: 2s - loss: 0.9210 - acc: 0.5177
 5856/20000 [=======>......................] - ETA: 2s - loss: 0.9206 - acc: 0.5162
 6304/20000 [========>.....................] - ETA: 1s - loss: 0.9181 - acc: 0.5173
 6752/20000 [=========>....................] - ETA: 1s - loss: 0.9151 - acc: 0.5185
 7200/20000 [=========>....................] - ETA: 1s - loss: 0.9141 - acc: 0.5200
 7648/20000 [==========>...................] - ETA: 1s - loss: 0.9105 - acc: 0.5205
 8096/20000 [===========>..................] - ETA: 1s - loss: 0.9110 - acc: 0.5210
 8576/20000 [===========>..................] - ETA: 1s - loss: 0.9076 - acc: 0.5220
 9056/20000 [============>.................] - ETA: 1s - loss: 0.9071 - acc: 0.5245
 9536/20000 [=============>................] - ETA: 1s - loss: 0.9098 - acc: 0.5235
 9984/20000 [=============>................] - ETA: 1s - loss: 0.9063 - acc: 0.5242
10464/20000 [==============>...............] - ETA: 1s - loss: 0.9043 - acc: 0.5249
10912/20000 [===============>..............] - ETA: 1s - loss: 0.8997 - acc: 0.5262
11360/20000 [================>.............] - ETA: 1s - loss: 0.8996 - acc: 0.5262
11744/20000 [================>.............] - ETA: 1s - loss: 0.9040 - acc: 0.5256
12160/20000 [=================>............] - ETA: 1s - loss: 0.8997 - acc: 0.5277
12544/20000 [=================>............] - ETA: 0s - loss: 0.9018 - acc: 0.5277
12960/20000 [==================>...........] - ETA: 0s - loss: 0.9003 - acc: 0.5289
13344/20000 [===================>..........] - ETA: 0s - loss: 0.8989 - acc: 0.5293
13728/20000 [===================>..........] - ETA: 0s - loss: 0.8972 - acc: 0.5295
14144/20000 [====================>.........] - ETA: 0s - loss: 0.8963 - acc: 0.5308
14528/20000 [====================>.........] - ETA: 0s - loss: 0.8991 - acc: 0.5296
14944/20000 [=====================>........] - ETA: 0s - loss: 0.9008 - acc: 0.5294
15360/20000 [======================>.......] - ETA: 0s - loss: 0.8999 - acc: 0.5294
15776/20000 [======================>.......] - ETA: 0s - loss: 0.8989 - acc: 0.5304
16160/20000 [=======================>......] - ETA: 0s - loss: 0.8996 - acc: 0.5306
16512/20000 [=======================>......] - ETA: 0s - loss: 0.9003 - acc: 0.5309
16928/20000 [========================>.....] - ETA: 0s - loss: 0.9035 - acc: 0.5312
17344/20000 [=========================>....] - ETA: 0s - loss: 0.9015 - acc: 0.5317
17792/20000 [=========================>....] - ETA: 0s - loss: 0.9015 - acc: 0.5312
18272/20000 [==========================>...] - ETA: 0s - loss: 0.9036 - acc: 0.5304
18720/20000 [===========================>..] - ETA: 0s - loss: 0.9034 - acc: 0.5299
19200/20000 [===========================>..] - ETA: 0s - loss: 0.9047 - acc: 0.5294
19680/20000 [============================>.] - ETA: 0s - loss: 0.9037 - acc: 0.5299
20000/20000 [==============================] - 3s 151us/sample - loss: 0.9058 - acc: 0.5295 - val_loss: 0.9894 - val_acc: 0.5142
Epoch 2/6

   32/20000 [..............................] - ETA: 2s - loss: 0.8618 - acc: 0.5625
  448/20000 [..............................] - ETA: 2s - loss: 0.6714 - acc: 0.6384
  864/20000 [>.............................] - ETA: 2s - loss: 0.6573 - acc: 0.6516
 1312/20000 [>.............................] - ETA: 2s - loss: 0.6468 - acc: 0.6585
 1600/20000 [=>............................] - ETA: 2s - loss: 0.6630 - acc: 0.6525
 2016/20000 [==>...........................] - ETA: 2s - loss: 0.6600 - acc: 0.6493
 2432/20000 [==>...........................] - ETA: 2s - loss: 0.6660 - acc: 0.6460
 2784/20000 [===>..........................] - ETA: 2s - loss: 0.6854 - acc: 0.6376
 3168/20000 [===>..........................] - ETA: 2s - loss: 0.6801 - acc: 0.6417
 3552/20000 [====>.........................] - ETA: 2s - loss: 0.6820 - acc: 0.6441
 3968/20000 [====>.........................] - ETA: 2s - loss: 0.6812 - acc: 0.6426
 4384/20000 [=====>........................] - ETA: 2s - loss: 0.6885 - acc: 0.6410
 4832/20000 [======>.......................] - ETA: 2s - loss: 0.6787 - acc: 0.6428
 5280/20000 [======>.......................] - ETA: 1s - loss: 0.6856 - acc: 0.6426
 5760/20000 [=======>......................] - ETA: 1s - loss: 0.6849 - acc: 0.6431
 6240/20000 [========>.....................] - ETA: 1s - loss: 0.6843 - acc: 0.6434
 6720/20000 [=========>....................] - ETA: 1s - loss: 0.6910 - acc: 0.6403
 7200/20000 [=========>....................] - ETA: 1s - loss: 0.6936 - acc: 0.6372
 7680/20000 [==========>...................] - ETA: 1s - loss: 0.6939 - acc: 0.6380
 8160/20000 [===========>..................] - ETA: 1s - loss: 0.7001 - acc: 0.6348
 8608/20000 [===========>..................] - ETA: 1s - loss: 0.7002 - acc: 0.6349
 9056/20000 [============>.................] - ETA: 1s - loss: 0.7016 - acc: 0.6341
 9504/20000 [=============>................] - ETA: 1s - loss: 0.7010 - acc: 0.6350
 9952/20000 [=============>................] - ETA: 1s - loss: 0.7062 - acc: 0.6330
10400/20000 [==============>...............] - ETA: 1s - loss: 0.7060 - acc: 0.6334
10880/20000 [===============>..............] - ETA: 1s - loss: 0.7072 - acc: 0.6337
11360/20000 [================>.............] - ETA: 1s - loss: 0.7123 - acc: 0.6313
11840/20000 [================>.............] - ETA: 0s - loss: 0.7152 - acc: 0.6296
12320/20000 [=================>............] - ETA: 0s - loss: 0.7162 - acc: 0.6295
12800/20000 [==================>...........] - ETA: 0s - loss: 0.7186 - acc: 0.6283
13312/20000 [==================>...........] - ETA: 0s - loss: 0.7186 - acc: 0.6282
13792/20000 [===================>..........] - ETA: 0s - loss: 0.7177 - acc: 0.6296
14272/20000 [====================>.........] - ETA: 0s - loss: 0.7168 - acc: 0.6295
14752/20000 [=====================>........] - ETA: 0s - loss: 0.7178 - acc: 0.6290
15264/20000 [=====================>........] - ETA: 0s - loss: 0.7205 - acc: 0.6274
15744/20000 [======================>.......] - ETA: 0s - loss: 0.7213 - acc: 0.6264
16224/20000 [=======================>......] - ETA: 0s - loss: 0.7200 - acc: 0.6273
16704/20000 [========================>.....] - ETA: 0s - loss: 0.7204 - acc: 0.6273
17184/20000 [========================>.....] - ETA: 0s - loss: 0.7233 - acc: 0.6272
17664/20000 [=========================>....] - ETA: 0s - loss: 0.7244 - acc: 0.6268
18112/20000 [==========================>...] - ETA: 0s - loss: 0.7264 - acc: 0.6265
18592/20000 [==========================>...] - ETA: 0s - loss: 0.7265 - acc: 0.6271
19072/20000 [===========================>..] - ETA: 0s - loss: 0.7296 - acc: 0.6253
19552/20000 [============================>.] - ETA: 0s - loss: 0.7280 - acc: 0.6260
20000/20000 [==============================] - 3s 133us/sample - loss: 0.7298 - acc: 0.6247 - val_loss: 1.0100 - val_acc: 0.5286
Epoch 3/6

   32/20000 [..............................] - ETA: 2s - loss: 0.8882 - acc: 0.5312
  480/20000 [..............................] - ETA: 2s - loss: 0.5785 - acc: 0.6979
  928/20000 [>.............................] - ETA: 2s - loss: 0.5722 - acc: 0.7069
 1344/20000 [=>............................] - ETA: 2s - loss: 0.5547 - acc: 0.7217
 1728/20000 [=>............................] - ETA: 2s - loss: 0.5673 - acc: 0.7135
 2112/20000 [==>...........................] - ETA: 2s - loss: 0.5545 - acc: 0.7188
 2592/20000 [==>...........................] - ETA: 2s - loss: 0.5692 - acc: 0.7118
 3072/20000 [===>..........................] - ETA: 2s - loss: 0.5665 - acc: 0.7139
 3552/20000 [====>.........................] - ETA: 1s - loss: 0.5755 - acc: 0.7106
 3648/20000 [====>.........................] - ETA: 2s - loss: 0.5775 - acc: 0.7086
 4064/20000 [=====>........................] - ETA: 2s - loss: 0.5823 - acc: 0.7052
 4544/20000 [=====>........................] - ETA: 1s - loss: 0.5784 - acc: 0.7038
 5024/20000 [======>.......................] - ETA: 1s - loss: 0.5841 - acc: 0.7006
 5504/20000 [=======>......................] - ETA: 1s - loss: 0.5861 - acc: 0.6973
 5984/20000 [=======>......................] - ETA: 1s - loss: 0.5928 - acc: 0.6950
 6432/20000 [========>.....................] - ETA: 1s - loss: 0.5941 - acc: 0.6950
 6848/20000 [=========>....................] - ETA: 1s - loss: 0.5936 - acc: 0.6958
 7328/20000 [=========>....................] - ETA: 1s - loss: 0.5917 - acc: 0.6957
 7776/20000 [==========>...................] - ETA: 1s - loss: 0.5960 - acc: 0.6925
 8224/20000 [===========>..................] - ETA: 1s - loss: 0.6000 - acc: 0.6908
 8704/20000 [============>.................] - ETA: 1s - loss: 0.6024 - acc: 0.6891
 9184/20000 [============>.................] - ETA: 1s - loss: 0.6034 - acc: 0.6886
 9664/20000 [=============>................] - ETA: 1s - loss: 0.6089 - acc: 0.6878
10144/20000 [==============>...............] - ETA: 1s - loss: 0.6095 - acc: 0.6867
10624/20000 [==============>...............] - ETA: 1s - loss: 0.6076 - acc: 0.6875
11104/20000 [===============>..............] - ETA: 1s - loss: 0.6120 - acc: 0.6862
11584/20000 [================>.............] - ETA: 0s - loss: 0.6109 - acc: 0.6879
12064/20000 [=================>............] - ETA: 0s - loss: 0.6130 - acc: 0.6883
12544/20000 [=================>............] - ETA: 0s - loss: 0.6144 - acc: 0.6877
13024/20000 [==================>...........] - ETA: 0s - loss: 0.6145 - acc: 0.6878
13504/20000 [===================>..........] - ETA: 0s - loss: 0.6174 - acc: 0.6870
13984/20000 [===================>..........] - ETA: 0s - loss: 0.6178 - acc: 0.6866
14464/20000 [====================>.........] - ETA: 0s - loss: 0.6211 - acc: 0.6844
14944/20000 [=====================>........] - ETA: 0s - loss: 0.6225 - acc: 0.6829
15392/20000 [======================>.......] - ETA: 0s - loss: 0.6232 - acc: 0.6826
15872/20000 [======================>.......] - ETA: 0s - loss: 0.6232 - acc: 0.6828
16352/20000 [=======================>......] - ETA: 0s - loss: 0.6279 - acc: 0.6800
16832/20000 [========================>.....] - ETA: 0s - loss: 0.6293 - acc: 0.6795
17312/20000 [========================>.....] - ETA: 0s - loss: 0.6293 - acc: 0.6790
17792/20000 [=========================>....] - ETA: 0s - loss: 0.6301 - acc: 0.6787
18272/20000 [==========================>...] - ETA: 0s - loss: 0.6316 - acc: 0.6788
18752/20000 [===========================>..] - ETA: 0s - loss: 0.6304 - acc: 0.6788
19232/20000 [===========================>..] - ETA: 0s - loss: 0.6327 - acc: 0.6774
19712/20000 [============================>.] - ETA: 0s - loss: 0.6369 - acc: 0.6758
20000/20000 [==============================] - 3s 130us/sample - loss: 0.6382 - acc: 0.6751 - val_loss: 0.9295 - val_acc: 0.5582
Epoch 4/6

   32/20000 [..............................] - ETA: 2s - loss: 0.5421 - acc: 0.6875
  480/20000 [..............................] - ETA: 2s - loss: 0.4418 - acc: 0.7792
  928/20000 [>.............................] - ETA: 2s - loss: 0.4337 - acc: 0.7953
 1376/20000 [=>............................] - ETA: 2s - loss: 0.4594 - acc: 0.7769
 1824/20000 [=>............................] - ETA: 2s - loss: 0.4680 - acc: 0.7697
 2208/20000 [==>...........................] - ETA: 2s - loss: 0.4781 - acc: 0.7645
 2656/20000 [==>...........................] - ETA: 2s - loss: 0.4976 - acc: 0.7534
 3104/20000 [===>..........................] - ETA: 1s - loss: 0.5046 - acc: 0.7500
 3584/20000 [====>.........................] - ETA: 1s - loss: 0.5100 - acc: 0.7464
 4064/20000 [=====>........................] - ETA: 1s - loss: 0.5076 - acc: 0.7463
 4480/20000 [=====>........................] - ETA: 1s - loss: 0.5154 - acc: 0.7420
 4896/20000 [======>.......................] - ETA: 1s - loss: 0.5259 - acc: 0.7359
 5376/20000 [=======>......................] - ETA: 1s - loss: 0.5242 - acc: 0.7347
 5856/20000 [=======>......................] - ETA: 1s - loss: 0.5289 - acc: 0.7324
 6336/20000 [========>.....................] - ETA: 1s - loss: 0.5264 - acc: 0.7325
 6784/20000 [=========>....................] - ETA: 1s - loss: 0.5322 - acc: 0.7304
 7264/20000 [=========>....................] - ETA: 1s - loss: 0.5371 - acc: 0.7285
 7744/20000 [==========>...................] - ETA: 1s - loss: 0.5371 - acc: 0.7279
 8224/20000 [===========>..................] - ETA: 1s - loss: 0.5389 - acc: 0.7268
 8672/20000 [============>.................] - ETA: 1s - loss: 0.5472 - acc: 0.7242
 9152/20000 [============>.................] - ETA: 1s - loss: 0.5467 - acc: 0.7242
 9632/20000 [=============>................] - ETA: 1s - loss: 0.5491 - acc: 0.7225
10112/20000 [==============>...............] - ETA: 1s - loss: 0.5475 - acc: 0.7233
10560/20000 [==============>...............] - ETA: 1s - loss: 0.5513 - acc: 0.7211
11040/20000 [===============>..............] - ETA: 1s - loss: 0.5518 - acc: 0.7212
11520/20000 [================>.............] - ETA: 0s - loss: 0.5544 - acc: 0.7194
12000/20000 [=================>............] - ETA: 0s - loss: 0.5559 - acc: 0.7182
12480/20000 [=================>............] - ETA: 0s - loss: 0.5589 - acc: 0.7169
12960/20000 [==================>...........] - ETA: 0s - loss: 0.5585 - acc: 0.7181
13440/20000 [===================>..........] - ETA: 0s - loss: 0.5606 - acc: 0.7170
13920/20000 [===================>..........] - ETA: 0s - loss: 0.5602 - acc: 0.7172
14400/20000 [====================>.........] - ETA: 0s - loss: 0.5633 - acc: 0.7156
14880/20000 [=====================>........] - ETA: 0s - loss: 0.5628 - acc: 0.7159
15360/20000 [======================>.......] - ETA: 0s - loss: 0.5662 - acc: 0.7137
15808/20000 [======================>.......] - ETA: 0s - loss: 0.5662 - acc: 0.7137
16288/20000 [=======================>......] - ETA: 0s - loss: 0.5682 - acc: 0.7129
16736/20000 [========================>.....] - ETA: 0s - loss: 0.5700 - acc: 0.7130
17216/20000 [========================>.....] - ETA: 0s - loss: 0.5718 - acc: 0.7121
17664/20000 [=========================>....] - ETA: 0s - loss: 0.5723 - acc: 0.7120
18144/20000 [==========================>...] - ETA: 0s - loss: 0.5738 - acc: 0.7120
18592/20000 [==========================>...] - ETA: 0s - loss: 0.5741 - acc: 0.7120
19072/20000 [===========================>..] - ETA: 0s - loss: 0.5758 - acc: 0.7108
19552/20000 [============================>.] - ETA: 0s - loss: 0.5769 - acc: 0.7103
20000/20000 [==============================] - 3s 130us/sample - loss: 0.5810 - acc: 0.7085 - val_loss: 0.9726 - val_acc: 0.5626
Epoch 5/6

   32/20000 [..............................] - ETA: 2s - loss: 0.4575 - acc: 0.7812
  480/20000 [..............................] - ETA: 2s - loss: 0.4780 - acc: 0.7563
  928/20000 [>.............................] - ETA: 2s - loss: 0.4801 - acc: 0.7565
 1344/20000 [=>............................] - ETA: 2s - loss: 0.4693 - acc: 0.7671
 1728/20000 [=>............................] - ETA: 2s - loss: 0.5131 - acc: 0.7448
 2144/20000 [==>...........................] - ETA: 2s - loss: 0.4931 - acc: 0.7551
 2592/20000 [==>...........................] - ETA: 2s - loss: 0.5033 - acc: 0.7481
 3040/20000 [===>..........................] - ETA: 2s - loss: 0.4943 - acc: 0.7510
 3520/20000 [====>.........................] - ETA: 1s - loss: 0.4906 - acc: 0.7523
 4000/20000 [=====>........................] - ETA: 1s - loss: 0.4904 - acc: 0.7535
 4448/20000 [=====>........................] - ETA: 1s - loss: 0.4884 - acc: 0.7561
 4928/20000 [======>.......................] - ETA: 1s - loss: 0.4894 - acc: 0.7571
 5280/20000 [======>.......................] - ETA: 1s - loss: 0.4838 - acc: 0.7602
 5664/20000 [=======>......................] - ETA: 1s - loss: 0.4864 - acc: 0.7578
 6080/20000 [========>.....................] - ETA: 1s - loss: 0.4890 - acc: 0.7569
 6528/20000 [========>.....................] - ETA: 1s - loss: 0.4861 - acc: 0.7589
 6976/20000 [=========>....................] - ETA: 1s - loss: 0.4947 - acc: 0.7552
 7424/20000 [==========>...................] - ETA: 1s - loss: 0.4936 - acc: 0.7561
 7904/20000 [==========>...................] - ETA: 1s - loss: 0.4998 - acc: 0.7522
 8352/20000 [===========>..................] - ETA: 1s - loss: 0.5018 - acc: 0.7522
 8832/20000 [============>.................] - ETA: 1s - loss: 0.5025 - acc: 0.7519
 9312/20000 [============>.................] - ETA: 1s - loss: 0.5016 - acc: 0.7514
 9792/20000 [=============>................] - ETA: 1s - loss: 0.5037 - acc: 0.7511
10272/20000 [==============>...............] - ETA: 1s - loss: 0.5071 - acc: 0.7482
10752/20000 [===============>..............] - ETA: 1s - loss: 0.5060 - acc: 0.7490
11232/20000 [===============>..............] - ETA: 1s - loss: 0.5092 - acc: 0.7480
11712/20000 [================>.............] - ETA: 0s - loss: 0.5101 - acc: 0.7473
12192/20000 [=================>............] - ETA: 0s - loss: 0.5100 - acc: 0.7472
12672/20000 [==================>...........] - ETA: 0s - loss: 0.5148 - acc: 0.7446
13152/20000 [==================>...........] - ETA: 0s - loss: 0.5167 - acc: 0.7424
13632/20000 [===================>..........] - ETA: 0s - loss: 0.5177 - acc: 0.7418
14112/20000 [====================>.........] - ETA: 0s - loss: 0.5184 - acc: 0.7421
14560/20000 [====================>.........] - ETA: 0s - loss: 0.5176 - acc: 0.7417
15040/20000 [=====================>........] - ETA: 0s - loss: 0.5239 - acc: 0.7390
15520/20000 [======================>.......] - ETA: 0s - loss: 0.5258 - acc: 0.7376
16000/20000 [=======================>......] - ETA: 0s - loss: 0.5274 - acc: 0.7366
16480/20000 [=======================>......] - ETA: 0s - loss: 0.5305 - acc: 0.7356
16960/20000 [========================>.....] - ETA: 0s - loss: 0.5314 - acc: 0.7348
17440/20000 [=========================>....] - ETA: 0s - loss: 0.5323 - acc: 0.7350
17920/20000 [=========================>....] - ETA: 0s - loss: 0.5320 - acc: 0.7353
18400/20000 [==========================>...] - ETA: 0s - loss: 0.5324 - acc: 0.7353
18880/20000 [===========================>..] - ETA: 0s - loss: 0.5334 - acc: 0.7348
19360/20000 [============================>.] - ETA: 0s - loss: 0.5324 - acc: 0.7348
19840/20000 [============================>.] - ETA: 0s - loss: 0.5348 - acc: 0.7340
20000/20000 [==============================] - 3s 130us/sample - loss: 0.5351 - acc: 0.7340 - val_loss: 1.1786 - val_acc: 0.5578
Epoch 6/6

   32/20000 [..............................] - ETA: 2s - loss: 0.3833 - acc: 0.8125
  448/20000 [..............................] - ETA: 2s - loss: 0.4162 - acc: 0.8103
  896/20000 [>.............................] - ETA: 2s - loss: 0.4317 - acc: 0.7913
 1344/20000 [=>............................] - ETA: 2s - loss: 0.4474 - acc: 0.7805
 1792/20000 [=>............................] - ETA: 2s - loss: 0.4643 - acc: 0.7762
 2208/20000 [==>...........................] - ETA: 2s - loss: 0.4643 - acc: 0.7745
 2688/20000 [===>..........................] - ETA: 2s - loss: 0.4665 - acc: 0.7712
 3104/20000 [===>..........................] - ETA: 2s - loss: 0.4572 - acc: 0.7742
 3552/20000 [====>.........................] - ETA: 1s - loss: 0.4520 - acc: 0.7779
 4032/20000 [=====>........................] - ETA: 1s - loss: 0.4547 - acc: 0.7773
 4480/20000 [=====>........................] - ETA: 1s - loss: 0.4529 - acc: 0.7768
 4960/20000 [======>.......................] - ETA: 1s - loss: 0.4577 - acc: 0.7726
 5440/20000 [=======>......................] - ETA: 1s - loss: 0.4519 - acc: 0.7765
 5920/20000 [=======>......................] - ETA: 1s - loss: 0.4533 - acc: 0.7782
 6016/20000 [========>.....................] - ETA: 1s - loss: 0.4541 - acc: 0.7778
 6336/20000 [========>.....................] - ETA: 1s - loss: 0.4588 - acc: 0.7754
 6784/20000 [=========>....................] - ETA: 1s - loss: 0.4639 - acc: 0.7718
 7232/20000 [=========>....................] - ETA: 1s - loss: 0.4609 - acc: 0.7734
 7680/20000 [==========>...................] - ETA: 1s - loss: 0.4593 - acc: 0.7737
 8128/20000 [===========>..................] - ETA: 1s - loss: 0.4633 - acc: 0.7708
 8576/20000 [===========>..................] - ETA: 1s - loss: 0.4613 - acc: 0.7722
 9056/20000 [============>.................] - ETA: 1s - loss: 0.4615 - acc: 0.7712
 9536/20000 [=============>................] - ETA: 1s - loss: 0.4668 - acc: 0.7687
10016/20000 [==============>...............] - ETA: 1s - loss: 0.4669 - acc: 0.7691
10464/20000 [==============>...............] - ETA: 1s - loss: 0.4678 - acc: 0.7690
10944/20000 [===============>..............] - ETA: 1s - loss: 0.4679 - acc: 0.7696
11424/20000 [================>.............] - ETA: 1s - loss: 0.4700 - acc: 0.7684
11904/20000 [================>.............] - ETA: 0s - loss: 0.4715 - acc: 0.7672
12384/20000 [=================>............] - ETA: 0s - loss: 0.4728 - acc: 0.7662
12864/20000 [==================>...........] - ETA: 0s - loss: 0.4772 - acc: 0.7638
13344/20000 [===================>..........] - ETA: 0s - loss: 0.4770 - acc: 0.7638
13856/20000 [===================>..........] - ETA: 0s - loss: 0.4790 - acc: 0.7631
14336/20000 [====================>.........] - ETA: 0s - loss: 0.4829 - acc: 0.7611
14816/20000 [=====================>........] - ETA: 0s - loss: 0.4821 - acc: 0.7615
15296/20000 [=====================>........] - ETA: 0s - loss: 0.4832 - acc: 0.7611
15776/20000 [======================>.......] - ETA: 0s - loss: 0.4832 - acc: 0.7619
16256/20000 [=======================>......] - ETA: 0s - loss: 0.4853 - acc: 0.7606
16736/20000 [========================>.....] - ETA: 0s - loss: 0.4861 - acc: 0.7603
17216/20000 [========================>.....] - ETA: 0s - loss: 0.4885 - acc: 0.7587
17664/20000 [=========================>....] - ETA: 0s - loss: 0.4896 - acc: 0.7585
18080/20000 [==========================>...] - ETA: 0s - loss: 0.4896 - acc: 0.7585
18560/20000 [==========================>...] - ETA: 0s - loss: 0.4907 - acc: 0.7581
19040/20000 [===========================>..] - ETA: 0s - loss: 0.4952 - acc: 0.7564
19520/20000 [============================>.] - ETA: 0s - loss: 0.4962 - acc: 0.7561
20000/20000 [==============================] - 3s 130us/sample - loss: 0.4974 - acc: 0.7560 - val_loss: 1.1034 - val_acc: 0.5508
plot(glove.history)

glove.results <- glove.mod %>% evaluate(x_test.s, y_test.s)

   32/25000 [..............................] - ETA: 2s - loss: 0.9065 - acc: 0.7188
  960/25000 [>.............................] - ETA: 1s - loss: 0.9632 - acc: 0.5865
 1728/25000 [=>............................] - ETA: 1s - loss: 0.9979 - acc: 0.5845
 2464/25000 [=>............................] - ETA: 1s - loss: 1.0050 - acc: 0.5735
 3264/25000 [==>...........................] - ETA: 1s - loss: 1.0495 - acc: 0.5634
 4160/25000 [===>..........................] - ETA: 1s - loss: 1.0518 - acc: 0.5635
 4960/25000 [====>.........................] - ETA: 1s - loss: 1.0657 - acc: 0.5579
 5760/25000 [=====>........................] - ETA: 1s - loss: 1.0701 - acc: 0.5552
 6560/25000 [======>.......................] - ETA: 1s - loss: 1.0637 - acc: 0.5569
 7232/25000 [=======>......................] - ETA: 1s - loss: 1.0630 - acc: 0.5568
 8064/25000 [========>.....................] - ETA: 1s - loss: 1.0638 - acc: 0.5588
 8864/25000 [=========>....................] - ETA: 1s - loss: 1.0646 - acc: 0.5572
 9696/25000 [==========>...................] - ETA: 0s - loss: 1.0638 - acc: 0.5587
10656/25000 [===========>..................] - ETA: 0s - loss: 1.0777 - acc: 0.5560
11424/25000 [============>.................] - ETA: 0s - loss: 1.0830 - acc: 0.5548
12160/25000 [=============>................] - ETA: 0s - loss: 1.0821 - acc: 0.5544
13024/25000 [==============>...............] - ETA: 0s - loss: 1.0827 - acc: 0.5546
13920/25000 [===============>..............] - ETA: 0s - loss: 1.0833 - acc: 0.5555
14720/25000 [================>.............] - ETA: 0s - loss: 1.0810 - acc: 0.5569
15392/25000 [=================>............] - ETA: 0s - loss: 1.0847 - acc: 0.5558
16512/25000 [==================>...........] - ETA: 0s - loss: 1.0850 - acc: 0.5555
17664/25000 [====================>.........] - ETA: 0s - loss: 1.0846 - acc: 0.5546
18816/25000 [=====================>........] - ETA: 0s - loss: 1.0848 - acc: 0.5528
19968/25000 [======================>.......] - ETA: 0s - loss: 1.0862 - acc: 0.5534
21024/25000 [========================>.....] - ETA: 0s - loss: 1.0883 - acc: 0.5532
21856/25000 [=========================>....] - ETA: 0s - loss: 1.0868 - acc: 0.5537
22656/25000 [==========================>...] - ETA: 0s - loss: 1.0860 - acc: 0.5534
23360/25000 [===========================>..] - ETA: 0s - loss: 1.0870 - acc: 0.5529
23872/25000 [===========================>..] - ETA: 0s - loss: 1.0862 - acc: 0.5531
24608/25000 [============================>.] - ETA: 0s - loss: 1.0854 - acc: 0.5535
25000/25000 [==============================] - 1s 60us/sample - loss: 1.0862 - acc: 0.5531

Accuracy of 55% Pretty bad. These GloVe embeddings are trained on Wikipedia and the Gigaword corpus. The 50 most important dimensions, relating words to a narrow window of surrounding words, of this very general language set of corpora, are nowhere near as useful as the six (or even one) most important dimension(s) relating words to sentiment within our training data.

To get an idea what’s going on …

library(text2vec)
text2vec is still in beta version - APIs can be changed.
For tutorials and examples visit http://text2vec.org.

For FAQ refer to
  1. https://stackoverflow.com/questions/tagged/text2vec?sort=newest
  2. https://github.com/dselivanov/text2vec/issues?utf8=%E2%9C%93&q=is%3Aissue%20label%3Aquestion
If you have questions please post them at StackOverflow and mark with 'text2vec' tag.

Attaching package: ‘text2vec’

The following objects are masked from ‘package:keras’:

    fit, normalize
find_similar_words <- function(word, embedding_matrix, n = 5) {
  similarities <- embedding_matrix[word, , drop = FALSE] %>%
    sim2(embedding_matrix, y = ., method = "cosine")
  
  similarities[,1] %>% sort(decreasing = TRUE) %>% head(n)
}

The concept embodied by a word in this review context, e.g., “waste”:

find_similar_words("waste", embeds.6, n=10)

may not be at all what’s captured about the word in the GloVe embeddings:

rownames(embedding_matrix) <- c("?",top_words[1:4999])
find_similar_words("waste", embedding_matrix, n=10)
    waste     water   garbage     clean 
1.0000000 0.7724197 0.7605300 0.7390320 
      gas   amounts     trash      mine 
0.7258156 0.6982243 0.6903367 0.6879122 
  natural       oil 
0.6776923 0.6602530 

Or “gem”:

find_similar_words("gem", embeds.6, n=10)
          gem      favorite         anime 
    1.0000000     0.9992570     0.9991338 
    wonderful unforgettable        faults 
    0.9990917     0.9989748     0.9985772 
   incredible   suspenseful     realistic 
    0.9985302     0.9985083     0.9983043 
       beauty 
    0.9982520 
find_similar_words("gem", embedding_matrix, n=10)
      gem   diamond  precious  treasure 
1.0000000 0.7375269 0.6808224 0.6128838 
priceless      gold  valuable      pulp 
0.5930778 0.5867127 0.5440980 0.5299630 
   golden   magical 
0.5298125 0.5195156 

Or “wooden”:

find_similar_words("wooden", embeds.6, n=10)
      wooden unconvincing    obnoxious 
   1.0000000    0.9991876    0.9991796 
   ludicrous      destroy     supposed 
   0.9991443    0.9989238    0.9988930 
   redeeming      unfunny          uwe 
   0.9988248    0.9987805    0.9987153 
        mess 
   0.9987089 
find_similar_words("wooden", embedding_matrix, n=10)
   wooden     walls     doors     glass 
1.0000000 0.7539166 0.7357615 0.7291114 
cardboard     floor     stone      wood 
0.7199160 0.7160277 0.7153557 0.7100430 
 attached     empty 
0.7053251 0.7023291 

Or “moving”:

find_similar_words("moving", embeds.6, n=10)
      moving    essential   unexpected 
   1.0000000    0.9992714    0.9985157 
      berlin      balance      enjoyed 
   0.9983773    0.9982949    0.9979986 
conventional     stunning         love 
   0.9970112    0.9969180    0.9968535 
  importance 
   0.9964524 
find_similar_words("moving", embedding_matrix, n=10)
   moving   turning   through      turn 
1.0000000 0.8934571 0.8823608 0.8565476 
     into   quickly      move    coming 
0.8531339 0.8527476 0.8474165 0.8441274 
   beyond       way 
0.8403356 0.8368897 

In other contexts, pretrained embeddings can be very useful. But not, it appears, this one, at least without further tweaking. A reasonable approach might use pretrained embeddings as a starting point, and allow them to move based on the data from your specific context.

OK, that’s enough for now.

---
title: "An Introduction to Keras and Tensorflow in R"
subtitle: Text as Data, PLSC 597, Penn State
author: Burt L. Monroe
output:
  html_notebook:
    code_folding: show
    highlight: tango
    theme: united
    toc: yes
    df_print: paged
---
This notebook is based on tutorials, including code, from Rstudio (https://tensorflow.rstudio.com/) as well as Chollet and Allaire's [Deep Learning with R](https://www.manning.com/books/deep-learning-with-r) (2018), currently accessible free to Penn State students through the Penn State library: https://learning.oreilly.com/library/view/deep-learning-with/9781617295546/. Notebooks for the originals of *Deep Learning with R* are available here: https://github.com/jjallaire/deep-learning-with-r-notebooks.

We will be implementing neural models in R through the **keras** package, which itself, by default, uses the **tensorflow** "backend." You can access TensorFlow directly -- which provides more flexibility but requires more of the user -- and you can also use different backends, specifically **CNTK** and **Theano** through keras. (The R library keras is an *interface* to Keras itself, which offers an *API* to a *backend* like TensorFlow.) Keras is generally described as "high-level" or "model-level", meaning the researcher can build models using Keras building blocks -- which is probably all most of you would ever want to do.

**Warning 1**: Keras (https://keras.io) is written in Python, so (a) installing keras and tensorflow creates a Python environment on your machine (in my case, it detects Anaconda and creates a conda environment called `r-tensorflow`), and (b) much of the keras syntax is Pythonic (like 0-based indexing in some contexts), as are the often untraceable error messages.

```{r}
# devtools::install_github("rstudio/keras")
```

I used the `install_keras` function to install a default CPU-based keras and tensorflow. There are more details for alternative installations here: https://tensorflow.rstudio.com/keras/ 

**Warning 2**: There is currently an error in TensorFlow that manifested in the code that follows. The fix in my case was installing the "nightly" build which also back installs Python 3.6 instead of 3.7 in a new `r-tensorflow` environment. By the time you read this, the error probably won't exist, in which case `install_keras()` should be sufficient.
```{r}
library(keras)
# install_keras(tensorflow="nightly")
```


## Building a Deep Classifier in Keras

We'll work with the IMDB review dataset that comes with keras

```{r}
imdb <- dataset_imdb(num_words = 5000)
c(c(train_data, train_labels), c(test_data, test_labels)) %<-% imdb
### This is equivalent to:
# imdb <- dataset_imdb(num_words = 5000)
# train_data <- imdb$train$x
# train_labels <- imdb$train$y
# test_data <- imdb$test$x
# test_labels <- imdb$test$y
```

Look at the training data (features) and labels (positive/negative).

```{r}
str(train_data[[1]])

train_labels[[1]]

max(sapply(train_data, max))
```

Probably a good idea to figure out how to get back to text. Decode review 1:
```{r}
word_index <- dataset_imdb_word_index()
reverse_word_index <- names(word_index)
names(reverse_word_index) <- word_index
decoded_review <- sapply(train_data[[1]], 
                         function(index) {
                           word <- if (index >= 3)
reverse_word_index[[as.character(index - 3)]]
                           if (!is.null(word)) word else "?"
                           })
decoded_review
```

Create "one-hot" vectorization of input features. (Binary indicators of presence or absence of feature - in this case word / token - in "sequence" - in this case review.)
```{r}
vectorize_sequences <- function(sequences, dimension = 5000) {
  results <- matrix(0, nrow = length(sequences), ncol = dimension)
  for (i in 1:length(sequences))
    results[i, sequences[[i]]] <- 1 
  results
}

x_train <- vectorize_sequences(train_data)
x_test <- vectorize_sequences(test_data)

# Also change labels from integer to numeric
y_train <- as.numeric(train_labels)
y_test <- as.numeric(test_labels)
```

We need a model architecture. In many cases, this can be built a simple layer building blocks from Keras. Here's a three layer network for our classification problem:

```{r}
model <- keras_model_sequential() %>%
  layer_dense(units = 16, activation = "relu", input_shape = c(5000)) %>%
  layer_dense(units = 16, activation = "relu") %>%
  layer_dense(units = 1, activation = "sigmoid")
```

We need to "compile" our model by adding information about our loss function, what optimizer we wish to use, and what metrics we want to keep track of.

```{r}
model %>% compile(
  optimizer = "rmsprop",
  loss = "binary_crossentropy",
  metrics = c("accuracy")
)
```

Create a held-out set of your training data for validation.
```{r}
val_indices <- 1:10000 # not great practice if these are ordered

x_val <- x_train[val_indices,]
partial_x_train <- x_train[-val_indices,]
y_val <- y_train[val_indices]
partial_y_train <- y_train[-val_indices]
```

Fit the model, store the fit history.
```{r}
history <- model %>% fit(
  partial_x_train,
  partial_y_train,
  epochs = 20,
  batch_size = 512,
  validation_data = list(x_val, y_val)
)
```

```{r}
str(history)
```

```{r}
plot(history)
```

Overfitting. Fit at smaller number of epochs.
```{r}
model <- keras_model_sequential() %>%
  layer_dense(units = 16, activation = "relu", input_shape = c(5000)) %>%
  layer_dense(units = 16, activation = "relu") %>%
  layer_dense(units = 1, activation = "sigmoid")

model %>% compile(
  optimizer = "rmsprop",
  loss = "binary_crossentropy",
  metrics = c("accuracy")
)

model %>% fit(x_train, y_train, epochs = 6, batch_size = 512)
results <- model %>% evaluate(x_test, y_test)
```

In the test data, we get accuracy of 87.96% with this model.

So, what happened there? What is the model learning? It learned a function that converts 5000 inputs into 16 hidden / latent / intermediate numbers, then converts those 16 into a different 16 intermediate numbers, and then those into 1 number at the output. Those intermediate functions are nonlinear, otherwise there wouldn't be any gain from stacking them together. But we can get an approximate idea of how these inputs map to the single output by treating each layer as linear and multiplying the weights through: $W^{(5000\times 1)}_{io} \approx W_{i1}^{(5000\times 16)} \times W_{12}^{(16\times 16)} \times W_{2o}^{(16\times 1)}$. Those aggregate weights can give us an approximate idea of the main effect of each input. (This will not work, generally speaking, in more complex models or contexts.)

```{r}
model.weights.approx <- (get_weights(model)[[1]] %*% get_weights(model)[[3]] %*% get_weights(model)[[5]])[,1]
top_words <- reverse_word_index[as.character(1:5000)]
names(model.weights.approx) <- c("<PAD>","<START>","<UNK>",top_words[1:4997])

sort(model.weights.approx, dec=T)[1:20]

sort(model.weights.approx, dec=F)[1:20]
```

"7" is interesting. It comes from reviews like #168, which ends like this:

```{r}
sapply(train_data[[168]],function(index) {
  word <- if (index >= 3) reverse_word_index[[as.character(index - 3)]]
  if (!is.null(word)) word 
  else "?"})[201:215]
```

That means this whole thing is cheating, as far as I'm concerned. Some of the reviews end with the text summarizing the rating with a number out of 10! And then use that to "predict" whether the review is positive or negative (which is probably based on that 10 point scale rating in the first place).

And we can look at these similarly to how we looked at the coefficients from the classifiers in our earlier notebook. I'll also highlight those numbers.

```{r,fig.width=6,fig.height=7}
# Plot weights
plot(colSums(x_train),model.weights.approx, pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Weights Learned in Deep Model, IMDB", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
text(colSums(x_train),model.weights.approx, names(model.weights.approx),pos=4,cex=1.5*abs(model.weights.approx), col=rgb(0,0,0,.5*abs(model.weights.approx)))
text(colSums(x_train[,c("1","2","3","4","5","6","7","8","9")]),model.weights.approx[c("1","2","3","4","5","6","7","8","9")], names(model.weights.approx[c("1","2","3","4","5","6","7","8","9")]),pos=4,cex=1.5*model.weights.approx[c("1","2","3","4","5","6","7","8","9")], col=rgb(1,0,0,1))
```


### Compare to shallow logistic classifier

It's worth pointing out that "deep" didn't buy us much. If we just estimate with a single sigmoid (logistic) layer, we get nearly identical results, with 87.72% accuracy.

```{r}
logistic.mod <- keras_model_sequential() %>%
  layer_dense(units = 1, activation = "sigmoid", input_shape = c(5000))

logistic.mod %>% compile(
  optimizer = "rmsprop",
  loss = "binary_crossentropy",
  metrics = c("accuracy")
)

logistic.mod %>% fit(x_train, y_train, epochs = 6, batch_size = 512)
results <- logistic.mod %>% evaluate(x_test, y_test)
```

This we can interpret directly. It's basically a ridge regression like we saw in the earlier classification notebook.


```{r, fig.width=7, fig.height=6}
logmod.weights <- get_weights(logistic.mod)[[1]][,1]
top_words <- reverse_word_index[as.character(1:5000)]
names(logmod.weights) <- c("<PAD>","<START>","<UNK>",top_words[1:4997])

#Most positive words
sort(logmod.weights,dec=T)[1:20]

#Most negative words
sort(logmod.weights,dec=F)[1:20]

# Plot weights
plot(colSums(x_train),logmod.weights, pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Weights Learned in Shallow Logistic Model, IMDB", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
text(colSums(x_train),logmod.weights, names(logmod.weights),pos=4,cex=10*abs(logmod.weights), col=rgb(0,0,0,3*abs(logmod.weights)))
```


### Compare to Naive Bayes

Just to have another comparison, let's check with Naive Bayes.

```{r}
library(quanteda)

colnames(x_train) <- colnames(x_test) <-  c("<PAD>","<START>","<UNK>",top_words[1:4997])
dfm.train <- as.dfm(x_train)
dfm.test <- as.dfm(x_test)
```

```{r}
nb.mod <- textmodel_nb(dfm.train, y_train, distribution = "Bernoulli")
summary(nb.mod)
```

```{r}
y_test.pred.nb <- predict(nb.mod, newdata=dfm.test)
nb.class.table <- table(y_test,y_test.pred.nb)
sum(diag(nb.class.table/sum(nb.class.table)))
```

As we might expect, that's not as good.

As before, Naive Bayes overfits.

```{r, fig.width=7, fig.height=6}

#Most positive words
sort(nb.mod$PcGw[2,],dec=T)[1:20]

#Most negative words
sort(nb.mod$PcGw[2,],dec=F)[1:20]

# Plot weights
plot(colSums(x_train),nb.mod$PcGw[2,], pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Posterior Probabilities, Naive Bayes Classifier, IMDB", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
text(colSums(x_train),nb.mod$PcGw[2,], names(logmod.weights),pos=4,cex=5*abs(.5-nb.mod$PcGw[2,]), col=rgb(0,0,0,1.5*abs(.5-nb.mod$PcGw[2,])))
```

## Word Embeddings

### Simple training of task-specific embeddings

Keras has its own "embedding" that you can use as a layer (the first layer) in a model.

```{r}
max_features <- 5000
maxlen <- 500
imdb.s <- dataset_imdb(num_words = max_features)
c(c(x_train.s, y_train.s), c(x_test.s, y_test.s)) %<-% imdb.s
x_train.s <- pad_sequences(x_train.s, maxlen = maxlen)
x_test.s <- pad_sequences(x_test.s, maxlen = maxlen)
```

```{r}
emb.mod <- keras_model_sequential() %>%
  layer_embedding(input_dim = max_features, output_dim = 6,  
                  input_length = maxlen) %>%
  layer_flatten() %>%
  layer_dense(units = 1, activation = "sigmoid")
emb.mod %>% compile(
  optimizer = "rmsprop",
  loss = "binary_crossentropy",
  metrics = c("acc")
)
summary(emb.mod)
emb.history <- emb.mod %>% fit(
  x_train.s, y_train.s, 
  epochs = 6,
  batch_size = 32,
  validation_split = 0.2
)
```

```{r}
emb.results <- emb.mod %>% evaluate(x_test.s, y_test.s)
```

That got us up to 88.68%, the best yet, so the embedding layer bought us something.

I've never seen anybody else do this, but you can look inside the embedding weights themselves to get a sense of what's going on. For space sake, we'll just plot the embeddings of the first 500 words. 

```{r, fig.width=7, fig.height=6}
embeds.6 <- get_weights(emb.mod)[[1]]
rownames(embeds.6) <- c("??","??","??","??",top_words[1:4996])
colnames(embeds.6) <- paste("Dim",1:6)

pairs(embeds.6[1:500,], col=rgb(0,0,0,.5),cex=.5, main="Six-dimensional Embeddings Trained in Classification Model")
```

Ahhhhh... our 6-dimensional embeddings are pretty close to one-dimensional. Why? Because they are trained for the essentially one-dimensional classification of positive-negative sentiment.

We can tease this out with something like PCA:

```{r, fig.width=7, fig.height=6}
pc.emb <- princomp(embeds.6)

summary(pc.emb)

# Most positive words (main dimension of embeddings)
sort(-pc.emb$scores[,1],dec=T)[1:20]

# Most negative words (main dimension of embeddings)
sort(-pc.emb$scores[,1],dec=F)[1:20]

# Most positive words (2nd dimension of embeddings)
sort(pc.emb$scores[,2],dec=T)[1:20]

# Most negative words (2nd dimension of embeddings)
sort(pc.emb$scores[,2],dec=F)[1:20]

```

```{r, fig.width=7, fig.height=6}
# Plot weights
plot(colSums(x_train),-pc.emb$scores[,1], pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Principal Component of Embeddings Layer", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
text(colSums(x_train),-pc.emb$scores[,1], rownames(embeds.6),pos=4,cex=1*abs(pc.emb$scores[,1]), col=rgb(0,0,0,.4*abs(pc.emb$scores[,1])))
```

This first embedding dimension is very similar to the weights we learned in the shallow logistic network:

```{r}
cor(logmod.weights[3:4999],-pc.emb$scores[4:5000,1])
```

The second dimension captures *something* ... perhaps in more nuanced criticism ... that acted as a subtle bit of extra information for our classifier that provided some slight improvement in perforamce. In effect, by letting the model learn six numbers per word instead of one, we gave it the opportunity to learn some more subtle characteristics of review sentiment. 

```{r, fig.width=7, fig.height=6}
# Plot weights
plot(colSums(x_train),pc.emb$scores[,2], pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Second Principal Component of Embeddings Layer", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
text(colSums(x_train),pc.emb$scores[,2], rownames(embeds.6),pos=4,cex=10*abs(pc.emb$scores[,2]), col=rgb(0,0,0,3*abs(pc.emb$scores[,2])))
```

### Using pretrained embeddings in classifier

The commented out code assumes you have downloaded the 822M zip file `glove.6B.zip` from https://nlp.stanford.edu/projects/glove/ and unzipped the folder (more than 2G). It then reads in the smallest file -- defining 50 dimensional embeddings for 400000 tokens -- 171M, and then creates `embedding_matrix` of 5000 x 50 for the 5000 top words in the imdb data (a 1.9M object).

I have included this object in an rds file, and the `readRDS` command reads in this object directly.

```{r}
# glove_dir <- "../Embeddings/glove.6B"
# lines <- readLines(file.path(glove_dir, "glove.6B.50d.txt"))

# embeddings_index <- new.env(hash = TRUE, parent = emptyenv())
# for (i in 1:length(lines)) {
#   line <- lines[[i]]
#   values <- strsplit(line, " ")[[1]]
#   word <- values[[1]]
#   embeddings_index[[word]] <- as.double(values[-1])
# }
# cat("Found", length(embeddings_index), "word vectors.\n")

max_words=5000

embedding_dim <- 50
# embedding_matrix <- array(0, c(max_words, embedding_dim))
# for (word in names(word_index)) {
#   index <- word_index[[word]]
#   if (index < max_words) {
#     embedding_vector <- embeddings_index[[word]]
#     if (!is.null(embedding_vector))
#       embedding_matrix[index+1,] <- embedding_vector
#     }
# }

# saveRDS(embedding_matrix, file="glove_imdb_example.rds")
```

```{r}
embedding_matrix <- readRDS("glove_imdb_example.rds")
```

```{r}
glove.mod <- keras_model_sequential() %>%
  layer_embedding(input_dim = max_words, output_dim = embedding_dim,
                  input_length = maxlen) %>%
  layer_flatten() %>%
  layer_dense(units = 1, activation = "sigmoid")
summary(glove.mod)

get_layer(glove.mod, index = 1) %>%
  set_weights(list(embedding_matrix)) %>%
  freeze_weights()

glove.mod %>% compile(
  optimizer = "rmsprop",
  loss = "binary_crossentropy",
  metrics = c("acc")
)

glove.history <- glove.mod %>% fit(
  x_train.s, y_train.s, 
  epochs = 6,
  batch_size = 32,
  validation_split = 0.2
)

plot(glove.history)
```

```{r}
glove.results <- glove.mod %>% evaluate(x_test.s, y_test.s)
```

Accuracy of 55% Pretty bad. These GloVe embeddings are trained on Wikipedia and the Gigaword corpus. The 50 most important dimensions, relating words to a narrow window of surrounding words, of this very general language set of corpora, are nowhere near as useful as the six (or even one) most important dimension(s) relating words to sentiment within our training data.

To get an idea what's going on ...

```{r}
library(text2vec)

find_similar_words <- function(word, embedding_matrix, n = 5) {
  similarities <- embedding_matrix[word, , drop = FALSE] %>%
    sim2(embedding_matrix, y = ., method = "cosine")
  
  similarities[,1] %>% sort(decreasing = TRUE) %>% head(n)
}
```

The concept embodied by a word in this review context, e.g., "waste":
```{r}
find_similar_words("waste", embeds.6, n=10)
```

may not be at all what's captured about the word in the GloVe embeddings:

```{r}
rownames(embedding_matrix) <- c("?",top_words[1:4999])
find_similar_words("waste", embedding_matrix, n=10)
```

Or "gem":
```{r}
find_similar_words("gem", embeds.6, n=10)

find_similar_words("gem", embedding_matrix, n=10)
```

Or "wooden":
```{r}
find_similar_words("wooden", embeds.6, n=10)
find_similar_words("wooden", embedding_matrix, n=10)
```

Or "moving":
```{r}
find_similar_words("moving", embeds.6, n=10)
find_similar_words("moving", embedding_matrix, n=10)
```


In other contexts, pretrained embeddings can be very useful. But not, it appears, this one, at least without further tweaking. A reasonable approach might use pretrained embeddings as a starting point, and allow them to move based on the data from your specific context.

OK, that's enough for now.



