Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx.onnx_cpp2py_export.checker.ValidationError when call quantize_static() in onnxruntime==1.20.1 #23268

Open
dzk9528 opened this issue Jan 7, 2025 · 4 comments
Labels
quantization issues related to quantization

Comments

@dzk9528
Copy link

dzk9528 commented Jan 7, 2025

Describe the issue

When I try to quantize model with larger weight size in onnxruntime 1.20.1, following error appeared:

WARNING:root:Please consider to run pre-processing before quantization. Refer to example: https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification/cpu/ReadMe.md 
Traceback (most recent call last):
  File "/home/engineer/tetramem/ml-experimental/quantize/ort_bug.py", line 75, in <module>
    onnxruntime.quantization.quantize(float_model, "test_quantized_model.onnx", quant_config=quant_config)
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnxruntime/quantization/quantize.py", line 878, in quantize
    quantize_static(
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnxruntime/quantization/quantize.py", line 693, in quantize_static
    calibrator = create_calibrator(
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnxruntime/quantization/calibrate.py", line 1186, in create_calibrator
    calibrator = MinMaxCalibrater(
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnxruntime/quantization/calibrate.py", line 321, in __init__
    super().__init__(
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnxruntime/quantization/calibrate.py", line 208, in __init__
    self.model = load_model_with_shape_infer(model_path)
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnxruntime/quantization/quant_utils.py", line 983, in load_model_with_shape_infer
    model = onnx.load(inferred_model_path.as_posix())
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnx/__init__.py", line 216, in load_model
    load_external_data_for_model(model, base_dir)
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnx/external_data_helper.py", line 64, in load_external_data_for_model
    load_external_data_for_tensor(tensor, base_dir)
  File "/opt/tetramem/python_env/lib/python3.10/site-packages/onnx/external_data_helper.py", line 42, in load_external_data_for_tensor
    external_data_file_path = c_checker._resolve_external_data_location(  # type: ignore[attr-defined]
onnx.onnx_cpp2py_export.checker.ValidationError: Data of TensorProto ( tensor name: weights) should be stored in /tmp/ort.quant.i4tq_bcf/5f6cc3dc-cc94-11ef-a03a-c87f54034d33, but it doesn't exist or is not accessible.

To reproduce

  • Python 3.10
  • onnx==1.16.2
  • Test code
import typing

import numpy as np
import onnx
import onnxruntime
import onnxruntime.quantization


class RandomCalibrationDataGenerator(onnxruntime.quantization.CalibrationDataReader):
    """Generates pseudo-random calibration data for testing."""

    def __init__(self, seed: int, name: str, shape: typing.Sequence[int], length: int):
        self.rng = np.random.default_rng(seed=seed)
        self.name = name
        self.shape = shape
        self.length = length
        self.counter = 0

    def get_next(self):
        """See base class."""
        if self.counter >= self.length:
            return None

        array = self.rng.normal(size=self.shape)
        datum = {self.name: array.astype(np.float32)}

        self.counter += 1
        return datum


rng = np.random.default_rng(seed=54321)
calibration_data_reader = RandomCalibrationDataGenerator(
    seed=54322,
    name="input",
    shape=[1, 128],
    length=5,
)

weight_array = rng.normal(size=(128, 256))
weight_array = weight_array.astype(np.float32)
weight_proto = onnx.numpy_helper.from_array(weight_array, name="weights")

node = onnx.helper.make_node(
    op_type="MatMul",
    inputs=["input", "weights"],
    outputs=["output"],
    name="dense",
)

input_info = onnx.helper.make_tensor_value_info(
    "input", onnx.TensorProto.FLOAT, ["batch", 128]
)
output_info = onnx.helper.make_tensor_value_info(
    "output", onnx.TensorProto.FLOAT, ["batch", 256]
)

graph = onnx.helper.make_graph(
    nodes=[node],
    initializer=[weight_proto],
    inputs=[input_info],
    outputs=[output_info],
    name="matmul_graph",
)
float_model = onnx.helper.make_model(graph)


quant_config = onnxruntime.quantization.StaticQuantConfig(
    calibration_data_reader,
    quant_format=onnxruntime.quantization.QuantFormat.QDQ,
    weight_type=onnxruntime.quantization.QuantType.QInt8,
    per_channel=True,
    extra_options=None,
)

onnxruntime.quantization.quantize(float_model, "test_quantized_model.onnx", quant_config=quant_config)

Urgency

This is a urgent request and it is very close to our overall model quantization software product development.

Platform

Linux

OS Version

Ubuntu 22.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.20.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

@dzk9528 dzk9528 changed the title onnx.onnx_cpp2py_export.checker.ValidationError when call quantize_static() with model.Proto in onnxruntime==1.20.1 onnx.onnx_cpp2py_export.checker.ValidationError when call quantize_static() in onnxruntime==1.20.1 Jan 7, 2025
@github-actions github-actions bot added the quantization issues related to quantization label Jan 7, 2025
@yuslepukhin
Copy link
Member

The model is referring to a weight that is expected to be stored in a file, but the file is not found. Whatever was used to create the model, externalized the weight and placed it in a file that is usually expected to be found next to the model.

@dzk9528
Copy link
Author

dzk9528 commented Jan 7, 2025

The model is referring to a weight that is expected to be stored in a file, but the file is not found. Whatever was used to create the model, externalized the weight and placed it in a file that is usually expected to be found next to the model.

@yuslepukhin I think you are right. But somehow the algorithm cannot properly load the model with external data and give errors in this simple cases, maybe the root cause is related to the save and reload the model part?

@mcollinswisc
Copy link
Contributor

The model is referring to a weight that is expected to be stored in a file, but the file is not found. Whatever was used to create the model, externalized the weight and placed it in a file that is usually expected to be found next to the model.

Please look at the test/repro code attached by @dzk9528. The ModelProto that is passed to onnxruntime.quantization.quantize does not have external weights.

The weights are externalized by the function save_and_reload_model_with_shape_infer defined here:
https://github.com/microsoft/onnxruntime/blob/da35cceac9cc30cf0e40c632315b4b500395111f/onnxruntime/python/tools/quantization/quant_utils.py#L989C5-L989C43
and called within the ONNXRuntime quantization here:

save_and_reload_model_with_shape_infer(model_input)

If you look at what is happening within the ONNXRuntime quantization tool:

  1. save_and_reload_model_with_shape_infer modifies its input ModelProto to use the external weights
  2. The temp directory where the external weights are stored is deleted at the end of the with statement within save_and_reload_model_with_shape_infer

So yes, the weights are externalized and then deleted. But this not happen in the model creation in user code. The broken external weight path is introduced inside of onnxruntime.quantization.quantize_static.

@yuslepukhin
Copy link
Member

Cc: @xiaoyu-work

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
quantization issues related to quantization
Projects
None yet
Development

No branches or pull requests

3 participants