Skip to content

Commit

Permalink
优化 SM2
Browse files Browse the repository at this point in the history
  • Loading branch information
deatil committed Jul 21, 2024
1 parent d761882 commit 7bbb4c3
Show file tree
Hide file tree
Showing 3 changed files with 404 additions and 123 deletions.
159 changes: 75 additions & 84 deletions gm/sm2/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,87 +53,86 @@ func UnmarshalSignatureASN1(sign []byte) (r, s *big.Int, err error) {
}

// 拼接编码
func marshalCipherBytes(curve elliptic.Curve, c []byte, mode Mode, h hashFunc) []byte {
byteLen := (curve.Params().BitSize + 7) / 8
hashSize := h().Size()

func marshalCipherBytes(c encryptedData, mode Mode) []byte {
// C1C3C2 密文结构: x + y + hash + CipherText
// C1C2C3 密文结构: x + y + CipherText + hash
switch mode {
case C1C2C3:
c1 := make([]byte, 2*byteLen)
c2 := make([]byte, len(c) - 2*byteLen - hashSize)
c3 := make([]byte, hashSize)
ct := []byte{0x04}
ct = append(ct, c.XCoordinate...)
ct = append(ct, c.YCoordinate...)
ct = append(ct, c.CipherText...)
ct = append(ct, c.Hash...)

copy(c1, c[0:]) // x1, y1
copy(c3, c[2*byteLen:]) // hash
copy(c2, c[2*byteLen+hashSize:]) // 密文

ct := make([]byte, 0)
ct = append(ct, c1...)
ct = append(ct, c2...)
ct = append(ct, c3...)

return append([]byte{0x04}, ct...)
return ct
case C1C3C2:
fallthrough
default:
return append([]byte{0x04}, c...)
ct := []byte{0x04}
ct = append(ct, c.XCoordinate...)
ct = append(ct, c.YCoordinate...)
ct = append(ct, c.Hash...)
ct = append(ct, c.CipherText...)

return ct
}
}

func unmarshalCipherBytes(curve elliptic.Curve, data []byte, mode Mode, h hashFunc) ([]byte, error) {
func unmarshalCipherBytes(curve elliptic.Curve, data []byte, mode Mode, h hashFunc) (encryptedData, error) {
typ := data[0]
if typ != byte(0x04) {
return nil, errors.New("cryptobin/sm2: encrypted data is error and miss prefix '4'.")
return encryptedData{}, errors.New("cryptobin/sm2: encrypted data is error and miss prefix '4'.")
}

hashSize := h().Size()

byteLen := (curve.Params().BitSize + 7) / 8
if len(data) < 2*byteLen + hashSize {
return nil, errors.New("cryptobin/sm2: encrypt data is too short.")
return encryptedData{}, errors.New("cryptobin/sm2: encrypt data is too short.")
}

switch mode {
case C1C2C3:
data = data[1:]

c1 := make([]byte, 2*byteLen)
c2 := make([]byte, len(data) - 2*byteLen - hashSize)
c3 := make([]byte, hashSize)
c1 := data[:2*byteLen]
c2 := data[2*byteLen:len(data) - hashSize]
c3 := data[len(data) - hashSize:]

copy(c1, data[0:]) // x1, y1
copy(c2, data[2*byteLen:]) // 密文
copy(c3, data[len(data) - hashSize:]) // hash

c := make([]byte, 0)
c = append(c, c1...)
c = append(c, c3...)
c = append(c, c2...)

data = c
return encryptedData{
XCoordinate: c1[:byteLen], // x分量
YCoordinate: c1[byteLen:], // y分量
Hash: c3, // hash
CipherText: c2, // cipherText
}, nil
case C1C3C2:
fallthrough
default:
data = data[1:]
}

return data, nil
c1 := data[0:2*byteLen]
c3 := data[2*byteLen:2*byteLen+hashSize]
c2 := data[2*byteLen+hashSize:]

return encryptedData{
XCoordinate: c1[:byteLen], // x分量
YCoordinate: c1[byteLen:], // y分量
Hash: c3, // hash
CipherText: c2, // cipherText
}, nil
}
}

// asn.1 编码
func marshalCipherASN1(curve elliptic.Curve, data []byte, mode Mode, h hashFunc) ([]byte, error) {
hashSize := h().Size()

func marshalCipherASN1(data encryptedData, mode Mode) ([]byte, error) {
if mode == C1C2C3 {
return marshalCipherASN1Old(curve, data, hashSize)
return marshalCipherASN1Old(data)
}

return marshalCipherASN1New(curve, data, hashSize)
return marshalCipherASN1New(data)
}

func unmarshalCipherASN1(curve elliptic.Curve, data []byte, mode Mode) ([]byte, error) {
func unmarshalCipherASN1(curve elliptic.Curve, data []byte, mode Mode) (encryptedData, error) {
if mode == C1C2C3 {
return unmarshalCipherASN1Old(curve, data)
}
Expand All @@ -151,36 +150,32 @@ type cipherASN1New struct {

// sm2 密文转 asn.1 编码格式
// sm2 密文结构: x + y + hash + CipherText
func marshalCipherASN1New(curve elliptic.Curve, data []byte, hashSize int) ([]byte, error) {
byteLen := (curve.Params().BitSize + 7) / 8

x := new(big.Int).SetBytes(data[:byteLen])
y := new(big.Int).SetBytes(data[byteLen:2*byteLen])

hash := data[2*byteLen:2*byteLen+hashSize]
cipherText := data[2*byteLen+hashSize:]

return asn1.Marshal(cipherASN1New{x, y, hash, cipherText})
func marshalCipherASN1New(data encryptedData) ([]byte, error) {
return asn1.Marshal(cipherASN1New{
XCoordinate: bytesToBigInt(data.XCoordinate),
YCoordinate: bytesToBigInt(data.YCoordinate),
HASH: data.Hash,
CipherText: data.CipherText,
})
}

// sm2 密文 asn.1 编码格式转 C1|C3|C2 拼接格式
func unmarshalCipherASN1New(curve elliptic.Curve, b []byte) ([]byte, error) {
func unmarshalCipherASN1New(curve elliptic.Curve, b []byte) (encryptedData, error) {
var data cipherASN1New
_, err := asn1.Unmarshal(b, &data)
if err != nil {
return nil, err
return encryptedData{}, err
}

xBuf := bigIntToBytes(curve, data.XCoordinate)
yBuf := bigIntToBytes(curve, data.YCoordinate)

c := []byte{}
c = append(c, xBuf...) // x分量
c = append(c, yBuf...) // y分量
c = append(c, data.HASH...) // hash
c = append(c, data.CipherText...) // cipherText
x := bigIntToBytes(curve, data.XCoordinate)
y := bigIntToBytes(curve, data.YCoordinate)

return c, nil
return encryptedData{
XCoordinate: x, // x分量
YCoordinate: y, // y分量
Hash: data.HASH, // hash
CipherText: data.CipherText, // cipherText
}, nil
}

// c1c2c3 格式
Expand All @@ -193,35 +188,31 @@ type cipherASN1Old struct {

// sm2 密文转 asn.1 编码格式
// sm2 密文结构: x + y + CipherText + hash
func marshalCipherASN1Old(curve elliptic.Curve, data []byte, hashSize int) ([]byte, error) {
byteLen := (curve.Params().BitSize + 7) / 8

x := new(big.Int).SetBytes(data[:byteLen])
y := new(big.Int).SetBytes(data[byteLen:2*byteLen])

hash := data[2*byteLen:2*byteLen+hashSize]
cipherText := data[2*byteLen+hashSize:]

return asn1.Marshal(cipherASN1Old{x, y, cipherText, hash})
func marshalCipherASN1Old(data encryptedData) ([]byte, error) {
return asn1.Marshal(cipherASN1Old{
XCoordinate: bytesToBigInt(data.XCoordinate),
YCoordinate: bytesToBigInt(data.YCoordinate),
CipherText: data.CipherText,
HASH: data.Hash,
})
}

// sm2 密文 asn.1 编码格式转 C1|C3|C2 拼接格式
func unmarshalCipherASN1Old(curve elliptic.Curve, b []byte) ([]byte, error) {
func unmarshalCipherASN1Old(curve elliptic.Curve, b []byte) (encryptedData, error) {
var data cipherASN1Old
_, err := asn1.Unmarshal(b, &data)
if err != nil {
return nil, err
return encryptedData{}, err
}

xBuf := bigIntToBytes(curve, data.XCoordinate)
yBuf := bigIntToBytes(curve, data.YCoordinate)

c := []byte{}
c = append(c, xBuf...) // x分量
c = append(c, yBuf...) // y分量
c = append(c, data.HASH...) // hash
c = append(c, data.CipherText...) // cipherText
x := bigIntToBytes(curve, data.XCoordinate)
y := bigIntToBytes(curve, data.YCoordinate)

return c, nil
return encryptedData{
XCoordinate: x, // x分量
YCoordinate: y, // y分量
CipherText: data.CipherText, // cipherText
Hash: data.HASH, // hash
}, nil
}

Loading

0 comments on commit 7bbb4c3

Please sign in to comment.