TrainableModel#
- class TrainableModel(neural_network, loss='squared_error', optimizer=None, warm_start=False, initial_point=None, callback=None)[ソース]#
ベースクラス:
SerializableModelMixin
Base class for ML model that defines a scikit-learn like interface for Estimators.
- パラメータ:
neural_network (NeuralNetwork) – An instance of an quantum neural network. If the neural network has a one-dimensional output, i.e., neural_network.output_shape=(1,), then it is expected to return values in [-1, +1] and it can only be used for binary classification. If the output is multi-dimensional, it is assumed that the result is a probability distribution, i.e., that the entries are non-negative and sum up to one. Then there are two options, either one-hot encoding or not. In case of one-hot encoding, each probability vector resulting a neural network is considered as one sample and the loss function is applied to the whole vector. Otherwise, each entry of the probability vector is considered as an individual sample and the loss function is applied to the index and weighted with the corresponding probability.
loss (str | Loss) – A target loss function to be used in training. Default is squared_error, i.e. L2 loss. Can be given either as a string for 『absolute_error』 (i.e. L1 Loss), 『squared_error』, 『cross_entropy』, or as a loss function implementing the Loss interface.
optimizer (Optimizer | Minimizer | None) – An instance of an optimizer or a callable to be used in training. Refer to
Minimizer
for more information on the callable protocol. When None defaults toSLSQP
.warm_start (bool) – Use weights from previous fit to start next fit.
initial_point (np.ndarray) – Initial point for the optimizer to start from.
callback (Callable[[np.ndarray, float], None] | None) – A reference to a user’s callback function that has two parameters and returns
None
. The callback can access intermediate data during training. On each iteration an optimizer invokes the callback and passes current weights as an array and a computed value as a float of the objective function being optimized. This allows to track how well optimization / training process is going on.
- 例外:
QiskitMachineLearningError – unknown loss, invalid neural network
Attributes
- callback#
Return the callback.
- fit_result#
Returns a resulting object from the optimization procedure. Please refer to the documentation of the OptimizerResult class for more details.
- 例外:
QiskitMachineLearningError – If the model has not been fit.
- initial_point#
Returns current initial point
- loss#
Returns the underlying neural network.
- neural_network#
Returns the underlying neural network.
- optimizer#
Returns an optimizer to be used in training.
- warm_start#
Returns the warm start flag.
- weights#
Returns trained weights as a numpy array. The weights can be also queried by calling model.fit_result.x, but in this case their representation depends on the optimizer used.
- 例外:
QiskitMachineLearningError – If the model has not been fit.
Methods
- fit(X, y)[ソース]#
Fit the model to data matrix X and target(s) y.
- パラメータ:
- 戻り値:
returns a trained model.
- 戻り値の型:
self
- 例外:
QiskitMachineLearningError – In case of invalid data (e.g. incompatible with network)
- classmethod load(file_name)#
Loads a model from the file. If the loaded model is not an instance of the class whose method was called, then a warning is raised. Nevertheless, the loaded model may be a valid model.
- abstract predict(X)[ソース]#
Predict using the network specified to the model.
- パラメータ:
X (ndarray) – The input data.
- 例外:
QiskitMachineLearningError – Model needs to be fit to some training data first
- 戻り値:
The predicted classes.
- 戻り値の型:
- save(file_name)#
Saves this model to the specified file. Internally, the model is serialized via
dill
. All parameters are saved, including a primitive instance that is referenced by internal objects. That means if a model is loaded from a file and is used, for instance, for inference, the same primitive will be used even if a cloud primitive was used.- パラメータ:
file_name (str) – a file name or path where to save the model.