diff --git a/fips140/hash.go b/fips140/hash.go new file mode 100644 index 0000000..482ddf1 --- /dev/null +++ b/fips140/hash.go @@ -0,0 +1,32 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fips140 + +import "io" + +// Hash is the common interface implemented by all hash functions. It is a copy +// of [hash.Hash] from the standard library, to avoid depending on security +// definitions from outside of the module. +type Hash interface { + // Write (via the embedded io.Writer interface) adds more data to the + // running hash. It never returns an error. + io.Writer + + // Sum appends the current hash to b and returns the resulting slice. + // It does not change the underlying hash state. + Sum(b []byte) []byte + + // Reset resets the Hash to its initial state. + Reset() + + // Size returns the number of bytes Sum will return. + Size() int + + // BlockSize returns the hash's underlying block size. + // The Write method must be able to accept any amount + // of data, but it may operate more efficiently if all writes + // are a multiple of the block size. + BlockSize() int +} \ No newline at end of file diff --git a/handshake_client.go b/handshake_client.go index 0b1172a..fddee7a 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -24,6 +24,7 @@ import ( "github.com/xtls/reality/hpke" "github.com/xtls/reality/mlkem768" + "github.com/xtls/reality/tls13" ) type clientHandshakeState struct { @@ -325,7 +326,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { if err := transcriptMsg(hello, transcript); err != nil { return err } - earlyTrafficSecret := suite.deriveSecret(earlySecret, clientEarlyTrafficLabel, transcript) + earlyTrafficSecret := earlySecret.ClientEarlyTrafficSecret(transcript) c.quicSetWriteSecret(QUICEncryptionLevelEarly, suite.id, earlyTrafficSecret) } @@ -384,7 +385,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { } func (c *Conn) loadSession(hello *clientHelloMsg) ( - session *SessionState, earlySecret, binderKey []byte, err error) { + session *SessionState, earlySecret *tls13.EarlySecret, binderKey []byte, err error) { if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { return nil, nil, nil, nil } @@ -511,8 +512,8 @@ func (c *Conn) loadSession(hello *clientHelloMsg) ( hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())} // Compute the PSK binders. See RFC 8446, Section 4.2.11.2. - earlySecret = cipherSuite.extract(session.secret, nil) - binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) + earlySecret = tls13.NewEarlySecret(cipherSuite.hash.New, session.secret) + binderKey = earlySecret.ResumptionBinderKey() transcript := cipherSuite.hash.New() if err := computeAndUpdatePSK(hello, binderKey, transcript, cipherSuite.finishedHash); err != nil { return nil, nil, nil, err diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index 15efbd1..31509fa 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -17,6 +17,8 @@ import ( "time" "github.com/xtls/reality/mlkem768" + "github.com/xtls/reality/tls13" + "golang.org/x/crypto/hkdf" ) type clientHandshakeStateTLS13 struct { @@ -27,7 +29,7 @@ type clientHandshakeStateTLS13 struct { keyShareKeys *keySharePrivateKeys session *SessionState - earlySecret []byte + earlySecret *tls13.EarlySecret binderKey []byte certReq *certificateRequestMsgTLS13 @@ -35,7 +37,7 @@ type clientHandshakeStateTLS13 struct { sentDummyCCS bool suite *cipherSuiteTLS13 transcript hash.Hash - masterSecret []byte + masterSecret *tls13.MasterSecret trafficSecret []byte // client_application_traffic_secret_0 echContext *echContext } @@ -89,8 +91,8 @@ func (hs *clientHandshakeStateTLS13) handshake() error { confTranscript.Write(hs.serverHello.original[:30]) confTranscript.Write(make([]byte, 8)) confTranscript.Write(hs.serverHello.original[38:]) - acceptConfirmation := hs.suite.expandLabel( - hs.suite.extract(hs.echContext.innerHello.random, nil), + acceptConfirmation := tls13.ExpandLabel(hs.suite.hash.New, + hkdf.Extract(hs.suite.hash.New, hs.echContext.innerHello.random, nil), "ech accept confirmation", confTranscript.Sum(nil), 8, @@ -266,8 +268,8 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { copy(hrrHello, hs.serverHello.original) hrrHello = bytes.Replace(hrrHello, hs.serverHello.encryptedClientHello, make([]byte, 8), 1) confTranscript.Write(hrrHello) - acceptConfirmation := hs.suite.expandLabel( - hs.suite.extract(hs.echContext.innerHello.random, nil), + acceptConfirmation := tls13.ExpandLabel(hs.suite.hash.New, + hkdf.Extract(hs.suite.hash.New, hs.echContext.innerHello.random, nil), "hrr ech accept confirmation", confTranscript.Sum(nil), 8, @@ -511,17 +513,14 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { earlySecret := hs.earlySecret if !hs.usingPSK { - earlySecret = hs.suite.extract(nil, nil) + earlySecret = tls13.NewEarlySecret(hs.suite.hash.New, nil) } - handshakeSecret := hs.suite.extract(sharedKey, - hs.suite.deriveSecret(earlySecret, "derived", nil)) + handshakeSecret := earlySecret.HandshakeSecret(sharedKey) - clientSecret := hs.suite.deriveSecret(handshakeSecret, - clientHandshakeTrafficLabel, hs.transcript) + clientSecret := handshakeSecret.ClientHandshakeTrafficSecret(hs.transcript) c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret) - serverSecret := hs.suite.deriveSecret(handshakeSecret, - serverHandshakeTrafficLabel, hs.transcript) + serverSecret := handshakeSecret.ServerHandshakeTrafficSecret(hs.transcript) c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret) if c.quic != nil { @@ -543,8 +542,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { return err } - hs.masterSecret = hs.suite.extract(nil, - hs.suite.deriveSecret(handshakeSecret, "derived", nil)) + hs.masterSecret = handshakeSecret.MasterSecret() return nil } @@ -732,10 +730,8 @@ func (hs *clientHandshakeStateTLS13) readServerFinished() error { // Derive secrets that take context through the server Finished. - hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, - clientApplicationTrafficLabel, hs.transcript) - serverSecret := hs.suite.deriveSecret(hs.masterSecret, - serverApplicationTrafficLabel, hs.transcript) + hs.trafficSecret = hs.masterSecret.ClientApplicationTrafficSecret(hs.transcript) + serverSecret := hs.masterSecret.ServerApplicationTrafficSecret(hs.transcript) c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret) err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret) @@ -842,8 +838,7 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error { c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret) if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil { - c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, - resumptionLabel, hs.transcript) + c.resumptionSecret = hs.masterSecret.ResumptionMasterSecret(hs.transcript) } if c.quic != nil { @@ -887,7 +882,7 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { return c.sendAlert(alertInternalError) } - psk := cipherSuite.expandLabel(c.resumptionSecret, "resumption", + psk := tls13.ExpandLabel(cipherSuite.hash.New, c.resumptionSecret, "resumption", msg.nonce, cipherSuite.hash.Size()) session := c.sessionState() diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index 72ed000..f97381e 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -23,6 +23,7 @@ import ( "time" "github.com/xtls/reality/mlkem768" + "github.com/xtls/reality/tls13" ) // maxClientPSKIdentities is the number of client PSK identities the server will @@ -41,10 +42,10 @@ type serverHandshakeStateTLS13 struct { suite *cipherSuiteTLS13 cert *Certificate sigAlg SignatureScheme - earlySecret []byte + earlySecret *tls13.EarlySecret sharedKey []byte - handshakeSecret []byte - masterSecret []byte + handshakeSecret *tls13.HandshakeSecret + masterSecret *tls13.MasterSecret trafficSecret []byte // client_application_traffic_secret_0 transcript hash.Hash clientFinished []byte @@ -435,8 +436,8 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { } } - hs.earlySecret = hs.suite.extract(sessionState.secret, nil) - binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil) + hs.earlySecret = tls13.NewEarlySecret(hs.suite.hash.New, sessionState.secret) + binderKey := hs.earlySecret.ResumptionBinderKey() // Clone the transcript in case a HelloRetryRequest was recorded. transcript := cloneHash(hs.transcript, hs.suite.hash) if transcript == nil { @@ -464,7 +465,7 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { if err := transcriptMsg(hs.clientHello, transcript); err != nil { return err } - earlyTrafficSecret := hs.suite.deriveSecret(hs.earlySecret, clientEarlyTrafficLabel, transcript) + earlyTrafficSecret := hs.earlySecret.ClientEarlyTrafficSecret(transcript) c.quicSetReadSecret(QUICEncryptionLevelEarly, hs.suite.id, earlyTrafficSecret) } @@ -702,16 +703,13 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error { earlySecret := hs.earlySecret if earlySecret == nil { - earlySecret = hs.suite.extract(nil, nil) + earlySecret = tls13.NewEarlySecret(hs.suite.hash.New, nil) } - hs.handshakeSecret = hs.suite.extract(hs.sharedKey, - hs.suite.deriveSecret(earlySecret, "derived", nil)) + hs.handshakeSecret = earlySecret.HandshakeSecret(hs.sharedKey) - clientSecret := hs.suite.deriveSecret(hs.handshakeSecret, - clientHandshakeTrafficLabel, hs.transcript) + clientSecret := hs.handshakeSecret.ClientHandshakeTrafficSecret(hs.transcript) c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret) - serverSecret := hs.suite.deriveSecret(hs.handshakeSecret, - serverHandshakeTrafficLabel, hs.transcript) + serverSecret := hs.handshakeSecret.ServerHandshakeTrafficSecret(hs.transcript) c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret) if c.quic != nil { @@ -836,13 +834,10 @@ func (hs *serverHandshakeStateTLS13) sendServerFinished() error { // Derive secrets that take context through the server Finished. - hs.masterSecret = hs.suite.extract(nil, - hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil)) + hs.masterSecret = hs.handshakeSecret.MasterSecret() - hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, - clientApplicationTrafficLabel, hs.transcript) - serverSecret := hs.suite.deriveSecret(hs.masterSecret, - serverApplicationTrafficLabel, hs.transcript) + hs.trafficSecret = hs.masterSecret.ClientApplicationTrafficSecret(hs.transcript) + serverSecret := hs.masterSecret.ServerApplicationTrafficSecret(hs.transcript) c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret) if c.quic != nil { @@ -908,8 +903,7 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { return err } - c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, - resumptionLabel, hs.transcript) + c.resumptionSecret = hs.masterSecret.ResumptionMasterSecret(hs.transcript) if !hs.shouldSendSessionTickets() { return nil @@ -924,7 +918,7 @@ func (c *Conn) sendSessionTicket(earlyData bool, extra [][]byte) error { } // ticket_nonce, which must be unique per connection, is always left at // zero because we only ever send one ticket per connection. - psk := suite.expandLabel(c.resumptionSecret, "resumption", + psk := tls13.ExpandLabel(suite.hash.New, c.resumptionSecret, "resumption", nil, suite.hash.Size()) m := new(newSessionTicketMsgTLS13) diff --git a/hkdf/hkdf.go b/hkdf/hkdf.go new file mode 100644 index 0000000..0143eb8 --- /dev/null +++ b/hkdf/hkdf.go @@ -0,0 +1,56 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package hkdf + +import ( + "github.com/xtls/reality/fips140" + "github.com/xtls/reality/hmac" +) + +func Extract[H fips140.Hash](h func() H, secret, salt []byte) []byte { + if len(secret) < 112/8 { + // fips140.RecordNonApproved() + } + if salt == nil { + salt = make([]byte, h().Size()) + } + extractor := hmac.New(h, salt) + hmac.MarkAsUsedInKDF(extractor) + extractor.Write(secret) + + return extractor.Sum(nil) +} + +func Expand[H fips140.Hash](h func() H, pseudorandomKey []byte, info string, keyLen int) []byte { + out := make([]byte, 0, keyLen) + expander := hmac.New(h, pseudorandomKey) + hmac.MarkAsUsedInKDF(expander) + var counter uint8 + var buf []byte + + for len(out) < keyLen { + counter++ + if counter == 0 { + panic("hkdf: counter overflow") + } + if counter > 1 { + expander.Reset() + } + expander.Write(buf) + expander.Write([]byte(info)) + expander.Write([]byte{counter}) + buf = expander.Sum(buf[:0]) + remain := keyLen - len(out) + remain = min(remain, len(buf)) + out = append(out, buf[:remain]...) + } + + return out +} + +func Key[H fips140.Hash](h func() H, secret, salt []byte, info string, keyLen int) []byte { + prk := Extract(h, secret, salt) + return Expand(h, prk, info, keyLen) +} \ No newline at end of file diff --git a/hmac/hmac.go b/hmac/hmac.go new file mode 100644 index 0000000..b4cc25e --- /dev/null +++ b/hmac/hmac.go @@ -0,0 +1,172 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package hmac implements HMAC according to [FIPS 198-1]. +// +// [FIPS 198-1]: https://doi.org/10.6028/NIST.FIPS.198-1 +package hmac + +import ( + "github.com/xtls/reality/fips140" + "github.com/xtls/reality/sha256" + "github.com/xtls/reality/sha3" + "github.com/xtls/reality/sha512" +) + +// key is zero padded to the block size of the hash function +// ipad = 0x36 byte repeated for key length +// opad = 0x5c byte repeated for key length +// hmac = H([key ^ opad] H([key ^ ipad] text)) + +// marshalable is the combination of encoding.BinaryMarshaler and +// encoding.BinaryUnmarshaler. Their method definitions are repeated here to +// avoid a dependency on the encoding package. +type marshalable interface { + MarshalBinary() ([]byte, error) + UnmarshalBinary([]byte) error +} + +type HMAC struct { + opad, ipad []byte + outer, inner fips140.Hash + + // If marshaled is true, then opad and ipad do not contain a padded + // copy of the key, but rather the marshaled state of outer/inner after + // opad/ipad has been fed into it. + marshaled bool + + // forHKDF and keyLen are stored to inform the service indicator decision. + forHKDF bool + keyLen int +} + +func (h *HMAC) Sum(in []byte) []byte { + // Per FIPS 140-3 IG C.M, key lengths below 112 bits are only allowed for + // legacy use (i.e. verification only) and we don't support that. However, + // HKDF uses the HMAC key for the salt, which is allowed to be shorter. + if h.keyLen < 112/8 && !h.forHKDF { + // fips140.RecordNonApproved() + } + switch h.inner.(type) { + case *sha256.Digest, *sha512.Digest, *sha3.Digest: + default: + // fips140.RecordNonApproved() + } + + origLen := len(in) + in = h.inner.Sum(in) + + if h.marshaled { + if err := h.outer.(marshalable).UnmarshalBinary(h.opad); err != nil { + panic(err) + } + } else { + h.outer.Reset() + h.outer.Write(h.opad) + } + h.outer.Write(in[origLen:]) + return h.outer.Sum(in[:origLen]) +} + +func (h *HMAC) Write(p []byte) (n int, err error) { + return h.inner.Write(p) +} + +func (h *HMAC) Size() int { return h.outer.Size() } +func (h *HMAC) BlockSize() int { return h.inner.BlockSize() } + +func (h *HMAC) Reset() { + if h.marshaled { + if err := h.inner.(marshalable).UnmarshalBinary(h.ipad); err != nil { + panic(err) + } + return + } + + h.inner.Reset() + h.inner.Write(h.ipad) + + // If the underlying hash is marshalable, we can save some time by saving a + // copy of the hash state now, and restoring it on future calls to Reset and + // Sum instead of writing ipad/opad every time. + // + // We do this on Reset to avoid slowing down the common single-use case. + // + // This is allowed by FIPS 198-1, Section 6: "Conceptually, the intermediate + // results of the compression function on the B-byte blocks (K0 ⊕ ipad) and + // (K0 ⊕ opad) can be precomputed once, at the time of generation of the key + // K, or before its first use. These intermediate results can be stored and + // then used to initialize H each time that a message needs to be + // authenticated using the same key. [...] These stored intermediate values + // shall be treated and protected in the same manner as secret keys." + marshalableInner, innerOK := h.inner.(marshalable) + if !innerOK { + return + } + marshalableOuter, outerOK := h.outer.(marshalable) + if !outerOK { + return + } + + imarshal, err := marshalableInner.MarshalBinary() + if err != nil { + return + } + + h.outer.Reset() + h.outer.Write(h.opad) + omarshal, err := marshalableOuter.MarshalBinary() + if err != nil { + return + } + + // Marshaling succeeded; save the marshaled state for later + h.ipad = imarshal + h.opad = omarshal + h.marshaled = true +} + +// New returns a new HMAC hash using the given [fips140.Hash] type and key. +func New[H fips140.Hash](h func() H, key []byte) *HMAC { + hm := &HMAC{keyLen: len(key)} + hm.outer = h() + hm.inner = h() + unique := true + func() { + defer func() { + // The comparison might panic if the underlying types are not comparable. + _ = recover() + }() + if hm.outer == hm.inner { + unique = false + } + }() + if !unique { + panic("crypto/hmac: hash generation function does not produce unique values") + } + blocksize := hm.inner.BlockSize() + hm.ipad = make([]byte, blocksize) + hm.opad = make([]byte, blocksize) + if len(key) > blocksize { + // If key is too big, hash it. + hm.outer.Write(key) + key = hm.outer.Sum(nil) + } + copy(hm.ipad, key) + copy(hm.opad, key) + for i := range hm.ipad { + hm.ipad[i] ^= 0x36 + } + for i := range hm.opad { + hm.opad[i] ^= 0x5c + } + hm.inner.Write(hm.ipad) + + return hm +} + +// MarkAsUsedInKDF records that this HMAC instance is used as part of a KDF. +func MarkAsUsedInKDF(h *HMAC) { + h.forHKDF = true +} \ No newline at end of file diff --git a/key_schedule.go b/key_schedule.go index 63636b5..0210be9 100644 --- a/key_schedule.go +++ b/key_schedule.go @@ -8,93 +8,28 @@ import ( "crypto/ecdh" "crypto/hmac" "errors" - "fmt" "hash" "io" - "golang.org/x/crypto/cryptobyte" - "golang.org/x/crypto/hkdf" "golang.org/x/crypto/sha3" "github.com/xtls/reality/mlkem768" + "github.com/xtls/reality/tls13" ) // This file contains the functions necessary to compute the TLS 1.3 key // schedule. See RFC 8446, Section 7. -const ( - resumptionBinderLabel = "res binder" - clientEarlyTrafficLabel = "c e traffic" - clientHandshakeTrafficLabel = "c hs traffic" - serverHandshakeTrafficLabel = "s hs traffic" - clientApplicationTrafficLabel = "c ap traffic" - serverApplicationTrafficLabel = "s ap traffic" - exporterLabel = "exp master" - resumptionLabel = "res master" - trafficUpdateLabel = "traffic upd" -) - -// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. -func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte { - var hkdfLabel cryptobyte.Builder - hkdfLabel.AddUint16(uint16(length)) - hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte("tls13 ")) - b.AddBytes([]byte(label)) - }) - hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(context) - }) - hkdfLabelBytes, err := hkdfLabel.Bytes() - if err != nil { - // Rather than calling BytesOrPanic, we explicitly handle this error, in - // order to provide a reasonable error message. It should be basically - // impossible for this to panic, and routing errors back through the - // tree rooted in this function is quite painful. The labels are fixed - // size, and the context is either a fixed-length computed hash, or - // parsed from a field which has the same length limitation. As such, an - // error here is likely to only be caused during development. - // - // NOTE: another reasonable approach here might be to return a - // randomized slice if we encounter an error, which would break the - // connection, but avoid panicking. This would perhaps be safer but - // significantly more confusing to users. - panic(fmt.Errorf("failed to construct HKDF label: %s", err)) - } - out := make([]byte, length) - n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out) - if err != nil || n != length { - panic("tls: HKDF-Expand-Label invocation failed unexpectedly") - } - return out -} - -// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1. -func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { - if transcript == nil { - transcript = c.hash.New() - } - return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size()) -} - -// extract implements HKDF-Extract with the cipher suite hash. -func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte { - if newSecret == nil { - newSecret = make([]byte, c.hash.Size()) - } - return hkdf.Extract(c.hash.New, newSecret, currentSecret) -} - // nextTrafficSecret generates the next traffic secret, given the current one, // according to RFC 8446, Section 7.2. func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte { - return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size()) + return tls13.ExpandLabel(c.hash.New, trafficSecret, "traffic upd", nil, c.hash.Size()) } // trafficKey generates traffic keys according to RFC 8446, Section 7.3. func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) { - key = c.expandLabel(trafficSecret, "key", nil, c.keyLen) - iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength) + key = tls13.ExpandLabel(c.hash.New, trafficSecret, "key", nil, c.keyLen) + iv = tls13.ExpandLabel(c.hash.New, trafficSecret, "iv", nil, aeadNonceLength) return } @@ -102,7 +37,7 @@ func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) { // to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey // selection. func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte { - finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size()) + finishedKey := tls13.ExpandLabel(c.hash.New, baseKey, "finished", nil, c.hash.Size()) verifyData := hmac.New(c.hash.New, finishedKey) verifyData.Write(transcript.Sum(nil)) return verifyData.Sum(nil) @@ -110,13 +45,10 @@ func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) [] // exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to // RFC 8446, Section 7.5. -func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) { - expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript) +func (c *cipherSuiteTLS13) exportKeyingMaterial(s *tls13.MasterSecret, transcript hash.Hash) func(string, []byte, int) ([]byte, error) { + expMasterSecret := s.ExporterMasterSecret(transcript) return func(label string, context []byte, length int) ([]byte, error) { - secret := c.deriveSecret(expMasterSecret, label, nil) - h := c.hash.New() - h.Write(context) - return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil + return expMasterSecret.Exporter(label, context, length), nil } } diff --git a/sha256/sha256.go b/sha256/sha256.go new file mode 100644 index 0000000..d0e15b7 --- /dev/null +++ b/sha256/sha256.go @@ -0,0 +1,231 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sha256 implements the SHA-224 and SHA-256 hash algorithms as defined +// in FIPS 180-4. +package sha256 + +import ( + //"github.com/xtls/reality" + "github.com/xtls/reality/byteorder" + "errors" +) + +// The size of a SHA-256 checksum in bytes. +const size = 32 + +// The size of a SHA-224 checksum in bytes. +const size224 = 28 + +// The block size of SHA-256 and SHA-224 in bytes. +const blockSize = 64 + +const ( + chunk = 64 + init0 = 0x6A09E667 + init1 = 0xBB67AE85 + init2 = 0x3C6EF372 + init3 = 0xA54FF53A + init4 = 0x510E527F + init5 = 0x9B05688C + init6 = 0x1F83D9AB + init7 = 0x5BE0CD19 + init0_224 = 0xC1059ED8 + init1_224 = 0x367CD507 + init2_224 = 0x3070DD17 + init3_224 = 0xF70E5939 + init4_224 = 0xFFC00B31 + init5_224 = 0x68581511 + init6_224 = 0x64F98FA7 + init7_224 = 0xBEFA4FA4 +) + +// Digest is a SHA-224 or SHA-256 [hash.Hash] implementation. +type Digest struct { + h [8]uint32 + x [chunk]byte + nx int + len uint64 + is224 bool // mark if this digest is SHA-224 +} + +const ( + magic224 = "sha\x02" + magic256 = "sha\x03" + marshaledSize = len(magic256) + 8*4 + chunk + 8 +) + +func (d *Digest) MarshalBinary() ([]byte, error) { + return d.AppendBinary(make([]byte, 0, marshaledSize)) +} + +func (d *Digest) AppendBinary(b []byte) ([]byte, error) { + if d.is224 { + b = append(b, magic224...) + } else { + b = append(b, magic256...) + } + b = byteorder.BEAppendUint32(b, d.h[0]) + b = byteorder.BEAppendUint32(b, d.h[1]) + b = byteorder.BEAppendUint32(b, d.h[2]) + b = byteorder.BEAppendUint32(b, d.h[3]) + b = byteorder.BEAppendUint32(b, d.h[4]) + b = byteorder.BEAppendUint32(b, d.h[5]) + b = byteorder.BEAppendUint32(b, d.h[6]) + b = byteorder.BEAppendUint32(b, d.h[7]) + b = append(b, d.x[:d.nx]...) + b = append(b, make([]byte, len(d.x)-d.nx)...) + b = byteorder.BEAppendUint64(b, d.len) + return b, nil +} + +func (d *Digest) UnmarshalBinary(b []byte) error { + if len(b) < len(magic224) || (d.is224 && string(b[:len(magic224)]) != magic224) || (!d.is224 && string(b[:len(magic256)]) != magic256) { + return errors.New("crypto/sha256: invalid hash state identifier") + } + if len(b) != marshaledSize { + return errors.New("crypto/sha256: invalid hash state size") + } + b = b[len(magic224):] + b, d.h[0] = consumeUint32(b) + b, d.h[1] = consumeUint32(b) + b, d.h[2] = consumeUint32(b) + b, d.h[3] = consumeUint32(b) + b, d.h[4] = consumeUint32(b) + b, d.h[5] = consumeUint32(b) + b, d.h[6] = consumeUint32(b) + b, d.h[7] = consumeUint32(b) + b = b[copy(d.x[:], b):] + b, d.len = consumeUint64(b) + d.nx = int(d.len % chunk) + return nil +} + +func consumeUint64(b []byte) ([]byte, uint64) { + return b[8:], byteorder.BEUint64(b) +} + +func consumeUint32(b []byte) ([]byte, uint32) { + return b[4:], byteorder.BEUint32(b) +} + +func (d *Digest) Reset() { + if !d.is224 { + d.h[0] = init0 + d.h[1] = init1 + d.h[2] = init2 + d.h[3] = init3 + d.h[4] = init4 + d.h[5] = init5 + d.h[6] = init6 + d.h[7] = init7 + } else { + d.h[0] = init0_224 + d.h[1] = init1_224 + d.h[2] = init2_224 + d.h[3] = init3_224 + d.h[4] = init4_224 + d.h[5] = init5_224 + d.h[6] = init6_224 + d.h[7] = init7_224 + } + d.nx = 0 + d.len = 0 +} + +// New returns a new Digest computing the SHA-256 hash. +func New() *Digest { + d := new(Digest) + d.Reset() + return d +} + +// New224 returns a new Digest computing the SHA-224 hash. +func New224() *Digest { + d := new(Digest) + d.is224 = true + d.Reset() + return d +} + +func (d *Digest) Size() int { + if !d.is224 { + return size + } + return size224 +} + +func (d *Digest) BlockSize() int { return blockSize } + +func (d *Digest) Write(p []byte) (nn int, err error) { + nn = len(p) + d.len += uint64(nn) + if d.nx > 0 { + n := copy(d.x[d.nx:], p) + d.nx += n + if d.nx == chunk { + block(d, d.x[:]) + d.nx = 0 + } + p = p[n:] + } + if len(p) >= chunk { + n := len(p) &^ (chunk - 1) + block(d, p[:n]) + p = p[n:] + } + if len(p) > 0 { + d.nx = copy(d.x[:], p) + } + return +} + +func (d *Digest) Sum(in []byte) []byte { + //fips140.RecordApproved() + // Make a copy of d so that caller can keep writing and summing. + d0 := *d + hash := d0.checkSum() + if d0.is224 { + return append(in, hash[:size224]...) + } + return append(in, hash[:]...) +} + +func (d *Digest) checkSum() [size]byte { + len := d.len + // Padding. Add a 1 bit and 0 bits until 56 bytes mod 64. + var tmp [64 + 8]byte // padding + length buffer + tmp[0] = 0x80 + var t uint64 + if len%64 < 56 { + t = 56 - len%64 + } else { + t = 64 + 56 - len%64 + } + + // Length in bits. + len <<= 3 + padlen := tmp[:t+8] + byteorder.BEPutUint64(padlen[t+0:], len) + d.Write(padlen) + + if d.nx != 0 { + panic("d.nx != 0") + } + + var digest [size]byte + + byteorder.BEPutUint32(digest[0:], d.h[0]) + byteorder.BEPutUint32(digest[4:], d.h[1]) + byteorder.BEPutUint32(digest[8:], d.h[2]) + byteorder.BEPutUint32(digest[12:], d.h[3]) + byteorder.BEPutUint32(digest[16:], d.h[4]) + byteorder.BEPutUint32(digest[20:], d.h[5]) + byteorder.BEPutUint32(digest[24:], d.h[6]) + if !d.is224 { + byteorder.BEPutUint32(digest[28:], d.h[7]) + } + + return digest +} \ No newline at end of file diff --git a/sha256/sha256block.go b/sha256/sha256block.go new file mode 100644 index 0000000..55fae63 --- /dev/null +++ b/sha256/sha256block.go @@ -0,0 +1,128 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// SHA256 block step. +// In its own file so that a faster assembly or C version +// can be substituted easily. + +package sha256 + +import "math/bits" + +var _K = [...]uint32{ + 0x428a2f98, + 0x71374491, + 0xb5c0fbcf, + 0xe9b5dba5, + 0x3956c25b, + 0x59f111f1, + 0x923f82a4, + 0xab1c5ed5, + 0xd807aa98, + 0x12835b01, + 0x243185be, + 0x550c7dc3, + 0x72be5d74, + 0x80deb1fe, + 0x9bdc06a7, + 0xc19bf174, + 0xe49b69c1, + 0xefbe4786, + 0x0fc19dc6, + 0x240ca1cc, + 0x2de92c6f, + 0x4a7484aa, + 0x5cb0a9dc, + 0x76f988da, + 0x983e5152, + 0xa831c66d, + 0xb00327c8, + 0xbf597fc7, + 0xc6e00bf3, + 0xd5a79147, + 0x06ca6351, + 0x14292967, + 0x27b70a85, + 0x2e1b2138, + 0x4d2c6dfc, + 0x53380d13, + 0x650a7354, + 0x766a0abb, + 0x81c2c92e, + 0x92722c85, + 0xa2bfe8a1, + 0xa81a664b, + 0xc24b8b70, + 0xc76c51a3, + 0xd192e819, + 0xd6990624, + 0xf40e3585, + 0x106aa070, + 0x19a4c116, + 0x1e376c08, + 0x2748774c, + 0x34b0bcb5, + 0x391c0cb3, + 0x4ed8aa4a, + 0x5b9cca4f, + 0x682e6ff3, + 0x748f82ee, + 0x78a5636f, + 0x84c87814, + 0x8cc70208, + 0x90befffa, + 0xa4506ceb, + 0xbef9a3f7, + 0xc67178f2, +} + +func blockGeneric(dig *Digest, p []byte) { + var w [64]uint32 + h0, h1, h2, h3, h4, h5, h6, h7 := dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4], dig.h[5], dig.h[6], dig.h[7] + for len(p) >= chunk { + // Can interlace the computation of w with the + // rounds below if needed for speed. + for i := 0; i < 16; i++ { + j := i * 4 + w[i] = uint32(p[j])<<24 | uint32(p[j+1])<<16 | uint32(p[j+2])<<8 | uint32(p[j+3]) + } + for i := 16; i < 64; i++ { + v1 := w[i-2] + t1 := (bits.RotateLeft32(v1, -17)) ^ (bits.RotateLeft32(v1, -19)) ^ (v1 >> 10) + v2 := w[i-15] + t2 := (bits.RotateLeft32(v2, -7)) ^ (bits.RotateLeft32(v2, -18)) ^ (v2 >> 3) + w[i] = t1 + w[i-7] + t2 + w[i-16] + } + + a, b, c, d, e, f, g, h := h0, h1, h2, h3, h4, h5, h6, h7 + + for i := 0; i < 64; i++ { + t1 := h + ((bits.RotateLeft32(e, -6)) ^ (bits.RotateLeft32(e, -11)) ^ (bits.RotateLeft32(e, -25))) + ((e & f) ^ (^e & g)) + _K[i] + w[i] + + t2 := ((bits.RotateLeft32(a, -2)) ^ (bits.RotateLeft32(a, -13)) ^ (bits.RotateLeft32(a, -22))) + ((a & b) ^ (a & c) ^ (b & c)) + + h = g + g = f + f = e + e = d + t1 + d = c + c = b + b = a + a = t1 + t2 + } + + h0 += a + h1 += b + h2 += c + h3 += d + h4 += e + h5 += f + h6 += g + h7 += h + + p = p[chunk:] + } + + dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4], dig.h[5], dig.h[6], dig.h[7] = h0, h1, h2, h3, h4, h5, h6, h7 +} \ No newline at end of file diff --git a/sha256/sha256block_noasm.go b/sha256/sha256block_noasm.go new file mode 100644 index 0000000..119a998 --- /dev/null +++ b/sha256/sha256block_noasm.go @@ -0,0 +1,9 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sha256 + +func block(dig *Digest, p []byte) { + blockGeneric(dig, p) +} \ No newline at end of file diff --git a/sha3/hashes.go b/sha3/hashes.go new file mode 100644 index 0000000..b719ca1 --- /dev/null +++ b/sha3/hashes.go @@ -0,0 +1,59 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sha3 + +// New224 returns a new Digest computing the SHA3-224 hash. +func New224() *Digest { + return &Digest{rate: rateK448, outputLen: 28, dsbyte: dsbyteSHA3} +} + +// New256 returns a new Digest computing the SHA3-256 hash. +func New256() *Digest { + return &Digest{rate: rateK512, outputLen: 32, dsbyte: dsbyteSHA3} +} + +// New384 returns a new Digest computing the SHA3-384 hash. +func New384() *Digest { + return &Digest{rate: rateK768, outputLen: 48, dsbyte: dsbyteSHA3} +} + +// New512 returns a new Digest computing the SHA3-512 hash. +func New512() *Digest { + return &Digest{rate: rateK1024, outputLen: 64, dsbyte: dsbyteSHA3} +} + +// TODO(fips): do this in the stdlib crypto/sha3 package. +// +// crypto.RegisterHash(crypto.SHA3_224, New224) +// crypto.RegisterHash(crypto.SHA3_256, New256) +// crypto.RegisterHash(crypto.SHA3_384, New384) +// crypto.RegisterHash(crypto.SHA3_512, New512) + +const ( + dsbyteSHA3 = 0b00000110 + dsbyteKeccak = 0b00000001 + dsbyteShake = 0b00011111 + dsbyteCShake = 0b00000100 + + // rateK[c] is the rate in bytes for Keccak[c] where c is the capacity in + // bits. Given the sponge size is 1600 bits, the rate is 1600 - c bits. + rateK256 = (1600 - 256) / 8 + rateK448 = (1600 - 448) / 8 + rateK512 = (1600 - 512) / 8 + rateK768 = (1600 - 768) / 8 + rateK1024 = (1600 - 1024) / 8 +) + +// NewLegacyKeccak256 returns a new Digest computing the legacy, non-standard +// Keccak-256 hash. +func NewLegacyKeccak256() *Digest { + return &Digest{rate: rateK512, outputLen: 32, dsbyte: dsbyteKeccak} +} + +// NewLegacyKeccak512 returns a new Digest computing the legacy, non-standard +// Keccak-512 hash. +func NewLegacyKeccak512() *Digest { + return &Digest{rate: rateK1024, outputLen: 64, dsbyte: dsbyteKeccak} +} \ No newline at end of file diff --git a/sha3/keccakf.go b/sha3/keccakf.go new file mode 100644 index 0000000..1a6ec83 --- /dev/null +++ b/sha3/keccakf.go @@ -0,0 +1,434 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sha3 + +import ( + "encoding/binary" + + "github.com/xtls/reality/byteorder" + //"crypto/internal/fips140deps/cpu" + "math/bits" + "unsafe" +) + +// rc stores the round constants for use in the ι step. +var rc = [24]uint64{ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808A, + 0x8000000080008000, + 0x000000000000808B, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008A, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000A, + 0x000000008000808B, + 0x800000000000008B, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800A, + 0x800000008000000A, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, +} + +// keccakF1600Generic applies the Keccak permutation. +func keccakF1600Generic(da *[200]byte) { + var a *[25]uint64 + //if cpu.BigEndian { + if binary.NativeEndian.Uint16([]byte{0x12, 0x34}) != uint16(0x3412) { + a = new([25]uint64) + for i := range a { + a[i] = byteorder.LEUint64(da[i*8:]) + } + defer func() { + for i := range a { + byteorder.LEPutUint64(da[i*8:], a[i]) + } + }() + } else { + a = (*[25]uint64)(unsafe.Pointer(da)) + } + + // Implementation translated from Keccak-inplace.c + // in the keccak reference code. + var t, bc0, bc1, bc2, bc3, bc4, d0, d1, d2, d3, d4 uint64 + + for i := 0; i < 24; i += 4 { + // Combines the 5 steps in each round into 2 steps. + // Unrolls 4 rounds per loop and spreads some steps across rounds. + + // Round 1 + bc0 = a[0] ^ a[5] ^ a[10] ^ a[15] ^ a[20] + bc1 = a[1] ^ a[6] ^ a[11] ^ a[16] ^ a[21] + bc2 = a[2] ^ a[7] ^ a[12] ^ a[17] ^ a[22] + bc3 = a[3] ^ a[8] ^ a[13] ^ a[18] ^ a[23] + bc4 = a[4] ^ a[9] ^ a[14] ^ a[19] ^ a[24] + d0 = bc4 ^ (bc1<<1 | bc1>>63) + d1 = bc0 ^ (bc2<<1 | bc2>>63) + d2 = bc1 ^ (bc3<<1 | bc3>>63) + d3 = bc2 ^ (bc4<<1 | bc4>>63) + d4 = bc3 ^ (bc0<<1 | bc0>>63) + + bc0 = a[0] ^ d0 + t = a[6] ^ d1 + bc1 = bits.RotateLeft64(t, 44) + t = a[12] ^ d2 + bc2 = bits.RotateLeft64(t, 43) + t = a[18] ^ d3 + bc3 = bits.RotateLeft64(t, 21) + t = a[24] ^ d4 + bc4 = bits.RotateLeft64(t, 14) + a[0] = bc0 ^ (bc2 &^ bc1) ^ rc[i] + a[6] = bc1 ^ (bc3 &^ bc2) + a[12] = bc2 ^ (bc4 &^ bc3) + a[18] = bc3 ^ (bc0 &^ bc4) + a[24] = bc4 ^ (bc1 &^ bc0) + + t = a[10] ^ d0 + bc2 = bits.RotateLeft64(t, 3) + t = a[16] ^ d1 + bc3 = bits.RotateLeft64(t, 45) + t = a[22] ^ d2 + bc4 = bits.RotateLeft64(t, 61) + t = a[3] ^ d3 + bc0 = bits.RotateLeft64(t, 28) + t = a[9] ^ d4 + bc1 = bits.RotateLeft64(t, 20) + a[10] = bc0 ^ (bc2 &^ bc1) + a[16] = bc1 ^ (bc3 &^ bc2) + a[22] = bc2 ^ (bc4 &^ bc3) + a[3] = bc3 ^ (bc0 &^ bc4) + a[9] = bc4 ^ (bc1 &^ bc0) + + t = a[20] ^ d0 + bc4 = bits.RotateLeft64(t, 18) + t = a[1] ^ d1 + bc0 = bits.RotateLeft64(t, 1) + t = a[7] ^ d2 + bc1 = bits.RotateLeft64(t, 6) + t = a[13] ^ d3 + bc2 = bits.RotateLeft64(t, 25) + t = a[19] ^ d4 + bc3 = bits.RotateLeft64(t, 8) + a[20] = bc0 ^ (bc2 &^ bc1) + a[1] = bc1 ^ (bc3 &^ bc2) + a[7] = bc2 ^ (bc4 &^ bc3) + a[13] = bc3 ^ (bc0 &^ bc4) + a[19] = bc4 ^ (bc1 &^ bc0) + + t = a[5] ^ d0 + bc1 = bits.RotateLeft64(t, 36) + t = a[11] ^ d1 + bc2 = bits.RotateLeft64(t, 10) + t = a[17] ^ d2 + bc3 = bits.RotateLeft64(t, 15) + t = a[23] ^ d3 + bc4 = bits.RotateLeft64(t, 56) + t = a[4] ^ d4 + bc0 = bits.RotateLeft64(t, 27) + a[5] = bc0 ^ (bc2 &^ bc1) + a[11] = bc1 ^ (bc3 &^ bc2) + a[17] = bc2 ^ (bc4 &^ bc3) + a[23] = bc3 ^ (bc0 &^ bc4) + a[4] = bc4 ^ (bc1 &^ bc0) + + t = a[15] ^ d0 + bc3 = bits.RotateLeft64(t, 41) + t = a[21] ^ d1 + bc4 = bits.RotateLeft64(t, 2) + t = a[2] ^ d2 + bc0 = bits.RotateLeft64(t, 62) + t = a[8] ^ d3 + bc1 = bits.RotateLeft64(t, 55) + t = a[14] ^ d4 + bc2 = bits.RotateLeft64(t, 39) + a[15] = bc0 ^ (bc2 &^ bc1) + a[21] = bc1 ^ (bc3 &^ bc2) + a[2] = bc2 ^ (bc4 &^ bc3) + a[8] = bc3 ^ (bc0 &^ bc4) + a[14] = bc4 ^ (bc1 &^ bc0) + + // Round 2 + bc0 = a[0] ^ a[5] ^ a[10] ^ a[15] ^ a[20] + bc1 = a[1] ^ a[6] ^ a[11] ^ a[16] ^ a[21] + bc2 = a[2] ^ a[7] ^ a[12] ^ a[17] ^ a[22] + bc3 = a[3] ^ a[8] ^ a[13] ^ a[18] ^ a[23] + bc4 = a[4] ^ a[9] ^ a[14] ^ a[19] ^ a[24] + d0 = bc4 ^ (bc1<<1 | bc1>>63) + d1 = bc0 ^ (bc2<<1 | bc2>>63) + d2 = bc1 ^ (bc3<<1 | bc3>>63) + d3 = bc2 ^ (bc4<<1 | bc4>>63) + d4 = bc3 ^ (bc0<<1 | bc0>>63) + + bc0 = a[0] ^ d0 + t = a[16] ^ d1 + bc1 = bits.RotateLeft64(t, 44) + t = a[7] ^ d2 + bc2 = bits.RotateLeft64(t, 43) + t = a[23] ^ d3 + bc3 = bits.RotateLeft64(t, 21) + t = a[14] ^ d4 + bc4 = bits.RotateLeft64(t, 14) + a[0] = bc0 ^ (bc2 &^ bc1) ^ rc[i+1] + a[16] = bc1 ^ (bc3 &^ bc2) + a[7] = bc2 ^ (bc4 &^ bc3) + a[23] = bc3 ^ (bc0 &^ bc4) + a[14] = bc4 ^ (bc1 &^ bc0) + + t = a[20] ^ d0 + bc2 = bits.RotateLeft64(t, 3) + t = a[11] ^ d1 + bc3 = bits.RotateLeft64(t, 45) + t = a[2] ^ d2 + bc4 = bits.RotateLeft64(t, 61) + t = a[18] ^ d3 + bc0 = bits.RotateLeft64(t, 28) + t = a[9] ^ d4 + bc1 = bits.RotateLeft64(t, 20) + a[20] = bc0 ^ (bc2 &^ bc1) + a[11] = bc1 ^ (bc3 &^ bc2) + a[2] = bc2 ^ (bc4 &^ bc3) + a[18] = bc3 ^ (bc0 &^ bc4) + a[9] = bc4 ^ (bc1 &^ bc0) + + t = a[15] ^ d0 + bc4 = bits.RotateLeft64(t, 18) + t = a[6] ^ d1 + bc0 = bits.RotateLeft64(t, 1) + t = a[22] ^ d2 + bc1 = bits.RotateLeft64(t, 6) + t = a[13] ^ d3 + bc2 = bits.RotateLeft64(t, 25) + t = a[4] ^ d4 + bc3 = bits.RotateLeft64(t, 8) + a[15] = bc0 ^ (bc2 &^ bc1) + a[6] = bc1 ^ (bc3 &^ bc2) + a[22] = bc2 ^ (bc4 &^ bc3) + a[13] = bc3 ^ (bc0 &^ bc4) + a[4] = bc4 ^ (bc1 &^ bc0) + + t = a[10] ^ d0 + bc1 = bits.RotateLeft64(t, 36) + t = a[1] ^ d1 + bc2 = bits.RotateLeft64(t, 10) + t = a[17] ^ d2 + bc3 = bits.RotateLeft64(t, 15) + t = a[8] ^ d3 + bc4 = bits.RotateLeft64(t, 56) + t = a[24] ^ d4 + bc0 = bits.RotateLeft64(t, 27) + a[10] = bc0 ^ (bc2 &^ bc1) + a[1] = bc1 ^ (bc3 &^ bc2) + a[17] = bc2 ^ (bc4 &^ bc3) + a[8] = bc3 ^ (bc0 &^ bc4) + a[24] = bc4 ^ (bc1 &^ bc0) + + t = a[5] ^ d0 + bc3 = bits.RotateLeft64(t, 41) + t = a[21] ^ d1 + bc4 = bits.RotateLeft64(t, 2) + t = a[12] ^ d2 + bc0 = bits.RotateLeft64(t, 62) + t = a[3] ^ d3 + bc1 = bits.RotateLeft64(t, 55) + t = a[19] ^ d4 + bc2 = bits.RotateLeft64(t, 39) + a[5] = bc0 ^ (bc2 &^ bc1) + a[21] = bc1 ^ (bc3 &^ bc2) + a[12] = bc2 ^ (bc4 &^ bc3) + a[3] = bc3 ^ (bc0 &^ bc4) + a[19] = bc4 ^ (bc1 &^ bc0) + + // Round 3 + bc0 = a[0] ^ a[5] ^ a[10] ^ a[15] ^ a[20] + bc1 = a[1] ^ a[6] ^ a[11] ^ a[16] ^ a[21] + bc2 = a[2] ^ a[7] ^ a[12] ^ a[17] ^ a[22] + bc3 = a[3] ^ a[8] ^ a[13] ^ a[18] ^ a[23] + bc4 = a[4] ^ a[9] ^ a[14] ^ a[19] ^ a[24] + d0 = bc4 ^ (bc1<<1 | bc1>>63) + d1 = bc0 ^ (bc2<<1 | bc2>>63) + d2 = bc1 ^ (bc3<<1 | bc3>>63) + d3 = bc2 ^ (bc4<<1 | bc4>>63) + d4 = bc3 ^ (bc0<<1 | bc0>>63) + + bc0 = a[0] ^ d0 + t = a[11] ^ d1 + bc1 = bits.RotateLeft64(t, 44) + t = a[22] ^ d2 + bc2 = bits.RotateLeft64(t, 43) + t = a[8] ^ d3 + bc3 = bits.RotateLeft64(t, 21) + t = a[19] ^ d4 + bc4 = bits.RotateLeft64(t, 14) + a[0] = bc0 ^ (bc2 &^ bc1) ^ rc[i+2] + a[11] = bc1 ^ (bc3 &^ bc2) + a[22] = bc2 ^ (bc4 &^ bc3) + a[8] = bc3 ^ (bc0 &^ bc4) + a[19] = bc4 ^ (bc1 &^ bc0) + + t = a[15] ^ d0 + bc2 = bits.RotateLeft64(t, 3) + t = a[1] ^ d1 + bc3 = bits.RotateLeft64(t, 45) + t = a[12] ^ d2 + bc4 = bits.RotateLeft64(t, 61) + t = a[23] ^ d3 + bc0 = bits.RotateLeft64(t, 28) + t = a[9] ^ d4 + bc1 = bits.RotateLeft64(t, 20) + a[15] = bc0 ^ (bc2 &^ bc1) + a[1] = bc1 ^ (bc3 &^ bc2) + a[12] = bc2 ^ (bc4 &^ bc3) + a[23] = bc3 ^ (bc0 &^ bc4) + a[9] = bc4 ^ (bc1 &^ bc0) + + t = a[5] ^ d0 + bc4 = bits.RotateLeft64(t, 18) + t = a[16] ^ d1 + bc0 = bits.RotateLeft64(t, 1) + t = a[2] ^ d2 + bc1 = bits.RotateLeft64(t, 6) + t = a[13] ^ d3 + bc2 = bits.RotateLeft64(t, 25) + t = a[24] ^ d4 + bc3 = bits.RotateLeft64(t, 8) + a[5] = bc0 ^ (bc2 &^ bc1) + a[16] = bc1 ^ (bc3 &^ bc2) + a[2] = bc2 ^ (bc4 &^ bc3) + a[13] = bc3 ^ (bc0 &^ bc4) + a[24] = bc4 ^ (bc1 &^ bc0) + + t = a[20] ^ d0 + bc1 = bits.RotateLeft64(t, 36) + t = a[6] ^ d1 + bc2 = bits.RotateLeft64(t, 10) + t = a[17] ^ d2 + bc3 = bits.RotateLeft64(t, 15) + t = a[3] ^ d3 + bc4 = bits.RotateLeft64(t, 56) + t = a[14] ^ d4 + bc0 = bits.RotateLeft64(t, 27) + a[20] = bc0 ^ (bc2 &^ bc1) + a[6] = bc1 ^ (bc3 &^ bc2) + a[17] = bc2 ^ (bc4 &^ bc3) + a[3] = bc3 ^ (bc0 &^ bc4) + a[14] = bc4 ^ (bc1 &^ bc0) + + t = a[10] ^ d0 + bc3 = bits.RotateLeft64(t, 41) + t = a[21] ^ d1 + bc4 = bits.RotateLeft64(t, 2) + t = a[7] ^ d2 + bc0 = bits.RotateLeft64(t, 62) + t = a[18] ^ d3 + bc1 = bits.RotateLeft64(t, 55) + t = a[4] ^ d4 + bc2 = bits.RotateLeft64(t, 39) + a[10] = bc0 ^ (bc2 &^ bc1) + a[21] = bc1 ^ (bc3 &^ bc2) + a[7] = bc2 ^ (bc4 &^ bc3) + a[18] = bc3 ^ (bc0 &^ bc4) + a[4] = bc4 ^ (bc1 &^ bc0) + + // Round 4 + bc0 = a[0] ^ a[5] ^ a[10] ^ a[15] ^ a[20] + bc1 = a[1] ^ a[6] ^ a[11] ^ a[16] ^ a[21] + bc2 = a[2] ^ a[7] ^ a[12] ^ a[17] ^ a[22] + bc3 = a[3] ^ a[8] ^ a[13] ^ a[18] ^ a[23] + bc4 = a[4] ^ a[9] ^ a[14] ^ a[19] ^ a[24] + d0 = bc4 ^ (bc1<<1 | bc1>>63) + d1 = bc0 ^ (bc2<<1 | bc2>>63) + d2 = bc1 ^ (bc3<<1 | bc3>>63) + d3 = bc2 ^ (bc4<<1 | bc4>>63) + d4 = bc3 ^ (bc0<<1 | bc0>>63) + + bc0 = a[0] ^ d0 + t = a[1] ^ d1 + bc1 = bits.RotateLeft64(t, 44) + t = a[2] ^ d2 + bc2 = bits.RotateLeft64(t, 43) + t = a[3] ^ d3 + bc3 = bits.RotateLeft64(t, 21) + t = a[4] ^ d4 + bc4 = bits.RotateLeft64(t, 14) + a[0] = bc0 ^ (bc2 &^ bc1) ^ rc[i+3] + a[1] = bc1 ^ (bc3 &^ bc2) + a[2] = bc2 ^ (bc4 &^ bc3) + a[3] = bc3 ^ (bc0 &^ bc4) + a[4] = bc4 ^ (bc1 &^ bc0) + + t = a[5] ^ d0 + bc2 = bits.RotateLeft64(t, 3) + t = a[6] ^ d1 + bc3 = bits.RotateLeft64(t, 45) + t = a[7] ^ d2 + bc4 = bits.RotateLeft64(t, 61) + t = a[8] ^ d3 + bc0 = bits.RotateLeft64(t, 28) + t = a[9] ^ d4 + bc1 = bits.RotateLeft64(t, 20) + a[5] = bc0 ^ (bc2 &^ bc1) + a[6] = bc1 ^ (bc3 &^ bc2) + a[7] = bc2 ^ (bc4 &^ bc3) + a[8] = bc3 ^ (bc0 &^ bc4) + a[9] = bc4 ^ (bc1 &^ bc0) + + t = a[10] ^ d0 + bc4 = bits.RotateLeft64(t, 18) + t = a[11] ^ d1 + bc0 = bits.RotateLeft64(t, 1) + t = a[12] ^ d2 + bc1 = bits.RotateLeft64(t, 6) + t = a[13] ^ d3 + bc2 = bits.RotateLeft64(t, 25) + t = a[14] ^ d4 + bc3 = bits.RotateLeft64(t, 8) + a[10] = bc0 ^ (bc2 &^ bc1) + a[11] = bc1 ^ (bc3 &^ bc2) + a[12] = bc2 ^ (bc4 &^ bc3) + a[13] = bc3 ^ (bc0 &^ bc4) + a[14] = bc4 ^ (bc1 &^ bc0) + + t = a[15] ^ d0 + bc1 = bits.RotateLeft64(t, 36) + t = a[16] ^ d1 + bc2 = bits.RotateLeft64(t, 10) + t = a[17] ^ d2 + bc3 = bits.RotateLeft64(t, 15) + t = a[18] ^ d3 + bc4 = bits.RotateLeft64(t, 56) + t = a[19] ^ d4 + bc0 = bits.RotateLeft64(t, 27) + a[15] = bc0 ^ (bc2 &^ bc1) + a[16] = bc1 ^ (bc3 &^ bc2) + a[17] = bc2 ^ (bc4 &^ bc3) + a[18] = bc3 ^ (bc0 &^ bc4) + a[19] = bc4 ^ (bc1 &^ bc0) + + t = a[20] ^ d0 + bc3 = bits.RotateLeft64(t, 41) + t = a[21] ^ d1 + bc4 = bits.RotateLeft64(t, 2) + t = a[22] ^ d2 + bc0 = bits.RotateLeft64(t, 62) + t = a[23] ^ d3 + bc1 = bits.RotateLeft64(t, 55) + t = a[24] ^ d4 + bc2 = bits.RotateLeft64(t, 39) + a[20] = bc0 ^ (bc2 &^ bc1) + a[21] = bc1 ^ (bc3 &^ bc2) + a[22] = bc2 ^ (bc4 &^ bc3) + a[23] = bc3 ^ (bc0 &^ bc4) + a[24] = bc4 ^ (bc1 &^ bc0) + } +} \ No newline at end of file diff --git a/sha3/sha3.go b/sha3/sha3.go new file mode 100644 index 0000000..7e4345d --- /dev/null +++ b/sha3/sha3.go @@ -0,0 +1,235 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sha3 implements the SHA-3 fixed-output-length hash functions and +// the SHAKE variable-output-length functions defined by [FIPS 202], as well as +// the cSHAKE extendable-output-length functions defined by [SP 800-185]. +// +// [FIPS 202]: https://doi.org/10.6028/NIST.FIPS.202 +// [SP 800-185]: https://doi.org/10.6028/NIST.SP.800-185 +package sha3 + +import ( + // "github.com/xtls/reality/fips140" + "github.com/xtls/reality/subtle" + "errors" +) + +// spongeDirection indicates the direction bytes are flowing through the sponge. +type spongeDirection int + +const ( + // spongeAbsorbing indicates that the sponge is absorbing input. + spongeAbsorbing spongeDirection = iota + // spongeSqueezing indicates that the sponge is being squeezed. + spongeSqueezing +) + +type Digest struct { + a [1600 / 8]byte // main state of the hash + + // a[n:rate] is the buffer. If absorbing, it's the remaining space to XOR + // into before running the permutation. If squeezing, it's the remaining + // output to produce before running the permutation. + n, rate int + + // dsbyte contains the "domain separation" bits and the first bit of + // the padding. Sections 6.1 and 6.2 of [1] separate the outputs of the + // SHA-3 and SHAKE functions by appending bitstrings to the message. + // Using a little-endian bit-ordering convention, these are "01" for SHA-3 + // and "1111" for SHAKE, or 00000010b and 00001111b, respectively. Then the + // padding rule from section 5.1 is applied to pad the message to a multiple + // of the rate, which involves adding a "1" bit, zero or more "0" bits, and + // a final "1" bit. We merge the first "1" bit from the padding into dsbyte, + // giving 00000110b (0x06) and 00011111b (0x1f). + // [1] http://csrc.nist.gov/publications/drafts/fips-202/fips_202_draft.pdf + // "Draft FIPS 202: SHA-3 Standard: Permutation-Based Hash and + // Extendable-Output Functions (May 2014)" + dsbyte byte + + outputLen int // the default output size in bytes + state spongeDirection // whether the sponge is absorbing or squeezing +} + +// BlockSize returns the rate of sponge underlying this hash function. +func (d *Digest) BlockSize() int { return d.rate } + +// Size returns the output size of the hash function in bytes. +func (d *Digest) Size() int { return d.outputLen } + +// Reset resets the Digest to its initial state. +func (d *Digest) Reset() { + // Zero the permutation's state. + for i := range d.a { + d.a[i] = 0 + } + d.state = spongeAbsorbing + d.n = 0 +} + +func (d *Digest) Clone() *Digest { + ret := *d + return &ret +} + +// permute applies the KeccakF-1600 permutation. +func (d *Digest) permute() { + keccakF1600(&d.a) + d.n = 0 +} + +// padAndPermute appends the domain separation bits in dsbyte, applies +// the multi-bitrate 10..1 padding rule, and permutes the state. +func (d *Digest) padAndPermute() { + // Pad with this instance's domain-separator bits. We know that there's + // at least one byte of space in the sponge because, if it were full, + // permute would have been called to empty it. dsbyte also contains the + // first one bit for the padding. See the comment in the state struct. + d.a[d.n] ^= d.dsbyte + // This adds the final one bit for the padding. Because of the way that + // bits are numbered from the LSB upwards, the final bit is the MSB of + // the last byte. + d.a[d.rate-1] ^= 0x80 + // Apply the permutation + d.permute() + d.state = spongeSqueezing +} + +// Write absorbs more data into the hash's state. +func (d *Digest) Write(p []byte) (n int, err error) { return d.write(p) } +func (d *Digest) writeGeneric(p []byte) (n int, err error) { + if d.state != spongeAbsorbing { + panic("sha3: Write after Read") + } + + n = len(p) + + for len(p) > 0 { + x := subtle.XORBytes(d.a[d.n:d.rate], d.a[d.n:d.rate], p) + d.n += x + p = p[x:] + + // If the sponge is full, apply the permutation. + if d.n == d.rate { + d.permute() + } + } + + return +} + +// read squeezes an arbitrary number of bytes from the sponge. +func (d *Digest) readGeneric(out []byte) (n int, err error) { + // If we're still absorbing, pad and apply the permutation. + if d.state == spongeAbsorbing { + d.padAndPermute() + } + + n = len(out) + + // Now, do the squeezing. + for len(out) > 0 { + // Apply the permutation if we've squeezed the sponge dry. + if d.n == d.rate { + d.permute() + } + + x := copy(out, d.a[d.n:d.rate]) + d.n += x + out = out[x:] + } + + return +} + +// Sum appends the current hash to b and returns the resulting slice. +// It does not change the underlying hash state. +func (d *Digest) Sum(b []byte) []byte { + // fips140.RecordApproved() + return d.sum(b) +} + +func (d *Digest) sumGeneric(b []byte) []byte { + if d.state != spongeAbsorbing { + panic("sha3: Sum after Read") + } + + // Make a copy of the original hash so that caller can keep writing + // and summing. + dup := d.Clone() + hash := make([]byte, dup.outputLen, 64) // explicit cap to allow stack allocation + dup.read(hash) + return append(b, hash...) +} + +const ( + magicSHA3 = "sha\x08" + magicShake = "sha\x09" + magicCShake = "sha\x0a" + magicKeccak = "sha\x0b" + // magic || rate || main state || n || sponge direction + marshaledSize = len(magicSHA3) + 1 + 200 + 1 + 1 +) + +func (d *Digest) MarshalBinary() ([]byte, error) { + return d.AppendBinary(make([]byte, 0, marshaledSize)) +} + +func (d *Digest) AppendBinary(b []byte) ([]byte, error) { + switch d.dsbyte { + case dsbyteSHA3: + b = append(b, magicSHA3...) + case dsbyteShake: + b = append(b, magicShake...) + case dsbyteCShake: + b = append(b, magicCShake...) + case dsbyteKeccak: + b = append(b, magicKeccak...) + default: + panic("unknown dsbyte") + } + // rate is at most 168, and n is at most rate. + b = append(b, byte(d.rate)) + b = append(b, d.a[:]...) + b = append(b, byte(d.n), byte(d.state)) + return b, nil +} + +func (d *Digest) UnmarshalBinary(b []byte) error { + if len(b) != marshaledSize { + return errors.New("sha3: invalid hash state") + } + + magic := string(b[:len(magicSHA3)]) + b = b[len(magicSHA3):] + switch { + case magic == magicSHA3 && d.dsbyte == dsbyteSHA3: + case magic == magicShake && d.dsbyte == dsbyteShake: + case magic == magicCShake && d.dsbyte == dsbyteCShake: + case magic == magicKeccak && d.dsbyte == dsbyteKeccak: + default: + return errors.New("sha3: invalid hash state identifier") + } + + rate := int(b[0]) + b = b[1:] + if rate != d.rate { + return errors.New("sha3: invalid hash state function") + } + + copy(d.a[:], b) + b = b[len(d.a):] + + n, state := int(b[0]), spongeDirection(b[1]) + if n > d.rate { + return errors.New("sha3: invalid hash state") + } + d.n = n + if state != spongeAbsorbing && state != spongeSqueezing { + return errors.New("sha3: invalid hash state") + } + d.state = state + + return nil +} \ No newline at end of file diff --git a/sha3/sha3_noasm.go b/sha3/sha3_noasm.go new file mode 100644 index 0000000..eba7eca --- /dev/null +++ b/sha3/sha3_noasm.go @@ -0,0 +1,19 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sha3 + +func keccakF1600(a *[200]byte) { + keccakF1600Generic(a) +} + +func (d *Digest) write(p []byte) (n int, err error) { + return d.writeGeneric(p) +} +func (d *Digest) read(out []byte) (n int, err error) { + return d.readGeneric(out) +} +func (d *Digest) sum(b []byte) []byte { + return d.sumGeneric(b) +} \ No newline at end of file diff --git a/sha512/sha512.go b/sha512/sha512.go new file mode 100644 index 0000000..9fbdb56 --- /dev/null +++ b/sha512/sha512.go @@ -0,0 +1,301 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sha512 implements the SHA-384, SHA-512, SHA-512/224, and SHA-512/256 +// hash algorithms as defined in FIPS 180-4. +package sha512 + +import ( + // "github.com/xtls/reality/fips140" + "github.com/xtls/reality/byteorder" + "errors" +) + +const ( + // size512 is the size, in bytes, of a SHA-512 checksum. + size512 = 64 + + // size224 is the size, in bytes, of a SHA-512/224 checksum. + size224 = 28 + + // size256 is the size, in bytes, of a SHA-512/256 checksum. + size256 = 32 + + // size384 is the size, in bytes, of a SHA-384 checksum. + size384 = 48 + + // blockSize is the block size, in bytes, of the SHA-512/224, + // SHA-512/256, SHA-384 and SHA-512 hash functions. + blockSize = 128 +) + +const ( + chunk = 128 + init0 = 0x6a09e667f3bcc908 + init1 = 0xbb67ae8584caa73b + init2 = 0x3c6ef372fe94f82b + init3 = 0xa54ff53a5f1d36f1 + init4 = 0x510e527fade682d1 + init5 = 0x9b05688c2b3e6c1f + init6 = 0x1f83d9abfb41bd6b + init7 = 0x5be0cd19137e2179 + init0_224 = 0x8c3d37c819544da2 + init1_224 = 0x73e1996689dcd4d6 + init2_224 = 0x1dfab7ae32ff9c82 + init3_224 = 0x679dd514582f9fcf + init4_224 = 0x0f6d2b697bd44da8 + init5_224 = 0x77e36f7304c48942 + init6_224 = 0x3f9d85a86a1d36c8 + init7_224 = 0x1112e6ad91d692a1 + init0_256 = 0x22312194fc2bf72c + init1_256 = 0x9f555fa3c84c64c2 + init2_256 = 0x2393b86b6f53b151 + init3_256 = 0x963877195940eabd + init4_256 = 0x96283ee2a88effe3 + init5_256 = 0xbe5e1e2553863992 + init6_256 = 0x2b0199fc2c85b8aa + init7_256 = 0x0eb72ddc81c52ca2 + init0_384 = 0xcbbb9d5dc1059ed8 + init1_384 = 0x629a292a367cd507 + init2_384 = 0x9159015a3070dd17 + init3_384 = 0x152fecd8f70e5939 + init4_384 = 0x67332667ffc00b31 + init5_384 = 0x8eb44a8768581511 + init6_384 = 0xdb0c2e0d64f98fa7 + init7_384 = 0x47b5481dbefa4fa4 +) + +// Digest is a SHA-384, SHA-512, SHA-512/224, or SHA-512/256 [hash.Hash] +// implementation. +type Digest struct { + h [8]uint64 + x [chunk]byte + nx int + len uint64 + size int // size224, size256, size384, or size512 +} + +func (d *Digest) Reset() { + switch d.size { + case size384: + d.h[0] = init0_384 + d.h[1] = init1_384 + d.h[2] = init2_384 + d.h[3] = init3_384 + d.h[4] = init4_384 + d.h[5] = init5_384 + d.h[6] = init6_384 + d.h[7] = init7_384 + case size224: + d.h[0] = init0_224 + d.h[1] = init1_224 + d.h[2] = init2_224 + d.h[3] = init3_224 + d.h[4] = init4_224 + d.h[5] = init5_224 + d.h[6] = init6_224 + d.h[7] = init7_224 + case size256: + d.h[0] = init0_256 + d.h[1] = init1_256 + d.h[2] = init2_256 + d.h[3] = init3_256 + d.h[4] = init4_256 + d.h[5] = init5_256 + d.h[6] = init6_256 + d.h[7] = init7_256 + case size512: + d.h[0] = init0 + d.h[1] = init1 + d.h[2] = init2 + d.h[3] = init3 + d.h[4] = init4 + d.h[5] = init5 + d.h[6] = init6 + d.h[7] = init7 + default: + panic("unknown size") + } + d.nx = 0 + d.len = 0 +} + +const ( + magic384 = "sha\x04" + magic512_224 = "sha\x05" + magic512_256 = "sha\x06" + magic512 = "sha\x07" + marshaledSize = len(magic512) + 8*8 + chunk + 8 +) + +func (d *Digest) MarshalBinary() ([]byte, error) { + return d.AppendBinary(make([]byte, 0, marshaledSize)) +} + +func (d *Digest) AppendBinary(b []byte) ([]byte, error) { + switch d.size { + case size384: + b = append(b, magic384...) + case size224: + b = append(b, magic512_224...) + case size256: + b = append(b, magic512_256...) + case size512: + b = append(b, magic512...) + default: + panic("unknown size") + } + b = byteorder.BEAppendUint64(b, d.h[0]) + b = byteorder.BEAppendUint64(b, d.h[1]) + b = byteorder.BEAppendUint64(b, d.h[2]) + b = byteorder.BEAppendUint64(b, d.h[3]) + b = byteorder.BEAppendUint64(b, d.h[4]) + b = byteorder.BEAppendUint64(b, d.h[5]) + b = byteorder.BEAppendUint64(b, d.h[6]) + b = byteorder.BEAppendUint64(b, d.h[7]) + b = append(b, d.x[:d.nx]...) + b = append(b, make([]byte, len(d.x)-d.nx)...) + b = byteorder.BEAppendUint64(b, d.len) + return b, nil +} + +func (d *Digest) UnmarshalBinary(b []byte) error { + if len(b) < len(magic512) { + return errors.New("crypto/sha512: invalid hash state identifier") + } + switch { + case d.size == size384 && string(b[:len(magic384)]) == magic384: + case d.size == size224 && string(b[:len(magic512_224)]) == magic512_224: + case d.size == size256 && string(b[:len(magic512_256)]) == magic512_256: + case d.size == size512 && string(b[:len(magic512)]) == magic512: + default: + return errors.New("crypto/sha512: invalid hash state identifier") + } + if len(b) != marshaledSize { + return errors.New("crypto/sha512: invalid hash state size") + } + b = b[len(magic512):] + b, d.h[0] = consumeUint64(b) + b, d.h[1] = consumeUint64(b) + b, d.h[2] = consumeUint64(b) + b, d.h[3] = consumeUint64(b) + b, d.h[4] = consumeUint64(b) + b, d.h[5] = consumeUint64(b) + b, d.h[6] = consumeUint64(b) + b, d.h[7] = consumeUint64(b) + b = b[copy(d.x[:], b):] + b, d.len = consumeUint64(b) + d.nx = int(d.len % chunk) + return nil +} + +func consumeUint64(b []byte) ([]byte, uint64) { + return b[8:], byteorder.BEUint64(b) +} + +// New returns a new Digest computing the SHA-512 hash. +func New() *Digest { + d := &Digest{size: size512} + d.Reset() + return d +} + +// New512_224 returns a new Digest computing the SHA-512/224 hash. +func New512_224() *Digest { + d := &Digest{size: size224} + d.Reset() + return d +} + +// New512_256 returns a new Digest computing the SHA-512/256 hash. +func New512_256() *Digest { + d := &Digest{size: size256} + d.Reset() + return d +} + +// New384 returns a new Digest computing the SHA-384 hash. +func New384() *Digest { + d := &Digest{size: size384} + d.Reset() + return d +} + +func (d *Digest) Size() int { + return d.size +} + +func (d *Digest) BlockSize() int { return blockSize } + +func (d *Digest) Write(p []byte) (nn int, err error) { + nn = len(p) + d.len += uint64(nn) + if d.nx > 0 { + n := copy(d.x[d.nx:], p) + d.nx += n + if d.nx == chunk { + block(d, d.x[:]) + d.nx = 0 + } + p = p[n:] + } + if len(p) >= chunk { + n := len(p) &^ (chunk - 1) + block(d, p[:n]) + p = p[n:] + } + if len(p) > 0 { + d.nx = copy(d.x[:], p) + } + return +} + +func (d *Digest) Sum(in []byte) []byte { + // fips140.RecordApproved() + // Make a copy of d so that caller can keep writing and summing. + d0 := new(Digest) + *d0 = *d + hash := d0.checkSum() + return append(in, hash[:d.size]...) +} + +func (d *Digest) checkSum() [size512]byte { + // Padding. Add a 1 bit and 0 bits until 112 bytes mod 128. + len := d.len + var tmp [128 + 16]byte // padding + length buffer + tmp[0] = 0x80 + var t uint64 + if len%128 < 112 { + t = 112 - len%128 + } else { + t = 128 + 112 - len%128 + } + + // Length in bits. + len <<= 3 + padlen := tmp[:t+16] + // Upper 64 bits are always zero, because len variable has type uint64, + // and tmp is already zeroed at that index, so we can skip updating it. + // byteorder.BEPutUint64(padlen[t+0:], 0) + byteorder.BEPutUint64(padlen[t+8:], len) + d.Write(padlen) + + if d.nx != 0 { + panic("d.nx != 0") + } + + var digest [size512]byte + byteorder.BEPutUint64(digest[0:], d.h[0]) + byteorder.BEPutUint64(digest[8:], d.h[1]) + byteorder.BEPutUint64(digest[16:], d.h[2]) + byteorder.BEPutUint64(digest[24:], d.h[3]) + byteorder.BEPutUint64(digest[32:], d.h[4]) + byteorder.BEPutUint64(digest[40:], d.h[5]) + if d.size != size384 { + byteorder.BEPutUint64(digest[48:], d.h[6]) + byteorder.BEPutUint64(digest[56:], d.h[7]) + } + + return digest +} \ No newline at end of file diff --git a/sha512/sha512block.go b/sha512/sha512block.go new file mode 100644 index 0000000..feb6827 --- /dev/null +++ b/sha512/sha512block.go @@ -0,0 +1,144 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// SHA512 block step. +// In its own file so that a faster assembly or C version +// can be substituted easily. + +package sha512 + +import "math/bits" + +var _K = [...]uint64{ + 0x428a2f98d728ae22, + 0x7137449123ef65cd, + 0xb5c0fbcfec4d3b2f, + 0xe9b5dba58189dbbc, + 0x3956c25bf348b538, + 0x59f111f1b605d019, + 0x923f82a4af194f9b, + 0xab1c5ed5da6d8118, + 0xd807aa98a3030242, + 0x12835b0145706fbe, + 0x243185be4ee4b28c, + 0x550c7dc3d5ffb4e2, + 0x72be5d74f27b896f, + 0x80deb1fe3b1696b1, + 0x9bdc06a725c71235, + 0xc19bf174cf692694, + 0xe49b69c19ef14ad2, + 0xefbe4786384f25e3, + 0x0fc19dc68b8cd5b5, + 0x240ca1cc77ac9c65, + 0x2de92c6f592b0275, + 0x4a7484aa6ea6e483, + 0x5cb0a9dcbd41fbd4, + 0x76f988da831153b5, + 0x983e5152ee66dfab, + 0xa831c66d2db43210, + 0xb00327c898fb213f, + 0xbf597fc7beef0ee4, + 0xc6e00bf33da88fc2, + 0xd5a79147930aa725, + 0x06ca6351e003826f, + 0x142929670a0e6e70, + 0x27b70a8546d22ffc, + 0x2e1b21385c26c926, + 0x4d2c6dfc5ac42aed, + 0x53380d139d95b3df, + 0x650a73548baf63de, + 0x766a0abb3c77b2a8, + 0x81c2c92e47edaee6, + 0x92722c851482353b, + 0xa2bfe8a14cf10364, + 0xa81a664bbc423001, + 0xc24b8b70d0f89791, + 0xc76c51a30654be30, + 0xd192e819d6ef5218, + 0xd69906245565a910, + 0xf40e35855771202a, + 0x106aa07032bbd1b8, + 0x19a4c116b8d2d0c8, + 0x1e376c085141ab53, + 0x2748774cdf8eeb99, + 0x34b0bcb5e19b48a8, + 0x391c0cb3c5c95a63, + 0x4ed8aa4ae3418acb, + 0x5b9cca4f7763e373, + 0x682e6ff3d6b2b8a3, + 0x748f82ee5defb2fc, + 0x78a5636f43172f60, + 0x84c87814a1f0ab72, + 0x8cc702081a6439ec, + 0x90befffa23631e28, + 0xa4506cebde82bde9, + 0xbef9a3f7b2c67915, + 0xc67178f2e372532b, + 0xca273eceea26619c, + 0xd186b8c721c0c207, + 0xeada7dd6cde0eb1e, + 0xf57d4f7fee6ed178, + 0x06f067aa72176fba, + 0x0a637dc5a2c898a6, + 0x113f9804bef90dae, + 0x1b710b35131c471b, + 0x28db77f523047d84, + 0x32caab7b40c72493, + 0x3c9ebe0a15c9bebc, + 0x431d67c49c100d4c, + 0x4cc5d4becb3e42b6, + 0x597f299cfc657e2a, + 0x5fcb6fab3ad6faec, + 0x6c44198c4a475817, +} + +func blockGeneric(dig *Digest, p []byte) { + var w [80]uint64 + h0, h1, h2, h3, h4, h5, h6, h7 := dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4], dig.h[5], dig.h[6], dig.h[7] + for len(p) >= chunk { + for i := 0; i < 16; i++ { + j := i * 8 + w[i] = uint64(p[j])<<56 | uint64(p[j+1])<<48 | uint64(p[j+2])<<40 | uint64(p[j+3])<<32 | + uint64(p[j+4])<<24 | uint64(p[j+5])<<16 | uint64(p[j+6])<<8 | uint64(p[j+7]) + } + for i := 16; i < 80; i++ { + v1 := w[i-2] + t1 := bits.RotateLeft64(v1, -19) ^ bits.RotateLeft64(v1, -61) ^ (v1 >> 6) + v2 := w[i-15] + t2 := bits.RotateLeft64(v2, -1) ^ bits.RotateLeft64(v2, -8) ^ (v2 >> 7) + + w[i] = t1 + w[i-7] + t2 + w[i-16] + } + + a, b, c, d, e, f, g, h := h0, h1, h2, h3, h4, h5, h6, h7 + + for i := 0; i < 80; i++ { + t1 := h + (bits.RotateLeft64(e, -14) ^ bits.RotateLeft64(e, -18) ^ bits.RotateLeft64(e, -41)) + ((e & f) ^ (^e & g)) + _K[i] + w[i] + + t2 := (bits.RotateLeft64(a, -28) ^ bits.RotateLeft64(a, -34) ^ bits.RotateLeft64(a, -39)) + ((a & b) ^ (a & c) ^ (b & c)) + + h = g + g = f + f = e + e = d + t1 + d = c + c = b + b = a + a = t1 + t2 + } + + h0 += a + h1 += b + h2 += c + h3 += d + h4 += e + h5 += f + h6 += g + h7 += h + + p = p[chunk:] + } + + dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4], dig.h[5], dig.h[6], dig.h[7] = h0, h1, h2, h3, h4, h5, h6, h7 +} \ No newline at end of file diff --git a/sha512/sha512block_noasm.go b/sha512/sha512block_noasm.go new file mode 100644 index 0000000..1a3ada7 --- /dev/null +++ b/sha512/sha512block_noasm.go @@ -0,0 +1,9 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sha512 + +func block(dig *Digest, p []byte) { + blockGeneric(dig, p) +} \ No newline at end of file diff --git a/tls13/tls13.go b/tls13/tls13.go new file mode 100644 index 0000000..90e8ad2 --- /dev/null +++ b/tls13/tls13.go @@ -0,0 +1,178 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tls13 implements the TLS 1.3 Key Schedule as specified in RFC 8446, +// Section 7.1 and allowed by FIPS 140-3 IG 2.4.B Resolution 7. +package tls13 + +import ( + "github.com/xtls/reality/byteorder" + "github.com/xtls/reality/fips140" + "github.com/xtls/reality/hkdf" +) + +// We don't set the service indicator in this package but we delegate that to +// the underlying functions because the TLS 1.3 KDF does not have a standard of +// its own. + +// ExpandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. +func ExpandLabel[H fips140.Hash](hash func() H, secret []byte, label string, context []byte, length int) []byte { + if len("tls13 ")+len(label) > 255 || len(context) > 255 { + // It should be impossible for this to panic: labels are fixed strings, + // and context is either a fixed-length computed hash, or parsed from a + // field which has the same length limitation. + // + // Another reasonable approach might be to return a randomized slice if + // we encounter an error, which would break the connection, but avoid + // panicking. This would perhaps be safer but significantly more + // confusing to users. + panic("tls13: label or context too long") + } + hkdfLabel := make([]byte, 0, 2+1+len("tls13 ")+len(label)+1+len(context)) + hkdfLabel = byteorder.BEAppendUint16(hkdfLabel, uint16(length)) + hkdfLabel = append(hkdfLabel, byte(len("tls13 ")+len(label))) + hkdfLabel = append(hkdfLabel, "tls13 "...) + hkdfLabel = append(hkdfLabel, label...) + hkdfLabel = append(hkdfLabel, byte(len(context))) + hkdfLabel = append(hkdfLabel, context...) + return hkdf.Expand(hash, secret, string(hkdfLabel), length) +} + +func extract[H fips140.Hash](hash func() H, newSecret, currentSecret []byte) []byte { + if newSecret == nil { + newSecret = make([]byte, hash().Size()) + } + return hkdf.Extract(hash, newSecret, currentSecret) +} + +func deriveSecret[H fips140.Hash](hash func() H, secret []byte, label string, transcript fips140.Hash) []byte { + if transcript == nil { + transcript = hash() + } + return ExpandLabel(hash, secret, label, transcript.Sum(nil), transcript.Size()) +} + +const ( + resumptionBinderLabel = "res binder" + clientEarlyTrafficLabel = "c e traffic" + clientHandshakeTrafficLabel = "c hs traffic" + serverHandshakeTrafficLabel = "s hs traffic" + clientApplicationTrafficLabel = "c ap traffic" + serverApplicationTrafficLabel = "s ap traffic" + earlyExporterLabel = "e exp master" + exporterLabel = "exp master" + resumptionLabel = "res master" +) + +type EarlySecret struct { + secret []byte + hash func() fips140.Hash +} + +func NewEarlySecret[H fips140.Hash](hash func() H, psk []byte) *EarlySecret { + return &EarlySecret{ + secret: extract(hash, psk, nil), + hash: func() fips140.Hash { return hash() }, + } +} + +func (s *EarlySecret) ResumptionBinderKey() []byte { + return deriveSecret(s.hash, s.secret, resumptionBinderLabel, nil) +} + +// ClientEarlyTrafficSecret derives the client_early_traffic_secret from the +// early secret and the transcript up to the ClientHello. +func (s *EarlySecret) ClientEarlyTrafficSecret(transcript fips140.Hash) []byte { + return deriveSecret(s.hash, s.secret, clientEarlyTrafficLabel, transcript) +} + +type HandshakeSecret struct { + secret []byte + hash func() fips140.Hash +} + +func (s *EarlySecret) HandshakeSecret(sharedSecret []byte) *HandshakeSecret { + derived := deriveSecret(s.hash, s.secret, "derived", nil) + return &HandshakeSecret{ + secret: extract(s.hash, sharedSecret, derived), + hash: s.hash, + } +} + +// ClientHandshakeTrafficSecret derives the client_handshake_traffic_secret from +// the handshake secret and the transcript up to the ServerHello. +func (s *HandshakeSecret) ClientHandshakeTrafficSecret(transcript fips140.Hash) []byte { + return deriveSecret(s.hash, s.secret, clientHandshakeTrafficLabel, transcript) +} + +// ServerHandshakeTrafficSecret derives the server_handshake_traffic_secret from +// the handshake secret and the transcript up to the ServerHello. +func (s *HandshakeSecret) ServerHandshakeTrafficSecret(transcript fips140.Hash) []byte { + return deriveSecret(s.hash, s.secret, serverHandshakeTrafficLabel, transcript) +} + +type MasterSecret struct { + secret []byte + hash func() fips140.Hash +} + +func (s *HandshakeSecret) MasterSecret() *MasterSecret { + derived := deriveSecret(s.hash, s.secret, "derived", nil) + return &MasterSecret{ + secret: extract(s.hash, nil, derived), + hash: s.hash, + } +} + +// ClientApplicationTrafficSecret derives the client_application_traffic_secret_0 +// from the master secret and the transcript up to the server Finished. +func (s *MasterSecret) ClientApplicationTrafficSecret(transcript fips140.Hash) []byte { + return deriveSecret(s.hash, s.secret, clientApplicationTrafficLabel, transcript) +} + +// ServerApplicationTrafficSecret derives the server_application_traffic_secret_0 +// from the master secret and the transcript up to the server Finished. +func (s *MasterSecret) ServerApplicationTrafficSecret(transcript fips140.Hash) []byte { + return deriveSecret(s.hash, s.secret, serverApplicationTrafficLabel, transcript) +} + +// ResumptionMasterSecret derives the resumption_master_secret from the master secret +// and the transcript up to the client Finished. +func (s *MasterSecret) ResumptionMasterSecret(transcript fips140.Hash) []byte { + return deriveSecret(s.hash, s.secret, resumptionLabel, transcript) +} + +type ExporterMasterSecret struct { + secret []byte + hash func() fips140.Hash +} + +// ExporterMasterSecret derives the exporter_master_secret from the master secret +// and the transcript up to the server Finished. +func (s *MasterSecret) ExporterMasterSecret(transcript fips140.Hash) *ExporterMasterSecret { + return &ExporterMasterSecret{ + secret: deriveSecret(s.hash, s.secret, exporterLabel, transcript), + hash: s.hash, + } +} + +// EarlyExporterMasterSecret derives the exporter_master_secret from the early secret +// and the transcript up to the ClientHello. +func (s *EarlySecret) EarlyExporterMasterSecret(transcript fips140.Hash) *ExporterMasterSecret { + return &ExporterMasterSecret{ + secret: deriveSecret(s.hash, s.secret, earlyExporterLabel, transcript), + hash: s.hash, + } +} + +func (s *ExporterMasterSecret) Exporter(label string, context []byte, length int) []byte { + secret := deriveSecret(s.hash, s.secret, label, nil) + h := s.hash() + h.Write(context) + return ExpandLabel(s.hash, secret, "exporter", h.Sum(nil), length) +} + +func TestingOnlyExporterSecret(s *ExporterMasterSecret) []byte { + return s.secret +} \ No newline at end of file