linux/arch/arm64/lib/csum.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2// Copyright (C) 2019-2020 Arm Ltd.
   3
   4#include <linux/compiler.h>
   5#include <linux/kasan-checks.h>
   6#include <linux/kernel.h>
   7
   8#include <net/checksum.h>
   9
  10/* Looks dumb, but generates nice-ish code */
  11static u64 accumulate(u64 sum, u64 data)
  12{
  13        __uint128_t tmp = (__uint128_t)sum + data;
  14        return tmp + (tmp >> 64);
  15}
  16
  17/*
  18 * We over-read the buffer and this makes KASAN unhappy. Instead, disable
  19 * instrumentation and call kasan explicitly.
  20 */
  21unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
  22{
  23        unsigned int offset, shift, sum;
  24        const u64 *ptr;
  25        u64 data, sum64 = 0;
  26
  27        if (unlikely(len == 0))
  28                return 0;
  29
  30        offset = (unsigned long)buff & 7;
  31        /*
  32         * This is to all intents and purposes safe, since rounding down cannot
  33         * result in a different page or cache line being accessed, and @buff
  34         * should absolutely not be pointing to anything read-sensitive. We do,
  35         * however, have to be careful not to piss off KASAN, which means using
  36         * unchecked reads to accommodate the head and tail, for which we'll
  37         * compensate with an explicit check up-front.
  38         */
  39        kasan_check_read(buff, len);
  40        ptr = (u64 *)(buff - offset);
  41        len = len + offset - 8;
  42
  43        /*
  44         * Head: zero out any excess leading bytes. Shifting back by the same
  45         * amount should be at least as fast as any other way of handling the
  46         * odd/even alignment, and means we can ignore it until the very end.
  47         */
  48        shift = offset * 8;
  49        data = *ptr++;
  50#ifdef __LITTLE_ENDIAN
  51        data = (data >> shift) << shift;
  52#else
  53        data = (data << shift) >> shift;
  54#endif
  55
  56        /*
  57         * Body: straightforward aligned loads from here on (the paired loads
  58         * underlying the quadword type still only need dword alignment). The
  59         * main loop strictly excludes the tail, so the second loop will always
  60         * run at least once.
  61         */
  62        while (unlikely(len > 64)) {
  63                __uint128_t tmp1, tmp2, tmp3, tmp4;
  64
  65                tmp1 = *(__uint128_t *)ptr;
  66                tmp2 = *(__uint128_t *)(ptr + 2);
  67                tmp3 = *(__uint128_t *)(ptr + 4);
  68                tmp4 = *(__uint128_t *)(ptr + 6);
  69
  70                len -= 64;
  71                ptr += 8;
  72
  73                /* This is the "don't dump the carry flag into a GPR" idiom */
  74                tmp1 += (tmp1 >> 64) | (tmp1 << 64);
  75                tmp2 += (tmp2 >> 64) | (tmp2 << 64);
  76                tmp3 += (tmp3 >> 64) | (tmp3 << 64);
  77                tmp4 += (tmp4 >> 64) | (tmp4 << 64);
  78                tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
  79                tmp1 += (tmp1 >> 64) | (tmp1 << 64);
  80                tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
  81                tmp3 += (tmp3 >> 64) | (tmp3 << 64);
  82                tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
  83                tmp1 += (tmp1 >> 64) | (tmp1 << 64);
  84                tmp1 = ((tmp1 >> 64) << 64) | sum64;
  85                tmp1 += (tmp1 >> 64) | (tmp1 << 64);
  86                sum64 = tmp1 >> 64;
  87        }
  88        while (len > 8) {
  89                __uint128_t tmp;
  90
  91                sum64 = accumulate(sum64, data);
  92                tmp = *(__uint128_t *)ptr;
  93
  94                len -= 16;
  95                ptr += 2;
  96
  97#ifdef __LITTLE_ENDIAN
  98                data = tmp >> 64;
  99                sum64 = accumulate(sum64, tmp);
 100#else
 101                data = tmp;
 102                sum64 = accumulate(sum64, tmp >> 64);
 103#endif
 104        }
 105        if (len > 0) {
 106                sum64 = accumulate(sum64, data);
 107                data = *ptr;
 108                len -= 8;
 109        }
 110        /*
 111         * Tail: zero any over-read bytes similarly to the head, again
 112         * preserving odd/even alignment.
 113         */
 114        shift = len * -8;
 115#ifdef __LITTLE_ENDIAN
 116        data = (data << shift) >> shift;
 117#else
 118        data = (data >> shift) << shift;
 119#endif
 120        sum64 = accumulate(sum64, data);
 121
 122        /* Finally, folding */
 123        sum64 += (sum64 >> 32) | (sum64 << 32);
 124        sum = sum64 >> 32;
 125        sum += (sum >> 16) | (sum << 16);
 126        if (offset & 1)
 127                return (u16)swab32(sum);
 128
 129        return sum >> 16;
 130}
 131
 132__sum16 csum_ipv6_magic(const struct in6_addr *saddr,
 133                        const struct in6_addr *daddr,
 134                        __u32 len, __u8 proto, __wsum csum)
 135{
 136        __uint128_t src, dst;
 137        u64 sum = (__force u64)csum;
 138
 139        src = *(const __uint128_t *)saddr->s6_addr;
 140        dst = *(const __uint128_t *)daddr->s6_addr;
 141
 142        sum += (__force u32)htonl(len);
 143#ifdef __LITTLE_ENDIAN
 144        sum += (u32)proto << 24;
 145#else
 146        sum += proto;
 147#endif
 148        src += (src >> 64) | (src << 64);
 149        dst += (dst >> 64) | (dst << 64);
 150
 151        sum = accumulate(sum, src >> 64);
 152        sum = accumulate(sum, dst >> 64);
 153
 154        sum += ((sum >> 32) | (sum << 32));
 155        return csum_fold((__force __wsum)(sum >> 32));
 156}
 157EXPORT_SYMBOL(csum_ipv6_magic);
 158