blob: 7afd4dc7ed80b0b3eeed4a41fb742cc3cc1bc59b [file]
import asyncio
import logging
import uuid
from base64 import b64decode
from datetime import datetime, timezone
from importlib import import_module
import marshmallow as ma
from app.objects.c_ability import Ability, AbilitySchema
from app.objects.secondclass.c_executor import Executor, ExecutorSchema
from app.objects.secondclass.c_fact import Fact, FactSchema, OriginType
from app.objects.secondclass.c_relationship import RelationshipSchema
from app.objects.secondclass.c_visibility import Visibility, VisibilitySchema
from app.utility.base_object import BaseObject
from app.utility.base_parser import PARSER_SIGNALS_FAILURE
from app.utility.base_service import BaseService
NO_STATUS_SET = object()
class LinkSchema(ma.Schema):
class Meta:
unknown = ma.EXCLUDE
id = ma.fields.String(load_default='')
paw = ma.fields.String()
command = ma.fields.String()
plaintext_command = ma.fields.String()
status = ma.fields.Integer(load_default=-3)
score = ma.fields.Integer(load_default=0)
jitter = ma.fields.Integer(load_default=0)
decide = ma.fields.DateTime(format=BaseObject.TIME_FORMAT)
pin = ma.fields.Integer(load_default=0)
pid = ma.fields.String()
facts = ma.fields.List(ma.fields.Nested(FactSchema()))
relationships = ma.fields.List(ma.fields.Nested(RelationshipSchema()))
used = ma.fields.List(ma.fields.Nested(FactSchema()))
unique = ma.fields.String()
collect = ma.fields.DateTime(format=BaseObject.TIME_FORMAT, dump_default='')
finish = ma.fields.String()
ability = ma.fields.Nested(AbilitySchema())
executor = ma.fields.Nested(ExecutorSchema())
cleanup = ma.fields.Integer(load_default=0)
visibility = ma.fields.Nested(VisibilitySchema())
host = ma.fields.String(load_default=None)
output = ma.fields.String()
deadman = ma.fields.Boolean()
agent_reported_time = ma.fields.DateTime(format=BaseObject.TIME_FORMAT, load_default=None)
@ma.pre_load()
def fix_ability(self, link, **_):
if 'ability' in link and isinstance(link['ability'], Ability):
ability = link.pop('ability')
link['ability'] = ability.schema.dump(ability)
return link
@ma.pre_load()
def fix_executor(self, link, **_):
if 'executor' in link and isinstance(link['executor'], Executor):
executor = link.pop('executor')
link['executor'] = executor.schema.dump(executor)
return link
@ma.pre_load()
def remove_properties(self, data, **_):
data.pop('unique', None)
data.pop('decide', None)
data.pop('pid', None)
data.pop('facts', None)
data.pop('collect', None)
data.pop('finish', None)
data.pop('visibility', None)
data.pop('output', None)
data.pop('used.unique', None)
return data
@ma.post_load()
def build_link(self, data, **kwargs):
return None if kwargs.get('partial') is True else Link(**data)
@ma.post_dump()
def prepare_dump(self, data, **_):
if data.get('agent_reported_time', None) is None:
data.pop('agent_reported_time', None)
return data
class Link(BaseObject):
schema = LinkSchema()
display_schema = LinkSchema(exclude=['jitter'])
load_schema = LinkSchema(exclude=['decide', 'pid', 'facts', 'unique', 'collect', 'finish', 'visibility',
'output', 'used.unique'])
RESERVED = dict(origin_link_id='#{origin_link_id}')
EVENT_EXCHANGE = 'link'
EVENT_QUEUE_STATUS_CHANGED = 'status_changed'
@property
def raw_command(self):
return self.decode_bytes(self.command) if self.command else ''
@property
def unique(self):
return self.hash('%s' % self.id)
@property
def pin(self):
return self._pin
@pin.setter
def pin(self, p):
self._pin = p
@property
def states(self):
return dict(HIGH_VIZ=-5,
UNTRUSTED=-4,
EXECUTE=-3,
DISCARD=-2,
PAUSE=-1,
SUCCESS=0,
ERROR=1,
TIMEOUT=124)
@property
def status(self):
return self._status
@property
def display(self):
dump = LinkSchema(exclude=['jitter']).dump(self)
dump['command'] = self.decode_bytes(dump['command'])
dump['plaintext_command'] = self.decode_bytes(dump['plaintext_command'])
return dump
@status.setter
def status(self, value):
previous_status = getattr(self, '_status', NO_STATUS_SET)
self._status = value
if previous_status is NO_STATUS_SET:
return
if previous_status == value:
return
self._emit_status_change_event(
from_status=previous_status,
to_status=value
)
@classmethod
def is_global_variable(cls, variable):
return variable in cls.RESERVED
def __init__(self, command='', plaintext_command='', paw='', ability=None, executor=None, status=-3, score=0, jitter=0, cleanup=0, id='',
pin=0, host=None, deadman=False, used=None, relationships=None, agent_reported_time=None):
super().__init__()
self.id = str(id)
self.command = command
self.plaintext_command = plaintext_command
self.command_hash = None
self.paw = paw
self.host = host
self.cleanup = cleanup
self.ability = ability
self.executor = executor
self.status = status
self.score = score
self.jitter = jitter
self.decide = datetime.now(timezone.utc)
self.pid = None
self.collect = None
self.finish = None
self.facts = []
self.relationships = relationships if relationships else []
self.used = used if used else []
self.visibility = Visibility()
self._pin = pin
self.output = False
self.deadman = deadman
self.agent_reported_time = agent_reported_time
def __eq__(self, other):
if isinstance(other, Link):
return other.paw == self.paw and other.ability.ability_id == self.ability.ability_id \
and other.used == self.used
return False
async def parse(self, operation, result):
if self.status != 0:
return
for parser in self.executor.parsers:
source_facts = operation.source.facts if operation else []
try:
relationships = await self._parse_link_result(result, parser, source_facts)
if len(relationships) > 0 and relationships[0] == PARSER_SIGNALS_FAILURE:
logging.getLogger('link').debug(f'link {self.id} (ability id={self.ability.ability_id}) encountered '
f'an error during execution, which was caught during parsing.')
self.status = self.states['ERROR']
relationships = [] # we didn't actually get anything out of this, so let's reset
else:
await self.create_relationships(relationships, operation)
await update_scores(operation, increment=len(relationships), used=self.used, facts=self.facts)
except Exception as e:
logging.getLogger('link').debug('error in %s while parsing ability %s: %s'
% (parser.module, self.ability.ability_id, e))
def apply_id(self, host):
self.id = str(uuid.uuid4())
self.host = host
self.replace_origin_link_id()
def can_ignore(self):
return self.status in [self.states['DISCARD'], self.states['HIGH_VIZ']]
def is_finished(self):
return self.status in [self.states['DISCARD'], self.states['SUCCESS'],
self.states['ERROR'], self.states['TIMEOUT']]
def is_valid_status(self, status):
return status in self.states.values()
def replace_origin_link_id(self):
decoded_cmd = self.decode_bytes(self.command)
self.command = self.encode_string(decoded_cmd.replace(self.RESERVED['origin_link_id'], self.id))
def _emit_status_change_event(self, from_status, to_status):
event_svc = BaseService.get_service('event_svc')
task = asyncio.get_event_loop().create_task(
event_svc.fire_event(
exchange=Link.EVENT_EXCHANGE,
queue=Link.EVENT_QUEUE_STATUS_CHANGED,
link=self.id,
from_status=from_status,
to_status=to_status
)
)
return task
async def _parse_link_result(self, result, parser, source_facts):
blob = b64decode(result).decode('utf-8')
parser_info = dict(module=parser.module, used_facts=self.used, mappers=parser.parserconfigs,
source_facts=source_facts)
p_inst = await self._load_module('Parser', parser_info)
return p_inst.parse(blob=blob)
@staticmethod
async def _load_module(module_type, module_info):
module = import_module(module_info['module'])
return getattr(module, module_type)(module_info)
async def create_relationships(self, relationships, operation):
for relationship in relationships:
relationship.origin = operation.id if operation else self.id
await self.save_fact(operation, relationship.source, relationship.score, relationship.shorthand)
await self.save_fact(operation, relationship.target, relationship.score, relationship.shorthand)
if all((relationship.source.trait, relationship.edge)):
knowledge_svc_handle = BaseService.get_service('knowledge_svc')
await knowledge_svc_handle.add_relationship(relationship)
self.relationships.append(relationship)
async def save_fact(self, operation, fact, score, relationship):
knowledge_svc_handle = BaseService.get_service('knowledge_svc')
all_facts = await operation.all_facts() if operation else self.facts
source = operation.id if operation else self.id
rl = [relationship] if relationship else []
if all([fact.trait, fact.value]):
if operation and operation.source:
if any([(fact.trait, fact.value) == (x.trait, x.value) for x in
await knowledge_svc_handle.get_facts(criteria=dict(source=operation.source.id))]):
source = operation.source.id
fact.source = source # Manual addition to ensure the check works correctly
if not await knowledge_svc_handle.check_fact_exists(fact, all_facts):
f_gen = Fact(trait=fact.trait, value=fact.value, source=source, score=score, collected_by=[self.paw],
technique_id=self.ability.technique_id, links=[self.id], relationships=rl,
origin_type=OriginType.LEARNED)
self.facts.append(f_gen)
await knowledge_svc_handle.add_fact(f_gen)
else:
existing_fact = (await knowledge_svc_handle.get_facts(criteria=dict(trait=fact.trait,
value=fact.value,
source=fact.source)))[0]
if self.id not in existing_fact.links:
existing_fact.links.append(self.id)
if relationship not in existing_fact.relationships:
existing_fact.relationships.append(relationship)
if self.paw not in existing_fact.collected_by and existing_fact not in self.used:
existing_fact.collected_by.append(self.paw)
await knowledge_svc_handle.update_fact(criteria=dict(trait=fact.trait, value=fact.value,
source=fact.source),
updates=dict(links=existing_fact.links,
relationships=existing_fact.relationships,
collected_by=existing_fact.collected_by))
existing_local_record = [x for x in self.facts if x.trait == fact.trait and x.value == fact.value]
if existing_local_record:
existing_local_record[0].links = existing_fact.links
else:
self.facts.append(existing_fact)
async def update_scores(operation, increment, used, facts):
knowledge_svc_handle = BaseService.get_service('knowledge_svc')
for uf in used:
all_facts = await operation.all_facts() if operation else facts
for found_fact in all_facts:
if found_fact.unique == uf.unique:
found_fact.score += increment
await knowledge_svc_handle.update_fact(dict(trait=found_fact.trait, value=found_fact.value,
source=found_fact.source), dict(score=found_fact.score))
break