/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package opennlp.tools.dl;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Date;
import java.util.LinkedList;
import java.util.List;

import org.apache.commons.math3.distribution.EnumeratedDistribution;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/**
 * A basic char/word-level stacked RNN model (2 hidden recurrent layers), based on Stacked RNN architecture from ICLR 2014's
 * "How to Construct Deep Recurrent Neural Networks" by Razvan Pascanu, Caglar Gulcehre, Kyunghyun Cho and Yoshua Bengio
 * and Andrej Karpathy's notes on RNNs.
 * See also:
 *
 * @see <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness">The Unreasonable Effectiveness of Recurrent Neural Networks</a>
 * @see <a href="https://arxiv.org/abs/1312.6026">How to Construct Deep Recurrent Neural Networks</a>
 */
public class StackedRNN extends RNN {

  // model parameters
  private final INDArray wxh; // input to hidden
  private final INDArray whh; // hidden to hidden
  private final INDArray whh2; // hidden to hidden2
  private final INDArray wh2y; // hidden2 to output
  private final INDArray wxh2;
  private final INDArray bh; // hidden bias
  private final INDArray bh2; // hidden2 bias
  private final INDArray by; // output bias

  private final double eps = 1e-4;
  private final double decay = 0.9;
  private final boolean rmsProp;

  private INDArray hPrev = null; // memory state
  private INDArray hPrev2 = null; // memory state

  public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) {
    this(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, true, false);
  }

  public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, int batch, boolean useChars, boolean rmsProp) {
    super(learningRate, seqLength, hiddenLayerSize, epochs, text, batch, useChars);

    this.rmsProp = rmsProp;
    wxh = Nd4j.randn(hiddenLayerSize, vocabSize).div(Math.sqrt(hiddenLayerSize));
    whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
    whh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
    wxh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
    wh2y = Nd4j.randn(vocabSize, hiddenLayerSize).div(Math.sqrt(vocabSize));
    bh = Nd4j.zeros(hiddenLayerSize, 1);
    bh2 = Nd4j.zeros(hiddenLayerSize, 1);
    by = Nd4j.zeros(vocabSize, 1);
  }

  public void learn() {

    int currentEpoch = -1;

    int n = 0;
    int p = 0;

    // memory variables for Adagrad
    INDArray mWxh = Nd4j.zerosLike(wxh);
    INDArray mWxh2 = Nd4j.zerosLike(wxh2);
    INDArray mWhh = Nd4j.zerosLike(whh);
    INDArray mWhh2 = Nd4j.zerosLike(whh2);
    INDArray mWh2y = Nd4j.zerosLike(wh2y);

    INDArray mbh = Nd4j.zerosLike(bh);
    INDArray mbh2 = Nd4j.zerosLike(bh2);
    INDArray mby = Nd4j.zerosLike(by);

    // loss at iteration 0
    double smoothLoss = -Math.log(1.0 / vocabSize) * seqLength;

    while (true) {
      // prepare inputs (we're sweeping from left to right in steps seqLength long)
      if (p + seqLength + 1 >= data.size() || n == 0) {
        hPrev = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory
        hPrev2 = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory
        p = 0; // go from start of data
        currentEpoch++;
        if (currentEpoch == epochs) {
          System.out.println("training finished: e:" + epochs + ", l: " + smoothLoss + ", h:(" + learningRate + ", " + seqLength + ", " + hiddenLayerSize + ")");
          break;
        }
      }

      INDArray inputs = getSequence(p);
      INDArray targets = getSequence(p + 1);

      // sample from the model every now and then
      if (n % 1000 == 0 && n > 0) {
        String txt = sample(inputs.getInt(0));
        System.out.printf("\n---\n %s \n----\n", txt);
      }

      INDArray dWxh = Nd4j.zerosLike(wxh);
      INDArray dWxh2 = Nd4j.zerosLike(wxh2);
      INDArray dWhh = Nd4j.zerosLike(whh);
      INDArray dWhh2 = Nd4j.zerosLike(whh2);
      INDArray dWh2y = Nd4j.zerosLike(wh2y);

      INDArray dbh = Nd4j.zerosLike(bh);
      INDArray dbh2 = Nd4j.zerosLike(bh);
      INDArray dby = Nd4j.zerosLike(by);

      // forward seqLength characters through the net and fetch gradient
      double loss = lossFun(inputs, targets, dWxh, dWhh, dWxh2, dWhh2, dWh2y, dbh, dbh2, dby);
      smoothLoss = smoothLoss * 0.999 + loss * 0.001;
      if (Double.isNaN(smoothLoss) || Double.isInfinite(smoothLoss)) {
        System.out.println("loss is " + smoothLoss + " (over/underflow occured, try adjusting hyperparameters)");
        break;
      }
      if (n % 100 == 0) {
        System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress
      }

      if (n % batch == 0) {
        if (rmsProp) {
          // perform parameter update with RMSprop
          mWxh = mWxh.mul(decay).add(1 - decay).mul((dWxh).mul(dWxh));
          wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps)));

          mWxh2 = mWxh2.mul(decay).add(1 - decay).mul((dWxh2).mul(dWxh2));
          wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2).add(eps)));

          mWhh = mWhh.mul(decay).add(1 - decay).mul((dWhh).mul(dWhh));
          whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(eps)));

          mWhh2 = mWhh2.mul(decay).add(1 - decay).mul((dWhh2).mul(dWhh2));
          whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2).add(eps)));

          mbh2 = mbh2.mul(decay).add(1 - decay).mul((dbh2).mul(dbh2));
          bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2).add(eps)));

          mWh2y = mWh2y.mul(decay).add(1 - decay).mul((dWh2y).mul(dWh2y));
          wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y).add(eps)));

          mbh = mbh.mul(decay).add(1 - decay).mul((dbh).mul(dbh));
          bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(eps)));

          mby = mby.mul(decay).add(1 - decay).mul((dby).mul(dby));
          by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps)));
        } else {
          // perform parameter update with Adagrad

          mWxh.addi(dWxh.mul(dWxh));
          wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps)));

          mWxh2.addi(dWxh2.mul(dWxh2));
          wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2).add(eps)));

          mWhh.addi(dWhh.mul(dWhh));
          whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(eps)));

          mWhh2.addi(dWhh2.mul(dWhh2));
          whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2).add(eps)));

          mbh2.addi(dbh2.mul(dbh2));
          bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2).add(eps)));

          mWh2y.addi(dWh2y.mul(dWh2y));
          wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y).add(eps)));

          mbh.addi(dbh.mul(dbh));
          bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(eps)));

          mby.addi(dby.mul(dby));
          by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps)));
        }
      }

      p += seqLength; // move data pointer
      n++; // iteration counter
    }
  }

  /**
   * inputs, targets are both list of integers
   * hprev is Hx1 array of initial hidden state
   * returns the loss, gradients on model parameters and last hidden state
   */
  private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWxh2, INDArray dWhh2, INDArray dWh2y,
                         INDArray dbh, INDArray dbh2, INDArray dby) {

    INDArray xs = Nd4j.zeros(seqLength, vocabSize);
    INDArray hs = null;
    INDArray hs2 = null;
    INDArray ys = null;
    INDArray ps = null;

    double loss = 0;

    // forward pass
    for (int t = 0; t < seqLength; t++) {
      int tIndex = inputs.getScalar(t).getInt(0);
      xs.putScalar(t, tIndex, 1); // encode in 1-of-k representation

      INDArray xst = xs.getRow(t);

      hPrev = Transforms.tanh((wxh.mmul(xst.transpose()).add(whh.mmul(hPrev)).add(bh))); // hidden state
      if (hs == null) {
        hs = init(seqLength, hPrev.shape());
      }
      hs.putRow(t, hPrev.dup());

      hPrev2 = Transforms.tanh((wxh2.mmul(hPrev).add(whh2.mmul(hPrev2)).add(bh2))); // hidden state 2
      if (hs2 == null) {
        hs2 = init(seqLength, hPrev2.shape());
      }
      hs2.putRow(t, hPrev2.dup());

      INDArray yst = wh2y.mmul(hPrev2).add(by); // unnormalized log probabilities for next chars
      if (ys == null) {
        ys = init(seqLength, yst.shape());
      }
      ys.putRow(t, yst);

      INDArray pst = Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)); // probabilities for next chars
      if (ps == null) {
        ps = init(seqLength, pst.shape());
      }
      ps.putRow(t, pst);

      loss += -Math.log(pst.getDouble(targets.getInt(t))); // softmax (cross-entropy loss)
    }

    // backward pass: compute gradients going backwards
    INDArray dhNext = Nd4j.zerosLike(hPrev);
    INDArray dh2Next = Nd4j.zerosLike(hPrev2);

    for (int t = seqLength - 1; t >= 0; t--) {
      INDArray dy = ps.getRow(t);
      dy.getRow(targets.getInt(t)).subi(1); // backprop into y

      INDArray hs2t = hs2.getRow(t);
      INDArray hs2tm1 = t == 0 ? hPrev2 : hs2.getRow(t - 1);

      dWh2y.addi(dy.mmul(hs2t.transpose()));
      dby.addi(dy);

      INDArray dh2 = wh2y.transpose().mmul(dy).add(dh2Next); // backprop into h2
      INDArray dhraw2 = (Nd4j.ones(hs2t.shape()).sub(hs2t.mul(hs2t))).mul(dh2); //  backprop through tanh nonlinearity
      dbh2.addi(dhraw2);
      INDArray hst = hs.getRow(t);
      dWxh2.addi(dhraw2.mmul(hst.transpose()));
      dWhh2.addi(dhraw2.mmul(hs2tm1.transpose()));
      dh2Next = whh2.transpose().mmul(dhraw2);

      INDArray dh = wxh2.transpose().mmul(dhraw2).add(dhNext); // backprop into h
      INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst.mul(hst))).mul(dh); // backprop through tanh nonlinearity
      dbh.addi(dhraw);
      dWxh.addi(dhraw.mmul(xs.getRow(t)));
      INDArray hsRow = t == 0 ? hPrev : hs.getRow(t - 1);
      dWhh.addi(dhraw.mmul(hsRow.transpose()));
      dhNext = whh.transpose().mmul(dhraw);
    }

    this.hPrev = hs.getRow(seqLength - 1);
    this.hPrev2 = hs2.getRow(seqLength - 1);

    return loss;
  }

  /**
   * sample a sequence of integers from the model, using current (hPrev) memory state, seedIx is seed letter for first time step
   */
  @Override
  public String sample(int seedIx) {

    INDArray x = Nd4j.zeros(vocabSize, 1);
    x.putScalar(seedIx, 1);
    int sampleSize = seqLength * 2;
    INDArray ixes = Nd4j.create(sampleSize);

    INDArray h = hPrev.dup();
    INDArray h2 = hPrev2.dup();

    for (int t = 0; t < sampleSize; t++) {
      h = Transforms.tanh((wxh.mmul(x)).add(whh.mmul(h)).add(bh));
      h2 = Transforms.tanh((wxh2.mmul(h)).add(whh2.mmul(h2)).add(bh2));
      INDArray y = wh2y.mmul(h2).add(by);
      INDArray pm = Nd4j.getExecutioner().execAndReturn(new SoftMax(y)).ravel();

      List<Pair<Integer, Double>> d = new LinkedList<>();
      for (int pi = 0; pi < vocabSize; pi++) {
        d.add(new Pair<>(pi, pm.getDouble(0, pi)));
      }
      try {
        EnumeratedDistribution<Integer> distribution = new EnumeratedDistribution<>(d);

        int ix = distribution.sample();

        x = Nd4j.zeros(vocabSize, 1);
        x.putScalar(ix, 1);
        ixes.putScalar(t, ix);
      } catch (Exception e) {
      }
    }

    return getSampleString(ixes);
  }

  @Override
  public void serialize(String prefix) throws IOException {
    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(prefix + new Date().toString() + ".txt")));
    bufferedWriter.write("wxh");
    bufferedWriter.write(wxh.toString());
    bufferedWriter.write("whh");
    bufferedWriter.write(whh.toString());
    bufferedWriter.write("wxh2");
    bufferedWriter.write(wxh2.toString());
    bufferedWriter.write("whh2");
    bufferedWriter.write(whh2.toString());
    bufferedWriter.write("wh2y");
    bufferedWriter.write(wh2y.toString());
    bufferedWriter.write("bh");
    bufferedWriter.write(bh.toString());
    bufferedWriter.write("bh2");
    bufferedWriter.write(bh2.toString());
    bufferedWriter.write("by");
    bufferedWriter.write(by.toString());
    bufferedWriter.flush();
    bufferedWriter.close();
  }

}