/// /// Copyright (c) 2018 xdx. All rights reserved. /// /// \file: key.go /// /// \brief: SM2密钥结构 /// /// \author: xdx /// package sm2 import ( "crypto/elliptic" "crypto/rand" "encoding/binary" "encoding/hex" "io" "math/big" "xdx.jelly/xgcl/gerrors" "xdx.jelly/xgcl/gmath" "xdx.jelly/xgcl/grand" "xdx.jelly/xgcl/internal" ) ////////////////////////////////////////////////////////// PrivateKey // PrivateKey 私钥, 使用NewPrivateKey生成 type PrivateKey struct { // it is dangerous that &PrivateKey = &PrivateKey.PublicKey // We sometimes pass &PrivateKey.PublicKey to caller and he could read the D. // It is better to use *PublicKey. // But we are along with crypto/ecdsa.PrivateKey PublicKey D *big.Int // use D (not anonymous) to hide big.Int's methods } // NewPrivateKey return a new PrivateKey instance // Equal to use PrivateKey{} func NewPrivateKey() *PrivateKey { return &PrivateKey{ D: new(big.Int), } } // Clear zero the privatekey, also it implements the gmath.Clearable interface func (k *PrivateKey) Clear() { gmath.ClearBigInt(k.D) } // Get return k.D, if k.D is nil, new one and return it func (k *PrivateKey) Get() *big.Int { if k.D == nil { k.D = new(big.Int) } return k.D } // SetString set k to s. func (k *PrivateKey) SetString(s string, base int) (*PrivateKey, bool) { if k.D == nil { k.D = new(big.Int) } if _, ok := k.D.SetString(s, base); !ok { return k, ok } k.computePublicKeyUncheck() return k, true } func (k *PrivateKey) SetBigInt(b *big.Int) *PrivateKey { if k.D == nil { k.D = new(big.Int) } k.D.Set(b) k.computePublicKeyUncheck() return k } func (k *PrivateKey) Set(x *PrivateKey) *PrivateKey { if x == nil { return k } if k.D == nil { k.D = new(big.Int) } k.D.Set(x.D) k.PublicKey.Set(&x.PublicKey) return k } // GenPublicKey generete the k.PublicKey and return it. func (k *PrivateKey) GenPublicKey() *PublicKey { k.computePublicKeyUncheck() return &k.PublicKey } func (k *PrivateKey) computePublicKeyUncheck() { k.PublicKey.Curve = Curve() k.PublicKey.X, k.PublicKey.Y = Curve().ScalarBaseMult(k.D.Bytes()) } // Random set k to a random key func (k *PrivateKey) Random(r io.Reader) *PrivateKey { N := Curve().Params().N if k.D == nil { k.D = new(big.Int) } for { d, err := rand.Int(r, N) if err != nil { return nil } if d.Sign() != 0 { k.D = d break } } k.computePublicKeyUncheck() return k } // MarshalUtil implements the gcl/util/encoding/UtilMarshaler interface func (k *PrivateKey) MarshalUtil(data []byte) ([]byte, error) { if data == nil { data = make([]byte, 0, 4+ECCRefMaxLen) } if k.D == nil { k.D = new(big.Int) } k.D.Mod(k.D, orderN) data, tail := internal.SliceForAppend(data, 4) Endian.PutUint32(tail, uint32(byteSize<<3)) data, tail = internal.SliceForAppend(data, ECCRefMaxLen) _ = gmath.FillBytes(k.D, tail) return data, nil } // UnmarshalUtil revert of MarshalUtil func (k *PrivateKey) UnmarshalUtil(data []byte) (uint64, error) { n := uint64(0) if len(data) < 4+ECCRefMaxLen { return 0, gerrors.WithAnnotating(ErrInvalidInput, "input data too short") } if ECCStrict { bits := binary.BigEndian.Uint32(data[:4]) if bits != uint32(byteSize<<3) { return 0, gerrors.WithAnnotating(ErrInvalidInput, "may be bits are little endian") } } data = data[4:] n += 4 if k.D == nil { k.D = new(big.Int) } k.D.SetBytes(data[:ECCRefMaxLen]) if k.D.Cmp(orderN) >= 0 || k.D.Sign() == 0 { return 0, gerrors.WithAnnotating(ErrInvalidInput, "the input private key as an integer is bigger than the order N") } n += ECCRefMaxLen k.computePublicKeyUncheck() return n, nil } // MarshalBinary implements the encoding.BinaryMarshaler interface // 返回字节符合GMT 0018的定义。 func (k *PrivateKey) MarshalBinary() ([]byte, error) { return k.MarshalUtil(nil) } // UnmarshalBinary implements the encoding.BinaryUnmarshaler interface // 返回字节符合GMT 0018的定义。 // 注意:若返回错误,则k的值未定义。 func (k *PrivateKey) UnmarshalBinary(data []byte) error { _, err := k.UnmarshalUtil(data) return gerrors.WithStack(err) } // Bytes return the big-endian of privateKey, of byteSize bytes(aka. 32 bytes), padding 0 in the leading func (k *PrivateKey) Bytes() []byte { r := make([]byte, byteSize) k.D.Mod(k.D, orderN) _ = gmath.FillBytes(k.D, r) return r } // SetBytes set buf to k, invert of Bytes() func (k *PrivateKey) SetBytes(buf []byte) error { if len(buf) < byteSize { return gerrors.WithAnnotating(ErrInvalidInput, "input too small") } if k.D == nil { k.D = new(big.Int) } k.D.SetBytes(buf) // Mod unnessary? k.D.Mod(k.D, orderN) k.computePublicKeyUncheck() return nil } // String return a readable string func (k *PrivateKey) String() string { if k.D == nil { return "" } // return k.Int.Text(16) return hex.EncodeToString(gmath.BigIntToNByte(k.D, byteSize)) } // GenPrivateKey 生成私钥。 // rnd should be nil, io.Reader or []byte. // If rnd == nil, use package grand to generates random numbers. func GenPrivateKey(rnd any) (*PrivateKey, error) { if rnd == nil { return GenerateKey(Curve(), grand.Reader) } if rnd, ok := rnd.(io.Reader); ok { return GenerateKey(Curve(), rnd) } if b, ok := rnd.([]byte); ok { D := new(big.Int).SetBytes(b) D.Mod(D, OrderN()) if D.Sign() == 0 { // D.SetInt64(1) return nil, gerrors.WithAnnotating(ErrRandomError, "input random bytes invalid, recall with another random bytes") } sk := &PrivateKey{D: D} sk.computePublicKeyUncheck() return sk, nil } panic("GenPrivateKey: input rnd must bi nil, io.Reader or []byte") } ////////////////////////////////////////////////////////// PublicKey // PublicKey 公钥 type PublicKey struct { elliptic.Curve X, Y *big.Int } // Set k=x func (k *PublicKey) Set(x *PublicKey) *PublicKey { if k.X == nil { k.X = new(big.Int) } if k.Y == nil { k.Y = new(big.Int) } k.X.Set(x.X) k.Y.Set(x.Y) return k } // NewPublicKey return a new PublicKey instance func NewPublicKey() *PublicKey { return &PublicKey{X: new(big.Int), Y: new(big.Int)} } func (k *PublicKey) MarshalASN1() ([]byte, error) { return elliptic.Marshal(Curve(), k.X, k.Y), nil } func (k *PublicKey) UnmarshalASN1(b []byte) error { x, y := elliptic.Unmarshal(Curve(), b) if x == nil || y == nil { return gerrors.WithAnnotating(ErrInvalidPoint, "x or y coordinates are 0") } k.Curve = Curve() k.X = x k.Y = y return nil } // MarshalUtil implements the gcl/util/encoding/UtilMarshaler interface func (k *PublicKey) MarshalUtil(data []byte) ([]byte, error) { if data == nil { data = make([]byte, 0, 4+2*ECCRefMaxLen) } buf := []byte{0, 0, 0, 0} // 密钥位长应该是N的比特数 Endian.PutUint32(buf, uint32(byteSize)<<3) data = append(data, buf...) k.X.Mod(k.X, orderN) xBytes := gmath.BigIntToNByte(k.X, ECCRefMaxLen) data = append(data, xBytes...) k.Y.Mod(k.Y, orderN) yBytes := gmath.BigIntToNByte(k.Y, ECCRefMaxLen) data = append(data, yBytes...) return data, nil } // UnmarshalUtil implements the gcl/util/encoding/UnmarshalUtil interface func (k *PublicKey) UnmarshalUtil(data []byte) (uint64, error) { n := uint64(0) if len(data) < 4+2*ECCRefMaxLen { return 0, gerrors.WithAnnotating(ErrInvalidInput, "input too small") } if ECCStrict { bits := binary.BigEndian.Uint32(data[:4]) if bits != uint32(byteSize)<<3 { return 0, gerrors.WithAnnotating(ErrInvalidInput, "may be bits are little endian") } } data = data[4:] n += 4 if k.X == nil { k.X = new(big.Int) } if k.Y == nil { k.Y = new(big.Int) } x := k.X.SetBytes(data[:ECCRefMaxLen]) if x.Cmp(orderN) >= 0 { return 0, gerrors.WithAnnotating(ErrInvalidInput, "the x coordinate is bigger than the order N") } data = data[ECCRefMaxLen:] n += ECCRefMaxLen y := k.Y.SetBytes(data[:ECCRefMaxLen]) if y.Cmp(orderN) >= 0 { return 0, gerrors.WithAnnotating(ErrInvalidInput, "the y coordinate is bigger than the order N") } data = data[ECCRefMaxLen:] //nolint n += ECCRefMaxLen // exclude the infinity point if gmath.IsBigInt0(x) && gmath.IsBigInt0(y) { return 0, gerrors.WithAnnotating(ErrInvalidInput, "the x or y coordinate is 0") } if !sm2Curve.IsOnCurve(x, y) { return 0, gerrors.WithStack(ErrInvalidPoint) } return n, nil } // MarshalBinary implements the encoding.BinaryMarshaler interface // 返回字节符合GMT 0018的定义。 func (k *PublicKey) MarshalBinary() ([]byte, error) { ret := make([]byte, 4+2*ECCRefMaxLen) // 密钥位长应该是N的比特数 Endian.PutUint32(ret[:4], uint32(byteSize)<<3) k.X.Mod(k.X, orderN) xBytes := k.X.Bytes() copy(ret[4+ECCRefMaxLen-len(xBytes):], xBytes) k.Y.Mod(k.Y, orderN) yBytes := k.Y.Bytes() copy(ret[4+2*ECCRefMaxLen-len(yBytes):], yBytes) return ret, nil } // UnmarshalBinary implements the encoding.BinaryUnmarshaler interface // 返回字节符合GMT 0018的定义。 // 特殊情形:若data表示无穷远点,即00000100 0000...000, // 返回k.X=k.Y=0 // 若返回错误,k的值未定义 func (k *PublicKey) UnmarshalBinary(data []byte) error { if len(data) != 4+2*ECCRefMaxLen { return gerrors.WithAnnotating(ErrInvalidInput, "input data too short") } k.Curve = sm2Curve if ECCStrict { bits := binary.BigEndian.Uint32(data[:4]) if bits != uint32(byteSize)<<3 { return gerrors.WithAnnotating(ErrInvalidInput, "input bits may be little endian, use big endian instead") } } if k.X == nil { k.X = new(big.Int) } if k.Y == nil { k.Y = new(big.Int) } data = data[4:] var sum byte for i := 0; i < ECCRefMaxLen-byteSize; i++ { sum |= data[i] } if sum != 0 { return gerrors.WithAnnotatingf(ErrInvalidInput, "the x coordinate is more than %d bits", byteSize<<3) } x := k.X.SetBytes(data[ECCRefMaxLen-byteSize : ECCRefMaxLen]) if x.Cmp(Prime()) >= 0 { return gerrors.WithAnnotatingf(ErrInvalidInput, "the x coordinate is big than the prime P") } data = data[ECCRefMaxLen:] for i := 0; i < ECCRefMaxLen-byteSize; i++ { sum |= data[i] } if sum != 0 { return gerrors.WithAnnotatingf(ErrInvalidInput, "the y coordinate is more than %d bits", byteSize<<3) } y := k.Y.SetBytes(data[ECCRefMaxLen-byteSize : ECCRefMaxLen]) if y.Cmp(Prime()) >= 0 { return gerrors.WithAnnotatingf(ErrInvalidInput, "the y coordinate is big than the prime P") } // exclude the infinity point if gmath.IsBigInt0(x) && gmath.IsBigInt0(y) { return nil } if !sm2Curve.IsOnCurve(x, y) { return gerrors.WithAnnotatingf(ErrInvalidInput, "the public key is not on the curve") } return nil } // SetToInf set k to infinity point func (k *PublicKey) SetToInf() { if k.X == nil { k.X = new(big.Int) } if k.Y == nil { k.Y = new(big.Int) } k.X.SetInt64(0) k.Y.SetInt64(0) } // Bytes 返回[x,y]的字节表示,共64字节,大端表示,不足用0补足。 func (k *PublicKey) Bytes() []byte { k.X.Mod(k.X, orderN) k.Y.Mod(k.Y, orderN) r := make([]byte, 2*byteSize) _ = gmath.FillBytes(k.X, r[:byteSize]) _ = gmath.FillBytes(k.Y, r[byteSize:]) return r } // SetBytes . func (k *PublicKey) SetBytes(buf []byte) error { if len(buf) < 2*byteSize { return gerrors.WithAnnotating(ErrInvalidInput, "input data too short") } if k.X == nil { k.X = new(big.Int) } if k.Y == nil { k.Y = new(big.Int) } k.X.SetBytes(buf[:byteSize]) k.Y.SetBytes(buf[byteSize : 2*byteSize]) if !k.IsValid() { return gerrors.WithAnnotatingf(ErrInvalidInput, "the public key is not on the curve") } return nil } // String return a readable string func (k *PublicKey) String() string { buf, _ := k.MarshalBinary() return hex.EncodeToString(buf) } // IsValid 返回公钥是否有效 func (k *PublicKey) IsValid() bool { return sm2Curve.IsOnCurve(k.X, k.Y) } // Normalize 公钥坐标mod n func (k *PublicKey) Normalize() { if k.X == nil { k.X = new(big.Int) } if k.Y == nil { k.Y = new(big.Int) } k.X.Mod(k.X, OrderN()) k.Y.Mod(k.Y, OrderN()) } // Equals 返回公钥是否相等 func (k *PublicKey) Equals(p *PublicKey) bool { k.Normalize() p.Normalize() return k.X.Cmp(p.X) == 0 && k.Y.Cmp(p.Y) == 0 } // Generate 生成公钥k = [d]G, // usage: // // pk := (&PublicKey{}).Generate(k) func (k *PublicKey) Generate(d *PrivateKey) *PublicKey { k.X, k.Y = sm2Curve.ScalarBaseMult(d.Bytes()) k.Curve = sm2Curve return k } // GenPublicKey 生成公钥。注意,返回的公钥指针是d.PublicKey func GenPublicKey(d *PrivateKey) *PublicKey { // We do't check that d < N // d.PublicKey.Curve = Curve() // d.PublicKey.X, d.PublicKey.Y = d.PublicKey.Curve.ScalarBaseMult(d.Bytes()) // return &d.PublicKey return d.GenPublicKey() } // GenerateKeyPairs return a key pair func GenerateKeyPairs(r io.Reader) (*PrivateKey, *PublicKey, error) { // b := make([]byte, byteSize) // if _, err := r.Read(b); err != nil { // return nil, nil, err // } sk, err := GenPrivateKey(r) if err != nil { return nil, nil, gerrors.WithAnnotating(err, "GenerateKeyPairs failed") } pk := GenPublicKey(sk) return sk, pk, nil } // VerifyKeyPair verify if pk = [sk]·G func VerifyKeyPair(sk *PrivateKey, pk *PublicKey) bool { x, y := Curve().ScalarBaseMult(sk.Bytes()) return pk.X.Cmp(x) == 0 && pk.Y.Cmp(y) == 0 }