Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gemma/configs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ static ModelConfig ConfigBaseGemmaV2() {
ModelConfig config = ConfigNoSSM();
config.att_cap = 50.0f;
config.final_cap = 30.0f;
config.eos_id = 1;
config.secondary_eos_id = 107;
return config;
}

Expand Down
4 changes: 2 additions & 2 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1427,7 +1427,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
// Sanity check: prompts should not be empty, nor start with EOS.
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
const PromptTokens& prompt = queries_prompt[query_idx];
HWY_ASSERT(prompt.size() != 0 && !model.Config().IsEOS(prompt[0]));
HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id);
}

const size_t num_queries = queries_prompt.size();
Expand Down Expand Up @@ -1615,4 +1615,4 @@ void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
} // namespace gcpp
HWY_AFTER_NAMESPACE();

#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
26 changes: 5 additions & 21 deletions gemma/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
size_t abs_pos = 0; // across turns
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
size_t prompt_size = 0;
bool end_of_turn_seen = false;

std::mt19937 gen;
InitGenerator(args, gen);
Expand Down Expand Up @@ -118,12 +117,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
// callback function invoked for each generated token.
auto stream_token = [&](int token, float) {
++abs_pos;
if (model.GetModelConfig().IsEOS(token)) {
if (app.verbosity >= 2) {
std::cout << "\n[ End ]\n";
}
return true;
}
const bool in_prompt = tokens_generated_this_turn < prompt_size;
const bool first_response_token = tokens_generated_this_turn == prompt_size;
++tokens_generated_this_turn;
Expand All @@ -132,6 +125,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
std::cerr << "." << std::flush;
}
return true;
} else if (model.GetModelConfig().IsEOS(token)) {
if (app.verbosity >= 2) {
std::cout << "\n[ End ]\n";
}
return true;
}
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
Expand All @@ -141,13 +139,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
std::cout << "\n\n";
}
}
if (token_text == "<end_of_turn>") {
// We don't want to show the <end_of_turn> token to the user.
// We also need to remember that we've seen it, so that we can rewind
// abs_pos appropriately. We expect EOS as the next token.
end_of_turn_seen = true;
return true;
}
std::cout << token_text << std::flush;
return true;
};
Expand Down Expand Up @@ -233,13 +224,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
HWY_ASSERT(abs_pos > 0);
abs_pos--;
}
if (end_of_turn_seen && abs_pos > 0) {
// If we have seen an end_of_turn token, we need to rewind abs_pos by one
// more, because we will prepend it again to the prompt in
// WrapAndTokenize.
abs_pos--;
}
end_of_turn_seen = false;
}
}

Expand Down
Loading