Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 77 additions & 88 deletions gemma/gemma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@
#ifndef GEMMA_ONCE
#define GEMMA_ONCE

#include <math.h> // sqrtf
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <algorithm>
#include <array>
#include <cmath>
#include <memory>
#include <string>
#include <utility>
Expand Down Expand Up @@ -73,11 +71,14 @@ struct Activations {
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2;
static constexpr size_t kCachePosSize =
TConfig::kGemmaLayers * kCacheLayerSize;
static constexpr size_t kQDim = kHeads == kKVHeads ? kQKVDim * 3 : kQKVDim;
static constexpr bool kIsMHA = kHeads == kKVHeads; // Multi-Head Attention
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
static constexpr size_t kQStride = kQKVDim * (kIsMHA ? 3 : 1);

std::array<float, kBatchSize * kModelDim> x; // input
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
std::array<float, kBatchSize * kHeads * kQDim> q; // query vector
std::array<float, kBatchSize * kHeads * kQStride> q; // query vector
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
att; // attention vector
std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output
Expand Down Expand Up @@ -242,7 +243,7 @@ HWY_NOINLINE void GriffinRecurrent(
using D = hn::ScalableTag<float>;
HWY_DASSERT(num_tokens <= kBatchSize);
static constexpr size_t kModelDim =
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr size_t kHeads = TConfig::kHeads;

Expand Down Expand Up @@ -370,26 +371,75 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Attention");
HWY_DASSERT(num_tokens <= kBatchSize);
static constexpr size_t kQKVDim = gcpp::Activations<TConfig, 1>::kQKVDim;
static constexpr size_t kCachePosSize =
gcpp::Activations<TConfig, kBatchSize>::kCachePosSize;
static constexpr size_t kCacheLayerSize =
gcpp::Activations<TConfig, kBatchSize>::kCacheLayerSize;
static constexpr size_t kModelDim =
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
using TActivations = Activations<TConfig, kBatchSize>;
constexpr size_t kQKVDim = TActivations::kQKVDim;
constexpr size_t kQStride = TActivations::kQStride;
constexpr size_t kCachePosSize = TActivations::kCachePosSize;
constexpr size_t kCacheLayerSize = TActivations::kCacheLayerSize;
constexpr size_t kModelDim = TActivations::kModelDim;
constexpr size_t kHeads = TConfig::kHeads;
constexpr size_t kKVHeads = TConfig::kKVHeads;
constexpr size_t kSeqLen = TConfig::kSeqLen;
GEMMA_CONSTEXPR_SQRT const float kQueryScale =
1.0f / Sqrt(static_cast<float>(kQKVDim));
constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention

// If MHA, this also computes KV, which we copy to the KV cache below.
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
num_tokens, activations.pre_att_rms_out.data(),
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);

for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
// QKV projections:
if constexpr (!kIsMHA) {
const size_t pos = batch_start + batch_idx;
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
// TODO: requires MatMul support for offsets.
MatVec<kKVHeads * kQKVDim * 2, kModelDim>(
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
activations.even_odd.data(), kv, pool);
}
}

// Positional encodings for kv:
pool.Run(
0, kKVHeads * num_tokens, [&](uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kKVHeads;
const size_t batch_idx = task / kKVHeads;
const size_t pos = batch_start + batch_idx;
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head * kQKVDim * 2;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
if constexpr (kIsMHA) {
// For MHA, copy kv into the KV cache from scratch space (see above).
const float* HWY_RESTRICT q =
activations.q.data() + (batch_idx * kHeads + head) * kQStride;
// Skip past the Q part of `q`, and copy KV to `kv`.
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
}
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
});

static_assert((kHeads % kKVHeads) == 0,
"query heads must be a multiple of key-value heads");
static constexpr size_t kGroupHeads = kHeads / kKVHeads;
pool.Run(0, kHeads * num_tokens, [&](uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t batch_idx = task / kHeads;
const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2;
float* HWY_RESTRICT q =
activations.q.data() + (batch_idx * kHeads + head) * kQStride;

auto Attn = [&](float* q, uint64_t head, size_t head_offset, size_t batch_idx,
size_t thread) HWY_ATTR {
const size_t pos = batch_start + batch_idx;
// Calculate scores
float* HWY_RESTRICT head_att = activations.att.data() +
head * kSeqLen +
batch_idx * kHeads * kSeqLen;
float* HWY_RESTRICT head_att =
activations.att.data() + head * kSeqLen + batch_idx * kHeads * kSeqLen;

Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
MulByConst(kQueryScale, q, kQKVDim);
Expand All @@ -398,8 +448,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
const size_t start_pos = pos - std::min(kSeqLen - 1, pos);
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head_offset;
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset;
const float score = Dot(q, k2, kQKVDim);
head_att[pos2 % kSeqLen] = score;
Expand All @@ -412,72 +462,11 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head_offset;
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim;
MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
}
};

if constexpr (kHeads == kKVHeads) {
// Multi-Head Attention calculates qkv using q as scratch space.
static_assert(TConfig::kInterleaveQKV);
MatMul_4x4_Batch<kModelDim, kHeads * kQKVDim * 3>(
num_tokens, activations.pre_att_rms_out.data(),
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
} else {
MatMul_4x4_Batch<kModelDim, kHeads * kQKVDim>(
num_tokens, activations.pre_att_rms_out.data(),
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
}

for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
// QKV projections:
if constexpr (kHeads != kKVHeads) {
const size_t pos = batch_start + batch_idx;
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
// TODO: requires MatMul support for offsets.
MatVec<kKVHeads * kQKVDim * 2, kModelDim>(
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
activations.even_odd.data(), kv, pool);
}
}

// Positional encodings for k:
const size_t num_kv_tasks = kKVHeads * num_tokens;
pool.Run(0, num_kv_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kKVHeads;
const size_t batch_idx = task / kKVHeads;
const size_t pos = batch_start + batch_idx;
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head * kQKVDim * 2;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
if constexpr (kHeads == kKVHeads) {
// For MHA, copy kv into the KV cache from scratch space (see above).
const float* HWY_RESTRICT q =
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
}
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
});

static_assert((TConfig::kHeads % TConfig::kKVHeads) == 0,
"query heads must be a multiple of key-value heads");
static constexpr size_t kGroupHeads = TConfig::kHeads / TConfig::kKVHeads;
static constexpr size_t kQOffsetScale = (kHeads == kKVHeads) ? 3 : 1;
const size_t num_q_tasks = kHeads * num_tokens;
pool.Run(0, num_q_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t batch_idx = task / kHeads;
const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2;
float* HWY_RESTRICT q = activations.q.data() + (batch_idx * kHeads + head) *
kQKVDim * kQOffsetScale;
Attn(q, head, head_offset, batch_idx, thread);
});

for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
Expand Down Expand Up @@ -1012,7 +1001,7 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
// Both pre-trained and instruction-tuned require BOS as first token.
if (pos == 0) {
tokens.insert(tokens.begin(), gcpp::BOS_ID);
tokens.insert(tokens.begin(), BOS_ID);
}
return tokens;
}
Expand Down