Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can't infer with a "exclude_lm_head" model #1166

Open
busishengui opened this issue Jan 6, 2025 · 1 comment
Open

can't infer with a "exclude_lm_head" model #1166

busishengui opened this issue Jan 6, 2025 · 1 comment
Assignees

Comments

@busishengui
Copy link

Describe the bug
can't infer with a "exclude_lm_head" model

To Reproduce

python3 builder.py -m llama3.2-3b -o row_llama3.2-3b-onnx-int4 -p int4 -e cpu --extra_options int4_block_size=128 int4_accuracy_level=4 int4_op_types_to_quantize=MatMul/Gather exclude_lm_head=1
auto model = OgaModel::Create("llama3.2-3b-onnx-int4");
auto tokenizer = OgaTokenizer::Create(*model);
params->SetSearchOption("max_length", 128);
auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer);
auto params = OgaGeneratorParams::Create(*model);
std::string query = "tell me hello";
auto seq = OgaSequences::Create();
tokenizer->Encode(query.c_str(), *seq);
params->SetInputSequences(*seq);
auto generator = OgaGenerator::Create(*model, *params);

Then you will see the bug:

terminate called after throwing an instance of 'std::runtime_error'
  what():  Model output was not found: logits
Aborted

onnxruntime-genai version: 0.5.2
OS:linux

@kunal-vaishnavi
Copy link
Contributor

If you are using exclude_lm_head, then the ONNX model's output will be the last hidden states instead of logits. When running this model, you will then have to convert your hidden states to logits afterwards since the generation loop in ONNX Runtime GenAI is performed on logits. Support for this approach is currently limited since there have not been many requests for it.

An alternative approach is to have both the last hidden states and the logits as outputs in the ONNX model. You can achieve that by using include_hidden_states in the extra_options (see example usage here).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants