package sm2 import ( "crypto/rand" "encoding/hex" "math/big" "strings" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" "xdx.jelly/xgcl/gerrors" "xdx.jelly/xgcl/gmath" "xdx.jelly/xgcl/sm/sm2/ec256" "xdx.jelly/xgcl/sm/sm3" ) // Signature 签名结构体 type Signature struct { R, S *big.Int } // NewSignature . func NewSignature() *Signature { return &Signature{ R: new(big.Int), S: new(big.Int), } } // MarshalUtil implements the gcl/util/encoding/UtilMarshaler interface func (sig *Signature) MarshalUtil(data []byte) ([]byte, error) { if data == nil { data = make([]byte, 0, 2*ECCRefMaxLen) } sig.R.Mod(sig.R, orderN) sig.S.Mod(sig.S, orderN) data = append(data, gmath.BigIntToNByte(sig.R, ECCRefMaxLen)...) data = append(data, gmath.BigIntToNByte(sig.S, ECCRefMaxLen)...) return data, nil } func (sig *Signature) UnmarshalUtil(data []byte) (uint64, error) { n := uint64(0) if len(data) < 2*ECCRefMaxLen { return 0, gerrors.WithAnnotating(ErrInvalidInput, "input data too short") } r := new(big.Int).SetBytes(data[:ECCRefMaxLen]) data = data[ECCRefMaxLen:] n += ECCRefMaxLen if r.Cmp(orderN) >= 0 || r.Sign() == 0 { return 0, gerrors.WithAnnotating(ErrInvalidInput, "r is zero or bigger than the order N") } s := new(big.Int).SetBytes(data[:ECCRefMaxLen]) data = data[ECCRefMaxLen:] n += ECCRefMaxLen if s.Cmp(orderN) >= 0 || s.Sign() == 0 { return 0, gerrors.WithAnnotating(ErrInvalidInput, "s is zero or bigger than the order N") } sig.R.Set(r) sig.S.Set(s) return n, nil } // MarshalBinary implements the encoding.BinaryMarshaler interface // r || s func (sig *Signature) MarshalBinary() ([]byte, error) { data := make([]byte, 2*ECCRefMaxLen) sig.R.Mod(sig.R, orderN) sig.S.Mod(sig.S, orderN) rBytes := sig.R.Bytes() copy(data[ECCRefMaxLen-len(rBytes):], rBytes) sBytes := sig.S.Bytes() copy(data[2*ECCRefMaxLen-len(sBytes):], sBytes) return data, nil } // UnmarshalBinary implements the encoding.BinaryUnmarshaler interface func (sig *Signature) UnmarshalBinary(data []byte) error { if len(data) != 2*ECCRefMaxLen { return gerrors.WithAnnotating(ErrInvalidInput, "input data too short") } r := new(big.Int).SetBytes(data[:ECCRefMaxLen]) if r.Cmp(orderN) >= 0 || r.Sign() == 0 { return gerrors.WithAnnotating(ErrInvalidInput, "r is zero or bigger than the order N") } s := new(big.Int).SetBytes(data[ECCRefMaxLen : 2*ECCRefMaxLen]) if s.Cmp(orderN) >= 0 || s.Sign() == 0 { return gerrors.WithAnnotating(ErrInvalidInput, "s is zero or bigger than the order N") } sig.R.Set(r) sig.S.Set(s) return nil } func (sig *Signature) MarshalASN1() ([]byte, error) { var b cryptobyte.Builder b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { b.AddASN1BigInt(sig.R) b.AddASN1BigInt(sig.S) }) return b.Bytes() } func (sig *Signature) UnmarshalASN1(data []byte) error { if sig.R == nil { sig.R = new(big.Int) } if sig.S == nil { sig.S = new(big.Int) } var inner cryptobyte.String input := cryptobyte.String(data) if !input.ReadASN1(&inner, asn1.SEQUENCE) || !input.Empty() || !inner.ReadASN1Integer(sig.R) || !inner.ReadASN1Integer(sig.S) || !inner.Empty() { return ErrDecodeASN1Failed } return nil } // SetBytes set Signature from a byte slice func (sig *Signature) SetBytes(rs []byte) error { if len(rs) < 2*byteSize { return gerrors.WithAnnotating(ErrInvalidInput, "input data too short") } if sig.R == nil { sig.R = new(big.Int) } if sig.S == nil { sig.S = new(big.Int) } sig.R.SetBytes(rs[:byteSize]) sig.S.SetBytes(rs[byteSize : 2*byteSize]) return nil } // Bytes return byte slice of a signature func (sig *Signature) Bytes() []byte { buf := make([]byte, 2*byteSize) _ = gmath.FillBytes(sig.R, buf[:byteSize]) _ = gmath.FillBytes(sig.S, buf[byteSize:]) return buf } // String return a readable string func (sig *Signature) String() string { var buf strings.Builder buf.WriteString("r: ") buf.WriteString(hex.EncodeToString(gmath.BigIntToNByte(sig.R, ECCRefMaxLen))) buf.WriteString("\ns: ") buf.WriteString(hex.EncodeToString(gmath.BigIntToNByte(sig.S, ECCRefMaxLen))) return buf.String() } // update a random k in case when Sign got a random integer k error // k's address are unchanged func update(k []byte) { hash := sm3.Sum(k) copy(k, hash[:]) } // fermatInverse calculates the inverse of k in GF(P) using Fermat's method. // This has better constant-time properties than Euclid's method (implemented // in math/big.Int.ModInverse) although math/big itself isn't strictly // constant-time so it's not perfect. // k = k^{-1} mod N func fermatInverse(k, N *big.Int) { // two := big.NewInt(2) nMinus2 := new(big.Int) nMinus2.Sub(N, gmath.BigInt2) k.Exp(k, nMinus2, N) } // Sign 签名 // e: sm3(Z || M),使用PreComputeWithIdAndPubkeyAndMessage计算 // k: 32字节随机数 func Sign(e, k []byte, privateKey *PrivateKey) (*Signature, error) { if len(e) < byteSize { return nil, gerrors.WithAnnotatingf(ErrInvalidInput, "input e should be of %d bytes, but it is %d bytes", byteSize, len(e)) } if len(k) < byteSize { return nil, gerrors.WithAnnotatingf(ErrInvalidInput, "input k should be of %d bytes, but it is %d bytes", byteSize, len(k)) } r := new(big.Int).SetBytes(e[:byteSize]) s := new(big.Int) intK := new(big.Int) // for only loop one time for almost all case for { intK.SetBytes(k[:byteSize]) if intK.Cmp(orderN) >= 0 { intK.Sub(intK, orderN) } if intK.Sign() == 0 { // omit the return error cause nMinusOne > 0 intK, _ = rand.Int(rand.Reader, nMinusOne) } // ScalarBaseMult is in constant-time x1, _ := sm2Curve.ScalarBaseMult(intK.Bytes()) r.Add(x1, r) // r = x1 + e r.Mod(r, orderN) // rearly happen if gmath.IsBigInt0(r) { goto Next } // s = (1+d)^(-1)(k + r - r - r*d)=(1+d)^(-1) * (k+r) - r s.Add(privateKey.D, gmath.BigInt1) // invert s mod N costs much time if f, ok := sm2Curve.(interface{ Inverse(*big.Int) *big.Int }); ok { s = f.Inverse(s) } else { // s.ModInverse(s, orderN) fermatInverse(s, orderN) } intK.Add(intK, r) // k + r < 2N and k + r > N is the most likely case if cmpResult := intK.Cmp(orderN); cmpResult > 0 { intK.Sub(intK, orderN) } else if cmpResult == 0 { goto Next } s.Mul(s, intK) s.Sub(s, r) s.Mod(s, orderN) break Next: // for another random k update(k[:byteSize]) continue } return &Signature{R: r, S: s}, nil } // Verify 验签 // pk:公钥,不做验证pk是否有效,另调用pk.IsValid()判断pk是否是在曲线上。 // // 当然如果pk无效,返回false // // e: sm3(Z || M),使用PreComputeWithIdAndPubkeyAndMessage计算 func Verify(e []byte, pk *PublicKey, sig *Signature) bool { if len(e) != byteSize { return false } r := sig.R s := sig.S if r.Sign() <= 0 || s.Sign() <= 0 || r.Cmp(orderN) >= 0 || s.Cmp(orderN) >= 0 { return false } x := pk.X y := pk.Y t := new(big.Int).Add(r, s) var x1, y1 *big.Int if false { t.Mod(t, orderN) x1, y1 = sm2Curve.ScalarBaseMult(sig.S.Bytes()) x2, y2 := sm2Curve.ScalarMult(x, y, t.Bytes()) x1, _ = sm2Curve.Add(x1, y1, x2, y2) } else { x1, _ = ec256.CombinedMult(x, y, sig.S.Bytes(), t.Bytes()) } t.SetBytes(e) x1.Add(x1, t) x1.Mod(x1, orderN) return x1.Cmp(r) == 0 }