linux/arch/arm64/crypto/aes-glue.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
   4 *
   5 * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
   6 */
   7
   8#include <asm/neon.h>
   9#include <asm/hwcap.h>
  10#include <asm/simd.h>
  11#include <crypto/aes.h>
  12#include <crypto/ctr.h>
  13#include <crypto/sha2.h>
  14#include <crypto/internal/hash.h>
  15#include <crypto/internal/simd.h>
  16#include <crypto/internal/skcipher.h>
  17#include <crypto/scatterwalk.h>
  18#include <linux/module.h>
  19#include <linux/cpufeature.h>
  20#include <crypto/xts.h>
  21
  22#include "aes-ce-setkey.h"
  23
  24#ifdef USE_V8_CRYPTO_EXTENSIONS
  25#define MODE                    "ce"
  26#define PRIO                    300
  27#define aes_expandkey           ce_aes_expandkey
  28#define aes_ecb_encrypt         ce_aes_ecb_encrypt
  29#define aes_ecb_decrypt         ce_aes_ecb_decrypt
  30#define aes_cbc_encrypt         ce_aes_cbc_encrypt
  31#define aes_cbc_decrypt         ce_aes_cbc_decrypt
  32#define aes_cbc_cts_encrypt     ce_aes_cbc_cts_encrypt
  33#define aes_cbc_cts_decrypt     ce_aes_cbc_cts_decrypt
  34#define aes_essiv_cbc_encrypt   ce_aes_essiv_cbc_encrypt
  35#define aes_essiv_cbc_decrypt   ce_aes_essiv_cbc_decrypt
  36#define aes_ctr_encrypt         ce_aes_ctr_encrypt
  37#define aes_xctr_encrypt        ce_aes_xctr_encrypt
  38#define aes_xts_encrypt         ce_aes_xts_encrypt
  39#define aes_xts_decrypt         ce_aes_xts_decrypt
  40#define aes_mac_update          ce_aes_mac_update
  41MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 Crypto Extensions");
  42#else
  43#define MODE                    "neon"
  44#define PRIO                    200
  45#define aes_ecb_encrypt         neon_aes_ecb_encrypt
  46#define aes_ecb_decrypt         neon_aes_ecb_decrypt
  47#define aes_cbc_encrypt         neon_aes_cbc_encrypt
  48#define aes_cbc_decrypt         neon_aes_cbc_decrypt
  49#define aes_cbc_cts_encrypt     neon_aes_cbc_cts_encrypt
  50#define aes_cbc_cts_decrypt     neon_aes_cbc_cts_decrypt
  51#define aes_essiv_cbc_encrypt   neon_aes_essiv_cbc_encrypt
  52#define aes_essiv_cbc_decrypt   neon_aes_essiv_cbc_decrypt
  53#define aes_ctr_encrypt         neon_aes_ctr_encrypt
  54#define aes_xctr_encrypt        neon_aes_xctr_encrypt
  55#define aes_xts_encrypt         neon_aes_xts_encrypt
  56#define aes_xts_decrypt         neon_aes_xts_decrypt
  57#define aes_mac_update          neon_aes_mac_update
  58MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 NEON");
  59#endif
  60#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
  61MODULE_ALIAS_CRYPTO("ecb(aes)");
  62MODULE_ALIAS_CRYPTO("cbc(aes)");
  63MODULE_ALIAS_CRYPTO("ctr(aes)");
  64MODULE_ALIAS_CRYPTO("xts(aes)");
  65MODULE_ALIAS_CRYPTO("xctr(aes)");
  66#endif
  67MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
  68MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
  69MODULE_ALIAS_CRYPTO("cmac(aes)");
  70MODULE_ALIAS_CRYPTO("xcbc(aes)");
  71MODULE_ALIAS_CRYPTO("cbcmac(aes)");
  72
  73MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
  74MODULE_LICENSE("GPL v2");
  75
  76/* defined in aes-modes.S */
  77asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
  78                                int rounds, int blocks);
  79asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
  80                                int rounds, int blocks);
  81
  82asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
  83                                int rounds, int blocks, u8 iv[]);
  84asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
  85                                int rounds, int blocks, u8 iv[]);
  86
  87asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
  88                                int rounds, int bytes, u8 const iv[]);
  89asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
  90                                int rounds, int bytes, u8 const iv[]);
  91
  92asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
  93                                int rounds, int bytes, u8 ctr[]);
  94
  95asmlinkage void aes_xctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
  96                                 int rounds, int bytes, u8 ctr[], int byte_ctr);
  97
  98asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
  99                                int rounds, int bytes, u32 const rk2[], u8 iv[],
 100                                int first);
 101asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
 102                                int rounds, int bytes, u32 const rk2[], u8 iv[],
 103                                int first);
 104
 105asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
 106                                      int rounds, int blocks, u8 iv[],
 107                                      u32 const rk2[]);
 108asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
 109                                      int rounds, int blocks, u8 iv[],
 110                                      u32 const rk2[]);
 111
 112asmlinkage int aes_mac_update(u8 const in[], u32 const rk[], int rounds,
 113                              int blocks, u8 dg[], int enc_before,
 114                              int enc_after);
 115
 116struct crypto_aes_xts_ctx {
 117        struct crypto_aes_ctx key1;
 118        struct crypto_aes_ctx __aligned(8) key2;
 119};
 120
 121struct crypto_aes_essiv_cbc_ctx {
 122        struct crypto_aes_ctx key1;
 123        struct crypto_aes_ctx __aligned(8) key2;
 124        struct crypto_shash *hash;
 125};
 126
 127struct mac_tfm_ctx {
 128        struct crypto_aes_ctx key;
 129        u8 __aligned(8) consts[];
 130};
 131
 132struct mac_desc_ctx {
 133        unsigned int len;
 134        u8 dg[AES_BLOCK_SIZE];
 135};
 136
 137static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 138                               unsigned int key_len)
 139{
 140        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 141
 142        return aes_expandkey(ctx, in_key, key_len);
 143}
 144
 145static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
 146                                      const u8 *in_key, unsigned int key_len)
 147{
 148        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 149        int ret;
 150
 151        ret = xts_verify_key(tfm, in_key, key_len);
 152        if (ret)
 153                return ret;
 154
 155        ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
 156        if (!ret)
 157                ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
 158                                    key_len / 2);
 159        return ret;
 160}
 161
 162static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
 163                                            const u8 *in_key,
 164                                            unsigned int key_len)
 165{
 166        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 167        u8 digest[SHA256_DIGEST_SIZE];
 168        int ret;
 169
 170        ret = aes_expandkey(&ctx->key1, in_key, key_len);
 171        if (ret)
 172                return ret;
 173
 174        crypto_shash_tfm_digest(ctx->hash, in_key, key_len, digest);
 175
 176        return aes_expandkey(&ctx->key2, digest, sizeof(digest));
 177}
 178
 179static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
 180{
 181        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 182        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 183        int err, rounds = 6 + ctx->key_length / 4;
 184        struct skcipher_walk walk;
 185        unsigned int blocks;
 186
 187        err = skcipher_walk_virt(&walk, req, false);
 188
 189        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 190                kernel_neon_begin();
 191                aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 192                                ctx->key_enc, rounds, blocks);
 193                kernel_neon_end();
 194                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 195        }
 196        return err;
 197}
 198
 199static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
 200{
 201        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 202        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 203        int err, rounds = 6 + ctx->key_length / 4;
 204        struct skcipher_walk walk;
 205        unsigned int blocks;
 206
 207        err = skcipher_walk_virt(&walk, req, false);
 208
 209        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 210                kernel_neon_begin();
 211                aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 212                                ctx->key_dec, rounds, blocks);
 213                kernel_neon_end();
 214                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 215        }
 216        return err;
 217}
 218
 219static int cbc_encrypt_walk(struct skcipher_request *req,
 220                            struct skcipher_walk *walk)
 221{
 222        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 223        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 224        int err = 0, rounds = 6 + ctx->key_length / 4;
 225        unsigned int blocks;
 226
 227        while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
 228                kernel_neon_begin();
 229                aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
 230                                ctx->key_enc, rounds, blocks, walk->iv);
 231                kernel_neon_end();
 232                err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
 233        }
 234        return err;
 235}
 236
 237static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
 238{
 239        struct skcipher_walk walk;
 240        int err;
 241
 242        err = skcipher_walk_virt(&walk, req, false);
 243        if (err)
 244                return err;
 245        return cbc_encrypt_walk(req, &walk);
 246}
 247
 248static int cbc_decrypt_walk(struct skcipher_request *req,
 249                            struct skcipher_walk *walk)
 250{
 251        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 252        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 253        int err = 0, rounds = 6 + ctx->key_length / 4;
 254        unsigned int blocks;
 255
 256        while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
 257                kernel_neon_begin();
 258                aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
 259                                ctx->key_dec, rounds, blocks, walk->iv);
 260                kernel_neon_end();
 261                err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
 262        }
 263        return err;
 264}
 265
 266static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
 267{
 268        struct skcipher_walk walk;
 269        int err;
 270
 271        err = skcipher_walk_virt(&walk, req, false);
 272        if (err)
 273                return err;
 274        return cbc_decrypt_walk(req, &walk);
 275}
 276
 277static int cts_cbc_encrypt(struct skcipher_request *req)
 278{
 279        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 280        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 281        int err, rounds = 6 + ctx->key_length / 4;
 282        int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
 283        struct scatterlist *src = req->src, *dst = req->dst;
 284        struct scatterlist sg_src[2], sg_dst[2];
 285        struct skcipher_request subreq;
 286        struct skcipher_walk walk;
 287
 288        skcipher_request_set_tfm(&subreq, tfm);
 289        skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
 290                                      NULL, NULL);
 291
 292        if (req->cryptlen <= AES_BLOCK_SIZE) {
 293                if (req->cryptlen < AES_BLOCK_SIZE)
 294                        return -EINVAL;
 295                cbc_blocks = 1;
 296        }
 297
 298        if (cbc_blocks > 0) {
 299                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 300                                           cbc_blocks * AES_BLOCK_SIZE,
 301                                           req->iv);
 302
 303                err = skcipher_walk_virt(&walk, &subreq, false) ?:
 304                      cbc_encrypt_walk(&subreq, &walk);
 305                if (err)
 306                        return err;
 307
 308                if (req->cryptlen == AES_BLOCK_SIZE)
 309                        return 0;
 310
 311                dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
 312                if (req->dst != req->src)
 313                        dst = scatterwalk_ffwd(sg_dst, req->dst,
 314                                               subreq.cryptlen);
 315        }
 316
 317        /* handle ciphertext stealing */
 318        skcipher_request_set_crypt(&subreq, src, dst,
 319                                   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
 320                                   req->iv);
 321
 322        err = skcipher_walk_virt(&walk, &subreq, false);
 323        if (err)
 324                return err;
 325
 326        kernel_neon_begin();
 327        aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 328                            ctx->key_enc, rounds, walk.nbytes, walk.iv);
 329        kernel_neon_end();
 330
 331        return skcipher_walk_done(&walk, 0);
 332}
 333
 334static int cts_cbc_decrypt(struct skcipher_request *req)
 335{
 336        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 337        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 338        int err, rounds = 6 + ctx->key_length / 4;
 339        int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
 340        struct scatterlist *src = req->src, *dst = req->dst;
 341        struct scatterlist sg_src[2], sg_dst[2];
 342        struct skcipher_request subreq;
 343        struct skcipher_walk walk;
 344
 345        skcipher_request_set_tfm(&subreq, tfm);
 346        skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
 347                                      NULL, NULL);
 348
 349        if (req->cryptlen <= AES_BLOCK_SIZE) {
 350                if (req->cryptlen < AES_BLOCK_SIZE)
 351                        return -EINVAL;
 352                cbc_blocks = 1;
 353        }
 354
 355        if (cbc_blocks > 0) {
 356                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 357                                           cbc_blocks * AES_BLOCK_SIZE,
 358                                           req->iv);
 359
 360                err = skcipher_walk_virt(&walk, &subreq, false) ?:
 361                      cbc_decrypt_walk(&subreq, &walk);
 362                if (err)
 363                        return err;
 364
 365                if (req->cryptlen == AES_BLOCK_SIZE)
 366                        return 0;
 367
 368                dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
 369                if (req->dst != req->src)
 370                        dst = scatterwalk_ffwd(sg_dst, req->dst,
 371                                               subreq.cryptlen);
 372        }
 373
 374        /* handle ciphertext stealing */
 375        skcipher_request_set_crypt(&subreq, src, dst,
 376                                   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
 377                                   req->iv);
 378
 379        err = skcipher_walk_virt(&walk, &subreq, false);
 380        if (err)
 381                return err;
 382
 383        kernel_neon_begin();
 384        aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 385                            ctx->key_dec, rounds, walk.nbytes, walk.iv);
 386        kernel_neon_end();
 387
 388        return skcipher_walk_done(&walk, 0);
 389}
 390
 391static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
 392{
 393        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 394
 395        ctx->hash = crypto_alloc_shash("sha256", 0, 0);
 396
 397        return PTR_ERR_OR_ZERO(ctx->hash);
 398}
 399
 400static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
 401{
 402        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 403
 404        crypto_free_shash(ctx->hash);
 405}
 406
 407static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
 408{
 409        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 410        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 411        int err, rounds = 6 + ctx->key1.key_length / 4;
 412        struct skcipher_walk walk;
 413        unsigned int blocks;
 414
 415        err = skcipher_walk_virt(&walk, req, false);
 416
 417        blocks = walk.nbytes / AES_BLOCK_SIZE;
 418        if (blocks) {
 419                kernel_neon_begin();
 420                aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 421                                      ctx->key1.key_enc, rounds, blocks,
 422                                      req->iv, ctx->key2.key_enc);
 423                kernel_neon_end();
 424                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 425        }
 426        return err ?: cbc_encrypt_walk(req, &walk);
 427}
 428
 429static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
 430{
 431        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 432        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 433        int err, rounds = 6 + ctx->key1.key_length / 4;
 434        struct skcipher_walk walk;
 435        unsigned int blocks;
 436
 437        err = skcipher_walk_virt(&walk, req, false);
 438
 439        blocks = walk.nbytes / AES_BLOCK_SIZE;
 440        if (blocks) {
 441                kernel_neon_begin();
 442                aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 443                                      ctx->key1.key_dec, rounds, blocks,
 444                                      req->iv, ctx->key2.key_enc);
 445                kernel_neon_end();
 446                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 447        }
 448        return err ?: cbc_decrypt_walk(req, &walk);
 449}
 450
 451static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
 452{
 453        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 454        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 455        int err, rounds = 6 + ctx->key_length / 4;
 456        struct skcipher_walk walk;
 457        unsigned int byte_ctr = 0;
 458
 459        err = skcipher_walk_virt(&walk, req, false);
 460
 461        while (walk.nbytes > 0) {
 462                const u8 *src = walk.src.virt.addr;
 463                unsigned int nbytes = walk.nbytes;
 464                u8 *dst = walk.dst.virt.addr;
 465                u8 buf[AES_BLOCK_SIZE];
 466
 467                /*
 468                 * If given less than 16 bytes, we must copy the partial block
 469                 * into a temporary buffer of 16 bytes to avoid out of bounds
 470                 * reads and writes.  Furthermore, this code is somewhat unusual
 471                 * in that it expects the end of the data to be at the end of
 472                 * the temporary buffer, rather than the start of the data at
 473                 * the start of the temporary buffer.
 474                 */
 475                if (unlikely(nbytes < AES_BLOCK_SIZE))
 476                        src = dst = memcpy(buf + sizeof(buf) - nbytes,
 477                                           src, nbytes);
 478                else if (nbytes < walk.total)
 479                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 480
 481                kernel_neon_begin();
 482                aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
 483                                                 walk.iv, byte_ctr);
 484                kernel_neon_end();
 485
 486                if (unlikely(nbytes < AES_BLOCK_SIZE))
 487                        memcpy(walk.dst.virt.addr,
 488                               buf + sizeof(buf) - nbytes, nbytes);
 489                byte_ctr += nbytes;
 490
 491                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
 492        }
 493
 494        return err;
 495}
 496
 497static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
 498{
 499        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 500        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 501        int err, rounds = 6 + ctx->key_length / 4;
 502        struct skcipher_walk walk;
 503
 504        err = skcipher_walk_virt(&walk, req, false);
 505
 506        while (walk.nbytes > 0) {
 507                const u8 *src = walk.src.virt.addr;
 508                unsigned int nbytes = walk.nbytes;
 509                u8 *dst = walk.dst.virt.addr;
 510                u8 buf[AES_BLOCK_SIZE];
 511
 512                /*
 513                 * If given less than 16 bytes, we must copy the partial block
 514                 * into a temporary buffer of 16 bytes to avoid out of bounds
 515                 * reads and writes.  Furthermore, this code is somewhat unusual
 516                 * in that it expects the end of the data to be at the end of
 517                 * the temporary buffer, rather than the start of the data at
 518                 * the start of the temporary buffer.
 519                 */
 520                if (unlikely(nbytes < AES_BLOCK_SIZE))
 521                        src = dst = memcpy(buf + sizeof(buf) - nbytes,
 522                                           src, nbytes);
 523                else if (nbytes < walk.total)
 524                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 525
 526                kernel_neon_begin();
 527                aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
 528                                walk.iv);
 529                kernel_neon_end();
 530
 531                if (unlikely(nbytes < AES_BLOCK_SIZE))
 532                        memcpy(walk.dst.virt.addr,
 533                               buf + sizeof(buf) - nbytes, nbytes);
 534
 535                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
 536        }
 537
 538        return err;
 539}
 540
 541static int __maybe_unused xts_encrypt(struct skcipher_request *req)
 542{
 543        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 544        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 545        int err, first, rounds = 6 + ctx->key1.key_length / 4;
 546        int tail = req->cryptlen % AES_BLOCK_SIZE;
 547        struct scatterlist sg_src[2], sg_dst[2];
 548        struct skcipher_request subreq;
 549        struct scatterlist *src, *dst;
 550        struct skcipher_walk walk;
 551
 552        if (req->cryptlen < AES_BLOCK_SIZE)
 553                return -EINVAL;
 554
 555        err = skcipher_walk_virt(&walk, req, false);
 556
 557        if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
 558                int xts_blocks = DIV_ROUND_UP(req->cryptlen,
 559                                              AES_BLOCK_SIZE) - 2;
 560
 561                skcipher_walk_abort(&walk);
 562
 563                skcipher_request_set_tfm(&subreq, tfm);
 564                skcipher_request_set_callback(&subreq,
 565                                              skcipher_request_flags(req),
 566                                              NULL, NULL);
 567                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 568                                           xts_blocks * AES_BLOCK_SIZE,
 569                                           req->iv);
 570                req = &subreq;
 571                err = skcipher_walk_virt(&walk, req, false);
 572        } else {
 573                tail = 0;
 574        }
 575
 576        for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
 577                int nbytes = walk.nbytes;
 578
 579                if (walk.nbytes < walk.total)
 580                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 581
 582                kernel_neon_begin();
 583                aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 584                                ctx->key1.key_enc, rounds, nbytes,
 585                                ctx->key2.key_enc, walk.iv, first);
 586                kernel_neon_end();
 587                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
 588        }
 589
 590        if (err || likely(!tail))
 591                return err;
 592
 593        dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
 594        if (req->dst != req->src)
 595                dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
 596
 597        skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
 598                                   req->iv);
 599
 600        err = skcipher_walk_virt(&walk, &subreq, false);
 601        if (err)
 602                return err;
 603
 604        kernel_neon_begin();
 605        aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 606                        ctx->key1.key_enc, rounds, walk.nbytes,
 607                        ctx->key2.key_enc, walk.iv, first);
 608        kernel_neon_end();
 609
 610        return skcipher_walk_done(&walk, 0);
 611}
 612
 613static int __maybe_unused xts_decrypt(struct skcipher_request *req)
 614{
 615        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 616        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 617        int err, first, rounds = 6 + ctx->key1.key_length / 4;
 618        int tail = req->cryptlen % AES_BLOCK_SIZE;
 619        struct scatterlist sg_src[2], sg_dst[2];
 620        struct skcipher_request subreq;
 621        struct scatterlist *src, *dst;
 622        struct skcipher_walk walk;
 623
 624        if (req->cryptlen < AES_BLOCK_SIZE)
 625                return -EINVAL;
 626
 627        err = skcipher_walk_virt(&walk, req, false);
 628
 629        if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
 630                int xts_blocks = DIV_ROUND_UP(req->cryptlen,
 631                                              AES_BLOCK_SIZE) - 2;
 632
 633                skcipher_walk_abort(&walk);
 634
 635                skcipher_request_set_tfm(&subreq, tfm);
 636                skcipher_request_set_callback(&subreq,
 637                                              skcipher_request_flags(req),
 638                                              NULL, NULL);
 639                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 640                                           xts_blocks * AES_BLOCK_SIZE,
 641                                           req->iv);
 642                req = &subreq;
 643                err = skcipher_walk_virt(&walk, req, false);
 644        } else {
 645                tail = 0;
 646        }
 647
 648        for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
 649                int nbytes = walk.nbytes;
 650
 651                if (walk.nbytes < walk.total)
 652                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 653
 654                kernel_neon_begin();
 655                aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 656                                ctx->key1.key_dec, rounds, nbytes,
 657                                ctx->key2.key_enc, walk.iv, first);
 658                kernel_neon_end();
 659                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
 660        }
 661
 662        if (err || likely(!tail))
 663                return err;
 664
 665        dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
 666        if (req->dst != req->src)
 667                dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
 668
 669        skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
 670                                   req->iv);
 671
 672        err = skcipher_walk_virt(&walk, &subreq, false);
 673        if (err)
 674                return err;
 675
 676
 677        kernel_neon_begin();
 678        aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 679                        ctx->key1.key_dec, rounds, walk.nbytes,
 680                        ctx->key2.key_enc, walk.iv, first);
 681        kernel_neon_end();
 682
 683        return skcipher_walk_done(&walk, 0);
 684}
 685
 686static struct skcipher_alg aes_algs[] = { {
 687#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
 688        .base = {
 689                .cra_name               = "ecb(aes)",
 690                .cra_driver_name        = "ecb-aes-" MODE,
 691                .cra_priority           = PRIO,
 692                .cra_blocksize          = AES_BLOCK_SIZE,
 693                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 694                .cra_module             = THIS_MODULE,
 695        },
 696        .min_keysize    = AES_MIN_KEY_SIZE,
 697        .max_keysize    = AES_MAX_KEY_SIZE,
 698        .setkey         = skcipher_aes_setkey,
 699        .encrypt        = ecb_encrypt,
 700        .decrypt        = ecb_decrypt,
 701}, {
 702        .base = {
 703                .cra_name               = "cbc(aes)",
 704                .cra_driver_name        = "cbc-aes-" MODE,
 705                .cra_priority           = PRIO,
 706                .cra_blocksize          = AES_BLOCK_SIZE,
 707                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 708                .cra_module             = THIS_MODULE,
 709        },
 710        .min_keysize    = AES_MIN_KEY_SIZE,
 711        .max_keysize    = AES_MAX_KEY_SIZE,
 712        .ivsize         = AES_BLOCK_SIZE,
 713        .setkey         = skcipher_aes_setkey,
 714        .encrypt        = cbc_encrypt,
 715        .decrypt        = cbc_decrypt,
 716}, {
 717        .base = {
 718                .cra_name               = "ctr(aes)",
 719                .cra_driver_name        = "ctr-aes-" MODE,
 720                .cra_priority           = PRIO,
 721                .cra_blocksize          = 1,
 722                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 723                .cra_module             = THIS_MODULE,
 724        },
 725        .min_keysize    = AES_MIN_KEY_SIZE,
 726        .max_keysize    = AES_MAX_KEY_SIZE,
 727        .ivsize         = AES_BLOCK_SIZE,
 728        .chunksize      = AES_BLOCK_SIZE,
 729        .setkey         = skcipher_aes_setkey,
 730        .encrypt        = ctr_encrypt,
 731        .decrypt        = ctr_encrypt,
 732}, {
 733        .base = {
 734                .cra_name               = "xctr(aes)",
 735                .cra_driver_name        = "xctr-aes-" MODE,
 736                .cra_priority           = PRIO,
 737                .cra_blocksize          = 1,
 738                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 739                .cra_module             = THIS_MODULE,
 740        },
 741        .min_keysize    = AES_MIN_KEY_SIZE,
 742        .max_keysize    = AES_MAX_KEY_SIZE,
 743        .ivsize         = AES_BLOCK_SIZE,
 744        .chunksize      = AES_BLOCK_SIZE,
 745        .setkey         = skcipher_aes_setkey,
 746        .encrypt        = xctr_encrypt,
 747        .decrypt        = xctr_encrypt,
 748}, {
 749        .base = {
 750                .cra_name               = "xts(aes)",
 751                .cra_driver_name        = "xts-aes-" MODE,
 752                .cra_priority           = PRIO,
 753                .cra_blocksize          = AES_BLOCK_SIZE,
 754                .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
 755                .cra_module             = THIS_MODULE,
 756        },
 757        .min_keysize    = 2 * AES_MIN_KEY_SIZE,
 758        .max_keysize    = 2 * AES_MAX_KEY_SIZE,
 759        .ivsize         = AES_BLOCK_SIZE,
 760        .walksize       = 2 * AES_BLOCK_SIZE,
 761        .setkey         = xts_set_key,
 762        .encrypt        = xts_encrypt,
 763        .decrypt        = xts_decrypt,
 764}, {
 765#endif
 766        .base = {
 767                .cra_name               = "cts(cbc(aes))",
 768                .cra_driver_name        = "cts-cbc-aes-" MODE,
 769                .cra_priority           = PRIO,
 770                .cra_blocksize          = AES_BLOCK_SIZE,
 771                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 772                .cra_module             = THIS_MODULE,
 773        },
 774        .min_keysize    = AES_MIN_KEY_SIZE,
 775        .max_keysize    = AES_MAX_KEY_SIZE,
 776        .ivsize         = AES_BLOCK_SIZE,
 777        .walksize       = 2 * AES_BLOCK_SIZE,
 778        .setkey         = skcipher_aes_setkey,
 779        .encrypt        = cts_cbc_encrypt,
 780        .decrypt        = cts_cbc_decrypt,
 781}, {
 782        .base = {
 783                .cra_name               = "essiv(cbc(aes),sha256)",
 784                .cra_driver_name        = "essiv-cbc-aes-sha256-" MODE,
 785                .cra_priority           = PRIO + 1,
 786                .cra_blocksize          = AES_BLOCK_SIZE,
 787                .cra_ctxsize            = sizeof(struct crypto_aes_essiv_cbc_ctx),
 788                .cra_module             = THIS_MODULE,
 789        },
 790        .min_keysize    = AES_MIN_KEY_SIZE,
 791        .max_keysize    = AES_MAX_KEY_SIZE,
 792        .ivsize         = AES_BLOCK_SIZE,
 793        .setkey         = essiv_cbc_set_key,
 794        .encrypt        = essiv_cbc_encrypt,
 795        .decrypt        = essiv_cbc_decrypt,
 796        .init           = essiv_cbc_init_tfm,
 797        .exit           = essiv_cbc_exit_tfm,
 798} };
 799
 800static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
 801                         unsigned int key_len)
 802{
 803        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 804
 805        return aes_expandkey(&ctx->key, in_key, key_len);
 806}
 807
 808static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
 809{
 810        u64 a = be64_to_cpu(x->a);
 811        u64 b = be64_to_cpu(x->b);
 812
 813        y->a = cpu_to_be64((a << 1) | (b >> 63));
 814        y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
 815}
 816
 817static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
 818                       unsigned int key_len)
 819{
 820        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 821        be128 *consts = (be128 *)ctx->consts;
 822        int rounds = 6 + key_len / 4;
 823        int err;
 824
 825        err = cbcmac_setkey(tfm, in_key, key_len);
 826        if (err)
 827                return err;
 828
 829        /* encrypt the zero vector */
 830        kernel_neon_begin();
 831        aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
 832                        rounds, 1);
 833        kernel_neon_end();
 834
 835        cmac_gf128_mul_by_x(consts, consts);
 836        cmac_gf128_mul_by_x(consts + 1, consts);
 837
 838        return 0;
 839}
 840
 841static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
 842                       unsigned int key_len)
 843{
 844        static u8 const ks[3][AES_BLOCK_SIZE] = {
 845                { [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
 846                { [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
 847                { [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
 848        };
 849
 850        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 851        int rounds = 6 + key_len / 4;
 852        u8 key[AES_BLOCK_SIZE];
 853        int err;
 854
 855        err = cbcmac_setkey(tfm, in_key, key_len);
 856        if (err)
 857                return err;
 858
 859        kernel_neon_begin();
 860        aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
 861        aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
 862        kernel_neon_end();
 863
 864        return cbcmac_setkey(tfm, key, sizeof(key));
 865}
 866
 867static int mac_init(struct shash_desc *desc)
 868{
 869        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 870
 871        memset(ctx->dg, 0, AES_BLOCK_SIZE);
 872        ctx->len = 0;
 873
 874        return 0;
 875}
 876
 877static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
 878                          u8 dg[], int enc_before, int enc_after)
 879{
 880        int rounds = 6 + ctx->key_length / 4;
 881
 882        if (crypto_simd_usable()) {
 883                int rem;
 884
 885                do {
 886                        kernel_neon_begin();
 887                        rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
 888                                             dg, enc_before, enc_after);
 889                        kernel_neon_end();
 890                        in += (blocks - rem) * AES_BLOCK_SIZE;
 891                        blocks = rem;
 892                        enc_before = 0;
 893                } while (blocks);
 894        } else {
 895                if (enc_before)
 896                        aes_encrypt(ctx, dg, dg);
 897
 898                while (blocks--) {
 899                        crypto_xor(dg, in, AES_BLOCK_SIZE);
 900                        in += AES_BLOCK_SIZE;
 901
 902                        if (blocks || enc_after)
 903                                aes_encrypt(ctx, dg, dg);
 904                }
 905        }
 906}
 907
 908static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
 909{
 910        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 911        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 912
 913        while (len > 0) {
 914                unsigned int l;
 915
 916                if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
 917                    (ctx->len + len) > AES_BLOCK_SIZE) {
 918
 919                        int blocks = len / AES_BLOCK_SIZE;
 920
 921                        len %= AES_BLOCK_SIZE;
 922
 923                        mac_do_update(&tctx->key, p, blocks, ctx->dg,
 924                                      (ctx->len != 0), (len != 0));
 925
 926                        p += blocks * AES_BLOCK_SIZE;
 927
 928                        if (!len) {
 929                                ctx->len = AES_BLOCK_SIZE;
 930                                break;
 931                        }
 932                        ctx->len = 0;
 933                }
 934
 935                l = min(len, AES_BLOCK_SIZE - ctx->len);
 936
 937                if (l <= AES_BLOCK_SIZE) {
 938                        crypto_xor(ctx->dg + ctx->len, p, l);
 939                        ctx->len += l;
 940                        len -= l;
 941                        p += l;
 942                }
 943        }
 944
 945        return 0;
 946}
 947
 948static int cbcmac_final(struct shash_desc *desc, u8 *out)
 949{
 950        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 951        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 952
 953        mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
 954
 955        memcpy(out, ctx->dg, AES_BLOCK_SIZE);
 956
 957        return 0;
 958}
 959
 960static int cmac_final(struct shash_desc *desc, u8 *out)
 961{
 962        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 963        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 964        u8 *consts = tctx->consts;
 965
 966        if (ctx->len != AES_BLOCK_SIZE) {
 967                ctx->dg[ctx->len] ^= 0x80;
 968                consts += AES_BLOCK_SIZE;
 969        }
 970
 971        mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
 972
 973        memcpy(out, ctx->dg, AES_BLOCK_SIZE);
 974
 975        return 0;
 976}
 977
 978static struct shash_alg mac_algs[] = { {
 979        .base.cra_name          = "cmac(aes)",
 980        .base.cra_driver_name   = "cmac-aes-" MODE,
 981        .base.cra_priority      = PRIO,
 982        .base.cra_blocksize     = AES_BLOCK_SIZE,
 983        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
 984                                  2 * AES_BLOCK_SIZE,
 985        .base.cra_module        = THIS_MODULE,
 986
 987        .digestsize             = AES_BLOCK_SIZE,
 988        .init                   = mac_init,
 989        .update                 = mac_update,
 990        .final                  = cmac_final,
 991        .setkey                 = cmac_setkey,
 992        .descsize               = sizeof(struct mac_desc_ctx),
 993}, {
 994        .base.cra_name          = "xcbc(aes)",
 995        .base.cra_driver_name   = "xcbc-aes-" MODE,
 996        .base.cra_priority      = PRIO,
 997        .base.cra_blocksize     = AES_BLOCK_SIZE,
 998        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
 999                                  2 * AES_BLOCK_SIZE,
1000        .base.cra_module        = THIS_MODULE,
1001
1002        .digestsize             = AES_BLOCK_SIZE,
1003        .init                   = mac_init,
1004        .update                 = mac_update,
1005        .final                  = cmac_final,
1006        .setkey                 = xcbc_setkey,
1007        .descsize               = sizeof(struct mac_desc_ctx),
1008}, {
1009        .base.cra_name          = "cbcmac(aes)",
1010        .base.cra_driver_name   = "cbcmac-aes-" MODE,
1011        .base.cra_priority      = PRIO,
1012        .base.cra_blocksize     = 1,
1013        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx),
1014        .base.cra_module        = THIS_MODULE,
1015
1016        .digestsize             = AES_BLOCK_SIZE,
1017        .init                   = mac_init,
1018        .update                 = mac_update,
1019        .final                  = cbcmac_final,
1020        .setkey                 = cbcmac_setkey,
1021        .descsize               = sizeof(struct mac_desc_ctx),
1022} };
1023
1024static void aes_exit(void)
1025{
1026        crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1027        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1028}
1029
1030static int __init aes_init(void)
1031{
1032        int err;
1033
1034        err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1035        if (err)
1036                return err;
1037
1038        err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1039        if (err)
1040                goto unregister_ciphers;
1041
1042        return 0;
1043
1044unregister_ciphers:
1045        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1046        return err;
1047}
1048
1049#ifdef USE_V8_CRYPTO_EXTENSIONS
1050module_cpu_feature_match(AES, aes_init);
1051#else
1052module_init(aes_init);
1053EXPORT_SYMBOL(neon_aes_ecb_encrypt);
1054EXPORT_SYMBOL(neon_aes_cbc_encrypt);
1055EXPORT_SYMBOL(neon_aes_ctr_encrypt);
1056EXPORT_SYMBOL(neon_aes_xts_encrypt);
1057EXPORT_SYMBOL(neon_aes_xts_decrypt);
1058#endif
1059module_exit(aes_exit);
1060