[PATCH 2/2] Simplify AES key schedule implementation

Jussi Kivilinna jussi.kivilinna at iki.fi
Wed Jul 27 22:16:32 CEST 2022


* cipher/rijndael-armv8-ce.c (_gcry_aes_armv8_ce_setkey): New key
schedule with simplified structure and less stack usage.
* cipher/rijndael-internal.h (RIJNDAEL_context_s): Add
'keyschedule32b'.
(keyschenc32b): New.
* cipher/rijndael-ppc-common.h (vec_u32): New.
* cipher/rijndael-ppc.c (vec_bswap32_const): Remove.
(_gcry_aes_sbox4_ppc8): Optimize for less instructions emitted.
(keysched_idx): New.
(_gcry_aes_ppc8_setkey): New key schedule with simplified structure.
* cipher/rijndael-tables.h (rcon): Remove.
* cipher/rijndael.c (sbox4): New.
(do_setkey): New key schedule with simplified structure and less
stack usage.
--

Signed-off-by: Jussi Kivilinna <jussi.kivilinna at iki.fi>
---
 cipher/rijndael-armv8-ce.c   | 102 +++++-----------------
 cipher/rijndael-internal.h   |  12 +--
 cipher/rijndael-ppc-common.h |   1 +
 cipher/rijndael-ppc.c        | 158 +++++++++++------------------------
 cipher/rijndael-tables.h     |   7 --
 cipher/rijndael.c            | 118 ++++++++------------------
 6 files changed, 117 insertions(+), 281 deletions(-)

diff --git a/cipher/rijndael-armv8-ce.c b/cipher/rijndael-armv8-ce.c
index e53c940e..10fb58be 100644
--- a/cipher/rijndael-armv8-ce.c
+++ b/cipher/rijndael-armv8-ce.c
@@ -128,103 +128,47 @@ typedef void (*xts_crypt_fn_t) (const void *keysched, unsigned char *outbuf,
                                 unsigned char *tweak, size_t nblocks,
                                 unsigned int nrounds);
 
+
 void
 _gcry_aes_armv8_ce_setkey (RIJNDAEL_context *ctx, const byte *key)
 {
-  union
-    {
-      PROPERLY_ALIGNED_TYPE dummy;
-      byte data[MAXKC][4];
-      u32 data32[MAXKC];
-    } tkk[2];
   unsigned int rounds = ctx->rounds;
-  int KC = rounds - 6;
-  unsigned int keylen = KC * 4;
-  unsigned int i, r, t;
+  unsigned int KC = rounds - 6;
+  u32 *W_u32 = ctx->keyschenc32b;
+  unsigned int i, j;
+  u32 W_prev;
   byte rcon = 1;
-  int j;
-#define k      tkk[0].data
-#define k_u32  tkk[0].data32
-#define tk     tkk[1].data
-#define tk_u32 tkk[1].data32
-#define W      (ctx->keyschenc)
-#define W_u32  (ctx->keyschenc32)
-
-  for (i = 0; i < keylen; i++)
-    {
-      k[i >> 2][i & 3] = key[i];
-    }
 
-  for (j = KC-1; j >= 0; j--)
+  for (i = 0; i < KC; i += 2)
     {
-      tk_u32[j] = k_u32[j];
-    }
-  r = 0;
-  t = 0;
-  /* Copy values into round key array.  */
-  for (j = 0; (j < KC) && (r < rounds + 1); )
-    {
-      for (; (j < KC) && (t < 4); j++, t++)
-        {
-          W_u32[r][t] = le_bswap32(tk_u32[j]);
-        }
-      if (t == 4)
-        {
-          r++;
-          t = 0;
-        }
+      W_u32[i + 0] = buf_get_le32(key + i * 4 + 0);
+      W_u32[i + 1] = buf_get_le32(key + i * 4 + 4);
     }
 
-  while (r < rounds + 1)
+  for (i = KC, j = KC, W_prev = W_u32[KC - 1];
+       i < 4 * (rounds + 1);
+       i += 2, j += 2)
     {
-      tk_u32[0] ^= _gcry_aes_sbox4_armv8_ce(rol(tk_u32[KC - 1], 24)) ^ rcon;
+      u32 temp0 = W_prev;
+      u32 temp1;
 
-      if (KC != 8)
+      if (j == KC)
         {
-          for (j = 1; j < KC; j++)
-            {
-              tk_u32[j] ^= tk_u32[j-1];
-            }
+          j = 0;
+          temp0 = _gcry_aes_sbox4_armv8_ce(rol(temp0, 24)) ^ rcon;
+          rcon = ((rcon << 1) ^ (-(rcon >> 7) & 0x1b)) & 0xff;
         }
-      else
+      else if (KC == 8 && j == 4)
         {
-          for (j = 1; j < KC/2; j++)
-            {
-              tk_u32[j] ^= tk_u32[j-1];
-            }
-
-          tk_u32[KC/2] ^= _gcry_aes_sbox4_armv8_ce(tk_u32[KC/2 - 1]);
-
-          for (j = KC/2 + 1; j < KC; j++)
-            {
-              tk_u32[j] ^= tk_u32[j-1];
-            }
+          temp0 = _gcry_aes_sbox4_armv8_ce(temp0);
         }
 
-      /* Copy values into round key array.  */
-      for (j = 0; (j < KC) && (r < rounds + 1); )
-        {
-          for (; (j < KC) && (t < 4); j++, t++)
-            {
-              W_u32[r][t] = le_bswap32(tk_u32[j]);
-            }
-          if (t == 4)
-            {
-              r++;
-              t = 0;
-            }
-        }
+      temp1 = W_u32[i - KC + 0];
 
-      rcon = (rcon << 1) ^ ((rcon >> 7) * 0x1b);
+      W_u32[i + 0] = temp0 ^ temp1;
+      W_u32[i + 1] = W_u32[i - KC + 1] ^ temp0 ^ temp1;
+      W_prev = W_u32[i + 1];
     }
-
-#undef W
-#undef tk
-#undef k
-#undef W_u32
-#undef tk_u32
-#undef k_u32
-  wipememory(&tkk, sizeof(tkk));
 }
 
 /* Make a decryption key from an encryption key. */
diff --git a/cipher/rijndael-internal.h b/cipher/rijndael-internal.h
index 30604088..52c892fd 100644
--- a/cipher/rijndael-internal.h
+++ b/cipher/rijndael-internal.h
@@ -160,6 +160,7 @@ typedef struct RIJNDAEL_context_s
     PROPERLY_ALIGNED_TYPE dummy;
     byte keyschedule[MAXROUNDS+1][4][4];
     u32 keyschedule32[MAXROUNDS+1][4];
+    u32 keyschedule32b[(MAXROUNDS+1)*4];
 #ifdef USE_PADLOCK
     /* The key as passed to the padlock engine.  It is only used if
        the padlock engine is used (USE_PADLOCK, below).  */
@@ -195,10 +196,11 @@ typedef struct RIJNDAEL_context_s
 } RIJNDAEL_context ATTR_ALIGNED_16;
 
 /* Macros defining alias for the keyschedules.  */
-#define keyschenc   u1.keyschedule
-#define keyschenc32 u1.keyschedule32
-#define keyschdec   u2.keyschedule
-#define keyschdec32 u2.keyschedule32
-#define padlockkey  u1.padlock_key
+#define keyschenc     u1.keyschedule
+#define keyschenc32   u1.keyschedule32
+#define keyschenc32b  u1.keyschedule32b
+#define keyschdec     u2.keyschedule
+#define keyschdec32   u2.keyschedule32
+#define padlockkey    u1.padlock_key
 
 #endif /* G10_RIJNDAEL_INTERNAL_H */
diff --git a/cipher/rijndael-ppc-common.h b/cipher/rijndael-ppc-common.h
index 3fa9a0b9..e4a90934 100644
--- a/cipher/rijndael-ppc-common.h
+++ b/cipher/rijndael-ppc-common.h
@@ -30,6 +30,7 @@
 
 
 typedef vector unsigned char block;
+typedef vector unsigned int vec_u32;
 
 typedef union
 {
diff --git a/cipher/rijndael-ppc.c b/cipher/rijndael-ppc.c
index f5c32361..6a32271d 100644
--- a/cipher/rijndael-ppc.c
+++ b/cipher/rijndael-ppc.c
@@ -34,10 +34,7 @@
 #include "rijndael-ppc-common.h"
 
 
-#ifdef WORDS_BIGENDIAN
-static const block vec_bswap32_const =
-  { 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12 };
-#else
+#ifndef WORDS_BIGENDIAN
 static const block vec_bswap32_const_neg =
   { ~3, ~2, ~1, ~0, ~7, ~6, ~5, ~4, ~11, ~10, ~9, ~8, ~15, ~14, ~13, ~12 };
 #endif
@@ -107,134 +104,81 @@ asm_store_be_noswap(block vec, unsigned long offset, void *ptr)
 static ASM_FUNC_ATTR_INLINE u32
 _gcry_aes_sbox4_ppc8(u32 fourbytes)
 {
-  union
-    {
-      PROPERLY_ALIGNED_TYPE dummy;
-      block data_vec;
-      u32 data32[4];
-    } u;
+  vec_u32 vec_fourbyte = { fourbytes, fourbytes, fourbytes, fourbytes };
+#ifdef WORDS_BIGENDIAN
+  return ((vec_u32)vec_sbox_be((block)vec_fourbyte))[1];
+#else
+  return ((vec_u32)vec_sbox_be((block)vec_fourbyte))[2];
+#endif
+}
+
 
-  u.data32[0] = fourbytes;
-  u.data_vec = vec_sbox_be(u.data_vec);
-  return u.data32[0];
+static ASM_FUNC_ATTR_INLINE unsigned int
+keysched_idx(unsigned int in)
+{
+#ifdef WORDS_BIGENDIAN
+  return in;
+#else
+  return (in & ~3U) | (3U - (in & 3U));
+#endif
 }
 
+
 void
 _gcry_aes_ppc8_setkey (RIJNDAEL_context *ctx, const byte *key)
 {
-  const block bige_const = asm_load_be_const();
-  union
-    {
-      PROPERLY_ALIGNED_TYPE dummy;
-      byte data[MAXKC][4];
-      u32 data32[MAXKC];
-    } tkk[2];
+  u32 tk_u32[MAXKC];
   unsigned int rounds = ctx->rounds;
-  int KC = rounds - 6;
-  unsigned int keylen = KC * 4;
-  u128_t *ekey = (u128_t *)(void *)ctx->keyschenc;
-  unsigned int i, r, t;
+  unsigned int KC = rounds - 6;
+  u32 *W_u32 = ctx->keyschenc32b;
+  unsigned int i, j;
+  u32 tk_prev;
   byte rcon = 1;
-  int j;
-#define k      tkk[0].data
-#define k_u32  tkk[0].data32
-#define tk     tkk[1].data
-#define tk_u32 tkk[1].data32
-#define W      (ctx->keyschenc)
-#define W_u32  (ctx->keyschenc32)
 
-  for (i = 0; i < keylen; i++)
+  for (i = 0; i < KC; i += 2)
     {
-      k[i >> 2][i & 3] = key[i];
+      unsigned int idx0 = keysched_idx(i + 0);
+      unsigned int idx1 = keysched_idx(i + 1);
+      tk_u32[i + 0] = buf_get_le32(key + i * 4 + 0);
+      tk_u32[i + 1] = buf_get_le32(key + i * 4 + 4);
+      W_u32[idx0] = _gcry_bswap32(tk_u32[i + 0]);
+      W_u32[idx1] = _gcry_bswap32(tk_u32[i + 1]);
     }
 
-  for (j = KC-1; j >= 0; j--)
-    {
-      tk_u32[j] = k_u32[j];
-    }
-  r = 0;
-  t = 0;
-  /* Copy values into round key array.  */
-  for (j = 0; (j < KC) && (r < rounds + 1); )
+  for (i = KC, j = KC, tk_prev = tk_u32[KC - 1];
+       i < 4 * (rounds + 1);
+       i += 2, j += 2)
     {
-      for (; (j < KC) && (t < 4); j++, t++)
-        {
-          W_u32[r][t] = le_bswap32(tk_u32[j]);
-        }
-      if (t == 4)
-        {
-          r++;
-          t = 0;
-        }
-    }
-  while (r < rounds + 1)
-    {
-      tk_u32[0] ^=
-	le_bswap32(
-	  _gcry_aes_sbox4_ppc8(rol(le_bswap32(tk_u32[KC - 1]), 24)) ^ rcon);
+      unsigned int idx0 = keysched_idx(i + 0);
+      unsigned int idx1 = keysched_idx(i + 1);
+      u32 temp0 = tk_prev;
+      u32 temp1;
 
-      if (KC != 8)
+      if (j == KC)
         {
-          for (j = 1; j < KC; j++)
-            {
-              tk_u32[j] ^= tk_u32[j-1];
-            }
+          j = 0;
+          temp0 = _gcry_aes_sbox4_ppc8(rol(temp0, 24)) ^ rcon;
+          rcon = ((rcon << 1) ^ (-(rcon >> 7) & 0x1b)) & 0xff;
         }
-      else
+      else if (KC == 8 && j == 4)
         {
-          for (j = 1; j < KC/2; j++)
-            {
-              tk_u32[j] ^= tk_u32[j-1];
-            }
-
-          tk_u32[KC/2] ^=
-	    le_bswap32(_gcry_aes_sbox4_ppc8(le_bswap32(tk_u32[KC/2 - 1])));
-
-          for (j = KC/2 + 1; j < KC; j++)
-            {
-              tk_u32[j] ^= tk_u32[j-1];
-            }
+          temp0 = _gcry_aes_sbox4_ppc8(temp0);
         }
 
-      /* Copy values into round key array.  */
-      for (j = 0; (j < KC) && (r < rounds + 1); )
-        {
-          for (; (j < KC) && (t < 4); j++, t++)
-            {
-              W_u32[r][t] = le_bswap32(tk_u32[j]);
-            }
-          if (t == 4)
-            {
-              r++;
-              t = 0;
-            }
-        }
+      temp1 = tk_u32[j + 0];
 
-      rcon = (rcon << 1) ^ (-(rcon >> 7) & 0x1b);
-    }
+      tk_u32[j + 0] = temp0 ^ temp1;
+      tk_u32[j + 1] ^= temp0 ^ temp1;
+      tk_prev = tk_u32[j + 1];
 
-  /* Store in big-endian order. */
-  for (r = 0; r <= rounds; r++)
-    {
-#ifndef WORDS_BIGENDIAN
-      VEC_STORE_BE(ekey, r, ALIGNED_LOAD (ekey, r), bige_const);
-#else
-      block rvec = ALIGNED_LOAD (ekey, r);
-      ALIGNED_STORE (ekey, r,
-                     vec_perm(rvec, rvec, vec_bswap32_const));
-      (void)bige_const;
-#endif
+      W_u32[idx0] = _gcry_bswap32(tk_u32[j + 0]);
+      W_u32[idx1] = _gcry_bswap32(tk_u32[j + 1]);
     }
 
-#undef W
-#undef tk
-#undef k
-#undef W_u32
-#undef tk_u32
-#undef k_u32
-  wipememory(&tkk, sizeof(tkk));
+  wipememory(tk_u32, sizeof(tk_u32));
 }
 
+
 void
 _gcry_aes_ppc8_prepare_decryption (RIJNDAEL_context *ctx)
 {
diff --git a/cipher/rijndael-tables.h b/cipher/rijndael-tables.h
index b54d9593..e46ce08c 100644
--- a/cipher/rijndael-tables.h
+++ b/cipher/rijndael-tables.h
@@ -218,10 +218,3 @@ static struct
 
 #define decT dec_tables.T
 #define inv_sbox dec_tables.inv_sbox
-
-static const u32 rcon[30] =
-  {
-    0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c,
-    0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35,
-    0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91
-  };
diff --git a/cipher/rijndael.c b/cipher/rijndael.c
index 7e75ddd2..f3060ea5 100644
--- a/cipher/rijndael.c
+++ b/cipher/rijndael.c
@@ -422,6 +422,17 @@ static void prefetch_dec(void)
 
 
 
+static inline u32
+sbox4(u32 inb4)
+{
+  u32 out;
+  out =  (encT[(inb4 >> 0) & 0xffU] & 0xff00U) >> 8;
+  out |= (encT[(inb4 >> 8) & 0xffU] & 0xff00U) >> 0;
+  out |= (encT[(inb4 >> 16) & 0xffU] & 0xff0000U) << 0;
+  out |= (encT[(inb4 >> 24) & 0xffU] & 0xff0000U) << 8;
+  return out;
+}
+
 /* Perform the key setup.  */
 static gcry_err_code_t
 do_setkey (RIJNDAEL_context *ctx, const byte *key, const unsigned keylen,
@@ -431,8 +442,7 @@ do_setkey (RIJNDAEL_context *ctx, const byte *key, const unsigned keylen,
   static const char *selftest_failed = 0;
   void (*hw_setkey)(RIJNDAEL_context *ctx, const byte *key) = NULL;
   int rounds;
-  int i,j, r, t, rconpointer = 0;
-  int KC;
+  unsigned int KC;
   unsigned int hwfeatures;
 
   /* The on-the-fly self tests are only run in non-fips mode. In fips
@@ -662,101 +672,43 @@ do_setkey (RIJNDAEL_context *ctx, const byte *key, const unsigned keylen,
     }
   else
     {
-      const byte *sbox = ((const byte *)encT) + 1;
-      union
-        {
-          PROPERLY_ALIGNED_TYPE dummy;
-          byte data[MAXKC][4];
-          u32 data32[MAXKC];
-        } tkk[2];
-#define k      tkk[0].data
-#define k_u32  tkk[0].data32
-#define tk     tkk[1].data
-#define tk_u32 tkk[1].data32
-#define W      (ctx->keyschenc)
-#define W_u32  (ctx->keyschenc32)
+      u32 W_prev;
+      u32 *W_u32 = ctx->keyschenc32b;
+      byte rcon = 1;
+      unsigned int i, j;
 
       prefetch_enc();
 
-      for (i = 0; i < keylen; i++)
+      for (i = 0; i < KC; i += 2)
         {
-          k[i >> 2][i & 3] = key[i];
+          W_u32[i + 0] = buf_get_le32(key + i * 4 + 0);
+          W_u32[i + 1] = buf_get_le32(key + i * 4 + 4);
         }
 
-      for (j = KC-1; j >= 0; j--)
+      for (i = KC, j = KC, W_prev = W_u32[KC - 1];
+           i < 4 * (rounds + 1);
+           i += 2, j += 2)
         {
-          tk_u32[j] = k_u32[j];
-        }
-      r = 0;
-      t = 0;
-      /* Copy values into round key array.  */
-      for (j = 0; (j < KC) && (r < rounds + 1); )
-        {
-          for (; (j < KC) && (t < 4); j++, t++)
-            {
-              W_u32[r][t] = le_bswap32(tk_u32[j]);
-            }
-          if (t == 4)
-            {
-              r++;
-              t = 0;
-            }
-        }
+          u32 temp0 = W_prev;
+          u32 temp1;
 
-      while (r < rounds + 1)
-        {
-          /* While not enough round key material calculated calculate
-             new values.  */
-          tk[0][0] ^= sbox[tk[KC-1][1] * 4];
-          tk[0][1] ^= sbox[tk[KC-1][2] * 4];
-          tk[0][2] ^= sbox[tk[KC-1][3] * 4];
-          tk[0][3] ^= sbox[tk[KC-1][0] * 4];
-          tk[0][0] ^= rcon[rconpointer++];
-
-          if (KC != 8)
+          if (j == KC)
             {
-              for (j = 1; j < KC; j++)
-                {
-                  tk_u32[j] ^= tk_u32[j-1];
-                }
+              j = 0;
+              temp0 = sbox4(rol(temp0, 24)) ^ rcon;
+              rcon = ((rcon << 1) ^ (-(rcon >> 7) & 0x1b)) & 0xff;
             }
-          else
+          else if (KC == 8 && j == 4)
             {
-              for (j = 1; j < KC/2; j++)
-                {
-                  tk_u32[j] ^= tk_u32[j-1];
-                }
-              tk[KC/2][0] ^= sbox[tk[KC/2 - 1][0] * 4];
-              tk[KC/2][1] ^= sbox[tk[KC/2 - 1][1] * 4];
-              tk[KC/2][2] ^= sbox[tk[KC/2 - 1][2] * 4];
-              tk[KC/2][3] ^= sbox[tk[KC/2 - 1][3] * 4];
-              for (j = KC/2 + 1; j < KC; j++)
-                {
-                  tk_u32[j] ^= tk_u32[j-1];
-                }
+              temp0 = sbox4(temp0);
             }
 
-          /* Copy values into round key array.  */
-          for (j = 0; (j < KC) && (r < rounds + 1); )
-            {
-              for (; (j < KC) && (t < 4); j++, t++)
-                {
-                  W_u32[r][t] = le_bswap32(tk_u32[j]);
-                }
-              if (t == 4)
-                {
-                  r++;
-                  t = 0;
-                }
-            }
+          temp1 = W_u32[i - KC + 0];
+
+          W_u32[i + 0] = temp0 ^ temp1;
+          W_u32[i + 1] = W_u32[i - KC + 1] ^ temp0 ^ temp1;
+          W_prev = W_u32[i + 1];
         }
-#undef W
-#undef tk
-#undef k
-#undef W_u32
-#undef tk_u32
-#undef k_u32
-      wipememory(&tkk, sizeof(tkk));
     }
 
   return 0;
-- 
2.34.1




More information about the Gcrypt-devel mailing list