-
Notifications
You must be signed in to change notification settings - Fork 142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow some known extra inputs in the model #1167
base: main
Are you sure you want to change the base?
Conversation
}}} {} | ||
|
||
void PresetExtraInputs::Add() { | ||
const auto input_names_vector = state_.model_.session_info_->GetInputNames(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be simpler to just do a 'session_info_->HasInput("num_logits_to_keep")' and add it? There's already a special case for onnx::Neg_ prefixes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the logic currently added is more general purpose since it checks for all inputs that have not been populated already. For the next preset extra input
, we would only need to add it to the registry.
Also, I don't want to add num_logits_to_keep if the user has supplied this input using params.set_input.
@@ -5,6 +5,46 @@ | |||
|
|||
namespace Generators { | |||
|
|||
PresetExtraInputs::PresetExtraInputs(State& state) | |||
: state_(state), | |||
registry_{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the registry always the same? Should this be a global that's created once vs per state?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think creating the registry is inexpensive and creating one over and over again per State should be ok for now.
I can make it global but would like to avoid making things global if possible.
This pull-request adds some extra inputs to a registry, so users need not be aware of such extra inputs.
In particular, this pull-request adds the
num_logits_to_keep
input to the registry.Models from huggingface/transformers that are exported using
torch.onnx.export
, can have an extra input callednum_logits_to_keep
that should be default be set to0
. Users should not need to know the input is required because it is an optimization in the torch model that does not impact the onnx model.Sometimes, the exporter will rename the input to
onnx::Neg_xyz
. This pull-request also handles such scenarios.