Skip to content

Commit

Permalink
优化 SM3 实现
Browse files Browse the repository at this point in the history
  • Loading branch information
deatil committed May 13, 2024
1 parent 4b5139c commit ca05e9a
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 207 deletions.
27 changes: 12 additions & 15 deletions hash/sm3/binary.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import (

const (
chunk = 64
magic256 = "sm3\x03"
marshaledSize = len(magic256) + 8*4 + chunk + 8
magic = "sm3\x03"
marshaledSize = len(magic) + 8*4 + chunk + 8
)

func (this *digest) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize)
b = append(b, magic256...)
b = append(b, magic...)

b = appendUint32(b, this.s[0])
b = appendUint32(b, this.s[1])
Expand All @@ -24,26 +24,24 @@ func (this *digest) MarshalBinary() ([]byte, error) {
b = appendUint32(b, this.s[6])
b = appendUint32(b, this.s[7])

b = append(b, this.x[:this.len]...)
b = append(b, this.x[:this.nx]...)

length := (this.nx * BlockSize) + uint64(this.len)

b = b[:len(b) + len(this.x) - int(this.len)]
b = appendUint64(b, length)
b = b[:len(b) + len(this.x) - int(this.nx)]
b = appendUint64(b, this.len)

return b, nil
}

func (this *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic256) || (string(b[:len(magic256)]) != magic256) {
return errors.New("sm3: invalid hash state identifier")
if len(b) < len(magic) || (string(b[:len(magic)]) != magic) {
return errors.New("go-hahs/sm3: invalid hash state identifier")
}

if len(b) != marshaledSize {
return errors.New("sm3: invalid hash state size")
return errors.New("go-hahs/sm3: invalid hash state size")
}

b = b[len(magic256):]
b = b[len(magic):]

b, this.s[0] = consumeUint32(b)
b, this.s[1] = consumeUint32(b)
Expand All @@ -57,11 +55,10 @@ func (this *digest) UnmarshalBinary(b []byte) error {
b = b[copy(this.x[:], b):]

var length uint64

b, length = consumeUint64(b)

this.len = int(length % chunk)
this.nx = length / chunk
this.nx = int(length % chunk)
this.len = length

return nil
}
Expand Down
36 changes: 36 additions & 0 deletions hash/sm3/binary_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package sm3

import (
"bytes"
"testing"
)

func Test_MarshalBinary(t *testing.T) {
msg := []byte("test-dd1111111dddddddatatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-data")

h := new(digest)
h.Reset()

h.Write(msg)
dst := h.Sum(nil)
if len(dst) == 0 {
t.Error("Hash make error")
}

bs, _ := h.MarshalBinary()

h.Reset()
err := h.UnmarshalBinary(bs)
if err != nil {
t.Fatal(err)
}

newdst := h.Sum(nil)
if len(newdst) == 0 {
t.Error("newHash make error")
}

if !bytes.Equal(newdst, dst) {
t.Errorf("Hash MarshalBinary error, got %x, want %x", newdst, dst)
}
}
91 changes: 39 additions & 52 deletions hash/sm3/digest.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ const BlockSize = 64
type digest struct {
s [8]uint32
x [BlockSize]byte
nx uint64
len int
nx int
len uint64
}

func newDigest() *digest {
Expand All @@ -20,64 +20,45 @@ func newDigest() *digest {
return d
}

func (this *digest) Size() int {
return Size
}

func (this *digest) BlockSize() int {
return BlockSize
}

func (this *digest) Reset() {
this.s = iv
this.x = [BlockSize]byte{}

this.nx = 0
this.len = 0
this.nx = 0
}

// Write is the interface for IO Writer
func (this *digest) Write(data []byte) (nn int, err error) {
nn = len(data)

var blocks int

dataLen := len(data)
func (this *digest) Size() int {
return Size
}

this.len &= 0x3f
if this.len > 0 {
var left int = BlockSize - this.len
func (this *digest) BlockSize() int {
return BlockSize
}

if dataLen < left {
copy(this.x[this.len:], data)
this.len += dataLen
// Write is the interface for IO Writer
func (this *digest) Write(p []byte) (nn int, err error) {
nn = len(p)

return
} else {
copy(this.x[this.len:], data[:left])
compressBlocks(this.s[:], this.x[:], 1)
this.len += uint64(nn)

this.nx++
this.nx &= 0x3f

data = data[left:]
dataLen -= left
}
}
plen := len(p)
for this.nx + plen >= BlockSize {
copy(this.x[this.nx:], p)

blocks = dataLen / BlockSize
if blocks > 0 {
compressBlocks(this.s[:], data, blocks)
this.processBlock(this.x[:])

this.nx += uint64(blocks)
xx := BlockSize - this.nx
plen -= xx

data = data[BlockSize * blocks:]
dataLen -= BlockSize * blocks
p = p[xx:]
this.nx = 0
}

this.len = dataLen
if dataLen > 0 {
copy(this.x[:], data)
}
copy(this.x[this.nx:], p)
this.nx += plen

return
}
Expand All @@ -92,23 +73,25 @@ func (this *digest) Sum(in []byte) []byte {
func (this *digest) checkSum() [Size]byte {
var i int32

this.len &= 0x3f
this.x[this.len] = 0x80
this.nx &= 0x3f
this.x[this.nx] = 0x80

zeros := make([]byte, BlockSize)

if this.len <= BlockSize - 9 {
copy(this.x[this.len + 1:BlockSize - 8], zeros)
if this.nx <= BlockSize - 9 {
copy(this.x[this.nx + 1:BlockSize - 8], zeros)
} else {
copy(this.x[this.len + 1:BlockSize], zeros)
compressBlocks(this.s[:], this.x[:], 1)
copy(this.x[this.nx + 1:BlockSize], zeros)
this.processBlock(this.x[:])
copy(this.x[:BlockSize - 8], zeros)
}

PUTU32(this.x[56:], uint32(this.nx >> 23))
PUTU32(this.x[60:], uint32((this.nx << 9) + uint64(this.len << 3)))
bcount := this.len / BlockSize

compressBlocks(this.s[:], this.x[:], 1)
PUTU32(this.x[56:], uint32(bcount >> 23))
PUTU32(this.x[60:], uint32((bcount << 9) + (uint64(this.nx) << 3)))

this.processBlock(this.x[:])

var digest [Size]byte
for i = 0; i < 8; i++ {
Expand All @@ -117,3 +100,7 @@ func (this *digest) checkSum() [Size]byte {

return digest
}

func (this *digest) processBlock(data []byte) {
compressBlocks(this.s[:], data)
}
19 changes: 19 additions & 0 deletions hash/sm3/sbox.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,22 @@ var iv = [8]uint32{
0x7380166F, 0x4914B2B9, 0x172442D7, 0xDA8A0600,
0xA96F30BC, 0x163138AA, 0xE38DEE4D, 0xB0FB0E4E,
}

var keys = [64]uint32{
0x79cc4519, 0xf3988a32, 0xe7311465, 0xce6228cb,
0x9cc45197, 0x3988a32f, 0x7311465e, 0xe6228cbc,
0xcc451979, 0x988a32f3, 0x311465e7, 0x6228cbce,
0xc451979c, 0x88a32f39, 0x11465e73, 0x228cbce6,
0x9d8a7a87, 0x3b14f50f, 0x7629ea1e, 0xec53d43c,
0xd8a7a879, 0xb14f50f3, 0x629ea1e7, 0xc53d43ce,
0x8a7a879d, 0x14f50f3b, 0x29ea1e76, 0x53d43cec,
0xa7a879d8, 0x4f50f3b1, 0x9ea1e762, 0x3d43cec5,
0x7a879d8a, 0xf50f3b14, 0xea1e7629, 0xd43cec53,
0xa879d8a7, 0x50f3b14f, 0xa1e7629e, 0x43cec53d,
0x879d8a7a, 0x0f3b14f5, 0x1e7629ea, 0x3cec53d4,
0x79d8a7a8, 0xf3b14f50, 0xe7629ea1, 0xcec53d43,
0x9d8a7a87, 0x3b14f50f, 0x7629ea1e, 0xec53d43c,
0xd8a7a879, 0xb14f50f3, 0x629ea1e7, 0xc53d43ce,
0x8a7a879d, 0x14f50f3b, 0x29ea1e76, 0x53d43cec,
0xa7a879d8, 0x4f50f3b1, 0x9ea1e762, 0x3d43cec5,
}
11 changes: 4 additions & 7 deletions hash/sm3/sm3.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,17 @@ import (
"hash"
)

// New returns a new hash.Hash computing the MD2 checksum.
// New returns a new hash.Hash computing the SM3 checksum.
func New() hash.Hash {
d := new(digest)
d.Reset()
return d
return newDigest()
}

// Sum returns the SM3 checksum of the data.
func Sum(data []byte) (sum [Size]byte) {
var h digest
h.Reset()
h := New()
h.Write(data)

hash := h.Sum(nil)

copy(sum[:], hash)
return
}
57 changes: 7 additions & 50 deletions hash/sm3/sm3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,6 @@ import (
"crypto/hmac"
)

func Test_Hash(t *testing.T) {
msg := []byte("test-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-data")

h := New()
h.Write(msg)
dst := h.Sum(nil)

if len(dst) == 0 {
t.Error("Hash make error")
}
}

type sm3Test struct {
out string
in string
Expand All @@ -33,14 +21,14 @@ var golden = []sm3Test{
{"520472cafdaf21d994c5849492ba802459472b5206503389fc81ff73adbec1b4", "abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabc"},
}

func TestGolden(t *testing.T) {
func Test_Check(t *testing.T) {
for i := 0; i < len(golden); i++ {
g := golden[i]
h := Sum([]byte(g.in))
s := fmt.Sprintf("%x", h)
sum := fmt.Sprintf("%x", h)

if s != g.out {
t.Fatalf("SM3 function: sm3(%s) = %s want %s", g.in, s, g.out)
if sum != g.out {
t.Fatalf("Sum: got %s, want %s", sum, g.out)
}

c := New()
Expand All @@ -53,47 +41,16 @@ func TestGolden(t *testing.T) {
io.WriteString(c, g.in[len(g.in)/2:])
}

s := fmt.Sprintf("%x", c.Sum(nil))
if s != g.out {
t.Fatalf("sm3[%d](%s) = %s want %s", j, g.in, s, g.out)
sum := fmt.Sprintf("%x", c.Sum(nil))
if sum != g.out {
t.Fatalf("New: got %s, want %s", sum, g.out)
}

c.Reset()
}
}
}

func Test_MarshalBinary(t *testing.T) {
msg := []byte("test-dd1111111dddddddatatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-datatest-data")

h := new(digest)
h.Reset()

h.Write(msg)
dst := h.Sum(nil)
if len(dst) == 0 {
t.Error("Hash make error")
}

bs, _ := h.MarshalBinary()

h.Reset()

err := h.UnmarshalBinary(bs)
if err != nil {
t.Fatal(err)
}

newdst := h.Sum(nil)
if len(newdst) == 0 {
t.Error("newHash make error")
}

if string(newdst) != string(dst) {
t.Error("Hash MarshalBinary error")
}
}

func Test_HmacSM3(t *testing.T) {
key := []byte("1234567812345678")
msg := []byte("abc")
Expand Down
Loading

0 comments on commit ca05e9a

Please sign in to comment.