Quantize Speech Recognition Models with OpenVINO™ Post-Training Optimization Tool ¶
This tutorial is also available as a Jupyter notebook that can be cloned directly from GitHub. See the installation guide for instructions to run this tutorial locally on Windows, Linux or macOS.
This tutorial demonstrates how to apply INT8 quantization to the
speech recognition model, known as
Data2Vec, using the
Post-Training Optimization Tool API (POT
API)
(part of the OpenVINO Toolkit). This
notebook uses a fine-tuned
data2vec-audio-base-960h
PyTorch model trained on the LibriSpeech ASR
corpus. The tutorial is designed to be
extendable to custom models and datasets. It consists of the following
steps:
- Download and prepare model. 
- Define data loading and accuracy validation functionality. 
- Prepare the model for quantization. 
- Run optimization pipeline. 
- Compare performance of the original and quantized models. 
Download and prepare model¶
data2vec is a framework for self-supervised representation learning for images, speech, and text as described in data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language (Baevski et al., 2022). The algorithm uses the same learning mechanism for different modalities.
 
pre-trained pipeline¶
In our case, we will use data2vec-audio-base-960h model, which was
finetuned on 960 hours of audio from LibriSpeech Automatic Speech
Recognition corpus and distributed as part of HuggingFace transformers.
Obtain Pytorch model representation¶
For instantiating PyTorch model class, we should use
Data2VecAudioForCTC.from_pretrained method with providing model ID
for downloading from HuggingFace hub. Model weights and configuration
files will be downloaded automatically in first time usage. Keep in mind
that downloading the files can take several minutes and depends on your
internet connection.
Additionally, we can create processor class which is responsible for model specific pre- and post-processing steps.
!pip install -q soundfile librosa
from transformers import Wav2Vec2Processor, Data2VecAudioForCTC
processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
model = Data2VecAudioForCTC.from_pretrained("facebook/data2vec-audio-base-960h")
Convert model to OpenVINO Intermediate Representation¶
from pathlib import Path
# Set model directory
MODEL_DIR = Path("model")
MODEL_DIR.mkdir(exist_ok=True)
from openvino.tools import mo
from openvino.runtime import serialize, Core
import torch
core = Core()
BATCH_SIZE = 1
MAX_SEQ_LENGTH = 30480
def export_model_to_onnx(model, path):
    # switch model to evaluation mode
    model.eval()
    # disallow gradient propagation for reducing memory during export
    with torch.no_grad():
        # define dummy input with specific shape
        default_input = torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float)
        inputs = {
            "inputs": default_input
        }
        # define names for dynamic dimentions
        symbolic_names = {0: "batch_size", 1: "sequence_len"}
        # export model
        torch.onnx.export(
            model,
            (inputs["inputs"]),
            path,
            opset_version=11,
            input_names=["inputs"],
            output_names=["logits"],
            dynamic_axes={
                "inputs": symbolic_names,
                "logits": symbolic_names,
            },
        )
        print("ONNX model saved to {}".format(path))
onnx_model_path = MODEL_DIR / "data2vec-audo-base.onnx"
ir_model_path = onnx_model_path.with_suffix('.xml')
if not ir_model_path.exists():
    if not onnx_model_path.exists():
        export_model_to_onnx(model, onnx_model_path)
    ov_model = mo.convert_model(onnx_model_path, compress_to_fp16=True)
    serialize(ov_model, str(ir_model_path))
    print("IR model saved to {}".format(ir_model_path))
else:
    print("Read IR model from {}".format(ir_model_path))
    ov_model = core.read_model(ir_model_path)
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-408/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/data2vec/modeling_data2vec_audio.py:427: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-408/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/data2vec/modeling_data2vec_audio.py:466: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
ONNX model saved to model/data2vec-audo-base.onnx
IR model saved to model/data2vec-audo-base.xml
Prepare inference data¶
For demonstration purposes, we will use short dummy version of
librispeach dataset - patrickvonplaten/librispeech_asr_dummy to
speed up model evaluation. Model accuracy can be different from reported
in the paper. For reproducing original accuracy, use librispeech_asr
dataset.
!pip install -q datasets "torchmetrics>=0.11.0"
from datasets import load_dataset
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
# define preprocessing function for converting audio to input values for model
def map_to_input(batch):
    preprocessed_signal = processor(batch["audio"]["array"], return_tensors="pt", padding="longest", sampling_rate=batch['audio']['sampling_rate'])
    input_values = preprocessed_signal.input_values
    batch['input_values'] = input_values
    return batch
# apply preprocessing function to dataset and remove audio column, to save memory as we do not need it anymore
dataset = ds.map(map_to_input, batched=False, remove_columns=["audio"])
test_sample = ds[0]["audio"]
[ WARNING ]  Found cached dataset librispeech_asr_dummy (/opt/home/k8sworker/.cache/huggingface/datasets/patrickvonplaten___librispeech_asr_dummy/clean/2.1.0/f2c70a4d03ab4410954901bde48c54b85ca1b7f9bf7d616e7e2a72b5ee6ddbfc)
[ WARNING ]  Loading cached processed dataset at /opt/home/k8sworker/.cache/huggingface/datasets/patrickvonplaten___librispeech_asr_dummy/clean/2.1.0/f2c70a4d03ab4410954901bde48c54b85ca1b7f9bf7d616e7e2a72b5ee6ddbfc/cache-c14398002c490f3f.arrow
Check model inference result¶
The code below is used for running model inference on a single sample from the dataset. It contains the following steps:
- Get the input_values tensor as model input. 
- Run model inference and obtain logits. 
- Find logits ids with highest probability, using argmax. 
- Decode predicted token ids, using processor. 
For reference, see the same function provided for OpenVINO model.
import numpy as np
# inference function for pytorch
def torch_infer(model, sample):
    logits = model(torch.Tensor(sample['input_values'])).logits
    # take argmax and decode
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    return transcription
# inference function for openvino
def ov_infer(model, sample):
    output = model.output(0)
    logits = model(np.array(sample['input_values']))[output]
    predicted_ids = np.argmax(logits, axis=-1)
    transcription = processor.batch_decode(torch.from_numpy(predicted_ids))
    return transcription
pt_transcription = torch_infer(model, dataset[0])
compiled_model = core.compile_model(ov_model)
ov_transcription = ov_infer(compiled_model, dataset[0])
import IPython.display as ipd
print(f"[Reference]:     {dataset[0]['text']}")
print(f"[PyTorch]:       {pt_transcription[0]}")
print(f"[OpenVINO FP16]: {ov_transcription[0]}")
ipd.Audio(test_sample["array"], rate=16000)
[Reference]:     A MAN SAID TO THE UNIVERSE SIR I EXIST
[PyTorch]:       A MAN SAID TO THE UNIVERSE SIR I EXIST
[OpenVINO FP16]: A MAN SAID TO THE UNIVERSE SIR I EXIST
Validate model accuracy on dataset¶
For model accuracy evaluation, Word Error Rate metric can be used. Word Error Rate or WER is the ratio of errors in a transcript to the total words spoken. A lower WER in speech-to-text means better accuracy in recognizing speech.
For WER calculation, we will use torchmetrics library.
from torchmetrics import WordErrorRate
from tqdm.notebook import tqdm
def compute_wer(dataset, model, infer_fn):
    wer = WordErrorRate()
    for sample in tqdm(dataset):
        # run infer function on sample
        transcription = infer_fn(model, sample)
        # update metric on sample result
        wer.update(transcription, [sample['text']])
    # finalize metric calculation
    result = wer.compute()
    return result
pt_result = compute_wer(dataset, model, torch_infer)
ov_result = compute_wer(dataset, compiled_model, ov_infer)
0%|          | 0/73 [00:00<?, ?it/s]
0%|          | 0/73 [00:00<?, ?it/s]
print(f'[PyTorch]   Word Error Rate: {pt_result:.4f}')
print(f'[OpenVino]  Word Error Rate: {ov_result:.4f}')
[PyTorch]   Word Error Rate: 0.0383
[OpenVino]  Word Error Rate: 0.0383
Prepare quantization pipeline¶
Post-Training Optimization Tool designed to accelerate the inference of DL models by converting them into a more hardware-friendly representation by applying specific methods that do not require re-training, for example, post-training quantization. For more details about the low-precision flow in OpenVINO™, refer to the Low Precision Optimization Guide.
The Python POT
API
provides simple interfaces for implementing custom model inference with
data loading and pre-processing on an arbitrary dataset and implementing
custom accuracy metrics to make it possible to use optimization
algorithms from the POT. The Python POT API represented by Pipeline
class for creating and configuring the optimization pipeline and
applying it to the model. The Pipeline class depends on the
implementation of the following model specific interfaces which should
be implemented according to the custom DL model:
- Engineis responsible for model inference and provides statistical data and accuracy metrics for the model.
- DataLoaderis responsible for the dataset loading, including the data pre-processing.
- Metricis responsible for calculating the accuracy metric for the model.
The diagram below shows relationships between the classes:
 
pot pipeline¶
Define DataLoader class¶
Define DataLoader based on POT API, as it will be used to collect
statistics for quantization and run model evaluation. Data22Vec model
accepts a raw waveform of the speech signal as input and produces
vocabulary class estimations as output. We already have prepared dataset
above for accuracy measurement. It will serve as data source for
quantization. DataLoader class encapsulates logic for iteration over
dataset samples and gets input data and label by index using
__getitem__ method.
from openvino.tools.pot import Metric, DataLoader, IEEngine, load_model, save_model, compress_model_weights, create_pipeline
class LibriSpeechDataLoader(DataLoader):
    # Required methods
    def __init__(self, dataset, sample_limit=None):
        """Constructor
        :param config: data loader specific config
        """
        super().__init__({})
        self._ds = dataset
        self.sample_limit = None
    def __len__(self):
        """Returns size of the dataset"""
        return self.sample_limit or len(self._ds)
    def __getitem__(self, index):
        """
        Returns annotation, data and metadata at the specified index.
        Possible formats:
        (index, annotation), data
        (index, annotation), data, metadata
        """
        if self.sample_limit is not None and index >= self.sample_limit:
            raise StopIteration
        sample = self._ds[index]
        inputs = {'inputs': np.array(sample['input_values'])}
        label = [sample['text']]
        return inputs, label
/opt/home/k8sworker/cibuilds/ov-notebook/OVNotebookOps-408/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/openvino/offline_transformations/__init__.py:10: FutureWarning: The module is private and following namespace offline_transformations will be removed in the future, use openvino.runtime.passes instead! warnings.warn(
Define Evaluation Metric class¶
In this step, the Metric interface for WER metric is implemented. To
make the metric compatible with running inside POT Pipeline, we should
inherit it from openvino.tools.pot.Metric class and override
following properties and methods: * value - returns the accuracy
metric value for the last model output. * avg_value - returns the
average accuracy metric value for all model outputs. * attributes -
returns a dictionary of metric attributes: direction - metric
growing direction (higher-better or higher-worse), type -
type of metric. * update(output, annotation) - calculates and
updates the accuracy metric value using last model output and
annotation. * reset() - resets collected accuracy metric.
class WERMetric(Metric):
    def __init__(self):
        super().__init__()
        self._name = "WER"
    def reset(self):
        """
        Resets collected matches
        """
        self._wer = WordErrorRate()
        self._last_result = None
    def get_attributes(self):
        """
        Returns a dictionary of metric attributes {metric_name: {attribute_name: value}}.
        Required attributes: 'direction': 'higher-better' or 'higher-worse'
                             'type': metric type
        """
        return {self._name: {"direction": "higher-worse", "type": "WER"}}
    @property
    def value(self):
        """Returns accuracy metric value for the last model output."""
        return {self._name: self._last_result if self._last_result is not None else self._wer.compute().item()}
    @property
    def avg_value(self):
        """Returns accuracy metric value for all model outputs."""
        return {self._name: self._wer.compute().item()}
    def update(self, output, target):
        """
        Updates prediction matches.
        :param output: model output
        :param target: annotations
        """
        res = output[0]
        predicted_ids = np.argmax(res, axis=-1)
        predicted_transcription = processor.batch_decode(torch.from_numpy(predicted_ids))
        res = []
        for pred, gt in zip(predicted_transcription, target):
            res.append(self._wer.forward([pred], gt).item())
        self._last_result = res
        return res
Define quantization configuration and optimization pipeline¶
The code below defines a configuration for the quantization pipeline and
runs it. To keep example minimalistic, built-in IEEngine
implementation of Engine interface from the POT API for model
inference is used here. We will use DefaultQuantization algorithm with
performance preset and additional specification of quantization
algorithm for activations. For information about configuration
parameters, refer to POT
documentation.
Our model architecture is transformer-based, so
model_type: transformer should be selected. For better accuracy,
part of layers should be kept in floating point representation using
ignored parameter. The ignored layers can be selected using
AccuracyAwareQuantization
algorithm, which aim to find layers that have the most significant
impact on accuracy drop and revert them back to floating point
precision. This process can be time consuming, that is why we keep this
experiment out of this tutorial and reuse its result, using
DefaultQuantization algorithm. > NOTE: Consider increasing
stat_subset_size to get more precise results. A suggested value is
300 or more, as it will take longer time to process.
model_config = {"model_name": "data2vec_base", "model": ir_model_path, "weights": ir_model_path.with_suffix(".bin")}
engine_config = {"device": "CPU"}
algorithms = [
    {
        "name": "DefaultQuantization",
        "params": {
            "target_device": "ANY",
            "model_type": "transformer",
            "preset": "performance",
            "stat_subset_size": 300,
            "activations": {
                "range_estimator": {
                    "min": {
                        "aggregator": "min",
                        "type": "min"
                    },
                    "max": {
                        "aggregator": "mean",
                        "type": "quantile",
                        "outlier_prob": 0.0001
                    }
                }
            },
            "ignored": {
                "scope": [
                    "/data2vec_audio/encoder/layers.3/feed_forward/intermediate_dense/MatMul",
                    "/data2vec_audio/feature_extractor/conv_layers.2/conv/Conv",
                    "/data2vec_audio/encoder/layers.3/Add_1",
                    "/data2vec_audio/encoder/layers.2/feed_forward/intermediate_dense/MatMul",
                    "/data2vec_audio/feature_extractor/conv_layers.0/conv/Conv",
                    "/data2vec_audio/encoder/layers.4/Add_1",
                    "/data2vec_audio/encoder/layers.4/feed_forward/intermediate_dense/MatMul",
                    "/data2vec_audio/encoder/layers.4/final_layer_norm/Div",
                    "/data2vec_audio/encoder/layers.4/feed_forward/output_dense/MatMul",
                    "/data2vec_audio/encoder/layers.8/attention/MatMul_1",
                    "/data2vec_audio/feature_extractor/conv_layers.1/conv/Conv",
                    "/data2vec_audio/encoder/layers.2/Add_1",
                    "/data2vec_audio/feature_extractor/conv_layers.0/layer_norm/Div",
                    "/data2vec_audio/encoder/layers.1/feed_forward/intermediate_dense/MatMul",
                    "/data2vec_audio/encoder/layers.1/Add_1",
                    "/data2vec_audio/feature_extractor/conv_layers.3/layer_norm/Div"
                ]
            }
        }
    }
]
# Step 1: Load the model.
model = load_model(model_config=model_config)
# Step 2: Initialize the data loader.
data_loader = LibriSpeechDataLoader(dataset)
# Step 3 (Optional. Required for AccuracyAwareQuantization): Initialize the metric.
metric = WERMetric()
# Step 4: Initialize the engine for metric calculation and statistics collection.
engine = IEEngine(config=engine_config, data_loader=data_loader, metric=metric)
# Step 5: Create a pipeline of compression algorithms.
pipeline = create_pipeline(algo_config=algorithms, engine=engine)
Run model quantization¶
Now, when all parts of compression pipeline are collected, we can start quantization. >NOTE: quantization process is time and memory consuming. It may takes several minutes depending on your hardware configuration.
import time
# Step 6: Run compression pipeline
print(f"Quantizing model with {algorithms[0]['params']['preset']} preset and {algorithms[0]['name']}")
start_time = time.perf_counter()
compressed_model = pipeline.run(model=model)
end_time = time.perf_counter()
print(f"Quantization finished in {end_time - start_time:.2f} seconds")
Quantizing model with performance preset and DefaultQuantization
Quantization finished in 114.44 seconds
After quantization is finished, compressed model representation can be
saved using save_model function.
# Step 7 (Optional): Compress model weights to quantized precision
#                    in order to reduce the size of the final .bin file.
compress_model_weights(model=compressed_model)
# Step 8: Save the compressed model to the desired path.
compressed_model_paths = save_model(model=compressed_model, save_path=MODEL_DIR, model_name="quantized_data2vec_base")
compressed_model_path = compressed_model_paths[0]["model"]
Check INT8 model inference result¶
INT8 model is the same in usage like the original one. We need to
read it, using the core.read_model method and load on the device,
using core.compile_model. After that, we can reuse the same
ov_infer function for getting model inference result on test sample.
ov_int8_model = core.read_model(compressed_model_path)
int8_compiled_model = core.compile_model(ov_int8_model)
transcription = ov_infer(int8_compiled_model, dataset[0])
print(f"[Reference]:     {dataset[0]['text']}")
print(f"[OpenVINO INT8]: {transcription[0]}")
ipd.Audio(test_sample["array"], rate=16000)
[Reference]:     A MAN SAID TO THE UNIVERSE SIR I EXIST
[OpenVINO INT8]: A MAN SAID TO THE UNIVERSE SIR I EXIST
Compare Performance of the Original and Quantized Models¶
Benchmark
Tool
is used to measure the inference performance of the FP16 and
INT8 models.
NOTE: For more accurate performance, it is recommended to run
benchmark_appin a terminal/command prompt after closing other applications. Runbenchmark_app -m model.xml -d CPUto benchmark async inference on CPU for one minute. ChangeCPUtoGPUto benchmark on GPU. Runbenchmark_app --helpto see an overview of all command-line options.
# Inference FP16 model (OpenVINO IR)
! benchmark_app -m $ir_model_path -shape "[1,30480]" -d CPU -api async -t 15
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2022.3.0-9052-9752fafe8eb-releases/2022/3
[ INFO ]
[ INFO ] Device info:
[ INFO ] CPU
[ INFO ] Build ................................. 2022.3.0-9052-9752fafe8eb-releases/2022/3
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(CPU) performance hint will be set to THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 195.95 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ]     inputs (node: inputs) : f32 / [...] / [?,?]
[ INFO ] Model outputs:
[ INFO ]     logits (node: logits) : f32 / [...] / [?,?,32]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[ INFO ] Reshaping model: 'inputs': [1,30480]
[ INFO ] Reshape model took 26.72 ms
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ]     inputs (node: inputs) : f32 / [...] / [1,30480]
[ INFO ] Model outputs:
[ INFO ]     logits (node: logits) : f32 / [...] / [1,95,32]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 1044.55 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ]   NETWORK_NAME: torch_jit
[ INFO ]   OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ]   NUM_STREAMS: 6
[ INFO ]   AFFINITY: Affinity.CORE
[ INFO ]   INFERENCE_NUM_THREADS: 24
[ INFO ]   PERF_COUNT: False
[ INFO ]   INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ]   PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ]   PERFORMANCE_HINT_NUM_REQUESTS: 0
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'inputs'!. This input will be filled with random values!
[ INFO ] Fill input 'inputs' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests, limits: 15000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 75.00 ms
[Step 11/11] Dumping statistics report
[ INFO ] Count:            636 iterations
[ INFO ] Duration:         15243.00 ms
[ INFO ] Latency:
[ INFO ]    Median:        143.32 ms
[ INFO ]    Average:       143.43 ms
[ INFO ]    Min:           104.70 ms
[ INFO ]    Max:           169.35 ms
[ INFO ] Throughput:   41.72 FPS
# Inference INT8 model (OpenVINO IR)
! benchmark_app -m $compressed_model_path -shape "[1,30480]" -d CPU -api async -t 15
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2022.3.0-9052-9752fafe8eb-releases/2022/3
[ INFO ]
[ INFO ] Device info:
[ INFO ] CPU
[ INFO ] Build ................................. 2022.3.0-9052-9752fafe8eb-releases/2022/3
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(CPU) performance hint will be set to THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 158.45 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ]     inputs (node: inputs) : f32 / [...] / [?,?]
[ INFO ] Model outputs:
[ INFO ]     logits (node: logits) : f32 / [...] / [?,?,32]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[ INFO ] Reshaping model: 'inputs': [1,30480]
[ INFO ] Reshape model took 28.84 ms
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ]     inputs (node: inputs) : f32 / [...] / [1,30480]
[ INFO ] Model outputs:
[ INFO ]     logits (node: logits) : f32 / [...] / [1,95,32]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 905.36 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ]   NETWORK_NAME: torch_jit
[ INFO ]   OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ]   NUM_STREAMS: 6
[ INFO ]   AFFINITY: Affinity.CORE
[ INFO ]   INFERENCE_NUM_THREADS: 24
[ INFO ]   PERF_COUNT: False
[ INFO ]   INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ]   PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ]   PERFORMANCE_HINT_NUM_REQUESTS: 0
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'inputs'!. This input will be filled with random values!
[ INFO ] Fill input 'inputs' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests, limits: 15000 ms duration)
[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).
[ INFO ] First inference took 55.33 ms
[Step 11/11] Dumping statistics report
[ INFO ] Count:            906 iterations
[ INFO ] Duration:         15107.30 ms
[ INFO ] Latency:
[ INFO ]    Median:        99.24 ms
[ INFO ]    Average:       99.67 ms
[ INFO ]    Min:           87.73 ms
[ INFO ]    Max:           124.02 ms
[ INFO ] Throughput:   59.97 FPS
Compare Accuracy of the Original and Quantized Models¶
Finally, calculate WER metric for the INT8 model representation and
compare it with the FP16 result.
int8_ov_result = compute_wer(dataset, int8_compiled_model, ov_infer)
print(f'[OpenVino FP16] Word Error Rate: {ov_result:.4}')
print(f'[OpenVino INT8] Word Error Rate: {int8_ov_result:.4f}')
0%|          | 0/73 [00:00<?, ?it/s]
[OpenVino FP16] Word Error Rate: 0.03826
[OpenVino INT8] Word Error Rate: 0.0504