Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 87 additions & 51 deletions lib/std/crypto/ghash.zig
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ pub const Ghash = struct {
pub const mac_length = 16;
pub const key_length = 16;

const pc_count = if (builtin.mode != .ReleaseSmall) 16 else 4;
const agg_2_treshold = 5;
const pc_count = if (builtin.mode != .ReleaseSmall) 16 else 2;
const agg_4_treshold = 22;
const agg_8_treshold = 84;
const agg_16_treshold = 328;

const mul_algorithm = if (builtin.cpu.arch == .x86) .karatsuba else .textbook;

hx: [pc_count]Precomp,
acc: u128 = 0,

Expand All @@ -43,10 +44,10 @@ pub const Ghash = struct {
var hx: [pc_count]Precomp = undefined;
hx[0] = h;
hx[1] = gcmReduce(clsq128(hx[0])); // h^2
hx[2] = gcmReduce(clmul128(hx[1], h)); // h^3
hx[3] = gcmReduce(clsq128(hx[1])); // h^4 = h^2^2

if (builtin.mode != .ReleaseSmall) {
hx[2] = gcmReduce(clmul128(hx[1], h)); // h^3
hx[3] = gcmReduce(clsq128(hx[1])); // h^4 = h^2^2
if (block_count >= agg_8_treshold) {
hx[4] = gcmReduce(clmul128(hx[3], h)); // h^5
hx[5] = gcmReduce(clsq128(hx[2])); // h^6 = h^3^2
Expand All @@ -69,24 +70,32 @@ pub const Ghash = struct {
return Ghash.initForBlockCount(key, math.maxInt(usize));
}

const Selector = enum { lo, hi };
const Selector = enum { lo, hi, hi_lo };

// Carryless multiplication of two 64-bit integers for x86_64.
inline fn clmulPclmul(x: u128, y: u128, comptime half: Selector) u128 {
if (half == .hi) {
const product = asm (
\\ vpclmulqdq $0x11, %[x], %[y], %[out]
: [out] "=x" (-> @Vector(2, u64)),
: [x] "x" (@bitCast(@Vector(2, u64), @as(u128, x))),
[y] "x" (@bitCast(@Vector(2, u64), @as(u128, y))),
: [x] "x" (@bitCast(@Vector(2, u64), x)),
[y] "x" (@bitCast(@Vector(2, u64), y)),
);
return @bitCast(u128, product);
} else {
} else if (half == .lo) {
const product = asm (
\\ vpclmulqdq $0x00, %[x], %[y], %[out]
: [out] "=x" (-> @Vector(2, u64)),
: [x] "x" (@bitCast(@Vector(2, u64), @as(u128, x))),
[y] "x" (@bitCast(@Vector(2, u64), @as(u128, y))),
: [x] "x" (@bitCast(@Vector(2, u64), x)),
[y] "x" (@bitCast(@Vector(2, u64), y)),
);
return @bitCast(u128, product);
} else {
const product = asm (
\\ vpclmulqdq $0x10, %[x], %[y], %[out]
: [out] "=x" (-> @Vector(2, u64)),
: [x] "x" (@bitCast(@Vector(2, u64), x)),
[y] "x" (@bitCast(@Vector(2, u64), y)),
);
return @bitCast(u128, product);
}
Expand All @@ -98,16 +107,24 @@ pub const Ghash = struct {
const product = asm (
\\ pmull2 %[out].1q, %[x].2d, %[y].2d
: [out] "=w" (-> @Vector(2, u64)),
: [x] "w" (@bitCast(@Vector(2, u64), @as(u128, x))),
[y] "w" (@bitCast(@Vector(2, u64), @as(u128, y))),
: [x] "w" (@bitCast(@Vector(2, u64), x)),
[y] "w" (@bitCast(@Vector(2, u64), y)),
);
return @bitCast(u128, product);
} else if (half == .lo) {
const product = asm (
\\ pmull %[out].1q, %[x].1d, %[y].1d
: [out] "=w" (-> @Vector(2, u64)),
: [x] "w" (@bitCast(@Vector(2, u64), x)),
[y] "w" (@bitCast(@Vector(2, u64), y)),
);
return @bitCast(u128, product);
} else {
const product = asm (
\\ pmull %[out].1q, %[x].1d, %[y].1d
: [out] "=w" (-> @Vector(2, u64)),
: [x] "w" (@bitCast(@Vector(2, u64), @as(u128, x))),
[y] "w" (@bitCast(@Vector(2, u64), @as(u128, y))),
: [x] "w" (@bitCast(@Vector(2, u64), x >> 64)),
[y] "w" (@bitCast(@Vector(2, u64), y)),
);
return @bitCast(u128, product);
}
Expand Down Expand Up @@ -144,38 +161,63 @@ pub const Ghash = struct {
(z3 & 0x88888888888888888888888888888888) ^ extra;
}

const I256 = struct {
hi: u128,
lo: u128,
mid: u128,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I understand what mid is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Schoolbook multiplication:

ab * cd =

    bd
+  ad
+  bc
+ ac

ac is hi, bd is lo and in the middle we have ad+bc. Eventually, a shifted addition will get rid of it:

mid = ad + bc
 lo = bd + (mid << n)
 hi = ac + (mid >> n)

We are doing several multiplications in a row. So, instead of adding the middle term after every multiplication, we can accumulate the lo, hi and mid values, and only add mid at the end. This is what we do here.

};

inline fn xor256(x: *I256, y: I256) void {
x.* = I256{
.hi = x.hi ^ y.hi,
.lo = x.lo ^ y.lo,
.mid = x.mid ^ y.mid,
};
}

// Square a 128-bit integer in GF(2^128).
fn clsq128(x: u128) u256 {
const lo = @truncate(u64, x);
const hi = @truncate(u64, x >> 64);
const mid = lo ^ hi;
const r_lo = clmul(x, x, .lo);
const r_hi = clmul(x, x, .hi);
const r_mid = clmul(mid, mid, .lo) ^ r_lo ^ r_hi;
return (@as(u256, r_hi) << 128) ^ (@as(u256, r_mid) << 64) ^ r_lo;
fn clsq128(x: u128) I256 {
return .{
.hi = clmul(x, x, .hi),
.lo = clmul(x, x, .lo),
.mid = 0,
};
}

// Multiply two 128-bit integers in GF(2^128).
inline fn clmul128(x: u128, y: u128) u256 {
const x_hi = @truncate(u64, x >> 64);
const y_hi = @truncate(u64, y >> 64);
const r_lo = clmul(x, y, .lo);
const r_hi = clmul(x, y, .hi);
const r_mid = clmul(x ^ x_hi, y ^ y_hi, .lo) ^ r_lo ^ r_hi;
return (@as(u256, r_hi) << 128) ^ (@as(u256, r_mid) << 64) ^ r_lo;
inline fn clmul128(x: u128, y: u128) I256 {
if (mul_algorithm == .karatsuba) {
const x_hi = @truncate(u64, x >> 64);
const y_hi = @truncate(u64, y >> 64);
const r_lo = clmul(x, y, .lo);
const r_hi = clmul(x, y, .hi);
const r_mid = clmul(x ^ x_hi, y ^ y_hi, .lo) ^ r_lo ^ r_hi;
return .{
.hi = r_hi,
.lo = r_lo,
.mid = r_mid,
};
} else {
return .{
.hi = clmul(x, y, .hi),
.lo = clmul(x, y, .lo),
.mid = clmul(x, y, .hi_lo) ^ clmul(y, x, .hi_lo),
};
}
}

// Reduce a 256-bit representative of a polynomial modulo the irreducible polynomial x^128 + x^127 + x^126 + x^121 + 1.
// This is done *without reversing the bits*, using Shay Gueron's black magic demysticated here:
// https://blog.quarkslab.com/reversing-a-finite-field-multiplication-optimization.html
inline fn gcmReduce(x: u256) u128 {
inline fn gcmReduce(x: I256) u128 {
const hi = x.hi ^ (x.mid >> 64);
const lo = x.lo ^ (x.mid << 64);
const p64 = (((1 << 121) | (1 << 126) | (1 << 127)) >> 64);
const lo = @truncate(u128, x);
const a = clmul(lo, p64, .lo);
const b = ((lo << 64) | (lo >> 64)) ^ a;
const c = clmul(b, p64, .lo);
const d = ((b << 64) | (b >> 64)) ^ c;
return d ^ @truncate(u128, x >> 128);
return d ^ hi;
}

const has_pclmul = std.Target.x86.featureSetHas(builtin.cpu.features, .pclmul);
Expand All @@ -202,7 +244,7 @@ pub const Ghash = struct {
var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[15 - 0]);
comptime var j = 1;
inline while (j < 16) : (j += 1) {
u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[15 - j]);
xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[15 - j]));
}
acc = gcmReduce(u);
}
Expand All @@ -212,7 +254,7 @@ pub const Ghash = struct {
var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[7 - 0]);
comptime var j = 1;
inline while (j < 8) : (j += 1) {
u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[7 - j]);
xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[7 - j]));
}
acc = gcmReduce(u);
}
Expand All @@ -222,31 +264,25 @@ pub const Ghash = struct {
var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[3 - 0]);
comptime var j = 1;
inline while (j < 4) : (j += 1) {
u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[3 - j]);
xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[3 - j]));
}
acc = gcmReduce(u);
}
} else if (msg.len >= agg_2_treshold * block_length) {
// 2-blocks aggregated reduction
while (i + 32 <= msg.len) : (i += 32) {
var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[1 - 0]);
comptime var j = 1;
inline while (j < 2) : (j += 1) {
u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[1 - j]);
}
acc = gcmReduce(u);
}
// 2-blocks aggregated reduction
while (i + 32 <= msg.len) : (i += 32) {
var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[1 - 0]);
comptime var j = 1;
inline while (j < 2) : (j += 1) {
xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[1 - j]));
}
acc = gcmReduce(u);
}
// remaining blocks
if (i < msg.len) {
const n = (msg.len - i) / 16;
var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[n - 1 - 0]);
var j: usize = 1;
while (j < n) : (j += 1) {
u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[n - 1 - j]);
}
i += n * 16;
const u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[0]);
acc = gcmReduce(u);
i += 16;
}
assert(i == msg.len);
st.acc = acc;
Expand Down