init: v1.0.0
This commit is contained in:
+543
@@ -0,0 +1,543 @@
|
||||
///
|
||||
/// Copyright (c) 2018 xdx. All rights reserved.
|
||||
///
|
||||
/// \file: key.go
|
||||
///
|
||||
/// \brief: SM2密钥结构
|
||||
///
|
||||
/// \author: xdx
|
||||
///
|
||||
|
||||
package sm2
|
||||
|
||||
import (
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"math/big"
|
||||
|
||||
"xdx.jelly/xgcl/gerrors"
|
||||
"xdx.jelly/xgcl/gmath"
|
||||
"xdx.jelly/xgcl/grand"
|
||||
"xdx.jelly/xgcl/internal"
|
||||
)
|
||||
|
||||
////////////////////////////////////////////////////////// PrivateKey
|
||||
|
||||
// PrivateKey 私钥, 使用NewPrivateKey生成
|
||||
type PrivateKey struct {
|
||||
// it is dangerous that &PrivateKey = &PrivateKey.PublicKey
|
||||
// We sometimes pass &PrivateKey.PublicKey to caller and he could read the D.
|
||||
// It is better to use *PublicKey.
|
||||
// But we are along with crypto/ecdsa.PrivateKey
|
||||
PublicKey
|
||||
D *big.Int // use D (not anonymous) to hide big.Int's methods
|
||||
}
|
||||
|
||||
// NewPrivateKey return a new PrivateKey instance
|
||||
// Equal to use PrivateKey{}
|
||||
func NewPrivateKey() *PrivateKey {
|
||||
return &PrivateKey{
|
||||
D: new(big.Int),
|
||||
}
|
||||
}
|
||||
|
||||
// Clear zero the privatekey, also it implements the gmath.Clearable interface
|
||||
func (k *PrivateKey) Clear() {
|
||||
gmath.ClearBigInt(k.D)
|
||||
}
|
||||
|
||||
// Get return k.D, if k.D is nil, new one and return it
|
||||
func (k *PrivateKey) Get() *big.Int {
|
||||
if k.D == nil {
|
||||
k.D = new(big.Int)
|
||||
}
|
||||
return k.D
|
||||
}
|
||||
|
||||
// SetString set k to s.
|
||||
func (k *PrivateKey) SetString(s string, base int) (*PrivateKey, bool) {
|
||||
if k.D == nil {
|
||||
k.D = new(big.Int)
|
||||
}
|
||||
|
||||
if _, ok := k.D.SetString(s, base); !ok {
|
||||
return k, ok
|
||||
}
|
||||
k.computePublicKeyUncheck()
|
||||
return k, true
|
||||
}
|
||||
func (k *PrivateKey) SetBigInt(b *big.Int) *PrivateKey {
|
||||
if k.D == nil {
|
||||
k.D = new(big.Int)
|
||||
}
|
||||
k.D.Set(b)
|
||||
k.computePublicKeyUncheck()
|
||||
return k
|
||||
}
|
||||
func (k *PrivateKey) Set(x *PrivateKey) *PrivateKey {
|
||||
if x == nil {
|
||||
return k
|
||||
}
|
||||
|
||||
if k.D == nil {
|
||||
k.D = new(big.Int)
|
||||
}
|
||||
|
||||
k.D.Set(x.D)
|
||||
k.PublicKey.Set(&x.PublicKey)
|
||||
return k
|
||||
}
|
||||
|
||||
// GenPublicKey generete the k.PublicKey and return it.
|
||||
func (k *PrivateKey) GenPublicKey() *PublicKey {
|
||||
k.computePublicKeyUncheck()
|
||||
return &k.PublicKey
|
||||
}
|
||||
|
||||
func (k *PrivateKey) computePublicKeyUncheck() {
|
||||
k.PublicKey.Curve = Curve()
|
||||
k.PublicKey.X, k.PublicKey.Y = Curve().ScalarBaseMult(k.D.Bytes())
|
||||
}
|
||||
|
||||
// Random set k to a random key
|
||||
func (k *PrivateKey) Random(r io.Reader) *PrivateKey {
|
||||
N := Curve().Params().N
|
||||
if k.D == nil {
|
||||
k.D = new(big.Int)
|
||||
}
|
||||
for {
|
||||
d, err := rand.Int(r, N)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if d.Sign() != 0 {
|
||||
k.D = d
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
k.computePublicKeyUncheck()
|
||||
return k
|
||||
}
|
||||
|
||||
// MarshalUtil implements the gcl/util/encoding/UtilMarshaler interface
|
||||
func (k *PrivateKey) MarshalUtil(data []byte) ([]byte, error) {
|
||||
if data == nil {
|
||||
data = make([]byte, 0, 4+ECCRefMaxLen)
|
||||
}
|
||||
if k.D == nil {
|
||||
k.D = new(big.Int)
|
||||
}
|
||||
k.D.Mod(k.D, orderN)
|
||||
|
||||
data, tail := internal.SliceForAppend(data, 4)
|
||||
Endian.PutUint32(tail, uint32(byteSize<<3))
|
||||
|
||||
data, tail = internal.SliceForAppend(data, ECCRefMaxLen)
|
||||
_ = gmath.FillBytes(k.D, tail)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// UnmarshalUtil revert of MarshalUtil
|
||||
func (k *PrivateKey) UnmarshalUtil(data []byte) (uint64, error) {
|
||||
n := uint64(0)
|
||||
if len(data) < 4+ECCRefMaxLen {
|
||||
return 0, gerrors.WithAnnotating(ErrInvalidInput, "input data too short")
|
||||
}
|
||||
if ECCStrict {
|
||||
bits := binary.BigEndian.Uint32(data[:4])
|
||||
if bits != uint32(byteSize<<3) {
|
||||
return 0, gerrors.WithAnnotating(ErrInvalidInput, "may be bits are little endian")
|
||||
}
|
||||
}
|
||||
data = data[4:]
|
||||
n += 4
|
||||
|
||||
if k.D == nil {
|
||||
k.D = new(big.Int)
|
||||
}
|
||||
|
||||
k.D.SetBytes(data[:ECCRefMaxLen])
|
||||
if k.D.Cmp(orderN) >= 0 || k.D.Sign() == 0 {
|
||||
return 0, gerrors.WithAnnotating(ErrInvalidInput, "the input private key as an integer is bigger than the order N")
|
||||
}
|
||||
n += ECCRefMaxLen
|
||||
k.computePublicKeyUncheck()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// MarshalBinary implements the encoding.BinaryMarshaler interface
|
||||
// 返回字节符合GMT 0018的定义。
|
||||
func (k *PrivateKey) MarshalBinary() ([]byte, error) {
|
||||
return k.MarshalUtil(nil)
|
||||
}
|
||||
|
||||
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface
|
||||
// 返回字节符合GMT 0018的定义。
|
||||
// 注意:若返回错误,则k的值未定义。
|
||||
func (k *PrivateKey) UnmarshalBinary(data []byte) error {
|
||||
_, err := k.UnmarshalUtil(data)
|
||||
return gerrors.WithStack(err)
|
||||
}
|
||||
|
||||
// Bytes return the big-endian of privateKey, of byteSize bytes(aka. 32 bytes), padding 0 in the leading
|
||||
func (k *PrivateKey) Bytes() []byte {
|
||||
r := make([]byte, byteSize)
|
||||
k.D.Mod(k.D, orderN)
|
||||
_ = gmath.FillBytes(k.D, r)
|
||||
return r
|
||||
}
|
||||
|
||||
// SetBytes set buf to k, invert of Bytes()
|
||||
func (k *PrivateKey) SetBytes(buf []byte) error {
|
||||
if len(buf) < byteSize {
|
||||
return gerrors.WithAnnotating(ErrInvalidInput, "input too small")
|
||||
}
|
||||
if k.D == nil {
|
||||
k.D = new(big.Int)
|
||||
}
|
||||
k.D.SetBytes(buf)
|
||||
// Mod unnessary?
|
||||
k.D.Mod(k.D, orderN)
|
||||
|
||||
k.computePublicKeyUncheck()
|
||||
return nil
|
||||
}
|
||||
|
||||
// String return a readable string
|
||||
func (k *PrivateKey) String() string {
|
||||
if k.D == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
// return k.Int.Text(16)
|
||||
return hex.EncodeToString(gmath.BigIntToNByte(k.D, byteSize))
|
||||
}
|
||||
|
||||
// GenPrivateKey 生成私钥。
|
||||
// rnd should be nil, io.Reader or []byte.
|
||||
// If rnd == nil, use package grand to generates random numbers.
|
||||
func GenPrivateKey(rnd any) (*PrivateKey, error) {
|
||||
if rnd == nil {
|
||||
return GenerateKey(Curve(), grand.Reader)
|
||||
}
|
||||
if rnd, ok := rnd.(io.Reader); ok {
|
||||
return GenerateKey(Curve(), rnd)
|
||||
}
|
||||
if b, ok := rnd.([]byte); ok {
|
||||
D := new(big.Int).SetBytes(b)
|
||||
D.Mod(D, OrderN())
|
||||
if D.Sign() == 0 {
|
||||
// D.SetInt64(1)
|
||||
return nil, gerrors.WithAnnotating(ErrRandomError, "input random bytes invalid, recall with another random bytes")
|
||||
}
|
||||
sk := &PrivateKey{D: D}
|
||||
sk.computePublicKeyUncheck()
|
||||
return sk, nil
|
||||
}
|
||||
panic("GenPrivateKey: input rnd must bi nil, io.Reader or []byte")
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////// PublicKey
|
||||
|
||||
// PublicKey 公钥
|
||||
type PublicKey struct {
|
||||
elliptic.Curve
|
||||
X, Y *big.Int
|
||||
}
|
||||
|
||||
// Set k=x
|
||||
func (k *PublicKey) Set(x *PublicKey) *PublicKey {
|
||||
if k.X == nil {
|
||||
k.X = new(big.Int)
|
||||
}
|
||||
if k.Y == nil {
|
||||
k.Y = new(big.Int)
|
||||
}
|
||||
|
||||
k.X.Set(x.X)
|
||||
k.Y.Set(x.Y)
|
||||
return k
|
||||
}
|
||||
|
||||
// NewPublicKey return a new PublicKey instance
|
||||
func NewPublicKey() *PublicKey {
|
||||
return &PublicKey{X: new(big.Int), Y: new(big.Int)}
|
||||
}
|
||||
func (k *PublicKey) MarshalASN1() ([]byte, error) {
|
||||
return elliptic.Marshal(Curve(), k.X, k.Y), nil
|
||||
}
|
||||
|
||||
func (k *PublicKey) UnmarshalASN1(b []byte) error {
|
||||
x, y := elliptic.Unmarshal(Curve(), b)
|
||||
if x == nil || y == nil {
|
||||
return gerrors.WithAnnotating(ErrInvalidPoint, "x or y coordinates are 0")
|
||||
}
|
||||
k.Curve = Curve()
|
||||
k.X = x
|
||||
k.Y = y
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalUtil implements the gcl/util/encoding/UtilMarshaler interface
|
||||
func (k *PublicKey) MarshalUtil(data []byte) ([]byte, error) {
|
||||
if data == nil {
|
||||
data = make([]byte, 0, 4+2*ECCRefMaxLen)
|
||||
}
|
||||
buf := []byte{0, 0, 0, 0}
|
||||
// 密钥位长应该是N的比特数
|
||||
Endian.PutUint32(buf, uint32(byteSize)<<3)
|
||||
data = append(data, buf...)
|
||||
|
||||
k.X.Mod(k.X, orderN)
|
||||
xBytes := gmath.BigIntToNByte(k.X, ECCRefMaxLen)
|
||||
data = append(data, xBytes...)
|
||||
|
||||
k.Y.Mod(k.Y, orderN)
|
||||
yBytes := gmath.BigIntToNByte(k.Y, ECCRefMaxLen)
|
||||
data = append(data, yBytes...)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// UnmarshalUtil implements the gcl/util/encoding/UnmarshalUtil interface
|
||||
func (k *PublicKey) UnmarshalUtil(data []byte) (uint64, error) {
|
||||
n := uint64(0)
|
||||
if len(data) < 4+2*ECCRefMaxLen {
|
||||
return 0, gerrors.WithAnnotating(ErrInvalidInput, "input too small")
|
||||
}
|
||||
if ECCStrict {
|
||||
bits := binary.BigEndian.Uint32(data[:4])
|
||||
if bits != uint32(byteSize)<<3 {
|
||||
return 0, gerrors.WithAnnotating(ErrInvalidInput, "may be bits are little endian")
|
||||
}
|
||||
}
|
||||
data = data[4:]
|
||||
n += 4
|
||||
|
||||
if k.X == nil {
|
||||
k.X = new(big.Int)
|
||||
}
|
||||
if k.Y == nil {
|
||||
k.Y = new(big.Int)
|
||||
}
|
||||
|
||||
x := k.X.SetBytes(data[:ECCRefMaxLen])
|
||||
if x.Cmp(orderN) >= 0 {
|
||||
return 0, gerrors.WithAnnotating(ErrInvalidInput, "the x coordinate is bigger than the order N")
|
||||
}
|
||||
data = data[ECCRefMaxLen:]
|
||||
n += ECCRefMaxLen
|
||||
|
||||
y := k.Y.SetBytes(data[:ECCRefMaxLen])
|
||||
if y.Cmp(orderN) >= 0 {
|
||||
return 0, gerrors.WithAnnotating(ErrInvalidInput, "the y coordinate is bigger than the order N")
|
||||
}
|
||||
data = data[ECCRefMaxLen:] //nolint
|
||||
n += ECCRefMaxLen
|
||||
|
||||
// exclude the infinity point
|
||||
if gmath.IsBigInt0(x) && gmath.IsBigInt0(y) {
|
||||
return 0, gerrors.WithAnnotating(ErrInvalidInput, "the x or y coordinate is 0")
|
||||
}
|
||||
|
||||
if !sm2Curve.IsOnCurve(x, y) {
|
||||
return 0, gerrors.WithStack(ErrInvalidPoint)
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// MarshalBinary implements the encoding.BinaryMarshaler interface
|
||||
// 返回字节符合GMT 0018的定义。
|
||||
func (k *PublicKey) MarshalBinary() ([]byte, error) {
|
||||
ret := make([]byte, 4+2*ECCRefMaxLen)
|
||||
// 密钥位长应该是N的比特数
|
||||
Endian.PutUint32(ret[:4], uint32(byteSize)<<3)
|
||||
k.X.Mod(k.X, orderN)
|
||||
xBytes := k.X.Bytes()
|
||||
copy(ret[4+ECCRefMaxLen-len(xBytes):], xBytes)
|
||||
|
||||
k.Y.Mod(k.Y, orderN)
|
||||
yBytes := k.Y.Bytes()
|
||||
copy(ret[4+2*ECCRefMaxLen-len(yBytes):], yBytes)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface
|
||||
// 返回字节符合GMT 0018的定义。
|
||||
// 特殊情形:若data表示无穷远点,即00000100 0000...000,
|
||||
// 返回k.X=k.Y=0
|
||||
// 若返回错误,k的值未定义
|
||||
func (k *PublicKey) UnmarshalBinary(data []byte) error {
|
||||
if len(data) != 4+2*ECCRefMaxLen {
|
||||
return gerrors.WithAnnotating(ErrInvalidInput, "input data too short")
|
||||
}
|
||||
k.Curve = sm2Curve
|
||||
if ECCStrict {
|
||||
bits := binary.BigEndian.Uint32(data[:4])
|
||||
if bits != uint32(byteSize)<<3 {
|
||||
return gerrors.WithAnnotating(ErrInvalidInput, "input bits may be little endian, use big endian instead")
|
||||
}
|
||||
}
|
||||
|
||||
if k.X == nil {
|
||||
k.X = new(big.Int)
|
||||
}
|
||||
if k.Y == nil {
|
||||
k.Y = new(big.Int)
|
||||
}
|
||||
|
||||
data = data[4:]
|
||||
var sum byte
|
||||
for i := 0; i < ECCRefMaxLen-byteSize; i++ {
|
||||
sum |= data[i]
|
||||
}
|
||||
if sum != 0 {
|
||||
return gerrors.WithAnnotatingf(ErrInvalidInput, "the x coordinate is more than %d bits", byteSize<<3)
|
||||
}
|
||||
x := k.X.SetBytes(data[ECCRefMaxLen-byteSize : ECCRefMaxLen])
|
||||
if x.Cmp(Prime()) >= 0 {
|
||||
return gerrors.WithAnnotatingf(ErrInvalidInput, "the x coordinate is big than the prime P")
|
||||
}
|
||||
|
||||
data = data[ECCRefMaxLen:]
|
||||
for i := 0; i < ECCRefMaxLen-byteSize; i++ {
|
||||
sum |= data[i]
|
||||
}
|
||||
if sum != 0 {
|
||||
return gerrors.WithAnnotatingf(ErrInvalidInput, "the y coordinate is more than %d bits", byteSize<<3)
|
||||
}
|
||||
y := k.Y.SetBytes(data[ECCRefMaxLen-byteSize : ECCRefMaxLen])
|
||||
if y.Cmp(Prime()) >= 0 {
|
||||
return gerrors.WithAnnotatingf(ErrInvalidInput, "the y coordinate is big than the prime P")
|
||||
}
|
||||
|
||||
// exclude the infinity point
|
||||
if gmath.IsBigInt0(x) && gmath.IsBigInt0(y) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !sm2Curve.IsOnCurve(x, y) {
|
||||
return gerrors.WithAnnotatingf(ErrInvalidInput, "the public key is not on the curve")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetToInf set k to infinity point
|
||||
func (k *PublicKey) SetToInf() {
|
||||
if k.X == nil {
|
||||
k.X = new(big.Int)
|
||||
}
|
||||
if k.Y == nil {
|
||||
k.Y = new(big.Int)
|
||||
}
|
||||
|
||||
k.X.SetInt64(0)
|
||||
k.Y.SetInt64(0)
|
||||
}
|
||||
|
||||
// Bytes 返回[x,y]的字节表示,共64字节,大端表示,不足用0补足。
|
||||
func (k *PublicKey) Bytes() []byte {
|
||||
k.X.Mod(k.X, orderN)
|
||||
k.Y.Mod(k.Y, orderN)
|
||||
r := make([]byte, 2*byteSize)
|
||||
_ = gmath.FillBytes(k.X, r[:byteSize])
|
||||
_ = gmath.FillBytes(k.Y, r[byteSize:])
|
||||
return r
|
||||
}
|
||||
|
||||
// SetBytes .
|
||||
func (k *PublicKey) SetBytes(buf []byte) error {
|
||||
if len(buf) < 2*byteSize {
|
||||
return gerrors.WithAnnotating(ErrInvalidInput, "input data too short")
|
||||
}
|
||||
if k.X == nil {
|
||||
k.X = new(big.Int)
|
||||
}
|
||||
if k.Y == nil {
|
||||
k.Y = new(big.Int)
|
||||
}
|
||||
|
||||
k.X.SetBytes(buf[:byteSize])
|
||||
k.Y.SetBytes(buf[byteSize : 2*byteSize])
|
||||
if !k.IsValid() {
|
||||
return gerrors.WithAnnotatingf(ErrInvalidInput, "the public key is not on the curve")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// String return a readable string
|
||||
func (k *PublicKey) String() string {
|
||||
buf, _ := k.MarshalBinary()
|
||||
return hex.EncodeToString(buf)
|
||||
}
|
||||
|
||||
// IsValid 返回公钥是否有效
|
||||
func (k *PublicKey) IsValid() bool {
|
||||
return sm2Curve.IsOnCurve(k.X, k.Y)
|
||||
}
|
||||
|
||||
// Normalize 公钥坐标mod n
|
||||
func (k *PublicKey) Normalize() {
|
||||
if k.X == nil {
|
||||
k.X = new(big.Int)
|
||||
}
|
||||
if k.Y == nil {
|
||||
k.Y = new(big.Int)
|
||||
}
|
||||
|
||||
k.X.Mod(k.X, OrderN())
|
||||
k.Y.Mod(k.Y, OrderN())
|
||||
}
|
||||
|
||||
// Equals 返回公钥是否相等
|
||||
func (k *PublicKey) Equals(p *PublicKey) bool {
|
||||
k.Normalize()
|
||||
p.Normalize()
|
||||
|
||||
return k.X.Cmp(p.X) == 0 && k.Y.Cmp(p.Y) == 0
|
||||
}
|
||||
|
||||
// Generate 生成公钥k = [d]G,
|
||||
// usage:
|
||||
//
|
||||
// pk := (&PublicKey{}).Generate(k)
|
||||
func (k *PublicKey) Generate(d *PrivateKey) *PublicKey {
|
||||
k.X, k.Y = sm2Curve.ScalarBaseMult(d.Bytes())
|
||||
k.Curve = sm2Curve
|
||||
return k
|
||||
}
|
||||
|
||||
// GenPublicKey 生成公钥。注意,返回的公钥指针是d.PublicKey
|
||||
func GenPublicKey(d *PrivateKey) *PublicKey {
|
||||
// We do't check that d < N
|
||||
// d.PublicKey.Curve = Curve()
|
||||
// d.PublicKey.X, d.PublicKey.Y = d.PublicKey.Curve.ScalarBaseMult(d.Bytes())
|
||||
// return &d.PublicKey
|
||||
return d.GenPublicKey()
|
||||
}
|
||||
|
||||
// GenerateKeyPairs return a key pair
|
||||
func GenerateKeyPairs(r io.Reader) (*PrivateKey, *PublicKey, error) {
|
||||
// b := make([]byte, byteSize)
|
||||
// if _, err := r.Read(b); err != nil {
|
||||
// return nil, nil, err
|
||||
// }
|
||||
sk, err := GenPrivateKey(r)
|
||||
if err != nil {
|
||||
return nil, nil, gerrors.WithAnnotating(err, "GenerateKeyPairs failed")
|
||||
}
|
||||
pk := GenPublicKey(sk)
|
||||
return sk, pk, nil
|
||||
}
|
||||
|
||||
// VerifyKeyPair verify if pk = [sk]·G
|
||||
func VerifyKeyPair(sk *PrivateKey, pk *PublicKey) bool {
|
||||
x, y := Curve().ScalarBaseMult(sk.Bytes())
|
||||
return pk.X.Cmp(x) == 0 && pk.Y.Cmp(y) == 0
|
||||
}
|
||||
Reference in New Issue
Block a user