messages.py 28.7 KB
Newer Older
1
2
3
4
5
# Authors: 
#   Trevor Perrin
#   Google - handling CertificateRequest.certificate_types
#   Google (adapted by Sam Rushing and Marcelo Fernandez) - NPN support
#   Dimitris Moraitis - Anon ciphersuites
6
#   Yngve Pettersen (ported by Paul Sokolovsky) - TLS 1.2
7
8
9
#
# See the LICENSE file for legal information regarding use of this file.

10
11
"""Classes representing TLS messages."""

12
13
14
15
16
17
18
19
20
21
from .utils.compat import *
from .utils.cryptomath import *
from .errors import *
from .utils.codec import *
from .constants import *
from .x509 import X509
from .x509certchain import X509CertChain
from .utils.tackwrapper import *

class RecordHeader3(object):
22
23
24
25
26
27
28
29
30
31
32
33
34
    def __init__(self):
        self.type = 0
        self.version = (0,0)
        self.length = 0
        self.ssl2 = False

    def create(self, version, type, length):
        self.type = type
        self.version = version
        self.length = length
        return self

    def write(self):
35
        w = Writer()
36
37
38
39
40
41
42
43
44
45
46
47
48
        w.add(self.type, 1)
        w.add(self.version[0], 1)
        w.add(self.version[1], 1)
        w.add(self.length, 2)
        return w.bytes

    def parse(self, p):
        self.type = p.get(1)
        self.version = (p.get(1), p.get(1))
        self.length = p.get(2)
        self.ssl2 = False
        return self

49
class RecordHeader2(object):
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    def __init__(self):
        self.type = 0
        self.version = (0,0)
        self.length = 0
        self.ssl2 = True

    def parse(self, p):
        if p.get(1)!=128:
            raise SyntaxError()
        self.type = ContentType.handshake
        self.version = (2,0)
        #We don't support 2-byte-length-headers; could be a problem
        self.length = p.get(1)
        return self


66
class Alert(object):
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    def __init__(self):
        self.contentType = ContentType.alert
        self.level = 0
        self.description = 0

    def create(self, description, level=AlertLevel.fatal):
        self.level = level
        self.description = description
        return self

    def parse(self, p):
        p.setLengthCheck(2)
        self.level = p.get(1)
        self.description = p.get(1)
        p.stopLengthCheck()
        return self

    def write(self):
85
        w = Writer()
86
87
88
89
90
        w.add(self.level, 1)
        w.add(self.description, 1)
        return w.bytes


91
92
93
94
95
96
97
98
99
100
class HandshakeMsg(object):
    def __init__(self, handshakeType):
        self.contentType = ContentType.handshake
        self.handshakeType = handshakeType
    
    def postWrite(self, w):
        headerWriter = Writer()
        headerWriter.add(self.handshakeType, 1)
        headerWriter.add(len(w.bytes), 3)
        return headerWriter.bytes + w.bytes
101
102
103

class ClientHello(HandshakeMsg):
    def __init__(self, ssl2=False):
104
        HandshakeMsg.__init__(self, HandshakeType.client_hello)
105
106
        self.ssl2 = ssl2
        self.client_version = (0,0)
107
108
        self.random = bytearray(32)
        self.session_id = bytearray(0)
109
110
111
112
        self.cipher_suites = []         # a list of 16-bit values
        self.certificate_types = [CertificateType.x509]
        self.compression_methods = []   # a list of 8-bit values
        self.srp_username = None        # a string
113
114
115
        self.tack = False
        self.supports_npn = False
        self.server_name = bytearray(0)
116
        self.channel_id = False
117
        self.extended_master_secret = False
118
        self.tb_client_params = []
119
        self.support_signed_cert_timestamps = False
120
        self.status_request = False
121
122

    def create(self, version, random, session_id, cipher_suites,
123
124
               certificate_types=None, srpUsername=None,
               tack=False, supports_npn=False, serverName=None):
125
126
127
128
129
130
        self.client_version = version
        self.random = random
        self.session_id = session_id
        self.cipher_suites = cipher_suites
        self.certificate_types = certificate_types
        self.compression_methods = [0]
131
132
133
134
135
136
        if srpUsername:
            self.srp_username = bytearray(srpUsername, "utf-8")
        self.tack = tack
        self.supports_npn = supports_npn
        if serverName:
            self.server_name = bytearray(serverName, "utf-8")
137
138
139
140
141
142
143
144
        return self

    def parse(self, p):
        if self.ssl2:
            self.client_version = (p.get(1), p.get(1))
            cipherSpecsLength = p.get(2)
            sessionIDLength = p.get(2)
            randomLength = p.get(2)
145
            self.cipher_suites = p.getFixList(3, cipherSpecsLength//3)
146
147
148
149
            self.session_id = p.getFixBytes(sessionIDLength)
            self.random = p.getFixBytes(randomLength)
            if len(self.random) < 32:
                zeroBytes = 32-len(self.random)
150
                self.random = bytearray(zeroBytes) + self.random
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
            self.compression_methods = [0]#Fake this value

            #We're not doing a stopLengthCheck() for SSLv2, oh well..
        else:
            p.startLengthCheck(3)
            self.client_version = (p.get(1), p.get(1))
            self.random = p.getFixBytes(32)
            self.session_id = p.getVarBytes(1)
            self.cipher_suites = p.getVarList(2, 2)
            self.compression_methods = p.getVarList(1, 1)
            if not p.atLengthCheck():
                totalExtLength = p.get(2)
                soFar = 0
                while soFar != totalExtLength:
                    extType = p.get(2)
                    extLength = p.get(2)
167
168
169
170
                    index1 = p.index
                    if extType == ExtensionType.srp:
                        self.srp_username = p.getVarBytes(1)
                    elif extType == ExtensionType.cert_type:
171
                        self.certificate_types = p.getVarList(1, 1)
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
                    elif extType == ExtensionType.tack:
                        self.tack = True
                    elif extType == ExtensionType.supports_npn:
                        self.supports_npn = True
                    elif extType == ExtensionType.server_name:
                        serverNameListBytes = p.getFixBytes(extLength)
                        p2 = Parser(serverNameListBytes)
                        p2.startLengthCheck(2)
                        while 1:
                            if p2.atLengthCheck():
                                break # no host_name, oh well
                            name_type = p2.get(1)
                            hostNameBytes = p2.getVarBytes(2)
                            if name_type == NameType.host_name:
                                self.server_name = hostNameBytes
                                break
188
189
                    elif extType == ExtensionType.channel_id:
                        self.channel_id = True
190
191
                    elif extType == ExtensionType.extended_master_secret:
                        self.extended_master_secret = True
192
193
194
195
196
                    elif extType == ExtensionType.token_binding:
                        tokenBindingBytes = p.getFixBytes(extLength)
                        p2 = Parser(tokenBindingBytes)
                        ver_minor = p2.get(1)
                        ver_major = p2.get(1)
197
                        if (ver_major, ver_minor) >= (0, 6):
198
199
200
                            p2.startLengthCheck(1)
                            while not p2.atLengthCheck():
                                self.tb_client_params.append(p2.get(1))
201
202
203
204
                    elif extType == ExtensionType.signed_cert_timestamps:
                        if extLength:
                            raise SyntaxError()
                        self.support_signed_cert_timestamps = True
205
206
207
208
209
210
211
212
213
214
215
216
217
                    elif extType == ExtensionType.status_request:
                        # Extension contents are currently ignored.
                        # According to RFC 6066, this is not strictly forbidden
                        # (although it is suboptimal):
                        # Servers that receive a client hello containing the
                        # "status_request" extension MAY return a suitable
                        # certificate status response to the client along with
                        # their certificate.  If OCSP is requested, they
                        # SHOULD use the information contained in the extension
                        # when selecting an OCSP responder and SHOULD include
                        # request_extensions in the OCSP request.
                        p.getFixBytes(extLength)
                        self.status_request = True
218
                    else:
219
220
221
222
                        _ = p.getFixBytes(extLength)
                    index2 = p.index
                    if index2 - index1 != extLength:
                        raise SyntaxError("Bad length for extension_data")
223
224
225
226
                    soFar += 4 + extLength
            p.stopLengthCheck()
        return self

227
228
    def write(self):
        w = Writer()
229
230
231
232
233
234
235
        w.add(self.client_version[0], 1)
        w.add(self.client_version[1], 1)
        w.addFixSeq(self.random, 1)
        w.addVarSeq(self.session_id, 1, 1)
        w.addVarSeq(self.cipher_suites, 2, 2)
        w.addVarSeq(self.compression_methods, 1, 1)

236
        w2 = Writer() # For Extensions
237
238
        if self.certificate_types and self.certificate_types != \
                [CertificateType.x509]:
239
240
241
            w2.add(ExtensionType.cert_type, 2)
            w2.add(len(self.certificate_types)+1, 2)
            w2.addVarSeq(self.certificate_types, 1, 1)
242
        if self.srp_username:
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
            w2.add(ExtensionType.srp, 2)
            w2.add(len(self.srp_username)+1, 2)
            w2.addVarSeq(self.srp_username, 1, 1)
        if self.supports_npn:
            w2.add(ExtensionType.supports_npn, 2)
            w2.add(0, 2)
        if self.server_name:
            w2.add(ExtensionType.server_name, 2)
            w2.add(len(self.server_name)+5, 2)
            w2.add(len(self.server_name)+3, 2)            
            w2.add(NameType.host_name, 1)
            w2.addVarSeq(self.server_name, 1, 2) 
        if self.tack:
            w2.add(ExtensionType.tack, 2)
            w2.add(0, 2)
        if len(w2.bytes):
            w.add(len(w2.bytes), 2)
            w.bytes += w2.bytes
        return self.postWrite(w)

class BadNextProtos(Exception):
    def __init__(self, l):
        self.length = l

    def __str__(self):
        return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length
269
270
271

class ServerHello(HandshakeMsg):
    def __init__(self):
272
        HandshakeMsg.__init__(self, HandshakeType.server_hello)
273
        self.server_version = (0,0)
274
275
        self.random = bytearray(32)
        self.session_id = bytearray(0)
276
277
278
        self.cipher_suite = 0
        self.certificate_type = CertificateType.x509
        self.compression_method = 0
279
280
281
        self.tackExt = None
        self.next_protos_advertised = None
        self.next_protos = None
282
        self.channel_id = False
283
        self.extended_master_secret = False
284
        self.tb_params = None
285
        self.signed_cert_timestamps = None
286
        self.status_request = False
287
288

    def create(self, version, random, session_id, cipher_suite,
289
               certificate_type, tackExt, next_protos_advertised):
290
291
292
293
294
295
        self.server_version = version
        self.random = random
        self.session_id = session_id
        self.cipher_suite = cipher_suite
        self.certificate_type = certificate_type
        self.compression_method = 0
296
297
        self.tackExt = tackExt
        self.next_protos_advertised = next_protos_advertised
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        return self

    def parse(self, p):
        p.startLengthCheck(3)
        self.server_version = (p.get(1), p.get(1))
        self.random = p.getFixBytes(32)
        self.session_id = p.getVarBytes(1)
        self.cipher_suite = p.get(2)
        self.compression_method = p.get(1)
        if not p.atLengthCheck():
            totalExtLength = p.get(2)
            soFar = 0
            while soFar != totalExtLength:
                extType = p.get(2)
                extLength = p.get(2)
313
314
315
                if extType == ExtensionType.cert_type:
                    if extLength != 1:
                        raise SyntaxError()
316
                    self.certificate_type = p.get(1)
317
318
319
320
                elif extType == ExtensionType.tack and tackpyLoaded:
                    self.tackExt = TackExtension(p.getFixBytes(extLength))
                elif extType == ExtensionType.supports_npn:
                    self.next_protos = self.__parse_next_protos(p.getFixBytes(extLength))
321
322
323
324
325
326
                else:
                    p.getFixBytes(extLength)
                soFar += 4 + extLength
        p.stopLengthCheck()
        return self

327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    def __parse_next_protos(self, b):
        protos = []
        while True:
            if len(b) == 0:
                break
            l = b[0]
            b = b[1:]
            if len(b) < l:
                raise BadNextProtos(len(b))
            protos.append(b[:l])
            b = b[l:]
        return protos

    def __next_protos_encoded(self):
        b = bytearray()
        for e in self.next_protos_advertised:
            if len(e) > 255 or len(e) == 0:
                raise BadNextProtos(len(e))
            b += bytearray( [len(e)] ) + bytearray(e)
        return b

    def write(self):
        w = Writer()
350
351
352
353
354
355
356
        w.add(self.server_version[0], 1)
        w.add(self.server_version[1], 1)
        w.addFixSeq(self.random, 1)
        w.addVarSeq(self.session_id, 1, 1)
        w.add(self.cipher_suite, 2)
        w.add(self.compression_method, 1)

357
        w2 = Writer() # For Extensions
358
359
        if self.certificate_type and self.certificate_type != \
                CertificateType.x509:
360
361
362
363
364
365
366
367
368
369
370
371
372
            w2.add(ExtensionType.cert_type, 2)
            w2.add(1, 2)
            w2.add(self.certificate_type, 1)
        if self.tackExt:
            b = self.tackExt.serialize()
            w2.add(ExtensionType.tack, 2)
            w2.add(len(b), 2)
            w2.bytes += b
        if self.next_protos_advertised is not None:
            encoded_next_protos_advertised = self.__next_protos_encoded()
            w2.add(ExtensionType.supports_npn, 2)
            w2.add(len(encoded_next_protos_advertised), 2)
            w2.addFixSeq(encoded_next_protos_advertised, 1)
373
        if self.channel_id:
374
375
            w2.add(ExtensionType.channel_id, 2)
            w2.add(0, 2)
376
377
378
        if self.extended_master_secret:
            w2.add(ExtensionType.extended_master_secret, 2)
            w2.add(0, 2)
379
380
381
382
383
384
        if self.tb_params:
            w2.add(ExtensionType.token_binding, 2)
            # length of extension
            w2.add(4, 2)
            # version
            w2.add(0, 1)
385
            w2.add(6, 1)
386
387
388
389
            # length of params (defined as variable length <1..2^8-1>, but in
            # this context the server can only send a single value.
            w2.add(1, 1)
            w2.add(self.tb_params, 1)
390
        if self.signed_cert_timestamps:
391
392
            w2.add(ExtensionType.signed_cert_timestamps, 2)
            w2.addVarSeq(bytearray(self.signed_cert_timestamps), 1, 2)
393
        if self.status_request:
394
395
396
397
398
399
            w2.add(ExtensionType.status_request, 2)
            w2.add(0, 2)
        if len(w2.bytes):
            w.add(len(w2.bytes), 2)
            w.bytes += w2.bytes        
        return self.postWrite(w)
400

401
402
403

class Certificate(HandshakeMsg):
    def __init__(self, certificateType):
404
        HandshakeMsg.__init__(self, HandshakeType.certificate)
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        self.certificateType = certificateType
        self.certChain = None

    def create(self, certChain):
        self.certChain = certChain
        return self

    def parse(self, p):
        p.startLengthCheck(3)
        if self.certificateType == CertificateType.x509:
            chainLength = p.get(3)
            index = 0
            certificate_list = []
            while index != chainLength:
                certBytes = p.getVarBytes(3)
                x509 = X509()
                x509.parseBinary(certBytes)
                certificate_list.append(x509)
                index += len(certBytes)+3
            if certificate_list:
                self.certChain = X509CertChain(certificate_list)
        else:
            raise AssertionError()

        p.stopLengthCheck()
        return self

432
433
    def write(self):
        w = Writer()
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
        if self.certificateType == CertificateType.x509:
            chainLength = 0
            if self.certChain:
                certificate_list = self.certChain.x509List
            else:
                certificate_list = []
            #determine length
            for cert in certificate_list:
                bytes = cert.writeBytes()
                chainLength += len(bytes)+3
            #add bytes
            w.add(chainLength, 3)
            for cert in certificate_list:
                bytes = cert.writeBytes()
                w.addVarSeq(bytes, 1, 3)
        else:
            raise AssertionError()
451
        return self.postWrite(w)
452

453
454
class CertificateStatus(HandshakeMsg):
    def __init__(self):
455
        HandshakeMsg.__init__(self, HandshakeType.certificate_status)
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474

    def create(self, ocsp_response):
        self.ocsp_response = ocsp_response
        return self

    # Defined for the sake of completeness, even though we currently only
    # support sending the status message (server-side), not requesting
    # or receiving it (client-side).
    def parse(self, p):
        p.startLengthCheck(3)
        status_type = p.get(1)
        # Only one type is specified, so hardwire it.
        if status_type != CertificateStatusType.ocsp:
            raise SyntaxError()
        ocsp_response = p.getVarBytes(3)
        if not ocsp_response:
            # Can't be empty
            raise SyntaxError()
        self.ocsp_response = ocsp_response
475
        p.stopLengthCheck()
476
477
        return self

478
479
    def write(self):
        w = Writer()
480
        w.add(CertificateStatusType.ocsp, 1)
481
482
        w.addVarSeq(bytearray(self.ocsp_response), 1, 3)
        return self.postWrite(w)
483

484
class CertificateRequest(HandshakeMsg):
485
    def __init__(self, version):
486
        HandshakeMsg.__init__(self, HandshakeType.certificate_request)
487
        self.certificate_types = []
488
        self.certificate_authorities = []
489
490
        self.version = version
        self.supported_signature_algs = []
491

492
    def create(self, certificate_types, certificate_authorities, sig_algs):
493
494
        self.certificate_types = certificate_types
        self.certificate_authorities = certificate_authorities
495
        self.supported_signature_algs = sig_algs
496
497
498
499
500
        return self

    def parse(self, p):
        p.startLengthCheck(3)
        self.certificate_types = p.getVarList(1, 1)
501
502
503
        if self.version >= (3,3):
            self.supported_signature_algs = \
                [(b >> 8, b & 0xff) for b in p.getVarList(2, 2)]
504
505
506
507
508
509
510
        ca_list_length = p.get(2)
        index = 0
        self.certificate_authorities = []
        while index != ca_list_length:
          ca_bytes = p.getVarBytes(2)
          self.certificate_authorities.append(ca_bytes)
          index += len(ca_bytes)+2
511
512
513
        p.stopLengthCheck()
        return self

514
515
    def write(self):
        w = Writer()
516
        w.addVarSeq(self.certificate_types, 1, 1)
517
518
519
520
521
        if self.version >= (3,3):
            w.add(2 * len(self.supported_signature_algs), 2)
            for (hash, signature) in self.supported_signature_algs:
                w.add(hash, 1)
                w.add(signature, 1)
522
523
524
525
526
527
528
529
        caLength = 0
        #determine length
        for ca_dn in self.certificate_authorities:
            caLength += len(ca_dn)+2
        w.add(caLength, 2)
        #add bytes
        for ca_dn in self.certificate_authorities:
            w.addVarSeq(ca_dn, 1, 2)
530
        return self.postWrite(w)
531
532

class ServerKeyExchange(HandshakeMsg):
533
    def __init__(self, cipherSuite, version):
534
        HandshakeMsg.__init__(self, HandshakeType.server_key_exchange)
535
        self.cipherSuite = cipherSuite
536
        self.version = version
537
538
539
540
        self.srp_N = 0
        self.srp_g = 0
        self.srp_s = bytearray(0)
        self.srp_B = 0
541
        # DH params:
542
543
544
        self.dh_p = 0
        self.dh_g = 0
        self.dh_Ys = 0
545
546
547
        # ECDH params:
        self.ecdhCurve = 0
        self.ecdhPublic = bytearray(0)
548
        self.signature = bytearray(0)
549
550
551
552
553
554
555

    def createSRP(self, srp_N, srp_g, srp_s, srp_B):
        self.srp_N = srp_N
        self.srp_g = srp_g
        self.srp_s = srp_s
        self.srp_B = srp_B
        return self
556
557
558
559
560
561
    
    def createDH(self, dh_p, dh_g, dh_Ys):
        self.dh_p = dh_p
        self.dh_g = dh_g
        self.dh_Ys = dh_Ys
        return self
562

563
564
565
566
567
    def createECDH(self, ecdhCurve, ecdhPublic):
        self.ecdhCurve = ecdhCurve
        self.ecdhPublic = ecdhPublic
        return self

568
569
    def parse(self, p):
        p.startLengthCheck(3)
570
571
572
573
574
575
576
577
578
579
580
        if self.cipherSuite in CipherSuite.srpAllSuites:
            self.srp_N = bytesToNumber(p.getVarBytes(2))
            self.srp_g = bytesToNumber(p.getVarBytes(2))
            self.srp_s = p.getVarBytes(1)
            self.srp_B = bytesToNumber(p.getVarBytes(2))
            if self.cipherSuite in CipherSuite.srpCertSuites:
                self.signature = p.getVarBytes(2)
        elif self.cipherSuite in CipherSuite.anonSuites:
            self.dh_p = bytesToNumber(p.getVarBytes(2))
            self.dh_g = bytesToNumber(p.getVarBytes(2))
            self.dh_Ys = bytesToNumber(p.getVarBytes(2))
581
582
583
        p.stopLengthCheck()
        return self

584
    def write_params(self):
585
586
587
588
589
590
        w = Writer()
        if self.cipherSuite in CipherSuite.srpAllSuites:
            w.addVarSeq(numberToByteArray(self.srp_N), 1, 2)
            w.addVarSeq(numberToByteArray(self.srp_g), 1, 2)
            w.addVarSeq(self.srp_s, 1, 1)
            w.addVarSeq(numberToByteArray(self.srp_B), 1, 2)
591
        elif self.cipherSuite in CipherSuite.dhAllSuites:
592
593
594
            w.addVarSeq(numberToByteArray(self.dh_p), 1, 2)
            w.addVarSeq(numberToByteArray(self.dh_g), 1, 2)
            w.addVarSeq(numberToByteArray(self.dh_Ys), 1, 2)
595
596
597
598
        elif self.cipherSuite in CipherSuite.ecdhAllSuites:
            w.add(ECCurveType.named_curve, 1)
            w.add(self.ecdhCurve, 2)
            w.addVarSeq(self.ecdhPublic, 1, 1)
599
600
601
602
603
604
605
606
        else:
            assert(False)
        return w.bytes

    def write(self):
        w = Writer()
        w.bytes += self.write_params()
        if self.cipherSuite in CipherSuite.certAllSuites:
607
608
609
610
            if self.version >= (3,3):
                # TODO: Signature algorithm negotiation not supported.
                w.add(HashAlgorithm.sha1, 1)
                w.add(SignatureAlgorithm.rsa, 1)
611
            w.addVarSeq(self.signature, 1, 2)
612
        return self.postWrite(w)
613
614

    def hash(self, clientRandom, serverRandom):
615
        bytes = clientRandom + serverRandom + self.write_params()
616
617
618
        if self.version >= (3,3):
            # TODO: Signature algorithm negotiation not supported.
            return SHA1(bytes)
619
        return MD5(bytes) + SHA1(bytes)
620
621
622

class ServerHelloDone(HandshakeMsg):
    def __init__(self):
623
        HandshakeMsg.__init__(self, HandshakeType.server_hello_done)
624
625
626
627
628
629
630
631
632

    def create(self):
        return self

    def parse(self, p):
        p.startLengthCheck(3)
        p.stopLengthCheck()
        return self

633
634
635
    def write(self):
        w = Writer()
        return self.postWrite(w)
636
637
638

class ClientKeyExchange(HandshakeMsg):
    def __init__(self, cipherSuite, version=None):
639
        HandshakeMsg.__init__(self, HandshakeType.client_key_exchange)
640
641
642
        self.cipherSuite = cipherSuite
        self.version = version
        self.srp_A = 0
643
        self.encryptedPreMasterSecret = bytearray(0)
644
645
646
647
648
649
650
651

    def createSRP(self, srp_A):
        self.srp_A = srp_A
        return self

    def createRSA(self, encryptedPreMasterSecret):
        self.encryptedPreMasterSecret = encryptedPreMasterSecret
        return self
652
653
654
655
656
    
    def createDH(self, dh_Yc):
        self.dh_Yc = dh_Yc
        return self
    
657
658
    def parse(self, p):
        p.startLengthCheck(3)
659
        if self.cipherSuite in CipherSuite.srpAllSuites:
660
            self.srp_A = bytesToNumber(p.getVarBytes(2))
661
        elif self.cipherSuite in CipherSuite.certSuites:
662
            if self.version in ((3,1), (3,2), (3,3)):
663
664
665
666
667
668
                self.encryptedPreMasterSecret = p.getVarBytes(2)
            elif self.version == (3,0):
                self.encryptedPreMasterSecret = \
                    p.getFixBytes(len(p.bytes)-p.index)
            else:
                raise AssertionError()
669
        elif self.cipherSuite in CipherSuite.dhAllSuites:
670
671
672
            self.dh_Yc = bytesToNumber(p.getVarBytes(2))
        elif self.cipherSuite in CipherSuite.ecdhAllSuites:
            self.ecdh_Yc = p.getVarBytes(1)
673
674
675
676
677
        else:
            raise AssertionError()
        p.stopLengthCheck()
        return self

678
679
680
681
682
    def write(self):
        w = Writer()
        if self.cipherSuite in CipherSuite.srpAllSuites:
            w.addVarSeq(numberToByteArray(self.srp_A), 1, 2)
        elif self.cipherSuite in CipherSuite.certSuites:
683
            if self.version in ((3,1), (3,2), (3,3)):
684
685
686
687
688
                w.addVarSeq(self.encryptedPreMasterSecret, 1, 2)
            elif self.version == (3,0):
                w.addFixSeq(self.encryptedPreMasterSecret, 1)
            else:
                raise AssertionError()
689
690
        elif self.cipherSuite in CipherSuite.anonSuites:
            w.addVarSeq(numberToByteArray(self.dh_Yc), 1, 2)            
691
692
        else:
            raise AssertionError()
693
        return self.postWrite(w)
694
695

class CertificateVerify(HandshakeMsg):
696
    def __init__(self, version):
697
        HandshakeMsg.__init__(self, HandshakeType.certificate_verify)
698
699
        self.version = version
        self.signature_algorithm = None
700
        self.signature = bytearray(0)
701

702
703
    def create(self, signature_algorithm, signature):
        self.signature_algorithm = signature_algorithm
704
705
706
707
708
        self.signature = signature
        return self

    def parse(self, p):
        p.startLengthCheck(3)
709
710
        if self.version >= (3,3):
            self.signature_algorithm = (p.get(1), p.get(1))
711
712
713
714
        self.signature = p.getVarBytes(2)
        p.stopLengthCheck()
        return self

715
716
    def write(self):
        w = Writer()
717
718
719
        if self.version >= (3,3):
            w.add(self.signature_algorithm[0], 1)
            w.add(self.signature_algorithm[1], 1)
720
        w.addVarSeq(self.signature, 1, 2)
721
        return self.postWrite(w)
722

723
class ChangeCipherSpec(object):
724
725
726
727
728
729
730
731
732
733
734
735
736
737
    def __init__(self):
        self.contentType = ContentType.change_cipher_spec
        self.type = 1

    def create(self):
        self.type = 1
        return self

    def parse(self, p):
        p.setLengthCheck(1)
        self.type = p.get(1)
        p.stopLengthCheck()
        return self

738
739
    def write(self):
        w = Writer()
740
        w.add(self.type,1)
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
        return w.bytes


class NextProtocol(HandshakeMsg):
    def __init__(self):
        HandshakeMsg.__init__(self, HandshakeType.next_protocol)
        self.next_proto = None

    def create(self, next_proto):
        self.next_proto = next_proto
        return self

    def parse(self, p):
        p.startLengthCheck(3)
        self.next_proto = p.getVarBytes(1)
        _ = p.getVarBytes(1)
        p.stopLengthCheck()
        return self
759

760
761
762
763
764
765
    def write(self, trial=False):
        w = Writer()
        w.addVarSeq(self.next_proto, 1, 1)
        paddingLen = 32 - ((len(self.next_proto) + 2) % 32)
        w.addVarSeq(bytearray(paddingLen), 1, 1)
        return self.postWrite(w)
766
767
768

class Finished(HandshakeMsg):
    def __init__(self, version):
769
        HandshakeMsg.__init__(self, HandshakeType.finished)
770
        self.version = version
771
        self.verify_data = bytearray(0)
772
773
774
775
776
777
778
779
780

    def create(self, verify_data):
        self.verify_data = verify_data
        return self

    def parse(self, p):
        p.startLengthCheck(3)
        if self.version == (3,0):
            self.verify_data = p.getFixBytes(36)
781
        elif self.version in ((3,1), (3,2), (3,3)):
782
783
784
785
786
787
            self.verify_data = p.getFixBytes(12)
        else:
            raise AssertionError()
        p.stopLengthCheck()
        return self

788
789
    def write(self):
        w = Writer()
790
        w.addFixSeq(self.verify_data, 1)
791
        return self.postWrite(w)
792

793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
class EncryptedExtensions(HandshakeMsg):
    def __init__(self):
        self.channel_id_key = None
        self.channel_id_proof = None

    def parse(self, p):
        p.startLengthCheck(3)
        soFar = 0
        while soFar != p.lengthCheck:
            extType = p.get(2)
            extLength = p.get(2)
            if extType == ExtensionType.channel_id:
                if extLength != 32*4:
                    raise SyntaxError()
                self.channel_id_key = p.getFixBytes(64)
                self.channel_id_proof = p.getFixBytes(64)
            else:
                p.getFixBytes(extLength)
            soFar += 4 + extLength
        p.stopLengthCheck()
        return self

815
class ApplicationData(object):
816
817
    def __init__(self):
        self.contentType = ContentType.application_data
818
        self.bytes = bytearray(0)
819
820
821
822

    def create(self, bytes):
        self.bytes = bytes
        return self
823
824
825
826
827
        
    def splitFirstByte(self):
        newMsg = ApplicationData().create(self.bytes[:1])
        self.bytes = self.bytes[1:]
        return newMsg
828
829
830
831
832
833

    def parse(self, p):
        self.bytes = p.bytes
        return self

    def write(self):
834
        return self.bytes