init: v1.0.0
This commit is contained in:
@@ -0,0 +1,380 @@
|
||||
///
|
||||
/// Copyright (c) 2018 xdx. All rights reserved.
|
||||
///
|
||||
/// \file: encryption.go
|
||||
///
|
||||
/// \brief: SM2加解密
|
||||
///
|
||||
/// \author: xdx
|
||||
///
|
||||
|
||||
package sm2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/asn1"
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"xdx.jelly/xgcl/gerrors"
|
||||
"xdx.jelly/xgcl/utils/objectpool"
|
||||
|
||||
"xdx.jelly/xgcl/gmath"
|
||||
"xdx.jelly/xgcl/grand"
|
||||
"xdx.jelly/xgcl/sm/sm2/ec256"
|
||||
"xdx.jelly/xgcl/sm/sm3"
|
||||
)
|
||||
|
||||
// DefaultCipherLength is the default cipher length when init
|
||||
const DefaultCipherLength = 128
|
||||
|
||||
// Cipher 密文结构
|
||||
// (x,y): C1
|
||||
// hash: C3
|
||||
// c: C2
|
||||
type Cipher struct {
|
||||
X, Y *big.Int
|
||||
Hash [32]byte
|
||||
C []byte
|
||||
}
|
||||
|
||||
// NewCipher return a new instance of Cipher
|
||||
func NewCipher() *Cipher {
|
||||
return &Cipher{
|
||||
X: new(big.Int),
|
||||
Y: new(big.Int),
|
||||
C: make([]byte, 0, DefaultCipherLength),
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize reduce C1's x,y
|
||||
func (c *Cipher) Normalize() *Cipher {
|
||||
if c.X != nil {
|
||||
c.X.Mod(c.X, orderN)
|
||||
}
|
||||
if c.Y != nil {
|
||||
c.Y.Mod(c.Y, orderN)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// Set c to src
|
||||
func (c *Cipher) Set(src *Cipher) *Cipher {
|
||||
c.X.Set(src.X)
|
||||
c.Y.Set(src.Y)
|
||||
c.Hash = src.Hash
|
||||
c.C = append(c.C[:0], src.C...)
|
||||
return c
|
||||
}
|
||||
|
||||
// cipherASN1 helper struct for ASN.1 encoding/decoding Cipher using asn1.Marshal.
|
||||
type cipherASN1 struct {
|
||||
X, Y *big.Int
|
||||
Hash []byte
|
||||
C []byte
|
||||
}
|
||||
|
||||
// MarshalASN1 marshal c to DER encoded and append to data.
|
||||
func (c *Cipher) MarshalASN1() ([]byte, error) {
|
||||
return asn1.Marshal(cipherASN1{
|
||||
X: c.X,
|
||||
Y: c.Y,
|
||||
Hash: c.Hash[:],
|
||||
C: c.C,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalASN1 unmarshal c from DER encoded data.
|
||||
func (c *Cipher) UnmarshalASN1(b []byte) (rest []byte, err error) {
|
||||
var ca cipherASN1
|
||||
|
||||
rest, err = asn1.Unmarshal(b, &ca)
|
||||
if err != nil {
|
||||
return b, err
|
||||
}
|
||||
c.X = ca.X
|
||||
c.Y = ca.Y
|
||||
copy(c.Hash[:], ca.Hash)
|
||||
c.C = append(c.C, ca.C...)
|
||||
return rest, nil
|
||||
}
|
||||
|
||||
// MarshalUtil implements the gcl/util/encoding/UtilMarshaler interface
|
||||
func (c *Cipher) MarshalUtil(data []byte) ([]byte, error) {
|
||||
if data == nil {
|
||||
data = make([]byte, 0, 2*ECCRefMaxLen+32+4+DefaultCipherLength)
|
||||
}
|
||||
c.X.Mod(c.X, orderN)
|
||||
xBytes := gmath.BigIntToNByte(c.X, ECCRefMaxLen)
|
||||
data = append(data, xBytes...)
|
||||
|
||||
c.Y.Mod(c.Y, orderN)
|
||||
yBytes := gmath.BigIntToNByte(c.Y, ECCRefMaxLen)
|
||||
data = append(data, yBytes...)
|
||||
data = append(data, c.Hash[:]...)
|
||||
buf := []byte{0, 0, 0, 0}
|
||||
Endian.PutUint32(buf, uint32(len(c.C)))
|
||||
data = append(data, buf...)
|
||||
data = append(data, c.C...)
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// UnmarshalUtil marshal the Cipher append to data
|
||||
func (c *Cipher) UnmarshalUtil(data []byte) (uint64, error) {
|
||||
n := uint64(0) // consumed bytes.
|
||||
if len(data) < 2*ECCRefMaxLen+32+4 {
|
||||
return 0, gerrors.WithAnnotatingf(ErrInvalidInput, "input(%d bytes) must be at least %d bytes", len(data), 2*ECCRefMaxLen+32+4)
|
||||
}
|
||||
|
||||
x := new(big.Int).SetBytes(data[:ECCRefMaxLen])
|
||||
data = data[ECCRefMaxLen:]
|
||||
n += ECCRefMaxLen
|
||||
|
||||
y := new(big.Int).SetBytes(data[:ECCRefMaxLen])
|
||||
data = data[ECCRefMaxLen:]
|
||||
n += ECCRefMaxLen
|
||||
|
||||
if x.Cmp(orderN) >= 0 || x.Sign() == 0 {
|
||||
return 0, gerrors.WithAnnotating(ErrInvalidInput, "x is bigger then the order N")
|
||||
}
|
||||
if y.Cmp(orderN) >= 0 || y.Sign() == 0 {
|
||||
return 0, gerrors.WithAnnotating(ErrInvalidInput, "y is bigger then the order N")
|
||||
}
|
||||
if !sm2Curve.IsOnCurve(x, y) {
|
||||
return 0, gerrors.WithAnnotating(ErrDecFailed, "C1 is not a valid curve point")
|
||||
}
|
||||
|
||||
c.X.Set(x)
|
||||
c.Y.Set(y)
|
||||
|
||||
copy(c.Hash[:], data[:32])
|
||||
data = data[32:]
|
||||
n += 32
|
||||
|
||||
clen := Endian.Uint32(data)
|
||||
data = data[4:]
|
||||
n += 4
|
||||
|
||||
if len(data) < int(clen) {
|
||||
return 0, gerrors.WithAnnotating(ErrInvalidInput, "C2 is too short")
|
||||
}
|
||||
|
||||
c.C = append(c.C[:0], data[:clen]...)
|
||||
n += uint64(clen)
|
||||
// return the rest data
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// MarshalBinary implements the encoding.BinaryMarshaler interface
|
||||
// 返回字节符合GMT 0018的定义。
|
||||
// x||y||m||L||c
|
||||
func (c *Cipher) MarshalBinary() ([]byte, error) {
|
||||
//data := make([]byte, 2*ECCRefMaxLen+32+4+len(c.C))
|
||||
data := objectpool.GetBytes()
|
||||
c.X.Mod(c.X, orderN)
|
||||
xBytes := gmath.BigIntToNByte(c.X, ECCRefMaxLen)
|
||||
data = append(data, xBytes...)
|
||||
|
||||
c.Y.Mod(c.Y, orderN)
|
||||
yBytes := gmath.BigIntToNByte(c.Y, ECCRefMaxLen)
|
||||
data = append(data, yBytes...)
|
||||
data = append(data, c.Hash[:]...)
|
||||
buf := []byte{0, 0, 0, 0}
|
||||
Endian.PutUint32(buf, uint32(len(c.C)))
|
||||
data = append(data, buf...)
|
||||
data = append(data, c.C...)
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface
|
||||
// 输入字节应符合GMT 0018的定义。
|
||||
func (c *Cipher) UnmarshalBinary(data []byte) error {
|
||||
if len(data) < 2*ECCRefMaxLen+32+4 {
|
||||
return gerrors.WithAnnotatingf(ErrInvalidInput, "input(%d bytes) must be at least %d bytes", len(data), 2*ECCRefMaxLen+32+4)
|
||||
}
|
||||
|
||||
x := new(big.Int).SetBytes(data[:ECCRefMaxLen])
|
||||
data = data[ECCRefMaxLen:]
|
||||
y := new(big.Int).SetBytes(data[:ECCRefMaxLen])
|
||||
data = data[ECCRefMaxLen:]
|
||||
|
||||
if x.Cmp(orderN) >= 0 || x.Sign() == 0 {
|
||||
return gerrors.WithAnnotating(ErrInvalidInput, "x is bigger then the order N")
|
||||
}
|
||||
if y.Cmp(orderN) >= 0 || y.Sign() == 0 {
|
||||
return gerrors.WithAnnotating(ErrInvalidInput, "y is bigger then the order N")
|
||||
}
|
||||
if !sm2Curve.IsOnCurve(x, y) {
|
||||
return gerrors.WithAnnotating(ErrInvalidInput, "(x,y) is not on the curve")
|
||||
}
|
||||
|
||||
c.X.Set(x)
|
||||
c.Y.Set(y)
|
||||
|
||||
copy(c.Hash[:], data[:32])
|
||||
data = data[32:]
|
||||
|
||||
clen := Endian.Uint32(data)
|
||||
data = data[4:]
|
||||
|
||||
if len(data) < int(clen) {
|
||||
return gerrors.WithAnnotating(ErrInvalidInput, "C2 is too short")
|
||||
}
|
||||
|
||||
c.C = append(c.C[:0], data[:clen]...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bytes 转换密文结构为byte切片
|
||||
func (c *Cipher) Bytes() []byte {
|
||||
var buf bytes.Buffer
|
||||
buf.Write(gmath.BigIntToNByte(c.X, byteSize))
|
||||
buf.Write(gmath.BigIntToNByte(c.Y, byteSize))
|
||||
buf.Write(c.Hash[:])
|
||||
buf.Write(c.C)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// SetBytes 转换[]byte为Cipher
|
||||
func (c *Cipher) SetBytes(data []byte) error {
|
||||
if len(data) < 2*byteSize+32 {
|
||||
return gerrors.WithAnnotatingf(ErrInvalidInput, "input(%d bytes) must be at least %d bytes", len(data), 2*byteSize+32)
|
||||
}
|
||||
if c.X == nil {
|
||||
c.X = new(big.Int)
|
||||
}
|
||||
if c.Y == nil {
|
||||
c.Y = new(big.Int)
|
||||
}
|
||||
c.X.SetBytes(data[:byteSize])
|
||||
c.Y.SetBytes(data[byteSize : 2*byteSize])
|
||||
copy(c.Hash[:], data[2*byteSize:2*byteSize+32])
|
||||
c.C = append(c.C[:0], data[2*byteSize+32:]...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// String return a readable string
|
||||
func (c *Cipher) String() string {
|
||||
var buf strings.Builder
|
||||
buf.WriteString("x: ")
|
||||
buf.WriteString(hex.EncodeToString(c.X.Bytes()))
|
||||
buf.WriteString("\ny: ")
|
||||
buf.WriteString(hex.EncodeToString(c.Y.Bytes()))
|
||||
buf.WriteString("\nc: ")
|
||||
buf.WriteString(hex.EncodeToString(c.C))
|
||||
buf.WriteString("\nhash: ")
|
||||
buf.WriteString(hex.EncodeToString(c.Hash[:]))
|
||||
|
||||
return buf.String()
|
||||
// return hex.EncodeToString(c.Bytes())
|
||||
}
|
||||
|
||||
// Encrypt 加密
|
||||
func Encrypt(pk *PublicKey, data, rnd []byte) (*Cipher, error) {
|
||||
// k := new(big.Int).SetBytes(rnd)
|
||||
if rnd == nil || len(rnd) < byteSize {
|
||||
rnd = grand.GetRandom(byteSize)
|
||||
}
|
||||
// if !pk.IsValid() {
|
||||
// return nil, gerrors.ERR_SM2_INVALID_PUBKEY
|
||||
// }
|
||||
var x, y *big.Int
|
||||
cipher := new(Cipher)
|
||||
xBytes := make([]byte, byteSize)
|
||||
yBytes := make([]byte, byteSize)
|
||||
var err error
|
||||
outer:
|
||||
for {
|
||||
cipher.X, cipher.Y = sm2Curve.ScalarBaseMult(rnd[:byteSize])
|
||||
x, y = sm2Curve.ScalarMult(pk.X, pk.Y, rnd[:byteSize])
|
||||
if err = gmath.FillBytes(x, xBytes); err != nil {
|
||||
return nil, gerrors.WithAnnotating(ErrEncFailed, "x is too big")
|
||||
}
|
||||
|
||||
if err = gmath.FillBytes(y, yBytes); err != nil {
|
||||
return nil, gerrors.WithAnnotating(ErrEncFailed, "y is too big")
|
||||
}
|
||||
|
||||
cipher.C = make([]byte, len(data))
|
||||
if len(data) == 0 {
|
||||
break
|
||||
}
|
||||
Kdf(cipher.C, xBytes, yBytes)
|
||||
for _, k := range cipher.C {
|
||||
if k != 0 {
|
||||
break outer
|
||||
}
|
||||
}
|
||||
if _, err = grand.GenerateRandom(rnd); err != nil {
|
||||
return nil, gerrors.ChainErrors(gerrors.WithAnnotating(ErrEncFailed, "generate random failed"), err)
|
||||
}
|
||||
}
|
||||
|
||||
// 两个对C的遍历,上面那个不会执行太多。
|
||||
for i := range cipher.C {
|
||||
cipher.C[i] ^= data[i]
|
||||
}
|
||||
|
||||
digest := sm3.New()
|
||||
// need padding 0 if x < 2*248
|
||||
digest.Write(xBytes)
|
||||
digest.Write(data)
|
||||
digest.Write(yBytes)
|
||||
digest.Sum(cipher.Hash[:0])
|
||||
return cipher, nil
|
||||
}
|
||||
|
||||
// Decrypt 解密
|
||||
// 返回明文
|
||||
func Decrypt(sk *PrivateKey, cipher *Cipher) (plainText []byte, err error) {
|
||||
|
||||
if !ec256.Curve256.IsOnCurve(cipher.X, cipher.Y) {
|
||||
return nil, gerrors.WithAnnotating(ErrDecFailed, "C1 is not a valid curve point")
|
||||
}
|
||||
|
||||
x, y := ec256.Curve256.ScalarMult(cipher.X, cipher.Y, sk.Bytes())
|
||||
|
||||
return Decrypt_aux(x, y, cipher)
|
||||
}
|
||||
|
||||
// Decrypt_aux common decrypt function
|
||||
func Decrypt_aux(x, y *big.Int, cipher *Cipher) (plainText []byte, err error) {
|
||||
plainText = make([]byte, len(cipher.C))
|
||||
xBytes := make([]byte, byteSize)
|
||||
yBytes := make([]byte, byteSize)
|
||||
if err = gmath.FillBytes(x, xBytes); err != nil {
|
||||
return nil, gerrors.WithAnnotating(ErrDecFailed, "x is too big")
|
||||
}
|
||||
|
||||
if err = gmath.FillBytes(y, yBytes); err != nil {
|
||||
return nil, gerrors.WithAnnotating(ErrDecFailed, "y is too big")
|
||||
}
|
||||
|
||||
if len(plainText) == 0 {
|
||||
goto Next
|
||||
}
|
||||
if err := Kdf(plainText, xBytes, yBytes); err != nil {
|
||||
return nil, gerrors.ChainErrors(ErrDecFailed, err)
|
||||
}
|
||||
for _, k := range plainText {
|
||||
if k != 0 {
|
||||
goto Next
|
||||
}
|
||||
}
|
||||
return nil, gerrors.WithAnnotating(ErrDecFailed, "t is all zero while decryption")
|
||||
Next:
|
||||
for i := range plainText {
|
||||
plainText[i] ^= cipher.C[i]
|
||||
}
|
||||
|
||||
digest := sm3.New()
|
||||
digest.Write(xBytes)
|
||||
digest.Write(plainText)
|
||||
digest.Write(yBytes)
|
||||
d := digest.Sum(nil)
|
||||
if !bytes.Equal(d, cipher.Hash[:]) {
|
||||
return nil, gerrors.WithAnnotating(ErrDecFailed, "mac check failed (U!=C3) while decryption")
|
||||
}
|
||||
return plainText, nil
|
||||
}
|
||||
Reference in New Issue
Block a user