| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| #------------------------------------------------------------- |
| # A simple example using the attention layer |
| # for single-headed self attention |
| # in combination with the LSTM recurrent layer. |
| # |
| # It use the clickbait dataset |
| # (https://www.kaggle.com/datasets/amananandrai/clickbait-dataset?select=clickbait_data.csv) |
| # which is a simple binary text classification with 32000 samples. |
| # |
| # To load the data set and insert it into the right folder, simply run the download_attentionExample.sh |
| # |
| #------------------------------------------------------------- |
| |
| |
| source("nn/layers/attention.dml") as attention |
| source("nn/layers/affine.dml") as affine |
| source("nn/layers/lstm.dml") as lstm |
| source("nn/layers/relu.dml") as relu |
| source("nn/layers/sigmoid.dml") as sigmoid |
| source("nn/optim/adam.dml") as adam |
| source("nn/layers/log_loss.dml") as log_loss |
| |
| |
| # 1 get data |
| data_loc = "scripts/nn/examples/data/" |
| tableschema = "string,int" |
| N=32000 # Samples of whole dataset |
| n=1000 # Samples to use for training |
| max_length = 32 # maximum sequence length |
| epochs = 10 |
| batch_size = 32 |
| val_size = 100 |
| |
| data = read(data_loc + "clickbait_data.csv", format="csv", header=TRUE, sep=",", data_type="frame", schema=tableschema, cols=2, rows=N) |
| |
| |
| [x_train, y_train, vocab_size] = preprocess(data, max_length, N) |
| |
| x_train = x_train[1:n] |
| y_train = y_train[1:n] |
| |
| # train network |
| |
| [biases, weights] = train(x_train, y_train, epochs, batch_size, max_length, vocab_size, val_size) |
| |
| |
| preprocess = function(frame[unknown] data, integer max_length, integer n) |
| return (matrix[double] features, matrix[double] targets, integer vocab_size) |
| { |
| /* |
| * Preprocess the raw text data into integer tokens, shuffles data and targets. |
| * |
| * Inputs: |
| * - data: dataframe with [string, int] schema and n rows. |
| * - max_length: maximum sequence length we use for training |
| * - n: number of samples. |
| * |
| * Outputs: |
| * - features: feature matrix of shape (n, max_sequence_length) |
| * - targets: labels vector of shape (n,1) |
| * - vocab_size: vocabulary size, used to define size of embedding matrix during training. |
| */ |
| |
| # map to lowercase, remove non alphanumeric characters |
| formatted = map(data[,1], "x -> x.toLowerCase().replaceAll(\"[^\\p{Alnum}]+\", \" \").replaceAll(\"[\\s]+\", \" \")") |
| ids = as.frame(seq(1,nrow(formatted),1)) |
| formatted = cbind(ids, formatted) |
| |
| # prepare feature matrix for lstm into one-hot-encoded sequences |
| spec = "{\"algo\" : \"split\", \"out\": \"position\", \"tokenize_col\": 2, \"id_cols\": [1]}" |
| tokenized = tokenize(target=formatted, spec=spec, max_tokens=max_length) |
| recode_spec = "{ \"recode\": [C3]}" |
| [tokens, mapping] = transformencode(target=tokenized, spec=recode_spec) |
| features = matrix(0, rows=n, cols=max_length) |
| row_old = as.scalar(tokens[1,1]) |
| pos = 1 |
| for(i in 1:nrow(tokens)) |
| { |
| row = as.scalar(tokens[i,1]) |
| if (row != row_old) |
| { |
| row_old = row |
| pos = 1 |
| } |
| features[row,pos] = tokens[i,3] |
| pos += 1 |
| } |
| features = replace(target=features, pattern = NaN, replacement = -1) |
| features = features + 2 |
| vocab_size = as.integer(max(features)) |
| |
| targets = as.matrix(data[,2]) |
| |
| #shuffle data |
| r = rand(rows=n, cols=1, min=0, max=1, pdf="uniform") |
| x = order(target=cbind(r,features), by=1) |
| y = order(target=cbind(r,targets), by=1) |
| features = x[,2:ncol(x)] |
| targets = y[,2:ncol(y)] |
| } |
| |
| train = function( matrix[double] x_train, |
| matrix[double] y_train, |
| integer epochs, |
| integer batch_size, |
| integer max_sequence_length, |
| integer vocab_size, |
| integer val_size |
| ) |
| return(List[unknown] biases, List[unknown] weights) |
| { |
| /* |
| * Trains our example model |
| * |
| * Inputs: |
| * - x_train: training features, matrix of shape (N,max_sequence_length). |
| * - y_train: training labels, matrix of shape (N,1). |
| * - epochs: number of epochs to train our model. |
| * - batch_size: batch size we use in each iteration. |
| * - max_length: maximum sequence length of data. |
| * - vocab_size: Size of our considered vocabulary. |
| * - val_size: Size of the validation set, which is subtracted from x_train and y_train. |
| * |
| * Outputs: |
| * - biases: list of biases. |
| * - weights: list of weights. |
| */ |
| samples = nrow(x_train) |
| print("Start Training") |
| |
| #validation split |
| x_val = x_train[1:val_size] |
| y_val = y_train[1:val_size] |
| |
| x_train = x_train[val_size+1:samples] |
| y_train = y_train[val_size+1:samples] |
| |
| samples = nrow(x_train) |
| features = ncol(x_train) |
| output_size = 1 |
| |
| # We use a trainable embedding, each row is an embedding for a word |
| embedding_size = 64 |
| W_E = rand(rows=vocab_size, cols=embedding_size) |
| |
| # 1 lstm layer |
| lstm_neurons = 150 |
| [W_0, b_0, out0, c0] = lstm::init(batch_size, embedding_size, lstm_neurons) |
| |
| # 2 attention layer: learnable query with half max_sequence_length |
| [W_query, b_query] = affine::init(max_sequence_length*lstm_neurons,max_sequence_length*lstm_neurons/2,-1) |
| [W_key, b_key] = affine::init(max_sequence_length*lstm_neurons,max_sequence_length*lstm_neurons,-1) |
| |
| |
| # 3 dense layer -> (hidden_size) |
| hidden_neurons = 128 |
| |
| [W_1, b_1] = affine::init(max_sequence_length/2 * lstm_neurons, hidden_neurons, -1) |
| |
| # 4 dense layer -> (output_size) |
| [W_2, b_2] = affine::init(hidden_neurons, output_size, -1) |
| |
| # 5 sigmoid layer: no weights |
| |
| # put weights & biases into list |
| biases = list(b_0, b_1, b_2, b_query, b_key) |
| weights = list(W_0, W_1, W_2, W_E, W_query, W_key) |
| |
| #optimizer init |
| [mW_E, vW_E] = adam::init(W_E) |
| |
| [mW_query, vW_query] = adam::init(W_query) |
| [mb_query, vb_query] = adam::init(b_query) |
| |
| [mW_key, vW_key] = adam::init(W_key) |
| [mb_key, vb_key] = adam::init(b_key) |
| |
| [mW_0, vW_0] = adam::init(W_0) |
| [mW_1, vW_1] = adam::init(W_1) |
| [mW_2, vW_2] = adam::init(W_2) |
| |
| [mb_0, vb_0] = adam::init(b_0) |
| [mb_1, vb_1] = adam::init(b_1) |
| [mb_2, vb_2] = adam::init(b_2) |
| |
| #optimizer params |
| lr = 0.001 |
| beta1 = 0.99 |
| beta2 = 0.999 |
| epsilon = 1e-8 |
| t = 0 |
| |
| #allocate matrices for attention layer |
| dQuery = matrix(0, rows=batch_size, cols=max_sequence_length*embedding_size) |
| dValue = matrix(0, rows=batch_size, cols=max_sequence_length*embedding_size) |
| dKey = matrix(0, rows=batch_size, cols=max_sequence_length*embedding_size) |
| out2 = matrix(0, rows=batch_size, cols=max_sequence_length*embedding_size) |
| |
| #training loop |
| iters = ceil(samples/batch_size) |
| for (ep in 1:epochs) |
| { |
| print("Start ep: " + ep) |
| for (i in 1:iters) |
| { |
| print("Iteration: " + i) |
| # 1 Get batch data |
| start = ((i-1) * batch_size) %% samples + 1 |
| end = min(samples, start + batch_size -1) |
| |
| x_batch = x_train[start:end,] |
| y_batch = y_train[start:end,] |
| |
| # 2 predict |
| [y_hat, out5, out4, out3, out2, out1, query, key, emb, cache_out_out, cache_c_out, cache_ifog_out] = |
| predict(x_batch, biases, weights, max_sequence_length, embedding_size, lstm_neurons, out2) |
| |
| # 3 backpropagation |
| dout = log_loss::backward(y_hat, y_batch) |
| dprobs = sigmoid::backward(dout, out5) |
| [dout_2, dW_2, db_2] = affine::backward(dprobs, out4, W_2, b_2) |
| drelu = relu::backward(dout_2, out3) |
| [dout_1, dW_1, db_1] = affine::backward(drelu, out2, W_1, b_1) |
| [dQuery, dValue, dKey] = attention::backward(dattention=dout_1, |
| query=query, |
| key=key, |
| value=out1, |
| D=max_sequence_length, |
| dquery=dQuery, |
| dvalue=dValue, |
| dkey=dKey) |
| [dq, dW_query, db_query] = affine::backward(dQuery, out1, W_query, b_query) |
| [dk, dW_key, db_key] = affine::backward(dKey, out1, W_key, b_key) |
| dc = matrix(0, rows=nrow(x_batch), cols=lstm_neurons) |
| out0 = dc |
| c0 = dc |
| [dEmb, dW_0, db_0, dout0, dc0] = lstm::backward(dValue + dq + dk, |
| dc, |
| emb, |
| W_0, |
| b_0, |
| max_sequence_length, |
| embedding_size, |
| TRUE, |
| out0, |
| c0, |
| cache_out_out, |
| cache_c_out, |
| cache_ifog_out) |
| |
| # 4 update weights & biases |
| t = ep * i - 1 |
| # #embedding |
| [W_E, mW_E, vW_E] = update_embeddings(x_batch, dEmb, W_E, mW_E, vW_E, |
| lr, beta1, beta2, epsilon, t, max_sequence_length, embedding_size) |
| |
| # lstm |
| [b_0, mb_0, vb_0] = adam::update(b_0, db_0, lr, beta1, beta2, epsilon, t, mb_0, vb_0) |
| [W_0, mW_0, vW_0] = adam::update(W_0, dW_0, lr, beta1, beta2, epsilon, t, mW_0, vW_0) |
| |
| # affine query |
| [W_query, mW_query, vW_query] = adam::update(W_query, dW_query, lr, beta1, beta2, epsilon, t, mW_query, vW_query) |
| [b_query, mb_query, vb_query] = adam::update(b_query, db_query, lr, beta1, beta2, epsilon, t, mb_query, vb_query) |
| |
| # affine key |
| [W_key, mW_key, vW_key] = adam::update(W_key, dW_key, lr, beta1, beta2, epsilon, t, mW_key, vW_key) |
| [b_key, mb_key, vb_key] = adam::update(b_key, db_key, lr, beta1, beta2, epsilon, t, mb_key, vb_key) |
| |
| |
| # hidden affine |
| [b_1, mb_1, vb_1] = adam::update(b_1, db_1, lr, beta1, beta2, epsilon, t, mb_1, vb_1) |
| [W_1, mW_1, vW_1] = adam::update(W_1, dW_1, lr, beta1, beta2, epsilon, t, mW_1, vW_1) |
| |
| # output affine |
| [b_2, mb_2, vb_2] = adam::update(b_2, db_2, lr, beta2, beta2, epsilon, t, mb_2, vb_2) |
| [W_2, mW_2, vW_2] = adam::update(W_2, dW_2, lr, beta2, beta2, epsilon, t, mW_2, vW_2) |
| |
| # put weights & biases into list |
| biases = list(b_0,b_1,b_2,b_query, b_key) |
| weights = list(W_0,W_1,W_2,W_E,W_query, W_key) |
| } |
| [loss, accuracy] = evaluate(x_train, y_train, biases, weights, lstm_neurons, max_sequence_length, embedding_size, out2) |
| [val_loss, val_accuracy] = evaluate(x_val, y_val, biases, weights, lstm_neurons, max_sequence_length, embedding_size, out2) |
| print("Epoch: " + ep + "; Train Loss: " + loss + "; Train Acc: " + accuracy +"; Val. Loss: " + val_loss + "; Val. Accuracy: " + val_accuracy) |
| } |
| } |
| |
| predict = function( matrix[double] x, |
| List[unknown] biases, |
| List[unknown] weights, |
| integer max_sequence_length, |
| integer embedding_size, |
| integer lstm_neurons, |
| matrix[double] out2 |
| ) |
| return (matrix[double] y_hat, matrix[double] out5, matrix[double] out4, matrix[double] out3, |
| matrix[double] out2, matrix[double] out1, matrix[double] query, matrix[double] key, matrix[double] emb, matrix[double] cache_out_out, |
| matrix[double] cache_c_out, matrix[double] cache_ifog_out) |
| { |
| /* |
| * Predicts an output y_hat for given samples x. |
| * |
| * Inputs: |
| * - x: sample features of shape(batch_size, max_sequence_length) |
| * - biases: list of biases of length 3 (lstm, affine, affine) |
| * - weights: list of weights of length 4 (lstm, affine, affine, embedding) |
| * - max_sequence_length: number of words per sample. |
| * - embedding_size: size of embedding vector for 1 word |
| * - lstm_neurons: number of neurons in lstm layer. |
| * - out2: matrix of shape (batch_size, max_sequence_length*embedding_size) as attention for attention layer. |
| * |
| * Outputs: |
| * - y_hat: matrix of shape(batch_size, 1), prediction for log-loss, output of sigmoid layer |
| * - out5: output of final affine layer, shape(batch_size, 1) |
| * - out4: output of relu layer |
| * - out3: output of hidden affine layer |
| * - out2: output of attention layer |
| * - out1: output states from lstm layer, of shape(batch_size, max_sequence_length * lstm_neurons) |
| * - query: transformation of out1 by affine layer, of shape(batch_size, max_sequence_length * lstm_neurons/2) |
| * - key: transformation of out1 by affine layer, of shape(batch_size, max_sequence_length * lstm_neurons) |
| * - cache_out_out: cache_out output from lstm layer |
| * - cahce_c_out: cache_c output from lstm layer |
| * - cache_ifog_out: cahce_ifog output from lstm layer |
| */ |
| |
| # unpack weights & biases |
| W_0 = as.matrix(weights[1]) |
| W_1 = as.matrix(weights[2]) |
| W_2 = as.matrix(weights[3]) |
| W_E = as.matrix(weights[4]) |
| W_query = as.matrix(weights[5]) |
| W_key = as.matrix(weights[6]) |
| |
| b_0 = as.matrix(biases[1]) |
| b_1 = as.matrix(biases[2]) |
| b_2 = as.matrix(biases[3]) |
| b_query = as.matrix(biases[4]) |
| b_key = as.matrix(biases[5]) |
| |
| # fetch embedding |
| emb = fetch_embeddings(x, W_E, max_sequence_length, embedding_size) |
| # put input through layers |
| batch_size = nrow(x) |
| out0 = matrix(0, batch_size, lstm_neurons) |
| c0 = out0 |
| [out1, c_out, cache_out_out, cache_c_out, cache_ifog_out]= |
| lstm::forward(emb, W_0, b_0, max_sequence_length, embedding_size, TRUE, out0, c0) |
| query = affine::forward(out1, W_query, b_query) |
| key = affine::forward(out1, W_key, b_key) |
| out2 = attention::forward(query=query,key=key, value=out1, D=max_sequence_length, attention=out2) |
| out3 = affine::forward(out2, W_1, b_1) |
| out4 = relu::forward(out3) |
| out5 = affine::forward(out4, W_2, b_2) |
| y_hat = sigmoid::forward(out5) |
| } |
| |
| fetch_embeddings = function(matrix[double] indexes, matrix[double] W_E, |
| integer max_sequence_length, integer embedding_size) |
| return(matrix[double] emb) |
| { |
| /* |
| * Fetches embeddings for given tokens (indexes). |
| * |
| * Inputs: |
| * - indexes: tokens for fetching embeddings, shape(batch_size, max_sequence_length) |
| * - W_E: trainable embedding matrix of shape(vocab_size, embedding_size) |
| * - max_sequence_lengt: number of words per sample. |
| * - embedding_size: size of an embedding vector for 1 word. |
| * |
| * Outputs: |
| * - emb: embedded version of indexes of shape(batch_size, max_sequence_length * embedding_size) |
| */ |
| |
| emb = matrix(0, rows=nrow(indexes), cols=embedding_size*max_sequence_length) |
| for (i in 1:nrow(indexes)) |
| { |
| for (j in 1:max_sequence_length) |
| { |
| index = as.integer(as.scalar(indexes[i,j])) |
| emb[i,(j-1)*embedding_size+1:j*embedding_size] = W_E[index] |
| } |
| } |
| } |
| |
| update_embeddings = function(matrix[double] indexes, matrix[double] dEmb, matrix[double] W_E, |
| matrix[double] mW_E, matrix[double] vW_E, double lr, double beta1, double beta2, |
| double epsilon, integer t, integer max_sequence_length, integer embedding_size) |
| return (matrix[double] W_E, matrix[double] mW_E, matrix[double] vW_E) |
| { |
| /* |
| * Updates embedding matrix for given tokens (indexes). |
| * |
| * Inputs: |
| * - indexes: tokens for fetching embeddings, shape(batch_size, max_sequence_length) |
| * - dEmb: gradient from upstream of shape(batch_size, max_sequence_length * embedding_size) |
| * - W_E: trainable embedding matrix of shape(vocab_size, embedding_size) |
| * - mW_E: m variable (1st moment estimate) for adam optimizer |
| * - vW_E: v variable (2nd moment estimate) for adam optimizer |
| * - lr: learning rate |
| * - beta1: exponential decay rate for 1st moment estimate for adam optimizer |
| * - beta2: exponential decay rate for 2nd moment estimate for adam optimizer |
| * - epsilon: for numerical stability of adam optimizer |
| * - t: timestep for adam optimizer |
| * - max_sequence_length: number of words per sample. |
| * - embedding_size: size of an embedding vector for 1 word. |
| * |
| * Outputs: |
| * - W_E: updated trainable embedding matrix of shape(vocab_size, embedding_size) |
| * - mW_E: updated m variable (1st moment estimate) for adam optimizer |
| * - vW_E: updated v variable (2nd moment estimate) for adam optimizer |
| */ |
| for (i in 1:nrow(indexes)) |
| { |
| for (j in 1:max_sequence_length) |
| { |
| index = as.integer(as.scalar(indexes[i,j])) |
| [W_Ei, mW_Ei, vW_Ei] = adam::update( |
| W_E[index], dEmb[i,(j-1)*embedding_size+1:j*embedding_size], lr, beta1, beta2, |
| epsilon, t, mW_E[index], vW_E[index]) |
| W_E[index] = W_Ei |
| mW_E[index] = mW_Ei |
| vW_E[index] = vW_Ei |
| } |
| } |
| } |
| |
| evaluate = function(matrix[double] x, matrix[double] y, |
| list[unknown] biases, list[unknown] weights, integer lstm_neurons, |
| integer max_sequence_length, integer embedding_size, matrix[double] out2) |
| return(double loss, double accuracy) |
| { |
| /* |
| * Evaluate fit by calculating log_loss and accuracy using predict function. |
| * |
| * Inputs: |
| * - x: feature matrix of shape(batch_size, max_sequence_length) |
| * - y: target matrix for x of shape(batch_size, 1) |
| * - biases: list of biases |
| * - weights: list of weights |
| * - lstm_neurons: number of neurons used in lstm layer |
| * - max_sequence_length: number of words per sample. |
| * - embedding_size: size of an embedding vector for 1 word. |
| * - out2: matrix of shape (batch_size, max_sequence_length*embedding_size) as attention for attention layer. |
| * |
| * Outputs: |
| * - loss: log_loss of prediction |
| * - accuracy: percentage of correct classifications |
| */ |
| batch_size = nrow(x) |
| [y_hat, out5, out4, out3, out2, out1, query, key, emb, cache_out_out, cache_c_out, cache_ifog_out] = |
| predict(x, biases, weights, max_sequence_length, embedding_size, lstm_neurons, out2) |
| loss = log_loss::forward(y_hat, y) |
| |
| z = y_hat >= 0.5 |
| accuracy = 1 - sum(abs(z - y)) / batch_size |
| } |
| |
| |