Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,6 @@ cc_test(
],
)

cc_library(
name = "args",
hdrs = ["util/args.h"],
deps = [
"//compression:io",
"@hwy//:hwy",
],
)

cc_library(
name = "gemma_lib",
srcs = [
Expand All @@ -70,10 +61,10 @@ cc_library(
"gemma/gemma.h",
],
deps = [
":args",
":ops",
# "//base",
"//compression:compress",
"//compression:io",
"@hwy//:hwy",
"@hwy//:matvec",
"@hwy//:nanobenchmark", # timer
Expand All @@ -83,6 +74,25 @@ cc_library(
],
)

cc_library(
name = "args",
hdrs = ["util/args.h"],
deps = [
"//compression:io",
"@hwy//:hwy",
],
)

cc_library(
name = "app",
hdrs = ["util/app.h"],
deps = [
":args",
":gemma_lib",
"@hwy//:hwy",
],
)

cc_test(
name = "gemma_test",
srcs = ["gemma/gemma_test.cc"],
Expand All @@ -102,16 +112,6 @@ cc_test(
],
)

cc_library(
name = "app",
hdrs = ["util/app.h"],
deps = [
":args",
":gemma_lib",
"@hwy//:hwy",
],
)

cc_binary(
name = "gemma",
srcs = ["gemma/run.cc"],
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ set(SOURCES
compression/blob_store.h
compression/compress.h
compression/compress-inl.h
compression/io_win.cc
compression/io.cc
compression/io.h
compression/nuq.h
Expand Down
27 changes: 10 additions & 17 deletions compression/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ package(

cc_library(
name = "io",
srcs = ["io.cc"],
srcs = [
"io.cc",
# Placeholder for io backend, do not remove
],
hdrs = ["io.h"],
# Placeholder for io textual_hdrs, do not remove
deps = [
# Placeholder for io deps, do not remove
"@hwy//:hwy",
Expand Down Expand Up @@ -80,12 +82,8 @@ cc_library(

cc_library(
name = "sfp",
hdrs = [
"sfp.h",
],
textual_hdrs = [
"sfp-inl.h",
],
hdrs = ["sfp.h"],
textual_hdrs = ["sfp-inl.h"],
deps = [
"@hwy//:hwy",
],
Expand All @@ -112,12 +110,8 @@ cc_test(

cc_library(
name = "nuq",
hdrs = [
"nuq.h",
],
textual_hdrs = [
"nuq-inl.h",
],
hdrs = ["nuq.h"],
textual_hdrs = ["nuq-inl.h"],
deps = [
":sfp",
"@hwy//:hwy",
Expand Down Expand Up @@ -158,6 +152,7 @@ cc_library(
deps = [
":blob_store",
":distortion",
":io",
":nuq",
":sfp",
":stats",
Expand All @@ -170,9 +165,7 @@ cc_library(
# For internal experimentation
cc_library(
name = "analyze",
textual_hdrs = [
"analyze.h",
],
textual_hdrs = ["analyze.h"],
deps = [
":distortion",
":nuq",
Expand Down
22 changes: 12 additions & 10 deletions compression/blob_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <stdint.h>

#include <atomic>
#include <memory>
#include <vector>

#include "compression/io.h"
Expand Down Expand Up @@ -199,12 +200,13 @@ class BlobStore {
};
#pragma pack(pop)

BlobError BlobReader::Open(const char* filename) {
if (!file_.Open(filename, "r")) return __LINE__;
BlobError BlobReader::Open(const Path& filename) {
file_ = OpenFileOrNull(filename, "r");
if (!file_) return __LINE__;

// Read first part of header to get actual size.
BlobStore bs;
if (!file_.Read(0, sizeof(bs), &bs)) return __LINE__;
if (!file_->Read(0, sizeof(bs), &bs)) return __LINE__;
const size_t padded_size = bs.PaddedHeaderSize();
HWY_ASSERT(padded_size >= sizeof(bs));

Expand All @@ -216,11 +218,11 @@ BlobError BlobReader::Open(const char* filename) {
hwy::CopySameSize(&bs, blob_store_.get());
// Read the rest of the header, but not the full file.
uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get());
if (!file_.Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) {
if (!file_->Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) {
return __LINE__;
}

return blob_store_->CheckValidity(file_.FileSize());
return blob_store_->CheckValidity(file_->FileSize());
}

BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
Expand All @@ -247,7 +249,7 @@ BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
// between consecutive runs.
// - memory-mapped I/O is less predictable and adds noise to measurements.
BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
File* pfile = &file_; // not owned
File* pfile = file_.get(); // not owned
const auto& requests = requests_;
std::atomic_flag err = ATOMIC_FLAG_INIT;
// >5x speedup from parallel reads when cached.
Expand All @@ -262,7 +264,7 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
return 0;
}

BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const char* filename) {
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
HWY_ASSERT(keys_.size() == blobs_.size());

// Concatenate blobs in memory.
Expand All @@ -273,9 +275,9 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const char* filename) {
keys_.data(), blobs_.data(), keys_.size(), bs.get());

// Create/replace existing file.
File file;
if (!file.Open(filename, "w+")) return __LINE__;
File* pfile = &file; // not owned
std::unique_ptr<File> file = OpenFileOrNull(filename, "w+");
if (!file) return __LINE__;
File* pfile = file.get(); // not owned

std::atomic_flag err = ATOMIC_FLAG_INIT;
pool.Run(0, requests.size(),
Expand Down
7 changes: 4 additions & 3 deletions compression/blob_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <stddef.h>
#include <stdint.h>

#include <memory>
#include <vector>

#include "compression/io.h"
Expand Down Expand Up @@ -63,7 +64,7 @@ class BlobReader {
~BlobReader() = default;

// Opens `filename` and reads its header.
BlobError Open(const char* filename);
BlobError Open(const Path& filename);

// Enqueues read requests if `key` is found and its size matches `size`.
BlobError Enqueue(hwy::uint128_t key, void* data, size_t size);
Expand All @@ -74,7 +75,7 @@ class BlobReader {
private:
BlobStorePtr blob_store_; // holds header, not the entire file
std::vector<BlobIO> requests_;
File file_;
std::unique_ptr<File> file_;
};

class BlobWriter {
Expand All @@ -85,7 +86,7 @@ class BlobWriter {
}

// Stores all blobs to disk in the given order with padding for alignment.
BlobError WriteAll(hwy::ThreadPool& pool, const char* filename);
BlobError WriteAll(hwy::ThreadPool& pool, const Path& filename);

private:
std::vector<hwy::uint128_t> keys_;
Expand Down
6 changes: 3 additions & 3 deletions compression/compress-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,11 +459,11 @@ class Compressor {
}
}

void WriteAll(hwy::ThreadPool& pool, const char* blob_filename) {
void WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
const BlobError err = writer_.WriteAll(pool, blob_filename);
if (err != 0) {
fprintf(stderr, "Failed to write blobs to %s (error %d)\n", blob_filename,
err);
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
blob_filename.path.c_str(), err);
}
}

Expand Down
5 changes: 3 additions & 2 deletions compression/compress.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

// IWYU pragma: begin_exports
#include "compression/blob_store.h"
#include "compression/io.h"
#include "compression/nuq.h"
#include "compression/sfp.h"
// IWYU pragma: end_exports
Expand Down Expand Up @@ -166,13 +167,13 @@ hwy::uint128_t CacheKey(const char* name) {

class CacheLoader {
public:
explicit CacheLoader(const char* blob_filename) {
explicit CacheLoader(const Path& blob_filename) {
err_ = reader_.Open(blob_filename);
if (err_ != 0) {
fprintf(stderr,
"Cached compressed weights does not exist yet (code %d), "
"compressing weights and creating file: %s.\n",
err_, blob_filename);
err_, blob_filename.path.c_str());
}
}

Expand Down
Loading