blob: a5e3a7caf9bcef009ab96452ceccf91637679230 [file] [log] [blame]
# -------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# TODO: Test edge cases: unequal number of audio-video timestamps (should still work and add the average over all audio/video samples)
import shutil
import unittest
from systemds.scuro.modality.joined import JoinCondition
from systemds.scuro.modality.unimodal_modality import UnimodalModality
from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
from systemds.scuro.representations.resnet import ResNet
from tests.scuro.data_generator import setup_data
from systemds.scuro.dataloader.audio_loader import AudioLoader
from systemds.scuro.dataloader.video_loader import VideoLoader
from systemds.scuro.modality.type import ModalityType
class TestMultimodalJoin(unittest.TestCase):
test_file_path = None
mods = None
text = None
audio = None
video = None
data_generator = None
num_instances = 0
indizes = []
@classmethod
def setUpClass(cls):
cls.test_file_path = "join_test_data"
cls.num_instances = 4
cls.mods = [ModalityType.VIDEO, ModalityType.AUDIO]
cls.data_generator = setup_data(cls.mods, cls.num_instances, cls.test_file_path)
@classmethod
def tearDownClass(cls):
print("Cleaning up test data")
shutil.rmtree(cls.test_file_path)
def test_video_audio_join(self):
self._execute_va_join()
def test_chunked_video_audio_join(self):
self._execute_va_join(2)
def test_video_chunked_audio_join(self):
self._execute_va_join(None, 2)
def test_chunked_video_chunked_audio_join(self):
self._execute_va_join(2, 2)
def test_audio_video_join(self):
# Audio has a much higher frequency than video, hence we would need to
# duplicate or interpolate frames to match them to the audio frequency
self._execute_av_join()
# TODO
# def test_chunked_audio_video_join(self):
# self._execute_av_join(2)
# TODO
# def test_chunked_audio_chunked_video_join(self):
# self._execute_av_join(2, 2)
def _execute_va_join(self, l_chunk_size=None, r_chunk_size=None):
video, audio = self._prepare_data(l_chunk_size, r_chunk_size)
self._join(video, audio, 2)
def _execute_av_join(self, l_chunk_size=None, r_chunk_size=None):
video, audio = self._prepare_data(l_chunk_size, r_chunk_size)
self._join(audio, video, 2)
def _prepare_data(self, l_chunk_size=None, r_chunk_size=None):
video_data_loader = VideoLoader(
self.data_generator.get_modality_path(ModalityType.VIDEO),
self.data_generator.indices,
chunk_size=l_chunk_size,
)
video = UnimodalModality(video_data_loader)
audio_data_loader = AudioLoader(
self.data_generator.get_modality_path(ModalityType.AUDIO),
self.data_generator.indices,
r_chunk_size,
)
audio = UnimodalModality(audio_data_loader)
mel_audio = audio.apply_representation(MelSpectrogram())
return video, mel_audio
def _join(self, left_modality, right_modality, window_size):
resnet_modality = (
left_modality.join(
right_modality, JoinCondition("timestamp", "timestamp", "<")
)
.apply_representation(ResNet(layer="layer1.0.conv2", model_name="ResNet50"))
.window(window_size, "mean")
.combine("concat")
)
assert resnet_modality.left_modality is not None
assert resnet_modality.right_modality is not None
assert len(resnet_modality.left_modality.data) == self.num_instances
assert len(resnet_modality.right_modality.data) == self.num_instances
assert resnet_modality.data is not None
return resnet_modality
if __name__ == "__main__":
unittest.main()