blob: 74edf71172c7dd81d2bc514ee27e6ced91900315 [file]
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;
(ns mnist-mlp
(:require [clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
[org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.eval-metric :as eval-metric]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.module :as m]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.optimizer :as optimizer]
[org.apache.clojure-mxnet.symbol :as sym]
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.ndarray :as ndarray])
(:gen-class))
(def data-dir "data/")
(def batch-size 10)
(def num-epoch 5)
(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte")))
(sh "../../scripts/get_mnist_data.sh"))
;; for save checkpoints load checkpoints
(io/make-parents "model/dummy.txt")
;;; Load the MNIST datasets
(defonce train-data (mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte")
:label (str data-dir "train-labels-idx1-ubyte")
:label-name "softmax_label"
:input-shape [784]
:batch-size batch-size
:shuffle true
:flat true
:silent false
:seed 10}))
(defonce test-data (mx-io/mnist-iter {:image (str data-dir "t10k-images-idx3-ubyte")
:label (str data-dir "t10k-labels-idx1-ubyte")
:input-shape [784]
:batch-size batch-size
:flat true
:silent false}))
(defn get-symbol []
(as-> (sym/variable "data") data
(sym/fully-connected "fc1" {:data data :num-hidden 128})
(sym/activation "relu1" {:data data :act-type "relu"})
(sym/fully-connected "fc2" {:data data :num-hidden 64})
(sym/activation "relu2" {:data data :act-type "relu"})
(sym/fully-connected "fc3" {:data data :num-hidden 10})
(sym/softmax-output "softmax" {:data data})))
(defn- print-header [message]
(println "")
(println "=================")
(println (str " " message))
(println "=================")
(println ""))
(defn run-intermediate-level-api [& {:keys [devs load-model-epoch]}]
(let [header "Running Intermediate Level API"]
(print-header (if load-model-epoch (str header " and loading from previous epoch " load-model-epoch)
header)))
(let [save-prefix "model/mnist-mlp"
mod (if load-model-epoch
(do
(println "Loading from checkpoint of epoch " load-model-epoch)
(m/load-checkpoint {:contexts devs :prefix save-prefix :epoch load-model-epoch}))
(m/module (get-symbol) {:contexts devs}))
metric (eval-metric/accuracy)]
(-> mod
(m/bind {:data-shapes (mx-io/provide-data train-data) :label-shapes (mx-io/provide-label train-data)})
(m/init-params)
(m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.9})}))
(doseq [epoch-num (range num-epoch)]
(println "starting epoch " epoch-num)
(mx-io/do-batches
train-data
(fn [batch]
(-> mod
(m/forward batch)
(m/update-metric metric (mx-io/batch-label batch))
(m/backward)
(m/update))))
(println "result for epoch " epoch-num " is " (eval-metric/get-and-reset metric))
(m/save-checkpoint mod {:prefix save-prefix :epoch epoch-num :save-opt-states true}))))
(defn run-high-level-api [devs]
(print-header "Running High Level API")
(let [mod (m/module (get-symbol) {:contexts devs})]
;;; note only one function for training
(m/fit mod {:train-data train-data :eval-data test-data :num-epoch num-epoch})
;;high level predict (just a dummy call but it returns a vector of results
(m/predict mod {:eval-data test-data})
;;;high level score (returs the eval values)
(let [score (m/score mod {:eval-data test-data :eval-metric (eval-metric/accuracy)})]
(println "High level predict score is " score))))
(defn run-predication-and-calc-accuracy-manually [devs]
;;; Gathers all the predictions at once with `predict-every-batch`
;;; then cycles thorugh the batches and manually calculates the accuracy stats
(print-header "Running Predicting and Calcing the Accuracy Manually")
(let [mod (m/module (get-symbol) {:contexts devs})]
;;; note only one function for training
(m/fit mod {:train-data train-data :eval-data test-data :num-epoch num-epoch})
(let [preds (m/predict-every-batch mod {:eval-data test-data})
stats (mx-io/reduce-batches test-data
(fn [r b]
(let [pred-label (->> (ndarray/argmax-channel (first (get preds (:index r))))
(ndarray/->vec)
(mapv int))
label (->> (mx-io/batch-label b)
(first)
(ndarray/->vec)
(mapv int))
acc-sum (apply + (mapv (fn [pl l] (if (= pl l) 1 0))
pred-label label))]
(-> r
(update :index inc)
(update :acc-cnt (fn [v] (+ v (count pred-label))))
(update :acc-sum (fn [v] (+ v
(apply + (mapv (fn [pl l] (if (= pl l) 1 0))
pred-label label))))))))
{:acc-sum 0 :acc-cnt 0 :index 0})]
(println "Stats: " stats)
(println "Accuracy: " (/ (:acc-sum stats)
(* 1.0 (:acc-cnt stats)))))))
(defn run-prediction-iterator-api [devs]
;;Cycles through all the batchs and manually predicts and prints out the accuracy
;;using `predict-batch`
(print-header "Running the Prediction Iterator API and Calcing the Accuracy Manually")
(let [mod (m/module (get-symbol) {:contexts devs})]
;;; note only one function for training
(m/fit mod {:train-data train-data :eval-data test-data :num-epoch num-epoch})
(mx-io/reduce-batches test-data
(fn [r b]
(let [preds (m/predict-batch mod b)
pred-label (->> (ndarray/argmax-channel (first preds))
(ndarray/->vec)
(mapv int))
label (->> (mx-io/batch-label b)
(first)
(ndarray/->vec)
(mapv int))
acc (/ (apply + (mapv (fn [pl l] (if (= pl l) 1 0)) pred-label label))
(* 1.0 (count pred-label)))]
(println "Batch " r " acc: " acc)
(inc r))))))
(defn run-all [devs]
(run-intermediate-level-api :devs devs)
(run-intermediate-level-api :devs devs :load-model-epoch (dec num-epoch))
(run-high-level-api devs)
(run-prediction-iterator-api devs)
(run-predication-and-calc-accuracy-manually devs))
(defn -main
[& args]
(let [[dev dev-num] args
devs (if (= dev ":gpu")
(mapv #(context/gpu %) (range (Integer/parseInt (or dev-num "1"))))
(mapv #(context/cpu %) (range (Integer/parseInt (or dev-num "1")))))]
(println "Running Module MNIST example")
(println "Running with context devices of" devs)
(run-all devs)))
(comment
;;; run all the example functions
(run-all [(context/cpu)])
;;; run for the number of epochs
(run-intermediate-level-api :devs [(context/cpu)])
;;=> starting epoch 0
;;=> result for epoch 0 is [accuracy 0.8531333]
;;=> INFO ml.dmlc.mxnet.module.Module: Saved checkpoint to model/mnist-mlp-0000.params
;;=> INFO ml.dmlc.mxnet.module.Module: Saved optimizer state to model/mnist-mlp-0000.states
;;=> ....
;;=> starting epoch 4
;;=> result for epoch 4 is [accuracy 0.91875]
;;=> INFO ml.dmlc.mxnet.module.Module: Saved checkpoint to model/mnist-mlp-0004.params
;;=> INFO ml.dmlc.mxnet.module.Module: Saved optimizer state to model/mnist-mlp-0004.states
;; load from the last saved file and run again
(run-intermediate-level-api :devs [(context/cpu)] :load-model-epoch (dec num-epoch))
;;=> Loading from checkpoint of epoch 4
;;=> starting epoch 0
;;=> result for epoch 0 is [accuracy 0.96258336]
;;=> INFO ml.dmlc.mxnet.module.Module: Saved checkpoint to model/mnist-mlp-0000.params
;;=> INFO ml.dmlc.mxnet.module.Module: Saved optimizer state to model/mnist-mlp-0000.states
;;=> ...
;;=> starting epoch 4
;;=> result for epoch 4 is [accuracy 0.9819833]
;;=> INFO ml.dmlc.mxnet.module.Module: Saved checkpoint to model/mnist-mlp-0004.params
;;=> INFO ml.dmlc.mxnet.module.Module: Saved optimizer state to model/mnist-mlp-0004.states
(run-high-level-api [(context/cpu)])
;;=> ["accuracy" 0.9454]
(run-prediction-iterator-api [(context/cpu)])
;;=> Batch 0 acc: 1.0
;;=> Batch 1 acc: 0.9
;;=> Batch 2 acc: 1.0
;;=> ...
;;=> Batch 999 acc: 1.0
(run-predication-and-calc-accuracy-manually [(context/cpu)])
;;=> Stats: {:acc-sum 9494, :acc-cnt 10000, :index 1000}
;;=> Accuracy: 0.9494
)