/// /// Copyright (c) 2018 xdx. All rights reserved. /// /// \file: encryption.go /// /// \brief: SM2加解密 /// /// \author: xdx /// package sm2 import ( "bytes" "encoding/asn1" "encoding/hex" "math/big" "strings" "xdx.jelly/xgcl/gerrors" "xdx.jelly/xgcl/utils/objectpool" "xdx.jelly/xgcl/gmath" "xdx.jelly/xgcl/grand" "xdx.jelly/xgcl/sm/sm2/ec256" "xdx.jelly/xgcl/sm/sm3" ) // DefaultCipherLength is the default cipher length when init const DefaultCipherLength = 128 // Cipher 密文结构 // (x,y): C1 // hash: C3 // c: C2 type Cipher struct { X, Y *big.Int Hash [32]byte C []byte } // NewCipher return a new instance of Cipher func NewCipher() *Cipher { return &Cipher{ X: new(big.Int), Y: new(big.Int), C: make([]byte, 0, DefaultCipherLength), } } // Normalize reduce C1's x,y func (c *Cipher) Normalize() *Cipher { if c.X != nil { c.X.Mod(c.X, orderN) } if c.Y != nil { c.Y.Mod(c.Y, orderN) } return c } // Set c to src func (c *Cipher) Set(src *Cipher) *Cipher { c.X.Set(src.X) c.Y.Set(src.Y) c.Hash = src.Hash c.C = append(c.C[:0], src.C...) return c } // cipherASN1 helper struct for ASN.1 encoding/decoding Cipher using asn1.Marshal. type cipherASN1 struct { X, Y *big.Int Hash []byte C []byte } // MarshalASN1 marshal c to DER encoded and append to data. func (c *Cipher) MarshalASN1() ([]byte, error) { return asn1.Marshal(cipherASN1{ X: c.X, Y: c.Y, Hash: c.Hash[:], C: c.C, }) } // UnmarshalASN1 unmarshal c from DER encoded data. func (c *Cipher) UnmarshalASN1(b []byte) (rest []byte, err error) { var ca cipherASN1 rest, err = asn1.Unmarshal(b, &ca) if err != nil { return b, err } c.X = ca.X c.Y = ca.Y copy(c.Hash[:], ca.Hash) c.C = append(c.C, ca.C...) return rest, nil } // MarshalUtil implements the gcl/util/encoding/UtilMarshaler interface func (c *Cipher) MarshalUtil(data []byte) ([]byte, error) { if data == nil { data = make([]byte, 0, 2*ECCRefMaxLen+32+4+DefaultCipherLength) } c.X.Mod(c.X, orderN) xBytes := gmath.BigIntToNByte(c.X, ECCRefMaxLen) data = append(data, xBytes...) c.Y.Mod(c.Y, orderN) yBytes := gmath.BigIntToNByte(c.Y, ECCRefMaxLen) data = append(data, yBytes...) data = append(data, c.Hash[:]...) buf := []byte{0, 0, 0, 0} Endian.PutUint32(buf, uint32(len(c.C))) data = append(data, buf...) data = append(data, c.C...) return data, nil } // UnmarshalUtil marshal the Cipher append to data func (c *Cipher) UnmarshalUtil(data []byte) (uint64, error) { n := uint64(0) // consumed bytes. if len(data) < 2*ECCRefMaxLen+32+4 { return 0, gerrors.WithAnnotatingf(ErrInvalidInput, "input(%d bytes) must be at least %d bytes", len(data), 2*ECCRefMaxLen+32+4) } x := new(big.Int).SetBytes(data[:ECCRefMaxLen]) data = data[ECCRefMaxLen:] n += ECCRefMaxLen y := new(big.Int).SetBytes(data[:ECCRefMaxLen]) data = data[ECCRefMaxLen:] n += ECCRefMaxLen if x.Cmp(orderN) >= 0 || x.Sign() == 0 { return 0, gerrors.WithAnnotating(ErrInvalidInput, "x is bigger then the order N") } if y.Cmp(orderN) >= 0 || y.Sign() == 0 { return 0, gerrors.WithAnnotating(ErrInvalidInput, "y is bigger then the order N") } if !sm2Curve.IsOnCurve(x, y) { return 0, gerrors.WithAnnotating(ErrDecFailed, "C1 is not a valid curve point") } c.X.Set(x) c.Y.Set(y) copy(c.Hash[:], data[:32]) data = data[32:] n += 32 clen := Endian.Uint32(data) data = data[4:] n += 4 if len(data) < int(clen) { return 0, gerrors.WithAnnotating(ErrInvalidInput, "C2 is too short") } c.C = append(c.C[:0], data[:clen]...) n += uint64(clen) // return the rest data return n, nil } // MarshalBinary implements the encoding.BinaryMarshaler interface // 返回字节符合GMT 0018的定义。 // x||y||m||L||c func (c *Cipher) MarshalBinary() ([]byte, error) { //data := make([]byte, 2*ECCRefMaxLen+32+4+len(c.C)) data := objectpool.GetBytes() c.X.Mod(c.X, orderN) xBytes := gmath.BigIntToNByte(c.X, ECCRefMaxLen) data = append(data, xBytes...) c.Y.Mod(c.Y, orderN) yBytes := gmath.BigIntToNByte(c.Y, ECCRefMaxLen) data = append(data, yBytes...) data = append(data, c.Hash[:]...) buf := []byte{0, 0, 0, 0} Endian.PutUint32(buf, uint32(len(c.C))) data = append(data, buf...) data = append(data, c.C...) return data, nil } // UnmarshalBinary implements the encoding.BinaryUnmarshaler interface // 输入字节应符合GMT 0018的定义。 func (c *Cipher) UnmarshalBinary(data []byte) error { if len(data) < 2*ECCRefMaxLen+32+4 { return gerrors.WithAnnotatingf(ErrInvalidInput, "input(%d bytes) must be at least %d bytes", len(data), 2*ECCRefMaxLen+32+4) } x := new(big.Int).SetBytes(data[:ECCRefMaxLen]) data = data[ECCRefMaxLen:] y := new(big.Int).SetBytes(data[:ECCRefMaxLen]) data = data[ECCRefMaxLen:] if x.Cmp(orderN) >= 0 || x.Sign() == 0 { return gerrors.WithAnnotating(ErrInvalidInput, "x is bigger then the order N") } if y.Cmp(orderN) >= 0 || y.Sign() == 0 { return gerrors.WithAnnotating(ErrInvalidInput, "y is bigger then the order N") } if !sm2Curve.IsOnCurve(x, y) { return gerrors.WithAnnotating(ErrInvalidInput, "(x,y) is not on the curve") } c.X.Set(x) c.Y.Set(y) copy(c.Hash[:], data[:32]) data = data[32:] clen := Endian.Uint32(data) data = data[4:] if len(data) < int(clen) { return gerrors.WithAnnotating(ErrInvalidInput, "C2 is too short") } c.C = append(c.C[:0], data[:clen]...) return nil } // Bytes 转换密文结构为byte切片 func (c *Cipher) Bytes() []byte { var buf bytes.Buffer buf.Write(gmath.BigIntToNByte(c.X, byteSize)) buf.Write(gmath.BigIntToNByte(c.Y, byteSize)) buf.Write(c.Hash[:]) buf.Write(c.C) return buf.Bytes() } // SetBytes 转换[]byte为Cipher func (c *Cipher) SetBytes(data []byte) error { if len(data) < 2*byteSize+32 { return gerrors.WithAnnotatingf(ErrInvalidInput, "input(%d bytes) must be at least %d bytes", len(data), 2*byteSize+32) } if c.X == nil { c.X = new(big.Int) } if c.Y == nil { c.Y = new(big.Int) } c.X.SetBytes(data[:byteSize]) c.Y.SetBytes(data[byteSize : 2*byteSize]) copy(c.Hash[:], data[2*byteSize:2*byteSize+32]) c.C = append(c.C[:0], data[2*byteSize+32:]...) return nil } // String return a readable string func (c *Cipher) String() string { var buf strings.Builder buf.WriteString("x: ") buf.WriteString(hex.EncodeToString(c.X.Bytes())) buf.WriteString("\ny: ") buf.WriteString(hex.EncodeToString(c.Y.Bytes())) buf.WriteString("\nc: ") buf.WriteString(hex.EncodeToString(c.C)) buf.WriteString("\nhash: ") buf.WriteString(hex.EncodeToString(c.Hash[:])) return buf.String() // return hex.EncodeToString(c.Bytes()) } // Encrypt 加密 func Encrypt(pk *PublicKey, data, rnd []byte) (*Cipher, error) { // k := new(big.Int).SetBytes(rnd) if rnd == nil || len(rnd) < byteSize { rnd = grand.GetRandom(byteSize) } // if !pk.IsValid() { // return nil, gerrors.ERR_SM2_INVALID_PUBKEY // } var x, y *big.Int cipher := new(Cipher) xBytes := make([]byte, byteSize) yBytes := make([]byte, byteSize) var err error outer: for { cipher.X, cipher.Y = sm2Curve.ScalarBaseMult(rnd[:byteSize]) x, y = sm2Curve.ScalarMult(pk.X, pk.Y, rnd[:byteSize]) if err = gmath.FillBytes(x, xBytes); err != nil { return nil, gerrors.WithAnnotating(ErrEncFailed, "x is too big") } if err = gmath.FillBytes(y, yBytes); err != nil { return nil, gerrors.WithAnnotating(ErrEncFailed, "y is too big") } cipher.C = make([]byte, len(data)) if len(data) == 0 { break } Kdf(cipher.C, xBytes, yBytes) for _, k := range cipher.C { if k != 0 { break outer } } if _, err = grand.GenerateRandom(rnd); err != nil { return nil, gerrors.ChainErrors(gerrors.WithAnnotating(ErrEncFailed, "generate random failed"), err) } } // 两个对C的遍历,上面那个不会执行太多。 for i := range cipher.C { cipher.C[i] ^= data[i] } digest := sm3.New() // need padding 0 if x < 2*248 digest.Write(xBytes) digest.Write(data) digest.Write(yBytes) digest.Sum(cipher.Hash[:0]) return cipher, nil } // Decrypt 解密 // 返回明文 func Decrypt(sk *PrivateKey, cipher *Cipher) (plainText []byte, err error) { if !ec256.Curve256.IsOnCurve(cipher.X, cipher.Y) { return nil, gerrors.WithAnnotating(ErrDecFailed, "C1 is not a valid curve point") } x, y := ec256.Curve256.ScalarMult(cipher.X, cipher.Y, sk.Bytes()) return Decrypt_aux(x, y, cipher) } // Decrypt_aux common decrypt function func Decrypt_aux(x, y *big.Int, cipher *Cipher) (plainText []byte, err error) { plainText = make([]byte, len(cipher.C)) xBytes := make([]byte, byteSize) yBytes := make([]byte, byteSize) if err = gmath.FillBytes(x, xBytes); err != nil { return nil, gerrors.WithAnnotating(ErrDecFailed, "x is too big") } if err = gmath.FillBytes(y, yBytes); err != nil { return nil, gerrors.WithAnnotating(ErrDecFailed, "y is too big") } if len(plainText) == 0 { goto Next } if err := Kdf(plainText, xBytes, yBytes); err != nil { return nil, gerrors.ChainErrors(ErrDecFailed, err) } for _, k := range plainText { if k != 0 { goto Next } } return nil, gerrors.WithAnnotating(ErrDecFailed, "t is all zero while decryption") Next: for i := range plainText { plainText[i] ^= cipher.C[i] } digest := sm3.New() digest.Write(xBytes) digest.Write(plainText) digest.Write(yBytes) d := digest.Sum(nil) if !bytes.Equal(d, cipher.Hash[:]) { return nil, gerrors.WithAnnotating(ErrDecFailed, "mac check failed (U!=C3) while decryption") } return plainText, nil }