188 lines
4.2 KiB
Go
188 lines
4.2 KiB
Go
package blockmode_test
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/cipher"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"xdx.jelly/xgcl/grand"
|
|
"xdx.jelly/xgcl/sm/sm4"
|
|
"xdx.jelly/xgcl/utils/blockmode"
|
|
)
|
|
|
|
type noopBlock int
|
|
|
|
func (b noopBlock) BlockSize() int { return int(b) }
|
|
func (noopBlock) Encrypt(dst, src []byte) { copy(dst, src) }
|
|
func (noopBlock) Decrypt(dst, src []byte) { copy(dst, src) }
|
|
|
|
func inc(b []byte) {
|
|
for i := len(b) - 1; i >= 0; i++ {
|
|
b[i]++
|
|
if b[i] != 0 {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func xor(a, b []byte) {
|
|
for i := range a {
|
|
a[i] ^= b[i]
|
|
}
|
|
}
|
|
|
|
func TestCTR(t *testing.T) {
|
|
for size := 64; size <= 1024; size *= 2 {
|
|
iv := make([]byte, size)
|
|
ctr := cipher.NewCTR(noopBlock(size), iv)
|
|
src := make([]byte, 1024)
|
|
for i := range src {
|
|
src[i] = 0xff
|
|
}
|
|
want := make([]byte, 1024)
|
|
copy(want, src)
|
|
counter := make([]byte, size)
|
|
for i := 1; i < len(want)/size; i++ {
|
|
inc(counter)
|
|
xor(want[i*size:(i+1)*size], counter)
|
|
}
|
|
dst := make([]byte, 1024)
|
|
ctr.XORKeyStream(dst, src)
|
|
if !bytes.Equal(dst, want) {
|
|
t.Errorf("for size %d\nhave %x\nwant %x", size, dst, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func FuzzSm4Ctr(f *testing.F) {
|
|
iv := grand.GetRandom(16)
|
|
key := grand.GetRandom(16)
|
|
|
|
key, _ = hex.DecodeString("70fe9d4cdf29d1db1549a44d70bf28fb")
|
|
iv, _ = hex.DecodeString("edf1631376519ddc9654cec2900060de")
|
|
block, _ := sm4.NewCipher(key)
|
|
mode := blockmode.Wrap(block)
|
|
|
|
f.Add([]byte{}, []byte{})
|
|
f.Fuzz(func(t *testing.T, plaintext, ad []byte) {
|
|
stdctr := cipher.NewCTR(block, iv)
|
|
ctr, err := blockmode.NewCTR(mode, iv)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
stdct := make([]byte, len(plaintext))
|
|
stdctr.XORKeyStream(stdct, plaintext)
|
|
ct1 := make([]byte, len(plaintext))
|
|
ctr.XORKeyStream(ct1, plaintext)
|
|
|
|
if bytes.Compare(ct1, stdct) != 0 {
|
|
t.Errorf("XORKeyStream failed")
|
|
}
|
|
|
|
if err := ctr.EncryptInit(iv); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
ct2, err := ctr.EncryptUpdate(nil, plaintext[:len(plaintext)/2])
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ct2, err = ctr.EncryptUpdate(ct2, plaintext[len(plaintext)/2:]); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
ct2, err = ctr.EncryptFinal(ct2)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if bytes.Compare(ct2, stdct) != 0 {
|
|
t.Errorf("Encrypt failed")
|
|
}
|
|
|
|
if err := ctr.DecryptInit(iv); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pt2, err := ctr.DecryptUpdate(nil, ct2[:len(ct2)/2])
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if pt2, err = ctr.DecryptUpdate(pt2, ct2[len(ct2)/2:]); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pt2, err = ctr.DecryptFinal(pt2)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if bytes.Compare(pt2, plaintext) != 0 {
|
|
t.Errorf("Decrypt failed")
|
|
}
|
|
|
|
})
|
|
}
|
|
|
|
func TestCTRData(t *testing.T) {
|
|
size := 1024
|
|
key := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10}
|
|
iv := []byte{0xF0, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF}
|
|
src := make([]byte, size)
|
|
b, err := sm4.NewCipher(key)
|
|
assert.Nil(t, err)
|
|
ctr, err := blockmode.NewCTR(blockmode.Wrap(b), iv)
|
|
|
|
for i := range src {
|
|
src[i] = byte(i)
|
|
}
|
|
|
|
ctr.EncryptInit(iv)
|
|
dst, err := ctr.EncryptUpdate(nil, src)
|
|
assert.Nil(t, err)
|
|
|
|
dst, err = ctr.EncryptFinal(dst)
|
|
assert.Nil(t, err)
|
|
fmt.Println("src:")
|
|
for i := range src {
|
|
fmt.Printf("0x%02x, ", src[i])
|
|
if (i+1)%32 == 0 {
|
|
fmt.Println("")
|
|
}
|
|
}
|
|
fmt.Println("dst:")
|
|
for i := range dst {
|
|
fmt.Printf("0x%02x, ", dst[i])
|
|
if (i+1)%32 == 0 {
|
|
fmt.Println("")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCTRSpeed(t *testing.T) {
|
|
size := 1024
|
|
key := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10}
|
|
iv := []byte{0xF0, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF}
|
|
src := make([]byte, size)
|
|
b, err := sm4.NewCipher(key)
|
|
assert.Nil(t, err)
|
|
ctr, err := blockmode.NewCTR(blockmode.Wrap(b), iv)
|
|
|
|
for i := range src {
|
|
src[i] = byte(i)
|
|
}
|
|
|
|
assert.Nil(t, err)
|
|
|
|
start := time.Now()
|
|
times := 10000
|
|
for i := 0; i < times; i++ {
|
|
ctr.EncryptInit(iv)
|
|
dst, _ := ctr.EncryptUpdate(nil, src)
|
|
dst, _ = ctr.EncryptFinal(dst)
|
|
}
|
|
end := time.Now()
|
|
elapsed := end.Sub(start)
|
|
t.Log("SM4 Encrypt: ", int(float64(times*len(src))/float64(1024*1024)/elapsed.Seconds()), "MBps")
|
|
|
|
}
|