package bn256 import ( "math/big" ) // useLattice switch if we use lattice and GLV to accelerate computing. // In product version, it could be a const. // const useLattice = true var useLattice = true var half = new(big.Int).Rsh(N, 1) var curveLattice = &lattice{ vectors: [][]*big.Int{ {bigFromBase10("287113247090025866066532163502283840641"), bigFromBase10("13835058055293825813")}, {bigFromBase10("13835058055293825813"), bigFromBase10("-287113247090025866052697105446990014828")}, }, inverse: []*big.Int{ bigFromBase10("287113247090025866052697105446990014828"), bigFromBase10("13835058055293825813"), }, det: bigFromBase16("B640000002A3A6F1D603AB4FF58EC74449F2934B18EA8BEEE56EE19CD69ECF25"), } var targetLattice = &lattice{ vectors: [][]*big.Int{ {bigFromBase10("6917529027646912907"), bigFromBase10("6917529027646912906"), bigFromBase10("6917529027646912906"), bigFromBase10("-13835058055293825812")}, {bigFromBase10("13835058055293825813"), bigFromBase10("-6917529027646912906"), bigFromBase10("-6917529027646912907"), bigFromBase10("-6917529027646912906")}, {bigFromBase10("13835058055293825812"), bigFromBase10("13835058055293825813"), bigFromBase10("13835058055293825813"), bigFromBase10("13835058055293825813")}, {bigFromBase10("6917529027646912905"), bigFromBase10("27670116110587651626"), bigFromBase10("-13835058055293825811"), bigFromBase10("6917529027646912905")}, }, inverse: []*big.Int{ bigFromBase10("95704415696675288700373269546839468391"), bigFromBase10("3972228441934428951444310461776851119314861197558173512586"), bigFromBase10("1986114220967214475722155230888425559660889363292910212746"), bigFromBase10("-95704415696675288686538211491545642578"), }, det: new(big.Int).Set(N), } type lattice struct { vectors [][]*big.Int inverse []*big.Int det *big.Int } // decompose takes a scalar mod Order as input and finds a short, positive decomposition of it wrt to the lattice basis. // output: out[0] + lambda*out[1] = 0 mod n // [lambda](x,y) = (eta*x, y), eta^3 = 1 // TODO out[0] and out[1] should > 0 or ? func (l *lattice) decompose(k *big.Int) []*big.Int { n := len(l.inverse) // Calculate closest vector in lattice to with Babai's rounding. c := make([]*big.Int, n) for i := 0; i < n; i++ { c[i] = new(big.Int).Mul(k, l.inverse[i]) round(c[i], l.det) } // Transform vectors according to c and subtract . out := make([]*big.Int, n) temp := new(big.Int) for i := 0; i < n; i++ { out[i] = new(big.Int) for j := 0; j < n; j++ { temp.Mul(c[j], l.vectors[j][i]) out[i].Add(out[i], temp) } out[i].Neg(out[i]) //TODO why add out[i].Add(out[i], l.vectors[0][i]).Add(out[i], l.vectors[0][i]) } out[0].Add(out[0], k) return out } func (l *lattice) Precompute(add func(i, j uint)) { n := uint(len(l.vectors)) total := uint(1) << uint(n) for i := uint(0); i < n; i++ { for j := uint(0); j < total; j++ { if (j>>i)&1 == 1 { add(i, j) } } } } func (l *lattice) Multi(scalar *big.Int) []uint8 { decomp := l.decompose(scalar) maxLen := 0 for _, x := range decomp { if x.BitLen() > maxLen { maxLen = x.BitLen() } } out := make([]uint8, maxLen) for j, x := range decomp { for i := 0; i < maxLen; i++ { //out[i] += uint8(x.Bit(i)) << uint(j) out[i] += uint8(x.Abs(x).Bit(i)) << uint(j) } } return out } // round sets num to num/denom rounded to the nearest integer. func round(num, denom *big.Int) { r := new(big.Int) num.DivMod(num, denom, r) if r.Cmp(half) == 1 { num.Add(num, big.NewInt(1)) } }