90 lines
2.1 KiB
Go
90 lines
2.1 KiB
Go
package sm2m
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
|
|
"xdx.jelly/xgcl/gmath"
|
|
"xdx.jelly/xgcl/sm/sm2"
|
|
)
|
|
|
|
var (
|
|
ErrIndataError = fmt.Errorf("data input error")
|
|
)
|
|
|
|
type ClientDecContext struct {
|
|
cipher sm2.Cipher
|
|
}
|
|
|
|
func NewClientDecContext() *ClientDecContext {
|
|
return &ClientDecContext{cipher: *sm2.NewCipher()}
|
|
}
|
|
|
|
// Initial 客户端解密,将密文中的C1发送给服务端
|
|
func (c *ClientDecContext) Initial(cipher *sm2.Cipher) (out []byte, err error) {
|
|
c.cipher.Set(cipher)
|
|
out = make([]byte, 0, 2*sm2.ByteSize())
|
|
out = append(out, gmath.BigIntToNByte(cipher.X, sm2.ByteSize())...)
|
|
out = append(out, gmath.BigIntToNByte(cipher.Y, sm2.ByteSize())...)
|
|
return out, nil
|
|
}
|
|
|
|
// Final 客户端解密得到密文
|
|
func (c *ClientDecContext) Final(clientKey *sm2.PrivateKey, in []byte) ([]byte, error) {
|
|
if len(in) < 2*sm2.ByteSize() {
|
|
return []byte{}, ErrIndataError
|
|
}
|
|
// d*C1 = dc^(-1)*ds^(-1)*C1 - C1
|
|
x := new(big.Int)
|
|
x.SetBytes(in[:sm2.ByteSize()])
|
|
y := new(big.Int)
|
|
y.SetBytes(in[sm2.ByteSize() : 2*sm2.ByteSize()])
|
|
cInv := new(big.Int)
|
|
cInv.Set(clientKey.D)
|
|
cInv.ModInverse(cInv, sm2.OrderN())
|
|
|
|
x, y = sm2.Curve256.ScalarMult(x, y, cInv.Bytes())
|
|
gx := new(big.Int)
|
|
gy := new(big.Int)
|
|
gx.Set(c.cipher.X)
|
|
gy.Set(c.cipher.Y)
|
|
gy.Sub(sm2.Prime(), gy)
|
|
x, y = sm2.Curve().Add(x, y, gx, gy)
|
|
|
|
return sm2.Decrypt_aux(x, y, &c.cipher)
|
|
}
|
|
|
|
func ServerDec(serverKey *sm2.PrivateKey, in []byte) ([]byte, error) {
|
|
return ServerImportKey(serverKey, in)
|
|
}
|
|
|
|
// 将加密密钥拆分为服务端和客户端的密钥分量
|
|
func SplitDecryptKey(de *sm2.PrivateKey, rnd io.Reader) (*sm2.PrivateKey, *sm2.PrivateKey, error) {
|
|
dc := new(big.Int).Add(de.D, gmath.BigInt1)
|
|
if dc.Cmp(sm2.OrderN()) == 0 {
|
|
return &sm2.PrivateKey{D: big.NewInt(0)}, &sm2.PrivateKey{D: big.NewInt(0)}, nil
|
|
}
|
|
|
|
var ds *big.Int
|
|
var err error
|
|
for {
|
|
ds, err = rand.Int(rnd, sm2.OrderN())
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if ds.Sign() > 0 {
|
|
break
|
|
}
|
|
}
|
|
|
|
dc.Mul(dc, ds)
|
|
dc.Mod(dc, sm2.OrderN())
|
|
dc.ModInverse(dc, sm2.OrderN())
|
|
|
|
skc := &sm2.PrivateKey{D: dc}
|
|
sks := &sm2.PrivateKey{D: ds}
|
|
return skc, sks, nil
|
|
}
|