Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,19 @@ cc_library(
],
)

cc_library(
name = "cross_entropy",
srcs = [
"gemma/cross_entropy.cc",
],
hdrs = [
"gemma/cross_entropy.h",
],
deps = [
":gemma_lib",
],
)

cc_library(
name = "args",
hdrs = ["util/args.h"],
Expand Down Expand Up @@ -141,6 +154,7 @@ cc_test(
],
deps = [
":args",
":cross_entropy",
":gemma_lib",
":ops",
"@googletest//:gtest_main",
Expand Down Expand Up @@ -190,6 +204,7 @@ cc_binary(
":app",
":args",
":common",
":cross_entropy",
":gemma_lib",
"//compression:io",
"@hwy//:hwy",
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ set(SOURCES
gemma/activations.h
gemma/common.cc
gemma/common.h
gemma/cross_entropy.cc
gemma/cross_entropy.h
gemma/gemma.cc
gemma/gemma.h
gemma/ops.h
Expand Down
2 changes: 1 addition & 1 deletion backprop/optimize_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ TEST(OptimizeTest, GradientDescent) {
};
RuntimeConfig runtime = {
max_tokens, max_generated_tokens, temperature, verbosity, &gen,
stream_token, accept_token, ReverseSequenceSampler::kEndToken,
stream_token, accept_token, nullptr, ReverseSequenceSampler::kEndToken,
};
TimingInfo timing_info;
gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);
Expand Down
7 changes: 5 additions & 2 deletions gemma/benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <vector>

#include "compression/io.h" // Path
#include "gemma/cross_entropy.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/args.h"
Expand Down Expand Up @@ -204,13 +205,15 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type);
float entropy = model.ComputeCrossEntropy(num_tokens, prompt_slice,
kv_cache, app.verbosity);
float entropy = ComputeCrossEntropy(model, num_tokens, prompt_slice,
kv_cache, app.verbosity);
total_entropy += entropy;
LogSpeedStats(time_start, pos + num_tokens);
std::string text_slice;
HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice));
total_input_len += text_slice.size();
printf("Total cross entropy: %f [cumulative: %f]\n",
entropy, total_entropy);
printf("Cross entropy per byte: %f [cumulative: %f]\n",
entropy / text_slice.size(), total_entropy / total_input_len);
}
Expand Down
109 changes: 109 additions & 0 deletions gemma/cross_entropy.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "gemma/cross_entropy.h"

#include <algorithm>
#include <cmath>
#include <regex> // NOLINT
#include <string>
#include <utility>
#include <vector>

namespace gcpp {

namespace {
template <typename TConfig>
struct GetVocabSize {
int operator()() const { return TConfig::kVocabSize; }
};

static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
std::string token_str;
tokenizer.Decode({token}, &token_str);
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
}

void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len,
size_t k) {
std::vector<std::pair<float, int>> sorted(len);
for (size_t i = 0; i < len; ++i) {
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
}
std::sort(sorted.begin(), sorted.end(),
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
if (a.first != b.first) {
return a.first > b.first;
}
return a.second < b.second;
});
for (size_t i = 0; i < k; ++i) {
printf(" [#%-2d token %6d = %-12s %.2e]\n", static_cast<int>(i + 1),
sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(),
sorted[i].first);
}
}
} // namespace

float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt,
KVCache& kv_cache, int verbosity) {
auto stream_token = [](int, float) { return true; };
auto accept_token = [](int) { return true; };

const int vocab_size = CallFunctorForModel<GetVocabSize>(gemma.ModelType());
float cross_entropy = std::log(vocab_size); // first token
size_t pos = 1;
std::function<int(const float*, size_t)> sample_token =
[&](const float* probs, size_t vocab_size) -> int {
const int token = prompt[pos];
const float prob = probs[token];
cross_entropy -= std::max(std::log(prob), -64.0f);

if (verbosity >= 4) {
LogTopK(gemma.Tokenizer(), probs, vocab_size, 10);
}
if (verbosity >= 3) {
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos,
token, TokenString(gemma.Tokenizer(), token).c_str(), prob,
-std::log(prob) / std::log(2.0));
}
if (verbosity >= 2 && pos % 100 == 99) {
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
cross_entropy / std::log(2.0) / (pos + 1));
}
++pos;
return token;
};
std::vector<int> prompt0 = { prompt[0] };
RuntimeConfig runtime = {
.max_tokens = max_tokens,
.max_generated_tokens = max_tokens - 1,
.temperature = 0.0f,
.verbosity = verbosity,
.gen = nullptr,
.stream_token = stream_token,
.accept_token = accept_token,
.sample_func = &sample_token,
};
TimingInfo timing_info;

gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info, nullptr);

const float scale = 1.0 / std::log(2.0);
return cross_entropy * scale;
}

} // namespace gcpp
31 changes: 31 additions & 0 deletions gemma/cross_entropy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_

#include <vector>

#include "gemma/gemma.h"

namespace gcpp {

float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity);

} // namespace gcpp

#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
103 changes: 11 additions & 92 deletions gemma/gemma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
#include <array>
#include <cmath>
#include <memory>
#include <regex> // NOLINT
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -217,32 +216,6 @@ bool GemmaTokenizer::Decode(const std::vector<int>& ids,
return impl_->Decode(ids, detokenized);
}

static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
std::string token_str;
tokenizer.Decode({token}, &token_str);
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
}

void LogTopK(const GemmaTokenizer& tokenizer, float* HWY_RESTRICT logits,
float* HWY_RESTRICT dist, size_t len, size_t k) {
std::vector<std::pair<float, int>> sorted(len);
for (size_t i = 0; i < len; ++i) {
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
}
std::sort(sorted.begin(), sorted.end(),
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
if (a.first != b.first) {
return a.first > b.first;
}
return a.second < b.second;
});
for (size_t i = 0; i < k; ++i) {
printf(" [#%-2d token %6d = %-12s %.2e %f]\n", static_cast<int>(i + 1),
sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(),
sorted[i].first, logits[sorted[i].second]);
}
}

} // namespace gcpp
#endif // GEMMA_ONCE

Expand Down Expand Up @@ -837,13 +810,19 @@ void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
MatVec<kVocabSize, TConfig::kModelDim>(
weights.embedder_input_embedding, 0, final_activation,
activations.even_odd.data(), activations.logits.data(), pool);
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
// Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), kVocabSize);
token = SampleTopK<TConfig::kTopK>(
activations.logits.data(), kVocabSize, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token);
if (!runtime_config.stream_token(token, activations.logits[token])) {
token = runtime_config.eos_id;
if (runtime_config.sample_func) {
token = (*runtime_config.sample_func)(activations.logits.data(),
kVocabSize);
} else {
token = SampleTopK<TConfig::kTopK>(
activations.logits.data(), kVocabSize, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token);
if (!runtime_config.stream_token(token, activations.logits[token])) {
token = runtime_config.eos_id;
}
}
if (generate_pos == 0) {
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
Expand All @@ -868,51 +847,6 @@ void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
}
}

template <class TConfig>
void ComputeCrossEntropy(const ByteStorageT& weights_u8,
ByteStorageT& decode_u8,
const GemmaTokenizer& tokenizer, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity,
float& cross_entropy) {
const WeightsT<TConfig>& weights = GetWeights<TConfig>(weights_u8);
auto& activations = GetActivations<TConfig, 1>(decode_u8);

static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize;
std::vector<float> logits(kVocabSize);
Softmax(activations.logits.data(), kVocabSize);
float total_entropy = 0.0f;
for (size_t pos = 0; pos < max_tokens && pos < prompt.size(); ++pos) {
if (verbosity >= 4) {
LogTopK(tokenizer, logits.data(), activations.logits.data(), kVocabSize,
10);
}
const int token = prompt[pos];
const float prob = activations.logits[token];
if (verbosity >= 3) {
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token,
TokenString(tokenizer, token).c_str(), prob,
-std::log(prob) / std::log(2.0));
}
total_entropy -= std::max(std::log(prob), -64.0f);
if (verbosity >= 2 && pos % 100 == 99) {
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
total_entropy / std::log(2.0) / (pos + 1));
}
Transformer(token, pos, weights, activations, kv_cache, pool,
/*layers_output=*/nullptr);
MatVec<kVocabSize, kModelDim>(
weights.embedder_input_embedding, 0, activations.x.data(),
activations.even_odd.data(), activations.logits.data(), pool);
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
memcpy(logits.data(), activations.logits.data(),
kVocabSize * sizeof(logits[0]));
Softmax(activations.logits.data(), kVocabSize);
}
cross_entropy = total_entropy / std::log(2.0);
}

} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
Expand Down Expand Up @@ -970,20 +904,5 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
}

float Gemma::ComputeCrossEntropy(size_t max_tokens,
const std::vector<int>& prompt,
KVCache& kv_cache, int verbosity) {
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);

float cross_entropy = 0.0f;
GEMMA_EXPORT_AND_DISPATCH_MODEL(
model_type_, ComputeCrossEntropy,
(weights_u8_, decode_u8_, tokenizer_, max_tokens, prompt, kv_cache, pool_,
verbosity, cross_entropy));

pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
return cross_entropy;
}

} // namespace gcpp
#endif // HWY_ONCE
8 changes: 5 additions & 3 deletions gemma/gemma.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ using StreamFunc = std::function<bool(int, float)>;
// AcceptFunc is called with token. It should return False for tokens you don't
// want to generate and True for tokens you want to generate.
using AcceptFunc = std::function<bool(int)>;
// CustomSampleFunc is called with the probability distribution for the next
// token, and its return value is used as the next generated token.
using CustomSampleFunc = std::function<int(const float*, size_t)>;

struct RuntimeConfig {
size_t max_tokens;
Expand All @@ -81,6 +84,7 @@ struct RuntimeConfig {
std::mt19937* gen;
const StreamFunc& stream_token;
const AcceptFunc& accept_token;
const CustomSampleFunc* sample_func = nullptr;
int eos_id = EOS_ID;
};

Expand All @@ -107,6 +111,7 @@ class Gemma {
Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool);
~Gemma();

Model ModelType() const { return model_type_; }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ByteStorageT& Weights() const { return weights_u8_; }
const ByteStorageT& Prefill() const { return prefill_u8_; }
Expand All @@ -119,9 +124,6 @@ class Gemma {
KVCache& kv_cache, TimingInfo& timing_info,
LayersOutputT* layers_output = nullptr);

float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
KVCache& kv_cache, int verbosity);

private:
hwy::ThreadPool& pool_;

Expand Down
Loading