This package includes an example Recurrent Neural Network. The package is loaded using:
library(rnn)
We can view the code of the main rnn()
function by calling it without the parathesis.
rnn
## function(binary_dim, alpha, input_dim, hidden_dim, output_dim, iterations=5000) {
##
## # check what largest possible number is
## largest_number = 2^binary_dim
##
## # initialize neural network weights
## synapse_0 = matrix(runif(n = input_dim*hidden_dim, min=-1, max=1), nrow=input_dim)
## synapse_1 = matrix(runif(n = hidden_dim*output_dim, min=-1, max=1), nrow=hidden_dim)
## synapse_h = matrix(runif(n = hidden_dim*hidden_dim, min=-1, max=1), nrow=hidden_dim)
##
## synapse_0_update = matrix(0, nrow = input_dim, ncol = hidden_dim)
## synapse_1_update = matrix(0, nrow = hidden_dim, ncol = output_dim)
## synapse_h_update = matrix(0, nrow = hidden_dim, ncol = hidden_dim)
##
## # training logic
## for (j in 1:iterations) {
##
## # generate a simple addition problem (a + b = c)
## a_int = sample(1:(largest_number/2), 1) # int version
## a = int2binary(a_int, binary_dim) # binary encoding
##
## b_int = sample(1:(largest_number/2), 1) # int version
## b = int2binary(b_int, binary_dim)
##
## # true answer
## c_int = a_int + b_int
## c = int2binary(c_int, binary_dim)
##
## # where we'll store our best guesss (binary encoded)
## d = matrix(0, nrow = 1, ncol = binary_dim)
##
## overallError = 0
##
## layer_2_deltas = matrix(0)
## layer_1_values = matrix(0, nrow=1, ncol = hidden_dim)
## # layer_1_values = rbind(layer_1_values, matrix(0, nrow=1, ncol=hidden_dim))
##
## # moving along the positions in the binary encoding
## for (position in 0:(binary_dim-1)) {
##
## # generate input and output
## X = cbind(a[binary_dim - position],b[binary_dim - position])
## y = c[binary_dim - position]
##
## # hidden layer (input ~+ prev_hidden)
## layer_1 = sigmoid((X%*%synapse_0) + (layer_1_values[dim(layer_1_values)[1],] %*% synapse_h))
##
## # output layer (new binary representation)
## layer_2 = sigmoid(layer_1 %*% synapse_1)
##
## # did we miss?... if so, by how much?
## layer_2_error = y - layer_2
## layer_2_deltas = rbind(layer_2_deltas, layer_2_error * sigmoid_output_to_derivative(layer_2))
## overallError = overallError + abs(layer_2_error)
##
## # decode estimate so we can print it out
## d[binary_dim - position] = round(layer_2)
##
## # store hidden layer so we can print it out
## layer_1_values = rbind(layer_1_values, layer_1)
## }
##
## future_layer_1_delta = matrix(0, nrow = 1, ncol = hidden_dim)
##
## for (position in 0:(binary_dim-1)) {
##
## X = cbind(a[position+1], b[position+1])
## layer_1 = layer_1_values[dim(layer_1_values)[1]-position,]
## prev_layer_1 = layer_1_values[dim(layer_1_values)[1]-(position+1),]
##
## # error at output layer
## layer_2_delta = layer_2_deltas[dim(layer_2_deltas)[1]-position,]
## # error at hidden layer
## layer_1_delta = (future_layer_1_delta %*% t(synapse_h) + layer_2_delta %*% t(synapse_1)) *
## sigmoid_output_to_derivative(layer_1)
##
## # let's update all our weights so we can try again
## synapse_1_update = synapse_1_update + matrix(layer_1) %*% layer_2_delta
## synapse_h_update = synapse_h_update + matrix(prev_layer_1) %*% layer_1_delta
## synapse_0_update = synapse_0_update + t(X) %*% layer_1_delta
##
## future_layer_1_delta = layer_1_delta
## }
##
## synapse_0 = synapse_0 + ( synapse_0_update * alpha )
## synapse_1 = synapse_1 + ( synapse_1_update * alpha )
## synapse_h = synapse_h + ( synapse_h_update * alpha )
##
## synapse_0_update = synapse_0_update * 0
## synapse_1_update = synapse_1_update * 0
## synapse_h_update = synapse_h_update * 0
##
## # print out progress
## if(j %% 500 ==0) {
## print(paste("Error:", overallError))
## print(paste("Pred:", paste(d, collapse = " ")))
## print(paste("True:", paste(c, collapse = " ")))
## out = 0
## for (x in 1:length(d)) {
## out[x] = rev(d)[x]*2^(x-1) }
## print(paste(a_int, "+", b_int, "=", sum(out)))
## print("----------------")
## }
## }
## }
## <environment: namespace:rnn>
As can be seen from the above, the model relies on three other functions that are included with the package.
The first one is int2binary
, which stands for integer to binary, and converts an integer number to its binary representation:
int2binary(146, length=8)
## [1] 1 0 0 1 0 0 1 0
The code for this function is:
int2binary
## function(x, length) {
## tail(rev(as.integer(intToBits(x))), length) }
## <environment: namespace:rnn>
The second function is sigmoid()
, which converts an integer to its sigmoid value.
(a <- sigmoid(3))
## [1] 0.9525741
The code for the sigmoid()
function is:
sigmoid
## function(x) {
## output = 1 / (1+exp(-x))
## return(output) }
## <environment: namespace:rnn>
The final function converts the sigmoid value of a number to its derivative.
sigmoid_output_to_derivative(a) # a was created above using sigmoid()
## [1] 0.04517666
Finally, we can inspect this code using:
sigmoid_output_to_derivative
## function(output) {
## return( output*(1-output) ) }
## <environment: namespace:rnn>
By setting a seed for the random number generator, we ensure replicability.
set.seed(123)
An example is included in the help file.
help('rnn')
This example is:
# using the default of 10,000 iterations
rnn(binary_dim = 8,
alpha = 0.1,
input_dim = 2,
hidden_dim = 10,
output_dim = 1 )
## [1] "Error: 4.010345563468"
## [1] "Pred: 0 0 0 0 0 0 0 0"
## [1] "True: 0 1 0 1 1 0 0 1"
## [1] "16 + 73 = 0"
## [1] "----------------"
## [1] "Error: 4.02920967352848"
## [1] "Pred: 1 1 1 1 1 1 1 1"
## [1] "True: 1 0 1 1 0 0 0 0"
## [1] "56 + 120 = 255"
## [1] "----------------"
## [1] "Error: 4.12557055200652"
## [1] "Pred: 0 0 0 0 0 0 0 0"
## [1] "True: 1 0 1 0 0 0 1 1"
## [1] "39 + 124 = 0"
## [1] "----------------"
## [1] "Error: 4.06576951236069"
## [1] "Pred: 1 0 1 1 1 1 1 0"
## [1] "True: 0 1 0 1 0 0 0 0"
## [1] "74 + 6 = 190"
## [1] "----------------"
## [1] "Error: 4.05885825857722"
## [1] "Pred: 0 0 0 0 0 0 0 0"
## [1] "True: 1 0 1 1 0 1 1 0"
## [1] "76 + 106 = 0"
## [1] "----------------"
## [1] "Error: 4.07237068740385"
## [1] "Pred: 0 1 1 1 0 0 0 0"
## [1] "True: 1 0 1 1 1 0 1 1"
## [1] "127 + 60 = 112"
## [1] "----------------"
## [1] "Error: 3.70719891391563"
## [1] "Pred: 0 0 0 0 0 1 1 1"
## [1] "True: 0 1 0 1 0 1 1 1"
## [1] "71 + 16 = 7"
## [1] "----------------"
## [1] "Error: 3.54035352343777"
## [1] "Pred: 1 0 0 1 1 1 1 0"
## [1] "True: 1 0 0 1 0 1 1 0"
## [1] "79 + 71 = 158"
## [1] "----------------"
## [1] "Error: 3.87571289227445"
## [1] "Pred: 0 0 0 1 1 0 1 0"
## [1] "True: 0 1 1 0 0 0 1 0"
## [1] "77 + 21 = 26"
## [1] "----------------"
## [1] "Error: 3.09484833222603"
## [1] "Pred: 1 0 1 1 1 1 1 1"
## [1] "True: 1 0 1 1 1 1 1 1"
## [1] "125 + 66 = 191"
## [1] "----------------"
It is interesting to vary the number of hidden units.
# using the default of 10,000 iterations
rnn(binary_dim = 8,
alpha = 0.1,
input_dim = 2,
hidden_dim = 3,
output_dim = 1,
iterations = 20000)
## [1] "Error: 3.92094537253012"
## [1] "Pred: 0 1 0 1 0 0 0 0"
## [1] "True: 0 1 0 1 1 0 0 0"
## [1] "87 + 1 = 80"
## [1] "----------------"
## [1] "Error: 4.04666554905518"
## [1] "Pred: 1 1 1 1 1 0 1 1"
## [1] "True: 1 0 1 0 0 1 1 0"
## [1] "114 + 52 = 251"
## [1] "----------------"
## [1] "Error: 3.958835551849"
## [1] "Pred: 0 0 0 0 0 0 1 0"
## [1] "True: 0 0 1 1 0 0 0 0"
## [1] "3 + 45 = 2"
## [1] "----------------"
## [1] "Error: 4.01120927634392"
## [1] "Pred: 1 1 1 1 1 0 1 1"
## [1] "True: 0 0 0 0 1 1 0 1"
## [1] "8 + 5 = 251"
## [1] "----------------"
## [1] "Error: 4.05754646026104"
## [1] "Pred: 1 1 0 0 1 0 1 0"
## [1] "True: 1 0 1 0 0 1 0 1"
## [1] "64 + 101 = 202"
## [1] "----------------"
## [1] "Error: 4.00467259262997"
## [1] "Pred: 0 1 1 1 0 1 0 0"
## [1] "True: 1 1 0 0 1 0 1 0"
## [1] "96 + 106 = 116"
## [1] "----------------"
## [1] "Error: 3.97262278830339"
## [1] "Pred: 0 1 1 1 0 0 1 0"
## [1] "True: 0 1 0 1 1 0 1 0"
## [1] "66 + 24 = 114"
## [1] "----------------"
## [1] "Error: 3.87812958107059"
## [1] "Pred: 1 1 0 1 1 0 0 0"
## [1] "True: 0 1 1 0 1 0 0 0"
## [1] "16 + 88 = 216"
## [1] "----------------"
## [1] "Error: 3.76827039076875"
## [1] "Pred: 0 1 1 1 1 1 0 1"
## [1] "True: 0 1 0 1 1 0 0 1"
## [1] "29 + 60 = 125"
## [1] "----------------"
## [1] "Error: 3.69864880242451"
## [1] "Pred: 1 1 1 1 1 0 1 1"
## [1] "True: 0 1 1 1 1 0 1 0"
## [1] "41 + 81 = 251"
## [1] "----------------"
## [1] "Error: 3.72417373919211"
## [1] "Pred: 0 0 1 1 1 1 1 1"
## [1] "True: 0 0 1 0 0 0 0 1"
## [1] "4 + 29 = 63"
## [1] "----------------"
## [1] "Error: 3.45828716714469"
## [1] "Pred: 0 0 0 1 0 1 0 0"
## [1] "True: 0 0 1 0 1 0 0 0"
## [1] "26 + 14 = 20"
## [1] "----------------"
## [1] "Error: 4.01122874067438"
## [1] "Pred: 0 1 1 1 1 1 1 0"
## [1] "True: 0 1 0 0 0 0 0 0"
## [1] "27 + 37 = 126"
## [1] "----------------"
## [1] "Error: 3.22607020244502"
## [1] "Pred: 1 0 1 0 1 1 1 0"
## [1] "True: 1 0 1 0 1 1 1 0"
## [1] "76 + 98 = 174"
## [1] "----------------"
## [1] "Error: 3.84316603595037"
## [1] "Pred: 0 1 1 1 0 0 1 0"
## [1] "True: 0 1 1 1 1 1 0 0"
## [1] "47 + 77 = 114"
## [1] "----------------"
## [1] "Error: 4.16821433128999"
## [1] "Pred: 1 0 1 1 1 0 1 1"
## [1] "True: 1 1 0 0 0 0 0 1"
## [1] "109 + 84 = 187"
## [1] "----------------"
## [1] "Error: 2.33225702334997"
## [1] "Pred: 1 0 0 1 1 0 1 1"
## [1] "True: 1 0 0 1 1 0 0 1"
## [1] "25 + 128 = 155"
## [1] "----------------"
## [1] "Error: 2.27988603657025"
## [1] "Pred: 0 1 0 0 1 1 1 1"
## [1] "True: 0 1 0 0 1 0 1 1"
## [1] "11 + 64 = 79"
## [1] "----------------"
## [1] "Error: 3.00513563960477"
## [1] "Pred: 0 0 1 0 1 0 0 0"
## [1] "True: 0 0 1 1 0 0 0 1"
## [1] "26 + 23 = 40"
## [1] "----------------"
## [1] "Error: 3.34716635145008"
## [1] "Pred: 0 1 1 1 0 0 0 0"
## [1] "True: 0 1 1 0 0 1 0 1"
## [1] "7 + 94 = 112"
## [1] "----------------"
## [1] "Error: 3.2360257563136"
## [1] "Pred: 0 1 0 0 1 1 1 0"
## [1] "True: 1 0 0 0 1 1 0 1"
## [1] "72 + 69 = 78"
## [1] "----------------"
## [1] "Error: 3.03463950916919"
## [1] "Pred: 1 1 1 1 0 1 1 1"
## [1] "True: 1 0 0 1 0 1 0 1"
## [1] "48 + 101 = 247"
## [1] "----------------"
## [1] "Error: 1.95257318377524"
## [1] "Pred: 1 0 1 1 0 1 1 0"
## [1] "True: 1 0 1 1 0 1 1 0"
## [1] "96 + 86 = 182"
## [1] "----------------"
## [1] "Error: 4.11697073433738"
## [1] "Pred: 1 0 0 0 0 0 0 0"
## [1] "True: 1 1 1 0 0 1 0 1"
## [1] "126 + 103 = 128"
## [1] "----------------"
## [1] "Error: 2.19006366358445"
## [1] "Pred: 0 1 1 0 1 0 0 0"
## [1] "True: 0 1 1 0 1 0 0 0"
## [1] "103 + 1 = 104"
## [1] "----------------"
## [1] "Error: 2.99267071257363"
## [1] "Pred: 1 0 0 1 1 1 1 0"
## [1] "True: 1 0 0 0 1 1 1 0"
## [1] "46 + 96 = 158"
## [1] "----------------"
## [1] "Error: 3.6491616956578"
## [1] "Pred: 1 0 0 1 0 0 0 1"
## [1] "True: 1 0 0 1 1 1 0 1"
## [1] "111 + 46 = 145"
## [1] "----------------"
## [1] "Error: 3.16124301610558"
## [1] "Pred: 0 1 0 0 0 0 0 1"
## [1] "True: 0 1 0 0 1 0 0 1"
## [1] "42 + 31 = 65"
## [1] "----------------"
## [1] "Error: 1.86639289236952"
## [1] "Pred: 0 0 1 1 1 1 1 1"
## [1] "True: 0 0 1 1 1 1 1 1"
## [1] "30 + 33 = 63"
## [1] "----------------"
## [1] "Error: 3.90259175859386"
## [1] "Pred: 1 0 0 0 0 1 0 0"
## [1] "True: 1 1 1 1 0 1 0 0"
## [1] "123 + 121 = 132"
## [1] "----------------"
## [1] "Error: 2.58913084378957"
## [1] "Pred: 1 0 0 1 1 0 0 1"
## [1] "True: 1 0 0 1 1 0 0 1"
## [1] "103 + 50 = 153"
## [1] "----------------"
## [1] "Error: 3.11819733269934"
## [1] "Pred: 1 0 1 0 1 0 0 0"
## [1] "True: 1 0 1 0 0 0 0 0"
## [1] "87 + 73 = 168"
## [1] "----------------"
## [1] "Error: 1.21937745325075"
## [1] "Pred: 0 1 0 1 1 0 0 1"
## [1] "True: 0 1 0 1 1 0 0 1"
## [1] "64 + 25 = 89"
## [1] "----------------"
## [1] "Error: 3.92137735633988"
## [1] "Pred: 1 0 0 0 0 0 0 1"
## [1] "True: 1 0 1 1 1 0 0 1"
## [1] "60 + 125 = 129"
## [1] "----------------"
## [1] "Error: 3.03678943553476"
## [1] "Pred: 1 0 0 1 1 1 1 1"
## [1] "True: 1 0 0 1 0 1 1 1"
## [1] "113 + 38 = 159"
## [1] "----------------"
## [1] "Error: 2.52485337857573"
## [1] "Pred: 1 0 0 1 1 1 1 1"
## [1] "True: 1 0 0 1 0 1 1 1"
## [1] "86 + 65 = 159"
## [1] "----------------"
## [1] "Error: 0.686101476120311"
## [1] "Pred: 0 1 0 1 0 0 0 1"
## [1] "True: 0 1 0 1 0 0 0 1"
## [1] "17 + 64 = 81"
## [1] "----------------"
## [1] "Error: 2.17545526595644"
## [1] "Pred: 1 1 1 1 1 1 1 0"
## [1] "True: 0 1 1 1 1 1 1 0"
## [1] "38 + 88 = 254"
## [1] "----------------"
## [1] "Error: 2.54896625392598"
## [1] "Pred: 1 0 0 0 1 0 0 1"
## [1] "True: 1 0 0 0 1 0 0 1"
## [1] "114 + 23 = 137"
## [1] "----------------"
## [1] "Error: 1.87376443422062"
## [1] "Pred: 1 0 0 1 0 1 1 0"
## [1] "True: 1 1 0 1 0 1 1 0"
## [1] "118 + 96 = 150"
## [1] "----------------"