Source code for kedro_onnx.io.datasets

"""ONNX datasets."""
from __future__ import annotations
from dataclasses import dataclass, field
import logging
from abc import abstractmethod
from copy import deepcopy
from io import IOBase
from pathlib import PurePosixPath
from typing import Any, Dict, Union, get_args

import fsspec
from kedro.io.core import (
    AbstractVersionedDataSet, DataSetError, Version, get_filepath_str,
    get_protocol_and_path
)

from kedro_onnx.typing import IT, OT, OnnxFrameworks, ModelProto
import kedro_onnx.utils as utils
import onnx
import onnxmltools

[docs]logger = logging.getLogger(__name__)
[docs]class FsspecDataSet(AbstractVersionedDataSet[IT, OT]): """An abstract DataSet for creating a new DataSet using fsspec."""
[docs] DEFAULT_LOAD_ARGS: Dict[str, Any] = {}
[docs] DEFAULT_SAVE_ARGS: Dict[str, Any] = {}
def __init__( self, filepath: str, load_args: Dict[str, Any] = None, save_args: Dict[str, Any] = None, version: Version = None, credentials: Dict[str, Any] = None, fs_args: Dict[str, Any] = None, ) -> None: """Initializes `fsspec` targeting the given filepath. Args: filepath: Filepath in POSIX format to a file prefixed with a protocol like `s3://`. If prefix is not provided, `file` protocol (local filesystem) will b used. The prefix should be any protocol supported by `fsspec``. Note: `http(s)` doesn't support versioning. version: If specified, should be an instance of `kedro.io.core.Version`. If its `load` attribute is None, the latest version will be loaded. If its ``save`` attribute is None, save version will be autogenerated. credentials: Credentials required to get access to the underlying filesystem. E.g. for `GCSFileSystem` it should look like `{"token": None}`. fs_args: Extra arguments to pass into underlying filesystem class constructor (e.g. `{"project": "my-project"}` for `GCSFileSystem`), as well as to pass to the filesystem's `open` method through nested keys `open_args_load` and `open_args_save`. Example: >>> class MyDataSet(FsspecDataSet): ... def _load_fp(self, fp: IOBase) -> Any: ... return fp.read() ... def _save_fp(self, fp: IOBase, data: Any) -> None: ... fp.write(data) >>> path = fs.path('test.txt') >>> data_set = MyDataSet(path) >>> data_set.exists() False >>> data_set.save('abc') >>> data_set.load() 'abc' >>> data_set.exists() True >>> data_set._release() Note: Here you can find all available arguments for `open`: https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open # noqa: E501 """ _fs_args = deepcopy(fs_args) or {} _fs_open_args_load = _fs_args.pop("open_args_load", {}) _fs_open_args_save = _fs_args.pop("open_args_save", {}) _credentials = deepcopy(credentials) or {} protocol, path = get_protocol_and_path(filepath, version) if protocol == "file": _fs_args.setdefault("auto_mkdir", True) _fs_open_args_load.setdefault("mode", "r") _fs_open_args_save.setdefault("mode", "w") self._protocol = protocol self._fs: fsspec.AbstractFileSystem = fsspec.filesystem( self._protocol, **_credentials, **_fs_args ) super().__init__( filepath=PurePosixPath(path), version=version, exists_function=self._fs.exists, glob_function=self._fs.glob, ) # Handle default load and save arguments self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) self._load_args.update(load_args or {}) self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) self._save_args.update(save_args or {}) self._fs_open_args_load = _fs_open_args_load self._fs_open_args_save = _fs_open_args_save
[docs] def _describe(self) -> Dict[str, Any]: return dict( filepath=self._filepath, protocol=self._protocol, load_args=self._load_args, save_args=self._save_args, version=self._version, )
@abstractmethod
[docs] def _load_fp(self, fp: IOBase) -> OT: pass # pragma: no cover
[docs] def _load(self) -> OT: load_path = get_filepath_str(self._get_load_path(), self._protocol) with self._fs.open(load_path, **self._fs_open_args_load) as fp: return self._load_fp(fp)
@abstractmethod
[docs] def _save_fp(self, fp: IOBase, data: Any) -> None: pass # pragma: no cover
[docs] def _save(self, data: IT) -> None: save_path = get_filepath_str(self._get_save_path(), self._protocol) with self._fs.open(save_path, **self._fs_open_args_save) as fp: self._save_fp(fp, data) self._invalidate_cache()
[docs] def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) except DataSetError: # pragma: no cover return False return self._fs.exists(load_path)
[docs] def _release(self) -> None: super()._release() self._invalidate_cache()
[docs] def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath)
@dataclass
[docs]class OnnxSaveModel: """Object to store an ONNX model and kwargs for the converter function."""
[docs] model: Any
[docs] kwargs: dict = field(default_factory=dict)
[docs]class OnnxDataSet(FsspecDataSet[object, ModelProto]): """Loads and saves ONNX models. Attributes: backend (OnnxFrameworks): ONNX backend to use. Example: >>> from kedro_onnx.io import OnnxDataSet, OnnxSaveModel >>> from kedro_onnx.io import FloatTensorType >>> from sklearn.linear_model import LinearRegression >>> >>> path = fs.path('test.onnx') >>> data_set = OnnxDataSet(path, backend='sklearn') >>> >>> model = LinearRegression() >>> model = model.fit([[1], [2], [3]], [2, 4, 6]) >>> >>> save_model = OnnxSaveModel(model=model, ... kwargs={'initial_types': ( ... ('input', FloatTensorType([None, 1])),)}) >>> data_set.save(save_model) >>> onnx_model = data_set.load() >>> onnx_model.producer_name 'skl2onnx' >>> >>> from kedro_onnx.inference import run >>> run(onnx_model, [[4]]) array([[8.]], dtype=float32) For some backends, you may have to specify additional kwargs: In the example above, we used the `sklearn` backend. This backend requires the `initial_types` kwarg to be specified. For more information, see the `skl2onnx` documentation. >>> data_set.save(model)# doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE Traceback (most recent call last): ... kedro.io.core.DataSetError: You need to specify `initial_types` for `sklearn` backend. Use the `kedro_onnx.OnnxSaveModel` `kwargs` to specify additional arguments to the conversion function. You can even use the bare onnx model as the backend: >>> data_set = OnnxDataSet(path, backend='onnx') >>> data_set.save(onnx_model) # already an onnx model >>> onnx_model = data_set.load() >>> onnx_model.producer_name 'skl2onnx' """ def __init__( self, filepath: str, backend: OnnxFrameworks = 'onnx', load_args: Dict[str, Any] = None, version: Version = None, credentials: Dict[str, Any] = None, fs_args: Dict[str, Any] = None, ) -> None: """Initialises OnnxDataSet. Args: filepath (str): Filepath in POSIX format to a ONNX file prefixed with a protocol like `s3://`. If prefix is not provided, `file` protocol (local filesystem) will be used. The prefix should be any protocol supported by `fsspec`. Note: `http(s)` doesn't support versioning. backend (OnnxFrameworks): ONNX backend to use. To see the list of supported backends, look at `kedro_onnx.typing.OnnxFrameworks`. Defaults to `onnx`. load_args (Dict[str, Any], optional): Arguments for the conversion function from `onnxmltools`. Defaults to None. version (Version, optional): If specified, should be an instance of `kedro.io.core.Version`. If its `load` attribute is None, the latest version will be loaded. If its `save` attribute is None, save version will be autogenerated. Defaults to None. credentials (Dict[str, Any], optional): Credentials required to get access to the underlying filesystem. E.g. for `GCSFileSystem` it should look like `{"token": None}`. fs_args (Dict[str, Any], optional): Extra arguments to pass into underlying filesystem class constructor (e.g. `{"project": "my-project"}` for `GCSFileSystem`), as well as to pass to the filesystem's `open` method through nested keys `open_args_load` and `open_args_save`. Note: Here you can find all available arguments for `open`: https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open # noqa: E501 All defaults will be preserved, unless overwritten. `mode` argument is ignored and overwritten with `rb` or `wb`. Defaults to None. """ super().__init__( filepath, load_args, None, version, credentials, fs_args ) self._fs_open_args_load.update({"mode": "rb"}) self._fs_open_args_save.update({"mode": "wb"}) if backend != 'onnx': assert backend in utils.onnx_converters,\ (f"Backend {backend} is not supported. Supported backends are:" f" {get_args(OnnxFrameworks)}") utils.check_installed(utils.onnx_converters[backend]) self._backend = backend
[docs] def _describe(self) -> Dict[str, Any]: return dict(backend=self._backend, **super()._describe())
[docs] def _load_fp(self, fp: IOBase) -> ModelProto: model: ModelProto = onnx.ModelProto() model.ParseFromString(fp.read()) return model
[docs] def _validate_kwarg(self, kwargs: dict, key: str): if key not in kwargs: raise DataSetError( f"You need to specify `{key}` for `{self._backend}` backend.\n" "Use the `kedro_onnx.OnnxSaveModel` `kwargs` to specify " "additional arguments to the conversion function." )
[docs] def _validate_sklearn(self, model: Any, kwargs: dict): self._validate_kwarg(kwargs, "initial_types")
[docs] def _validate_lightbgm(self, model: Any, kwargs: dict): self._validate_kwarg(kwargs, "initial_types")
[docs] def _validate_sparkml(self, model: Any, kwargs: dict): self._validate_kwarg(kwargs, "initial_types") self._validate_kwarg(kwargs, "spark_session")
[docs] def _validate_xgboost(self, model: Any, kwargs: dict): self._validate_kwarg(kwargs, "initial_types")
[docs] def _validate(self, model: Any, kwargs: dict): if self._backend == "sklearn": self._validate_sklearn(model, kwargs) elif self._backend == "lightgbm": self._validate_lightbgm(model, kwargs) elif self._backend == "sparkml": self._validate_sparkml(model, kwargs) elif self._backend == "xgboost": self._validate_xgboost(model, kwargs)
[docs] def _convert(self, model: Any, kwargs: Any) -> ModelProto: convert_fn = getattr(onnxmltools, f"convert_{self._backend}") return convert_fn(model, **kwargs)
[docs] def _save_fp( self, fp: IOBase, data: Union[OnnxSaveModel, onnx.ModelProto, Any] ) -> None: if isinstance(data, onnx.ModelProto): model = data else: save_model = ( data if isinstance(data, OnnxSaveModel) else OnnxSaveModel(data) ) full_kwargs = {**self._load_args, **save_model.kwargs} self._validate(save_model.model, full_kwargs) model = self._convert(save_model.model, full_kwargs) fp.write(model.SerializeToString())