kedro_onnx.io
#
DataSets and DataSet definitions for ONNX models.
Submodules#
Package Contents#
Classes#
Loads and saves ONNX models. |
|
Object to store an ONNX model and kwargs for the converter function. |
- class kedro_onnx.io.OnnxDataSet(filepath: str, backend: kedro_onnx.typing.OnnxFrameworks = 'onnx', load_args: Dict[str, Any] = None, version: kedro.io.core.Version = None, credentials: Dict[str, Any] = None, fs_args: Dict[str, Any] = None)[source]#
Bases:
FsspecDataSet
[object
,kedro_onnx.typing.ModelProto
]Loads and saves ONNX models.
- backend#
ONNX backend to use.
- Type:
OnnxFrameworks
Example
>>> from kedro_onnx.io import OnnxDataSet, OnnxSaveModel >>> from kedro_onnx.io import FloatTensorType >>> from sklearn.linear_model import LinearRegression >>> >>> path = fs.path('test.onnx') >>> data_set = OnnxDataSet(path, backend='sklearn') >>> >>> model = LinearRegression() >>> model = model.fit([[1], [2], [3]], [2, 4, 6]) >>> >>> save_model = OnnxSaveModel(model=model, ... kwargs={'initial_types': ( ... ('input', FloatTensorType([None, 1])),)}) >>> data_set.save(save_model) >>> onnx_model = data_set.load() >>> onnx_model.producer_name 'skl2onnx' >>> >>> from kedro_onnx.inference import run >>> run(onnx_model, [[4]]) array([[8.]], dtype=float32)
For some backends, you may have to specify additional kwargs:
In the example above, we used the sklearn backend. This backend requires the initial_types kwarg to be specified. For more information, see the skl2onnx documentation.
>>> data_set.save(model) Traceback (most recent call last): ... kedro.io.core.DataSetError: You need to specify `initial_types` for `sklearn` backend. Use the `kedro_onnx.OnnxSaveModel` `kwargs` to specify additional arguments to the conversion function.
You can even use the bare onnx model as the backend:
>>> data_set = OnnxDataSet(path, backend='onnx') >>> data_set.save(onnx_model) # already an onnx model >>> onnx_model = data_set.load() >>> onnx_model.producer_name 'skl2onnx'
- _describe() Dict[str, Any] #
- _load_fp(fp: io.IOBase) kedro_onnx.typing.ModelProto #
- _validate_kwarg(kwargs: dict, key: str)#
- _validate_sklearn(model: Any, kwargs: dict)#
- _validate_lightbgm(model: Any, kwargs: dict)#
- _validate_sparkml(model: Any, kwargs: dict)#
- _validate_xgboost(model: Any, kwargs: dict)#
- _validate(model: Any, kwargs: dict)#
- _convert(model: Any, kwargs: Any) kedro_onnx.typing.ModelProto #
- _save_fp(fp: io.IOBase, data: Union[OnnxSaveModel, onnx.ModelProto, Any]) None #