172 lines
3.8 KiB
Go
172 lines
3.8 KiB
Go
// fpe is the Format-Preserving Encryption.
|
|
package fpe
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math/big"
|
|
"strings"
|
|
)
|
|
|
|
type Numeral = uint16
|
|
|
|
type Char = uint16
|
|
|
|
// Alphabet is the interface for an alphabet.
|
|
type Alphabet interface {
|
|
// fpe实际上是可对任意的字符表进行编码,比如JPEG2000,编码范围为0xff8f的序列,且不以0xff结尾。
|
|
// 但是这里的Encode只处理utf-8的string
|
|
Encode(s string) ([]Numeral, error)
|
|
Decode(numString []Numeral) (string, error)
|
|
|
|
// 如果是二进制类的数据
|
|
// EncodeBlob(b []byte) ([]Numeral, error)
|
|
// DecodeBlob(numString []Numeral) ([]byte, error)
|
|
|
|
Radix() int
|
|
Num(X []Numeral) *big.Int
|
|
|
|
// X = Str_radix^m(C)
|
|
Str(X []Numeral, x *big.Int)
|
|
}
|
|
|
|
var (
|
|
// ASCII code 32-126, all printable characters.
|
|
Printable = NewAlphabet(" !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~")
|
|
Numeric = NewAlphabet("0123456789")
|
|
Lower = NewAlphabet("abcdefghijklmnopqrstuvwxyz")
|
|
Upper = NewAlphabet("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
|
Alpha = NewAlphabet("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
|
NumAlpha = NewAlphabet("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
|
)
|
|
|
|
type FPE interface {
|
|
Encrypt(tweak []byte, X []Numeral) ([]Numeral, error)
|
|
Decrypt(tweak []byte, X []Numeral) ([]Numeral, error)
|
|
Alphabet
|
|
}
|
|
|
|
// GenericAlphabet 可以处理asicc字符集
|
|
type GenericAlphabet struct {
|
|
tblBuf [128]Char
|
|
tbl [][]Numeral
|
|
|
|
r Radix
|
|
}
|
|
|
|
// Assume the characters in alphabet are all ASCII now.
|
|
// Unicode characters not supported now.
|
|
func NewAlphabet(alphabet string) Alphabet {
|
|
r := len(alphabet)
|
|
|
|
res := &GenericAlphabet{}
|
|
res.r.Set(r)
|
|
|
|
tblBuf := res.tblBuf
|
|
for i := range tblBuf {
|
|
tblBuf[i] = 128
|
|
}
|
|
for _, d := range alphabet {
|
|
if d > 127 {
|
|
panic("only ASCII characters are supported now")
|
|
}
|
|
tblBuf[d] = Char(d)
|
|
}
|
|
tbls := make([][]Numeral, 0)
|
|
start := -1
|
|
for i, d := range tblBuf {
|
|
switch {
|
|
case d == 128:
|
|
if start >= 0 {
|
|
tbls = append(tbls, tblBuf[start:i])
|
|
start = -1
|
|
}
|
|
case d < 128:
|
|
if start < 0 {
|
|
start = i
|
|
}
|
|
}
|
|
}
|
|
res.tbl = tbls
|
|
return res
|
|
}
|
|
|
|
var numSwitch = false
|
|
|
|
// Num implements Alphabet.
|
|
func (ga *GenericAlphabet) Num(X []Char) *big.Int {
|
|
return ga.r.Num(X)
|
|
}
|
|
|
|
// x should < radix^(lenX)
|
|
func (ga *GenericAlphabet) Str(X []Numeral, x *big.Int) {
|
|
ga.r.Str(X, x)
|
|
}
|
|
|
|
// FromString implements Alphabet.
|
|
func (ga *GenericAlphabet) Encode(s string) ([]Numeral, error) {
|
|
res := make([]Numeral, 0, len(s))
|
|
for _, c := range s {
|
|
if c>>16 != 0 {
|
|
return nil, fmt.Errorf("unsupported character %c", c)
|
|
}
|
|
|
|
if idx := index(Char(c), ga.tbl); idx >= 0 {
|
|
res = append(res, Numeral(idx))
|
|
} else {
|
|
return nil, errors.New("bad characters")
|
|
}
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
// // Radix implements Alphabet.
|
|
func (ga *GenericAlphabet) Radix() int {
|
|
return int(ga.r.r)
|
|
}
|
|
|
|
// ToString implements Alphabet.
|
|
func (ga *GenericAlphabet) Decode(numString []Numeral) (string, error) {
|
|
var sb strings.Builder
|
|
radix := uint16(ga.r.r)
|
|
for _, d := range numString {
|
|
if d >= radix {
|
|
return "", errors.New("bad numeric string")
|
|
}
|
|
if _, err := sb.WriteRune(rune(char(d, ga.tbl))); err != nil {
|
|
return "", errors.New("bad rune")
|
|
}
|
|
}
|
|
return sb.String(), nil
|
|
}
|
|
|
|
var _ Alphabet = &GenericAlphabet{}
|
|
|
|
// tbls[i] are continuous Char
|
|
// tbls[0][0], tbls[0][1], ..., tbls[0][n0],
|
|
// tbls[1][0], tbls[1][1], ..., tbls[1][n1],
|
|
// ...
|
|
func index(c Char, tbls [][]Char) int {
|
|
n := 0
|
|
for _, tbl := range tbls {
|
|
n0 := int(c) - int(tbl[0])
|
|
if n0 >= 0 && n0 < len(tbl) {
|
|
return n + n0
|
|
}
|
|
n += len(tbl)
|
|
}
|
|
return -1
|
|
}
|
|
|
|
// Assume n < radix
|
|
func char(n Numeral, tbls [][]Char) Char {
|
|
m := int(n)
|
|
for _, tbl := range tbls {
|
|
if m < len(tbl) {
|
|
return tbl[m]
|
|
}
|
|
m -= len(tbl)
|
|
}
|
|
panic("numeric great than radix")
|
|
}
|