This notebook is based on tutorials, including code, from Rstudio (https://tensorflow.rstudio.com/) as well as Chollet and Allaire’s Deep Learning with R (2018), currently accessible free to Penn State students through the Penn State library: https://learning.oreilly.com/library/view/deep-learning-with/9781617295546/. Notebooks for the originals of Deep Learning with R are available here: https://github.com/jjallaire/deep-learning-with-r-notebooks.
We will be implementing neural models in R through the keras package, which itself, by default, uses the tensorflow “backend.” You can access TensorFlow directly – which provides more flexibility but requires more of the user – and you can also use different backends, specifically CNTK and Theano through keras. (The R library keras is an interface to Keras itself, which offers an API to a backend like TensorFlow.) Keras is generally described as “high-level” or “model-level”, meaning the researcher can build models using Keras building blocks – which is probably all most of you would ever want to do.
Warning 1: Keras (https://keras.io) is written in Python, so (a) installing keras and tensorflow creates a Python environment on your machine (in my case, it detects Anaconda and creates a conda environment called r-tensorflow
), and (b) much of the keras syntax is Pythonic (like 0-based indexing in some contexts), as are the often untraceable error messages.
# devtools::install_github("rstudio/keras")
I used the install_keras
function to install a default CPU-based keras and tensorflow. There are more details for alternative installations here: https://tensorflow.rstudio.com/keras/
Warning 2: There is currently an error in TensorFlow that manifested in the code that follows. The fix in my case was installing the “nightly” build which also back installs Python 3.6 instead of 3.7 in a new r-tensorflow
environment. By the time you read this, the error probably won’t exist, in which case install_keras()
should be sufficient.
library(keras)
# install_keras(tensorflow="nightly")
We’ll work with the IMDB review dataset that comes with keras
imdb <- dataset_imdb(num_words = 5000)
c(c(train_data, train_labels), c(test_data, test_labels)) %<-% imdb
### This is equivalent to:
# imdb <- dataset_imdb(num_words = 5000)
# train_data <- imdb$train$x
# train_labels <- imdb$train$y
# test_data <- imdb$test$x
# test_labels <- imdb$test$y
Look at the training data (features) and labels (positive/negative).
str(train_data[[1]])
int [1:218] 1 14 22 16 43 530 973 1622 1385 65 ...
train_labels[[1]]
[1] 1
max(sapply(train_data, max))
[1] 4999
Probably a good idea to figure out how to get back to text. Decode review 1:
decoded_review
[1] "?" "this" "film"
[4] "was" "just" "brilliant"
[7] "casting" "location" "scenery"
[10] "story" "direction" "everyone's"
[13] "really" "suited" "the"
[16] "part" "they" "played"
[19] "and" "you" "could"
[22] "just" "imagine" "being"
[25] "there" "robert" "?"
[28] "is" "an" "amazing"
[31] "actor" "and" "now"
[34] "the" "same" "being"
[37] "director" "?" "father"
[40] "came" "from" "the"
[43] "same" "scottish" "island"
[46] "as" "myself" "so"
[49] "i" "loved" "the"
[52] "fact" "there" "was"
[55] "a" "real" "connection"
[58] "with" "this" "film"
[61] "the" "witty" "remarks"
[64] "throughout" "the" "film"
[67] "were" "great" "it"
[70] "was" "just" "brilliant"
[73] "so" "much" "that"
[76] "i" "bought" "the"
[79] "film" "as" "soon"
[82] "as" "it" "was"
[85] "released" "for" "?"
[88] "and" "would" "recommend"
[91] "it" "to" "everyone"
[94] "to" "watch" "and"
[97] "the" "fly" "?"
[100] "was" "amazing" "really"
[103] "cried" "at" "the"
[106] "end" "it" "was"
[109] "so" "sad" "and"
[112] "you" "know" "what"
[115] "they" "say" "if"
[118] "you" "cry" "at"
[121] "a" "film" "it"
[124] "must" "have" "been"
[127] "good" "and" "this"
[130] "definitely" "was" "also"
[133] "?" "to" "the"
[136] "two" "little" "?"
[139] "that" "played" "the"
[142] "?" "of" "norman"
[145] "and" "paul" "they"
[148] "were" "just" "brilliant"
[151] "children" "are" "often"
[154] "left" "out" "of"
[157] "the" "?" "list"
[160] "i" "think" "because"
[163] "the" "stars" "that"
[166] "play" "them" "all"
[169] "grown" "up" "are"
[172] "such" "a" "big"
[175] "?" "for" "the"
[178] "whole" "film" "but"
[181] "these" "children" "are"
[184] "amazing" "and" "should"
[187] "be" "?" "for"
[190] "what" "they" "have"
[193] "done" "don't" "you"
[196] "think" "the" "whole"
[199] "story" "was" "so"
[202] "lovely" "because" "it"
[205] "was" "true" "and"
[208] "was" "someone's" "life"
[211] "after" "all" "that"
[214] "was" "?" "with"
[217] "us" "all"
Create “one-hot” vectorization of input features. (Binary indicators of presence or absence of feature - in this case word / token - in “sequence” - in this case review.)
vectorize_sequences <- function(sequences, dimension = 5000) {
results <- matrix(0, nrow = length(sequences), ncol = dimension)
for (i in 1:length(sequences))
results[i, sequences[[i]]] <- 1
results
}
x_train <- vectorize_sequences(train_data)
x_test <- vectorize_sequences(test_data)
# Also change labels from integer to numeric
y_train <- as.numeric(train_labels)
y_test <- as.numeric(test_labels)
We need a model architecture. In many cases, this can be built a simple layer building blocks from Keras. Here’s a three layer network for our classification problem:
model <- keras_model_sequential() %>%
layer_dense(units = 16, activation = "relu", input_shape = c(5000)) %>%
layer_dense(units = 16, activation = "relu") %>%
layer_dense(units = 1, activation = "sigmoid")
We need to “compile” our model by adding information about our loss function, what optimizer we wish to use, and what metrics we want to keep track of.
model %>% compile(
optimizer = "rmsprop",
loss = "binary_crossentropy",
metrics = c("accuracy")
)
Create a held-out set of your training data for validation.
val_indices <- 1:10000 # not great practice if these are ordered
x_val <- x_train[val_indices,]
partial_x_train <- x_train[-val_indices,]
y_val <- y_train[val_indices]
partial_y_train <- y_train[-val_indices]
Fit the model, store the fit history.
history <- model %>% fit(
partial_x_train,
partial_y_train,
epochs = 20,
batch_size = 512,
validation_data = list(x_val, y_val)
)
2019-11-11 16:31:29.700615: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
* https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
* https://github.com/tensorflow/addons
* https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.
Train on 15000 samples, validate on 10000 samples
Epoch 1/20
512/15000 [>.............................] - ETA: 5s - loss: 0.6948 - acc: 0.4980
3584/15000 [======>.......................] - ETA: 0s - loss: 0.6565 - acc: 0.6283
7168/15000 [=============>................] - ETA: 0s - loss: 0.6042 - acc: 0.7121
10752/15000 [====================>.........] - ETA: 0s - loss: 0.5660 - acc: 0.7503
13824/15000 [==========================>...] - ETA: 0s - loss: 0.5373 - acc: 0.7717
15000/15000 [==============================] - 1s 51us/sample - loss: 0.5266 - acc: 0.7797 - val_loss: 0.4061 - val_acc: 0.8590
Epoch 2/20
512/15000 [>.............................] - ETA: 0s - loss: 0.3654 - acc: 0.8926
3584/15000 [======>.......................] - ETA: 0s - loss: 0.3678 - acc: 0.8792
6656/15000 [============>.................] - ETA: 0s - loss: 0.3620 - acc: 0.8804
9728/15000 [==================>...........] - ETA: 0s - loss: 0.3491 - acc: 0.8859
12800/15000 [========================>.....] - ETA: 0s - loss: 0.3385 - acc: 0.8870
15000/15000 [==============================] - 1s 37us/sample - loss: 0.3337 - acc: 0.8880 - val_loss: 0.3255 - val_acc: 0.8745
Epoch 3/20
512/15000 [>.............................] - ETA: 0s - loss: 0.2889 - acc: 0.9023
3584/15000 [======>.......................] - ETA: 0s - loss: 0.2659 - acc: 0.9141
6656/15000 [============>.................] - ETA: 0s - loss: 0.2666 - acc: 0.9124
9728/15000 [==================>...........] - ETA: 0s - loss: 0.2646 - acc: 0.9108
12800/15000 [========================>.....] - ETA: 0s - loss: 0.2616 - acc: 0.9103
15000/15000 [==============================] - 1s 34us/sample - loss: 0.2595 - acc: 0.9099 - val_loss: 0.2926 - val_acc: 0.8844
Epoch 4/20
512/15000 [>.............................] - ETA: 0s - loss: 0.1898 - acc: 0.9395
3584/15000 [======>.......................] - ETA: 0s - loss: 0.2176 - acc: 0.9314
6656/15000 [============>.................] - ETA: 0s - loss: 0.2225 - acc: 0.9259
9728/15000 [==================>...........] - ETA: 0s - loss: 0.2207 - acc: 0.9230
12288/15000 [=======================>......] - ETA: 0s - loss: 0.2234 - acc: 0.9214
14848/15000 [============================>.] - ETA: 0s - loss: 0.2229 - acc: 0.9209
15000/15000 [==============================] - 1s 36us/sample - loss: 0.2220 - acc: 0.9212 - val_loss: 0.2873 - val_acc: 0.8830
Epoch 5/20
512/15000 [>.............................] - ETA: 0s - loss: 0.2280 - acc: 0.9199
3072/15000 [=====>........................] - ETA: 0s - loss: 0.1959 - acc: 0.9303
5632/15000 [==========>...................] - ETA: 0s - loss: 0.1888 - acc: 0.9345
8192/15000 [===============>..............] - ETA: 0s - loss: 0.1878 - acc: 0.9330
10752/15000 [====================>.........] - ETA: 0s - loss: 0.1894 - acc: 0.9322
13824/15000 [==========================>...] - ETA: 0s - loss: 0.1927 - acc: 0.9301
15000/15000 [==============================] - 1s 36us/sample - loss: 0.1929 - acc: 0.9307 - val_loss: 0.3036 - val_acc: 0.8783
Epoch 6/20
512/15000 [>.............................] - ETA: 0s - loss: 0.1768 - acc: 0.9316
3072/15000 [=====>........................] - ETA: 0s - loss: 0.1677 - acc: 0.9430
5632/15000 [==========>...................] - ETA: 0s - loss: 0.1682 - acc: 0.9411
8704/15000 [================>.............] - ETA: 0s - loss: 0.1703 - acc: 0.9415
11776/15000 [======================>.......] - ETA: 0s - loss: 0.1678 - acc: 0.9423
14848/15000 [============================>.] - ETA: 0s - loss: 0.1729 - acc: 0.9387
15000/15000 [==============================] - 1s 36us/sample - loss: 0.1726 - acc: 0.9388 - val_loss: 0.2886 - val_acc: 0.8841
Epoch 7/20
512/15000 [>.............................] - ETA: 0s - loss: 0.1492 - acc: 0.9492
2048/15000 [===>..........................] - ETA: 0s - loss: 0.1376 - acc: 0.9556
4096/15000 [=======>......................] - ETA: 0s - loss: 0.1442 - acc: 0.9514
5632/15000 [==========>...................] - ETA: 0s - loss: 0.1466 - acc: 0.9510
7680/15000 [==============>...............] - ETA: 0s - loss: 0.1472 - acc: 0.9500
10240/15000 [===================>..........] - ETA: 0s - loss: 0.1519 - acc: 0.9485
12800/15000 [========================>.....] - ETA: 0s - loss: 0.1570 - acc: 0.9453
15000/15000 [==============================] - 1s 45us/sample - loss: 0.1586 - acc: 0.9443 - val_loss: 0.2977 - val_acc: 0.8807
Epoch 8/20
512/15000 [>.............................] - ETA: 0s - loss: 0.1394 - acc: 0.9414
3072/15000 [=====>........................] - ETA: 0s - loss: 0.1331 - acc: 0.9515
5632/15000 [==========>...................] - ETA: 0s - loss: 0.1349 - acc: 0.9522
8192/15000 [===============>..............] - ETA: 0s - loss: 0.1344 - acc: 0.9543
10752/15000 [====================>.........] - ETA: 0s - loss: 0.1408 - acc: 0.9510
13312/15000 [=========================>....] - ETA: 0s - loss: 0.1428 - acc: 0.9491
15000/15000 [==============================] - 1s 38us/sample - loss: 0.1414 - acc: 0.9500 - val_loss: 0.3113 - val_acc: 0.8759
Epoch 9/20
512/15000 [>.............................] - ETA: 0s - loss: 0.1477 - acc: 0.9492
3072/15000 [=====>........................] - ETA: 0s - loss: 0.1261 - acc: 0.9609
5632/15000 [==========>...................] - ETA: 0s - loss: 0.1271 - acc: 0.9590
8192/15000 [===============>..............] - ETA: 0s - loss: 0.1272 - acc: 0.9573
10752/15000 [====================>.........] - ETA: 0s - loss: 0.1260 - acc: 0.9574
13312/15000 [=========================>....] - ETA: 0s - loss: 0.1310 - acc: 0.9542
15000/15000 [==============================] - 1s 37us/sample - loss: 0.1310 - acc: 0.9541 - val_loss: 0.3268 - val_acc: 0.8776
Epoch 10/20
512/15000 [>.............................] - ETA: 0s - loss: 0.0997 - acc: 0.9766
3072/15000 [=====>........................] - ETA: 0s - loss: 0.1017 - acc: 0.9707
5632/15000 [==========>...................] - ETA: 0s - loss: 0.1053 - acc: 0.9686
8192/15000 [===============>..............] - ETA: 0s - loss: 0.1107 - acc: 0.9661
10752/15000 [====================>.........] - ETA: 0s - loss: 0.1135 - acc: 0.9636
13312/15000 [=========================>....] - ETA: 0s - loss: 0.1193 - acc: 0.9603
15000/15000 [==============================] - 1s 39us/sample - loss: 0.1194 - acc: 0.9596 - val_loss: 0.3521 - val_acc: 0.8697
Epoch 11/20
512/15000 [>.............................] - ETA: 0s - loss: 0.1096 - acc: 0.9648
3072/15000 [=====>........................] - ETA: 0s - loss: 0.1012 - acc: 0.9688
5632/15000 [==========>...................] - ETA: 0s - loss: 0.0988 - acc: 0.9696
7680/15000 [==============>...............] - ETA: 0s - loss: 0.1009 - acc: 0.9694
10240/15000 [===================>..........] - ETA: 0s - loss: 0.1050 - acc: 0.9672
12800/15000 [========================>.....] - ETA: 0s - loss: 0.1050 - acc: 0.9667
15000/15000 [==============================] - 1s 39us/sample - loss: 0.1054 - acc: 0.9657 - val_loss: 0.3571 - val_acc: 0.8696
Epoch 12/20
512/15000 [>.............................] - ETA: 0s - loss: 0.0850 - acc: 0.9766
3072/15000 [=====>........................] - ETA: 0s - loss: 0.0988 - acc: 0.9678
5632/15000 [==========>...................] - ETA: 0s - loss: 0.0961 - acc: 0.9700
8192/15000 [===============>..............] - ETA: 0s - loss: 0.0940 - acc: 0.9706
10752/15000 [====================>.........] - ETA: 0s - loss: 0.0936 - acc: 0.9706
13312/15000 [=========================>....] - ETA: 0s - loss: 0.0965 - acc: 0.9684
15000/15000 [==============================] - 1s 38us/sample - loss: 0.0991 - acc: 0.9673 - val_loss: 0.3736 - val_acc: 0.8680
Epoch 13/20
512/15000 [>.............................] - ETA: 0s - loss: 0.0817 - acc: 0.9746
3072/15000 [=====>........................] - ETA: 0s - loss: 0.0732 - acc: 0.9775
5632/15000 [==========>...................] - ETA: 0s - loss: 0.0762 - acc: 0.9783
8192/15000 [===============>..............] - ETA: 0s - loss: 0.0840 - acc: 0.9742
10240/15000 [===================>..........] - ETA: 0s - loss: 0.0849 - acc: 0.9741
12800/15000 [========================>.....] - ETA: 0s - loss: 0.0901 - acc: 0.9709
15000/15000 [==============================] - 1s 39us/sample - loss: 0.0915 - acc: 0.9697 - val_loss: 0.3958 - val_acc: 0.8648
Epoch 14/20
512/15000 [>.............................] - ETA: 0s - loss: 0.0643 - acc: 0.9883
3072/15000 [=====>........................] - ETA: 0s - loss: 0.0666 - acc: 0.9847
5120/15000 [=========>....................] - ETA: 0s - loss: 0.0665 - acc: 0.9840
7680/15000 [==============>...............] - ETA: 0s - loss: 0.0686 - acc: 0.9820
10240/15000 [===================>..........] - ETA: 0s - loss: 0.0730 - acc: 0.9790
12800/15000 [========================>.....] - ETA: 0s - loss: 0.0746 - acc: 0.9773
15000/15000 [==============================] - 1s 39us/sample - loss: 0.0767 - acc: 0.9764 - val_loss: 0.4126 - val_acc: 0.8631
Epoch 15/20
512/15000 [>.............................] - ETA: 0s - loss: 0.0629 - acc: 0.9883
3072/15000 [=====>........................] - ETA: 0s - loss: 0.0650 - acc: 0.9811
5120/15000 [=========>....................] - ETA: 0s - loss: 0.0721 - acc: 0.9789
7168/15000 [=============>................] - ETA: 0s - loss: 0.0756 - acc: 0.9766
9728/15000 [==================>...........] - ETA: 0s - loss: 0.0747 - acc: 0.9773
12800/15000 [========================>.....] - ETA: 0s - loss: 0.0751 - acc: 0.9767
15000/15000 [==============================] - 1s 39us/sample - loss: 0.0743 - acc: 0.9767 - val_loss: 0.4257 - val_acc: 0.8650
Epoch 16/20
512/15000 [>.............................] - ETA: 0s - loss: 0.0618 - acc: 0.9883
3072/15000 [=====>........................] - ETA: 0s - loss: 0.0593 - acc: 0.9867
5632/15000 [==========>...................] - ETA: 0s - loss: 0.0572 - acc: 0.9860
8192/15000 [===============>..............] - ETA: 0s - loss: 0.0605 - acc: 0.9832
10752/15000 [====================>.........] - ETA: 0s - loss: 0.0614 - acc: 0.9828
13312/15000 [=========================>....] - ETA: 0s - loss: 0.0632 - acc: 0.9816
15000/15000 [==============================] - 1s 39us/sample - loss: 0.0634 - acc: 0.9816 - val_loss: 0.4482 - val_acc: 0.8633
Epoch 17/20
512/15000 [>.............................] - ETA: 0s - loss: 0.0530 - acc: 0.9902
3072/15000 [=====>........................] - ETA: 0s - loss: 0.0670 - acc: 0.9805
5632/15000 [==========>...................] - ETA: 0s - loss: 0.0604 - acc: 0.9822
8192/15000 [===============>..............] - ETA: 0s - loss: 0.0618 - acc: 0.9813
10752/15000 [====================>.........] - ETA: 0s - loss: 0.0599 - acc: 0.9825
13312/15000 [=========================>....] - ETA: 0s - loss: 0.0608 - acc: 0.9817
15000/15000 [==============================] - 1s 39us/sample - loss: 0.0599 - acc: 0.9825 - val_loss: 0.4742 - val_acc: 0.8616
Epoch 18/20
512/15000 [>.............................] - ETA: 0s - loss: 0.0339 - acc: 0.9980
3072/15000 [=====>........................] - ETA: 0s - loss: 0.0410 - acc: 0.9935
5632/15000 [==========>...................] - ETA: 0s - loss: 0.0462 - acc: 0.9899
8192/15000 [===============>..............] - ETA: 0s - loss: 0.0532 - acc: 0.9849
10752/15000 [====================>.........] - ETA: 0s - loss: 0.0532 - acc: 0.9855
13312/15000 [=========================>....] - ETA: 0s - loss: 0.0519 - acc: 0.9866
15000/15000 [==============================] - 1s 38us/sample - loss: 0.0533 - acc: 0.9859 - val_loss: 0.5359 - val_acc: 0.8493
Epoch 19/20
512/15000 [>.............................] - ETA: 0s - loss: 0.0659 - acc: 0.9805
3072/15000 [=====>........................] - ETA: 0s - loss: 0.0431 - acc: 0.9906
5120/15000 [=========>....................] - ETA: 0s - loss: 0.0426 - acc: 0.9916
7680/15000 [==============>...............] - ETA: 0s - loss: 0.0414 - acc: 0.9914
10240/15000 [===================>..........] - ETA: 0s - loss: 0.0418 - acc: 0.9912
12800/15000 [========================>.....] - ETA: 0s - loss: 0.0468 - acc: 0.9887
15000/15000 [==============================] - 1s 39us/sample - loss: 0.0468 - acc: 0.9882 - val_loss: 0.5153 - val_acc: 0.8619
Epoch 20/20
512/15000 [>.............................] - ETA: 0s - loss: 0.0289 - acc: 0.9941
3072/15000 [=====>........................] - ETA: 0s - loss: 0.0300 - acc: 0.9948
5632/15000 [==========>...................] - ETA: 0s - loss: 0.0319 - acc: 0.9943
8192/15000 [===============>..............] - ETA: 0s - loss: 0.0375 - acc: 0.9918
10752/15000 [====================>.........] - ETA: 0s - loss: 0.0407 - acc: 0.9906
13312/15000 [=========================>....] - ETA: 0s - loss: 0.0407 - acc: 0.9905
15000/15000 [==============================] - 1s 39us/sample - loss: 0.0405 - acc: 0.9904 - val_loss: 0.5282 - val_acc: 0.8611
str(history)
List of 2
$ params :List of 7
..$ batch_size : int 512
..$ epochs : int 20
..$ steps : NULL
..$ samples : int 15000
..$ verbose : int 0
..$ do_validation: logi TRUE
..$ metrics : chr [1:4] "loss" "acc" "val_loss" "val_acc"
$ metrics:List of 4
..$ loss : num [1:20] 0.527 0.334 0.26 0.222 0.193 ...
..$ acc : num [1:20] 0.78 0.888 0.91 0.921 0.931 ...
..$ val_loss: num [1:20] 0.406 0.326 0.293 0.287 0.304 ...
..$ val_acc : num [1:20] 0.859 0.874 0.884 0.883 0.878 ...
- attr(*, "class")= chr "keras_training_history"
plot(history)
Overfitting. Fit at smaller number of epochs.
model <- keras_model_sequential() %>%
layer_dense(units = 16, activation = "relu", input_shape = c(5000)) %>%
layer_dense(units = 16, activation = "relu") %>%
layer_dense(units = 1, activation = "sigmoid")
model %>% compile(
optimizer = "rmsprop",
loss = "binary_crossentropy",
metrics = c("accuracy")
)
model %>% fit(x_train, y_train, epochs = 6, batch_size = 512)
Epoch 1/6
512/25000 [..............................] - ETA: 8s - loss: 0.6923 - acc: 0.5195
2560/25000 [==>...........................] - ETA: 2s - loss: 0.6715 - acc: 0.5996
4608/25000 [====>.........................] - ETA: 1s - loss: 0.6407 - acc: 0.6597
6656/25000 [======>.......................] - ETA: 0s - loss: 0.6095 - acc: 0.7006
9216/25000 [==========>...................] - ETA: 0s - loss: 0.5798 - acc: 0.7296
11776/25000 [=============>................] - ETA: 0s - loss: 0.5550 - acc: 0.7502
14336/25000 [================>.............] - ETA: 0s - loss: 0.5333 - acc: 0.7657
16896/25000 [===================>..........] - ETA: 0s - loss: 0.5143 - acc: 0.7792
19456/25000 [======================>.......] - ETA: 0s - loss: 0.4984 - acc: 0.7895
22016/25000 [=========================>....] - ETA: 0s - loss: 0.4845 - acc: 0.7987
24576/25000 [============================>.] - ETA: 0s - loss: 0.4729 - acc: 0.8053
25000/25000 [==============================] - 1s 36us/sample - loss: 0.4709 - acc: 0.8064
Epoch 2/6
512/25000 [..............................] - ETA: 0s - loss: 0.3216 - acc: 0.9043
3072/25000 [==>...........................] - ETA: 0s - loss: 0.3091 - acc: 0.9105
6144/25000 [======>.......................] - ETA: 0s - loss: 0.3146 - acc: 0.8991
9216/25000 [==========>...................] - ETA: 0s - loss: 0.3020 - acc: 0.9016
11776/25000 [=============>................] - ETA: 0s - loss: 0.2999 - acc: 0.9001
14336/25000 [================>.............] - ETA: 0s - loss: 0.2984 - acc: 0.8991
16896/25000 [===================>..........] - ETA: 0s - loss: 0.2968 - acc: 0.8978
19456/25000 [======================>.......] - ETA: 0s - loss: 0.2946 - acc: 0.8975
22016/25000 [=========================>....] - ETA: 0s - loss: 0.2935 - acc: 0.8968
24576/25000 [============================>.] - ETA: 0s - loss: 0.2919 - acc: 0.8964
25000/25000 [==============================] - 1s 26us/sample - loss: 0.2915 - acc: 0.8962
Epoch 3/6
512/25000 [..............................] - ETA: 0s - loss: 0.2423 - acc: 0.9219
3072/25000 [==>...........................] - ETA: 0s - loss: 0.2397 - acc: 0.9144
5632/25000 [=====>........................] - ETA: 0s - loss: 0.2429 - acc: 0.9148
8192/25000 [========>.....................] - ETA: 0s - loss: 0.2430 - acc: 0.9142
10752/25000 [===========>..................] - ETA: 0s - loss: 0.2430 - acc: 0.9134
13312/25000 [==============>...............] - ETA: 0s - loss: 0.2411 - acc: 0.9142
15872/25000 [==================>...........] - ETA: 0s - loss: 0.2383 - acc: 0.9144
18432/25000 [=====================>........] - ETA: 0s - loss: 0.2370 - acc: 0.9149
20992/25000 [========================>.....] - ETA: 0s - loss: 0.2380 - acc: 0.9140
22528/25000 [==========================>...] - ETA: 0s - loss: 0.2384 - acc: 0.9135
24576/25000 [============================>.] - ETA: 0s - loss: 0.2391 - acc: 0.9124
25000/25000 [==============================] - 1s 28us/sample - loss: 0.2389 - acc: 0.9125
Epoch 4/6
512/25000 [..............................] - ETA: 0s - loss: 0.1904 - acc: 0.9355
3072/25000 [==>...........................] - ETA: 0s - loss: 0.2067 - acc: 0.9271
5632/25000 [=====>........................] - ETA: 0s - loss: 0.2042 - acc: 0.9293
8192/25000 [========>.....................] - ETA: 0s - loss: 0.2061 - acc: 0.9281
10752/25000 [===========>..................] - ETA: 0s - loss: 0.2090 - acc: 0.9253
13312/25000 [==============>...............] - ETA: 0s - loss: 0.2071 - acc: 0.9257
15872/25000 [==================>...........] - ETA: 0s - loss: 0.2078 - acc: 0.9255
18432/25000 [=====================>........] - ETA: 0s - loss: 0.2080 - acc: 0.9252
20992/25000 [========================>.....] - ETA: 0s - loss: 0.2072 - acc: 0.9249
23552/25000 [===========================>..] - ETA: 0s - loss: 0.2094 - acc: 0.9236
25000/25000 [==============================] - 1s 26us/sample - loss: 0.2112 - acc: 0.9222
Epoch 5/6
512/25000 [..............................] - ETA: 0s - loss: 0.1536 - acc: 0.9531
3072/25000 [==>...........................] - ETA: 0s - loss: 0.1755 - acc: 0.9385
5632/25000 [=====>........................] - ETA: 0s - loss: 0.1755 - acc: 0.9400
8192/25000 [========>.....................] - ETA: 0s - loss: 0.1770 - acc: 0.9404
10752/25000 [===========>..................] - ETA: 0s - loss: 0.1831 - acc: 0.9342
13312/25000 [==============>...............] - ETA: 0s - loss: 0.1885 - acc: 0.9315
15872/25000 [==================>...........] - ETA: 0s - loss: 0.1903 - acc: 0.9300
17920/25000 [====================>.........] - ETA: 0s - loss: 0.1894 - acc: 0.9305
20480/25000 [=======================>......] - ETA: 0s - loss: 0.1925 - acc: 0.9282
23040/25000 [==========================>...] - ETA: 0s - loss: 0.1933 - acc: 0.9280
25000/25000 [==============================] - 1s 27us/sample - loss: 0.1959 - acc: 0.9275
Epoch 6/6
512/25000 [..............................] - ETA: 0s - loss: 0.1763 - acc: 0.9336
3072/25000 [==>...........................] - ETA: 0s - loss: 0.1597 - acc: 0.9479
5632/25000 [=====>........................] - ETA: 0s - loss: 0.1656 - acc: 0.9441
8192/25000 [========>.....................] - ETA: 0s - loss: 0.1738 - acc: 0.9386
10752/25000 [===========>..................] - ETA: 0s - loss: 0.1737 - acc: 0.9387
13312/25000 [==============>...............] - ETA: 0s - loss: 0.1748 - acc: 0.9374
15872/25000 [==================>...........] - ETA: 0s - loss: 0.1762 - acc: 0.9355
18432/25000 [=====================>........] - ETA: 0s - loss: 0.1792 - acc: 0.9344
20992/25000 [========================>.....] - ETA: 0s - loss: 0.1812 - acc: 0.9333
23552/25000 [===========================>..] - ETA: 0s - loss: 0.1820 - acc: 0.9331
25000/25000 [==============================] - 1s 26us/sample - loss: 0.1832 - acc: 0.9323
results <- model %>% evaluate(x_test, y_test)
32/25000 [..............................] - ETA: 37s - loss: 0.3705 - acc: 0.8438
2560/25000 [==>...........................] - ETA: 0s - loss: 0.3205 - acc: 0.8715
4544/25000 [====>.........................] - ETA: 0s - loss: 0.3138 - acc: 0.8737
6656/25000 [======>.......................] - ETA: 0s - loss: 0.3129 - acc: 0.8762
8960/25000 [=========>....................] - ETA: 0s - loss: 0.3109 - acc: 0.8765
11296/25000 [============>.................] - ETA: 0s - loss: 0.3198 - acc: 0.8743
13472/25000 [===============>..............] - ETA: 0s - loss: 0.3152 - acc: 0.8759
15456/25000 [=================>............] - ETA: 0s - loss: 0.3137 - acc: 0.8764
17312/25000 [===================>..........] - ETA: 0s - loss: 0.3139 - acc: 0.8771
19168/25000 [======================>.......] - ETA: 0s - loss: 0.3094 - acc: 0.8790
20960/25000 [========================>.....] - ETA: 0s - loss: 0.3086 - acc: 0.8789
22592/25000 [==========================>...] - ETA: 0s - loss: 0.3083 - acc: 0.8790
24576/25000 [============================>.] - ETA: 0s - loss: 0.3071 - acc: 0.8794
25000/25000 [==============================] - 1s 27us/sample - loss: 0.3068 - acc: 0.8796
In the test data, we get accuracy of 87.96% with this model.
So, what happened there? What is the model learning? It learned a function that converts 5000 inputs into 16 hidden / latent / intermediate numbers, then converts those 16 into a different 16 intermediate numbers, and then those into 1 number at the output. Those intermediate functions are nonlinear, otherwise there wouldn’t be any gain from stacking them together. But we can get an approximate idea of how these inputs map to the single output by treating each layer as linear and multiplying the weights through: \(W^{(5000\times 1)}_{io} \approx W_{i1}^{(5000\times 16)} \times W_{12}^{(16\times 16)} \times W_{2o}^{(16\times 1)}\). Those aggregate weights can give us an approximate idea of the main effect of each input. (This will not work, generally speaking, in more complex models or contexts.)
model.weights.approx <- (get_weights(model)[[1]] %*% get_weights(model)[[3]] %*% get_weights(model)[[5]])[,1]
top_words <- reverse_word_index[as.character(1:5000)]
names(model.weights.approx) <- c("<PAD>","<START>","<UNK>",top_words[1:4997])
sort(model.weights.approx, dec=T)[1:20]
7 refreshing wonderfully
1.6517927 1.4027301 1.3446769
rare edie hooked
1.2740762 1.2113184 1.2086754
excellent extraordinary noir
1.2033552 1.1940259 1.1527241
flawless appreciated 8
1.1384029 1.1199644 1.1187128
captures perfect tight
1.0980826 1.0617401 1.0599815
vengeance superb superbly
1.0371490 1.0346370 1.0314989
gem underrated
0.9951883 0.9877272
sort(model.weights.approx, dec=F)[1:20]
waste worst disappointment
-1.705202 -1.627594 -1.589494
poorly unfunny incoherent
-1.453674 -1.413595 -1.393460
unwatchable stinker awful
-1.391467 -1.353770 -1.352675
forgettable pointless laughable
-1.306769 -1.297418 -1.297009
lousy obnoxious lacks
-1.252515 -1.246433 -1.233059
mst3k uninteresting disappointing
-1.222244 -1.215452 -1.211375
dull tiresome
-1.198182 -1.186388
“7” is interesting. It comes from reviews like #168, which ends like this:
sapply(train_data[[168]],function(index) {
word <- if (index >= 3) reverse_word_index[[as.character(index - 3)]]
if (!is.null(word)) word
else "?"})[201:215]
[1] "creates" "a" "great"
[4] "monster" "like" "atmosphere"
[7] "br" "br" "vote"
[10] "7" "and" "half"
[13] "out" "of" "10"
That means this whole thing is cheating, as far as I’m concerned. Some of the reviews end with the text summarizing the rating with a number out of 10! And then use that to “predict” whether the review is positive or negative (which is probably based on that 10 point scale rating in the first place).
And we can look at these similarly to how we looked at the coefficients from the classifiers in our earlier notebook. I’ll also highlight those numbers.
# Plot weights
plot(colSums(x_train),model.weights.approx, pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Weights Learned in Deep Model, IMDB", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
2 x values <= 0 omitted from logarithmic plot
text(colSums(x_train),model.weights.approx, names(model.weights.approx),pos=4,cex=1.5*abs(model.weights.approx), col=rgb(0,0,0,.5*abs(model.weights.approx)))
text(colSums(x_train[,c("1","2","3","4","5","6","7","8","9")]),model.weights.approx[c("1","2","3","4","5","6","7","8","9")], names(model.weights.approx[c("1","2","3","4","5","6","7","8","9")]),pos=4,cex=1.5*model.weights.approx[c("1","2","3","4","5","6","7","8","9")], col=rgb(1,0,0,1))
It’s worth pointing out that “deep” didn’t buy us much. If we just estimate with a single sigmoid (logistic) layer, we get nearly identical results, with 87.72% accuracy.
logistic.mod <- keras_model_sequential() %>%
layer_dense(units = 1, activation = "sigmoid", input_shape = c(5000))
logistic.mod %>% compile(
optimizer = "rmsprop",
loss = "binary_crossentropy",
metrics = c("accuracy")
)
logistic.mod %>% fit(x_train, y_train, epochs = 6, batch_size = 512)
Epoch 1/6
512/25000 [..............................] - ETA: 7s - loss: 0.6906 - acc: 0.5449
3072/25000 [==>...........................] - ETA: 1s - loss: 0.6730 - acc: 0.5970
6144/25000 [======>.......................] - ETA: 0s - loss: 0.6562 - acc: 0.6478
8704/25000 [=========>....................] - ETA: 0s - loss: 0.6441 - acc: 0.6769
11776/25000 [=============>................] - ETA: 0s - loss: 0.6321 - acc: 0.7035
14848/25000 [================>.............] - ETA: 0s - loss: 0.6216 - acc: 0.7200
17920/25000 [====================>.........] - ETA: 0s - loss: 0.6121 - acc: 0.7338
20992/25000 [========================>.....] - ETA: 0s - loss: 0.6027 - acc: 0.7455
24064/25000 [===========================>..] - ETA: 0s - loss: 0.5942 - acc: 0.7552
25000/25000 [==============================] - 1s 30us/sample - loss: 0.5915 - acc: 0.7577
Epoch 2/6
512/25000 [..............................] - ETA: 0s - loss: 0.5087 - acc: 0.8359
4096/25000 [===>..........................] - ETA: 0s - loss: 0.5129 - acc: 0.8374
7680/25000 [========>.....................] - ETA: 0s - loss: 0.5060 - acc: 0.8439
10752/25000 [===========>..................] - ETA: 0s - loss: 0.5025 - acc: 0.8424
13824/25000 [===============>..............] - ETA: 0s - loss: 0.4960 - acc: 0.8432
16384/25000 [==================>...........] - ETA: 0s - loss: 0.4928 - acc: 0.8435
19456/25000 [======================>.......] - ETA: 0s - loss: 0.4883 - acc: 0.8445
23040/25000 [==========================>...] - ETA: 0s - loss: 0.4822 - acc: 0.8465
25000/25000 [==============================] - 1s 21us/sample - loss: 0.4797 - acc: 0.8470
Epoch 3/6
512/25000 [..............................] - ETA: 0s - loss: 0.4360 - acc: 0.8730
3584/25000 [===>..........................] - ETA: 0s - loss: 0.4352 - acc: 0.8627
6656/25000 [======>.......................] - ETA: 0s - loss: 0.4328 - acc: 0.8628
9728/25000 [==========>...................] - ETA: 0s - loss: 0.4317 - acc: 0.8600
13312/25000 [==============>...............] - ETA: 0s - loss: 0.4269 - acc: 0.8621
16384/25000 [==================>...........] - ETA: 0s - loss: 0.4233 - acc: 0.8646
19968/25000 [======================>.......] - ETA: 0s - loss: 0.4199 - acc: 0.8659
23040/25000 [==========================>...] - ETA: 0s - loss: 0.4167 - acc: 0.8678
25000/25000 [==============================] - 1s 22us/sample - loss: 0.4156 - acc: 0.8677
Epoch 4/6
512/25000 [..............................] - ETA: 0s - loss: 0.3826 - acc: 0.8828
3584/25000 [===>..........................] - ETA: 0s - loss: 0.3741 - acc: 0.8867
6656/25000 [======>.......................] - ETA: 0s - loss: 0.3778 - acc: 0.8824
9728/25000 [==========>...................] - ETA: 0s - loss: 0.3783 - acc: 0.8814
12800/25000 [==============>...............] - ETA: 0s - loss: 0.3777 - acc: 0.8798
15872/25000 [==================>...........] - ETA: 0s - loss: 0.3770 - acc: 0.8792
18944/25000 [=====================>........] - ETA: 0s - loss: 0.3767 - acc: 0.8778
22016/25000 [=========================>....] - ETA: 0s - loss: 0.3759 - acc: 0.8775
25000/25000 [==============================] - 1s 22us/sample - loss: 0.3723 - acc: 0.8801
Epoch 5/6
512/25000 [..............................] - ETA: 0s - loss: 0.3388 - acc: 0.8965
3584/25000 [===>..........................] - ETA: 0s - loss: 0.3491 - acc: 0.8876
6656/25000 [======>.......................] - ETA: 0s - loss: 0.3488 - acc: 0.8849
9728/25000 [==========>...................] - ETA: 0s - loss: 0.3506 - acc: 0.8836
12800/25000 [==============>...............] - ETA: 0s - loss: 0.3476 - acc: 0.8852
15872/25000 [==================>...........] - ETA: 0s - loss: 0.3466 - acc: 0.8853
18944/25000 [=====================>........] - ETA: 0s - loss: 0.3439 - acc: 0.8873
22016/25000 [=========================>....] - ETA: 0s - loss: 0.3426 - acc: 0.8870
25000/25000 [==============================] - 1s 22us/sample - loss: 0.3413 - acc: 0.8874
Epoch 6/6
512/25000 [..............................] - ETA: 0s - loss: 0.2844 - acc: 0.9180
3584/25000 [===>..........................] - ETA: 0s - loss: 0.3194 - acc: 0.8979
6656/25000 [======>.......................] - ETA: 0s - loss: 0.3240 - acc: 0.8947
9728/25000 [==========>...................] - ETA: 0s - loss: 0.3214 - acc: 0.8976
12800/25000 [==============>...............] - ETA: 0s - loss: 0.3222 - acc: 0.8950
15872/25000 [==================>...........] - ETA: 0s - loss: 0.3199 - acc: 0.8959
18944/25000 [=====================>........] - ETA: 0s - loss: 0.3191 - acc: 0.8959
22016/25000 [=========================>....] - ETA: 0s - loss: 0.3192 - acc: 0.8949
25000/25000 [==============================] - 1s 34us/sample - loss: 0.3180 - acc: 0.8949
results <- logistic.mod %>% evaluate(x_test, y_test)
32/25000 [..............................] - ETA: 51s - loss: 0.3305 - acc: 0.8438
1856/25000 [=>............................] - ETA: 1s - loss: 0.3378 - acc: 0.8793
4000/25000 [===>..........................] - ETA: 0s - loss: 0.3450 - acc: 0.8770
6400/25000 [======>.......................] - ETA: 0s - loss: 0.3449 - acc: 0.8748
8864/25000 [=========>....................] - ETA: 0s - loss: 0.3430 - acc: 0.8751
11424/25000 [============>.................] - ETA: 0s - loss: 0.3455 - acc: 0.8733
13984/25000 [===============>..............] - ETA: 0s - loss: 0.3442 - acc: 0.8738
16704/25000 [===================>..........] - ETA: 0s - loss: 0.3444 - acc: 0.8725
19456/25000 [======================>.......] - ETA: 0s - loss: 0.3421 - acc: 0.8753
22240/25000 [=========================>....] - ETA: 0s - loss: 0.3415 - acc: 0.8759
25000/25000 [==============================] - 1s 23us/sample - loss: 0.3405 - acc: 0.8772
This we can interpret directly. It’s basically a ridge regression like we saw in the earlier classification notebook.
logmod.weights <- get_weights(logistic.mod)[[1]][,1]
top_words <- reverse_word_index[as.character(1:5000)]
names(logmod.weights) <- c("<PAD>","<START>","<UNK>",top_words[1:4997])
#Most positive words
sort(logmod.weights,dec=T)[1:20]
excellent wonderful perfect
0.2910813 0.2788198 0.2777115
8 loved fantastic
0.2681858 0.2612670 0.2599397
favorite 7 best
0.2543596 0.2516940 0.2509511
superb highly incredible
0.2509483 0.2504135 0.2423920
great fun definitely
0.2413276 0.2395824 0.2350059
captures gem wonderfully
0.2334028 0.2325189 0.2299354
rare superbly
0.2299205 0.2298841
#Most negative words
sort(logmod.weights,dec=F)[1:20]
awful poor terrible
-0.2987652 -0.2955288 -0.2779662
worst redeeming stupid
-0.2719380 -0.2696612 -0.2666538
bad badly pointless
-0.2661456 -0.2661098 -0.2659274
waste mess worse
-0.2612805 -0.2611499 -0.2577833
1 poorly crap
-0.2522472 -0.2493045 -0.2488209
wasted ridiculous disappointing
-0.2469009 -0.2464800 -0.2455676
save pathetic
-0.2444713 -0.2432728
# Plot weights
plot(colSums(x_train),logmod.weights, pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Weights Learned in Shallow Logistic Model, IMDB", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
2 x values <= 0 omitted from logarithmic plot
text(colSums(x_train),logmod.weights, names(logmod.weights),pos=4,cex=10*abs(logmod.weights), col=rgb(0,0,0,3*abs(logmod.weights)))
Just to have another comparison, let’s check with Naive Bayes.
library(quanteda)
colnames(x_train) <- colnames(x_test) <- c("<PAD>","<START>","<UNK>",top_words[1:4997])
dfm.train <- as.dfm(x_train)
dfm.test <- as.dfm(x_test)
nb.mod <- textmodel_nb(dfm.train, y_train, distribution = "Bernoulli")
summary(nb.mod)
Call:
textmodel_nb.dfm(x = dfm.train, y = y_train, distribution = "Bernoulli")
Class Priors:
(showing first 2 elements)
0 1
0.5 0.5
Estimated Feature Scores:
<PAD> <START> <UNK> the and a
0 0.5 0.4998 0.5 0.5006 0.4972 0.5007
1 0.5 0.5002 0.5 0.4994 0.5028 0.4993
of to is br in it
0 0.5003 0.5049 0.494 0.512 0.4962 0.5031
1 0.4997 0.4951 0.506 0.488 0.5038 0.4969
i this that was as for
0 0.5199 0.51 0.5105 0.5283 0.4787 0.5025
1 0.4801 0.49 0.4895 0.4717 0.5213 0.4975
with movie but film on not
0 0.4939 0.5408 0.5126 0.4909 0.5088 0.5276
1 0.5061 0.4592 0.4874 0.5091 0.4912 0.4724
you are his have he be
0 0.5143 0.4932 0.462 0.5301 0.4785 0.527
1 0.4857 0.5068 0.538 0.4699 0.5215 0.473
y_test.pred.nb <- predict(nb.mod, newdata=dfm.test)
nb.class.table <- table(y_test,y_test.pred.nb)
sum(diag(nb.class.table/sum(nb.class.table)))
[1] 0.8418
As we might expect, that’s not as good.
As before, Naive Bayes overfits.
#Most positive words
sort(nb.mod$PcGw[2,],dec=T)[1:20]
edie paulie flawless
0.9756098 0.9500000 0.9200000
gundam matthau superbly
0.9130435 0.9104478 0.9098361
felix perfection captures
0.9019608 0.8985507 0.8829268
wonderfully masterful refreshing
0.8821656 0.8735632 0.8693467
breathtaking mildred delightful
0.8666667 0.8611111 0.8582677
polanski beautifully voight
0.8571429 0.8530120 0.8529412
underrated powell
0.8515284 0.8484848
#Most negative words
sort(nb.mod$PcGw[2,],dec=F)[1:20]
uwe boll unwatchable
0.03389831 0.03448276 0.04807692
stinker incoherent mst3k
0.04807692 0.05185185 0.06140351
unfunny waste seagal
0.06837607 0.07356322 0.07843137
atrocious pointless horrid
0.08762887 0.08874459 0.09009009
redeeming drivel blah
0.09148265 0.09836066 0.09876543
worst laughable lousy
0.09969122 0.09975062 0.10294118
awful wasting
0.10595568 0.10666667
# Plot weights
plot(colSums(x_train),nb.mod$PcGw[2,], pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Posterior Probabilities, Naive Bayes Classifier, IMDB", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
2 x values <= 0 omitted from logarithmic plot
text(colSums(x_train),nb.mod$PcGw[2,], names(logmod.weights),pos=4,cex=5*abs(.5-nb.mod$PcGw[2,]), col=rgb(0,0,0,1.5*abs(.5-nb.mod$PcGw[2,])))
Keras has its own “embedding” that you can use as a layer (the first layer) in a model.
max_features <- 5000
maxlen <- 500
imdb.s <- dataset_imdb(num_words = max_features)
c(c(x_train.s, y_train.s), c(x_test.s, y_test.s)) %<-% imdb.s
x_train.s <- pad_sequences(x_train.s, maxlen = maxlen)
x_test.s <- pad_sequences(x_test.s, maxlen = maxlen)
emb.mod <- keras_model_sequential() %>%
layer_embedding(input_dim = max_features, output_dim = 6,
input_length = maxlen) %>%
layer_flatten() %>%
layer_dense(units = 1, activation = "sigmoid")
emb.mod %>% compile(
optimizer = "rmsprop",
loss = "binary_crossentropy",
metrics = c("acc")
)
summary(emb.mod)
Model: "sequential_5"
______________________________________________
Layer (type) Output Shape Param #
==============================================
embedding_1 (Embedd (None, 500, 6) 30000
______________________________________________
flatten_1 (Flatten) (None, 3000) 0
______________________________________________
dense_9 (Dense) (None, 1) 3001
==============================================
Total params: 33,001
Trainable params: 33,001
Non-trainable params: 0
______________________________________________
emb.history <- emb.mod %>% fit(
x_train.s, y_train.s,
epochs = 6,
batch_size = 32,
validation_split = 0.2
)
Train on 20000 samples, validate on 5000 samples
Epoch 1/6
32/20000 [..............................] - ETA: 1:44 - loss: 0.6993 - acc: 0.3438
544/20000 [..............................] - ETA: 7s - loss: 0.6953 - acc: 0.4926
1024/20000 [>.............................] - ETA: 4s - loss: 0.6928 - acc: 0.5244
1600/20000 [=>............................] - ETA: 3s - loss: 0.6935 - acc: 0.5131
2176/20000 [==>...........................] - ETA: 3s - loss: 0.6936 - acc: 0.5087
2304/20000 [==>...........................] - ETA: 3s - loss: 0.6936 - acc: 0.5078
2784/20000 [===>..........................] - ETA: 3s - loss: 0.6931 - acc: 0.5133
3328/20000 [===>..........................] - ETA: 2s - loss: 0.6928 - acc: 0.5177
3904/20000 [====>.........................] - ETA: 2s - loss: 0.6929 - acc: 0.5169
4512/20000 [=====>........................] - ETA: 2s - loss: 0.6924 - acc: 0.5204
5152/20000 [======>.......................] - ETA: 2s - loss: 0.6917 - acc: 0.5297
5760/20000 [=======>......................] - ETA: 1s - loss: 0.6907 - acc: 0.5361
6336/20000 [========>.....................] - ETA: 1s - loss: 0.6899 - acc: 0.5406
6944/20000 [=========>....................] - ETA: 1s - loss: 0.6883 - acc: 0.5462
7520/20000 [==========>...................] - ETA: 1s - loss: 0.6867 - acc: 0.5527
8128/20000 [===========>..................] - ETA: 1s - loss: 0.6847 - acc: 0.5613
8736/20000 [============>.................] - ETA: 1s - loss: 0.6822 - acc: 0.5698
9376/20000 [=============>................] - ETA: 1s - loss: 0.6786 - acc: 0.5806
9984/20000 [=============>................] - ETA: 1s - loss: 0.6753 - acc: 0.5879
10624/20000 [==============>...............] - ETA: 1s - loss: 0.6716 - acc: 0.5955
11104/20000 [===============>..............] - ETA: 0s - loss: 0.6680 - acc: 0.6016
11744/20000 [================>.............] - ETA: 0s - loss: 0.6633 - acc: 0.6090
12384/20000 [=================>............] - ETA: 0s - loss: 0.6584 - acc: 0.6177
13056/20000 [==================>...........] - ETA: 0s - loss: 0.6520 - acc: 0.6269
13728/20000 [===================>..........] - ETA: 0s - loss: 0.6458 - acc: 0.6350
14400/20000 [====================>.........] - ETA: 0s - loss: 0.6397 - acc: 0.6421
15040/20000 [=====================>........] - ETA: 0s - loss: 0.6337 - acc: 0.6486
15712/20000 [======================>.......] - ETA: 0s - loss: 0.6273 - acc: 0.6554
16352/20000 [=======================>......] - ETA: 0s - loss: 0.6205 - acc: 0.6629
16992/20000 [========================>.....] - ETA: 0s - loss: 0.6141 - acc: 0.6690
17632/20000 [=========================>....] - ETA: 0s - loss: 0.6081 - acc: 0.6744
18272/20000 [==========================>...] - ETA: 0s - loss: 0.6018 - acc: 0.6799
18944/20000 [===========================>..] - ETA: 0s - loss: 0.5961 - acc: 0.6848
19584/20000 [============================>.] - ETA: 0s - loss: 0.5898 - acc: 0.6902
20000/20000 [==============================] - 2s 114us/sample - loss: 0.5859 - acc: 0.6935 - val_loss: 0.3996 - val_acc: 0.8404
Epoch 2/6
32/20000 [..............................] - ETA: 2s - loss: 0.3811 - acc: 0.8750
640/20000 [..............................] - ETA: 1s - loss: 0.3796 - acc: 0.8703
1312/20000 [>.............................] - ETA: 1s - loss: 0.3743 - acc: 0.8598
1952/20000 [=>............................] - ETA: 1s - loss: 0.3714 - acc: 0.8612
2592/20000 [==>...........................] - ETA: 1s - loss: 0.3602 - acc: 0.8681
3168/20000 [===>..........................] - ETA: 1s - loss: 0.3563 - acc: 0.8681
3584/20000 [====>.........................] - ETA: 1s - loss: 0.3532 - acc: 0.8686
4000/20000 [=====>........................] - ETA: 1s - loss: 0.3520 - acc: 0.8702
4608/20000 [=====>........................] - ETA: 1s - loss: 0.3492 - acc: 0.8691
5248/20000 [======>.......................] - ETA: 1s - loss: 0.3444 - acc: 0.8720
5856/20000 [=======>......................] - ETA: 1s - loss: 0.3431 - acc: 0.8723
6496/20000 [========>.....................] - ETA: 1s - loss: 0.3428 - acc: 0.8716
7104/20000 [=========>....................] - ETA: 1s - loss: 0.3409 - acc: 0.8720
7648/20000 [==========>...................] - ETA: 1s - loss: 0.3379 - acc: 0.8725
8256/20000 [===========>..................] - ETA: 1s - loss: 0.3375 - acc: 0.8711
8832/20000 [============>.................] - ETA: 1s - loss: 0.3364 - acc: 0.8718
9504/20000 [=============>................] - ETA: 0s - loss: 0.3340 - acc: 0.8729
10144/20000 [==============>...............] - ETA: 0s - loss: 0.3318 - acc: 0.8733
10816/20000 [===============>..............] - ETA: 0s - loss: 0.3310 - acc: 0.8725
11392/20000 [================>.............] - ETA: 0s - loss: 0.3303 - acc: 0.8723
12064/20000 [=================>............] - ETA: 0s - loss: 0.3270 - acc: 0.8737
12768/20000 [==================>...........] - ETA: 0s - loss: 0.3268 - acc: 0.8728
13440/20000 [===================>..........] - ETA: 0s - loss: 0.3255 - acc: 0.8726
14080/20000 [====================>.........] - ETA: 0s - loss: 0.3242 - acc: 0.8736
14752/20000 [=====================>........] - ETA: 0s - loss: 0.3230 - acc: 0.8738
15424/20000 [======================>.......] - ETA: 0s - loss: 0.3217 - acc: 0.8740
16032/20000 [=======================>......] - ETA: 0s - loss: 0.3213 - acc: 0.8741
16672/20000 [========================>.....] - ETA: 0s - loss: 0.3188 - acc: 0.8751
17344/20000 [=========================>....] - ETA: 0s - loss: 0.3178 - acc: 0.8752
17952/20000 [=========================>....] - ETA: 0s - loss: 0.3183 - acc: 0.8747
18560/20000 [==========================>...] - ETA: 0s - loss: 0.3176 - acc: 0.8745
19200/20000 [===========================>..] - ETA: 0s - loss: 0.3155 - acc: 0.8751
19872/20000 [============================>.] - ETA: 0s - loss: 0.3143 - acc: 0.8751
20000/20000 [==============================] - 2s 95us/sample - loss: 0.3139 - acc: 0.8754 - val_loss: 0.2994 - val_acc: 0.8768
Epoch 3/6
32/20000 [..............................] - ETA: 2s - loss: 0.4134 - acc: 0.7812
608/20000 [..............................] - ETA: 1s - loss: 0.2737 - acc: 0.8766
1152/20000 [>.............................] - ETA: 1s - loss: 0.2479 - acc: 0.9010
1696/20000 [=>............................] - ETA: 1s - loss: 0.2495 - acc: 0.9033
2304/20000 [==>...........................] - ETA: 1s - loss: 0.2642 - acc: 0.8984
2944/20000 [===>..........................] - ETA: 1s - loss: 0.2545 - acc: 0.9015
3584/20000 [====>.........................] - ETA: 1s - loss: 0.2531 - acc: 0.9012
4224/20000 [=====>........................] - ETA: 1s - loss: 0.2515 - acc: 0.9018
4864/20000 [======>.......................] - ETA: 1s - loss: 0.2491 - acc: 0.9036
5024/20000 [======>.......................] - ETA: 1s - loss: 0.2506 - acc: 0.9021
5568/20000 [=======>......................] - ETA: 1s - loss: 0.2464 - acc: 0.9039
6176/20000 [========>.....................] - ETA: 1s - loss: 0.2457 - acc: 0.9041
6752/20000 [=========>....................] - ETA: 1s - loss: 0.2465 - acc: 0.9034
7328/20000 [=========>....................] - ETA: 1s - loss: 0.2478 - acc: 0.9042
7904/20000 [==========>...................] - ETA: 1s - loss: 0.2474 - acc: 0.9044
8480/20000 [===========>..................] - ETA: 1s - loss: 0.2482 - acc: 0.9044
9088/20000 [============>.................] - ETA: 0s - loss: 0.2495 - acc: 0.9033
9728/20000 [=============>................] - ETA: 0s - loss: 0.2476 - acc: 0.9045
10368/20000 [==============>...............] - ETA: 0s - loss: 0.2464 - acc: 0.9053
11008/20000 [===============>..............] - ETA: 0s - loss: 0.2479 - acc: 0.9045
11648/20000 [================>.............] - ETA: 0s - loss: 0.2496 - acc: 0.9038
12256/20000 [=================>............] - ETA: 0s - loss: 0.2495 - acc: 0.9036
12896/20000 [==================>...........] - ETA: 0s - loss: 0.2499 - acc: 0.9041
13568/20000 [===================>..........] - ETA: 0s - loss: 0.2495 - acc: 0.9040
14208/20000 [====================>.........] - ETA: 0s - loss: 0.2475 - acc: 0.9054
14848/20000 [=====================>........] - ETA: 0s - loss: 0.2473 - acc: 0.9052
15520/20000 [======================>.......] - ETA: 0s - loss: 0.2465 - acc: 0.9050
16192/20000 [=======================>......] - ETA: 0s - loss: 0.2456 - acc: 0.9048
16864/20000 [========================>.....] - ETA: 0s - loss: 0.2455 - acc: 0.9049
17504/20000 [=========================>....] - ETA: 0s - loss: 0.2458 - acc: 0.9044
18176/20000 [==========================>...] - ETA: 0s - loss: 0.2453 - acc: 0.9042
18816/20000 [===========================>..] - ETA: 0s - loss: 0.2457 - acc: 0.9041
19424/20000 [============================>.] - ETA: 0s - loss: 0.2461 - acc: 0.9040
20000/20000 [==============================] - 2s 96us/sample - loss: 0.2467 - acc: 0.9041 - val_loss: 0.2816 - val_acc: 0.8868
Epoch 4/6
32/20000 [..............................] - ETA: 2s - loss: 0.2937 - acc: 0.8438
608/20000 [..............................] - ETA: 1s - loss: 0.2283 - acc: 0.9079
1184/20000 [>.............................] - ETA: 1s - loss: 0.2253 - acc: 0.9088
1728/20000 [=>............................] - ETA: 1s - loss: 0.2159 - acc: 0.9161
2240/20000 [==>...........................] - ETA: 1s - loss: 0.2223 - acc: 0.9152
2848/20000 [===>..........................] - ETA: 1s - loss: 0.2223 - acc: 0.9143
3424/20000 [====>.........................] - ETA: 1s - loss: 0.2259 - acc: 0.9130
4064/20000 [=====>........................] - ETA: 1s - loss: 0.2221 - acc: 0.9149
4736/20000 [======>.......................] - ETA: 1s - loss: 0.2210 - acc: 0.9158
5376/20000 [=======>......................] - ETA: 1s - loss: 0.2198 - acc: 0.9157
6048/20000 [========>.....................] - ETA: 1s - loss: 0.2186 - acc: 0.9153
6208/20000 [========>.....................] - ETA: 1s - loss: 0.2176 - acc: 0.9158
6752/20000 [=========>....................] - ETA: 1s - loss: 0.2190 - acc: 0.9154
7360/20000 [==========>...................] - ETA: 1s - loss: 0.2191 - acc: 0.9147
7968/20000 [==========>...................] - ETA: 1s - loss: 0.2191 - acc: 0.9154
8576/20000 [===========>..................] - ETA: 1s - loss: 0.2188 - acc: 0.9159
9184/20000 [============>.................] - ETA: 0s - loss: 0.2181 - acc: 0.9167
9760/20000 [=============>................] - ETA: 0s - loss: 0.2165 - acc: 0.9167
10336/20000 [==============>...............] - ETA: 0s - loss: 0.2184 - acc: 0.9158
10944/20000 [===============>..............] - ETA: 0s - loss: 0.2193 - acc: 0.9147
11616/20000 [================>.............] - ETA: 0s - loss: 0.2167 - acc: 0.9156
12256/20000 [=================>............] - ETA: 0s - loss: 0.2163 - acc: 0.9157
12896/20000 [==================>...........] - ETA: 0s - loss: 0.2162 - acc: 0.9163
13536/20000 [===================>..........] - ETA: 0s - loss: 0.2163 - acc: 0.9165
14208/20000 [====================>.........] - ETA: 0s - loss: 0.2159 - acc: 0.9166
14880/20000 [=====================>........] - ETA: 0s - loss: 0.2148 - acc: 0.9168
15552/20000 [======================>.......] - ETA: 0s - loss: 0.2158 - acc: 0.9164
16224/20000 [=======================>......] - ETA: 0s - loss: 0.2155 - acc: 0.9162
16896/20000 [========================>.....] - ETA: 0s - loss: 0.2160 - acc: 0.9161
17568/20000 [=========================>....] - ETA: 0s - loss: 0.2163 - acc: 0.9158
18208/20000 [==========================>...] - ETA: 0s - loss: 0.2166 - acc: 0.9159
18816/20000 [===========================>..] - ETA: 0s - loss: 0.2169 - acc: 0.9157
19456/20000 [============================>.] - ETA: 0s - loss: 0.2168 - acc: 0.9160
20000/20000 [==============================] - 2s 96us/sample - loss: 0.2166 - acc: 0.9161 - val_loss: 0.2762 - val_acc: 0.8908
Epoch 5/6
32/20000 [..............................] - ETA: 2s - loss: 0.1332 - acc: 0.9688
608/20000 [..............................] - ETA: 1s - loss: 0.1757 - acc: 0.9359
1184/20000 [>.............................] - ETA: 1s - loss: 0.1815 - acc: 0.9341
1728/20000 [=>............................] - ETA: 1s - loss: 0.1853 - acc: 0.9282
2304/20000 [==>...........................] - ETA: 1s - loss: 0.1842 - acc: 0.9293
2848/20000 [===>..........................] - ETA: 1s - loss: 0.1878 - acc: 0.9266
3456/20000 [====>.........................] - ETA: 1s - loss: 0.1837 - acc: 0.9280
4096/20000 [=====>........................] - ETA: 1s - loss: 0.1930 - acc: 0.9236
4736/20000 [======>.......................] - ETA: 1s - loss: 0.1906 - acc: 0.9250
5376/20000 [=======>......................] - ETA: 1s - loss: 0.1915 - acc: 0.9249
6048/20000 [========>.....................] - ETA: 1s - loss: 0.1887 - acc: 0.9256
6720/20000 [=========>....................] - ETA: 1s - loss: 0.1853 - acc: 0.9269
7392/20000 [==========>...................] - ETA: 1s - loss: 0.1856 - acc: 0.9267
7520/20000 [==========>...................] - ETA: 1s - loss: 0.1857 - acc: 0.9266
8064/20000 [===========>..................] - ETA: 1s - loss: 0.1882 - acc: 0.9265
8640/20000 [===========>..................] - ETA: 1s - loss: 0.1899 - acc: 0.9258
9216/20000 [============>.................] - ETA: 0s - loss: 0.1880 - acc: 0.9266
9792/20000 [=============>................] - ETA: 0s - loss: 0.1885 - acc: 0.9266
10400/20000 [==============>...............] - ETA: 0s - loss: 0.1903 - acc: 0.9265
10976/20000 [===============>..............] - ETA: 0s - loss: 0.1919 - acc: 0.9260
11584/20000 [================>.............] - ETA: 0s - loss: 0.1913 - acc: 0.9260
12256/20000 [=================>............] - ETA: 0s - loss: 0.1911 - acc: 0.9264
12896/20000 [==================>...........] - ETA: 0s - loss: 0.1926 - acc: 0.9256
13568/20000 [===================>..........] - ETA: 0s - loss: 0.1930 - acc: 0.9253
14240/20000 [====================>.........] - ETA: 0s - loss: 0.1941 - acc: 0.9250
14912/20000 [=====================>........] - ETA: 0s - loss: 0.1954 - acc: 0.9244
15552/20000 [======================>.......] - ETA: 0s - loss: 0.1950 - acc: 0.9248
16224/20000 [=======================>......] - ETA: 0s - loss: 0.1957 - acc: 0.9248
16896/20000 [========================>.....] - ETA: 0s - loss: 0.1943 - acc: 0.9257
17568/20000 [=========================>....] - ETA: 0s - loss: 0.1943 - acc: 0.9258
18208/20000 [==========================>...] - ETA: 0s - loss: 0.1947 - acc: 0.9255
18880/20000 [===========================>..] - ETA: 0s - loss: 0.1949 - acc: 0.9251
19552/20000 [============================>.] - ETA: 0s - loss: 0.1969 - acc: 0.9244
20000/20000 [==============================] - 2s 95us/sample - loss: 0.1969 - acc: 0.9245 - val_loss: 0.2784 - val_acc: 0.8898
Epoch 6/6
32/20000 [..............................] - ETA: 2s - loss: 0.1332 - acc: 0.9688
608/20000 [..............................] - ETA: 1s - loss: 0.1955 - acc: 0.9243
1184/20000 [>.............................] - ETA: 1s - loss: 0.1983 - acc: 0.9231
1728/20000 [=>............................] - ETA: 1s - loss: 0.1830 - acc: 0.9317
2272/20000 [==>...........................] - ETA: 1s - loss: 0.1808 - acc: 0.9340
2784/20000 [===>..........................] - ETA: 1s - loss: 0.1789 - acc: 0.9357
3392/20000 [====>.........................] - ETA: 1s - loss: 0.1872 - acc: 0.9337
4032/20000 [=====>........................] - ETA: 1s - loss: 0.1876 - acc: 0.9328
4672/20000 [======>.......................] - ETA: 1s - loss: 0.1877 - acc: 0.9317
5344/20000 [=======>......................] - ETA: 1s - loss: 0.1860 - acc: 0.9311
5984/20000 [=======>......................] - ETA: 1s - loss: 0.1825 - acc: 0.9325
6656/20000 [========>.....................] - ETA: 1s - loss: 0.1816 - acc: 0.9325
7296/20000 [=========>....................] - ETA: 1s - loss: 0.1826 - acc: 0.9313
7936/20000 [==========>...................] - ETA: 1s - loss: 0.1811 - acc: 0.9330
8544/20000 [===========>..................] - ETA: 0s - loss: 0.1786 - acc: 0.9346
8672/20000 [============>.................] - ETA: 1s - loss: 0.1797 - acc: 0.9344
9184/20000 [============>.................] - ETA: 0s - loss: 0.1828 - acc: 0.9329
9760/20000 [=============>................] - ETA: 0s - loss: 0.1828 - acc: 0.9333
10368/20000 [==============>...............] - ETA: 0s - loss: 0.1820 - acc: 0.9345
10944/20000 [===============>..............] - ETA: 0s - loss: 0.1821 - acc: 0.9343
11584/20000 [================>.............] - ETA: 0s - loss: 0.1806 - acc: 0.9351
12160/20000 [=================>............] - ETA: 0s - loss: 0.1808 - acc: 0.9346
12800/20000 [==================>...........] - ETA: 0s - loss: 0.1801 - acc: 0.9345
13472/20000 [===================>..........] - ETA: 0s - loss: 0.1800 - acc: 0.9343
14112/20000 [====================>.........] - ETA: 0s - loss: 0.1794 - acc: 0.9349
14752/20000 [=====================>........] - ETA: 0s - loss: 0.1796 - acc: 0.9350
15328/20000 [=====================>........] - ETA: 0s - loss: 0.1792 - acc: 0.9346
15936/20000 [======================>.......] - ETA: 0s - loss: 0.1805 - acc: 0.9337
16544/20000 [=======================>......] - ETA: 0s - loss: 0.1800 - acc: 0.9339
17184/20000 [========================>.....] - ETA: 0s - loss: 0.1805 - acc: 0.9334
17792/20000 [=========================>....] - ETA: 0s - loss: 0.1805 - acc: 0.9336
18464/20000 [==========================>...] - ETA: 0s - loss: 0.1810 - acc: 0.9335
19136/20000 [===========================>..] - ETA: 0s - loss: 0.1810 - acc: 0.9335
19776/20000 [============================>.] - ETA: 0s - loss: 0.1821 - acc: 0.9329
20000/20000 [==============================] - 2s 97us/sample - loss: 0.1821 - acc: 0.9329 - val_loss: 0.2811 - val_acc: 0.8918
emb.results <- emb.mod %>% evaluate(x_test.s, y_test.s)
32/25000 [..............................] - ETA: 0s - loss: 0.2774 - acc: 0.8750
2016/25000 [=>............................] - ETA: 0s - loss: 0.2774 - acc: 0.8800
4224/25000 [====>.........................] - ETA: 0s - loss: 0.2870 - acc: 0.8859
6592/25000 [======>.......................] - ETA: 0s - loss: 0.2831 - acc: 0.8844
8736/25000 [=========>....................] - ETA: 0s - loss: 0.2829 - acc: 0.8844
10624/25000 [===========>..................] - ETA: 0s - loss: 0.2887 - acc: 0.8830
12896/25000 [==============>...............] - ETA: 0s - loss: 0.2899 - acc: 0.8838
15424/25000 [=================>............] - ETA: 0s - loss: 0.2864 - acc: 0.8843
18016/25000 [====================>.........] - ETA: 0s - loss: 0.2847 - acc: 0.8850
20512/25000 [=======================>......] - ETA: 0s - loss: 0.2816 - acc: 0.8862
23168/25000 [==========================>...] - ETA: 0s - loss: 0.2807 - acc: 0.8864
25000/25000 [==============================] - 1s 22us/sample - loss: 0.2796 - acc: 0.8868
That got us up to 88.68%, the best yet, so the embedding layer bought us something.
I’ve never seen anybody else do this, but you can look inside the embedding weights themselves to get a sense of what’s going on. For space sake, we’ll just plot the embeddings of the first 500 words.
embeds.6 <- get_weights(emb.mod)[[1]]
rownames(embeds.6) <- c("??","??","??","??",top_words[1:4996])
colnames(embeds.6) <- paste("Dim",1:6)
pairs(embeds.6[1:500,], col=rgb(0,0,0,.5),cex=.5, main="Six-dimensional Embeddings Trained in Classification Model")
Ahhhhh… our 6-dimensional embeddings are pretty close to one-dimensional. Why? Because they are trained for the essentially one-dimensional classification of positive-negative sentiment.
We can tease this out with something like PCA:
pc.emb <- princomp(embeds.6)
summary(pc.emb)
Importance of components:
Comp.1 Comp.2
Standard deviation 0.3566176 0.05058622
Proportion of Variance 0.9316192 0.01874552
Cumulative Proportion 0.9316192 0.95036475
Comp.3 Comp.4
Standard deviation 0.04372121 0.04287706
Proportion of Variance 0.01400288 0.01346738
Cumulative Proportion 0.96436763 0.97783501
Comp.5 Comp.6
Standard deviation 0.03973241 0.03804071
Proportion of Variance 0.01156439 0.01060059
Cumulative Proportion 0.98939941 1.00000000
# Most positive words (main dimension of embeddings)
sort(-pc.emb$scores[,1],dec=T)[1:20]
excellent 7 perfect
1.660139 1.647258 1.629065
gem 8 favorite
1.490264 1.387230 1.360757
refreshing wonderfully superb
1.355163 1.303961 1.256989
captures amazing wonderful
1.209862 1.177840 1.133390
rare highly superbly
1.133343 1.121740 1.111449
marvelous noir perfectly
1.094616 1.089450 1.087435
finest favorites
1.064197 1.052959
# Most negative words (main dimension of embeddings)
sort(-pc.emb$scores[,1],dec=F)[1:20]
worst waste awful
-2.426991 -2.287825 -2.110015
poorly pointless dull
-1.897196 -1.736179 -1.693020
laughable avoid mess
-1.663206 -1.601402 -1.572703
disappointment disappointing worse
-1.530430 -1.523112 -1.520477
terrible unfunny badly
-1.484618 -1.475630 -1.440923
boring redeeming mst3k
-1.433651 -1.406507 -1.398853
lacks dreadful
-1.370802 -1.364966
# Most positive words (2nd dimension of embeddings)
sort(pc.emb$scores[,2],dec=T)[1:20]
didn't academy coming
0.1993631 0.1894763 0.1867834
piece must protagonists
0.1674755 0.1651370 0.1625642
towards delight so
0.1594424 0.1576414 0.1528698
reason through self
0.1481342 0.1480310 0.1461657
opening possibly to
0.1456153 0.1450613 0.1421859
victims known involving
0.1389948 0.1389144 0.1368349
point don't
0.1362795 0.1351084
# Most negative words (2nd dimension of embeddings)
sort(pc.emb$scores[,2],dec=F)[1:20]
bizarre takes rings become
-0.1997089 -0.1818287 -0.1748362 -0.1709388
frank descent steps those
-0.1627424 -0.1616023 -0.1497373 -0.1496597
students tons count mother
-0.1494445 -0.1474266 -0.1466301 -0.1449812
stop animal group segment
-0.1447104 -0.1437884 -0.1400553 -0.1348926
its toilet laughing father
-0.1347013 -0.1344255 -0.1320045 -0.1316030
# Plot weights
plot(colSums(x_train),-pc.emb$scores[,1], pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Principal Component of Embeddings Layer", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
2 x values <= 0 omitted from logarithmic plot
text(colSums(x_train),-pc.emb$scores[,1], rownames(embeds.6),pos=4,cex=1*abs(pc.emb$scores[,1]), col=rgb(0,0,0,.4*abs(pc.emb$scores[,1])))
This first embedding dimension is very similar to the weights we learned in the shallow logistic network:
cor(logmod.weights[3:4999],-pc.emb$scores[4:5000,1])
[1] 0.8978127
The second dimension captures something … perhaps in more nuanced criticism … that acted as a subtle bit of extra information for our classifier that provided some slight improvement in perforamce. In effect, by letting the model learn six numbers per word instead of one, we gave it the opportunity to learn some more subtle characteristics of review sentiment.
# Plot weights
plot(colSums(x_train),pc.emb$scores[,2], pch=19, col=rgb(0,0,0,.3), cex=.5, log="x", main="Second Principal Component of Embeddings Layer", ylab="<--- Negative Reviews --- Positive Reviews --->", xlab="Total Appearances")
2 x values <= 0 omitted from logarithmic plot
text(colSums(x_train),pc.emb$scores[,2], rownames(embeds.6),pos=4,cex=10*abs(pc.emb$scores[,2]), col=rgb(0,0,0,3*abs(pc.emb$scores[,2])))
The commented out code assumes you have downloaded the 822M zip file glove.6B.zip
from https://nlp.stanford.edu/projects/glove/ and unzipped the folder (more than 2G). It then reads in the smallest file – defining 50 dimensional embeddings for 400000 tokens – 171M, and then creates embedding_matrix
of 5000 x 50 for the 5000 top words in the imdb data (a 1.9M object).
I have included this object in an rds file, and the readRDS
command reads in this object directly.
# glove_dir <- "../Embeddings/glove.6B"
# lines <- readLines(file.path(glove_dir, "glove.6B.50d.txt"))
# embeddings_index <- new.env(hash = TRUE, parent = emptyenv())
# for (i in 1:length(lines)) {
# line <- lines[[i]]
# values <- strsplit(line, " ")[[1]]
# word <- values[[1]]
# embeddings_index[[word]] <- as.double(values[-1])
# }
# cat("Found", length(embeddings_index), "word vectors.\n")
max_words=5000
embedding_dim <- 50
# embedding_matrix <- array(0, c(max_words, embedding_dim))
# for (word in names(word_index)) {
# index <- word_index[[word]]
# if (index < max_words) {
# embedding_vector <- embeddings_index[[word]]
# if (!is.null(embedding_vector))
# embedding_matrix[index+1,] <- embedding_vector
# }
# }
# saveRDS(embedding_matrix, file="glove_imdb_example.rds")
embedding_matrix <- readRDS("glove_imdb_example.rds")
glove.mod <- keras_model_sequential() %>%
layer_embedding(input_dim = max_words, output_dim = embedding_dim,
input_length = maxlen) %>%
layer_flatten() %>%
layer_dense(units = 1, activation = "sigmoid")
summary(glove.mod)
Model: "sequential_7"
______________________________________________
Layer (type) Output Shape Param #
==============================================
embedding_2 (Embedd (None, 500, 50) 250000
______________________________________________
flatten_2 (Flatten) (None, 25000) 0
______________________________________________
dense_10 (Dense) (None, 1) 25001
==============================================
Total params: 275,001
Trainable params: 275,001
Non-trainable params: 0
______________________________________________
get_layer(glove.mod, index = 1) %>%
set_weights(list(embedding_matrix)) %>%
freeze_weights()
glove.mod %>% compile(
optimizer = "rmsprop",
loss = "binary_crossentropy",
metrics = c("acc")
)
glove.history <- glove.mod %>% fit(
x_train.s, y_train.s,
epochs = 6,
batch_size = 32,
validation_split = 0.2
)
Train on 20000 samples, validate on 5000 samples
Epoch 1/6
32/20000 [..............................] - ETA: 1:28 - loss: 0.7380 - acc: 0.4375
608/20000 [..............................] - ETA: 6s - loss: 1.0776 - acc: 0.4951
1120/20000 [>.............................] - ETA: 4s - loss: 1.0127 - acc: 0.5036
1696/20000 [=>............................] - ETA: 3s - loss: 0.9551 - acc: 0.5136
1888/20000 [=>............................] - ETA: 3s - loss: 0.9557 - acc: 0.5138
2272/20000 [==>...........................] - ETA: 3s - loss: 0.9370 - acc: 0.5167
2688/20000 [===>..........................] - ETA: 3s - loss: 0.9219 - acc: 0.5205
3136/20000 [===>..........................] - ETA: 2s - loss: 0.9209 - acc: 0.5217
3584/20000 [====>.........................] - ETA: 2s - loss: 0.9329 - acc: 0.5179
4032/20000 [=====>........................] - ETA: 2s - loss: 0.9219 - acc: 0.5191
4480/20000 [=====>........................] - ETA: 2s - loss: 0.9222 - acc: 0.5174
4928/20000 [======>.......................] - ETA: 2s - loss: 0.9222 - acc: 0.5172
5376/20000 [=======>......................] - ETA: 2s - loss: 0.9210 - acc: 0.5177
5856/20000 [=======>......................] - ETA: 2s - loss: 0.9206 - acc: 0.5162
6304/20000 [========>.....................] - ETA: 1s - loss: 0.9181 - acc: 0.5173
6752/20000 [=========>....................] - ETA: 1s - loss: 0.9151 - acc: 0.5185
7200/20000 [=========>....................] - ETA: 1s - loss: 0.9141 - acc: 0.5200
7648/20000 [==========>...................] - ETA: 1s - loss: 0.9105 - acc: 0.5205
8096/20000 [===========>..................] - ETA: 1s - loss: 0.9110 - acc: 0.5210
8576/20000 [===========>..................] - ETA: 1s - loss: 0.9076 - acc: 0.5220
9056/20000 [============>.................] - ETA: 1s - loss: 0.9071 - acc: 0.5245
9536/20000 [=============>................] - ETA: 1s - loss: 0.9098 - acc: 0.5235
9984/20000 [=============>................] - ETA: 1s - loss: 0.9063 - acc: 0.5242
10464/20000 [==============>...............] - ETA: 1s - loss: 0.9043 - acc: 0.5249
10912/20000 [===============>..............] - ETA: 1s - loss: 0.8997 - acc: 0.5262
11360/20000 [================>.............] - ETA: 1s - loss: 0.8996 - acc: 0.5262
11744/20000 [================>.............] - ETA: 1s - loss: 0.9040 - acc: 0.5256
12160/20000 [=================>............] - ETA: 1s - loss: 0.8997 - acc: 0.5277
12544/20000 [=================>............] - ETA: 0s - loss: 0.9018 - acc: 0.5277
12960/20000 [==================>...........] - ETA: 0s - loss: 0.9003 - acc: 0.5289
13344/20000 [===================>..........] - ETA: 0s - loss: 0.8989 - acc: 0.5293
13728/20000 [===================>..........] - ETA: 0s - loss: 0.8972 - acc: 0.5295
14144/20000 [====================>.........] - ETA: 0s - loss: 0.8963 - acc: 0.5308
14528/20000 [====================>.........] - ETA: 0s - loss: 0.8991 - acc: 0.5296
14944/20000 [=====================>........] - ETA: 0s - loss: 0.9008 - acc: 0.5294
15360/20000 [======================>.......] - ETA: 0s - loss: 0.8999 - acc: 0.5294
15776/20000 [======================>.......] - ETA: 0s - loss: 0.8989 - acc: 0.5304
16160/20000 [=======================>......] - ETA: 0s - loss: 0.8996 - acc: 0.5306
16512/20000 [=======================>......] - ETA: 0s - loss: 0.9003 - acc: 0.5309
16928/20000 [========================>.....] - ETA: 0s - loss: 0.9035 - acc: 0.5312
17344/20000 [=========================>....] - ETA: 0s - loss: 0.9015 - acc: 0.5317
17792/20000 [=========================>....] - ETA: 0s - loss: 0.9015 - acc: 0.5312
18272/20000 [==========================>...] - ETA: 0s - loss: 0.9036 - acc: 0.5304
18720/20000 [===========================>..] - ETA: 0s - loss: 0.9034 - acc: 0.5299
19200/20000 [===========================>..] - ETA: 0s - loss: 0.9047 - acc: 0.5294
19680/20000 [============================>.] - ETA: 0s - loss: 0.9037 - acc: 0.5299
20000/20000 [==============================] - 3s 151us/sample - loss: 0.9058 - acc: 0.5295 - val_loss: 0.9894 - val_acc: 0.5142
Epoch 2/6
32/20000 [..............................] - ETA: 2s - loss: 0.8618 - acc: 0.5625
448/20000 [..............................] - ETA: 2s - loss: 0.6714 - acc: 0.6384
864/20000 [>.............................] - ETA: 2s - loss: 0.6573 - acc: 0.6516
1312/20000 [>.............................] - ETA: 2s - loss: 0.6468 - acc: 0.6585
1600/20000 [=>............................] - ETA: 2s - loss: 0.6630 - acc: 0.6525
2016/20000 [==>...........................] - ETA: 2s - loss: 0.6600 - acc: 0.6493
2432/20000 [==>...........................] - ETA: 2s - loss: 0.6660 - acc: 0.6460
2784/20000 [===>..........................] - ETA: 2s - loss: 0.6854 - acc: 0.6376
3168/20000 [===>..........................] - ETA: 2s - loss: 0.6801 - acc: 0.6417
3552/20000 [====>.........................] - ETA: 2s - loss: 0.6820 - acc: 0.6441
3968/20000 [====>.........................] - ETA: 2s - loss: 0.6812 - acc: 0.6426
4384/20000 [=====>........................] - ETA: 2s - loss: 0.6885 - acc: 0.6410
4832/20000 [======>.......................] - ETA: 2s - loss: 0.6787 - acc: 0.6428
5280/20000 [======>.......................] - ETA: 1s - loss: 0.6856 - acc: 0.6426
5760/20000 [=======>......................] - ETA: 1s - loss: 0.6849 - acc: 0.6431
6240/20000 [========>.....................] - ETA: 1s - loss: 0.6843 - acc: 0.6434
6720/20000 [=========>....................] - ETA: 1s - loss: 0.6910 - acc: 0.6403
7200/20000 [=========>....................] - ETA: 1s - loss: 0.6936 - acc: 0.6372
7680/20000 [==========>...................] - ETA: 1s - loss: 0.6939 - acc: 0.6380
8160/20000 [===========>..................] - ETA: 1s - loss: 0.7001 - acc: 0.6348
8608/20000 [===========>..................] - ETA: 1s - loss: 0.7002 - acc: 0.6349
9056/20000 [============>.................] - ETA: 1s - loss: 0.7016 - acc: 0.6341
9504/20000 [=============>................] - ETA: 1s - loss: 0.7010 - acc: 0.6350
9952/20000 [=============>................] - ETA: 1s - loss: 0.7062 - acc: 0.6330
10400/20000 [==============>...............] - ETA: 1s - loss: 0.7060 - acc: 0.6334
10880/20000 [===============>..............] - ETA: 1s - loss: 0.7072 - acc: 0.6337
11360/20000 [================>.............] - ETA: 1s - loss: 0.7123 - acc: 0.6313
11840/20000 [================>.............] - ETA: 0s - loss: 0.7152 - acc: 0.6296
12320/20000 [=================>............] - ETA: 0s - loss: 0.7162 - acc: 0.6295
12800/20000 [==================>...........] - ETA: 0s - loss: 0.7186 - acc: 0.6283
13312/20000 [==================>...........] - ETA: 0s - loss: 0.7186 - acc: 0.6282
13792/20000 [===================>..........] - ETA: 0s - loss: 0.7177 - acc: 0.6296
14272/20000 [====================>.........] - ETA: 0s - loss: 0.7168 - acc: 0.6295
14752/20000 [=====================>........] - ETA: 0s - loss: 0.7178 - acc: 0.6290
15264/20000 [=====================>........] - ETA: 0s - loss: 0.7205 - acc: 0.6274
15744/20000 [======================>.......] - ETA: 0s - loss: 0.7213 - acc: 0.6264
16224/20000 [=======================>......] - ETA: 0s - loss: 0.7200 - acc: 0.6273
16704/20000 [========================>.....] - ETA: 0s - loss: 0.7204 - acc: 0.6273
17184/20000 [========================>.....] - ETA: 0s - loss: 0.7233 - acc: 0.6272
17664/20000 [=========================>....] - ETA: 0s - loss: 0.7244 - acc: 0.6268
18112/20000 [==========================>...] - ETA: 0s - loss: 0.7264 - acc: 0.6265
18592/20000 [==========================>...] - ETA: 0s - loss: 0.7265 - acc: 0.6271
19072/20000 [===========================>..] - ETA: 0s - loss: 0.7296 - acc: 0.6253
19552/20000 [============================>.] - ETA: 0s - loss: 0.7280 - acc: 0.6260
20000/20000 [==============================] - 3s 133us/sample - loss: 0.7298 - acc: 0.6247 - val_loss: 1.0100 - val_acc: 0.5286
Epoch 3/6
32/20000 [..............................] - ETA: 2s - loss: 0.8882 - acc: 0.5312
480/20000 [..............................] - ETA: 2s - loss: 0.5785 - acc: 0.6979
928/20000 [>.............................] - ETA: 2s - loss: 0.5722 - acc: 0.7069
1344/20000 [=>............................] - ETA: 2s - loss: 0.5547 - acc: 0.7217
1728/20000 [=>............................] - ETA: 2s - loss: 0.5673 - acc: 0.7135
2112/20000 [==>...........................] - ETA: 2s - loss: 0.5545 - acc: 0.7188
2592/20000 [==>...........................] - ETA: 2s - loss: 0.5692 - acc: 0.7118
3072/20000 [===>..........................] - ETA: 2s - loss: 0.5665 - acc: 0.7139
3552/20000 [====>.........................] - ETA: 1s - loss: 0.5755 - acc: 0.7106
3648/20000 [====>.........................] - ETA: 2s - loss: 0.5775 - acc: 0.7086
4064/20000 [=====>........................] - ETA: 2s - loss: 0.5823 - acc: 0.7052
4544/20000 [=====>........................] - ETA: 1s - loss: 0.5784 - acc: 0.7038
5024/20000 [======>.......................] - ETA: 1s - loss: 0.5841 - acc: 0.7006
5504/20000 [=======>......................] - ETA: 1s - loss: 0.5861 - acc: 0.6973
5984/20000 [=======>......................] - ETA: 1s - loss: 0.5928 - acc: 0.6950
6432/20000 [========>.....................] - ETA: 1s - loss: 0.5941 - acc: 0.6950
6848/20000 [=========>....................] - ETA: 1s - loss: 0.5936 - acc: 0.6958
7328/20000 [=========>....................] - ETA: 1s - loss: 0.5917 - acc: 0.6957
7776/20000 [==========>...................] - ETA: 1s - loss: 0.5960 - acc: 0.6925
8224/20000 [===========>..................] - ETA: 1s - loss: 0.6000 - acc: 0.6908
8704/20000 [============>.................] - ETA: 1s - loss: 0.6024 - acc: 0.6891
9184/20000 [============>.................] - ETA: 1s - loss: 0.6034 - acc: 0.6886
9664/20000 [=============>................] - ETA: 1s - loss: 0.6089 - acc: 0.6878
10144/20000 [==============>...............] - ETA: 1s - loss: 0.6095 - acc: 0.6867
10624/20000 [==============>...............] - ETA: 1s - loss: 0.6076 - acc: 0.6875
11104/20000 [===============>..............] - ETA: 1s - loss: 0.6120 - acc: 0.6862
11584/20000 [================>.............] - ETA: 0s - loss: 0.6109 - acc: 0.6879
12064/20000 [=================>............] - ETA: 0s - loss: 0.6130 - acc: 0.6883
12544/20000 [=================>............] - ETA: 0s - loss: 0.6144 - acc: 0.6877
13024/20000 [==================>...........] - ETA: 0s - loss: 0.6145 - acc: 0.6878
13504/20000 [===================>..........] - ETA: 0s - loss: 0.6174 - acc: 0.6870
13984/20000 [===================>..........] - ETA: 0s - loss: 0.6178 - acc: 0.6866
14464/20000 [====================>.........] - ETA: 0s - loss: 0.6211 - acc: 0.6844
14944/20000 [=====================>........] - ETA: 0s - loss: 0.6225 - acc: 0.6829
15392/20000 [======================>.......] - ETA: 0s - loss: 0.6232 - acc: 0.6826
15872/20000 [======================>.......] - ETA: 0s - loss: 0.6232 - acc: 0.6828
16352/20000 [=======================>......] - ETA: 0s - loss: 0.6279 - acc: 0.6800
16832/20000 [========================>.....] - ETA: 0s - loss: 0.6293 - acc: 0.6795
17312/20000 [========================>.....] - ETA: 0s - loss: 0.6293 - acc: 0.6790
17792/20000 [=========================>....] - ETA: 0s - loss: 0.6301 - acc: 0.6787
18272/20000 [==========================>...] - ETA: 0s - loss: 0.6316 - acc: 0.6788
18752/20000 [===========================>..] - ETA: 0s - loss: 0.6304 - acc: 0.6788
19232/20000 [===========================>..] - ETA: 0s - loss: 0.6327 - acc: 0.6774
19712/20000 [============================>.] - ETA: 0s - loss: 0.6369 - acc: 0.6758
20000/20000 [==============================] - 3s 130us/sample - loss: 0.6382 - acc: 0.6751 - val_loss: 0.9295 - val_acc: 0.5582
Epoch 4/6
32/20000 [..............................] - ETA: 2s - loss: 0.5421 - acc: 0.6875
480/20000 [..............................] - ETA: 2s - loss: 0.4418 - acc: 0.7792
928/20000 [>.............................] - ETA: 2s - loss: 0.4337 - acc: 0.7953
1376/20000 [=>............................] - ETA: 2s - loss: 0.4594 - acc: 0.7769
1824/20000 [=>............................] - ETA: 2s - loss: 0.4680 - acc: 0.7697
2208/20000 [==>...........................] - ETA: 2s - loss: 0.4781 - acc: 0.7645
2656/20000 [==>...........................] - ETA: 2s - loss: 0.4976 - acc: 0.7534
3104/20000 [===>..........................] - ETA: 1s - loss: 0.5046 - acc: 0.7500
3584/20000 [====>.........................] - ETA: 1s - loss: 0.5100 - acc: 0.7464
4064/20000 [=====>........................] - ETA: 1s - loss: 0.5076 - acc: 0.7463
4480/20000 [=====>........................] - ETA: 1s - loss: 0.5154 - acc: 0.7420
4896/20000 [======>.......................] - ETA: 1s - loss: 0.5259 - acc: 0.7359
5376/20000 [=======>......................] - ETA: 1s - loss: 0.5242 - acc: 0.7347
5856/20000 [=======>......................] - ETA: 1s - loss: 0.5289 - acc: 0.7324
6336/20000 [========>.....................] - ETA: 1s - loss: 0.5264 - acc: 0.7325
6784/20000 [=========>....................] - ETA: 1s - loss: 0.5322 - acc: 0.7304
7264/20000 [=========>....................] - ETA: 1s - loss: 0.5371 - acc: 0.7285
7744/20000 [==========>...................] - ETA: 1s - loss: 0.5371 - acc: 0.7279
8224/20000 [===========>..................] - ETA: 1s - loss: 0.5389 - acc: 0.7268
8672/20000 [============>.................] - ETA: 1s - loss: 0.5472 - acc: 0.7242
9152/20000 [============>.................] - ETA: 1s - loss: 0.5467 - acc: 0.7242
9632/20000 [=============>................] - ETA: 1s - loss: 0.5491 - acc: 0.7225
10112/20000 [==============>...............] - ETA: 1s - loss: 0.5475 - acc: 0.7233
10560/20000 [==============>...............] - ETA: 1s - loss: 0.5513 - acc: 0.7211
11040/20000 [===============>..............] - ETA: 1s - loss: 0.5518 - acc: 0.7212
11520/20000 [================>.............] - ETA: 0s - loss: 0.5544 - acc: 0.7194
12000/20000 [=================>............] - ETA: 0s - loss: 0.5559 - acc: 0.7182
12480/20000 [=================>............] - ETA: 0s - loss: 0.5589 - acc: 0.7169
12960/20000 [==================>...........] - ETA: 0s - loss: 0.5585 - acc: 0.7181
13440/20000 [===================>..........] - ETA: 0s - loss: 0.5606 - acc: 0.7170
13920/20000 [===================>..........] - ETA: 0s - loss: 0.5602 - acc: 0.7172
14400/20000 [====================>.........] - ETA: 0s - loss: 0.5633 - acc: 0.7156
14880/20000 [=====================>........] - ETA: 0s - loss: 0.5628 - acc: 0.7159
15360/20000 [======================>.......] - ETA: 0s - loss: 0.5662 - acc: 0.7137
15808/20000 [======================>.......] - ETA: 0s - loss: 0.5662 - acc: 0.7137
16288/20000 [=======================>......] - ETA: 0s - loss: 0.5682 - acc: 0.7129
16736/20000 [========================>.....] - ETA: 0s - loss: 0.5700 - acc: 0.7130
17216/20000 [========================>.....] - ETA: 0s - loss: 0.5718 - acc: 0.7121
17664/20000 [=========================>....] - ETA: 0s - loss: 0.5723 - acc: 0.7120
18144/20000 [==========================>...] - ETA: 0s - loss: 0.5738 - acc: 0.7120
18592/20000 [==========================>...] - ETA: 0s - loss: 0.5741 - acc: 0.7120
19072/20000 [===========================>..] - ETA: 0s - loss: 0.5758 - acc: 0.7108
19552/20000 [============================>.] - ETA: 0s - loss: 0.5769 - acc: 0.7103
20000/20000 [==============================] - 3s 130us/sample - loss: 0.5810 - acc: 0.7085 - val_loss: 0.9726 - val_acc: 0.5626
Epoch 5/6
32/20000 [..............................] - ETA: 2s - loss: 0.4575 - acc: 0.7812
480/20000 [..............................] - ETA: 2s - loss: 0.4780 - acc: 0.7563
928/20000 [>.............................] - ETA: 2s - loss: 0.4801 - acc: 0.7565
1344/20000 [=>............................] - ETA: 2s - loss: 0.4693 - acc: 0.7671
1728/20000 [=>............................] - ETA: 2s - loss: 0.5131 - acc: 0.7448
2144/20000 [==>...........................] - ETA: 2s - loss: 0.4931 - acc: 0.7551
2592/20000 [==>...........................] - ETA: 2s - loss: 0.5033 - acc: 0.7481
3040/20000 [===>..........................] - ETA: 2s - loss: 0.4943 - acc: 0.7510
3520/20000 [====>.........................] - ETA: 1s - loss: 0.4906 - acc: 0.7523
4000/20000 [=====>........................] - ETA: 1s - loss: 0.4904 - acc: 0.7535
4448/20000 [=====>........................] - ETA: 1s - loss: 0.4884 - acc: 0.7561
4928/20000 [======>.......................] - ETA: 1s - loss: 0.4894 - acc: 0.7571
5280/20000 [======>.......................] - ETA: 1s - loss: 0.4838 - acc: 0.7602
5664/20000 [=======>......................] - ETA: 1s - loss: 0.4864 - acc: 0.7578
6080/20000 [========>.....................] - ETA: 1s - loss: 0.4890 - acc: 0.7569
6528/20000 [========>.....................] - ETA: 1s - loss: 0.4861 - acc: 0.7589
6976/20000 [=========>....................] - ETA: 1s - loss: 0.4947 - acc: 0.7552
7424/20000 [==========>...................] - ETA: 1s - loss: 0.4936 - acc: 0.7561
7904/20000 [==========>...................] - ETA: 1s - loss: 0.4998 - acc: 0.7522
8352/20000 [===========>..................] - ETA: 1s - loss: 0.5018 - acc: 0.7522
8832/20000 [============>.................] - ETA: 1s - loss: 0.5025 - acc: 0.7519
9312/20000 [============>.................] - ETA: 1s - loss: 0.5016 - acc: 0.7514
9792/20000 [=============>................] - ETA: 1s - loss: 0.5037 - acc: 0.7511
10272/20000 [==============>...............] - ETA: 1s - loss: 0.5071 - acc: 0.7482
10752/20000 [===============>..............] - ETA: 1s - loss: 0.5060 - acc: 0.7490
11232/20000 [===============>..............] - ETA: 1s - loss: 0.5092 - acc: 0.7480
11712/20000 [================>.............] - ETA: 0s - loss: 0.5101 - acc: 0.7473
12192/20000 [=================>............] - ETA: 0s - loss: 0.5100 - acc: 0.7472
12672/20000 [==================>...........] - ETA: 0s - loss: 0.5148 - acc: 0.7446
13152/20000 [==================>...........] - ETA: 0s - loss: 0.5167 - acc: 0.7424
13632/20000 [===================>..........] - ETA: 0s - loss: 0.5177 - acc: 0.7418
14112/20000 [====================>.........] - ETA: 0s - loss: 0.5184 - acc: 0.7421
14560/20000 [====================>.........] - ETA: 0s - loss: 0.5176 - acc: 0.7417
15040/20000 [=====================>........] - ETA: 0s - loss: 0.5239 - acc: 0.7390
15520/20000 [======================>.......] - ETA: 0s - loss: 0.5258 - acc: 0.7376
16000/20000 [=======================>......] - ETA: 0s - loss: 0.5274 - acc: 0.7366
16480/20000 [=======================>......] - ETA: 0s - loss: 0.5305 - acc: 0.7356
16960/20000 [========================>.....] - ETA: 0s - loss: 0.5314 - acc: 0.7348
17440/20000 [=========================>....] - ETA: 0s - loss: 0.5323 - acc: 0.7350
17920/20000 [=========================>....] - ETA: 0s - loss: 0.5320 - acc: 0.7353
18400/20000 [==========================>...] - ETA: 0s - loss: 0.5324 - acc: 0.7353
18880/20000 [===========================>..] - ETA: 0s - loss: 0.5334 - acc: 0.7348
19360/20000 [============================>.] - ETA: 0s - loss: 0.5324 - acc: 0.7348
19840/20000 [============================>.] - ETA: 0s - loss: 0.5348 - acc: 0.7340
20000/20000 [==============================] - 3s 130us/sample - loss: 0.5351 - acc: 0.7340 - val_loss: 1.1786 - val_acc: 0.5578
Epoch 6/6
32/20000 [..............................] - ETA: 2s - loss: 0.3833 - acc: 0.8125
448/20000 [..............................] - ETA: 2s - loss: 0.4162 - acc: 0.8103
896/20000 [>.............................] - ETA: 2s - loss: 0.4317 - acc: 0.7913
1344/20000 [=>............................] - ETA: 2s - loss: 0.4474 - acc: 0.7805
1792/20000 [=>............................] - ETA: 2s - loss: 0.4643 - acc: 0.7762
2208/20000 [==>...........................] - ETA: 2s - loss: 0.4643 - acc: 0.7745
2688/20000 [===>..........................] - ETA: 2s - loss: 0.4665 - acc: 0.7712
3104/20000 [===>..........................] - ETA: 2s - loss: 0.4572 - acc: 0.7742
3552/20000 [====>.........................] - ETA: 1s - loss: 0.4520 - acc: 0.7779
4032/20000 [=====>........................] - ETA: 1s - loss: 0.4547 - acc: 0.7773
4480/20000 [=====>........................] - ETA: 1s - loss: 0.4529 - acc: 0.7768
4960/20000 [======>.......................] - ETA: 1s - loss: 0.4577 - acc: 0.7726
5440/20000 [=======>......................] - ETA: 1s - loss: 0.4519 - acc: 0.7765
5920/20000 [=======>......................] - ETA: 1s - loss: 0.4533 - acc: 0.7782
6016/20000 [========>.....................] - ETA: 1s - loss: 0.4541 - acc: 0.7778
6336/20000 [========>.....................] - ETA: 1s - loss: 0.4588 - acc: 0.7754
6784/20000 [=========>....................] - ETA: 1s - loss: 0.4639 - acc: 0.7718
7232/20000 [=========>....................] - ETA: 1s - loss: 0.4609 - acc: 0.7734
7680/20000 [==========>...................] - ETA: 1s - loss: 0.4593 - acc: 0.7737
8128/20000 [===========>..................] - ETA: 1s - loss: 0.4633 - acc: 0.7708
8576/20000 [===========>..................] - ETA: 1s - loss: 0.4613 - acc: 0.7722
9056/20000 [============>.................] - ETA: 1s - loss: 0.4615 - acc: 0.7712
9536/20000 [=============>................] - ETA: 1s - loss: 0.4668 - acc: 0.7687
10016/20000 [==============>...............] - ETA: 1s - loss: 0.4669 - acc: 0.7691
10464/20000 [==============>...............] - ETA: 1s - loss: 0.4678 - acc: 0.7690
10944/20000 [===============>..............] - ETA: 1s - loss: 0.4679 - acc: 0.7696
11424/20000 [================>.............] - ETA: 1s - loss: 0.4700 - acc: 0.7684
11904/20000 [================>.............] - ETA: 0s - loss: 0.4715 - acc: 0.7672
12384/20000 [=================>............] - ETA: 0s - loss: 0.4728 - acc: 0.7662
12864/20000 [==================>...........] - ETA: 0s - loss: 0.4772 - acc: 0.7638
13344/20000 [===================>..........] - ETA: 0s - loss: 0.4770 - acc: 0.7638
13856/20000 [===================>..........] - ETA: 0s - loss: 0.4790 - acc: 0.7631
14336/20000 [====================>.........] - ETA: 0s - loss: 0.4829 - acc: 0.7611
14816/20000 [=====================>........] - ETA: 0s - loss: 0.4821 - acc: 0.7615
15296/20000 [=====================>........] - ETA: 0s - loss: 0.4832 - acc: 0.7611
15776/20000 [======================>.......] - ETA: 0s - loss: 0.4832 - acc: 0.7619
16256/20000 [=======================>......] - ETA: 0s - loss: 0.4853 - acc: 0.7606
16736/20000 [========================>.....] - ETA: 0s - loss: 0.4861 - acc: 0.7603
17216/20000 [========================>.....] - ETA: 0s - loss: 0.4885 - acc: 0.7587
17664/20000 [=========================>....] - ETA: 0s - loss: 0.4896 - acc: 0.7585
18080/20000 [==========================>...] - ETA: 0s - loss: 0.4896 - acc: 0.7585
18560/20000 [==========================>...] - ETA: 0s - loss: 0.4907 - acc: 0.7581
19040/20000 [===========================>..] - ETA: 0s - loss: 0.4952 - acc: 0.7564
19520/20000 [============================>.] - ETA: 0s - loss: 0.4962 - acc: 0.7561
20000/20000 [==============================] - 3s 130us/sample - loss: 0.4974 - acc: 0.7560 - val_loss: 1.1034 - val_acc: 0.5508
plot(glove.history)
glove.results <- glove.mod %>% evaluate(x_test.s, y_test.s)
32/25000 [..............................] - ETA: 2s - loss: 0.9065 - acc: 0.7188
960/25000 [>.............................] - ETA: 1s - loss: 0.9632 - acc: 0.5865
1728/25000 [=>............................] - ETA: 1s - loss: 0.9979 - acc: 0.5845
2464/25000 [=>............................] - ETA: 1s - loss: 1.0050 - acc: 0.5735
3264/25000 [==>...........................] - ETA: 1s - loss: 1.0495 - acc: 0.5634
4160/25000 [===>..........................] - ETA: 1s - loss: 1.0518 - acc: 0.5635
4960/25000 [====>.........................] - ETA: 1s - loss: 1.0657 - acc: 0.5579
5760/25000 [=====>........................] - ETA: 1s - loss: 1.0701 - acc: 0.5552
6560/25000 [======>.......................] - ETA: 1s - loss: 1.0637 - acc: 0.5569
7232/25000 [=======>......................] - ETA: 1s - loss: 1.0630 - acc: 0.5568
8064/25000 [========>.....................] - ETA: 1s - loss: 1.0638 - acc: 0.5588
8864/25000 [=========>....................] - ETA: 1s - loss: 1.0646 - acc: 0.5572
9696/25000 [==========>...................] - ETA: 0s - loss: 1.0638 - acc: 0.5587
10656/25000 [===========>..................] - ETA: 0s - loss: 1.0777 - acc: 0.5560
11424/25000 [============>.................] - ETA: 0s - loss: 1.0830 - acc: 0.5548
12160/25000 [=============>................] - ETA: 0s - loss: 1.0821 - acc: 0.5544
13024/25000 [==============>...............] - ETA: 0s - loss: 1.0827 - acc: 0.5546
13920/25000 [===============>..............] - ETA: 0s - loss: 1.0833 - acc: 0.5555
14720/25000 [================>.............] - ETA: 0s - loss: 1.0810 - acc: 0.5569
15392/25000 [=================>............] - ETA: 0s - loss: 1.0847 - acc: 0.5558
16512/25000 [==================>...........] - ETA: 0s - loss: 1.0850 - acc: 0.5555
17664/25000 [====================>.........] - ETA: 0s - loss: 1.0846 - acc: 0.5546
18816/25000 [=====================>........] - ETA: 0s - loss: 1.0848 - acc: 0.5528
19968/25000 [======================>.......] - ETA: 0s - loss: 1.0862 - acc: 0.5534
21024/25000 [========================>.....] - ETA: 0s - loss: 1.0883 - acc: 0.5532
21856/25000 [=========================>....] - ETA: 0s - loss: 1.0868 - acc: 0.5537
22656/25000 [==========================>...] - ETA: 0s - loss: 1.0860 - acc: 0.5534
23360/25000 [===========================>..] - ETA: 0s - loss: 1.0870 - acc: 0.5529
23872/25000 [===========================>..] - ETA: 0s - loss: 1.0862 - acc: 0.5531
24608/25000 [============================>.] - ETA: 0s - loss: 1.0854 - acc: 0.5535
25000/25000 [==============================] - 1s 60us/sample - loss: 1.0862 - acc: 0.5531
Accuracy of 55% Pretty bad. These GloVe embeddings are trained on Wikipedia and the Gigaword corpus. The 50 most important dimensions, relating words to a narrow window of surrounding words, of this very general language set of corpora, are nowhere near as useful as the six (or even one) most important dimension(s) relating words to sentiment within our training data.
To get an idea what’s going on …
library(text2vec)
text2vec is still in beta version - APIs can be changed.
For tutorials and examples visit http://text2vec.org.
For FAQ refer to
1. https://stackoverflow.com/questions/tagged/text2vec?sort=newest
2. https://github.com/dselivanov/text2vec/issues?utf8=%E2%9C%93&q=is%3Aissue%20label%3Aquestion
If you have questions please post them at StackOverflow and mark with 'text2vec' tag.
Attaching package: ‘text2vec’
The following objects are masked from ‘package:keras’:
fit, normalize
find_similar_words <- function(word, embedding_matrix, n = 5) {
similarities <- embedding_matrix[word, , drop = FALSE] %>%
sim2(embedding_matrix, y = ., method = "cosine")
similarities[,1] %>% sort(decreasing = TRUE) %>% head(n)
}
The concept embodied by a word in this review context, e.g., “waste”:
find_similar_words("waste", embeds.6, n=10)
may not be at all what’s captured about the word in the GloVe embeddings:
rownames(embedding_matrix) <- c("?",top_words[1:4999])
find_similar_words("waste", embedding_matrix, n=10)
waste water garbage clean
1.0000000 0.7724197 0.7605300 0.7390320
gas amounts trash mine
0.7258156 0.6982243 0.6903367 0.6879122
natural oil
0.6776923 0.6602530
Or “gem”:
find_similar_words("gem", embeds.6, n=10)
gem favorite anime
1.0000000 0.9992570 0.9991338
wonderful unforgettable faults
0.9990917 0.9989748 0.9985772
incredible suspenseful realistic
0.9985302 0.9985083 0.9983043
beauty
0.9982520
find_similar_words("gem", embedding_matrix, n=10)
gem diamond precious treasure
1.0000000 0.7375269 0.6808224 0.6128838
priceless gold valuable pulp
0.5930778 0.5867127 0.5440980 0.5299630
golden magical
0.5298125 0.5195156
Or “wooden”:
find_similar_words("wooden", embeds.6, n=10)
wooden unconvincing obnoxious
1.0000000 0.9991876 0.9991796
ludicrous destroy supposed
0.9991443 0.9989238 0.9988930
redeeming unfunny uwe
0.9988248 0.9987805 0.9987153
mess
0.9987089
find_similar_words("wooden", embedding_matrix, n=10)
wooden walls doors glass
1.0000000 0.7539166 0.7357615 0.7291114
cardboard floor stone wood
0.7199160 0.7160277 0.7153557 0.7100430
attached empty
0.7053251 0.7023291
Or “moving”:
find_similar_words("moving", embeds.6, n=10)
moving essential unexpected
1.0000000 0.9992714 0.9985157
berlin balance enjoyed
0.9983773 0.9982949 0.9979986
conventional stunning love
0.9970112 0.9969180 0.9968535
importance
0.9964524
find_similar_words("moving", embedding_matrix, n=10)
moving turning through turn
1.0000000 0.8934571 0.8823608 0.8565476
into quickly move coming
0.8531339 0.8527476 0.8474165 0.8441274
beyond way
0.8403356 0.8368897
In other contexts, pretrained embeddings can be very useful. But not, it appears, this one, at least without further tweaking. A reasonable approach might use pretrained embeddings as a starting point, and allow them to move based on the data from your specific context.
OK, that’s enough for now.