linux/arch/arm/crypto/aes-neonbs-glue.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * Bit sliced AES using NEON instructions
   4 *
   5 * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
   6 */
   7
   8#include <asm/neon.h>
   9#include <asm/simd.h>
  10#include <crypto/aes.h>
  11#include <crypto/ctr.h>
  12#include <crypto/internal/cipher.h>
  13#include <crypto/internal/simd.h>
  14#include <crypto/internal/skcipher.h>
  15#include <crypto/scatterwalk.h>
  16#include <crypto/xts.h>
  17#include <linux/module.h>
  18
  19MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
  20MODULE_LICENSE("GPL v2");
  21
  22MODULE_ALIAS_CRYPTO("ecb(aes)");
  23MODULE_ALIAS_CRYPTO("cbc(aes)-all");
  24MODULE_ALIAS_CRYPTO("ctr(aes)");
  25MODULE_ALIAS_CRYPTO("xts(aes)");
  26
  27MODULE_IMPORT_NS(CRYPTO_INTERNAL);
  28
  29asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
  30
  31asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
  32                                  int rounds, int blocks);
  33asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
  34                                  int rounds, int blocks);
  35
  36asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
  37                                  int rounds, int blocks, u8 iv[]);
  38
  39asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
  40                                  int rounds, int blocks, u8 ctr[]);
  41
  42asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
  43                                  int rounds, int blocks, u8 iv[], int);
  44asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
  45                                  int rounds, int blocks, u8 iv[], int);
  46
  47struct aesbs_ctx {
  48        int     rounds;
  49        u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
  50};
  51
  52struct aesbs_cbc_ctx {
  53        struct aesbs_ctx        key;
  54        struct crypto_skcipher  *enc_tfm;
  55};
  56
  57struct aesbs_xts_ctx {
  58        struct aesbs_ctx        key;
  59        struct crypto_cipher    *cts_tfm;
  60        struct crypto_cipher    *tweak_tfm;
  61};
  62
  63struct aesbs_ctr_ctx {
  64        struct aesbs_ctx        key;            /* must be first member */
  65        struct crypto_aes_ctx   fallback;
  66};
  67
  68static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
  69                        unsigned int key_len)
  70{
  71        struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
  72        struct crypto_aes_ctx rk;
  73        int err;
  74
  75        err = aes_expandkey(&rk, in_key, key_len);
  76        if (err)
  77                return err;
  78
  79        ctx->rounds = 6 + key_len / 4;
  80
  81        kernel_neon_begin();
  82        aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
  83        kernel_neon_end();
  84
  85        return 0;
  86}
  87
  88static int __ecb_crypt(struct skcipher_request *req,
  89                       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
  90                                  int rounds, int blocks))
  91{
  92        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  93        struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
  94        struct skcipher_walk walk;
  95        int err;
  96
  97        err = skcipher_walk_virt(&walk, req, false);
  98
  99        while (walk.nbytes >= AES_BLOCK_SIZE) {
 100                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 101
 102                if (walk.nbytes < walk.total)
 103                        blocks = round_down(blocks,
 104                                            walk.stride / AES_BLOCK_SIZE);
 105
 106                kernel_neon_begin();
 107                fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
 108                   ctx->rounds, blocks);
 109                kernel_neon_end();
 110                err = skcipher_walk_done(&walk,
 111                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 112        }
 113
 114        return err;
 115}
 116
 117static int ecb_encrypt(struct skcipher_request *req)
 118{
 119        return __ecb_crypt(req, aesbs_ecb_encrypt);
 120}
 121
 122static int ecb_decrypt(struct skcipher_request *req)
 123{
 124        return __ecb_crypt(req, aesbs_ecb_decrypt);
 125}
 126
 127static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 128                            unsigned int key_len)
 129{
 130        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 131        struct crypto_aes_ctx rk;
 132        int err;
 133
 134        err = aes_expandkey(&rk, in_key, key_len);
 135        if (err)
 136                return err;
 137
 138        ctx->key.rounds = 6 + key_len / 4;
 139
 140        kernel_neon_begin();
 141        aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
 142        kernel_neon_end();
 143        memzero_explicit(&rk, sizeof(rk));
 144
 145        return crypto_skcipher_setkey(ctx->enc_tfm, in_key, key_len);
 146}
 147
 148static int cbc_encrypt(struct skcipher_request *req)
 149{
 150        struct skcipher_request *subreq = skcipher_request_ctx(req);
 151        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 152        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 153
 154        skcipher_request_set_tfm(subreq, ctx->enc_tfm);
 155        skcipher_request_set_callback(subreq,
 156                                      skcipher_request_flags(req),
 157                                      NULL, NULL);
 158        skcipher_request_set_crypt(subreq, req->src, req->dst,
 159                                   req->cryptlen, req->iv);
 160
 161        return crypto_skcipher_encrypt(subreq);
 162}
 163
 164static int cbc_decrypt(struct skcipher_request *req)
 165{
 166        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 167        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 168        struct skcipher_walk walk;
 169        int err;
 170
 171        err = skcipher_walk_virt(&walk, req, false);
 172
 173        while (walk.nbytes >= AES_BLOCK_SIZE) {
 174                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 175
 176                if (walk.nbytes < walk.total)
 177                        blocks = round_down(blocks,
 178                                            walk.stride / AES_BLOCK_SIZE);
 179
 180                kernel_neon_begin();
 181                aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 182                                  ctx->key.rk, ctx->key.rounds, blocks,
 183                                  walk.iv);
 184                kernel_neon_end();
 185                err = skcipher_walk_done(&walk,
 186                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 187        }
 188
 189        return err;
 190}
 191
 192static int cbc_init(struct crypto_skcipher *tfm)
 193{
 194        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 195        unsigned int reqsize;
 196
 197        ctx->enc_tfm = crypto_alloc_skcipher("cbc(aes)", 0, CRYPTO_ALG_ASYNC |
 198                                             CRYPTO_ALG_NEED_FALLBACK);
 199        if (IS_ERR(ctx->enc_tfm))
 200                return PTR_ERR(ctx->enc_tfm);
 201
 202        reqsize = sizeof(struct skcipher_request);
 203        reqsize += crypto_skcipher_reqsize(ctx->enc_tfm);
 204        crypto_skcipher_set_reqsize(tfm, reqsize);
 205
 206        return 0;
 207}
 208
 209static void cbc_exit(struct crypto_skcipher *tfm)
 210{
 211        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 212
 213        crypto_free_skcipher(ctx->enc_tfm);
 214}
 215
 216static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
 217                                 unsigned int key_len)
 218{
 219        struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
 220        int err;
 221
 222        err = aes_expandkey(&ctx->fallback, in_key, key_len);
 223        if (err)
 224                return err;
 225
 226        ctx->key.rounds = 6 + key_len / 4;
 227
 228        kernel_neon_begin();
 229        aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
 230        kernel_neon_end();
 231
 232        return 0;
 233}
 234
 235static int ctr_encrypt(struct skcipher_request *req)
 236{
 237        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 238        struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
 239        struct skcipher_walk walk;
 240        u8 buf[AES_BLOCK_SIZE];
 241        int err;
 242
 243        err = skcipher_walk_virt(&walk, req, false);
 244
 245        while (walk.nbytes > 0) {
 246                const u8 *src = walk.src.virt.addr;
 247                u8 *dst = walk.dst.virt.addr;
 248                int bytes = walk.nbytes;
 249
 250                if (unlikely(bytes < AES_BLOCK_SIZE))
 251                        src = dst = memcpy(buf + sizeof(buf) - bytes,
 252                                           src, bytes);
 253                else if (walk.nbytes < walk.total)
 254                        bytes &= ~(8 * AES_BLOCK_SIZE - 1);
 255
 256                kernel_neon_begin();
 257                aesbs_ctr_encrypt(dst, src, ctx->rk, ctx->rounds, bytes, walk.iv);
 258                kernel_neon_end();
 259
 260                if (unlikely(bytes < AES_BLOCK_SIZE))
 261                        memcpy(walk.dst.virt.addr,
 262                               buf + sizeof(buf) - bytes, bytes);
 263
 264                err = skcipher_walk_done(&walk, walk.nbytes - bytes);
 265        }
 266
 267        return err;
 268}
 269
 270static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
 271{
 272        struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
 273        unsigned long flags;
 274
 275        /*
 276         * Temporarily disable interrupts to avoid races where
 277         * cachelines are evicted when the CPU is interrupted
 278         * to do something else.
 279         */
 280        local_irq_save(flags);
 281        aes_encrypt(&ctx->fallback, dst, src);
 282        local_irq_restore(flags);
 283}
 284
 285static int ctr_encrypt_sync(struct skcipher_request *req)
 286{
 287        if (!crypto_simd_usable())
 288                return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
 289
 290        return ctr_encrypt(req);
 291}
 292
 293static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 294                            unsigned int key_len)
 295{
 296        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 297        int err;
 298
 299        err = xts_verify_key(tfm, in_key, key_len);
 300        if (err)
 301                return err;
 302
 303        key_len /= 2;
 304        err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len);
 305        if (err)
 306                return err;
 307        err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
 308        if (err)
 309                return err;
 310
 311        return aesbs_setkey(tfm, in_key, key_len);
 312}
 313
 314static int xts_init(struct crypto_skcipher *tfm)
 315{
 316        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 317
 318        ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0);
 319        if (IS_ERR(ctx->cts_tfm))
 320                return PTR_ERR(ctx->cts_tfm);
 321
 322        ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
 323        if (IS_ERR(ctx->tweak_tfm))
 324                crypto_free_cipher(ctx->cts_tfm);
 325
 326        return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
 327}
 328
 329static void xts_exit(struct crypto_skcipher *tfm)
 330{
 331        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 332
 333        crypto_free_cipher(ctx->tweak_tfm);
 334        crypto_free_cipher(ctx->cts_tfm);
 335}
 336
 337static int __xts_crypt(struct skcipher_request *req, bool encrypt,
 338                       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
 339                                  int rounds, int blocks, u8 iv[], int))
 340{
 341        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 342        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 343        int tail = req->cryptlen % AES_BLOCK_SIZE;
 344        struct skcipher_request subreq;
 345        u8 buf[2 * AES_BLOCK_SIZE];
 346        struct skcipher_walk walk;
 347        int err;
 348
 349        if (req->cryptlen < AES_BLOCK_SIZE)
 350                return -EINVAL;
 351
 352        if (unlikely(tail)) {
 353                skcipher_request_set_tfm(&subreq, tfm);
 354                skcipher_request_set_callback(&subreq,
 355                                              skcipher_request_flags(req),
 356                                              NULL, NULL);
 357                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 358                                           req->cryptlen - tail, req->iv);
 359                req = &subreq;
 360        }
 361
 362        err = skcipher_walk_virt(&walk, req, true);
 363        if (err)
 364                return err;
 365
 366        crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
 367
 368        while (walk.nbytes >= AES_BLOCK_SIZE) {
 369                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 370                int reorder_last_tweak = !encrypt && tail > 0;
 371
 372                if (walk.nbytes < walk.total) {
 373                        blocks = round_down(blocks,
 374                                            walk.stride / AES_BLOCK_SIZE);
 375                        reorder_last_tweak = 0;
 376                }
 377
 378                kernel_neon_begin();
 379                fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
 380                   ctx->key.rounds, blocks, walk.iv, reorder_last_tweak);
 381                kernel_neon_end();
 382                err = skcipher_walk_done(&walk,
 383                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 384        }
 385
 386        if (err || likely(!tail))
 387                return err;
 388
 389        /* handle ciphertext stealing */
 390        scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
 391                                 AES_BLOCK_SIZE, 0);
 392        memcpy(buf + AES_BLOCK_SIZE, buf, tail);
 393        scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
 394
 395        crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
 396
 397        if (encrypt)
 398                crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf);
 399        else
 400                crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf);
 401
 402        crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
 403
 404        scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
 405                                 AES_BLOCK_SIZE + tail, 1);
 406        return 0;
 407}
 408
 409static int xts_encrypt(struct skcipher_request *req)
 410{
 411        return __xts_crypt(req, true, aesbs_xts_encrypt);
 412}
 413
 414static int xts_decrypt(struct skcipher_request *req)
 415{
 416        return __xts_crypt(req, false, aesbs_xts_decrypt);
 417}
 418
 419static struct skcipher_alg aes_algs[] = { {
 420        .base.cra_name          = "__ecb(aes)",
 421        .base.cra_driver_name   = "__ecb-aes-neonbs",
 422        .base.cra_priority      = 250,
 423        .base.cra_blocksize     = AES_BLOCK_SIZE,
 424        .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
 425        .base.cra_module        = THIS_MODULE,
 426        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 427
 428        .min_keysize            = AES_MIN_KEY_SIZE,
 429        .max_keysize            = AES_MAX_KEY_SIZE,
 430        .walksize               = 8 * AES_BLOCK_SIZE,
 431        .setkey                 = aesbs_setkey,
 432        .encrypt                = ecb_encrypt,
 433        .decrypt                = ecb_decrypt,
 434}, {
 435        .base.cra_name          = "__cbc(aes)",
 436        .base.cra_driver_name   = "__cbc-aes-neonbs",
 437        .base.cra_priority      = 250,
 438        .base.cra_blocksize     = AES_BLOCK_SIZE,
 439        .base.cra_ctxsize       = sizeof(struct aesbs_cbc_ctx),
 440        .base.cra_module        = THIS_MODULE,
 441        .base.cra_flags         = CRYPTO_ALG_INTERNAL |
 442                                  CRYPTO_ALG_NEED_FALLBACK,
 443
 444        .min_keysize            = AES_MIN_KEY_SIZE,
 445        .max_keysize            = AES_MAX_KEY_SIZE,
 446        .walksize               = 8 * AES_BLOCK_SIZE,
 447        .ivsize                 = AES_BLOCK_SIZE,
 448        .setkey                 = aesbs_cbc_setkey,
 449        .encrypt                = cbc_encrypt,
 450        .decrypt                = cbc_decrypt,
 451        .init                   = cbc_init,
 452        .exit                   = cbc_exit,
 453}, {
 454        .base.cra_name          = "__ctr(aes)",
 455        .base.cra_driver_name   = "__ctr-aes-neonbs",
 456        .base.cra_priority      = 250,
 457        .base.cra_blocksize     = 1,
 458        .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
 459        .base.cra_module        = THIS_MODULE,
 460        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 461
 462        .min_keysize            = AES_MIN_KEY_SIZE,
 463        .max_keysize            = AES_MAX_KEY_SIZE,
 464        .chunksize              = AES_BLOCK_SIZE,
 465        .walksize               = 8 * AES_BLOCK_SIZE,
 466        .ivsize                 = AES_BLOCK_SIZE,
 467        .setkey                 = aesbs_setkey,
 468        .encrypt                = ctr_encrypt,
 469        .decrypt                = ctr_encrypt,
 470}, {
 471        .base.cra_name          = "ctr(aes)",
 472        .base.cra_driver_name   = "ctr-aes-neonbs-sync",
 473        .base.cra_priority      = 250 - 1,
 474        .base.cra_blocksize     = 1,
 475        .base.cra_ctxsize       = sizeof(struct aesbs_ctr_ctx),
 476        .base.cra_module        = THIS_MODULE,
 477
 478        .min_keysize            = AES_MIN_KEY_SIZE,
 479        .max_keysize            = AES_MAX_KEY_SIZE,
 480        .chunksize              = AES_BLOCK_SIZE,
 481        .walksize               = 8 * AES_BLOCK_SIZE,
 482        .ivsize                 = AES_BLOCK_SIZE,
 483        .setkey                 = aesbs_ctr_setkey_sync,
 484        .encrypt                = ctr_encrypt_sync,
 485        .decrypt                = ctr_encrypt_sync,
 486}, {
 487        .base.cra_name          = "__xts(aes)",
 488        .base.cra_driver_name   = "__xts-aes-neonbs",
 489        .base.cra_priority      = 250,
 490        .base.cra_blocksize     = AES_BLOCK_SIZE,
 491        .base.cra_ctxsize       = sizeof(struct aesbs_xts_ctx),
 492        .base.cra_module        = THIS_MODULE,
 493        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 494
 495        .min_keysize            = 2 * AES_MIN_KEY_SIZE,
 496        .max_keysize            = 2 * AES_MAX_KEY_SIZE,
 497        .walksize               = 8 * AES_BLOCK_SIZE,
 498        .ivsize                 = AES_BLOCK_SIZE,
 499        .setkey                 = aesbs_xts_setkey,
 500        .encrypt                = xts_encrypt,
 501        .decrypt                = xts_decrypt,
 502        .init                   = xts_init,
 503        .exit                   = xts_exit,
 504} };
 505
 506static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
 507
 508static void aes_exit(void)
 509{
 510        int i;
 511
 512        for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
 513                if (aes_simd_algs[i])
 514                        simd_skcipher_free(aes_simd_algs[i]);
 515
 516        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 517}
 518
 519static int __init aes_init(void)
 520{
 521        struct simd_skcipher_alg *simd;
 522        const char *basename;
 523        const char *algname;
 524        const char *drvname;
 525        int err;
 526        int i;
 527
 528        if (!(elf_hwcap & HWCAP_NEON))
 529                return -ENODEV;
 530
 531        err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 532        if (err)
 533                return err;
 534
 535        for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
 536                if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
 537                        continue;
 538
 539                algname = aes_algs[i].base.cra_name + 2;
 540                drvname = aes_algs[i].base.cra_driver_name + 2;
 541                basename = aes_algs[i].base.cra_driver_name;
 542                simd = simd_skcipher_create_compat(algname, drvname, basename);
 543                err = PTR_ERR(simd);
 544                if (IS_ERR(simd))
 545                        goto unregister_simds;
 546
 547                aes_simd_algs[i] = simd;
 548        }
 549        return 0;
 550
 551unregister_simds:
 552        aes_exit();
 553        return err;
 554}
 555
 556late_initcall(aes_init);
 557module_exit(aes_exit);
 558