2using System.Collections.Generic;
55 public SGDSolver(
CudaDnn<T> cuda,
Log log,
SolverParameter p,
CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest,
IXDatabaseBase db,
IXPersist<T> persist,
int nSolverCount = 1,
int nSolverRank = 0,
Net<T> shareNet =
null, onGetWorkspace getws =
null, onSetWorkspace setws =
null)
56 : base(cuda, log, p, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws)
99 for (
int i = 0; i < colNetParams.
Count; i++)
101 List<int> rgShape = colNetParams[i].shape();
133 if (nIterationOverride == -1)
201 string strOut =
"Iteration " +
m_nIter.ToString() +
", lr = " + dfRate.ToString() +
", Loss = " +
m_dfSmoothedLoss.ToString();
210 for (
int i = 0; i <
m_net.learnable_parameters.Count; i++)
258 state.history.Add(blob.
ToProto());
276 if (!colNetParams[param_id].DiffExists)
280 m_cuda.scal(colNetParams[param_id].count(), dfAccumNormalization, colNetParams[param_id].mutable_gpu_diff);
291 if (!colNetParams[param_id].DiffExists)
294 List<double?> rgNetParamWeightDecay =
m_net.params_weight_decay;
296 double dfLocalDecay = dfWeightDecay * rgNetParamWeightDecay[param_id].GetValueOrDefault(0);
298 if (dfLocalDecay > 0)
304 m_cuda.axpy(colNetParams[param_id].count(), dfLocalDecay, colNetParams[param_id].gpu_data, colNetParams[param_id].mutable_gpu_diff);
308 m_cuda.sign(colNetParams[param_id].count(), colNetParams[param_id].gpu_data,
m_colTemp[param_id].mutable_gpu_data);
309 m_cuda.axpy(colNetParams[param_id].count(), dfLocalDecay,
m_colTemp[param_id].gpu_data, colNetParams[param_id].mutable_gpu_diff);
325 if (!colNetParams[param_id].DiffExists)
328 List<double?> net_params_lr =
m_net.params_lr;
330 T fLocalRate =
Utility.ConvertVal<T>(dfRate * net_params_lr[param_id].GetValueOrDefault(0));
334 m_cuda.sgd_update(colNetParams[param_id].count(), colNetParams[param_id].mutable_gpu_diff,
m_colHistory[param_id].mutable_gpu_data, fMomentum, fLocalRate);
344 if (dfClipGradients < 0)
348 double dfSumsqDiff = 0;
350 for (
int i = 0; i < colNetParams.
Count; i++)
352 if (colNetParams[i].DiffExists)
353 dfSumsqDiff +=
Utility.ConvertVal<T>(colNetParams[i].sumsq_diff());
356 double dfL2NormDiff = Math.Sqrt(dfSumsqDiff);
358 if (dfL2NormDiff > dfClipGradients)
360 double dfScaleFactor = dfClipGradients / dfL2NormDiff;
363 m_log.
WriteLine(
"Gradient clipping: scaling down gradients (L2 norm " + dfL2NormDiff.ToString() +
" > " + dfClipGradients.ToString() +
") by scale factor " + dfScaleFactor.ToString());
365 for (
int i = 0; i < colNetParams.
Count; i++)
367 if (colNetParams[i].DiffExists)
368 colNetParams[i].scale_diff(
Utility.ConvertVal<T>(dfScaleFactor));
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
The Log class provides general output in text form.
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
void FAIL(string str)
Causes a failure which throws an exception with the desciptive text.
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
void CHECK_GT(double df1, double df2, string str)
Test whether one number is greater than another.
void CHECK_GE(double df1, double df2, string str)
Test whether one number is greater than or equal to another.
The Utility class provides general utility funtions.
The BlobCollection contains a list of Blobs.
int Count
Returns the number of items in the collection.
The Blob is the main holder of data that moves through the Layers of the Net.
BlobProto ToProto(bool bWriteDiff=false)
Writes the Blob to a new BlobProto.
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
The SolverParameter is a parameter for the solver, specifying the train and test networks.
int stepsize
The stepsize for learning rate policy 'step'.
int max_iter
The maximum number of iterations.
string regularization_type
Specifies the regularization type (default = 'L2').
string lr_policy
The learning rate decay policy.
double power
The 'power' parameter to compute the learning rate.
bool enable_clip_gradient_status
Optionally, enable status output when gradients are clipped (default = true)
int iter_size
Accumulate gradients over 'iter_size' x 'batch_size' instances.
double gamma
Specifies the 'gamma' parameter to compute the 'step', 'exp', 'inv', and 'sigmoid' learning policy (d...
int display
The number of iterations between displaying info. If display = 0, no info will be displayed.
double weight_decay
Specifies the weight decay (default = 0.0005).
List< int > stepvalue
The step values for learning rate policy 'multistep'.
double momentum
Specifies the momentum value - used by all solvers EXCEPT the 'AdaGrad' and 'RMSProp' solvers....
double base_lr
The base learning rate (default = 0.01).
double clip_gradients
Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, whenever their actual L2 norm...
The SolverState specifies the state of a given solver.
int iter
The current iteration.
List< BlobProto > history
The history for SGD solvers.
int current_step
The current step for learning rate.
Stochastic Gradient Descent solver with momentum updates weights by a linear combination of the negat...
virtual void ComputeUpdateValue(int param_id, double dfRate, int nIterationOverride=-1)
Compute the SGD update value that will be applied to a learnable blobs in the training Net.
BlobCollection< T > m_colHistory
History maintains the historical momentum data.
BlobCollection< T > history
Returns the history BlobCollection containing historical momentum data.
override void dispose()
Releases all resources (GPU and Host) used by the Solver.
override double ApplyUpdate(int nIterationOverride=-1)
Compute the update values and apply them to the training Net.
void PreSolve()
Runs the pre-solve which prepares the Solver to start Solving.
override void RestoreSolverState(byte[] rgState)
Restore the state of the Solver.
virtual void Normalize(int param_id)
Normalize a learnable Blob of the training Net.
SGDSolver(CudaDnn< T > cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
The SGDSolver constructor.
virtual void ClipGradients()
Clip the gradients of all learnable blobs in the training Net.
override byte[] SnapshotSolverState()
Take a snapshot of the Solver state.
BlobCollection< T > m_colTemp
Update maintains update related data and is not needed in snapshots.
double GetLearningRate(int nIterationOverride=-1)
Return the current learning rate.
virtual void Regularize(int param_id)
Regularize a learnable Blob of the training net.
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
double m_dfSmoothedLoss
Specifies the smoothed loss protected for derived classes to use.
SolverParameter m_param
Specifies the SolverParameter that defines how the Solver operates.
CudaDnn< T > m_cuda
Specifies the instance of CudaDnn used by the Solver that provides a connection to Cuda.
double? m_dfIterAccuracy
Specifies the iteration accuracy calculated when a blob exists with the name 'accuracy'.
double LearningRateOverride
Get/set the learning rate override. When 0, this setting is ignored.
int m_nIter
Specifies the current iteration.
IXPersist< T > m_persist
Specifies the persistance object used to save weight and solver states.
Net< T > m_net
Specifies the training Net.
int m_nCurrentStep
Specifies the current step.
Log m_log
Specifies the Log for output.
The IXDatabaseBase interface defines the general interface to the in-memory database.
The IXPersist interface is used by the CaffeControl to load and save weights.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
The MyCaffe.common namespace contains common MyCaffe classes.
The MyCaffe.db.image namespace contains all image database related classes.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...