// aes2.txt  change from aes.txt  to kid syntax   16 May 2004
// aes4.txt  change to init-array  7 nov 2005
//  change to init, avoid imm line  14 Oct 2009

// From matmul, it use the array[16] instead of the array 2 dimension
// array[16:n] = array[4:i][4:j] --> n = 4*i + j

//state = array 16		// state
state = array
  50 67 246 168 136 90 48 141 49 49
  152 162 224 55 7 52

//key = array 16		// key
key = array
  43 126 21 22 40 174 210 166 171 247
  21 136 9 207 79 60

//roundkey = array 176	// roundkey

//sbox_table = array 256		// sbox table : ind = 16 * x + y
sbox_table = array
  99 202 183 4 9 83 208 81 205 96
  224 231 186 112 225 140 124 130 253 199
  131 209 239 163 12 129 50 200 120 62
  248 161 119 201 147 35 44 0 170 64
  19 79 58 55 37 181 152 137 123 125
  38 195 26 237 251 143 236 220 10 109
  46 102 17 13 242 250 54 24 27 32
  67 146 95 34 73 141 28 72 105 191
  107 89 63 150 110 252 77 157 151 42
  6 213 166 3 217 230 111 71 247 5
  90 177 51 56 68 144 36 78 180 246
  142 66 197 240 204 154 160 91 133 245
  23 136 92 169 198 14 148 104 48 173
  52 7 82 106 69 188 196 70 194 108
  232 97 155 65 1 212 165 18 59 203
  249 182 167 238 211 86 221 53 30 153
  103 162 229 128 214 190 2 218 126 184
  172 244 116 87 135 45 43 175 241 226
  179 57 127 33 61 20 98 234 31 185
  233 15 254 156 113 235 41 74 80 16
  100 222 145 101 75 134 206 176 215 164
  216 39 227 76 60 255 93 94 149 122
  189 193 85 84 171 114 49 178 47 88
  159 243 25 11 228 174 139 29 40 187
  118 192 21 117 132 207 168 210 115 219
  121 8 138 158 223 22

//inv_sbox_table = array 256	// inv sbox table : ind = 16 * x + y

//rc = array 13
rc = array
  0 1 2 4 8 16 32 64 128 27 54 108 216

to init =
  roundkey = array 176	        // roundkey
  inv_sbox_table = array 256	// ind = 16 * x + y
  a0 = array 4
  a1 = array 4
  a2 = array 4

  roundkey_temp = array 4
  subword_output = array 4

to print_state | i j =
    for i 0 3
        for j 0 3
            print state[i + (j << 2)] space
        nl

// --------------------------------------------

to shift_row | i j k ind1 ind2 temp i4 i8 =
    for i 0 3
        for k 0 i-1
            ind1 = i
            temp = state[ind1]

//			for j 0 3 [
//				ind1 = i + (j << 2)
//				ind2 = i + ((j+1) << 2)
//				state:ind1 = state:ind2
//			]

// unroll
            ind1 = i + 12
            i4 = i + 4
            i8 = i + 8
            state[i] = state[i4]
            state[i4] = state[i8]
            state[i8] = state[ind1]
            state[ind1] = state[i+16]

//			ind1 = i + 12
            state[ind1] = temp

// --------------------------------------------

to mix_column | i j ind si a1i a00 a01 a02 a03 =
    j = 0
    while (j < 16)
        for i 0 3

//			ind = j + i
//			a0:i = state:ind
//			a1:i = ((state:ind) << 1) & 255
//			if ((state:ind) >= 128) a1:i = (a1:i) ^ 27
//			a2:i = (a1:i) ^ (a0:i)

// use temp var to reduce indexing
// si = state:ind,  a1i  = a1:i

            si = state[j+i]
            a0[i] = si
            a1i = (si << 1) & 255
            if( si >= 128 ) a1i = a1i ^ 27
            a1[i] = a1i
            a2[i] = a1i ^ si

// use temp var to reduce indexing
// a00 = a0:0, a01 = a0:1, a02 = a0:2, a03 = a0:3

//		state:j = (a1:0) ^ (a2:1) ^ (a0:2) ^ (a0:3)
//		state:j+1 = (a0:0) ^ (a1:1) ^ (a2:2) ^ (a0:3)
//		state:j+2 = (a0:0) ^ (a0:1) ^ (a1:2) ^ (a2:3)
//		state:j+3 = (a2:0) ^ (a0:1) ^ (a0:2) ^ (a1:3)

        a00 = a0[0]
        a01 = a0[1]
        a02 = a0[2]
        a03 = a0[3]
        state[j] = a1[0] ^ a2[1] ^ a02 ^ a03
        state[j+1] = a00 ^ a1[1] ^ a2[2] ^ a03
        state[j+2] = a00 ^ a01 ^ a1[2] ^ a2[3]
        state[j+3] = a2[0] ^ a01 ^ a02 ^ a1[3]
        j = j + 4


// --------------------------------------------
to index16 x y = x + (y << 4)

// Debug the sbox print function ...
to print_sbox | i j ind =
    for i 0 15
        for j 0 15
            ind = i + (j << 4)
            //print ind space nl
            print sbox_table[ind] space
        nl

to byte_sub | i x y =
    for i 0 15
        x = (state[i]) >> 4
        x = x & 15
        y = (state[i]) & 15
        state[i] = sbox_table[x + (y << 4)]

to inv_byte_sub | i x y =
    for i 0 15
        x = (state[i]) >> 4
        x = x & 15
        y = (state[i]) & 15
        state[i] = inv_sbox_table[x + (y << 4)]

// --------------------------------------------
to print_roundkey | i =
    for i 0 175
        print roundkey[i] space

to add_round_key round | i j k roundkey_index =
    roundkey_index = round << 4
    k = 0
    for i 0 15
        state[i] = state[i] ^ roundkey[roundkey_index + k]
        k = k + 1

to key_expansion | i j k l x y temp rcon ind ind2 =
    for i 0 15 roundkey[i] = key[i]
    for i 4 43
        ind = (i-1) << 2
        for j 0 3
            roundkey_temp[j] = roundkey[ind + j]

        if ((i & 3) == 0)
            temp = roundkey_temp[0]
            for k 0 2 roundkey_temp[k] = roundkey_temp[k+1]
            roundkey_temp[3] = temp

            for k 0 3
                x = (roundkey_temp[k]) >> 4
                x = x & 15
                y = (roundkey_temp[k]) & 15
                roundkey_temp[k] = sbox_table[x + (y << 4)]

            rcon = rc[i>>2]
            roundkey_temp[0] = roundkey_temp[0] ^ rcon

        ind = i << 2
        ind2 = (i-4) << 2
        for j 0 3
            roundkey[ind+j] = roundkey[ind2+j] ^ roundkey_temp[j]


// --------------------------------------------

to cipher | round =
    round = 0

//	key_expansion

//	nl print round
    add_round_key round
    round = round + 1

//	nl print_state

    while (round < 10)
//		nl print round

        byte_sub
//		nl print_state

        shift_row
//		nl print_state

        mix_column
//		nl print_state

        add_round_key round
//		nl print_state

        round = round + 1


//    nl print round
    byte_sub
//	nl print_state

    shift_row
//	nl print_state

    add_round_key round
    nl print_state

// --------------------------------------------


// -- Main --
to main =
  init
  key_expansion
  cipher

// main
// end

