Update ES backend library to work with all ES versions (2,5,6,7).

git-svn-id: https://svn.apache.org/repos/asf/steve/trunk/pysteve@1862546 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/backends/es.py b/lib/backends/es.py
index 0052dac..29fb574 100644
--- a/lib/backends/es.py
+++ b/lib/backends/es.py
@@ -20,26 +20,146 @@
 import random
 import time
 from lib import constants
+import elasticsearch
+
+
+class SteveESWrapper(object):
+    """
+       Class for rewriting old-style queries to the new ones,
+       where doc_type is an integral part of the DB name
+    """
+    def __init__(self, ES):
+        self.ES = ES
+    
+    def get(self, index, doc_type, id):
+        return self.ES.get(index = index+'_'+doc_type, doc_type = '_doc', id = id)
+    def exists(self, index, doc_type, id):
+        return self.ES.exists(index = index+'_'+doc_type, doc_type = '_doc', id = id)
+    def delete(self, index, doc_type, id):
+        return self.ES.delete(index = index+'_'+doc_type, doc_type = '_doc', id = id)
+    def index(self, index, doc_type, id = None, body = None):
+        return self.ES.index(index = index+'_'+doc_type, doc_type = '_doc', id = id, body = body)
+    def update(self, index, doc_type, id, body):
+        return self.ES.update(index = index+'_'+doc_type, doc_type = '_doc', id = id, body = body)
+    def scroll(self, scroll_id, scroll):
+        return self.ES.scroll(scroll_id = scroll_id, scroll = scroll)
+    def delete_by_query(self, **kwargs):
+        return self.ES.delete_by_query(**kwargs)
+    def search(self, index, doc_type, size = 100, scroll = None, _source_include = None, body = None, q = None, sort = None):
+        if q:
+            body = {
+                "query": {
+                    "query_string": {
+                        "query": q
+                    }
+                }
+            }
+        if sort and body:
+            if '.keyword' not in sort:
+                sort = sort + ".keyword"
+            body['sort'] = [
+                { sort: 'asc'}
+            ]
+        return self.ES.search(
+            index = index+'_'+doc_type,
+            doc_type = '_doc',
+            size = size,
+            scroll = scroll,
+            _source_include = _source_include,
+            body = body
+            )
+    def count(self, index, doc_type = '*', body = None):
+        return self.ES.count(
+            index = index+'_'+doc_type,
+            doc_type = '_doc',
+            body = body
+            )
+
+class SteveESWrapperSeven(object):
+    """
+       Class for rewriting old-style queries to the >= 7.x ones,
+       where doc_type is an integral part of the DB name and NO DOC_TYPE!
+    """
+    def __init__(self, ES):
+        self.ES = ES
+    
+    def get(self, index, doc_type, id):
+        return self.ES.get(index = index+'_'+doc_type, id = id)
+    def exists(self, index, doc_type, id):
+        return self.ES.exists(index = index+'_'+doc_type, id = id)
+    def delete(self, index, doc_type, id):
+        return self.ES.delete(index = index+'_'+doc_type, id = id)
+    def index(self, index, doc_type, id = None, body = None):
+        return self.ES.index(index = index+'_'+doc_type, id = id, body = body)
+    def update(self, index, doc_type, id, body):
+        return self.ES.update(index = index+'_'+doc_type, id = id, body = body)
+    def scroll(self, scroll_id, scroll):
+        return self.ES.scroll(scroll_id = scroll_id, scroll = scroll)
+    def delete_by_query(self, **kwargs):
+        return self.ES.delete_by_query(**kwargs)
+    def search(self, index, doc_type, size = 100, scroll = None, _source_include = None, body = None, q = None, sort = None):
+        if q:
+            body = {
+                "query": {
+                    "query_string": {
+                        "query": q
+                    }
+                }
+            }
+        if sort and body:
+            if '.keyword' not in sort:
+                sort = sort + ".keyword"
+            body['sort'] = [
+                { sort: 'asc'}
+            ]
+        return self.ES.search(
+            index = index+'_'+doc_type,
+            size = size,
+            scroll = scroll,
+            _source_includes = _source_include,
+            body = body
+            )
+    def count(self, index, doc_type = '*', body = None):
+        return self.ES.count(
+            index = index+'_'+doc_type,
+            body = body
+            )
+    
+
+class SteveDatabase(object):
+    def __init__(self, config):
+        self.config = config
+        self.dbname = config.get('elasticsearch','index')
+        self.ES = elasticsearch.Elasticsearch([{
+                'host': config.get('elasticsearch', 'host'),
+                'port': int(config.get('elasticsearch','port')),
+                'use_ssl': config.get('elasticsearch', 'secure'),
+                'verify_certs': False,
+            }],
+                max_retries=5,
+                retry_on_timeout=True
+            )
+        
+        # IMPORTANT BIT: Figure out if this is ES < 6.x, 6.x or >= 7.x.
+        # If so, we're using the new ES DB mappings, and need to adjust ALL
+        # ES calls to match this.
+        self.ESversion = int(self.ES.info()['version']['number'].split('.')[0])
+        if self.ESversion >= 7:
+            self.ES = SteveESWrapperSeven(self.ES)
+        elif self.ESVersion >= 6:
+            self.ES = SteveESWrapper(self.ES)
 
 class ElasticSearchBackend:
     es = None
     
     def __init__(self, config):
         " Init - get config and turn it into an ES instance"
-        from elasticsearch import Elasticsearch
         self.index = config.get("elasticsearch", "index") if config.has_option("elasticsearch", "index") else "steve"
-        self.es = Elasticsearch([
-                        {
-                            'host': config.get("elasticsearch", "host"),
-                            'port': int(config.get("elasticsearch", "port")),
-                            'url_prefix': config.get("elasticsearch", "uri"),
-                            'use_ssl': False if config.get("elasticsearch", "secure") == "false" else True
-                        },
-                    ])
+        self.DB = SteveDatabase(config)
         
         # Check that we have a 'steve' index. If not, create it.
-        if not self.es.indices.exists(self.index):
-            self.es.indices.create(index = self.index, body = {
+        if not self.DB.ES.ES.indices.exists(self.index):
+            self.DB.ES.ES.indices.create(index = self.index, body = {
                     "settings": {
                         "number_of_shards" : 3,
                         "number_of_replicas" : 1
@@ -56,13 +176,13 @@
         if issue and issue[0]:
             doc = "issues"
             eid = hashlib.sha224(election + "/" + issue[0]).hexdigest()
-        return self.es.exists(index=self.index, doc_type=doc, id=eid)
+        return self.DB.ES.exists(index=self.index, doc_type=doc, id=eid)
         
         
     
     def get_basedata(self, election):
         "Get base data from an election"
-        res = self.es.get(index=self.index, doc_type="elections", id=election)
+        res = self.DB.ES.get(index=self.index, doc_type="elections", id=election)
         if res:
             return res['_source']
         return None
@@ -75,7 +195,7 @@
             basedata['closed'] = False
         else:
             basedata['closed'] = True
-        self.es.index(index=self.index, doc_type="elections", id=election, body = basedata )
+        self.DB.ES.index(index=self.index, doc_type="elections", id=election, body = basedata )
             
     
     def issue_get(self, electionID, issueID):
@@ -83,7 +203,7 @@
         issuedata = None
         ihash = ""
         iid = hashlib.sha224(electionID + "/" + issueID).hexdigest()
-        res = self.es.get(index=self.index, doc_type="issues", id=iid)
+        res = self.DB.ES.get(index=self.index, doc_type="issues", id=iid)
         if res:
             issuedata = res['_source']
             ihash = hashlib.sha224(json.dumps(issuedata)).hexdigest()
@@ -92,7 +212,7 @@
     
     def votes_get(self, electionID, issueID):
         "Read votes and return as a dict"
-        res = self.es.search(index=self.index, doc_type="votes", q = "election:%s AND issue:%s" % (electionID, issueID), size = 9999)
+        res = self.DB.ES.search(index=self.index, doc_type="votes", q = "election:%s AND issue:%s" % (electionID, issueID), size = 9999)
         results = len(res['hits']['hits'])
         if results > 0:
             votes = {}
@@ -105,7 +225,7 @@
     
     def votes_get_raw(self, electionID, issueID):
         "Read votes and return raw format"
-        res = self.es.search(index=self.index, doc_type="votes", q = "election:%s AND issue:%s" % (electionID, issueID), size = 9999)
+        res = self.DB.ES.search(index=self.index, doc_type="votes", q = "election:%s AND issue:%s" % (electionID, issueID), size = 9999)
         results = len(res['hits']['hits'])
         if results > 0:
             votes = []
@@ -116,7 +236,7 @@
     
     def vote_history(self, electionID, issueID):
         "Read vote history and return raw format"
-        res = self.es.search(index=self.index, doc_type="vote_history", sort = "data.timestamp", q = "election:%s AND issue:%s" % (electionID, issueID), size = 9999)
+        res = self.DB.ES.search(index=self.index, doc_type="vote_history", sort = "data.timestamp", q = "election:%s AND issue:%s" % (electionID, issueID), size = 9999)
         results = len(res['hits']['hits'])
         if results > 0:
             votes = []
@@ -127,25 +247,25 @@
     
     def election_create(self,electionID, basedata):
         "Create a new election"
-        self.es.index(index=self.index, doc_type="elections", id=electionID, body =
+        self.DB.ES.index(index=self.index, doc_type="elections", id=electionID, body =
             basedata
         );
     
     def election_update(self,electionID, basedata):
         "Update an election with new data"
-        self.es.index(index = self.index, doc_type = "elections", id=electionID, body = basedata)
+        self.DB.ES.index(index = self.index, doc_type = "elections", id=electionID, body = basedata)
     
     
     def issue_update(self,electionID, issueID, issueData):
         "Update an issue with new data"
-        self.es.index(index = self.index, doc_type = "issues", id=hashlib.sha224(electionID + "/" + issueID).hexdigest(), body = issueData)
+        self.DB.ES.index(index = self.index, doc_type = "issues", id=hashlib.sha224(electionID + "/" + issueID).hexdigest(), body = issueData)
     
     
     def issue_list(self, election):
         "List all issues in an election"
         issues = []
         try:
-            res = self.es.search(index=self.index, doc_type="issues", sort = "id", q = "election:%s" % election, size = 999, _source_include = 'id')
+            res = self.DB.ES.search(index=self.index, doc_type="issues", sort = "id", q = "election:%s" % election, size = 999, _source_include = 'id')
             results = len(res['hits']['hits'])
             if results > 0:
                 for entry in res['hits']['hits']:
@@ -158,7 +278,7 @@
         "List all elections"
         elections = []
         try:
-            res = self.es.search(index=self.index, doc_type="elections", sort = "id", q = "*", size = 9999)
+            res = self.DB.ES.search(index=self.index, doc_type="elections", sort = "id", q = "*", size = 9999)
             results = len(res['hits']['hits'])
             if results > 0:
                 for entry in res['hits']['hits']:
@@ -174,7 +294,7 @@
         now = time.time()
         if vhash:
             eid = vhash
-        self.es.index(index=self.index, doc_type="votes", id=eid, body =
+        self.DB.ES.index(index=self.index, doc_type="votes", id=eid, body =
             {
                 'issue': issueID,
                 'election': electionID,
@@ -186,7 +306,7 @@
             }
         );
         # Backlog of changesets
-        self.es.index(index=self.index, doc_type="vote_history", body =
+        self.DB.ES.index(index=self.index, doc_type="vote_history", body =
             {
                 'issue': issueID,
                 'election': electionID,
@@ -201,11 +321,11 @@
         
     def issue_delete(self, electionID, issueID):
         "Deletes an issue if it exists"
-        self.es.delete(index=self.index, doc_type="issues", id=hashlib.sha224(electionID + "/" + issueID).hexdigest());
+        self.DB.ES.delete(index=self.index, doc_type="issues", id=hashlib.sha224(electionID + "/" + issueID).hexdigest());
         
     def issue_create(self,electionID, issueID, data):
         "Create an issue"
-        self.es.index(index=self.index, doc_type="issues", id=hashlib.sha224(electionID + "/" + issueID).hexdigest(), body = data);
+        self.DB.ES.index(index=self.index, doc_type="issues", id=hashlib.sha224(electionID + "/" + issueID).hexdigest(), body = data);
     
     
     
@@ -214,7 +334,7 @@
         
         # First, try the raw hash as an ID
         try:
-            res = self.es.get(index=self.index, doc_type="voters", id=votekey)
+            res = self.DB.ES.get(index=self.index, doc_type="voters", id=votekey)
             if res:
                 return res['_source']['uid']
         except:
@@ -222,7 +342,7 @@
         
         # Now, look for it as hash inside the doc
         try:
-            res = self.es.search(index=self.index, doc_type="voters", q = "election:%s" % electionID, size = 9999)
+            res = self.DB.ES.search(index=self.index, doc_type="voters", q = "election:%s" % electionID, size = 9999)
             results = len(res['hits']['hits'])
             if results > 0:
                 for entry in res['hits']['hits']:
@@ -235,7 +355,7 @@
     def voter_add(self,election, PID, xhash):
         "Add a voter to the DB"
         eid = hashlib.sha224(election + ":" + PID).hexdigest()
-        self.es.index(index=self.index, doc_type="voters", id=eid, body = {
+        self.DB.ES.index(index=self.index, doc_type="voters", id=eid, body = {
             'election': election,
             'hash': xhash,
             'uid': PID
@@ -255,7 +375,7 @@
         for issue in issues:
             vhash = hashlib.sha224(xhash + issue).hexdigest()
             try:
-                self.es.delete(index=self.index, doc_type="votes", id=vhash);
+                self.DB.ES.delete(index=self.index, doc_type="votes", id=vhash);
             except:
                 pass
         return True
@@ -263,13 +383,13 @@
     def voter_remove(self,election, UID):
         "Remove the voter with the given UID"
         votehash = hashlib.sha224(election + ":" + UID).hexdigest()
-        self.es.delete(index=self.index, doc_type="voters", id=votehash);
+        self.DB.ES.delete(index=self.index, doc_type="voters", id=votehash);
     
     def voter_has_voted(self,election, issue, uid):
         "Return true if the voter has voted on this issue, otherwise false"
         eid = hashlib.sha224(election + ":" + issue + ":" + uid).hexdigest()
         try:
-            return self.es.exists(index=self.index, doc_type="votes", id=eid)
+            return self.DB.ES.exists(index=self.index, doc_type="votes", id=eid)
         except:
             return False
 
@@ -279,7 +399,7 @@
         # First, get all elections
         elections = {}
         
-        res = self.es.search(index=self.index, doc_type="elections", sort = "id", q = "*", size = 9999)
+        res = self.DB.ES.search(index=self.index, doc_type="elections", sort = "id", q = "*", size = 9999)
         results = len(res['hits']['hits'])
         if results > 0:
             for entry in res['hits']['hits']:
@@ -292,7 +412,7 @@
                 
         # Then, get all ballots and note whether they still apply or not
         ballots = {}
-        res = self.es.search(index=self.index, doc_type="voters", body = {
+        res = self.DB.ES.search(index=self.index, doc_type="voters", body = {
             "query": {
                 "match": {
                     "uid": UID