/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.mrql;

import org.apache.mrql.gen.*;


/** Translates physical plans to BSP plans to be executed on Hama */
public class BSPTranslator extends TypeInference {
    final static int orM = ClassImporter.find_method_number("or",#[bool,bool]);
    final static int andM = ClassImporter.find_method_number("and",#[bool,bool]);
    final static int notM = ClassImporter.find_method_number("not",#[bool]);
    final static int eqM = ClassImporter.find_method_number("eq",#[int,int]);
    final static int plusM = ClassImporter.find_method_number("plus",#[int,int]);
    final static int neqM = ClassImporter.find_method_number("neq",#[int,int]);
    final static int geqM = ClassImporter.find_method_number("geq",#[long,long]);
    final static int gtM = ClassImporter.find_method_number("gt",#[int,int]);
    final static int unionM = ClassImporter.find_method_number("union",#[bag(any),bag(any)]);
    final static int countM = ClassImporter.find_method_number("count",#[bag(any)]);
    final static int coerceM = ClassImporter.find_method_number("coerce",#[any,int]);

    final static Trees planNames = plans_with_distributed_lambdas
        .append(#[ParsedSource,Generator,BinarySource,GroupByJoin,Repeat,Closure,repeat]);

    private static int source_num = 0;

    public static void reset () {
        source_num = 0;
    }

    private static Tree new_source ( Tree e ) {
        match e {
        case `f(...):
            if (planNames.member(#<`f>))
                return preprocess(e);
        };
        return #<BSPSource(`(source_num++),`(preprocess(e)))>;
    }

    /** add a new source num to every BSP operation */
    private static Tree preprocess ( Tree e ) {
        match e {
        case lambda(`v,`b):
            return #<lambda(`v,`(preprocess(b)))>;
        case Aggregate(`acc,`zero,`S):
            return #<Aggregate(`acc,`zero,
                               `(new_source(S)))>;
        case cMap(`m,`S):
            return #<cMap(`m,
                          `(new_source(S)))>;
        case AggregateMap(`m,`acc,`zero,`S):
            return #<AggregateMap(`m,`acc,`zero,
                                  `(new_source(S)))>;
        case MapReduce(`m,`r,`S,`o):
            return #<MapReduce(`m,`r,
                               `(new_source(S)),`o)>;
        case MapAggregateReduce(`m,`r,`acc,`zero,`S,`o):
            return #<MapAggregateReduce(`m,`r,`acc,`zero,
                                        `(new_source(S)),`o)>;
        case MapCombineReduce(`m,`c,`r,`S,`o):
            return #<MapCombineReduce(`m,`c,`r,
                                      `(new_source(S)),`o)>;
        case MapReduce2(`mx,`my,`r,`x,`y,`o):
            return #<MapReduce2(`mx,`my,`r,
                                `(new_source(x)),
                                `(new_source(y)),`o)>;
        case MapReduce2(`mx,`my,`c,`r,`x,`y,`o):
            return #<MapCombineReduce2(`mx,`my,`c,`r,
                                       `(new_source(x)),
                                       `(new_source(y)),`o)>;
        case MapAggregateReduce2(`mx,`my,`r,`acc,`zero,`x,`y,`o):
            return #<MapAggregateReduce2(`mx,`my,`r,`acc,`zero,
                                         `(new_source(x)),
                                         `(new_source(y)),`o)>;
        case MapJoin(`mx,`my,`r,`x,`y):
            return #<MapJoin(`mx,`my,`r,
                             `(new_source(x)),
                             `(new_source(y)))>;
        case MapAggregateJoin(`mx,`my,`r,`acc,`zero,`x,`y):
            return #<MapAggregateJoin(`mx,`my,`r,`acc,`zero,
                                      `(new_source(x)),
                                      `(new_source(y)))>;
        case CrossProduct(`mx,`my,`r,`x,`y):
            return #<CrossProduct(`mx,`my,`r,
                                  `(new_source(x)),
                                  `(new_source(y)))>;
        case CrossAggregateProduct(`mx,`my,`r,`acc,`zero,`x,`y):
            return #<CrossAggregateProduct(`mx,`my,`r,`acc,`zero,
                                           `(new_source(x)),
                                           `(new_source(y)))>;
        case GroupByJoin(`kx,`ky,`gx,`gy,`acc,`z,`r,`x,`y,`o):
            return #<GroupByJoin(`kx,`ky,`gx,`gy,`acc,`z,`r,
                                 `(new_source(x)),
                                 `(new_source(y)),`o)>;
        case Repeat(`f,`ds,`max):
            return #<Repeat(`(preprocess(f)),`(preprocess(ds)),`max)>;
        case repeat(`f,`ds,`max):
            return #<repeat(`(preprocess(f)),`(preprocess(ds)),`max)>;
        case Closure(`f,`ds,`max):
            return #<Closure(`(preprocess(f)),`(preprocess(ds)),`max)>;
        case `f(...):
            if (! #[ParsedSource,Generator,BinarySource].member(#<`f>))
                fail;
            return #<BSPSource(`(source_num++),`e)>;
        case `f(...as):
            Trees bs = #[];
            for (Tree a: as)
                bs = bs.append(preprocess(a));
            return #<`f(...bs)>;
        };
        return e;
    }

    /** returns the source num of a BSP operation */
    private static int source_num ( Tree e ) {
        match e {
        case BSP(tuple(`i,...),...):
            return (int)((LongLeaf)i).value();
        case BSP(`i,...):
            return (int)((LongLeaf)i).value();
        case BSPSource(`i,_):
            return (int)((LongLeaf)i).value();
        };
        return -1;
    }

    private static Tree subst_getCache_num ( int n, int m, Tree e ) {
        match e {
        case getCache(`cache,`k):
            if (!k.equals(#<`n>))
                fail;
            return #<getCache(`cache,`m)>;
        case `f(...as):
            Trees bs = #[];
            for (Tree a: as)
                bs = bs.append(subst_getCache_num(n,m,a));
            return #<`f(...bs)>;
        };
        return e;
    }

    private static Tree subst_setCache_num ( int n, int m, Tree e ) {
        match e {
        case setCache(`cache,`k,...r):
            if (!k.equals(#<`n>))
                fail;
            return #<setCache(`cache,`m,...r)>;
        case `f(...as):
            Trees bs = #[];
            for (Tree a: as)
                bs = bs.append(subst_setCache_num(n,m,a));
            return #<`f(...bs)>;
        };
        return e;
    }

    private static Tree set_cache_num ( Tree e, int n ) {
        match e {
        case BSP(`m,...r):
            return subst_setCache_num(source_num(e),n,#<BSP(`n,...r)>);
        case BSPSource(_,`x):
            return #<BSPSource(`n,`x)>;
        };
        return e;
    }

    // return the source numbers of a BSP operation
    private static Trees source_nums ( Tree e ) {
        match e {
        case BSP(`i,...):
            return #[`i];
        case Merge(`x,`y):
            return source_nums(x).append(source_nums(y));
        case BSPSource(`i,_):
            return #[`i];
        };
        return #[];
    }

    private static Trees source_nums ( Trees s ) {
        Trees res = #[];
        for ( Tree x: s )
            res = res.append(source_nums(x));
        return res;
    }

    private static Tree getCache ( Tree var, Tree e, Tree body ) {
        Trees ns = source_nums(e);
        if (ns.length() > 0)
            return subst(var,#<getCache(cache,...ns)>,body);
        return body;
    }

    private static Tree getCache ( Tree var, Trees s, Tree body ) {
        Trees ns = source_nums(s);
        return subst(var,#<getCache(cache,...ns)>,body);
    }

    /** optimize a BSP plan after BSP fusion */
    private static Tree post_simplify ( Tree e ) {
        match e {
        case cmap(lambda(`x,`S(`a)),`b):
            if (is_collection(S))
                return post_simplify(#<map(lambda(`x,`a),`b)>);
            else fail
        case `f(...as):
            Trees bs = #[];
            for (Tree a: as)
                bs = bs.append(post_simplify(a));
            return #<`f(...bs)>;
        };
        return e;
    }

    private static Tree processBSP ( Tree e ) {
        return post_simplify(Simplification.simplify_all(Simplification.rename(e)));
    }

    private static Tree mkBSP ( Tree superstep, Tree state, Tree orderp, Tree input ) {
        int rn = source_num++;
        superstep = processBSP(subst(#<o_>,#<`rn>,getCache(#<i_>,input,superstep)));
        return #<BSP(`rn,`superstep,`state,`orderp,`input)>;
    }

    private static Tree mkBSPL ( Tree superstep, Tree state, Tree orderp, Trees input ) {
        int rn = source_num++;
        superstep = processBSP(subst(#<o_>,#<`rn>,getCache(#<i>,input,superstep)));
        return #<BSP(`rn,`superstep,`state,`orderp,...input)>;
    }

    private static Tree mkBSPL ( int[] ns, Tree superstep, Tree state, Tree orderp, Trees input ) {
        superstep = processBSP(subst(#<o_>,#<`(ns[0])>,getCache(#<i_>,input,superstep)));
        Trees s = #[];
        for ( int n: ns )
            s = s.append(#<`n>);
        return #<BSP(tuple(...s),`superstep,`state,`orderp,...input)>;
    }

    private static Tree mkBSPL ( int n, Tree superstep, Tree state, Tree orderp, Trees input ) {
        superstep = processBSP(subst(#<o_>,#<`n>,getCache(#<i_>,input,superstep)));
        return #<BSP(`n,`superstep,`state,`orderp,...input)>;
    }

    private static Tree mkBSP2 ( Tree superstep, Tree state, Tree orderp, Tree left, Tree right ) {
        int rn = source_num++;
        superstep = processBSP(subst(#<o_>,#<`rn>,getCache(#<i_>,left,
                                                           getCache(#<j_>,right,superstep))));
        return #<BSP(`rn,`superstep,`state,`orderp,`left,`right)>;
    }

    /** construct a BSP plan from a physical plan
     * @param e the physical plan
     * @return the BSP plan
     */
    private static Tree mr2bsp ( Tree e ) {
        match e {
        case Aggregate(`acc,`zero,`S):
            return #<Aggregate(`acc,`zero,`(mr2bsp(S)))>;
        case cMap(`m,`S):
            return mkBSP(#<lambda(tuple(cache,ms,k,peer),
                                  setCache(cache,o_,cmap(`m,i_),
                                           tuple(BAG(),tuple(),TRUE())))>,
                         #<tuple()>,
                         #<false>,
                         mr2bsp(S));
        case AggregateMap(`m,`acc,`zero,`S):
            return mkBSP(#<lambda(tuple(cache,ms,k,peer),
                                  setCache(cache,o_,aggregate(`acc,`zero,cmap(`m,i_)),
                                           tuple(BAG(),tuple(),TRUE())))>,
                         #<tuple()>,
                         #<false>,
                         mr2bsp(S));
        case MapReduce(`m,`r,`S,`o):
            return mkBSP(#<lambda(tuple(cache,ms,map_step,peer),
                                  if(map_step,
                                     tuple(cmap(lambda(tuple(k,c),
                                                       bag(tuple(k,tuple(k,c)))),
                                                cmap(`m,i_)),
                                           false,
                                           FALSE()),
                                     setCache(cache,o_,`(o.equals(#<true>)   // need to sort the result?
                                                   ? #<cmap(lambda(tuple(k,s),
                                                                   cmap(lambda(x,bag(tuple(x,k))),
                                                                        apply(`r,tuple(k,s)))),
                                                            groupBy(ms))>
                                                   : #<cmap(`r,groupBy(ms))>),
                                              tuple(BAG(),false,TRUE()))))>,
                         #<true>,
                         o,
                         mr2bsp(S));
        case MapAggregateReduce(`m,`r,`acc,`zero,`S,_):
            return mkBSP(#<lambda(tuple(cache,ms,map_step,peer),
                                  if(map_step,
                                     tuple(cmap(lambda(tuple(k,c),
                                                       bag(tuple(k,tuple(k,c)))),
                                                cmap(`m,i_)),
                                           false,
                                           FALSE()),
                                     setCache(cache,o_,aggregate(`acc,`zero,cmap(`r,groupBy(ms))),
                                              tuple(BAG(),false,TRUE()))))>,
                         #<true>,
                         #<false>,
                         mr2bsp(S));
        case MapCombineReduce(`m,`c,`r,`S,`o):
            return mkBSP(#<lambda(tuple(cache,ms,map_step,peer),
                                  if(map_step,
                                     tuple(cmap(lambda(tuple(k,s),
                                                       cmap(lambda(x,bag(tuple(k,tuple(k,x)))),
                                                            apply(`c,tuple(k,s)))),
                                                groupBy(cmap(`m,i_))),
                                           false,
                                           FALSE()),
                                     setCache(cache,o_,`(o.equals(#<true>)   // need to sort the result?
                                                   ? #<cmap(lambda(tuple(k,s),
                                                                   cmap(lambda(x,bag(tuple(x,k))),
                                                                        apply(`r,tuple(k,s)))),
                                                            groupBy(ms))>
                                                   : #<cmap(`r,groupBy(ms))>),
                                              tuple(BAG(),false,TRUE()))))>,
                         #<true>,
                         o,
                         mr2bsp(S));
        case MapCombineReduce2(`mx,`my,`c,`r,`x,`y,`o):
            return mkBSP2(#<lambda(tuple(cache,ms,map_step,peer),
                                   if(map_step,
                                      tuple(callM(union,`unionM,
                                                  cmap(lambda(tuple(kx,x),
                                                              bag(tuple(kx,tuple(kx,tuple(1,x))))),
                                                       cmap(`mx,i_)),
                                                  cmap(lambda(tuple(ky,y),
                                                              bag(tuple(ky,tuple(ky,tuple(2,y))))),
                                                       cmap(`my,j_))),
                                            false,
                                            FALSE()),
                                      setCache(cache,o_,cmap(lambda(tuple(k,s),
                                                              cmap(lambda(tuple(kk,ss),
                                                                          cmap(lambda(x,bag(tuple(kk,x))),
                                                                               apply(`c,tuple(kk,ss)))),
                                                                   groupBy(apply(`r,
                                                                                 tuple(cmap(lambda(tuple(kx,x),
                                                                                                   if(callM(eq,`eqM,kx,1),
                                                                                                      bag(x),
                                                                                                      bag())),
                                                                                            s),
                                                                                       cmap(lambda(tuple(ky,y),
                                                                                                   if(callM(eq,`eqM,ky,2),
                                                                                                      bag(y),
                                                                                                      bag())),
                                                                                            s)))))),
                                                       groupBy(ms)),
                                               tuple(BAG(),false,TRUE()))))>,
                          #<true>,
                          o,
                          mr2bsp(x),
                          mr2bsp(y));
        case MapReduce2(`mx,`my,`r,`x,`y,`o):
            return mkBSP2(#<lambda(tuple(cache,ms,map_step,peer),
                                   if(map_step,
                                      tuple(callM(union,`unionM,
                                                  cmap(lambda(tuple(kx,x),
                                                              bag(tuple(kx,tuple(kx,tuple(1,x))))),
                                                       cmap(`mx,i_)),
                                                  cmap(lambda(tuple(ky,y),
                                                              bag(tuple(ky,tuple(ky,tuple(2,y))))),
                                                       cmap(`my,j_))),
                                            false,
                                            FALSE()),
                                      setCache(cache,o_,cmap(lambda(tuple(k,s),
                                                              cmap(lambda(x,bag(`(o.equals(#<true>) ? #<tuple(x,k)> : #<x>))),
                                                                   apply(`r,
                                                                         tuple(cmap(lambda(tuple(kx,x),
                                                                                           if(callM(eq,`eqM,kx,1),
                                                                                              bag(x),
                                                                                              bag())),
                                                                                    s),
                                                                               cmap(lambda(tuple(ky,y),
                                                                                           if(callM(eq,`eqM,ky,2),
                                                                                              bag(y),
                                                                                              bag())),
                                                                                    s))))),
                                                       groupBy(ms)),
                                               tuple(BAG(),false,TRUE()))))>,
                          #<true>,
                          o,
                          mr2bsp(x),
                          mr2bsp(y));
        case MapAggregateReduce2(`mx,`my,`r,`acc,`zero,`x,`y,_):
            return mkBSP2(#<lambda(tuple(cache,ms,map_step,peer),
                                   if(map_step,
                                      tuple(callM(union,`unionM,
                                                  cmap(lambda(tuple(kx,x),
                                                              bag(tuple(kx,tuple(kx,tuple(1,x))))),
                                                       cmap(`mx,i_)),
                                                  cmap(lambda(tuple(ky,y),
                                                              bag(tuple(ky,tuple(ky,tuple(2,y))))),
                                                       cmap(`my,j_))),
                                            false,
                                            FALSE()),
                                      setCache(cache,o_,aggregate(`acc,`zero,
                                                      cmap(lambda(tuple(k,s),
                                                              apply(`r,tuple(cmap(lambda(tuple(kx,x),
                                                                                        if(callM(eq,`eqM,kx,1),
                                                                                           bag(x),
                                                                                           bag())),
                                                                                 s),
                                                                            cmap(lambda(tuple(ky,y),
                                                                                        if(callM(eq,`eqM,ky,2),
                                                                                           bag(y),
                                                                                           bag())),
                                                                                 s)))),
                                                           groupBy(ms))),
                                               tuple(BAG(),false,TRUE()))))>,
                          #<true>,
                          #<false>,
                          mr2bsp(x),
                          mr2bsp(y));
        case MapJoin(`mx,`my,`r,`x,`y):
             return mr2bsp(#<MapReduce2(`mx,`my,
                                        lambda(tuple(xs,ys),cmap(lambda(x,apply(`r,tuple(x,ys))),xs)),
                                        `x,`y,false)>);
        case MapAggregateJoin(`mx,`my,`r,`acc,`zero,`x,`y):
            return mr2bsp(#<MapAggregateReduce2(`mx,`my,
                                                lambda(tuple(xs,ys),cmap(lambda(x,apply(`r,tuple(x,ys))),xs)),
                                                `acc,`zero,`x,`y)>);
        case GroupByJoin(`kx,`ky,lambda(`vx,`gx),lambda(`vy,`gy),`acc,`zero,`r,`x,`y,`o):
            int n = (int)Math.floor(Math.sqrt(Config.nodes));
            int m = n;
            // System.err.println("Using a groupBy join on a "+n+"*"+m+" grid of partitions");
            Tree xkey = #<cmap(lambda(i,bag(tuple(call(plus,call(mod,call(hash_code,`gx),`m),call(times,`m,i)),
                                                  tuple(1,call(plus,call(mod,call(hash_code,`gx),`m),call(times,`m,i)),`vx)))),
                               range(0,`(n-1)))>;
            Tree ykey = #<cmap(lambda(j,bag(tuple(call(plus,call(times,call(mod,call(hash_code,`gy),`n),`m),j),
                                                  tuple(2,call(plus,call(times,call(mod,call(hash_code,`gy),`n),`m),j),`vy)))),
                               range(0,`(m-1)))>;
            type_inference(xkey);
            type_inference(ykey);
            xkey = PlanGeneration.makePlan(xkey);
            ykey = PlanGeneration.makePlan(ykey);
            return mkBSP2(#<lambda(tuple(cache,ms,map_step,peer),
                                   if(map_step,
                                      tuple(callM(union,`unionM,
                                                  cmap(lambda(`vx,`xkey),i_),
                                                  cmap(lambda(`vy,`ykey),j_)),
                                            false,
                                            FALSE()),
                                      setCache(cache,o_,
                                               mergeGroupByJoin(`kx,`ky,lambda(`vx,`gx),lambda(`vy,`gy),`acc,`zero,`r,
                                                      cmap(lambda(tuple(kx,p,x),
                                                                  if(callM(eq,`eqM,kx,1),bag(tuple(p,x)),bag())),ms),
                                                      cmap(lambda(tuple(ky,p,y),
                                                                  if(callM(eq,`eqM,ky,2),bag(tuple(p,y)),bag())),ms),`o),
                                               tuple(BAG(),false,TRUE()))))>,
                          #<true>,
                          o,
                          mr2bsp(x),
                          mr2bsp(y));
        case CrossProduct(`mx,`my,`r,`x,`y):
            return mkBSP(#<lambda(tuple(cache,ms,ys,peer),
                                  tuple(BAG(),
                                        setCache(cache,o_,
                                                 cmap(lambda(x,
                                                             cmap(lambda(y,apply(`r,tuple(x,y))),
                                                                  cmap(`my,ys))),
                                                      cmap(`mx,i_)),
                                                 tuple()),
                                        TRUE()))>,
                         #<Collect(`(mr2bsp(y)))>,
                         #<false>,
                         mr2bsp(x));
        case CrossAggregateProduct(`mx,`my,`r,`acc,`zero,`x,`y):
            return mkBSP(#<lambda(tuple(cache,ms,ys,peer),
                                  tuple(BAG(),
                                        setCache(cache,o_,
                                                 aggregate(`acc,`zero,
                                                           cmap(lambda(x,
                                                                       cmap(lambda(y,apply(`r,tuple(x,y))),
                                                                            cmap(`my,ys))),
                                                                cmap(`mx,i_)),
                                                           tuple())),
                                        TRUE()))>,
                         #<Collect(`(mr2bsp(y)))>,
                         #<false>,
                         mr2bsp(x));
        case Repeat(lambda(`v,`b),`ds,`max):
            Tree step = bspSimplify(mr2bsp(b));
            int step_cache_num = source_num(step);
            match step {
            case BSP(`n,`s,`k0,_,...as):
                Tree ds_source = mr2bsp(ds);
                int ds_cache_num = source_num(ds_source);
                ds_source = set_cache_num(ds_source,step_cache_num);
                // the initial values of all data sources
                Trees sources = #[`ds_source];
                Tree step_input = #<0>;
                for ( Tree x: as )
                    match x {
                    case BSPSource(`j,`y):
                        if (y.equals(v))
                            step_input = j;
                        else sources = sources.append(x);
                    case _: sources = sources.append(x);
                    };
                s = subst_getCache_num((int)((LongLeaf)step_input).value(),step_cache_num,s);
                return mkBSPL(step_cache_num,
                              #<lambda(tuple(cache,ms,tuple(k,steps),peer),
                                       let(tuple(ts,kk,step_end),
                                           apply(`s,tuple(cache,ms,k,peer)),
                                           if(step_end,   // end of repeat step
                                              setCache(cache,`step_cache_num,
                                                       map(lambda(tuple(x,bb),x),
                                                           getCache(cache,`step_cache_num)),
                                                      tuple(bag(),
                                                            tuple(`k0,callM(plus,`plusM,steps,1)),
                                                            if(callM(gt,`gtM,steps,`max),
                                                               TRUE(),           // if # of steps > limit, exit
                                                               callM(not,`notM,  // ... else check the stopping condition
                                                                     aggregate(lambda(tuple(x,y),callM(or,`orM,x,y)),
                                                                               false,
                                                                               map(lambda(tuple(x,bb),bb),
                                                                                   getCache(cache,`step_cache_num))))))),
                                              tuple(ts,tuple(kk,steps),FALSE()))))>,
                              #<tuple(`k0,2)>,
                              #<false>,
                              sources);
            case `x: throw new Error("Cannot compile the repeat function: "+x);
            }
        // when the repeat variable is in memory
        case repeat(lambda(`v,`b),`ds,`max):
            if (!Config.hadoop_mode)
                fail;
            Tree step = bspSimplify(mr2bsp(b));
            int step_cache_num = source_num(step);
            step = subst(v,#<map(lambda(tuple(x,b),x),getCache(cache,`step_cache_num))>,step);
            match step {
            case BSP(`n,`s,`k0,_,...as):
                // the initial values of all data sources
                Trees sources = #[];
                for ( Tree x: as )
                    if (!x.equals(v))
                        sources = sources.append(x);
                Tree res = mkBSPL(#<lambda(tuple(cache,ms,tuple(k,steps,firstp,S),peer),
                                      let(ignore,if(firstp,setCache(cache,`step_cache_num,S,0),0),
                                          let(tuple(ts,kk,step_end),
                                              apply(`s,tuple(cache,ms,k,peer)),
                                              if(step_end,   // end of repeat step
                                                 setCache(cache,o_,map(lambda(tuple(x,b),x),getCache(cache,`step_cache_num)),
                                                    setCache(cache,`step_cache_num,
                                                             distribute(peer,getCache(cache,`step_cache_num)),
                                                             tuple(bag(),
                                                                   tuple(`k0,callM(plus,`plusM,steps,1),false,bag()),
                                                                   if(callM(gt,`gtM,steps,`max),
                                                                      TRUE(),           // if # of steps > limit, exit
                                                                      callM(not,`notM,  // ... else check the stopping condition
                                                                            aggregate(lambda(tuple(x,y),callM(or,`orM,x,y)),
                                                                                      false,
                                                                                      map(lambda(tuple(x,bb),bb),
                                                                                          getCache(cache,`step_cache_num)))))))),
                                                 tuple(ts,tuple(kk,steps,false,bag()),FALSE())))))>,
                                  #<tuple(`k0,2,true,map(lambda(x,tuple(x,false)),`ds))>,
                                  #<false>,
                                  sources);
                return #<Collect(`res)>;   // must return a memory bag
            case `x: fail
            }
        case Closure(lambda(`v,`b),`ds,`max):
            Tree step = bspSimplify(mr2bsp(b));
            int step_cache_num = source_num(step);
            match step {
            case BSP(`n,`s,`k0,_,...as):
                Tree ds_source = mr2bsp(ds);
                int ds_cache_num = source_num(ds_source);
                ds_source = set_cache_num(ds_source,step_cache_num);
                // the initial values of all data sources
                Trees sources = #[`ds_source];
                Tree step_input = #<0>;
                for ( Tree x: as )
                    match x {
                    case BSPSource(`j,`y):
                        if (y.equals(v))
                            step_input = j;
                        else sources = sources.append(x);
                    case _: sources = sources.append(x);
                    };
                s = subst_getCache_num((int)((LongLeaf)step_input).value(),step_cache_num,s);
                return mkBSPL(step_cache_num,
                              #<lambda(tuple(cache,ms,tuple(k,steps,len),peer),
                                       let(tuple(ts,kk,step_end),
                                           apply(`s,tuple(cache,ms,k,peer)),
                                           if(step_end,   // end of repeat step
                                              let(newLen,callM(count,`countM,getCache(cache,`step_cache_num)),
                                                  tuple(bag(),
                                                        tuple(`k0,callM(plus,`plusM,steps,1),newLen),
                                                        if(callM(gt,`gtM,steps,`max),
                                                           TRUE(),           // if # of steps > limit, exit
                                                           // ... else check if the new size is the same as the old size
                                                           callM(geq,`geqM,len,newLen)))),
                                              tuple(ts,tuple(kk,steps,len),FALSE()))))>,
                              #<tuple(`k0,1,callM(coerce,`coerceM,0,4))>,
                              #<false>,
                              sources);
            case `x: throw new Error("Cannot compile the closure function: "+x);
            }
        case Loop(lambda(tuple(...vs),tuple(...bs)),tuple(...ss),`max):
            Tree[] steps = new Tree[vs.length()];
            Tree[] inits = new Tree[vs.length()];
            int[] cache_num = new int[vs.length()];
            Tree[] k = new Tree[vs.length()];
            Trees sources = #[];
            Trees all_cache = #[];
            boolean dont_fuse = false;
            for ( int i = 0; i < vs.length(); i++ ) {
                inits[i] = mr2bsp(ss.nth(i));
                sources = sources.append(inits[i]);
                cache_num[i] = source_num(inits[i]);
                all_cache = all_cache.append(#<getCache(cache,`(cache_num[i]))>);
            };
            for ( int i = 0; i < vs.length(); i++ ) {
                Tree ee = bspSimplify(mr2bsp(bs.nth(i)));
                match ee {
                case BSP(`n,`s,`k0,_,...as):
                    steps[i] = subst_setCache_num(source_num(ee),cache_num[i],s);
                    k[i] = k0;
                    loop: for ( Tree x: as )
                        match x {
                        case BSPSource(`m,`w):
                            if (!vs.member(w))
                                fail;
                            for ( int j = 0; j < vs.length(); j++ )
                                if (w.equals(vs.nth(j)))
                                    steps[i] = subst_getCache_num(source_num(x),cache_num[j],steps[i]);
                        case BSPSource(`n1,`d1):
                            for ( Tree y: sources )
                                match y {
                                case BSPSource(`n2,`d2):
                                    if (d1.equals(d2)) {
                                        steps[i] = subst_getCache_num(source_num(x),source_num(y),steps[i]);
                                        continue loop;
                                    }
                                };
                            sources = sources.append(x);
                        case _: sources = sources.append(x);
                        };
                case _: dont_fuse = true;
                }
            };
            if (dont_fuse)
                fail;
            Tree code = #<tuple(BAG(),tuple(`(k[0]),0,1),TRUE())>;
            for ( int i = 0; i < vs.length(); i++ )
                code = #<if(callM(eq,`eqM,i,`i),
                            let(tuple(ts,kk,step_end),
                                apply(`(steps[i]),tuple(cache,ms,k,peer)),
                                if(step_end,   // end of repeat step
                                   `((i+1 < vs.length())
                                     ? #<tuple(bag(),tuple(`(k[i+1]),`(i+1),steps),FALSE())>
                                     : #<tuple(bag(),
                                               tuple(`(k[0]),0,callM(plus,`plusM,steps,1)),
                                               if(callM(gt,`gtM,steps,`max),      // if # of steps > limit, exit
                                                  TRUE(),
                                                  FALSE()))>),
                                   tuple(ts,tuple(kk,i,steps),FALSE()))),
                            `code)>;
            return mkBSPL(cache_num,
                          #<lambda(tuple(cache,ms,tuple(k,i,steps),peer),`code)>,
                          #<tuple(`(k[0]),0,2)>,
                          #<false>,
                          sources);
        case `f(...as):
            Trees bs = #[];
            for ( Tree a: as )
                bs = bs.append(mr2bsp(a));
            return #<`f(...bs)>;
        };
        return e;
    }

    /** simplify the BSP plan by fusing consequtive BSP plans */
    private static Tree bspSimplify ( Tree e ) {
        match e {
        case BSP(`n,`s2,`k2,`o,...r,BSP(_,`s1,`k1,_,...s),...t):
            Trees ys = r.append(t);
            Trees ns = #[];
            loop: for ( Tree x: s )
                match x {
                case BSPSource(`n1,`d1):
                    for ( Tree y: ys )
                        match y {
                        case BSPSource(`n2,`d2):
                            if (d1.equals(d2)) {
                                s1 = subst_getCache_num(source_num(x),source_num(y),s1);
                                continue loop;
                            }
                        };
                    ns = ns.append(x);
                case _: ns = ns.append(x);
                };
            int[] nvs = null;
            match n {
            case `f(...as):
                nvs = new int[as.length()];
                for ( int i = 0; i < as.length(); i++ )
                    nvs[i] = (int)((LongLeaf)as.nth(i)).value();
            case _:
                nvs = new int[1];
                nvs[0] = (int)((LongLeaf)n).value();
            };
            return bspSimplify(mkBSPL(nvs,
                                      #<lambda(tuple(cache,ms,tuple(first,k),peer),
                                               if(first,
                                                  let(tuple(ts,kk,b),apply(`s1,tuple(cache,ms,k,peer)),
                                                      let(exit,
                                                          synchronize(peer,b),   // poll all peers: do you want to exit?
                                                          // all peers must aggree to exit the inner BSP
                                                          //    and proceed to the outer BSP
                                                          tuple(ts,tuple(callM(not,`notM,exit),
                                                                         if(exit,`k2,kk)),FALSE()))),
                                                  let(tuple(ts,kk,bb),apply(`s2,tuple(cache,ms,k,peer)),
                                                      tuple(ts,tuple(false,kk),bb))))>,
                                      #<tuple(true,`k1)>,
                                      o,
                                      #[...r,...ns,...t]));
        case `f(...as):
            Trees bs = #[];
            for ( Tree a: as )
                bs = bs.append(bspSimplify(a));
            return #<`f(...bs)>;
        };
        return e;
    }

    private static Tree post_simplify_plan ( Tree e ) {
        match e {
        case setCache(`cache,`a,`v,`ret):
            return post_simplify_plan(#<setNth(`cache,`a,materialize(`v),`ret)>);
        case getCache(`cache,`a,...as):
            Tree z = #<nth(`cache,`a)>;
            for ( Tree x: as )
                z = #<callM(union,`unionM,`z,nth(`cache,`x))>;
            return z;
        case `f(...as):
            Trees bs = #[];
            for (Tree a: as)
                bs = bs.append(post_simplify_plan(a));
            return #<`f(...bs)>;
        };
        return e;
    }

    /** construct and simplify the BSP plan from a physical plan
     * @param plan the physical plan
     * @return the BSP plan
     */
    public static Tree constructBSPplan ( Tree plan ) {
        return post_simplify_plan(bspSimplify(mr2bsp(preprocess(plan))));
    }
}
