blob: 63110ac5f51aa4c24fea0f530419522fc204b3f0 [file] [log] [blame]
# Copyright 2022 The casbin Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading
import time
import casbin
import etcd3
class ETCDWatcher(object):
def __init__(self, endpoints, running, callback, key_name):
self.client = None
self.endpoints = endpoints
self.running = running
self.callback = callback
self.keyName = key_name
self.mutex = threading.Lock()
self.watch_thread = threading.Thread(target=self.start_watch, daemon=True)
self.logger = logging.getLogger(__name__)
def create_client(self):
self.client = etcd3.Etcd3Client(host=self.endpoints[0], port=self.endpoints[1])
def close(self):
self.running = False
self.logger.info("ETCD watcher closed")
def set_update_callback(self, callback):
"""
sets the callback function to be called when the policy is updated
:param callback:
:return:
"""
self.mutex.acquire()
self.callback = callback
self.mutex.release()
def update(self):
"""
update the policy
"""
self.client.put(self.keyName, str(time.time()))
return True
def update_for_add_policy(self, section, ptype, *params):
"""
update for add policy
:param section: section
:param ptype: policy type
:param params: other params
:return: True if updated
"""
message = "Update for add policy: " + section + " " + ptype + " " + str(params)
self.logger.info(message)
return self.update()
def update_for_remove_policy(self, section, ptype, *params):
"""
update for remove policy
:param section: section
:param ptype: policy type
:param params: other params
:return: True if updated
"""
message = (
"Update for remove policy: " + section + " " + ptype + " " + str(params)
)
self.logger.info(message)
return self.update()
def update_for_remove_filtered_policy(self, section, ptype, field_index, *params):
"""
update for remove filtered policy
:param section: section
:param ptype: policy type
:param field_index: field index
:param params: other params
:return:
"""
message = (
"Update for remove filtered policy: "
+ section
+ " "
+ ptype
+ " "
+ str(field_index)
+ " "
+ str(params)
)
self.logger.info(message)
return self.update()
def update_for_save_policy(self, model: casbin.Model):
"""
update for save policy
:param model: casbin model
:return:
"""
message = "Update for save policy: " + model.to_text()
self.logger.info(message)
return self.update()
def update_for_add_policies(self, section, ptype, *params):
"""
update for add policies
:param section: section
:param ptype: policy type
:param params: other params
:return:
"""
message = (
"Update for add policies: " + section + " " + ptype + " " + str(params)
)
self.logger.info(message)
return self.update()
def update_for_remove_policies(self, section, ptype, *params):
"""
update for remove policies
:param section: section
:param ptype: policy type
:param params: other params
:return:
"""
message = (
"Update for remove policies: " + section + " " + ptype + " " + str(params)
)
self.logger.info(message)
return self.update()
def start_watch(self):
"""
starts the watch thread
:return:
"""
events_iterator, cancel = self.client.watch(self.keyName)
for event in events_iterator:
if isinstance(event, etcd3.events.PutEvent) or isinstance(
event, etcd3.events.DeleteEvent
):
self.mutex.acquire()
if self.callback is not None:
self.callback(event)
self.mutex.release()
def new_watcher(endpoints, keyname):
"""
creates a new watcher
:param endpoints:
:param keyname:
:return: a new watcher
"""
etcd = ETCDWatcher(
endpoints=endpoints, running=True, callback=None, key_name=keyname
)
etcd.create_client()
etcd.watch_thread.start()
etcd.logger.info("ETCD watcher started")
return etcd