package blockmode_test import ( "bytes" "crypto/cipher" "encoding/hex" "fmt" "testing" "time" "github.com/stretchr/testify/assert" "xdx.jelly/xgcl/grand" "xdx.jelly/xgcl/sm/sm4" "xdx.jelly/xgcl/utils/blockmode" ) type noopBlock int func (b noopBlock) BlockSize() int { return int(b) } func (noopBlock) Encrypt(dst, src []byte) { copy(dst, src) } func (noopBlock) Decrypt(dst, src []byte) { copy(dst, src) } func inc(b []byte) { for i := len(b) - 1; i >= 0; i++ { b[i]++ if b[i] != 0 { break } } } func xor(a, b []byte) { for i := range a { a[i] ^= b[i] } } func TestCTR(t *testing.T) { for size := 64; size <= 1024; size *= 2 { iv := make([]byte, size) ctr := cipher.NewCTR(noopBlock(size), iv) src := make([]byte, 1024) for i := range src { src[i] = 0xff } want := make([]byte, 1024) copy(want, src) counter := make([]byte, size) for i := 1; i < len(want)/size; i++ { inc(counter) xor(want[i*size:(i+1)*size], counter) } dst := make([]byte, 1024) ctr.XORKeyStream(dst, src) if !bytes.Equal(dst, want) { t.Errorf("for size %d\nhave %x\nwant %x", size, dst, want) } } } func FuzzSm4Ctr(f *testing.F) { iv := grand.GetRandom(16) key := grand.GetRandom(16) key, _ = hex.DecodeString("70fe9d4cdf29d1db1549a44d70bf28fb") iv, _ = hex.DecodeString("edf1631376519ddc9654cec2900060de") block, _ := sm4.NewCipher(key) mode := blockmode.Wrap(block) f.Add([]byte{}, []byte{}) f.Fuzz(func(t *testing.T, plaintext, ad []byte) { stdctr := cipher.NewCTR(block, iv) ctr, err := blockmode.NewCTR(mode, iv) if err != nil { t.Fatal(err) } stdct := make([]byte, len(plaintext)) stdctr.XORKeyStream(stdct, plaintext) ct1 := make([]byte, len(plaintext)) ctr.XORKeyStream(ct1, plaintext) if bytes.Compare(ct1, stdct) != 0 { t.Errorf("XORKeyStream failed") } if err := ctr.EncryptInit(iv); err != nil { t.Fatal(err) } ct2, err := ctr.EncryptUpdate(nil, plaintext[:len(plaintext)/2]) if err != nil { t.Fatal(err) } if ct2, err = ctr.EncryptUpdate(ct2, plaintext[len(plaintext)/2:]); err != nil { t.Fatal(err) } ct2, err = ctr.EncryptFinal(ct2) if err != nil { t.Fatal(err) } if bytes.Compare(ct2, stdct) != 0 { t.Errorf("Encrypt failed") } if err := ctr.DecryptInit(iv); err != nil { t.Fatal(err) } pt2, err := ctr.DecryptUpdate(nil, ct2[:len(ct2)/2]) if err != nil { t.Fatal(err) } if pt2, err = ctr.DecryptUpdate(pt2, ct2[len(ct2)/2:]); err != nil { t.Fatal(err) } pt2, err = ctr.DecryptFinal(pt2) if err != nil { t.Fatal(err) } if bytes.Compare(pt2, plaintext) != 0 { t.Errorf("Decrypt failed") } }) } func TestCTRData(t *testing.T) { size := 1024 key := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10} iv := []byte{0xF0, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF} src := make([]byte, size) b, err := sm4.NewCipher(key) assert.Nil(t, err) ctr, err := blockmode.NewCTR(blockmode.Wrap(b), iv) for i := range src { src[i] = byte(i) } ctr.EncryptInit(iv) dst, err := ctr.EncryptUpdate(nil, src) assert.Nil(t, err) dst, err = ctr.EncryptFinal(dst) assert.Nil(t, err) fmt.Println("src:") for i := range src { fmt.Printf("0x%02x, ", src[i]) if (i+1)%32 == 0 { fmt.Println("") } } fmt.Println("dst:") for i := range dst { fmt.Printf("0x%02x, ", dst[i]) if (i+1)%32 == 0 { fmt.Println("") } } } func TestCTRSpeed(t *testing.T) { size := 1024 key := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10} iv := []byte{0xF0, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF} src := make([]byte, size) b, err := sm4.NewCipher(key) assert.Nil(t, err) ctr, err := blockmode.NewCTR(blockmode.Wrap(b), iv) for i := range src { src[i] = byte(i) } assert.Nil(t, err) start := time.Now() times := 10000 for i := 0; i < times; i++ { ctr.EncryptInit(iv) dst, _ := ctr.EncryptUpdate(nil, src) dst, _ = ctr.EncryptFinal(dst) } end := time.Now() elapsed := end.Sub(start) t.Log("SM4 Encrypt: ", int(float64(times*len(src))/float64(1024*1024)/elapsed.Seconds()), "MBps") }