blob: 22afb39a7455bdab6a978831d4724e184f033483 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
forward = function(matrix[double] indices, matrix[double] embedding_dict)
return (matrix[double] embeddings) {
/*
* Forward pass of an embedding layer. An embedding matrix is constructed
* from indices and corresponding embedding vectors from the embedding
* dictionary.
*
* Inputs:
* - indices: Indices referring to embedding vectors of embedding dictionary
* of shape n x 1 with each value in {1, ..., v}.
* - embedding_dict: Dictionary of embedding vectors of shape v x d.
*
* Outputs:
* - embeddings: Embedding matrix where row i is equal to
* embedding_dict[indices[i]].
*/
n = nrow(indices)
v = nrow(embedding_dict)
# Construct permutation-like matrix (one '1' per row, rest '0')
permutation = matrix(0, rows=n, cols=v)
for (i in 1:n) {
permutation[i, as.integer(as.scalar(indices[i]))] = 1
}
embeddings = permutation %*% embedding_dict
}
backward = function(matrix[double] dout, matrix[double] indices, int v,
int padding_idx = -1)
return (matrix[double] dembedding_dict) {
/*
* Backward pass of embedding layer computes the gradients of the embedding
* dictionary.
*
* Inputs:
* - dout: Gradient of the output.
* - indices: Indices referring to embedding vectors of embedding dictionary
* of shape n x 1 with each value in {1, ..., v}.
* - v: Embedding dictionary size.
* - padding_idx: Index of embedding vector of embedding dictionary which
* should not be updated (i.e. gradients are 0). Use -1 if
* there is no padding vector.
*
* Outputs:
* - dembedding_dict: Gradients of the dictionary of embedding vectors of
* shape v x d.
*/
n = nrow(indices)
# Construct permutation-like matrix (one '1' per row, rest '0')
permutation = matrix(0, rows=n, cols=v)
for (i in 1:n) {
permutation[i, as.integer(as.scalar(indices[i]))] = 1
}
dembedding_dict = t(permutation) %*% dout
if (padding_idx != -1) {
dembedding_dict[padding_idx] = matrix(0, rows=1, cols=ncol(dout))
}
}
init = function(int v, int d, int seed = -1)
return (matrix[double] embedding_dict) {
/*
* Initializes embedding dictionary matrix via N(0, 1).
*
* Inputs:
* - v: Embedding dictionary size.
* - d: Embedding vector dimension.
* - seed: Random generation seed.
*
* Output:
* - embedding_dict: Embedding dictionary matrix of shape v x d.
*/
embedding_dict = rand(rows=v, cols=d, pdf="normal", seed=seed)
}