init: v1.0.0

This commit is contained in:
yaole
2026-05-27 23:03:00 +08:00
commit 8d97f750eb
466 changed files with 80067 additions and 0 deletions
+380
View File
@@ -0,0 +1,380 @@
///
/// Copyright (c) 2018 xdx. All rights reserved.
///
/// \file: encryption.go
///
/// \brief: SM2加解密
///
/// \author: xdx
///
package sm2
import (
"bytes"
"encoding/asn1"
"encoding/hex"
"math/big"
"strings"
"xdx.jelly/xgcl/gerrors"
"xdx.jelly/xgcl/utils/objectpool"
"xdx.jelly/xgcl/gmath"
"xdx.jelly/xgcl/grand"
"xdx.jelly/xgcl/sm/sm2/ec256"
"xdx.jelly/xgcl/sm/sm3"
)
// DefaultCipherLength is the default cipher length when init
const DefaultCipherLength = 128
// Cipher 密文结构
// (x,y): C1
// hash: C3
// c: C2
type Cipher struct {
X, Y *big.Int
Hash [32]byte
C []byte
}
// NewCipher return a new instance of Cipher
func NewCipher() *Cipher {
return &Cipher{
X: new(big.Int),
Y: new(big.Int),
C: make([]byte, 0, DefaultCipherLength),
}
}
// Normalize reduce C1's x,y
func (c *Cipher) Normalize() *Cipher {
if c.X != nil {
c.X.Mod(c.X, orderN)
}
if c.Y != nil {
c.Y.Mod(c.Y, orderN)
}
return c
}
// Set c to src
func (c *Cipher) Set(src *Cipher) *Cipher {
c.X.Set(src.X)
c.Y.Set(src.Y)
c.Hash = src.Hash
c.C = append(c.C[:0], src.C...)
return c
}
// cipherASN1 helper struct for ASN.1 encoding/decoding Cipher using asn1.Marshal.
type cipherASN1 struct {
X, Y *big.Int
Hash []byte
C []byte
}
// MarshalASN1 marshal c to DER encoded and append to data.
func (c *Cipher) MarshalASN1() ([]byte, error) {
return asn1.Marshal(cipherASN1{
X: c.X,
Y: c.Y,
Hash: c.Hash[:],
C: c.C,
})
}
// UnmarshalASN1 unmarshal c from DER encoded data.
func (c *Cipher) UnmarshalASN1(b []byte) (rest []byte, err error) {
var ca cipherASN1
rest, err = asn1.Unmarshal(b, &ca)
if err != nil {
return b, err
}
c.X = ca.X
c.Y = ca.Y
copy(c.Hash[:], ca.Hash)
c.C = append(c.C, ca.C...)
return rest, nil
}
// MarshalUtil implements the gcl/util/encoding/UtilMarshaler interface
func (c *Cipher) MarshalUtil(data []byte) ([]byte, error) {
if data == nil {
data = make([]byte, 0, 2*ECCRefMaxLen+32+4+DefaultCipherLength)
}
c.X.Mod(c.X, orderN)
xBytes := gmath.BigIntToNByte(c.X, ECCRefMaxLen)
data = append(data, xBytes...)
c.Y.Mod(c.Y, orderN)
yBytes := gmath.BigIntToNByte(c.Y, ECCRefMaxLen)
data = append(data, yBytes...)
data = append(data, c.Hash[:]...)
buf := []byte{0, 0, 0, 0}
Endian.PutUint32(buf, uint32(len(c.C)))
data = append(data, buf...)
data = append(data, c.C...)
return data, nil
}
// UnmarshalUtil marshal the Cipher append to data
func (c *Cipher) UnmarshalUtil(data []byte) (uint64, error) {
n := uint64(0) // consumed bytes.
if len(data) < 2*ECCRefMaxLen+32+4 {
return 0, gerrors.WithAnnotatingf(ErrInvalidInput, "input(%d bytes) must be at least %d bytes", len(data), 2*ECCRefMaxLen+32+4)
}
x := new(big.Int).SetBytes(data[:ECCRefMaxLen])
data = data[ECCRefMaxLen:]
n += ECCRefMaxLen
y := new(big.Int).SetBytes(data[:ECCRefMaxLen])
data = data[ECCRefMaxLen:]
n += ECCRefMaxLen
if x.Cmp(orderN) >= 0 || x.Sign() == 0 {
return 0, gerrors.WithAnnotating(ErrInvalidInput, "x is bigger then the order N")
}
if y.Cmp(orderN) >= 0 || y.Sign() == 0 {
return 0, gerrors.WithAnnotating(ErrInvalidInput, "y is bigger then the order N")
}
if !sm2Curve.IsOnCurve(x, y) {
return 0, gerrors.WithAnnotating(ErrDecFailed, "C1 is not a valid curve point")
}
c.X.Set(x)
c.Y.Set(y)
copy(c.Hash[:], data[:32])
data = data[32:]
n += 32
clen := Endian.Uint32(data)
data = data[4:]
n += 4
if len(data) < int(clen) {
return 0, gerrors.WithAnnotating(ErrInvalidInput, "C2 is too short")
}
c.C = append(c.C[:0], data[:clen]...)
n += uint64(clen)
// return the rest data
return n, nil
}
// MarshalBinary implements the encoding.BinaryMarshaler interface
// 返回字节符合GMT 0018的定义。
// x||y||m||L||c
func (c *Cipher) MarshalBinary() ([]byte, error) {
//data := make([]byte, 2*ECCRefMaxLen+32+4+len(c.C))
data := objectpool.GetBytes()
c.X.Mod(c.X, orderN)
xBytes := gmath.BigIntToNByte(c.X, ECCRefMaxLen)
data = append(data, xBytes...)
c.Y.Mod(c.Y, orderN)
yBytes := gmath.BigIntToNByte(c.Y, ECCRefMaxLen)
data = append(data, yBytes...)
data = append(data, c.Hash[:]...)
buf := []byte{0, 0, 0, 0}
Endian.PutUint32(buf, uint32(len(c.C)))
data = append(data, buf...)
data = append(data, c.C...)
return data, nil
}
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface
// 输入字节应符合GMT 0018的定义。
func (c *Cipher) UnmarshalBinary(data []byte) error {
if len(data) < 2*ECCRefMaxLen+32+4 {
return gerrors.WithAnnotatingf(ErrInvalidInput, "input(%d bytes) must be at least %d bytes", len(data), 2*ECCRefMaxLen+32+4)
}
x := new(big.Int).SetBytes(data[:ECCRefMaxLen])
data = data[ECCRefMaxLen:]
y := new(big.Int).SetBytes(data[:ECCRefMaxLen])
data = data[ECCRefMaxLen:]
if x.Cmp(orderN) >= 0 || x.Sign() == 0 {
return gerrors.WithAnnotating(ErrInvalidInput, "x is bigger then the order N")
}
if y.Cmp(orderN) >= 0 || y.Sign() == 0 {
return gerrors.WithAnnotating(ErrInvalidInput, "y is bigger then the order N")
}
if !sm2Curve.IsOnCurve(x, y) {
return gerrors.WithAnnotating(ErrInvalidInput, "(x,y) is not on the curve")
}
c.X.Set(x)
c.Y.Set(y)
copy(c.Hash[:], data[:32])
data = data[32:]
clen := Endian.Uint32(data)
data = data[4:]
if len(data) < int(clen) {
return gerrors.WithAnnotating(ErrInvalidInput, "C2 is too short")
}
c.C = append(c.C[:0], data[:clen]...)
return nil
}
// Bytes 转换密文结构为byte切片
func (c *Cipher) Bytes() []byte {
var buf bytes.Buffer
buf.Write(gmath.BigIntToNByte(c.X, byteSize))
buf.Write(gmath.BigIntToNByte(c.Y, byteSize))
buf.Write(c.Hash[:])
buf.Write(c.C)
return buf.Bytes()
}
// SetBytes 转换[]byte为Cipher
func (c *Cipher) SetBytes(data []byte) error {
if len(data) < 2*byteSize+32 {
return gerrors.WithAnnotatingf(ErrInvalidInput, "input(%d bytes) must be at least %d bytes", len(data), 2*byteSize+32)
}
if c.X == nil {
c.X = new(big.Int)
}
if c.Y == nil {
c.Y = new(big.Int)
}
c.X.SetBytes(data[:byteSize])
c.Y.SetBytes(data[byteSize : 2*byteSize])
copy(c.Hash[:], data[2*byteSize:2*byteSize+32])
c.C = append(c.C[:0], data[2*byteSize+32:]...)
return nil
}
// String return a readable string
func (c *Cipher) String() string {
var buf strings.Builder
buf.WriteString("x: ")
buf.WriteString(hex.EncodeToString(c.X.Bytes()))
buf.WriteString("\ny: ")
buf.WriteString(hex.EncodeToString(c.Y.Bytes()))
buf.WriteString("\nc: ")
buf.WriteString(hex.EncodeToString(c.C))
buf.WriteString("\nhash: ")
buf.WriteString(hex.EncodeToString(c.Hash[:]))
return buf.String()
// return hex.EncodeToString(c.Bytes())
}
// Encrypt 加密
func Encrypt(pk *PublicKey, data, rnd []byte) (*Cipher, error) {
// k := new(big.Int).SetBytes(rnd)
if rnd == nil || len(rnd) < byteSize {
rnd = grand.GetRandom(byteSize)
}
// if !pk.IsValid() {
// return nil, gerrors.ERR_SM2_INVALID_PUBKEY
// }
var x, y *big.Int
cipher := new(Cipher)
xBytes := make([]byte, byteSize)
yBytes := make([]byte, byteSize)
var err error
outer:
for {
cipher.X, cipher.Y = sm2Curve.ScalarBaseMult(rnd[:byteSize])
x, y = sm2Curve.ScalarMult(pk.X, pk.Y, rnd[:byteSize])
if err = gmath.FillBytes(x, xBytes); err != nil {
return nil, gerrors.WithAnnotating(ErrEncFailed, "x is too big")
}
if err = gmath.FillBytes(y, yBytes); err != nil {
return nil, gerrors.WithAnnotating(ErrEncFailed, "y is too big")
}
cipher.C = make([]byte, len(data))
if len(data) == 0 {
break
}
Kdf(cipher.C, xBytes, yBytes)
for _, k := range cipher.C {
if k != 0 {
break outer
}
}
if _, err = grand.GenerateRandom(rnd); err != nil {
return nil, gerrors.ChainErrors(gerrors.WithAnnotating(ErrEncFailed, "generate random failed"), err)
}
}
// 两个对C的遍历,上面那个不会执行太多。
for i := range cipher.C {
cipher.C[i] ^= data[i]
}
digest := sm3.New()
// need padding 0 if x < 2*248
digest.Write(xBytes)
digest.Write(data)
digest.Write(yBytes)
digest.Sum(cipher.Hash[:0])
return cipher, nil
}
// Decrypt 解密
// 返回明文
func Decrypt(sk *PrivateKey, cipher *Cipher) (plainText []byte, err error) {
if !ec256.Curve256.IsOnCurve(cipher.X, cipher.Y) {
return nil, gerrors.WithAnnotating(ErrDecFailed, "C1 is not a valid curve point")
}
x, y := ec256.Curve256.ScalarMult(cipher.X, cipher.Y, sk.Bytes())
return Decrypt_aux(x, y, cipher)
}
// Decrypt_aux common decrypt function
func Decrypt_aux(x, y *big.Int, cipher *Cipher) (plainText []byte, err error) {
plainText = make([]byte, len(cipher.C))
xBytes := make([]byte, byteSize)
yBytes := make([]byte, byteSize)
if err = gmath.FillBytes(x, xBytes); err != nil {
return nil, gerrors.WithAnnotating(ErrDecFailed, "x is too big")
}
if err = gmath.FillBytes(y, yBytes); err != nil {
return nil, gerrors.WithAnnotating(ErrDecFailed, "y is too big")
}
if len(plainText) == 0 {
goto Next
}
if err := Kdf(plainText, xBytes, yBytes); err != nil {
return nil, gerrors.ChainErrors(ErrDecFailed, err)
}
for _, k := range plainText {
if k != 0 {
goto Next
}
}
return nil, gerrors.WithAnnotating(ErrDecFailed, "t is all zero while decryption")
Next:
for i := range plainText {
plainText[i] ^= cipher.C[i]
}
digest := sm3.New()
digest.Write(xBytes)
digest.Write(plainText)
digest.Write(yBytes)
d := digest.Sum(nil)
if !bytes.Equal(d, cipher.Hash[:]) {
return nil, gerrors.WithAnnotating(ErrDecFailed, "mac check failed (U!=C3) while decryption")
}
return plainText, nil
}