blob: 019770500a2e28d9ea46d1143691645bbd4868dc [file] [log] [blame]
# Copyright 2022 The Applied Research Laboratory for Intelligence and Security (ARLIS)
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import distill
from distill.segmentation.segmentation_error import SegmentationError
class Segments():
A list of Segment objects.
def __init__(self, segments=[]):
Segments initialization function.
:param segments ([]): An optional List of Segment objects to be included in the Segments abstraction.
self.segments = segments
def __iter__(self):
Allows Segments to be iterable.
return iter(self.segments)
def __len__(self):
Allows Segments to return a length.
return len(self.segments)
def __getitem__(self, item):
Allows Segments to be subscriptable by segment name or numeric index.
if isinstance(item, str):
segment_names = [segment.get_segment_name() for segment in self.segments]
if item in segment_names:
index = segment_names.index(item)
return self.segments[index]
return self.segments[item]
def __setitem__(self, key, value):
Allows subscripts to be used to set items.
if not isinstance(value, distill.Segment):
raise TypeError("Segments objects can only hold Segment objects.")
if isinstance(key, str):
if not value.segment_name == key:
raise SegmentationError("Segment name in subscript must match the segment name of the Segment object.")
segment_names = [segment.get_segment_name() for segment in self.segments]
if key in segment_names:
index = segment_names.index(key)
self.segments[index] = value
elif isinstance(key, int):
if key < len(self.segments):
self.segments[key] = value
raise SegmentationError("Index provided goes beyond the length of the underlying list of Segment objects.")
def __str__(self):
Creates a readable string for Segments.
result = "Segments: [\n"
for segment in self.segments:
result += str(segment) + "\n"
result += "]"
return result
def get_segment_list(self):
Returns a list of Segment objects in Segments.
:return: A list of segment objects.
return self.segments
def get_segment_name_dict(self):
Returns a dictionary of segment_name to Segment objects based on the key parameter. Note that segment names
must be unique.
result = {}
for segment in self.segments:
if segment.get_segment_name() in result:
raise SegmentationError("Segment names must be unique")
result[segment.get_segment_name()] = segment
return result
def get_num_logs(self, num_logs):
Returns a new Segments object only including segments with the specified number of logs.
:param num_logs: The minimum number of logs (inclusive) necessary to be included in the new Segments object.
:return: A new Segments object that contains Segment objects with at least the specified number of logs.
segments = [segment for segment in self.segments if segment.num_logs >= num_logs]
return Segments(segments)
def get_segments_before(self, time):
Returns a new Segments object only including segments that have end times before the indicated time.
:param time: An integer or datetime object that represents the time for which Segment end times should be before.
:return: A new Segments object that contains Segment objects that have end times prior to the time indicated.
if not isinstance(time, int) and not isinstance(time, datetime.datetime):
raise TypeError('Time must be an integer or datetime object.')
segments = [segment for segment in self.segments if segment.start_end_val[1] < time]
return Segments(segments)
def get_segments_after(self, time):
Returns a new Segments object only including segments that have start times after the indicated time.
:param time: An integer or datetime object that represents the time for which Segment start times should be after.
:return: A new Segments object that contains Segment objects that have start times after the time indicated.
if not isinstance(time, int) and not isinstance(time, datetime.datetime):
raise TypeError('Time must be an integer or datetime object.')
segments = [segment for segment in self.segments if segment.start_end_val[0] > time]
return Segments(segments)
def get_segments_of_type(self, segment_type):
Returns a new Segments object that includes Segment objects of a specified type.
:param segment_type: The type of Segment objects that should be included.
:return: A new Segments object that contains Segment objects of the specified type.
if not isinstance(segment_type, distill.Segment_Type):
raise TypeError("Given segment_type: " + str(segment_type) + " is not a valid segment type.")
segments = [segment for segment in self.segments if segment.segment_type == segment_type]
return Segments(segments)
def append(self, item):
Adds a Segment object to the Segments object.
:param item: The Segment object to add.
if not isinstance(item, distill.Segment):
raise TypeError("Only Segment objects can be added to a Segments object.")
def append_segments(self, segments):
Adds all Segment objects in the given Segments object to the current Segments object.
:param segments: A Segments object to append to the calling Segments object.
if not isinstance(segments, distill.Segments):
raise TypeError("Only Segments objects can be appended with append_segments.")
def delete(self, segment_name):
Deletes the Segment object with the given segment_name.
:param segment_name: The name of the Segment to delete.
segment_names = [segment.get_segment_name() for segment in self.segments]
if segment_name in segment_names:
index = segment_names.index(segment_name)
segment = self.segments[index]
raise SegmentationError("No Segment objects with given segment name: " + segment_name)