diff --git a/handshake_client.go b/handshake_client.go index 9cd987a..e6a9420 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -13,7 +13,6 @@ import ( "crypto/rsa" "crypto/subtle" "crypto/x509" - "encoding/binary" "errors" "fmt" "hash" @@ -22,6 +21,7 @@ import ( "strings" "time" + "github.com/xtls/reality/byteorder" "github.com/xtls/reality/hpke" "github.com/xtls/reality/mlkem" "github.com/xtls/reality/tls13" @@ -707,7 +707,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { return err } if len(skx.key) >= 3 && skx.key[0] == 3 /* named curve */ { - c.curveID = CurveID(binary.BigEndian.Uint16(skx.key[1:])) + c.curveID = CurveID(byteorder.BEUint16(skx.key[1:])) } msg, err = c.readHandshake(&hs.finishedHash) diff --git a/handshake_server.go b/handshake_server.go index dbc0551..3201520 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -12,12 +12,13 @@ import ( "crypto/rsa" "crypto/subtle" "crypto/x509" - "encoding/binary" "errors" "fmt" "hash" "io" "time" + + "github.com/xtls/reality/byteorder" ) // serverHandshakeState contains details of a server handshake in progress. @@ -579,7 +580,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { } if skx != nil { if len(skx.key) >= 3 && skx.key[0] == 3 /* named curve */ { - c.curveID = CurveID(binary.BigEndian.Uint16(skx.key[1:])) + c.curveID = CurveID(byteorder.BEUint16(skx.key[1:])) } if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil { return err diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index 1dce1c7..3f0f465 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -14,7 +14,6 @@ import ( "crypto/rsa" "crypto/sha512" "crypto/x509" - "encoding/binary" "errors" "hash" "io" @@ -22,6 +21,7 @@ import ( "slices" "time" + "github.com/xtls/reality/byteorder" "github.com/xtls/reality/mlkem" "github.com/xtls/reality/tls13" ) @@ -953,7 +953,7 @@ func (c *Conn) sendSessionTicket(earlyData bool, extra [][]byte) error { if _, err := c.config.rand().Read(ageAdd); err != nil { return err } - m.ageAdd = binary.LittleEndian.Uint32(ageAdd) + m.ageAdd = byteorder.LEUint32(ageAdd) if earlyData { // RFC 9001, Section 4.6.1