| #!/usr/bin/env python |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import patchviasocket |
| |
| import getpass |
| import os |
| import socket |
| import tempfile |
| import time |
| import threading |
| import unittest |
| |
| KEY_DATA = "I luv\nCloudStack\n" |
| CMD_DATA = "/run/this-for-me --please=TRUE! very%quickly" |
| NON_EXISTING_FILE = "must-not-exist" |
| |
| |
| def write_key_file(): |
| _, tmpfile = tempfile.mkstemp(".sck") |
| with open(tmpfile, "w") as f: |
| f.write(KEY_DATA) |
| return tmpfile |
| |
| |
| class SocketThread(threading.Thread): |
| def __init__(self): |
| super(SocketThread, self).__init__() |
| self._data = "" |
| self._folder = tempfile.mkdtemp(".sck") |
| self._file = os.path.join(self._folder, "socket") |
| self._ready = False |
| |
| def data(self): |
| return self._data |
| |
| def file(self): |
| return self._file |
| |
| def wait_until_ready(self): |
| while not self._ready: |
| time.sleep(0.050) |
| |
| def run(self): |
| TIMEOUT = 0.314 # Very short time for tests that don't write to socket. |
| MAX_SIZE = 10 * 1024 |
| |
| s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) |
| try: |
| s.bind(self._file) |
| s.listen(1) |
| s.settimeout(TIMEOUT) |
| try: |
| self._ready = True |
| client, address = s.accept() |
| self._data = client.recv(MAX_SIZE) |
| client.close() |
| except socket.timeout: |
| pass |
| finally: |
| s.close() |
| os.remove(self._file) |
| os.rmdir(self._folder) |
| |
| |
| class TestPatchViaSocket(unittest.TestCase): |
| def setUp(self): |
| self._key_file = write_key_file() |
| |
| self._unreadable = write_key_file() |
| os.chmod(self._unreadable, 0) |
| |
| self.assertFalse(os.path.exists(NON_EXISTING_FILE)) |
| self.assertNotEqual("root", getpass.getuser(), "must be non-root user (to test access denied errors)") |
| |
| def tearDown(self): |
| os.remove(self._key_file) |
| os.remove(self._unreadable) |
| |
| def test_write_to_socket(self): |
| reader = SocketThread() |
| reader.start() |
| reader.wait_until_ready() |
| self.assertEquals(0, patchviasocket.send_to_socket(reader.file(), self._key_file, CMD_DATA)) |
| reader.join() |
| data = reader.data() |
| self.assertIn(KEY_DATA, data) |
| self.assertIn(CMD_DATA.replace("%", " "), data) |
| self.assertNotIn("LUV", data) |
| self.assertNotIn("very%quickly", data) # Testing substitution |
| |
| def test_host_key_error(self): |
| reader = SocketThread() |
| reader.start() |
| reader.wait_until_ready() |
| self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), NON_EXISTING_FILE, CMD_DATA)) |
| reader.join() # timeout |
| |
| def test_host_key_access_denied(self): |
| reader = SocketThread() |
| reader.start() |
| reader.wait_until_ready() |
| self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), self._unreadable, CMD_DATA)) |
| reader.join() # timeout |
| |
| def test_nonexistant_socket_error(self): |
| reader = SocketThread() |
| reader.start() |
| reader.wait_until_ready() |
| self.assertEquals(1, patchviasocket.send_to_socket(NON_EXISTING_FILE, self._key_file, CMD_DATA)) |
| reader.join() # timeout |
| |
| def test_invalid_socket_error(self): |
| reader = SocketThread() |
| reader.start() |
| reader.wait_until_ready() |
| self.assertEquals(1, patchviasocket.send_to_socket(self._key_file, self._key_file, CMD_DATA)) |
| reader.join() # timeout |
| |
| def test_access_denied_socket_error(self): |
| reader = SocketThread() |
| reader.start() |
| reader.wait_until_ready() |
| self.assertEquals(1, patchviasocket.send_to_socket(self._unreadable, self._key_file, CMD_DATA)) |
| reader.join() # timeout |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |