blob: 0c49a3f7862314de76ca0649a48c8629ed167b57 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# Automatically generated by addcopyright.py at 01/29/2013
'''
Created on Jan 2, 2013
@author: frank
'''
import cherrypy
import sglib
import xmlobject
import types
import uuid
import os.path
import sys
import os
class SGRule(object):
def __init__(self):
self.protocol = None
self.start_port = None
self.end_port = None
self.allowed_ips = []
class IPSet(object):
IPSET_TYPE = 'hash:ip'
def __init__(self, setname, ips):
self.ips = ips
self.name = setname
def create(self):
tmpname = str(uuid.uuid4()).replace('-', '')[0:30]
sglib.ShellCmd('ipset -N %s %s' % (tmpname, self.IPSET_TYPE))()
try:
for ip in self.ips:
sglib.ShellCmd('ipset -A %s %s' % (tmpname, ip))()
try:
sglib.ShellCmd('ipset -N %s %s' % (self.name, self.IPSET_TYPE))()
cherrypy.log('created new ipset: %s' % self.name)
except Exception:
cherrypy.log('%s already exists, no need to create new' % self.name)
finally:
sglib.ShellCmd('ipset -W %s %s' % (tmpname, self.name))()
sglib.ShellCmd('ipset -F %s' % tmpname)()
sglib.ShellCmd('ipset -X %s' % tmpname)()
@staticmethod
def destroy_sets(sets_to_keep):
sets = sglib.ShellCmd('ipset list')()
for s in sets.split('\n'):
if 'Name:' in s:
set_name = s.split(':', 1)[1].strip()
if not set_name in sets_to_keep:
sglib.ShellCmd('ipset destroy %s' % set_name)()
cherrypy.log('destroyed unused ipset: %s' % set_name)
class SGAgent(object):
def __init__(self):
pass
def _self_list(self, obj):
if isinstance(obj, types.ListType):
return obj
else:
return [obj]
def set_rules(self, req):
body = req.body
doc = xmlobject.loads(body)
vm_name = doc.vmName.text_
vm_id = doc.vmId.text_
vm_ip = doc.vmIp.text_
vm_mac = doc.vmMac.text_
sig = doc.signature.text_
seq = doc.sequenceNumber.text_
def parse_rules(rules, lst):
for i in self._self_list(rules):
r = SGRule()
r.protocol = i.protocol.text_
r.start_port = i.startPort.text_
r.end_port = i.endPort.text_
if hasattr(i, 'ip'):
for ip in self._self_list(i.ip):
r.allowed_ips.append(ip.text_)
lst.append(r)
i_rules = []
if hasattr(doc, 'ingressRules'):
parse_rules(doc.ingressRules, i_rules)
e_rules = []
if hasattr(doc, 'egressRules'):
parse_rules(doc.egressRules, e_rules)
def create_chain(name):
try:
sglib.ShellCmd('iptables -F %s' % name)()
except Exception:
sglib.ShellCmd('iptables -N %s' % name)()
def apply_rules(rules, chainname, direction, action, current_set_names):
create_chain(chainname)
for r in i_rules:
allow_any = False
if '0.0.0.0/0' in r.allowed_ips:
allow_any = True
r.allowed_ips.remove('0.0.0.0/0')
if r.allowed_ips:
setname = '_'.join([chainname, r.protocol, r.start_port, r.end_port])
ipset = IPSet(setname, r.allowed_ips)
ipset.create()
current_set_names.append(setname)
if r.protocol == 'all':
cmd = ['iptables -I', chainname, '-m state --state NEW -m set --set', setname, direction, '-j', action]
sglib.ShellCmd(' '.join(cmd))()
elif r.protocol != 'icmp':
port_range = ":".join([r.start_port, r.end_port])
cmd = ['iptables', '-I', chainname, '-p', r.protocol, '-m', r.protocol, '--dport', port_range, '-m state --state NEW -m set --set', setname, direction, '-j', action]
sglib.ShellCmd(' '.join(cmd))()
else:
port_range = "/".join([r.start_port, r.end_port])
if r.start_port == "-1":
port_range = "any"
cmd = ['iptables', '-I', i_chain_name, '-p', 'icmp', '--icmp-type', port_range, '-m set --set', setname, direction, '-j', action]
sglib.ShellCmd(' '.join(cmd))()
if allow_any and r.protocol != 'all':
if r.protocol != 'icmp':
port_range = ":".join([r.start_port, r.end_port])
cmd = ['iptables', '-I', chainname, '-p', r.protocol, '-m', r.protocol, '--dport', port_range, '-m', 'state', '--state', 'NEW', '-j', action]
sglib.ShellCmd(' '.join(cmd))()
else:
port_range = "/".join([r.start_port, r.end_port])
if r.start_port == "-1":
port_range = "any"
cmd = ['iptables', '-I', i_chain_name, '-p', 'icmp', '--icmp-type', port_range, '-j', action]
sglib.ShellCmd(' '.join(cmd))()
current_sets = []
i_chain_name = vm_name + '-in'
apply_rules(i_rules, i_chain_name, 'src', 'ACCEPT', current_sets)
e_chain_name = vm_name + '-eg'
apply_rules(e_rules, e_chain_name, 'dst', 'RETURN', current_sets)
if e_rules:
sglib.ShellCmd('iptables -A %s -j RETURN' % e_chain_name)
else:
sglib.ShellCmd('iptables -A %s -j DROP' % e_chain_name)
sglib.ShellCmd('iptables -A %s -j DROP' % i_chain_name)
IPSet.destroy_sets(current_sets)
def echo(self, req):
cherrypy.log("echo: I am alive")
def index(self):
req = sglib.Request.from_cherrypy_request(cherrypy.request)
cmd_name = req.headers['command']
if not hasattr(self, cmd_name):
raise ValueError("SecurityGroupAgent doesn't have a method called '%s'" % cmd_name)
method = getattr(self, cmd_name)
return method(req)
index.exposed = True
@staticmethod
def start():
cherrypy.log.access_file = '/var/log/cs-securitygroup.log'
cherrypy.log.error_file = '/var/log/cs-securitygroup.log'
cherrypy.server.socket_host = '0.0.0.0'
cherrypy.server.socket_port = 9988
cherrypy.quickstart(SGAgent())
@staticmethod
def stop():
cherrypy.engine.exit()
PID_FILE = '/var/run/cssgagent.pid'
class SGAgentDaemon(sglib.Daemon):
def __init__(self):
super(SGAgentDaemon, self).__init__(PID_FILE)
self.is_stopped = False
self.agent = SGAgent()
sglib.Daemon.register_atexit_hook(self._do_stop)
def _do_stop(self):
if self.is_stopped:
return
self.is_stopped = True
self.agent.stop()
def run(self):
self.agent.start()
def stop(self):
self.agent.stop()
super(SGAgentDaemon, self).stop()
def main():
usage = 'usage: python -c "from security_group_agent import cs_sg_agent; cs_sg_agent.main()" start|stop|restart'
if len(sys.argv) != 2 or not sys.argv[1] in ['start', 'stop', 'restart']:
print usage
sys.exit(1)
cmd = sys.argv[1]
agentdaemon = SGAgentDaemon()
if cmd == 'start':
agentdaemon.start()
elif cmd == 'stop':
agentdaemon.stop()
else:
agentdaemon.restart()
sys.exit(0)