﻿/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

using System;
using System.IO;
using System.Transactions;
using System.Threading;

using NUnit.Framework;
using Apache.NMS.Test;
using Apache.NMS.ActiveMQ.Commands;
using Apache.NMS.ActiveMQ.Transport;
using System.Data.SqlClient;
using System.Collections;

namespace Apache.NMS.ActiveMQ.Test
{
    // PREREQUISITES to run those tests :
    // - A local instance of sql server 2008 running, with a db (e.g. TestDB) that 
    //   as a table (e.g. TestTable) with a single column (e.g. TestID) of type INT.
    //   The test default to using an SQL Connection string with a user id of 
    //   'user' and the password 'password'
    // - AMQ Server 5.4.2+
    // - NMS 1.5+
    //
    // IMPORTANT
    // Because you cannot perform recovery in a process more than once you cannot
    // run these tests sequentially in the NUnit GUI or NUnit Console runner you
    // must run them one at a time from the console or using a tool like the ReSharper
    // plugin for Visual Studio.
    //

    public class DtcTransactionsTestSupport : NMSTestSupport
    {
        protected const int MSG_COUNT = 5;
        protected string nonExistantPath;
        protected NetTxConnectionFactory dtcFactory;
        
        private ITrace oldTracer;

        protected const string sqlConnectionString =
            "Data Source=localhost;Initial Catalog=TestDB;User ID=user;Password=password";
        protected const string testTable = "TestTable";
        protected const string testColumn = "TestID";
        protected const string testQueueName = "TestQueue";
        protected const string connectionURI = "tcpfaulty://${activemqhost}:61616";

        [SetUp]
        public override void SetUp()
        {
            this.oldTracer = Tracer.Trace;
            this.nonExistantPath = Path.Combine(Directory.GetCurrentDirectory(), Guid.NewGuid().ToString());

            base.SetUp();

            PurgeDestination();
        }

        [TearDown]
        public override void TearDown()
        {
            DeleteDestination();
            
            base.TearDown();

            Tracer.Trace = this.oldTracer;
        }

        protected void OnException(Exception ex)
        {
            Tracer.DebugFormat("Test Driver received Error Notification: {0}", ex.Message);
        }

        #region Database Utility Methods

        protected static void PrepareDatabase()
        {
            using (SqlConnection sqlConnection = new SqlConnection(sqlConnectionString))
            {
                sqlConnection.Open();

                // remove all data from test table
                using (SqlCommand sqlCommand = new SqlCommand(string.Format("TRUNCATE TABLE {0}", testTable), sqlConnection))
                {
                    sqlCommand.ExecuteNonQuery();
                }

                // add some data to test table
                for (int i = 0; i < MSG_COUNT; ++i)
                {
                    using (SqlCommand sqlCommand = new SqlCommand(
                        string.Format(
                                        "INSERT INTO {0} ({1}) values ({2})",
                                        testTable,
                                        testColumn,
                                        i), sqlConnection))
                    {
                        sqlCommand.ExecuteNonQuery();
                    }
                }

                sqlConnection.Close();
            }
        }

        protected static void PurgeDatabase()
        {
            using (SqlConnection sqlConnection = new SqlConnection(sqlConnectionString))
            {
                sqlConnection.Open();

                // remove all data from test table
                using (SqlCommand sqlCommand = new SqlCommand(string.Format("TRUNCATE TABLE {0}", testTable), sqlConnection))
                {
                    sqlCommand.ExecuteNonQuery();
                }

                sqlConnection.Close();
            }
        }

        protected static IList ExtractDataSet()
        {
            IList entries = new ArrayList();

            using (SqlConnection sqlConnection = new SqlConnection(sqlConnectionString))
            {
                sqlConnection.Open();

                using (SqlCommand sqlReadCommand = new SqlCommand(
                    string.Format("SELECT {0} FROM {1}", testColumn, testTable), sqlConnection))
                using (SqlDataReader reader = sqlReadCommand.ExecuteReader())
                {
                    while (reader.Read())
                    {
                        entries.Add("Hello World " + (int)reader[0]);
                    }
                }
            }

            return entries;
        }

        protected static void VerifyDatabaseTableIsEmpty()
        {
            using (SqlConnection sqlConnection = new SqlConnection(sqlConnectionString))
            {
                sqlConnection.Open();
                SqlCommand sqlCommand = new SqlCommand(
                    string.Format("SELECT COUNT(*) FROM {0}", testTable),
                    sqlConnection);
                int count = (int)sqlCommand.ExecuteScalar();
                Assert.AreEqual(0, count, "wrong number of rows in DB");
            }
        }

        protected static void VerifyDatabaseTableIsFull()
        {
            VerifyDatabaseTableIsFull(MSG_COUNT);
        }

        protected static void VerifyDatabaseTableIsFull(int expected)
        {
            using (SqlConnection sqlConnection = new SqlConnection(sqlConnectionString))
            {
                sqlConnection.Open();
                SqlCommand sqlCommand = new SqlCommand(
                    string.Format("SELECT COUNT(*) FROM {0}", testTable),
                    sqlConnection);
                int count = (int)sqlCommand.ExecuteScalar();
                Assert.AreEqual(expected, count, "wrong number of rows in DB");
            }
        }

        #endregion

        #region Destination Utility Methods

        protected static void DeleteDestination()
        {
            IConnectionFactory factory = new ConnectionFactory(ReplaceEnvVar(connectionURI));

            using (Connection connection = factory.CreateConnection() as Connection)
            {
                using (ISession session = connection.CreateSession())
                {
                    IQueue queue = session.GetQueue(testQueueName);
                    try
                    {
                        connection.DeleteDestination(queue);
                    }
                    catch
                    {
                    }
                }
            }
        }

        protected static void PurgeDestination()
        {
            IConnectionFactory factory = new ConnectionFactory(ReplaceEnvVar(connectionURI));

            using (IConnection connection = factory.CreateConnection())
            {
                connection.Start();

                using (ISession session = connection.CreateSession())
                using (IMessageConsumer consumer = session.CreateConsumer(session.GetQueue(testQueueName)))
                {
                    IMessage recvd;
                    while ((recvd = consumer.Receive(TimeSpan.FromMilliseconds(3000))) != null)
                    {
                        Tracer.Debug("Setup Purged Message: " + recvd);
                    }
                }
            }
        }

        protected static void PurgeAndFillQueue()
        {
            PurgeAndFillQueue(MSG_COUNT);
        }

        protected static void PurgeAndFillQueue(int msgCount)
        {
            IConnectionFactory factory = new ConnectionFactory(ReplaceEnvVar(connectionURI));

            using (IConnection connection = factory.CreateConnection())
            {
                connection.Start();

                using (ISession session = connection.CreateSession())
                {
                    IQueue queue = session.GetQueue(testQueueName);

                    // empty queue
                    using (IMessageConsumer consumer = session.CreateConsumer(queue))
                    {
                        while ((consumer.Receive(TimeSpan.FromMilliseconds(2000))) != null)
                        {
                        }
                    }

                    // enqueue several messages
                    using (IMessageProducer producer = session.CreateProducer(queue))
                    {
                        producer.DeliveryMode = MsgDeliveryMode.Persistent;

                        for (int i = 0; i < msgCount; i++)
                        {
                            producer.Send(session.CreateTextMessage(i.ToString()));
                        }
                    }
                }
            }
        }

        #endregion

        #region Broker Queue State Validation Routines

        protected static void VerifyBrokerQueueCountNoRecovery()
        {
            VerifyBrokerQueueCountNoRecovery(MSG_COUNT);
        }

        protected static void VerifyBrokerQueueCountNoRecovery(int expectedNumberOfMessages)
        {
            IConnectionFactory factory = new ConnectionFactory(ReplaceEnvVar(connectionURI));

            using (IConnection connection = factory.CreateConnection())
            {
                // check messages are present in the queue
                using (ISession session = connection.CreateSession())
                {
                    IQueue queue = session.GetQueue(testQueueName);

                    using (IQueueBrowser browser = session.CreateBrowser(queue))
                    {
                        connection.Start();
                        int count = 0;
                        IEnumerator enumerator = browser.GetEnumerator();

                        while(enumerator.MoveNext())
                        {
                            IMessage msg = enumerator.Current as IMessage;
                            Assert.IsNotNull(msg, "message is not in the queue !");
                            count++;
                        }

                        // count should match the expected count
                        Assert.AreEqual(expectedNumberOfMessages, count);
                    }
                }
            }
        }

        protected void VerifyBrokerQueueCount()
        {
            VerifyBrokerQueueCount(MSG_COUNT, connectionURI);
        }

        protected void VerifyBrokerQueueCount(int expectedCount)
        {
            VerifyBrokerQueueCount(expectedCount, connectionURI);
        }

        protected void VerifyBrokerQueueCount(string connectionUri)
        {
            VerifyBrokerQueueCount(MSG_COUNT, connectionUri);
        }

        protected void VerifyBrokerQueueCount(int expectedCount, string connectionUri)
        {           
            using (INetTxConnection connection = dtcFactory.CreateNetTxConnection())
            {
                // check messages are present in the queue
                using (INetTxSession session = connection.CreateNetTxSession())
                {
                    IQueue queue = session.GetQueue(testQueueName);

                    using (IQueueBrowser browser = session.CreateBrowser(queue))
                    {
                        connection.Start();
                        int count = 0;
                        IEnumerator enumerator = browser.GetEnumerator();

                        while (enumerator.MoveNext())
                        {
                            IMessage msg = enumerator.Current as IMessage;
                            Assert.IsNotNull(msg, "message is not in the queue !");
                            count++;
                        }

                        // count should match the expected count
                        Assert.AreEqual(expectedCount, count);
                    }
                }
            }
        }

        protected void VerifyNoMessagesInQueueNoRecovery()
        {
            VerifyBrokerQueueCountNoRecovery(0);
        }

        protected void VerifyNoMessagesInQueue()
        {
            VerifyBrokerQueueCount(0);
        }

        protected static void VerifyBrokerStateNoRecover(int expectedNumberOfMessages)
        {
            IConnectionFactory factory = new ConnectionFactory(ReplaceEnvVar(connectionURI));

            using (IConnection connection = factory.CreateConnection())
            {
                // check messages are present in the queue
                using (ISession session = connection.CreateSession())
                {
                    IDestination queue = session.GetQueue(testQueueName);

                    using (IMessageConsumer consumer = session.CreateConsumer(queue))
                    {
                        connection.Start();
                        IMessage msg;

                        for (int i = 0; i < expectedNumberOfMessages; ++i)
                        {
                            msg = consumer.Receive(TimeSpan.FromMilliseconds(2000));
                            Assert.IsNotNull(msg, "message is not in the queue !");
                        }

                        // next message should be empty
                        msg = consumer.Receive(TimeSpan.FromMilliseconds(2000));
                        Assert.IsNull(msg, "message found but not expected !");
                        consumer.Close();
                    }
                }

                connection.Close();
            }
        }

        protected void VerifyBrokerHasMessagesInQueue(string connectionURI)
        {
            using (INetTxConnection connection = dtcFactory.CreateNetTxConnection())
            {
                // check messages are present in the queue
                using (INetTxSession session = connection.CreateNetTxSession())
                {
                    IDestination queue = session.GetQueue(testQueueName);

                    using (IMessageConsumer consumer = session.CreateConsumer(queue))
                    {
                        connection.Start();

                        for (int i = 0; i < MSG_COUNT; ++i)
                        {
                            IMessage msg = consumer.Receive(TimeSpan.FromMilliseconds(2000));
                            Assert.IsNotNull(msg, "message is not in the queue !");
                        }

                        consumer.Close();
                    }
                }
            }
        }


        #endregion

        #region Transport Hools for controlling failure point.

        public void FailOnPrepareTransportHook(ITransport transport, Command command)
        {
            if (command is TransactionInfo)
            {
                TransactionInfo txInfo = command as TransactionInfo;
                if (txInfo.Type == (byte)TransactionType.Prepare)
                {
                    Thread.Sleep(1000);
                    Tracer.Debug("Throwing Error on Prepare.");
                    throw new Exception("Error writing Prepare command");
                }
            }
        }

        public void FailOnRollbackTransportHook(ITransport transport, Command command)
        {
            if (command is TransactionInfo)
            {
                TransactionInfo txInfo = command as TransactionInfo;
                if (txInfo.Type == (byte)TransactionType.Rollback)
                {
                    Tracer.Debug("Throwing Error on Rollback.");
                    throw new Exception("Error writing Rollback command");
                }
            }
        }

        public void FailOnCommitTransportHook(ITransport transport, Command command)
        {
            if (command is TransactionInfo)
            {
                TransactionInfo txInfo = command as TransactionInfo;
                if (txInfo.Type == (byte)TransactionType.CommitTwoPhase)
                {
                    Tracer.Debug("Throwing Error on Commit.");
                    throw new Exception("Error writing Commit command");
                }
            }
        }

        #endregion

        #region Recovery Harness Hooks for controlling failure conditions

        public void FailOnPreLogRecoveryHook(XATransactionId xid, byte[] recoveryInformatio)
        {
            Tracer.Debug("Throwing Error before the Recovery Information is Logged.");
            throw new Exception("Intentional Error Logging Recovery Information");
        }

        #endregion

        #region Produce Messages use cases

        protected static void ReadFromDbAndProduceToQueueWithCommit(INetTxConnection connection)
        {
            IList entries = ExtractDataSet();

            using (INetTxSession session = connection.CreateNetTxSession())
            {
                IQueue queue = session.GetQueue(testQueueName);

                // enqueue several messages read from DB
                try
                {
                    using (IMessageProducer producer = session.CreateProducer(queue))
                    {
                        producer.DeliveryMode = MsgDeliveryMode.Persistent;

                        using (TransactionScope scoped = new TransactionScope(TransactionScopeOption.RequiresNew))
                        using (SqlConnection sqlConnection = new SqlConnection(sqlConnectionString))
                        {
                            sqlConnection.Open();

                            Assert.IsNotNull(Transaction.Current);

                            Tracer.DebugFormat("Sending {0} messages to Broker in this TX", entries.Count);
                            foreach (string textBody in entries)
                            {
                                producer.Send(session.CreateTextMessage(textBody));
                            }

                            using (SqlCommand sqlDeleteCommand = new SqlCommand(
                                string.Format("DELETE FROM {0}", testTable), sqlConnection))
                            {
                                int count = sqlDeleteCommand.ExecuteNonQuery();
                                Assert.AreEqual(entries.Count, count, "wrong number of rows deleted");
                            }

                            scoped.Complete();
                        }
                    }
                }
                catch (Exception e) // exception thrown in TransactionContext.Commit(Enlistment enlistment)
                {
                    Tracer.Debug("TX;Error from TransactionScope: " + e.Message);
                    Tracer.Debug(e.ToString());
                }
            }
        }

        protected static void ReadFromDbAndProduceToQueueWithScopeAborted(INetTxConnection connection)
        {
            IList entries = ExtractDataSet();

            using (INetTxSession session = connection.CreateNetTxSession())
            {
                IQueue queue = session.GetQueue(testQueueName);

                // enqueue several messages read from DB
                try
                {
                    using (IMessageProducer producer = session.CreateProducer(queue))
                    {
                        producer.DeliveryMode = MsgDeliveryMode.Persistent;

                        using (TransactionScope scoped = new TransactionScope(TransactionScopeOption.RequiresNew))
                        using (SqlConnection sqlConnection = new SqlConnection(sqlConnectionString))
                        {
                            sqlConnection.Open();

                            Assert.IsNotNull(Transaction.Current);

                            Tracer.DebugFormat("Sending {0} messages to Broker in this TX", entries.Count);
                            foreach (string textBody in entries)
                            {
                                producer.Send(session.CreateTextMessage(textBody));
                            }

                            using (SqlCommand sqlDeleteCommand = new SqlCommand(
                                string.Format("DELETE FROM {0}", testTable), sqlConnection))
                            {
                                int count = sqlDeleteCommand.ExecuteNonQuery();
                                Assert.AreEqual(entries.Count, count, "wrong number of rows deleted");
                            }
                        }
                    }
                }
                catch (Exception e)
                {
                    Tracer.Debug("TX;Error from TransactionScope: " + e.Message);
                    Tracer.Debug(e.ToString());
                }
            }
        }

        #endregion

        #region Consume Messages Use Cases

        protected static void ReadFromQueueAndInsertIntoDbWithCommit(INetTxConnection connection)
        {
            using (INetTxSession session = connection.CreateNetTxSession())
            {
                IQueue queue = session.GetQueue(testQueueName);

                // read message from queue and insert into db table
                try
                {
                    using (IMessageConsumer consumer = session.CreateConsumer(queue))
                    {
                        using (TransactionScope scoped = new TransactionScope(TransactionScopeOption.RequiresNew))
                        using (SqlConnection sqlConnection = new SqlConnection(sqlConnectionString))
                        using (SqlCommand sqlInsertCommand = new SqlCommand())
                        {
                            sqlConnection.Open();
                            sqlInsertCommand.Connection = sqlConnection;

                            Assert.IsNotNull(Transaction.Current);

                            for (int i = 0; i < MSG_COUNT; i++)
                            {
                                ITextMessage message = consumer.Receive() as ITextMessage;
                                Assert.IsNotNull(message, "missing message");
                                sqlInsertCommand.CommandText =
                                    string.Format("INSERT INTO {0} VALUES ({1})", testTable, Convert.ToInt32(message.Text));
                                sqlInsertCommand.ExecuteNonQuery();
                            }

                            scoped.Complete();
                        }
                    }
                }
                catch (Exception e)
                {
                    Tracer.Debug("TX;Error from TransactionScope: " + e.Message);
                    Tracer.Debug(e.ToString());
                }
            }
        }

        protected static void ReadFromQueueAndInsertIntoDbWithScopeAborted(INetTxConnection connection)
        {
            using (INetTxSession session = connection.CreateNetTxSession())
            {
                IQueue queue = session.GetQueue(testQueueName);

                // read message from queue and insert into db table
                try
                {
                    using (IMessageConsumer consumer = session.CreateConsumer(queue))
                    {
                        using (TransactionScope scoped = new TransactionScope(TransactionScopeOption.RequiresNew))
                        using (SqlConnection sqlConnection = new SqlConnection(sqlConnectionString))
                        using (SqlCommand sqlInsertCommand = new SqlCommand())
                        {
                            sqlConnection.Open();
                            sqlInsertCommand.Connection = sqlConnection;
                            Assert.IsNotNull(Transaction.Current);

                            for (int i = 0; i < MSG_COUNT; i++)
                            {
                                ITextMessage message = consumer.Receive() as ITextMessage;
                                Assert.IsNotNull(message, "missing message");

                                sqlInsertCommand.CommandText =
                                    string.Format("INSERT INTO {0} VALUES ({1})", testTable, Convert.ToInt32(message.Text));
                                sqlInsertCommand.ExecuteNonQuery();
                            }
                        }
                    }
                }
                catch (Exception e)
                {
                    Tracer.Debug("TX;Error from TransactionScope: " + e.Message);
                    Tracer.Debug(e.ToString());
                }
            }
        }

        #endregion
    }
}
