Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ cc_library(
],
)

# Split from :threading to break a circular dependency with :allocator.
cc_library(
name = "threading",
srcs = ["util/threading.cc"],
hdrs = ["util/threading.h"],
name = "topology",
srcs = ["util/topology.cc"],
hdrs = ["util/topology.h"],
deps = [
":basics",
# Placeholder for container detection, do not remove
"@highway//:hwy",
"@highway//:thread_pool",
"@highway//:topology",
],
)
Expand All @@ -48,32 +46,54 @@ cc_library(
hdrs = ["util/allocator.h"],
deps = [
":basics",
":threading",
":topology",
"@highway//:hwy",
"@highway//:thread_pool",
"@highway//:topology",
],
)

cc_library(
name = "test_util",
hdrs = ["util/test_util.h"],
name = "threading",
srcs = ["util/threading.cc"],
hdrs = ["util/threading.h"],
deps = [
":allocator",
":basics",
":topology",
# Placeholder for container detection, do not remove
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:stats",
"@highway//:thread_pool",
"@highway//:topology",
],
)

cc_test(
name = "threading_test",
srcs = ["util/threading_test.cc"],
deps = [
":allocator",
":basics",
":threading",
"@googletest//:gtest_main",
"@highway//:auto_tune",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:robust_statistics",
"@highway//:stats",
"@highway//:thread_pool",
"@highway//:timer",
],
)

cc_library(
name = "test_util",
hdrs = ["util/test_util.h"],
deps = [
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:stats",
],
)

Expand Down Expand Up @@ -104,6 +124,7 @@ cc_library(
":allocator",
":basics",
":threading",
":topology",
"//compression:compress",
"@highway//:algo",
"@highway//:bit_set",
Expand All @@ -113,7 +134,6 @@ cc_library(
"@highway//:nanobenchmark",
"@highway//:profiler",
"@highway//:thread_pool",
"@highway//:topology",
"@highway//hwy/contrib/sort:vqsort",
],
)
Expand All @@ -128,11 +148,11 @@ cc_test(
tags = ["ops_tests"],
deps = [
":allocator",
":app",
":ops",
":test_util",
":threading",
"@googletest//:gtest_main", # buildcleaner: keep
"//:app",
"//compression:compress",
"//compression:test_util",
"@highway//:hwy",
Expand All @@ -154,11 +174,12 @@ cc_test(
tags = ["ops_tests"],
deps = [
":allocator",
":app",
":common",
":ops",
":test_util",
":threading",
"@googletest//:gtest_main", # buildcleaner: keep
"//:app",
"//compression:compress",
"@highway//:hwy",
"@highway//:hwy_test_util",
Expand Down Expand Up @@ -405,6 +426,7 @@ cc_library(
":cross_entropy",
":gemma_lib",
":kv_cache",
":ops",
":threading",
# Placeholder for internal dep, do not remove.,
"@google_benchmark//:benchmark",
Expand Down Expand Up @@ -464,13 +486,13 @@ cc_binary(
":benchmark_helper",
":common",
":gemma_lib",
":ops",
":threading",
# Placeholder for internal dep, do not remove.,
"//compression:sfp",
"//paligemma:image",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool",
],
)

Expand Down Expand Up @@ -634,13 +656,12 @@ cc_test(
":backprop",
":backprop_scalar",
":common",
":gemma_lib",
":ops",
":prompt",
":sampler",
":threading",
":weights",
"@googletest//:gtest_main",
"//:threading",
"//compression:compress",
"@highway//:hwy",
"@highway//:hwy_test_util",
Expand Down
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG f2209b911c74019e85d0b7a7a2833c9a2e1b7995 EXCLUDE_FROM_ALL)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06 EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(highway)

## Note: absl needs to be installed by sentencepiece. This will only happen if
Expand Down Expand Up @@ -108,6 +108,8 @@ set(SOURCES
util/test_util.h
util/threading.cc
util/threading.h
util/topology.cc
util/topology.h
)

if(NOT CMAKE_BUILD_TYPE)
Expand Down
2 changes: 1 addition & 1 deletion MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
# Require a more recent version.
git_override(
module_name = "highway",
commit = "f2209b911c74019e85d0b7a7a2833c9a2e1b7995",
commit = "c5bebf84ad01edec97e336f5c97ca4e0df6b4d06",
remote = "https://github.com/google/highway",
)

Expand Down
25 changes: 12 additions & 13 deletions backprop/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
#include "ops/ops.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"

// clang-format off
#undef HWY_TARGET_INCLUDE
Expand All @@ -59,9 +58,9 @@ void TestMatMulVJP() {
static const size_t kRows = 8;
static const size_t kCols = 64;
static const size_t kTokens = 5;
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 8));
Allocator::Init(pools.Topology());
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
Allocator::Init(topology);
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
std::mt19937 gen(42);
MatStorageT<float> weights("weights", kRows, kCols);
MatStorageT<float> x("x", kTokens, kCols);
Expand Down Expand Up @@ -105,9 +104,9 @@ void TestMultiHeadMatMulVJP() {
static const size_t kCols = 16;
static const size_t kHeads = 4;
static const size_t kTokens = 3;
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 8));
Allocator::Init(pools.Topology());
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
Allocator::Init(topology);
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
std::mt19937 gen(42);
MatStorageT<float> weights("weights", kRows, kCols * kHeads);
MatStorageT<float> x("x", kTokens, kCols * kHeads);
Expand Down Expand Up @@ -150,9 +149,9 @@ void TestMultiHeadMatMulVJP() {
void TestRMSNormVJP() {
static const size_t K = 2;
static const size_t N = 64;
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 8));
Allocator::Init(pools.Topology());
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
Allocator::Init(topology);
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
std::mt19937 gen(42);
MatStorageT<float> weights("weights", N, 1);
MatStorageT<float> x("x", K, N);
Expand Down Expand Up @@ -216,9 +215,9 @@ static ModelConfig TestConfig() {

void TestEndToEnd() {
std::mt19937 gen(42);
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 1));
Allocator::Init(pools.Topology());
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1));
Allocator::Init(topology);
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
ModelConfig config = TestConfig();
WeightsWrapper<float> weights(config);
WeightsWrapper<float> grad(config);
Expand Down
9 changes: 5 additions & 4 deletions backprop/optimize_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@
namespace gcpp {

TEST(OptimizeTest, GradientDescent) {
NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 1));
Allocator::Init(pools.Topology());
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1));
Allocator::Init(topology);
NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
MatMulEnv env(topology, pools);
hwy::ThreadPool& pool = pools.Pool();
std::mt19937 gen(42);

Expand All @@ -66,7 +67,7 @@ TEST(OptimizeTest, GradientDescent) {
config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope);

Gemma gemma(GemmaTokenizer(), info, pools);
Gemma gemma(GemmaTokenizer(), info, env);

const auto generate = [&](const std::vector<int>& prompt) {
std::vector<int> reply;
Expand Down
5 changes: 3 additions & 2 deletions compression/blob_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,9 @@ void ReadAndCompareBlobs(const char* path1, const char* path2) {
if (!CompareKeys(reader1, reader2)) return;

// Single allocation, avoid initializing the memory.
NestedPools pools(0);
Allocator::Init(pools.Topology());
BoundedTopology topology;
Allocator::Init(topology);
NestedPools pools(topology);
const size_t total_bytes = TotalBytes(reader1) + TotalBytes(reader2);
BytePtr all_blobs = hwy::AllocateAligned<uint8_t>(total_bytes);
size_t pos = 0;
Expand Down
11 changes: 6 additions & 5 deletions evals/benchmark_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {

GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app)
: pools_(CreatePools(app)) {
Allocator::Init(pools_.Topology());
: topology_(CreateTopology(app)),
pools_(CreatePools(topology_, app)),
env_(topology_, pools_) {
InferenceArgs mutable_inference = inference;
AbortIfInvalidArgs(mutable_inference);
LoaderArgs mutable_loader = loader;
Expand All @@ -66,7 +67,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
fprintf(stderr, "Skipping model load because: %s\n", err);
} else {
fprintf(stderr, "Loading model...\n");
model_ = AllocateGemma(mutable_loader, pools_);
model_ = AllocateGemma(mutable_loader, env_);
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.resize(1);
kv_caches_[0] = KVCache::Create(model_->GetModelConfig(),
Expand Down Expand Up @@ -236,7 +237,7 @@ std::string CacheString() {
}

void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
NestedPools& pools) {
const BoundedTopology& topology, NestedPools& pools) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
Expand All @@ -255,7 +256,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
"Compiled config : %s\n"
"Weight Type : %s\n"
"EmbedderInput Type : %s\n",
dt, cpu100, pools.TopologyString(), pools.PinString(),
dt, cpu100, topology.TopologyString(), pools.PinString(),
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
hwy::VectorBytes() * 8, CompiledConfig(),
StringFromType(loader.Info().weight), TypeName<EmbedderInputT>());
Expand Down
16 changes: 7 additions & 9 deletions evals/benchmark_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <vector>

#include "gemma/gemma.h"
#include "ops/matmul.h"
#include "util/app.h"
#include "util/threading.h"
#include "hwy/base.h"
Expand Down Expand Up @@ -105,23 +106,20 @@ class GemmaEnv {
KVCache& MutableKVCache() { return kv_caches_[0]; }

private:
// Thread pool for running inference.
NestedPools pools_;
// Random number generator.
std::mt19937 gen_;
// The model to run inference on.
BoundedTopology topology_;
NestedPools pools_; // Thread pool.
MatMulEnv env_;
std::mt19937 gen_; // Random number generator.
std::unique_ptr<Gemma> model_;
// KV caches, same number as query batch.
std::vector<KVCache> kv_caches_;
// Runtime config for inference.
std::vector<KVCache> kv_caches_; // Same number as query batch.
RuntimeConfig runtime_config_;
};

// Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens);

void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
NestedPools& pools);
const BoundedTopology& topology, NestedPools& pools);
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);

} // namespace gcpp
Expand Down
1 change: 0 additions & 1 deletion examples/hello_world/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ cc_binary(
# Placeholder for internal dep, do not remove.,
"//:app",
"//:args",
"//:common",
"//:gemma_lib",
"//:threading",
"//:tokenizer",
Expand Down
2 changes: 1 addition & 1 deletion examples/hello_world/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ project(hello_world)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG f2209b911c74019e85d0b7a7a2833c9a2e1b7995)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06)
FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece)
Expand Down
7 changes: 4 additions & 3 deletions examples/hello_world/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ int main(int argc, char** argv) {
}

// Instantiate model and KV Cache
gcpp::NestedPools pools = gcpp::CreatePools(app);
gcpp::Allocator::Init(pools.Topology());
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
gcpp::BoundedTopology topology(gcpp::CreateTopology(app));
gcpp::NestedPools pools = gcpp::CreatePools(topology, app);
gcpp::MatMulEnv env(topology, pools);
gcpp::Gemma model = gcpp::CreateGemma(loader, env);
gcpp::KVCache kv_cache =
gcpp::KVCache::Create(model.GetModelConfig(),
inference.prefill_tbatch_size);
Expand Down
Loading
Loading