blob: fb3a8f352deee8db04eb8d04908bdc60c17de17d [file] [log] [blame]
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;
(ns rnn.lstm
(:require [org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.executor :as executor]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.symbol :as sym]))
(defn lstm-param [i2h-weight i2h-bias
h2h-weight h2h-bias]
{:i2h-weight i2h-weight :i2h-bias i2h-bias
:h2h-weight h2h-weight :h2h-bias h2h-bias})
(defn lstm-state [c h]
{:c c :h h})
(defn lstm [num-hidden in-data prev-state param seq-idx layer-idx dropout]
(let [in-dataa (if (pos? dropout)
(sym/dropout {:data in-data :p dropout})
in-data)
i2h (sym/fully-connected (str "t" seq-idx "_l" layer-idx "_i2h")
{:data in-dataa :weight (:i2h-weight param)
:bias (:i2h-bias param) :num-hidden (* num-hidden 4)})
h2h (sym/fully-connected (str "t" seq-idx "_l" layer-idx "_h2h")
{:data (:h prev-state) :weight (:h2h-weight param)
:bias (:h2h-bias param) :num-hidden (* num-hidden 4)})
gates (sym/+ i2h h2h)
slice-gates (sym/slice-channel (str "t" seq-idx "_l" layer-idx "_slice")
{:data gates :num-outputs 4})
in-gate (sym/activation {:data (sym/get slice-gates 0) :act-type "sigmoid"})
in-transform (sym/activation {:data (sym/get slice-gates 1) :act-type "tanh"})
forget-gate (sym/activation {:data (sym/get slice-gates 2) :act-type "sigmoid"})
out-gate (sym/activation {:data (sym/get slice-gates 3) :act-type "sigmoid"})
next-c (sym/+ (sym/* forget-gate (:c prev-state))
(sym/* in-gate in-transform))
next-h (sym/* out-gate (sym/activation {:data next-c :act-type "tanh"}))]
(lstm-state next-c next-h)))
(defn lstm-unroll [num-lstm-layer seq-len input-size num-hidden num-embed num-label dropout]
(let [embed-weight (sym/variable "embed_weight")
cls-weight (sym/variable "cls_weight")
cls-bias (sym/variable "cls_bias")
param-cells (mapv (fn [i]
(lstm-param (sym/variable (str "l" i "_i2h_weight"))
(sym/variable (str "l" i "_i2h_bias"))
(sym/variable (str "l" i "_h2h_weight"))
(sym/variable (str "l" i "_h2h_bias"))))
(range 0 num-lstm-layer))
last-states (mapv (fn [i]
(lstm-state (sym/variable (str "l" i "_init_c_beta"))
(sym/variable (str "l" i "_init_h_beta"))))
(range 0 num-lstm-layer))
;; embedding layer
data (sym/variable "data")
label (sym/variable "softmax_label")
embed (sym/embedding "embed" {:data data :input-dim input-size :weight embed-weight
:output-dim num-embed})
wordvec (sym/slice-channel {:data embed :num-outputs seq-len :squeeze-axis 1})
dp-ratio 0
;; stack lstm
hidden-all (doall (for [seq-idx (range seq-len)]
(let [hidden (:h (last (loop [i 0
hidden (sym/get wordvec seq-idx)
next-states []]
(if (= i num-lstm-layer)
next-states
(let [dp-ratio (if (zero? i) 0 dropout)
next-state (lstm num-hidden
hidden
(get last-states i)
(get param-cells i)
seq-idx
i
dp-ratio)]
(recur (inc i)
(:h next-state)
(conj next-states next-state)))))))]
(if (pos? dropout)
(sym/dropout {:data hidden :p dropout})
hidden))))
hidden-concat (sym/concat "concat" nil hidden-all {:dim 0})
pred (sym/fully-connected "pred" {:data hidden-concat :num-hidden num-label
:weight cls-weight :bias cls-bias})
label (sym/transpose {:data label})
label (sym/reshape {:data label :target-shape [0]})
sm (sym/softmax-output "softmax" {:data pred :label label})]
sm))
(defn lstm-inference-symbol [num-lstm-layer input-size num-hidden
num-embed num-label dropout]
(let [seq-idx 0
embed-weight (sym/variable "embed_weight")
cls-weight (sym/variable "cls_weight")
cls-bias (sym/variable "cls_bias")
param-cells (mapv (fn [i]
(lstm-param (sym/variable (str "l" i "_i2h_weight"))
(sym/variable (str "l" i "_i2h_bias"))
(sym/variable (str "l" i "_h2h_weight"))
(sym/variable (str "l" i "_h2h_bias"))))
(range 0 num-lstm-layer))
last-states (mapv (fn [i]
(lstm-state (sym/variable (str "l" i "_init_c_beta"))
(sym/variable (str "l" i "_init_h_beta"))))
(range 0 num-lstm-layer))
data (sym/variable "data")
dp-ratio 0
;; stack lstm
next-states (loop [i 0
hidden (sym/embedding "embed" {:data data :input-dim input-size :weight embed-weight :output-dim num-embed})
next-states []]
(if (= i num-lstm-layer)
next-states
(let [dp-ratio (if (zero? i) 0 dropout)
next-state (lstm num-hidden
hidden
(get last-states i)
(get param-cells i)
seq-idx
i
dp-ratio)]
(recur (inc i)
(:h next-state)
(conj next-states next-state)))))
;;; decoder
hidden (:h (last next-states))
hidden (if (pos? dropout) (sym/dropout {:data hidden :p dropout}) hidden)
fc (sym/fully-connected "pred" {:data hidden :num-hidden num-label
:weight cls-weight :bias cls-bias})
sm (sym/softmax-output "softmax" {:data fc})
outs (into [sm] (mapcat (fn [next-s] (vals next-s)) next-states))]
(sym/group outs)))
(defn lstm-inference-model [{:keys [num-lstm-layer input-size num-hidden
num-embed num-label arg-params
ctx dropout]
:or {ctx (context/cpu)
dropout 0.0}}]
(let [lstm-sym (lstm-inference-symbol num-lstm-layer
input-size
num-hidden
num-embed
num-label
dropout)
batch-size 1
init-c (into {} (map (fn [l]
{(str "l" l "_init_c_beta") [batch-size num-hidden]})
(range num-lstm-layer)))
init-h (into {} (map (fn [l]
{(str "l" l "_init_h_beta") [batch-size num-hidden]}))
(range num-lstm-layer))
data-shape {"data" [batch-size]}
input-shape (merge init-c init-h data-shape)
exec (sym/simple-bind lstm-sym ctx input-shape)
exec-arg-map (executor/arg-map exec)
states-map (zipmap (mapcat (fn [i] [(str "l" i "_init_c_beta")
(str "l" i "_init_h_beta")])
(range num-lstm-layer))
(rest (executor/outputs exec)))]
(doseq [[k v] arg-params]
(if-let [target-v (get exec-arg-map k)]
(when (and (not (get input-shape k))
(not= "softmax_label" k))
(ndarray/copy-to v target-v))))
{:exec exec
:states-map states-map}))
(defn forward [{:keys [exec states-map] :as lstm-model} input-data new-seq]
(when new-seq
(doseq [[k v] states-map]
(ndarray/set (get (executor/arg-map exec) k) 0)))
(do
(ndarray/copy-to input-data (get (executor/arg-map exec) "data"))
(executor/forward exec)
(doseq [[k v] states-map]
(ndarray/copy-to v (get (executor/arg-map exec) k)))
(first (executor/outputs exec))))