381 lines
9.3 KiB
Go
381 lines
9.3 KiB
Go
///
|
|
/// Copyright (c) 2018 xdx. All rights reserved.
|
|
///
|
|
/// \file: encryption.go
|
|
///
|
|
/// \brief: SM2加解密
|
|
///
|
|
/// \author: xdx
|
|
///
|
|
|
|
package sm2
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/asn1"
|
|
"encoding/hex"
|
|
"math/big"
|
|
"strings"
|
|
|
|
"xdx.jelly/xgcl/gerrors"
|
|
"xdx.jelly/xgcl/utils/objectpool"
|
|
|
|
"xdx.jelly/xgcl/gmath"
|
|
"xdx.jelly/xgcl/grand"
|
|
"xdx.jelly/xgcl/sm/sm2/ec256"
|
|
"xdx.jelly/xgcl/sm/sm3"
|
|
)
|
|
|
|
// DefaultCipherLength is the default cipher length when init
|
|
const DefaultCipherLength = 128
|
|
|
|
// Cipher 密文结构
|
|
// (x,y): C1
|
|
// hash: C3
|
|
// c: C2
|
|
type Cipher struct {
|
|
X, Y *big.Int
|
|
Hash [32]byte
|
|
C []byte
|
|
}
|
|
|
|
// NewCipher return a new instance of Cipher
|
|
func NewCipher() *Cipher {
|
|
return &Cipher{
|
|
X: new(big.Int),
|
|
Y: new(big.Int),
|
|
C: make([]byte, 0, DefaultCipherLength),
|
|
}
|
|
}
|
|
|
|
// Normalize reduce C1's x,y
|
|
func (c *Cipher) Normalize() *Cipher {
|
|
if c.X != nil {
|
|
c.X.Mod(c.X, orderN)
|
|
}
|
|
if c.Y != nil {
|
|
c.Y.Mod(c.Y, orderN)
|
|
}
|
|
return c
|
|
}
|
|
|
|
// Set c to src
|
|
func (c *Cipher) Set(src *Cipher) *Cipher {
|
|
c.X.Set(src.X)
|
|
c.Y.Set(src.Y)
|
|
c.Hash = src.Hash
|
|
c.C = append(c.C[:0], src.C...)
|
|
return c
|
|
}
|
|
|
|
// cipherASN1 helper struct for ASN.1 encoding/decoding Cipher using asn1.Marshal.
|
|
type cipherASN1 struct {
|
|
X, Y *big.Int
|
|
Hash []byte
|
|
C []byte
|
|
}
|
|
|
|
// MarshalASN1 marshal c to DER encoded and append to data.
|
|
func (c *Cipher) MarshalASN1() ([]byte, error) {
|
|
return asn1.Marshal(cipherASN1{
|
|
X: c.X,
|
|
Y: c.Y,
|
|
Hash: c.Hash[:],
|
|
C: c.C,
|
|
})
|
|
}
|
|
|
|
// UnmarshalASN1 unmarshal c from DER encoded data.
|
|
func (c *Cipher) UnmarshalASN1(b []byte) (rest []byte, err error) {
|
|
var ca cipherASN1
|
|
|
|
rest, err = asn1.Unmarshal(b, &ca)
|
|
if err != nil {
|
|
return b, err
|
|
}
|
|
c.X = ca.X
|
|
c.Y = ca.Y
|
|
copy(c.Hash[:], ca.Hash)
|
|
c.C = append(c.C, ca.C...)
|
|
return rest, nil
|
|
}
|
|
|
|
// MarshalUtil implements the gcl/util/encoding/UtilMarshaler interface
|
|
func (c *Cipher) MarshalUtil(data []byte) ([]byte, error) {
|
|
if data == nil {
|
|
data = make([]byte, 0, 2*ECCRefMaxLen+32+4+DefaultCipherLength)
|
|
}
|
|
c.X.Mod(c.X, orderN)
|
|
xBytes := gmath.BigIntToNByte(c.X, ECCRefMaxLen)
|
|
data = append(data, xBytes...)
|
|
|
|
c.Y.Mod(c.Y, orderN)
|
|
yBytes := gmath.BigIntToNByte(c.Y, ECCRefMaxLen)
|
|
data = append(data, yBytes...)
|
|
data = append(data, c.Hash[:]...)
|
|
buf := []byte{0, 0, 0, 0}
|
|
Endian.PutUint32(buf, uint32(len(c.C)))
|
|
data = append(data, buf...)
|
|
data = append(data, c.C...)
|
|
return data, nil
|
|
}
|
|
|
|
// UnmarshalUtil marshal the Cipher append to data
|
|
func (c *Cipher) UnmarshalUtil(data []byte) (uint64, error) {
|
|
n := uint64(0) // consumed bytes.
|
|
if len(data) < 2*ECCRefMaxLen+32+4 {
|
|
return 0, gerrors.WithAnnotatingf(ErrInvalidInput, "input(%d bytes) must be at least %d bytes", len(data), 2*ECCRefMaxLen+32+4)
|
|
}
|
|
|
|
x := new(big.Int).SetBytes(data[:ECCRefMaxLen])
|
|
data = data[ECCRefMaxLen:]
|
|
n += ECCRefMaxLen
|
|
|
|
y := new(big.Int).SetBytes(data[:ECCRefMaxLen])
|
|
data = data[ECCRefMaxLen:]
|
|
n += ECCRefMaxLen
|
|
|
|
if x.Cmp(orderN) >= 0 || x.Sign() == 0 {
|
|
return 0, gerrors.WithAnnotating(ErrInvalidInput, "x is bigger then the order N")
|
|
}
|
|
if y.Cmp(orderN) >= 0 || y.Sign() == 0 {
|
|
return 0, gerrors.WithAnnotating(ErrInvalidInput, "y is bigger then the order N")
|
|
}
|
|
if !sm2Curve.IsOnCurve(x, y) {
|
|
return 0, gerrors.WithAnnotating(ErrDecFailed, "C1 is not a valid curve point")
|
|
}
|
|
|
|
c.X.Set(x)
|
|
c.Y.Set(y)
|
|
|
|
copy(c.Hash[:], data[:32])
|
|
data = data[32:]
|
|
n += 32
|
|
|
|
clen := Endian.Uint32(data)
|
|
data = data[4:]
|
|
n += 4
|
|
|
|
if len(data) < int(clen) {
|
|
return 0, gerrors.WithAnnotating(ErrInvalidInput, "C2 is too short")
|
|
}
|
|
|
|
c.C = append(c.C[:0], data[:clen]...)
|
|
n += uint64(clen)
|
|
// return the rest data
|
|
return n, nil
|
|
}
|
|
|
|
// MarshalBinary implements the encoding.BinaryMarshaler interface
|
|
// 返回字节符合GMT 0018的定义。
|
|
// x||y||m||L||c
|
|
func (c *Cipher) MarshalBinary() ([]byte, error) {
|
|
//data := make([]byte, 2*ECCRefMaxLen+32+4+len(c.C))
|
|
data := objectpool.GetBytes()
|
|
c.X.Mod(c.X, orderN)
|
|
xBytes := gmath.BigIntToNByte(c.X, ECCRefMaxLen)
|
|
data = append(data, xBytes...)
|
|
|
|
c.Y.Mod(c.Y, orderN)
|
|
yBytes := gmath.BigIntToNByte(c.Y, ECCRefMaxLen)
|
|
data = append(data, yBytes...)
|
|
data = append(data, c.Hash[:]...)
|
|
buf := []byte{0, 0, 0, 0}
|
|
Endian.PutUint32(buf, uint32(len(c.C)))
|
|
data = append(data, buf...)
|
|
data = append(data, c.C...)
|
|
return data, nil
|
|
}
|
|
|
|
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface
|
|
// 输入字节应符合GMT 0018的定义。
|
|
func (c *Cipher) UnmarshalBinary(data []byte) error {
|
|
if len(data) < 2*ECCRefMaxLen+32+4 {
|
|
return gerrors.WithAnnotatingf(ErrInvalidInput, "input(%d bytes) must be at least %d bytes", len(data), 2*ECCRefMaxLen+32+4)
|
|
}
|
|
|
|
x := new(big.Int).SetBytes(data[:ECCRefMaxLen])
|
|
data = data[ECCRefMaxLen:]
|
|
y := new(big.Int).SetBytes(data[:ECCRefMaxLen])
|
|
data = data[ECCRefMaxLen:]
|
|
|
|
if x.Cmp(orderN) >= 0 || x.Sign() == 0 {
|
|
return gerrors.WithAnnotating(ErrInvalidInput, "x is bigger then the order N")
|
|
}
|
|
if y.Cmp(orderN) >= 0 || y.Sign() == 0 {
|
|
return gerrors.WithAnnotating(ErrInvalidInput, "y is bigger then the order N")
|
|
}
|
|
if !sm2Curve.IsOnCurve(x, y) {
|
|
return gerrors.WithAnnotating(ErrInvalidInput, "(x,y) is not on the curve")
|
|
}
|
|
|
|
c.X.Set(x)
|
|
c.Y.Set(y)
|
|
|
|
copy(c.Hash[:], data[:32])
|
|
data = data[32:]
|
|
|
|
clen := Endian.Uint32(data)
|
|
data = data[4:]
|
|
|
|
if len(data) < int(clen) {
|
|
return gerrors.WithAnnotating(ErrInvalidInput, "C2 is too short")
|
|
}
|
|
|
|
c.C = append(c.C[:0], data[:clen]...)
|
|
return nil
|
|
}
|
|
|
|
// Bytes 转换密文结构为byte切片
|
|
func (c *Cipher) Bytes() []byte {
|
|
var buf bytes.Buffer
|
|
buf.Write(gmath.BigIntToNByte(c.X, byteSize))
|
|
buf.Write(gmath.BigIntToNByte(c.Y, byteSize))
|
|
buf.Write(c.Hash[:])
|
|
buf.Write(c.C)
|
|
return buf.Bytes()
|
|
}
|
|
|
|
// SetBytes 转换[]byte为Cipher
|
|
func (c *Cipher) SetBytes(data []byte) error {
|
|
if len(data) < 2*byteSize+32 {
|
|
return gerrors.WithAnnotatingf(ErrInvalidInput, "input(%d bytes) must be at least %d bytes", len(data), 2*byteSize+32)
|
|
}
|
|
if c.X == nil {
|
|
c.X = new(big.Int)
|
|
}
|
|
if c.Y == nil {
|
|
c.Y = new(big.Int)
|
|
}
|
|
c.X.SetBytes(data[:byteSize])
|
|
c.Y.SetBytes(data[byteSize : 2*byteSize])
|
|
copy(c.Hash[:], data[2*byteSize:2*byteSize+32])
|
|
c.C = append(c.C[:0], data[2*byteSize+32:]...)
|
|
return nil
|
|
}
|
|
|
|
// String return a readable string
|
|
func (c *Cipher) String() string {
|
|
var buf strings.Builder
|
|
buf.WriteString("x: ")
|
|
buf.WriteString(hex.EncodeToString(c.X.Bytes()))
|
|
buf.WriteString("\ny: ")
|
|
buf.WriteString(hex.EncodeToString(c.Y.Bytes()))
|
|
buf.WriteString("\nc: ")
|
|
buf.WriteString(hex.EncodeToString(c.C))
|
|
buf.WriteString("\nhash: ")
|
|
buf.WriteString(hex.EncodeToString(c.Hash[:]))
|
|
|
|
return buf.String()
|
|
// return hex.EncodeToString(c.Bytes())
|
|
}
|
|
|
|
// Encrypt 加密
|
|
func Encrypt(pk *PublicKey, data, rnd []byte) (*Cipher, error) {
|
|
// k := new(big.Int).SetBytes(rnd)
|
|
if rnd == nil || len(rnd) < byteSize {
|
|
rnd = grand.GetRandom(byteSize)
|
|
}
|
|
// if !pk.IsValid() {
|
|
// return nil, gerrors.ERR_SM2_INVALID_PUBKEY
|
|
// }
|
|
var x, y *big.Int
|
|
cipher := new(Cipher)
|
|
xBytes := make([]byte, byteSize)
|
|
yBytes := make([]byte, byteSize)
|
|
var err error
|
|
outer:
|
|
for {
|
|
cipher.X, cipher.Y = sm2Curve.ScalarBaseMult(rnd[:byteSize])
|
|
x, y = sm2Curve.ScalarMult(pk.X, pk.Y, rnd[:byteSize])
|
|
if err = gmath.FillBytes(x, xBytes); err != nil {
|
|
return nil, gerrors.WithAnnotating(ErrEncFailed, "x is too big")
|
|
}
|
|
|
|
if err = gmath.FillBytes(y, yBytes); err != nil {
|
|
return nil, gerrors.WithAnnotating(ErrEncFailed, "y is too big")
|
|
}
|
|
|
|
cipher.C = make([]byte, len(data))
|
|
if len(data) == 0 {
|
|
break
|
|
}
|
|
Kdf(cipher.C, xBytes, yBytes)
|
|
for _, k := range cipher.C {
|
|
if k != 0 {
|
|
break outer
|
|
}
|
|
}
|
|
if _, err = grand.GenerateRandom(rnd); err != nil {
|
|
return nil, gerrors.ChainErrors(gerrors.WithAnnotating(ErrEncFailed, "generate random failed"), err)
|
|
}
|
|
}
|
|
|
|
// 两个对C的遍历,上面那个不会执行太多。
|
|
for i := range cipher.C {
|
|
cipher.C[i] ^= data[i]
|
|
}
|
|
|
|
digest := sm3.New()
|
|
// need padding 0 if x < 2*248
|
|
digest.Write(xBytes)
|
|
digest.Write(data)
|
|
digest.Write(yBytes)
|
|
digest.Sum(cipher.Hash[:0])
|
|
return cipher, nil
|
|
}
|
|
|
|
// Decrypt 解密
|
|
// 返回明文
|
|
func Decrypt(sk *PrivateKey, cipher *Cipher) (plainText []byte, err error) {
|
|
|
|
if !ec256.Curve256.IsOnCurve(cipher.X, cipher.Y) {
|
|
return nil, gerrors.WithAnnotating(ErrDecFailed, "C1 is not a valid curve point")
|
|
}
|
|
|
|
x, y := ec256.Curve256.ScalarMult(cipher.X, cipher.Y, sk.Bytes())
|
|
|
|
return Decrypt_aux(x, y, cipher)
|
|
}
|
|
|
|
// Decrypt_aux common decrypt function
|
|
func Decrypt_aux(x, y *big.Int, cipher *Cipher) (plainText []byte, err error) {
|
|
plainText = make([]byte, len(cipher.C))
|
|
xBytes := make([]byte, byteSize)
|
|
yBytes := make([]byte, byteSize)
|
|
if err = gmath.FillBytes(x, xBytes); err != nil {
|
|
return nil, gerrors.WithAnnotating(ErrDecFailed, "x is too big")
|
|
}
|
|
|
|
if err = gmath.FillBytes(y, yBytes); err != nil {
|
|
return nil, gerrors.WithAnnotating(ErrDecFailed, "y is too big")
|
|
}
|
|
|
|
if len(plainText) == 0 {
|
|
goto Next
|
|
}
|
|
if err := Kdf(plainText, xBytes, yBytes); err != nil {
|
|
return nil, gerrors.ChainErrors(ErrDecFailed, err)
|
|
}
|
|
for _, k := range plainText {
|
|
if k != 0 {
|
|
goto Next
|
|
}
|
|
}
|
|
return nil, gerrors.WithAnnotating(ErrDecFailed, "t is all zero while decryption")
|
|
Next:
|
|
for i := range plainText {
|
|
plainText[i] ^= cipher.C[i]
|
|
}
|
|
|
|
digest := sm3.New()
|
|
digest.Write(xBytes)
|
|
digest.Write(plainText)
|
|
digest.Write(yBytes)
|
|
d := digest.Sum(nil)
|
|
if !bytes.Equal(d, cipher.Hash[:]) {
|
|
return nil, gerrors.WithAnnotating(ErrDecFailed, "mac check failed (U!=C3) while decryption")
|
|
}
|
|
return plainText, nil
|
|
}
|