blob: c753350eaf52cd84d5a9f1e648362d6d7a3c939f [file] [log] [blame]
# -- coding: utf-8 --
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from . import CsHelper
from .CsDatabag import CsCmdLine
import logging
class CsChain(object):
def __init__(self):
self.chain = {}
self.last_added = ''
self.count = {}
def add(self, table, chain):
if table not in list(self.chain.keys()):
self.chain.setdefault(table, []).append(chain)
else:
self.chain[table].append(chain)
if self.last_added != chain:
self.last_added = chain
self.count[chain] = 0
def add_rule(self, chain):
self.count[chain] += 1
def get(self, table):
if table not in list(self.chain.keys()):
return {}
return self.chain[table]
def get_count(self, chain):
return self.count[chain]
def last(self):
return self.last_added
def has_chain(self, table, chain):
if table not in list(self.chain.keys()):
return False
if chain not in self.chain[table]:
return False
return True
class CsTable(object):
def __init__(self):
self.table = []
self.last_added = ''
def add(self, name):
if name not in self.table:
self.table.append(name)
self.last_added = name
def get(self):
return self.table
def last(self):
return self.last_added
class CsNetfilters(object):
def __init__(self, load=True):
self.rules = []
self.table = CsTable()
self.chain = CsChain()
if load:
self.get_all_rules()
def get_all_rules(self):
for i in CsHelper.execute("iptables-save"):
if i.startswith('*'): # Table
self.table.add(i[1:])
if i.startswith(':'): # Chain
self.chain.add(self.table.last(), i[1:].split(' ')[0])
if i.startswith('-A'): # Rule
self.chain.add_rule(i.split()[1])
rule = CsNetfilter()
rule.parse(i)
rule.set_table(self.table.last())
rule.set_chain(i.split()[1])
rule.set_count(self.chain.get_count(i.split()[1]))
self.save(rule)
def save(self, rule):
self.rules.append(rule)
def get(self):
return self.rules
def has_table(self, table):
return table in self.table.get()
def has_chain(self, table, chain):
return self.chain.has_chain(table, chain)
def has_rule(self, new_rule):
for r in self.get():
if new_rule == r:
if new_rule.get_count() > 0:
continue
r.mark_seen()
return True
return False
def get_unseen(self):
del_list = [x for x in self.rules if x.unseen()]
for r in del_list:
cmd = "iptables -t %s %s" % (r.get_table(), r.to_str(True))
logging.debug("unseen cmd: %s ", cmd)
CsHelper.execute(cmd)
# print "Delete rule %s from table %s" % (r.to_str(True), r.get_table())
logging.info("Delete rule %s from table %s", r.to_str(True), r.get_table())
def compare(self, list):
""" Compare reality with what is needed """
# PASS 1: Ensure all chains are present
for fw in list:
new_rule = CsNetfilter()
new_rule.parse(fw[2])
new_rule.set_table(fw[0])
self.add_chain(new_rule)
ruleSet = set()
# PASS 2: Create rules
for fw in list:
tupledFw = tuple(fw)
if tupledFw in ruleSet:
logging.debug("Already processed : %s", tupledFw)
continue
new_rule = CsNetfilter()
new_rule.parse(fw[2])
new_rule.set_table(fw[0])
if isinstance(fw[1], int):
new_rule.set_count(fw[1])
rule_chain = new_rule.get_chain()
logging.debug("Checking if the rule already exists: rule=%s table=%s chain=%s", new_rule.get_rule(), new_rule.get_table(), new_rule.get_chain())
if self.has_rule(new_rule):
logging.debug("Exists: rule=%s table=%s", fw[2], new_rule.get_table())
else:
# print "Add rule %s in table %s" % ( fw[2], new_rule.get_table())
logging.info("Add: rule=%s table=%s", fw[2], new_rule.get_table())
# front means insert instead of append
cpy = fw[2]
if fw[1] == "front":
cpy = cpy.replace('-A', '-I')
if isinstance(fw[1], int):
# if the rule is for ACLs, we want to insert them in order, right before the DROP all
if rule_chain.startswith("ACL_INBOUND") or rule_chain.startswith("ACL_OUTBOUND"):
rule_count = self.chain.get_count(rule_chain) if self.chain.get_count(rule_chain) > 0 else 1
cpy = cpy.replace("-A %s" % new_rule.get_chain(), '-I %s %s' % (new_rule.get_chain(), rule_count))
else:
cpy = cpy.replace("-A %s" % new_rule.get_chain(), '-I %s %s' % (new_rule.get_chain(), fw[1]))
ret = CsHelper.execute2("iptables -t %s %s" % (new_rule.get_table(), cpy))
# There are some issues in this framework causing failures .. like adding a chain without checking it is present causing
# the failures. Also some of the rule like removeFromLoadBalancerRule is deleting rule and deleteLoadBalancerRule
# trying to delete which causes the failure.
# For now raising the log.
# TODO: Need to fix in the framework.
if ret.returncode != 0:
error = ret.communicate()[0].decode()
logging.debug("iptables command got failed ... continuing")
ruleSet.add(tupledFw)
self.chain.add_rule(rule_chain)
self.del_standard()
self.get_unseen()
def add_chain(self, rule):
""" Add the given chain if it is not already present """
if not self.has_chain(rule.get_table(), rule.get_chain()):
if rule.get_chain():
CsHelper.execute("iptables -t %s -N %s" % (rule.get_table(), rule.get_chain()))
self.chain.add(rule.get_table(), rule.get_chain())
def del_standard(self):
""" Del rules that are there but should not be deleted
These standard firewall rules vary according to the device type
"""
type = CsCmdLine("cmdline").get_type()
try:
table = ''
for i in open("/etc/iptables/iptables-%s" % type):
if i.startswith('*'): # Table
table = i[1:].strip()
if i.startswith('-A'): # Rule
self.del_rule(table, i.strip())
except IOError:
logging.debug("Exception in del_standard, returning")
# Nothing can be done
return
def del_rule(self, table, rule):
nr = CsNetfilter()
nr.parse(rule)
nr.set_table(table)
self.delete(nr)
def delete(self, rule):
""" Delete a rule from the list of configured rules
The rule will not actually be removed on the host """
self.rules[:] = [x for x in self.rules if not x == rule]
def add_ip6_chain(self, address_family, table, chain, hook, action):
chain_policy = ""
if hook:
chain_policy = "type filter hook %s priority 0;" % hook
if chain_policy and action:
chain_policy = "%s policy %s;" % (chain_policy, action)
CsHelper.execute("nft add chain %s %s %s '{ %s }'" % (address_family, table, chain, chain_policy))
if hook == "input" or hook == "output":
CsHelper.execute("nft add rule %s %s %s icmpv6 type { echo-request, echo-reply, \
nd-neighbor-solicit, nd-router-advert, nd-neighbor-advert } accept" % (address_family, table, chain))
def apply_ip6_rules(self, rules, type):
if len(rules) == 0:
return
address_family = 'ip6'
table = 'ip6_firewall'
default_chains = [
{"chain": "fw_input", "hook": "input", "action": "drop"},
{"chain": "fw_forward", "hook": "forward", "action": "accept"}
]
if type == "acl":
table = 'ip6_acl'
default_chains = [
{"chain": "acl_input", "hook": "input", "action": "drop"},
{"chain": "acl_forward", "hook": "forward", "action": "accept"}
]
CsHelper.execute("nft add table %s %s" % (address_family, table))
for chain in default_chains:
self.add_ip6_chain(address_family, table, chain['chain'], chain['hook'], chain['action'])
for fw in rules:
chain = fw['chain']
type = fw['type']
rule = None
if 'rule' in fw:
rule = fw['rule']
if type == "chain":
hook = ""
if "output" in chain:
hook = "output"
elif "input" in chain:
hook = "input"
self.add_ip6_chain(address_family, table, chain, hook, rule)
else:
logging.info("Add: rule=%s in address_family=%s table=%s, chain=%s", rule, address_family, table, chain)
CsHelper.execute("nft add rule %s %s %s %s" % (address_family, table, chain, rule))
class CsNetfilter(object):
def __init__(self):
self.rule = {}
self.table = ''
self.chain = ''
self.seen = False
self.count = 0
def parse(self, rule):
self.rule = self.__convert_to_dict(rule)
def unseen(self):
return self.seen is False
def mark_seen(self):
self.seen = True
def __convert_to_dict(self, rule):
rule = str(rule.lstrip())
rule = rule.replace('! -', '!_-')
rule = rule.replace('-p all', '')
rule = rule.replace(' ', ' ')
rule = rule.replace('bootpc', '68')
# Ugly hack no.23 split this or else I will have an odd number of parameters
rule = rule.replace('--checksum-fill', '--checksum fill')
# -m can appear twice in a string
rule = rule.replace('-m state', '-m2 state')
rule = rule.replace('ESTABLISHED,RELATED', 'RELATED,ESTABLISHED')
bits = rule.split(' ')
rule = dict(list(zip(bits[0::2], bits[1::2])))
if "-A" in list(rule.keys()):
self.chain = rule["-A"]
return rule
def set_table(self, table):
if table == '':
table = "filter"
self.table = table
def get_table(self):
return self.table
def set_chain(self, chain):
self.chain = chain
def set_count(self, count=0):
self.count = count
def get_count(self):
return self.count
def get_chain(self):
return self.chain
def get_rule(self):
return self.rule
def to_str(self, delete=False):
""" Convert the rule back into aynactically correct iptables command """
# Order is important
order = ['-A', '-s', '-d', '!_-d', '-i', '!_-i', '-p', '-m', '-m2', '--icmp-type', '--state',
'--dport', '--destination-port', '-o', '!_-o', '-j', '--set-xmark', '--checksum',
'--to-source', '--to-destination', '--mark']
str = ''
for k in order:
if k in list(self.rule.keys()):
printable = k.replace('-m2', '-m')
printable = printable.replace('!_-', '! -')
if delete:
printable = printable.replace('-A', '-D')
if str == '':
str = "%s %s" % (printable, self.rule[k])
else:
str = "%s %s %s" % (str, printable, self.rule[k])
str = str.replace("--checksum fill", "--checksum-fill")
return str
def __eq__(self, rule):
if rule.get_table() != self.get_table():
return False
if rule.get_chain() != self.get_chain():
return False
if len(list(rule.get_rule().items())) != len(list(self.get_rule().items())):
return False
common = set(rule.get_rule().items()) & set(self.get_rule().items())
if len(common) != len(rule.get_rule()):
return False
return True