diff --git a/include/cybozu/sha2.hpp b/include/cybozu/sha2.hpp index 1830936..335a897 100644 --- a/include/cybozu/sha2.hpp +++ b/include/cybozu/sha2.hpp @@ -145,24 +145,22 @@ inline uint64_t rot64(uint64_t x, int s) template struct Common { - void term(const char *buf, size_t bufSize) + void term(uint8_t *buf, size_t bufSize) { assert(bufSize < T::blockSize_); T& self = static_cast(*this); const uint64_t totalSize = self.totalSize_ + bufSize; - uint8_t last[T::blockSize_]; - memcpy(last, buf, bufSize); - last[bufSize] = uint8_t(0x80); /* top bit = 1 */ - memset(&last[bufSize + 1], 0, T::blockSize_ - bufSize - 1); + buf[bufSize] = uint8_t(0x80); /* top bit = 1 */ + memset(&buf[bufSize + 1], 0, T::blockSize_ - bufSize - 1); if (bufSize >= T::blockSize_ - T::msgLenByte_) { - self.round(reinterpret_cast(last)); - memset(last, 0, sizeof(last)); // clear stack + self.round(buf); + memset(buf, 0, T::blockSize_ - 8); // clear stack } - cybozu::Set64bitAsBE(&last[T::blockSize_ - 8], totalSize * 8); - self.round(reinterpret_cast(last)); + cybozu::Set64bitAsBE(&buf[T::blockSize_ - 8], totalSize * 8); + self.round(buf); } - void inner_update(const char *buf, size_t bufSize) + void inner_update(const uint8_t *buf, size_t bufSize) { T& self = static_cast(*this); if (bufSize == 0) return; @@ -203,15 +201,35 @@ private: static const size_t msgLenByte_ = 8; uint64_t totalSize_; size_t roundBufSize_; - char roundBuf_[blockSize_]; + uint8_t roundBuf_[blockSize_]; uint32_t h_[hSize_]; static const size_t outByteSize_ = hSize_ * sizeof(uint32_t); const uint32_t *k_; + template + void round1(uint32_t *s, uint32_t *w, int i) + { + using namespace sha2_local; + uint32_t e = s[i4]; + uint32_t h = s[i7]; + h += rot32(e, 6) ^ rot32(e, 11) ^ rot32(e, 25); + uint32_t f = s[i5]; + uint32_t g = s[i6]; + h += g ^ (e & (f ^ g)); + h += k_[i]; + h += w[i]; + s[i3] += h; + uint32_t a = s[i0]; + uint32_t b = s[i1]; + uint32_t c = s[i2]; + h += rot32(a, 2) ^ rot32(a, 13) ^ rot32(a, 22); + h += ((a | b) & c) | (a & b); + s[i7] = h; + } /** @param buf [in] buffer(64byte) */ - void round(const char *buf) + void round(const uint8_t *buf) { using namespace sha2_local; uint32_t w[64]; @@ -225,38 +243,23 @@ private: uint32_t s1 = rot32(t, 17) ^ rot32(t, 19) ^ (t >> 10); w[i] = w[i - 16] + s0 + w[i - 7] + s1; } - uint32_t a = h_[0]; - uint32_t b = h_[1]; - uint32_t c = h_[2]; - uint32_t d = h_[3]; - uint32_t e = h_[4]; - uint32_t f = h_[5]; - uint32_t g = h_[6]; - uint32_t h = h_[7]; - for (int i = 0; i < 64; i++) { - uint32_t s1 = rot32(e, 6) ^ rot32(e, 11) ^ rot32(e, 25); - uint32_t ch = g ^ (e & (f ^ g)); - uint32_t t1 = h + s1 + ch + k_[i] + w[i]; - uint32_t s0 = rot32(a, 2) ^ rot32(a, 13) ^ rot32(a, 22); - uint32_t maj = ((a | b) & c) | (a & b); - uint32_t t2 = s0 + maj; - h = g; - g = f; - f = e; - e = d + t1; - d = c; - c = b; - b = a; - a = t1 + t2; + uint32_t s[8]; + for (int i = 0; i < 8; i++) { + s[i] = h_[i]; + } + for (int i = 0; i < 64; i += 8) { + round1<0, 1, 2, 3, 4, 5, 6, 7>(s, w, i + 0); + round1<7, 0, 1, 2, 3, 4, 5, 6>(s, w, i + 1); + round1<6, 7, 0, 1, 2, 3, 4, 5>(s, w, i + 2); + round1<5, 6, 7, 0, 1, 2, 3, 4>(s, w, i + 3); + round1<4, 5, 6, 7, 0, 1, 2, 3>(s, w, i + 4); + round1<3, 4, 5, 6, 7, 0, 1, 2>(s, w, i + 5); + round1<2, 3, 4, 5, 6, 7, 0, 1>(s, w, i + 6); + round1<1, 2, 3, 4, 5, 6, 7, 0>(s, w, i + 7); + } + for (int i = 0; i < 8; i++) { + h_[i] += s[i]; } - h_[0] += a; - h_[1] += b; - h_[2] += c; - h_[3] += d; - h_[4] += e; - h_[5] += f; - h_[6] += g; - h_[7] += h; totalSize_ += blockSize_; } public: @@ -290,7 +293,7 @@ public: } void update(const void *buf, size_t bufSize) { - inner_update(reinterpret_cast(buf), bufSize); + inner_update(reinterpret_cast(buf), bufSize); } size_t digest(void *md, size_t mdSize, const void *buf, size_t bufSize) { @@ -329,7 +332,7 @@ private: static const size_t msgLenByte_ = 16; uint64_t totalSize_; size_t roundBufSize_; - char roundBuf_[blockSize_]; + uint8_t roundBuf_[blockSize_]; uint64_t h_[hSize_]; static const size_t outByteSize_ = hSize_ * sizeof(uint64_t); const uint64_t *k_; @@ -359,7 +362,7 @@ private: /** @param buf [in] buffer(64byte) */ - void round(const char *buf) + void round(const uint8_t *buf) { using namespace sha2_local; uint64_t w[80]; @@ -431,7 +434,7 @@ public: } void update(const void *buf, size_t bufSize) { - inner_update(reinterpret_cast(buf), bufSize); + inner_update(reinterpret_cast(buf), bufSize); } size_t digest(void *md, size_t mdSize, const void *buf, size_t bufSize) {