package x import ( "crypto/rand" "crypto/rsa" "sync" "errors" "io" "math/big" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" "xdx.jelly/xgcl/gmath" "xdx.jelly/xgcl/grand" ) // // PublicKey ::= INTEGER // // PrivateKey ::= SEQUENCE{ // lambda INTEGER, // N INTEGER // } // // Cipher ::= INTEGER type PublicKey struct { N *big.Int N2 *big.Int mu sync.Mutex // guard rn rn *big.Int } type PrivateKey struct { PublicKey lambda *big.Int // option lambdaInv *big.Int p, q *big.Int } func (k *PrivateKey) Public() *PublicKey { if k.PublicKey.N2 == nil { k.PublicKey.N2 = new(big.Int).Mul(k.PublicKey.N, k.PublicKey.N) } if k.rn == nil { k.Precompute(grand.Reader) } return &k.PublicKey } // Cipher the member variable D is for gcl's internal use. // You should never use it except you know what you're doing. type Cipher struct { D *big.Int } func newPrivateKey(p *big.Int, q *big.Int, r *big.Int, s *big.Int) *PrivateKey { sk := &PrivateKey{ PublicKey: PublicKey{ N: new(big.Int).Mul(p, q), }, lambda: new(big.Int), lambdaInv: new(big.Int), } N := sk.PublicKey.N N.Mul(N, r).Mul(N, s) pk := sk.Public() // lambda = LCM(p-1, q-1) or just (p-1)*(q-1) // lcm(sk.lambda, new(big.Int).Sub(p, gmath.BigInt1), new(big.Int).Sub(q, gmath.BigInt1)) sk.lambda.Mul(new(big.Int).Sub(p, gmath.BigInt1), new(big.Int).Sub(q, gmath.BigInt1)) sk.lambda.Mul(sk.lambda, new(big.Int).Sub(r, gmath.BigInt1)) sk.lambda.Mul(sk.lambda, new(big.Int).Sub(s, gmath.BigInt1)) sk.lambdaInv.ModInverse(sk.lambda, pk.N) return sk } // return l = lcm(a,b) func lcm(l, a, b *big.Int) { d := new(big.Int).GCD(nil, nil, a, b) l.Mul(a, b) l.Div(l, d) } // GenerateKey 生成同态密钥. bits 应为2048. // rnd_opt 可选, 忽略则使用grand.Reader. func GenerateKey(bits int, rnd_opt ...io.Reader) (*PrivateKey, *PublicKey, error) { var rnd io.Reader if len(rnd_opt) > 0 { rnd = rnd_opt[0] } else { rnd = grand.Reader } priKey := &PrivateKey{} rsaKey1, err := rsa.GenerateKey(rnd, bits) if err != nil { return nil, nil, err } rsaKey2, err := rsa.GenerateKey(rnd, bits) if err != nil { return nil, nil, err } // check rsaKey1 and rsaKey2 has no gcd > 1. priKey = newPrivateKey(rsaKey1.Primes[0], rsaKey1.Primes[1], rsaKey2.Primes[0], rsaKey2.Primes[1]) return priKey, &priKey.PublicKey, nil } func (k *PrivateKey) Decrypt(c *Cipher) (*big.Int, error) { if c.D == nil || c.D.Cmp(k.N2) >= 0 { return nil, errors.New("invalid cipher") } if k.lambdaInv == nil { k.lambdaInv = new(big.Int).ModInverse(k.lambda, k.N) } pk := k.Public() d := c.D L := new(big.Int) L.Exp(d, k.lambda, pk.N2) // L = c^lambda mod n^2 L.Sub(L, gmath.BigInt1) L.Div(L, pk.N) // L = (c^lambda - 1)/n L.Mul(L, k.lambdaInv) L.Mod(L, pk.N) return L, nil } func (k *PublicKey) Bits() int { return k.N.BitLen() } // Precompute computes r^N mod N^2 for a random r. func (k *PublicKey) Precompute(rnd io.Reader) error { if k.rn != nil { return nil } k.mu.Lock() defer k.mu.Unlock() if k.N2 == nil { k.N2 = new(big.Int).Mul(k.N, k.N) } if k.rn != nil { return nil } buf := make([]byte, k.N.BitLen()/2) _, err := rnd.Read(buf) if err != nil { return err } buf[0] |= 1 // at least one r := new(big.Int).SetBytes(buf) r.Exp(r, k.N, k.N2) // r = r^n mod n^2 k.rn = r return nil } func (k *PublicKey) Encrypt(m *big.Int, rnd io.Reader) (*Cipher, error) { if true { k.Precompute(grand.Reader) rn := new(big.Int).SetBytes(grand.GetRandom(12)) rn.Exp(k.rn, rn, k.N2) d := new(big.Int) d.Mul(k.N, m) // d = n * m d.Add(d, gmath.BigInt1) // d = nm + 1 d.Mul(d, rn) // c = (nm+1)*r^n mod n^2 d.Mod(d, k.N2) return &Cipher{D: d}, nil } else { // r should select from (1, N) var r *big.Int var err error for { r, err = rand.Int(rnd, k.N) if err != nil { return nil, err } // should we check gcd(r,n) == 1? if r.Sign() > 0 { break } } d := new(big.Int) d.Mul(k.N, m) // d = n * m d.Add(d, gmath.BigInt1) // d = nm + 1 r.Exp(r, k.N, k.N2) // r = r^n mod n^2 // r.SetUint64(1) d.Mul(d, r) // c = (nm+1)*r^n mod n^2 d.Mod(d, k.N2) return &Cipher{D: d}, nil } } func Encrypt(m *big.Int, publicKey *PublicKey, rnd io.Reader) (*Cipher, error) { return publicKey.Encrypt(m, rnd) } func Decrypt(c *Cipher, key *PrivateKey) (*big.Int, error) { return key.Decrypt(c) } func (c *Cipher) Marshal() ([]byte, error) { if c.D == nil { return nil, errors.New("empty cipher") } var b cryptobyte.Builder b.AddASN1BigInt(c.D) return b.Bytes() } func (c *Cipher) Unmarshal(b []byte) error { if c.D == nil { c.D = new(big.Int) } input := cryptobyte.String(b) if !input.ReadASN1Integer(c.D) { return errors.New("parse ASN.1 cipher failed") } return nil } // round up n to the nearest power of 2. func roundup(n uint64) uint64 { n-- n |= n >> 1 n |= n >> 2 n |= n >> 4 n |= n >> 8 n |= n >> 16 n |= n >> 32 return n + 1 } func (k *PublicKey) Size() uint { return uint(roundup(uint64(k.N.BitLen()))) } // Marshal 编码公钥, PublicKey ::= INTEGER func (k *PublicKey) Marshal() ([]byte, error) { if k.N == nil { return nil, errors.New("empty public key") } var b cryptobyte.Builder b.AddASN1BigInt(k.N) return b.Bytes() } func (k *PublicKey) Unmarshal(b []byte) error { if k.N == nil || k.N2 == nil { k.N = new(big.Int) k.N2 = new(big.Int) } input := cryptobyte.String(b) if !input.ReadASN1Integer(k.N) { return errors.New("parse ASN.1 public key failed") } k.N2.Mul(k.N, k.N) return nil } // MarshalExt 编码公钥,包括预计算数据, PublicKey ::= INTEGER func (k *PublicKey) MarshalExt() ([]byte, error) { if k.N == nil { return nil, errors.New("empty public key") } k.Precompute(grand.Reader) var b cryptobyte.Builder b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { b.AddASN1BigInt(k.N) b.AddASN1BigInt(k.rn) }) return b.Bytes() } func (k *PublicKey) UnmarshalExt(b []byte) error { if k.N == nil || k.N2 == nil { k.N = new(big.Int) k.N2 = new(big.Int) } if k.rn == nil { k.rn = new(big.Int) } input := cryptobyte.String(b) var inner cryptobyte.String if !input.ReadASN1(&inner, asn1.SEQUENCE) || !inner.ReadASN1Integer(k.N) || !inner.ReadASN1Integer(k.rn) { return errors.New("read ASN.1 private key failed") } k.N2.Mul(k.N, k.N) return nil } // Marshal 编码私钥 func (k *PrivateKey) Marshal() ([]byte, error) { if k.lambda == nil { return nil, errors.New("empty private key") } var b cryptobyte.Builder b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { b.AddASN1BigInt(k.lambda) b.AddASN1BigInt(k.N) }) return b.Bytes() } // Unmarshal 解码私钥 func (k *PrivateKey) Unmarshal(b []byte) error { if k.lambda == nil || k.N == nil { k.lambda = new(big.Int) k.N = new(big.Int) } input := cryptobyte.String(b) var inner cryptobyte.String if !input.ReadASN1(&inner, asn1.SEQUENCE) || !inner.ReadASN1Integer(k.lambda) || !inner.ReadASN1Integer(k.N) { return errors.New("read ASN.1 private key failed") } if k.lambdaInv == nil { k.lambdaInv = new(big.Int) } k.lambdaInv.ModInverse(k.lambda, k.N) _ = k.Public() return nil }