package blockmode import ( "bytes" "crypto/cipher" "xdx.jelly/xgcl/gerrors" "xdx.jelly/xgcl/internal/xor" ) type ccm struct { b EcbCbcBlockMode nonceSize int tagSize int } const ccmStandardNonceSize = 12 const ccmTagSize = 16 const ccmBlockSize = 16 // NewCCM 返回AEAD实例,其中nonce为12字节,tag为16字节 func NewCCM(cipher EcbCbcBlockMode) (cipher.AEAD, error) { return NewCCMWithNonceAndTagSize(cipher, ccmStandardNonceSize, ccmTagSize) } // NewCCMWithNonceAndTagSize 返回一个AEAD接口对象,其中tagSize必须取值为4, 6, 8, 10, 12, 14, 16。 // nonceSize取值为{7,8,9,10,11,12,13} // // CCM不建议支持Init/Update/Final调用。需一次性将additional data 和plaintext/ciphertext传入。 func NewCCMWithNonceAndTagSize(cipher EcbCbcBlockMode, nonceSize, tagSize int) (cipher.AEAD, error) { if nonceSize < 7 || nonceSize > 13 { return nil, gerrors.WithAnnotating(ErrInvalidInput, "nonce size must be 7, 8, 9, 10, 11, 12 or 13") } if tagSize != 4 && tagSize != 6 && tagSize != 8 && tagSize != 10 && tagSize != 12 && tagSize != 14 && tagSize != 16 { return nil, gerrors.WithAnnotating(ErrInvalidInput, "tag size must be 4, 6, 8, 10, 12, 14 or 16") } return &ccm{ b: cipher, nonceSize: nonceSize, tagSize: tagSize, }, nil } // NonceSize 返回需要的nonce长度 func (c *ccm) NonceSize() int { return c.nonceSize } func (c *ccm) Overhead() int { return c.tagSize } func (c *ccm) Seal(dst, nonce, plaintext, additionalData []byte) []byte { B := format(nonce, additionalData, plaintext, c.tagSize) iv := make([]byte, ccmBlockSize) _ = c.b.CbcEncCryptBlocks(B, iv, B) tagMask := make([]byte, c.tagSize) copy(tagMask, B[len(B)-ccmBlockSize:]) N := c.deriveCounter(B[:0], nonce, len(plaintext)) _ = c.b.EcbEncCryptBlocks(N, N) p := xor.XorBytes(N, N, plaintext) tag := N[p : p+c.tagSize] xor.XorBytes(tag, tagMask, N[len(N)-ccmBlockSize:]) return append(dst, N[:p+c.tagSize]...) } func (c *ccm) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { p := len(ciphertext) - c.Overhead() ctrs := c.deriveCounter(nil, nonce, p) _ = c.b.EcbEncCryptBlocks(ctrs, ctrs) ret, plaintext := sliceForAppend(dst, p) xor.XorBytes(plaintext, ciphertext[:p], ctrs[:p]) y0 := ctrs[len(ctrs)-ccmBlockSize:] T := make([]byte, c.tagSize) xor.XorBytes(T, y0, ciphertext[p:]) B := format(nonce, additionalData, plaintext, c.tagSize) iv := make([]byte, ccmBlockSize) _ = c.b.CbcEncCryptBlocks(B, iv, B) if !bytes.Equal(T, B[len(B)-ccmBlockSize:len(B)-ccmBlockSize+c.tagSize]) { return dst, gerrors.WithAnnotating(ErrAEADTagCheckFailed, "CCM tag check failed") } return ret, nil } func format(nonce, ad, plain []byte, t int) []byte { n := len(nonce) a := uint64(len(ad)) p := len(plain) q := 15 - n B := make([]byte, 16) if a > 0 { B[0] |= 0x40 } B[0] |= ((byte(t>>1 - 1)) << 3) | byte(q-1) copy(B[1:], nonce) Q := B[1+n:] for i := q - 1; i >= 0; i-- { Q[i] = byte(p) p >>= 8 } // p = len(plain) if a > 0 { if a < (1<<16)-(1<<8) { B = append(B, byte(a>>8)) B = append(B, byte(a)) } else if a < (1 << 32) { B = append(B, []byte{0xff, 0xfe, byte(a >> 24), byte(a >> 16), byte(a >> 8), byte(a >> 0)}...) } else { B = append(B, []byte{0xff, 0xff, byte(a >> 56), byte(a >> 48), byte(a >> 40), byte(a >> 32), byte(a >> 24), byte(a >> 16), byte(a >> 8), byte(a >> 0)}...) } } B = append(B, ad...) paddingLen := ((len(B) + 15) >> 4) << 4 for i := len(B); i < paddingLen; i++ { B = append(B, 0) } B = append(B, plain...) paddingLen = ((len(B) + 15) >> 4) << 4 for i := len(B); i < paddingLen; i++ { B = append(B, 0) } return B } func (c *ccm) deriveCounter(counterBuf []byte, nonce []byte, p int) []byte { m := (p + ccmBlockSize - 1) >> 4 n := len(nonce) q := 15 - n if cap(counterBuf) < (m+1)*ccmBlockSize { counterBuf = make([]byte, (m+1)*ccmBlockSize) } ret := counterBuf[:(m+1)*ccmBlockSize] N := counterBuf[:m*ccmBlockSize] N0 := counterBuf[m*ccmBlockSize : (m+1)*ccmBlockSize] N0[0] = byte(q - 1) copy(N0[1:], nonce) for i := 1 + n; i < ccmBlockSize; i++ { N0[i] = 0 } N0[ccmBlockSize-1] = 1 for i := 0; i < m; i++ { copy(N[i*ccmBlockSize:], N0) for i := ccmBlockSize - 1; i >= 0; i-- { N0[i]++ if N0[i] != 0 { break } } } for i := ccmBlockSize - 1; i >= ccmBlockSize-q; i-- { N0[i] = 0 } return ret }