Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow some known extra inputs in the model #1167

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

baijumeswani
Copy link
Collaborator

This pull-request adds some extra inputs to a registry, so users need not be aware of such extra inputs.

In particular, this pull-request adds the num_logits_to_keep input to the registry.

Models from huggingface/transformers that are exported using torch.onnx.export, can have an extra input called num_logits_to_keep that should be default be set to 0. Users should not need to know the input is required because it is an optimization in the torch model that does not impact the onnx model.

Sometimes, the exporter will rename the input to onnx::Neg_xyz. This pull-request also handles such scenarios.

src/models/extra_inputs.cpp Outdated Show resolved Hide resolved
src/models/extra_inputs.cpp Outdated Show resolved Hide resolved
}}} {}

void PresetExtraInputs::Add() {
const auto input_names_vector = state_.model_.session_info_->GetInputNames();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be simpler to just do a 'session_info_->HasInput("num_logits_to_keep")' and add it? There's already a special case for onnx::Neg_ prefixes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic currently added is more general purpose since it checks for all inputs that have not been populated already. For the next preset extra input, we would only need to add it to the registry.

Also, I don't want to add num_logits_to_keep if the user has supplied this input using params.set_input.

@@ -5,6 +5,46 @@

namespace Generators {

PresetExtraInputs::PresetExtraInputs(State& state)
: state_(state),
registry_{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the registry always the same? Should this be a global that's created once vs per state?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think creating the registry is inexpensive and creating one over and over again per State should be ok for now.

I can make it global but would like to avoid making things global if possible.

@@ -8,6 +8,7 @@
import sysconfig
from pathlib import Path
import shutil
import tempfile

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'tempfile' is not used.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants