Model
The model API provides convenient high-level interface to do training and predicting on a network described using the symbolic API.
# MXNet.mx.AbstractModel
— Type.
AbstractModel
The abstract super type of all models in MXNet.jl.
source
# MXNet.mx.FeedForward
— Type.
FeedForward
The feedforward model provides convenient interface to train and predict on feedforward architectures like multi-layer MLP, ConvNets, etc. There is no explicitly handling of time index, but it is relatively easy to implement unrolled RNN / LSTM under this framework (TODO: add example). For models that handles sequential data explicitly, please use TODO...
source
# MXNet.mx.FeedForward
— Method.
FeedForward(arch :: SymbolicNode, ctx)
Arguments:
-
arch
: the architecture of the network constructed using the symbolic API. -
ctx
: the devices on which this model should do computation. It could be a singleContext
or a list ofContext
objects. In the latter case, data parallelization will be used for training. If no context is provided, the default contextcpu()
will be used.
source
# MXNet.mx.predict
— Method.
predict(self, data; overwrite=false, callback=nothing)
Predict using an existing model. The model should be already initialized, or trained or loaded from a checkpoint. There is an overloaded function that allows to pass the callback as the first argument, so it is possible to do
predict(model, data) do batch_output
# consume or write batch_output to file
end
Arguments:
-
self::FeedForward
: the model. -
data::AbstractDataProvider
: the data to perform prediction on. -
overwrite::Bool
: anExecutor
is initialized the first time predict is called. The memory allocation of theExecutor
depends on the mini-batch size of the test data provider. If you call predict twice with data provider of the same batch-size, then the executor can be potentially be re-used. So, ifoverwrite
is false, we will try to re-use, and raise an error if batch-size changed. Ifoverwrite
is true (the default), a newExecutor
will be created to replace the old one. -
verbosity::Integer
: Determines the verbosity of the print messages. Higher numbers leads to more verbose printing. Acceptable values are -0
: Do not print anything during prediction -1
: Print allocation information during prediction
Note
Prediction is computationally much less costly than training, so the bottleneck sometimes becomes the IO for copying mini-batches of data. Since there is no concern about convergence in prediction, it is better to set the mini-batch size as large as possible (limited by your device memory) if prediction speed is a concern.
For the same reason, currently prediction will only use the first device even if multiple devices are provided to construct the model.
Note
If you perform further after prediction. The weights are not automatically synchronized if overwrite
is set to false and the old predictor is re-used. In this case setting overwrite
to true (the default) will re-initialize the predictor the next time you call predict and synchronize the weights again.
See also train
, fit
, init_model
, and load_checkpoint
source
# MXNet.mx._split_inputs
— Method.
Get a split of batch_size
into n_split
pieces for data parallelization. Returns a vector of length n_split
, with each entry a UnitRange{Int}
indicating the slice index for that piece.
source
# MXNet.mx.fit
— Method.
fit(model :: FeedForward, optimizer, data; kwargs...)
Train the model
on data
with the optimizer
.
-
model::FeedForward
: the model to be trained. -
optimizer::AbstractOptimizer
: the optimization algorithm to use. -
data::AbstractDataProvider
: the training data provider. -
n_epoch::Int
: default 10, the number of full data-passes to run. -
eval_data::AbstractDataProvider
: keyword argument, defaultnothing
. The data provider for the validation set. -
eval_metric::AbstractEvalMetric
: keyword argument, defaultAccuracy()
. The metric used to evaluate the training performance. Ifeval_data
is provided, the same metric is also calculated on the validation set. -
kvstore
: keyword argument, default:local
. The key-value store used to synchronize gradients and parameters when multiple devices are used for training. :type kvstore:KVStore
orSymbol
-
initializer::AbstractInitializer
: keyword argument, defaultUniformInitializer(0.01)
. -
force_init::Bool
: keyword argument, default false. By default, the random initialization using the providedinitializer
will be skipped if the model weights already exists, maybe from a previous call totrain
or an explicit call toinit_model
orload_checkpoint
. When this option is set, it will always do random initialization at the begining of training. -
callbacks::Vector{AbstractCallback}
: keyword argument, default[]
. Callbacks to be invoked at each epoch or mini-batch, seeAbstractCallback
. -
verbosity::Int
: Determines the verbosity of the print messages. Higher numbers leads to more verbose printing. Acceptable values are -0
: Do not print anything during training -1
: Print starting and final messages -2
: Print one time messages and a message at the start of each epoch -3
: Print a summary of the training and validation accuracy for each epoch
source
# MXNet.mx.init_model
— Method.
init_model(self, initializer; overwrite=false, input_shapes...)
Initialize the weights in the model.
This method will be called automatically when training a model. So there is usually no need to call this method unless one needs to inspect a model with only randomly initialized weights.
Arguments:
-
self::FeedForward
: the model to be initialized. -
initializer::AbstractInitializer
: an initializer describing how the weights should be initialized. -
overwrite::Bool
: keyword argument, force initialization even when weights already exists. -
input_shapes
: the shape of all data and label inputs to this model, given as keyword arguments. For example,data=(28,28,1,100), label=(100,)
.
source
# MXNet.mx.load_checkpoint
— Method.
load_checkpoint(prefix, epoch, ::mx.FeedForward; context)
Load a mx.FeedForward model from the checkpoint prefix, epoch and optionally provide a context.
source
# MXNet.mx.train
— Method.
train(model :: FeedForward, ...)
Alias to fit
.
source