aboutsummaryrefslogtreecommitdiff
blob: 2c00949a732f35fbfff94e21e4a0378ea6bfba54 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# vim:fileencoding=utf8:et:ts=4:sts=4:sw=4:ft=python

from django.conf import settings
from django.test import TestCase
from django.test.utils import override_settings

import base64
import socket

import paramiko

from okupy import OkupyError
from okupy.common.ssh import ssh_handler, SSHServer
from okupy.tests.vars import TEST_SSH_KEY_FOR_NO_USER


@override_settings(SSH_HANDLERS={})
class SSHUnitTests(TestCase):
    def setUp(self):
        self._key = paramiko.RSAKey(
            data=base64.b64decode(TEST_SSH_KEY_FOR_NO_USER))
        self._server = SSHServer()

    def test_ssh_handler_decorator_works(self):
        @ssh_handler
        def test(key):
            pass

        self.assertEqual(settings.SSH_HANDLERS.get('test'), test)

    def test_noarg_handler_works(self):
        @ssh_handler
        def noarg(key):
            return 'yay'

        self.assertEqual(
            self._server.check_auth_publickey('noarg', self._key),
            paramiko.AUTH_SUCCESSFUL)

    def test_failure_is_propagated_properly(self):
        @ssh_handler
        def failing(key):
            return None

        self.assertEqual(
            self._server.check_auth_publickey('failing', self._key),
            paramiko.AUTH_FAILED)

    def test_argument_splitting_works(self):
        @ssh_handler
        def twoarg(a, b, key):
            if a == '1' and b == '2':
                return 'yay'
            else:
                return None

        self.assertEqual(
            self._server.check_auth_publickey('twoarg+1+2', self._key),
            paramiko.AUTH_SUCCESSFUL)

    def test_default_arguments_work(self):
        @ssh_handler
        def oneortwoarg(a, b='3', key=None):
            if not key:
                raise ValueError('key must not be None')
            if a == '1' and b == '3':
                return 'yay'
            else:
                return None

        self.assertEqual(
            self._server.check_auth_publickey('oneortwoarg+1', self._key),
            paramiko.AUTH_SUCCESSFUL)

    def test_wrong_command_returns_failure(self):
        @ssh_handler
        def somehandler(key):
            return 'er?'

        self.assertEqual(
            self._server.check_auth_publickey('otherhandler', self._key),
            paramiko.AUTH_FAILED)

    def test_missing_arguments_return_failure(self):
        @ssh_handler
        def onearg(arg, key):
            return 'er?'

        self.assertEqual(
            self._server.check_auth_publickey('onearg', self._key),
            paramiko.AUTH_FAILED)

    def test_too_many_arguments_return_failure(self):
        @ssh_handler
        def onearg(arg, key):
            return 'er?'

        self.assertEqual(
            self._server.check_auth_publickey('onearg+1+2', self._key),
            paramiko.AUTH_FAILED)

    def test_typeerror_is_propagated_properly(self):
        @ssh_handler
        def onearg(key):
            raise TypeError

        self.assertRaises(TypeError,
            self._server.check_auth_publickey, 'onearg', self._key)

    def test_result_caching_works(self):
        class Cache(object):
            def __init__(self):
                self.first = True

            def __call__(self, key):
                if self.first:
                    self.first = False
                    return 'yay'
                else:
                    return None

        cache = Cache()
        @ssh_handler
        def cached(key):
            return cache(key)

        if (self._server.check_auth_publickey('cached', self._key)
                != paramiko.AUTH_SUCCESSFUL):
            raise OkupyError('Test prerequisite failed')
        self.assertEqual(
            self._server.check_auth_publickey('cached', self._key),
            paramiko.AUTH_SUCCESSFUL)

    def test_message_is_printed_to_exec_request(self):
        @ssh_handler
        def noarg(key):
            return 'test-message'

        if (self._server.check_auth_publickey('noarg', self._key)
                != paramiko.AUTH_SUCCESSFUL):
            raise OkupyError('Test prerequisite failed')

        s1, s2 = socket.socketpair()
        self.assertTrue(self._server.check_channel_exec_request(s1, ':'))
        self.assertEqual(s2.makefile().read().rstrip(), 'test-message')

    def test_message_is_printed_to_shell_request(self):
        @ssh_handler
        def noarg(key):
            return 'test-message'

        if (self._server.check_auth_publickey('noarg', self._key)
                != paramiko.AUTH_SUCCESSFUL):
            raise OkupyError('Test prerequisite failed')

        s1, s2 = socket.socketpair()
        self.assertTrue(self._server.check_channel_shell_request(s1))
        self.assertEqual(s2.makefile().read().rstrip(), 'test-message')

    def test_cache_is_invalidated_after_channel_request(self):
        class Cache(object):
            def __init__(self):
                self.first = True

            def __call__(self, key):
                if self.first:
                    self.first = False
                    return 'test-message'
                else:
                    return None

        cache = Cache()
        @ssh_handler
        def cached(key):
            return cache(key)

        if (self._server.check_auth_publickey('cached', self._key)
                != paramiko.AUTH_SUCCESSFUL):
            raise OkupyError('Test prerequisite failed')

        s1, s2 = socket.socketpair()
        if not self._server.check_channel_shell_request(s1):
            raise OkupyError('Test prerequisite failed')
        if s2.makefile().read().rstrip() != 'test-message':
            raise OkupyError('Test prerequisite failed')

        self.assertEqual(
            self._server.check_auth_publickey('cached', self._key),
            paramiko.AUTH_FAILED)