Files
2026-05-27 23:03:00 +08:00

286 lines
7.0 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package ssss
import (
"bytes"
"encoding/asn1"
"io"
"math/big"
"xdx.jelly/xgcl/gerrors"
"xdx.jelly/xgcl/gmath"
)
const MaxSecretLength = 1024 / 8
// SharedSlice 秘密分片数据结构
//
// SharedSlice ::= SEQUENCE {
// Version INTEGER
// Threshold INTEGER
// Length INTEGER
// Degree INTEGER
// X INTEGER
// Y INTEGER
// Token [0]IMPLICIT OCTET STRING OPTINAL
// }
type SharedSlice struct {
Version int `asn1:"default:0"`
Threshold int // 门限
Length int // 原秘密消息长度
Degree int
X int // 分片x值
Y *big.Int // 分片f(x)值
Token []byte `asn1:"optional,implicit,tag:0"` // Optional 随机值,一个secret的n个分享值相同。
}
// Sharing 秘密分享
// secret 待分拆的秘密值,不超过MaxSecretLength字节(64),返回nShares个分享值。
// threshold, nShares, (t,n)门限值,例如(3,5)
// securityParam, 安全参数,取值128256512或1024. 取0则自动适配。如16字节secret对应128.
// rand, 随机数Reader
//
// 注: securityParam取值应不小于len(secret)*8, 若输入的securityParam < len(secret)*8,
// 则securityParam自动适配为 (len(secret) + 7)/8)*64
// 例: secret为16字节的SM4密钥,(3,5)门限分割
//
// shares, err := Split(secret, 3,5,0,rand.Reader)
// shares为5个share([]byte). 每个share为SharedSlice的ASN.1编码。
func Split(secret []byte, threshold, nShares int, securityParam int, rand io.Reader) ([][]byte, error) {
if len(secret) > MaxSecretLength {
return nil, gerrors.WithAnnotatingf(ErrSecretTooLarge, "secret is limits to %d bytes", MaxSecretLength)
}
if securityParam < len(secret)*8 {
securityParam = len(secret) * 8
}
switch {
case securityParam <= 128:
securityParam = 128
case securityParam <= 256:
securityParam = 256
case securityParam <= 512:
securityParam = 512
default:
securityParam = 1024
}
f := newGF2x(securityParam)
iSecret := new(big.Int).SetBytes(secret)
iShares, err := split(rand, threshold, nShares, iSecret, f)
if err != nil {
return nil, err
}
shares := make([][]byte, 0, nShares)
token := make([]byte, 16)
n, err := rand.Read(token)
if n < len(token) || err != nil {
token = nil
}
for k, v := range iShares {
slice := SharedSlice{
Version: 0,
Threshold: threshold,
Length: len(secret),
Degree: securityParam,
X: k,
Y: v,
Token: token,
}
b, err := asn1.Marshal(slice)
if err != nil {
return nil, gerrors.WithStack(err)
}
shares = append(shares, b)
}
return shares, nil
}
// Restore 秘密恢复,输入threshold个秘密分享值,返回 Secret.
// tShares 包含至少t个分拆值。否则返回错误。
func Restore(tShares [][]byte) ([]byte, error) {
iShares := make(map[int]*big.Int)
t := -1
length := -1
var token []byte
securityParam := -1
for _, s := range tShares {
slice := new(SharedSlice)
_, err := asn1.Unmarshal(s, slice)
if err != nil {
return nil, gerrors.ChainErrors(ErrInvalidInput, err)
}
if t == -1 {
t = slice.Threshold
} else if t != slice.Threshold {
return nil, gerrors.WithAnnotating(ErrBadShares, "share slices are not compatible")
}
if length == -1 {
length = slice.Length
} else if length != slice.Length {
return nil, gerrors.WithAnnotating(ErrBadShares, "share slices are not compatible")
}
if token == nil {
token = slice.Token
} else if !bytes.Equal(token, slice.Token) {
return nil, gerrors.WithAnnotating(ErrBadShares, "share slices are not compatible")
}
if securityParam == -1 {
securityParam = slice.Degree
} else if securityParam != slice.Degree {
return nil, gerrors.WithAnnotating(ErrBadShares, "share slices are not compatible")
}
if securityParam < slice.Y.BitLen() {
return nil, gerrors.WithAnnotating(ErrBadShares, "Share slices are broken")
}
iShares[slice.X] = slice.Y
}
f := newGF2x(securityParam)
iSecret, err := restore(t, iShares, f)
if err != nil {
return nil, gerrors.ChainErrors(ErrRestoreFailed, err)
}
secret := make([]byte, length)
if length*8 < iSecret.BitLen() {
return nil, gerrors.WithAnnotating(ErrBadShares, "Share slices are broken")
}
if err = gmath.FillBytes(iSecret, secret); err != nil {
return nil, gerrors.WithAnnotating(ErrBadShares, "Share slices are broken")
}
return secret, nil
}
// eval y = f(x), where f(X) = \sigma coeff[i]*X^i + X^t in field f
func horner(y, x *big.Int, coeff []*big.Int, f *gf2x) {
y.Set(x)
for i := len(coeff) - 1; i > 0; i-- {
f.add(y, y, coeff[i])
f.mul(y, y, x)
}
f.add(y, y, coeff[0])
}
// return F(1), F(2),...F(n), satisfy deg(F) = t - 1, F(0) = m
// m is the secret, deg(m) < f.degree
func split(rand io.Reader, t, n int, secret *big.Int, f *gf2x) (map[int]*big.Int, error) {
coeff := make([]*big.Int, t)
coeff[0] = secret
for i := 1; i < t; i++ {
if b, err := f.rand(rand); err != nil {
return nil, gerrors.WithStack(err)
} else {
coeff[i] = b
}
}
x := new(big.Int)
y := new(big.Int)
shares := make(map[int]*big.Int)
for i := 1; i <= n; i++ {
x.SetInt64(int64(i))
horner(y, x, coeff, f)
shares[i] = new(big.Int).Set(y)
}
return shares, nil
}
func restore(t int, shares map[int]*big.Int, f *gf2x) (*big.Int, error) {
if len(shares) < t {
return nil, ErrNeedMoreShares
}
A := make([][]*big.Int, 0, t)
y := make([]*big.Int, 0, t)
for i := 0; i < t; i++ {
row := make([]*big.Int, 0, t)
for j := 0; j < t; j++ {
row = append(row, new(big.Int))
}
A = append(A, row)
y = append(y, new(big.Int))
}
X := new(big.Int)
i := 0
for x, fx := range shares {
X.SetInt64(int64(x))
A[t-1][i].SetInt64(1)
for j := t - 2; j >= 0; j-- {
f.mul(A[j][i], A[j+1][i], X)
}
y[i].Set(fx)
f.mul(X, X, A[0][i])
f.add(y[i], y[i], X)
i++
if i >= t {
break
}
}
/*
Now A =
1^{t-1} 2^{t-1} ... t^{t-1}
...........
1^2 2^2 ... t^2
1 2 ... t
1 1 ... 1
and we have
(y[0],..., y[t-1]) = (a[0], ..., a[t-1]) * A.
All need to do is to solve a[0], which is the secret.
*/
if err := restore_secret(A, y, f); err != nil {
return nil, err
}
return y[t-1], nil
}
func restore_secret(A [][]*big.Int, y []*big.Int, f *gf2x) error {
n := len(A)
h := new(big.Int)
for i := 0; i < n; i++ {
var j int
if A[i][i].Sign() == 0 {
found := false
for j = i + 1; j < n; j++ {
if A[i][j].Sign() != 0 {
found = true
break
}
}
if !found {
return ErrSharesMayBeTheSame
}
for k := i; k < n; k++ {
A[k][i], A[k][j] = A[k][j], A[k][i]
}
y[i], y[j] = y[j], y[i]
}
for j = i + 1; j < n; j++ {
if A[i][j].Sign() != 0 {
for k := i + 1; k < n; k++ {
f.mul(h, A[k][i], A[i][j])
f.mul(A[k][j], A[k][j], A[i][i])
f.add(A[k][j], A[k][j], h)
}
f.mul(h, y[i], A[i][j])
f.mul(y[j], y[j], A[i][i])
f.add(y[j], y[j], h)
}
}
}
if A[n-1][n-1].Sign() == 0 {
return ErrSharesMayBeTheSame
}
if err := f.invert(h, A[n-1][n-1]); err != nil {
return gerrors.ChainErrors(ErrBadShares, err)
}
f.mul(y[n-1], y[n-1], h)
return nil
}