/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.opennlp.utils.cfg;

import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;

/**
 * a probabilistic CFG
 */
public class ProbabilisticContextFreeGrammar {

  private final Collection<String> nonTerminalSymbols;
  private final Collection<String> terminalSymbols;
  private final Map<Rule, Double> rules;
  private final String startSymbol;
  private boolean randomExpansion;

  public ProbabilisticContextFreeGrammar(Collection<String> nonTerminalSymbols, Collection<String> terminalSymbols,
                                         Map<Rule, Double> rules, String startSymbol, boolean randomExpansion) {

    assert nonTerminalSymbols.contains(startSymbol) : "start symbol doesn't belong to non-terminal symbols set";

    this.nonTerminalSymbols = nonTerminalSymbols;
    this.terminalSymbols = terminalSymbols;
    this.rules = rules;
    this.startSymbol = startSymbol;
    this.randomExpansion = randomExpansion;
  }

  public ProbabilisticContextFreeGrammar(Collection<String> nonTerminalSymbols, Collection<String> terminalSymbols, Map<Rule, Double> rules, String startSymbol) {
    this(nonTerminalSymbols, terminalSymbols, rules, startSymbol, false);
  }

  public Collection<String> getNonTerminalSymbols() {
    return nonTerminalSymbols;
  }

  public Collection<String> getTerminalSymbols() {
    return terminalSymbols;
  }

  public Map<Rule, Double> getRules() {
    return rules;
  }

  public String getStartSymbol() {
    return startSymbol;
  }


  public String[] leftMostDerivation(String... words) {
    ArrayList<String> expansion = new ArrayList<String>(words.length);

    assert words.length > 0 && startSymbol.equals(words[0]);

    for (String word : words) {
      expansion.addAll(getTerminals(word));
    }
    return expansion.toArray(new String[expansion.size()]);

  }

  private Collection<String> getTerminals(String word) {
    if (terminalSymbols.contains(word)) {
      Collection<String> c = new LinkedList<String>();
      c.add(word);
      return c;
    } else {
      assert nonTerminalSymbols.contains(word) : "word " + word + " is not contained in non terminals";
      String[] expansions = getExpansionForSymbol(word);
      Collection<String> c = new LinkedList<String>();
      for (String e : expansions) {
        c.addAll(getTerminals(e));
      }
      return c;
    }
  }

  private String[] getExpansionForSymbol(String currentSymbol) {
    Rule r = getRuleForSymbol(currentSymbol);
    return r.getExpansion();
  }

  private Rule getRuleForSymbol(String word) {
    ArrayList<Rule> possibleRules = new ArrayList<Rule>();
    for (Rule r : rules.keySet()) {
      if (word.equals(r.getEntry())) {
        if (!randomExpansion) {
          return r;
        }
        possibleRules.add(r);
      }
    }
    if (possibleRules.size() > 0) {
      return possibleRules.get(new Random().nextInt(possibleRules.size()));
    } else {
      throw new RuntimeException("could not find a rule for expanding symbol " + word);
    }
  }

  public BackPointer pi(List<String> sentence, int i, int j, String x) {
    BackPointer backPointer = new BackPointer(0, 0, null);
    if (i == j) {
      Rule rule = new Rule(x, sentence.get(i));
      double q = q(rule);
      backPointer = new BackPointer(q, i, rule);
    } else {
      double max = 0;
      for (Rule rule : getNTRules()) {
        for (int s = i; s < j; s++) {
          double q = q(rule);
          BackPointer left = pi(sentence, i, s, rule.getExpansion()[0]);
          BackPointer right = pi(sentence, s + 1, j, rule.getExpansion()[1]);
          double cp = q * left.getProbability() * right.getProbability();
          if (cp > max) {
            max = cp;
            backPointer = new BackPointer(max, s, rule, left, right);
          }
        }
      }
    }
    return backPointer;
  }

  public BackPointer cky(List<String> sentence) {
    BackPointer backPointer = null;

    int n = sentence.size();
    for (int l = 1; l < n; l++) {
      for (int i = 0; i < n - l; i++) {
        int j = i + l;
        double max = 0;
        for (String x : getNonTerminalSymbols()) {
          for (Rule r : getRulesForNonTerminal(x)) {
            for (int s = i; s < j - 1; s++) {
              double q = q(r);
              BackPointer left = pi(sentence, i, s, r.getExpansion()[0]);
              BackPointer right = pi(sentence, s + 1, j, r.getExpansion()[1]);
              double cp = q * left.getProbability() * right.getProbability();
              if (cp > max) {
                max = cp;
                backPointer = new BackPointer(max, s, r, left, right);
              }
            }
          }
        }
      }
    }
    return backPointer;
  }

  private Collection<Rule> getRulesForNonTerminal(String x) {
    LinkedList<Rule> ntRules = new LinkedList<Rule>();
    for (Rule r : rules.keySet()) {
      if (x.equals(r.getEntry()) && nonTerminalSymbols.contains(r.getExpansion()[0]) && nonTerminalSymbols.contains(r.getExpansion()[1])) {
        ntRules.add(r);
      }
    }
    return ntRules;
  }

  private Collection<Rule> getNTRules() {
    Collection<Rule> ntRules = new LinkedList<Rule>();
    for (Rule r : rules.keySet()) {
      if (nonTerminalSymbols.contains(r.getExpansion()[0]) && nonTerminalSymbols.contains(r.getExpansion()[1])) {
        ntRules.add(r);
      }
    }
    return ntRules;
  }

  private double q(Rule rule) {
    return rules.keySet().contains(rule) ? rules.get(rule) : 0;
  }

  public class BackPointer {

    private final double probability;
    private final int splitPoint;
    private final Rule rule;
    private BackPointer leftTree;
    private BackPointer rightTree;

    private BackPointer(double probability, int splitPoint, Rule rule) {
      this.probability = probability;
      this.splitPoint = splitPoint;
      this.rule = rule;
    }

    public BackPointer(double probability, int splitPoint, Rule rule, BackPointer leftTree, BackPointer rightTree) {
      this.probability = probability;
      this.splitPoint = splitPoint;
      this.rule = rule;
      this.leftTree = leftTree;
      this.rightTree = rightTree;
    }

    public double getProbability() {
      return probability;
    }

    public int getSplitPoint() {
      return splitPoint;
    }

    public Rule getRule() {
      return rule;
    }

    public BackPointer getLeftTree() {
      return leftTree;
    }

    public BackPointer getRightTree() {
      return rightTree;
    }

    @Override
    public String toString() {
      return "(" +
              rule.getEntry() + " " +
              (leftTree != null && rightTree != null ?
                      leftTree.toString() + " " + rightTree.toString() :
                      rule.getExpansion()[0]
              ) +
              ')';
    }
  }

}
