2using System.Collections.Generic;
27 public abstract class Solver<T> : IDisposable
61 AutoResetEvent m_evtCompleted =
new AutoResetEvent(
false);
62 bool m_bEnableTest =
true;
63 bool m_bEnableBlobDebugging =
false;
64 bool m_bEnableBreakOnNan =
false;
65 bool m_bEnableDetailedNanDetection =
false;
66 bool m_bEnableSingleStep =
false;
77 AutoResetEvent m_evtForceSnapshot;
78 AutoResetEvent m_evtForceTest;
95 double m_dfLastAccuracy = 0;
96 double m_dfLastError =
double.MaxValue;
97 double m_dfBestAccuracy = 0;
98 double m_dfBestError =
double.MaxValue;
100 int m_nTrainingIterationOverride = -1;
101 int m_nTestingIterationOverride = -1;
103 bool m_bWeightsUpdated =
false;
104 static object m_syncGetRi =
new object();
105 Blob<T> m_blobBatchInputData =
null;
106 double m_dfAverageTestTime = 0;
108 int m_nTrainingTimeLimitInMinutes = 0;
109 long m_hWorkspaceData = 0;
110 ulong m_lWorkspaceSizeInBytes = 0;
111 bool m_bFirstNanError =
true;
112 List<double> m_rgAverageAccuracyWindow =
null;
113 bool m_bForceTest =
false;
146 public event EventHandler<TestArgs>
OnTest;
181 public Solver(
CudaDnn<T> cuda,
Log log,
SolverParameter p,
CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest,
IXDatabaseBase db,
IXPersist<T> persist,
int nSolverCount = 1,
int nSolverRank = 0,
Net<T> shareNet =
null, onGetWorkspace getws =
null, onSetWorkspace setws =
null)
185 m_evtCancel = evtCancel;
186 m_evtForceSnapshot = evtForceSnapshot;
187 m_evtForceTest = evtForceTest;
205 m_rgAverageAccuracyWindow =
new List<double>();
208 m_rgAverageAccuracyWindow.Add(0);
238 int nTimingCount = 0;
239 double dfTotalTime = 0;
240 return fireOnTrainingIterationEvent(
false, 0, 0, ref nTimingCount, ref dfTotalTime);
243 private bool fireOnTrainingIterationEvent(
bool bFwdPassNanFree,
double dfLoss,
double dfLastLearningRate, ref
int nTimingCount, ref
double dfTotalTime)
247 string strFirstNanBlob =
null;
250 if (m_bEnableBlobDebugging)
252 dbgInfo =
TrainingNet.GetDebugInformation(m_bEnableDetailedNanDetection);
254 if (m_bEnableBreakOnNan && dbgInfo !=
null)
259 if (strFirstNanBlob !=
null)
261 string strPass = (!bFwdPassNanFree) ?
"Forward" :
"Backward";
262 m_log.
WriteLine(
"First NaN detected in the '" + strType +
"' of blob '" + strFirstNanBlob +
"' after " + strPass +
" pass.");
265 string strLastNanBlob = dbgInfo.
DetectLastNaN(out strTypeLast);
267 if (strLastNanBlob != strFirstNanBlob && strType != strTypeLast)
268 m_log.
WriteLine(
"Last NaN detected in the '" + strTypeLast +
"' of blob '" + strLastNanBlob +
"' after " + strPass +
" pass.");
273 double dfTime = (nTimingCount > 0) ? (dfTotalTime / nTimingCount) : 0;
274 OnTrainingIteration(
this,
new TrainingIterationArgs<T>(
m_nIter, m_dfLastAccuracy, dfLoss,
m_dfSmoothedLoss, m_dfBestError, m_bWeightsUpdated,
m_net.ActiveLabelCounts,
m_net.LabelQueryHitPercents,
m_net.LabelQueryEpochs,
m_net.BoostQueryHitPercents, dfLastLearningRate, dfTime, dbgInfo));
278 if (strFirstNanBlob !=
null)
280 m_log.
WriteLine(
"Training is now stopping at iteration " +
m_nIter.ToString(
"N0") +
" as the first NaN has been detected ('" + strFirstNanBlob +
"').");
293 get {
return m_nTrainingTimeLimitInMinutes; }
294 set { m_nTrainingTimeLimitInMinutes = value; }
302 get {
return m_snapshotWeightUpdatemMethod; }
303 set { m_snapshotWeightUpdatemMethod = value; }
332 if (m_blobBatchInputData !=
null)
334 m_blobBatchInputData.
Dispose();
335 m_blobBatchInputData =
null;
338 if (m_hWorkspaceData != 0)
340 m_cuda.DisableGhostMemory();
341 m_cuda.FreeMemory(m_hWorkspaceData);
342 m_cuda.ResetGhostMemory();
343 m_hWorkspaceData = 0;
344 m_lWorkspaceSizeInBytes = 0;
353 get {
return m_bEnableTest; }
354 set { m_bEnableTest = value; }
362 get {
return m_bEnableBlobDebugging; }
363 set { m_bEnableBlobDebugging = value; }
383 get {
return m_bEnableBreakOnNan; }
384 set { m_bEnableBreakOnNan = value; }
396 get {
return m_bEnableDetailedNanDetection; }
397 set { m_bEnableDetailedNanDetection = value; }
405 get {
return m_bEnableSingleStep; }
406 set { m_bEnableSingleStep = value; }
414 get {
return m_bWeightsUpdated; }
415 set { m_bWeightsUpdated = value; }
423 get {
return m_tag; }
424 set { m_tag = value; }
446 get {
return m_net; }
493 string field_names =
"net_param, train_net_param";
494 m_log.
CHECK_GE(num_train_nets, 1,
"SolverParameter must specify a train net using one of these fields: " + field_names);
495 m_log.
CHECK_LE(num_train_nets, 1,
"SolverParameter must not contain more than one of these fields specifying a train_net: " + field_names);
500 m_log.
WriteLine(
"Creating training net specified in train_net_param.");
518 net_param.
state = net_state;
521 m_net =
new Net<T>(
m_cuda,
m_log, net_param, m_evtCancel, m_db,
Phase.NONE, m_evtCompleted, shareNet, net_OnGetWorkspace, net_OnSetWorkspace);
522 m_net.OnGetIteration += net_OnGetIteration;
524 m_blobAccuracy =
m_net.FindBlob(
"accuracy");
526 catch(Exception excpt)
528 throw new Exception(
"Initializing Training Net: " + excpt.Message);
532 private void net_OnSetWorkspace(
object sender,
WorkspaceArgs e)
543 m_cuda.DisableGhostMemory();
549 if (m_hWorkspaceData != 0)
550 m_cuda.FreeMemory(m_hWorkspaceData);
553 m_hWorkspaceData =
m_cuda.AllocMemory((
long)lCount);
556 m_cuda.ResetGhostMemory();
559 private void net_OnGetWorkspace(
object sender,
WorkspaceArgs e)
585 int num_test_nets = num_test_net_params;
587 if (num_generic_nets > 0)
598 int num_test_net_instances = num_test_nets + num_generic_net_instances;
603 if (num_test_net_instances > 0)
606 List<string> sources =
new List<string>();
607 List<NetParameter> net_params =
new List<NetParameter>();
609 for (
int i = 0; i < num_test_net_params; i++)
611 sources.Add(
"test_net_param");
619 for (
int i = 0; i < remaining_test_nets; i++)
621 sources.Add(
"net_param");
628 for (
int i = 0; i < num_test_net_instances; i++)
636 net_state.
MergeFrom(net_params[i].state);
641 net_params[i].state = net_state;
643 m_log.
WriteLine(
"Creating test net (#" + i.ToString() +
") specified by " + sources[i],
true);
650 catch (Exception excpt)
652 throw new Exception(
"Initializing Testing Nets: " + excpt.Message);
669 get {
return m_net.ActiveLabelCounts; }
677 get {
return m_net.LabelQueryHitPercents; }
685 get {
return m_net.LabelQueryEpochs; }
713 if (m_nTrainingIterationOverride > 0)
714 nIters = m_nTrainingIterationOverride;
729 if (m_nTestingIterationOverride > 0)
730 nIters = m_nTestingIterationOverride;
744 public virtual void Solve(
int nIterationOverride = -1,
byte[] rgWeights =
null,
byte[] rgState =
null,
TRAIN_STEP step =
TRAIN_STEP.NONE)
750 if (rgWeights !=
null || rgState !=
null)
757 if (nIterationOverride <= 0)
760 if (!
Step(nIterationOverride, step))
767 else if (
m_net.learnable_parameters.SnapshotRequested(
true))
785 m_net.Forward(out dfLoss);
799 if (m_blobBatchInputData !=
null)
801 m_blobBatchInputData.
Dispose();
802 m_blobBatchInputData =
null;
818 public bool Step(
int nIters,
TRAIN_STEP step =
TRAIN_STEP.NONE,
bool bZeroDiffs =
true,
bool bApplyUpdates =
true,
bool bDisableOutput =
false,
bool bDisableProgress =
false,
double? dfLossOverride =
null,
bool? bAllowSnapshot =
null)
820 Exception err =
null;
826 int stop_iter =
m_nIter + nIters;
836 m_net.EnableBreakOnFirstNaN = m_bEnableBreakOnNan && m_bEnableBlobDebugging;
837 m_net.EnableDetailedNanDetection = m_bEnableDetailedNanDetection & m_bEnableBlobDebugging;
839 Stopwatch sw =
new Stopwatch();
842 Stopwatch swTimeout =
new Stopwatch();
845 while (
m_nIter < stop_iter && !m_evtCompleted.WaitOne(0))
849 m_net.ClearParamDiffs();
852 OnStart(
this,
new EventArgs());
873 double dfLossTotal = 0;
874 double? dfAccuracyTotal =
null;
877 Stopwatch swTiming =
new Stopwatch();
878 double dfTotalTime = 0;
879 int nTimingCount = 0;
880 bool bFwdPassNanFree =
true;
885 double? dfLocalAccuracy =
null;
898 bFwdPassNanFree =
m_net.ForwardBackward(colBottom, out dfLocalLoss, step);
900 if (m_blobAccuracy !=
null)
901 dfLocalAccuracy =
Utility.ConvertVal<T>(m_blobAccuracy.
GetData(0));
904 if (
double.IsNaN(dfLocalLoss) ||
double.IsInfinity(dfLocalLoss))
906 if (m_bFirstNanError)
908 m_log.
WriteError(
new Exception(
"The local loss at iteration " +
m_nIter.ToString() +
" is invalid (NAN or INFINITY)!"));
909 m_bFirstNanError =
false;
913 if (dfLocalAccuracy.HasValue)
915 if (!dfAccuracyTotal.HasValue)
918 dfAccuracyTotal = dfAccuracyTotal + dfLocalAccuracy.Value;
921 dfLossTotal += dfLocalLoss;
924 dfTotalTime += swTiming.Elapsed.TotalMilliseconds;
928 if (!bFwdPassNanFree)
932 dfLoss = dfLossTotal / nIterCount;
933 dfLoss = dfLossOverride.GetValueOrDefault(dfLoss);
935 if (dfAccuracyTotal.HasValue)
941 bool bDisplay =
false;
942 if (!bDisplay1 && sw.ElapsedMilliseconds > 2000 && !bDisableOutput)
945 m_bFirstNanError =
true;
949 if (bDisplay && bDisplay1)
958 for (
int j = 0; j < colResult.
Count; j++)
961 int nIdx =
m_net.output_blob_indices[j];
962 string output_name =
m_net.blob_names[nIdx];
963 double loss_weight =
m_net.blob_loss_weights[nIdx];
964 double dfTotalLossWeight = 0;
965 int nResultCount = colResult[j].count();
967 for (
int k = 0; k < nResultCount; k++)
973 if (loss_weight != 0)
974 strOut +=
" (* " + loss_weight.ToString() +
" = " + (loss_weight * result_vec[k]).ToString() +
" loss)";
976 m_log.
WriteLine(
" Train net output #" + score_index.ToString() +
": " + output_name +
" = " + result_vec[k].ToString() + strOut);
981 dfTotalLossWeight += loss_weight * result_vec[k];
987 double dfAverage = dfTotalLossWeight / nResultCount;
988 m_log.
WriteLine(
" Average weighted score = " + dfAverage.ToString() +
" for '" + output_name +
"' - averaged over " + nResultCount.ToString(
"N0") +
" results.");
997 double dfLastLearningRate = 0;
999 if (step !=
TRAIN_STEP.FORWARD && bApplyUpdates)
1005 if (!bDisableProgress)
1008 bool bSnapshotTaken =
false;
1014 (m_dfLastAccuracy > m_dfBestAccuracy))))
1016 bSnapshotTaken =
true;
1019 if (m_dfLastAccuracy > m_dfBestAccuracy)
1020 m_dfBestAccuracy = m_dfLastAccuracy;
1027 fireOnTrainingIterationEvent(bFwdPassNanFree, dfLoss, dfLastLearningRate, ref nTimingCount, ref dfTotalTime);
1032 if (step !=
TRAIN_STEP.NONE || m_bEnableSingleStep)
1036 if (!bDisableOutput)
1037 m_log.
WriteLine(
"Single step (both) triggered - solving stopped after a single forward/backward pass.");
1041 if (!bDisableOutput)
1042 m_log.
WriteLine(
"Single step (forward) triggered - solving stopped after a single forward pass.");
1046 if (!bDisableOutput)
1047 m_log.
WriteLine(
"Single step (backward) triggered - solving stopped after a single backward pass.");
1053 if (!bSnapshotTaken)
1064 if (m_nTrainingTimeLimitInMinutes > 0 && swTimeout.Elapsed.TotalMinutes > m_nTrainingTimeLimitInMinutes)
1066 m_log.
WriteLine(
"A training time-limit of " + m_nTrainingTimeLimitInMinutes.ToString(
"N0") +
" minutes has been exceeded - training will now stop.");
1076 catch (Exception excpt)
1083 if (err !=
null || m_evtCancel.
WaitOne(0))
1097 public void Restore(
byte[] rgWeights,
byte[] rgState,
string strSkipBlobTypes =
null)
1099 m_net.LoadWeights(rgWeights,
m_persist,
null,
null, strSkipBlobTypes);
1101 if (rgState !=
null)
1103 m_log.
WriteLine(
"Restoring previous solver state from restore state...");
1115 public void Snapshot(
bool bForced,
bool bScheduled,
bool bUpdateDatabase =
true)
1138 private void args_OnGetWeights(
object sender,
GetBytesArgs e)
1144 private void args_OnGetState(
object sender,
GetBytesArgs e)
1161 if (dfAccuracy == 0)
1162 dfAccuracy = 0.0001;
1180 get {
return m_nTrainingIterationOverride; }
1181 set { m_nTrainingIterationOverride = value; }
1189 get {
return m_nTestingIterationOverride; }
1190 set { m_nTestingIterationOverride = value; }
1198 get {
return m_evtCompleted; }
1206 get {
return m_evtCancel; }
1230 get {
return m_net; }
1264 if (m_evtForceSnapshot ==
null)
1267 return m_evtForceSnapshot.WaitOne(0);
1278 if (m_evtForceTest ==
null)
1281 m_bForceTest = m_evtForceTest.WaitOne(0);
1282 return m_bForceTest;
1322 public double TestAll(
int nIterationOverride = -1)
1324 double dfTotalAccuracy = 0;
1325 double dfTotalTime = 0;
1326 int nTotalCount = 0;
1328 for (
int test_net_id = 0; test_net_id <
m_rgTestNets.Count; test_net_id++)
1340 dfTotalAccuracy += testOne(nIterationOverride, test_net_id);
1342 dfTotalTime += m_dfAverageTestTime;
1355 dfTotalAccuracy += testOne(nIterationOverride, 0);
1360 if (m_rgAverageAccuracyWindow !=
null)
1362 m_rgAverageAccuracyWindow.Add(dfAccuracy);
1363 m_rgAverageAccuracyWindow.RemoveAt(0);
1364 dfAccuracy = m_rgAverageAccuracyWindow.Average();
1369 double dfTime = (nTotalCount > 0) ? dfTotalTime / nTotalCount : 0;
1376 private double testOne(
int nIterationOverride = -1,
int nTestNetId = 0)
1398 Stopwatch sw =
new Stopwatch();
1404 m_log.
WriteLine(
"Iteration " +
m_nIter.ToString() +
", Testing net (#" + nTestNetId.ToString() +
")");
1415 Dictionary<int, Dictionary<int, List<Tuple<float, int>>>> rgAllTruePos =
new Dictionary<int, Dictionary<int, List<Tuple<float, int>>>>();
1416 Dictionary<int, Dictionary<int, List<Tuple<float, int>>>> rgAllFalsePos =
new Dictionary<int, Dictionary<int, List<Tuple<float, int>>>>();
1417 Dictionary<int, Dictionary<int, int>> rgAllNumPos =
new Dictionary<int, Dictionary<int, int>>();
1421 if (nIterationOverride <= 0)
1424 int nIter = nIterationOverride;
1427 for (
int i = 0; i < nIter; i++)
1440 dfLoss += iter_loss;
1442 for (
int j = 0; j < colResult.
Count; j++)
1444 m_log.
CHECK_EQ(colResult[j].width, 5,
"The width must be = 5 for SSD.");
1446 int num_det = colResult[j].height;
1448 for (
int k = 0; k < num_det; k++)
1450 int item_id = (int)result_vec[k * 5];
1451 int nLabel = (int)result_vec[k * 5 + 1];
1456 if (!rgAllNumPos.ContainsKey(j))
1457 rgAllNumPos.
Add(j,
new Dictionary<int, int>());
1459 if (!rgAllNumPos[j].ContainsKey(nLabel))
1460 rgAllNumPos[j].Add(nLabel, (
int)result_vec[k * 5 + 2]);
1462 rgAllNumPos[j][nLabel] += (int)result_vec[k * 5 + 2];
1467 float fScore = (float)result_vec[k * 5 + 2];
1468 int tp = (int)result_vec[k * 5 + 3];
1469 int fp = (int)result_vec[k * 5 + 4];
1473 if (tp == 0 && fp == 0)
1476 if (!rgAllTruePos.ContainsKey(j))
1477 rgAllTruePos.Add(j,
new Dictionary<
int, List<Tuple<float, int>>>());
1479 if (!rgAllTruePos[j].ContainsKey(nLabel))
1480 rgAllTruePos[j].Add(nLabel,
new List<Tuple<float, int>>());
1482 if (!rgAllFalsePos.ContainsKey(j))
1483 rgAllFalsePos.Add(j,
new Dictionary<
int, List<Tuple<float, int>>>());
1485 if (!rgAllFalsePos[j].ContainsKey(nLabel))
1486 rgAllFalsePos[j].Add(nLabel,
new List<Tuple<float, int>>());
1488 rgAllTruePos[j][nLabel].Add(
new Tuple<float, int>(fScore, tp));
1489 rgAllFalsePos[j][nLabel].Add(
new Tuple<float, int>(fScore, fp));
1494 if (sw.Elapsed.TotalMilliseconds > 1000)
1514 float fTotalmAP = 0;
1515 for (
int i = 0; i < rgAllTruePos.Count; i++)
1517 if (!rgAllTruePos.ContainsKey(i))
1518 m_log.
FAIL(
"Missing output_blob true_pos: " + i.ToString());
1520 Dictionary<int, List<Tuple<float, int>>> rgTruePos = rgAllTruePos[i];
1522 if (!rgAllFalsePos.ContainsKey(i))
1523 m_log.
FAIL(
"Missing output_blob false_pos: " + i.ToString());
1525 Dictionary<int, List<Tuple<float, int>>> rgFalsePos = rgAllFalsePos[i];
1527 if (!rgAllNumPos.ContainsKey(i))
1528 m_log.
FAIL(
"Missing output_blob num_pos: " + i.ToString());
1530 Dictionary<int, int> rgNumPos = rgAllNumPos[i];
1532 Dictionary<int, float> rgAPs =
new Dictionary<int, float>();
1536 foreach (KeyValuePair<int, int> kv
in rgNumPos)
1538 int nLabel = kv.Key;
1539 int nLabelNumPos = kv.Value;
1541 if (!rgTruePos.ContainsKey(nLabel))
1543 m_log.
WriteLine(
"WARNING: Missing true_pos for label: " + nLabel.ToString() +
"!");
1546 List<Tuple<float, int>> rgLabelTruePos = rgTruePos[nLabel];
1548 if (!rgFalsePos.ContainsKey(nLabel))
1550 m_log.
WriteLine(
"WARNING: Missing false_pos for label: " + nLabel.ToString() +
"!");
1553 List<Tuple<float, int>> rgLabelFalsePos = rgFalsePos[nLabel];
1559 if (!rgAPs.ContainsKey(nLabel))
1560 rgAPs.Add(nLabel, fAp);
1562 rgAPs[nLabel] = fAp;
1567 m_log.
WriteLine(
"class " + nLabel.ToString() +
": " + fAp.ToString());
1570 fmAP /= rgNumPos.Count;
1573 string strOutputName = test_net.
blob_names[nOutputBlobIdx];
1575 m_log.
WriteLine(
" Test net output #" + i.ToString() +
": " + strOutputName +
" = " + fmAP.ToString());
1579 return fTotalmAP / rgAllTruePos.Count;
1581 catch (Exception excpt)
1603 m_bForceTest =
false;
1608 m_log.
WriteLine(
"Iteration " +
m_nIter.ToString() +
", Testing net (#" + nTestNetId.ToString() +
")");
1619 List<double> test_score =
new List<double>();
1620 List<int> test_score_output_id =
new List<int>();
1623 if (nIterationOverride <= 0)
1626 int nIter = nIterationOverride;
1628 Stopwatch sw =
new Stopwatch();
1631 double dfTotalTiming = 0;
1633 int nAccuracyIdx = 0;
1634 int nMinRank =
int.MaxValue;
1635 bool bAccuracyValid =
false;
1636 Stopwatch swTiming =
new Stopwatch();
1638 for (
int i = 0; i < nIter; i++)
1653 dfLoss += iter_loss;
1662 test_score_output_id.Add(1);
1663 bAccuracyValid =
true;
1671 for (
int j = 0; j < colResult.
Count; j++)
1675 for (
int k = 0; k < colResult[j].count(); k++)
1677 test_score.Add(result_vec[k]);
1678 test_score_output_id.Add(j);
1683 int nRank = (int)getNumber(colResult[j].
Tag, 0);
1684 if (nRank < nMinRank)
1696 for (
int j = 0; j < colResult.
Count; j++)
1700 for (
int k = 0; k < colResult[j].count(); k++)
1702 test_score[idx] += result_vec[k];
1710 dfTotalTiming += swTiming.Elapsed.TotalMilliseconds;
1713 if (sw.ElapsedMilliseconds > 2000)
1715 double dfPct = (double)i / (
double)nIter;
1727 m_dfAverageTestTime = (nTestCount > 0) ? dfTotalTiming / nTestCount : 0;
1741 double dfFinalScore = 0;
1745 dfFinalScore = test_score.Sum();
1746 int nTotal = test_score_output_id.Sum();
1747 dfFinalScore /= nTotal;
1751 for (
int i = 0; i < test_score.Count; i++)
1753 int nIdxTestScore = test_score_output_id[i];
1755 string output_name = test_net.
blob_names[output_blob_index];
1757 double dfMeanScore = test_score[i] / nIter;
1762 if (loss_weight != 0)
1763 strOut +=
" (* " + loss_weight.ToString() +
" = " + (loss_weight * dfMeanScore).ToString() +
" loss)";
1765 m_log.
WriteLine(
" Test net output #" + i.ToString() +
": " + output_name +
" = " + dfMeanScore.ToString() + strOut);
1768 if (i == nAccuracyIdx)
1769 dfFinalScore = dfMeanScore;
1773 if (test_score.Count == 0)
1776 return dfFinalScore;
1779 private double getNumber(
object value,
double dfDefault)
1785 return (
double)(sbyte)value;
1788 return (
double)(byte)value;
1791 return (
double)(short)value;
1793 if (value is ushort)
1794 return (
double)(ushort)value;
1797 return (
double)(int)value;
1800 return (
double)(uint)value;
1803 return (
double)(long)value;
1806 return (
double)(ulong)value;
1809 return (
double)(float)value;
1811 if (value is
double)
1812 return (
double)value;
1814 if (value is decimal)
1815 return (
double)(decimal)value;
1828 if (nAverageLoss == 0)
1839 int nIdx = (
m_nIter - nStartIter) % nAverageLoss;
1844 if (m_bWeightsUpdated)
1847 m_bWeightsUpdated =
false;
1852 if (m_dfLastError < m_dfBestError)
1853 m_dfBestError = m_dfLastError;
1889 public static SGDSolver<T> Create(
CudaDnn<T> cuda,
Log log,
ProjectEx p,
CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest,
IXDatabaseBase db,
IXPersist<T> persist,
int nSolverCount = 1,
int nSolverRank = 0,
Net<T> shareNet =
null, onGetWorkspace getws =
null, onSetWorkspace setws =
null)
1910 return Create(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1930 public static SGDSolver<T> Create(
CudaDnn<T> cuda,
Log log,
SolverParameter solverParam,
CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest,
IXDatabaseBase db,
IXPersist<T> persist,
int nSolverCount = 1,
int nSolverRank = 0,
Net<T> shareNet =
null, onGetWorkspace getws =
null, onSetWorkspace setws =
null)
1934 switch (solverParam.
type)
1937 solver =
new SGDSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1941 solver =
new NesterovSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1945 solver =
new AdaGradSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1949 solver =
new AdaDeltaSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1953 solver =
new AdamSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1957 solver =
new AdamWSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1961 solver =
new RmsPropSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1965 throw new NotImplementedException(
"The solver " + solverParam.
type.ToString() +
" is not implemented yet!");
1972#pragma warning disable 1591
1974 public class OutputCollection
1976 OutputDataCollection m_rgError =
new OutputDataCollection();
1977 OutputDataCollection m_rgAccuracy =
new OutputDataCollection();
1979 public OutputCollection()
1983 public OutputDataCollection Errors
1985 get {
return m_rgError; }
1988 public OutputDataCollection Accuracies
1990 get {
return m_rgAccuracy; }
1994 public class OutputDataCollection : IEnumerable<OutputData>
1996 List<OutputData> m_rgData =
new List<OutputData>();
1998 public OutputDataCollection()
2002 public List<OutputData> Data
2004 get {
return m_rgData; }
2009 get {
return m_rgData.Count; }
2012 public OutputData
this[
int nIdx]
2014 get {
return m_rgData[nIdx]; }
2015 set { m_rgData[nIdx] = value; }
2018 public void Add(
int nTotal,
string strName,
int nIdx,
double dfVal)
2020 OutputData data = Find(strName);
2024 data =
new OutputData(strName, nIdx);
2028 data.Add(nTotal, dfVal);
2031 public OutputData Find(
string strName)
2033 foreach (OutputData data
in m_rgData)
2035 if (data.Name == strName)
2042 public IEnumerator<OutputData> GetEnumerator()
2044 return m_rgData.GetEnumerator();
2047 IEnumerator IEnumerable.GetEnumerator()
2049 return m_rgData.GetEnumerator();
2053 public class OutputData
2056 double m_dfValue = 0;
2059 public OutputData(
string strName,
int nIdx)
2061 m_strName = strName;
2067 get {
return m_nIdx; }
2072 get {
return m_strName; }
2077 get {
return m_dfValue; }
2078 set { m_dfValue = value; }
2081 public void Add(
int nTotal,
double dfVal)
2083 double dfRatio = 1.0 / (double)nTotal;
2084 m_dfValue = (m_dfValue * (1.0 - dfRatio)) + (dfRatio * dfVal);
2088#pragma warning restore 1591
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
The Log class provides general output in text form.
void CHECK(bool b, string str)
Test a flag for true.
bool IsEnabled
Returns whether or not the Log is enabled.
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
bool Enable
Enables/disables the Log. When disabled, the Log does not output any data.
void FAIL(string str)
Causes a failure which throws an exception with the desciptive text.
double Progress
Get/set the progress associated with the Log.
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
void WriteError(Exception e)
Write an error as output.
void CHECK_GT(double df1, double df2, string str)
Test whether one number is greater than another.
void CHECK_LE(double df1, double df2, string str)
Test whether one number is less than or equal to another.
void CHECK_GE(double df1, double df2, string str)
Test whether one number is greater than or equal to another.
The ProjectEx class manages a project containing the solver description, model description,...
string? SolverDescription
Get/set the solver description script used by the Project.
int ID
Returns the ID of the Project in the database.
string? ModelDescription
Get/set the model description script used by the Project.
The RawProto class is used to parse and output Google prototxt file data.
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
The Utility class provides general utility funtions.
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
The BBox class processes the NormalizedBBox data used with SSD.
void Dispose()
Clean up all resources.
float ComputeAP(List< Tuple< float, int > > rgTp, int nNumPos, List< Tuple< float, int > > rgFp, ApVersion apVersion, out List< float > rgPrec, out List< float > rgRec)
Compute the average precision given true positive and false positive vectors.
The BlobCollection contains a list of Blobs.
void Add(Blob< T > b)
Add a new Blob to the collection.
int Count
Returns the number of items in the collection.
The Blob is the main holder of data that moves through the Layers of the Net.
T GetData(int nIdx)
Returns the data at a given flat index within the Blob.
virtual void Dispose(bool bDisposing)
Releases all resources used by the Blob (including both GPU and Host).
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
static ulong ConvertByteSizeToCount(ulong ulSizeInBytes)
Converts the byte size into the number of items in the base data type of float or double.
The CustomForwardBackArgs provide the arguments to the OnCustomForwardBack event within the Solver St...
double LocalLoss
Get/set the local loss of the pass.
bool FwdPassNanFree
Get/set whether or a NAN was detected in the forward pass.
The GetBytesArgs is passed along to the SnapshotArgs::OnGetWeights and SnapshotArgs::OnGetState event...
byte[] Data
Get/set the data as an array of bytes.
The GetIterationArgs is sent bubbled up to the solver when a layer needs to know the curret training ...
void SetIteration(Phase p, int nIteration)
The SetIteration method is used to set the iteration and the phase.
The GradientsReadyArgs is sent to the Solver::OnGradientsReady event which fires at the end of each S...
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
List< string > blob_names
Returns the blob names.
List< double > blob_loss_weights
Returns the collection of blob loss weights.
string name
Returns the network name.
List< int > output_blob_indices
Returns a list of the output Blob indexes.
The SnapshotArgs is sent to the Solver::OnSnapshot event which fires each time the Solver::Snapshot m...
bool Forced
Get/set whether or not the snapshot was forced or not.
bool SingleStep
Get/set the Solver single step.
bool IncludeWeights
Get/set whether or not to include the weights in the snapshot.
bool Scheduled
Get/set whether or not the snapshot is a regular scheduled snapshot (e.g. not an improved accuracy or...
bool IncludeState
Get/set whether or not to include the Solver state in the snapshot.
EventHandler< GetBytesArgs > OnGetState
Specifies the OnGetState event which fires when the SnapshotArgs::UpdateState method is called.
bool UpdateDatabase
Get/set whether or not to update the database (default = true).
EventHandler< GetBytesArgs > OnGetWeights
Specifies the OnGetWeights event which fires when the SnapshotArgs::UpdateWeights method is called.
The TestArgs are passed to the Solver::OnTest event.
double Accuracy
Get/set the accuracy for the test run. When overriding the testing, the override should set the accur...
The TestResultArgs are passed to the Solver::OnTestResults event.
bool AccuracyValid
Get/set the accuracy valid flag. When not valid, the OnTestResults event is ignored.
double Accuracy
Get/set the accuracy. The recipient of this event should set this value.
Specifies the TestingIterationArgs sent to the Solver::OnTestingIteration, which is called at the end...
The TrainingIterationArgs is sent to the Solver::OnTrainingIteration event that fires at the end of a...
The WorkspaceArgs are passed to both the Layer::OnSetWorkspace and Layer::OnGetWorkspace events.
long WorkspaceData
Get/set the handle to workspace data in GPU memory.
ulong WorkspaceSizeInBytes
Get/set the workspace memory size in bytes.
The Database class manages the actual connection to the physical database using Entity Framworks from...
Specifies the parameters use to create a Net
static NetParameter FromProto(RawProto rp)
Parse a RawProto into a new instance of the parameter.
NetState state
The current 'state' of the network, including the phase, level and stage. Some layers may be included...
int ProjectID
Specifies the ID of the project that created this net param (if any).
int solver_rank
Specifies the rank of the solver using this network.
int solver_count
Specifies the number of solvers used in a multi-gpu training session.
NetParameter Clone(bool bCloneLayers=true, int? nSolverCount=null, int? nSolverRank=null)
Creates a new copy of this instance of the parameter.
Specifies the NetState which includes the phase, level and stage for which a given Net is to run unde...
Phase phase
Specifies the Phase of the NetState.
void MergeFrom(NetState ns)
Merges another NetState with this instance.
The SolverParameter is a parameter for the solver, specifying the train and test networks.
int max_iter
The maximum number of iterations.
List< int > test_iter
The number of iterations for each test.
NetParameter net_param
Inline train net param, possibly combined with one or more test nets.
bool debug_info
If true, print information about the state of the net that may help with debugging learning problems.
NetParameter train_net_param
Inline train net param, possibly combined with one or more test nets.
List< NetState > test_state
The states for the train/test nets. Must be unspecified or specified once per net.
SolverType
Defines the type of solver.
string lr_policy
The learning rate decay policy.
static SolverParameter FromProto(RawProto rp)
Parses a new SolverParameter from a RawProto.
ApVersion ap_version
Specifies the AP Version to use for average precision when using Single-Shot Detection (SSD) - (defau...
long random_seed
If non-negative, the seed with which the Solver will initialize the caffe random number generator – u...
int average_loss
Display the loss averaged over the last average_loss iterations.
int test_interval
The number of iterations between two testing phases.
bool output_average_results
Specifies to average loss results before they are output - this can be faster when there are a lot of...
int iter_size
Accumulate gradients over 'iter_size' x 'batch_size' instances.
string DebugString()
Returns a debug string for the SolverParameter.
EvaluationType
Defines the evaluation method used in the SSD algorithm.
bool snapshot_after_train
If false, don't save a snapshot after training finishes.
bool snapshot_include_weights
Specifies whether or not the snapshot includes the trained weights. The default = true.
bool test_compute_loss
Test the compute loss.
SolverParameter()
The SolverParameter constructor.
EvaluationType eval_type
Specifies the evaluation type to use when using Single-Shot Detection (SSD) - (default = NONE,...
bool test_initialization
If true, run an initial test pass before the first iteration, ensuring memory availability and printi...
List< NetParameter > test_net_param
Inline test net params.
int display
The number of iterations between displaying info. If display = 0, no info will be displayed.
bool snapshot_diff
Whether to snapshot diff in the results or not. Snapshotting diff will help debugging but the final p...
bool snapshot_include_state
Specifies whether or not the snapshot includes the solver state. The default = false....
bool show_per_class_result
Specifies whether or not to display results per class when using Single-Shot Detection (SSD) - (defau...
int accuracy_average_window
Specifies the window over which to average the accuracies (default = 0 which ignores averaging).
int snapshot
Specifies the snapshot interval.
SolverType type
Specifies the solver type.
NetState train_state
The states for the train/test nets. Must be unspecified or specified once per net.
Use AdaDelta Solver which has gradient based optimization like SGD.
Use AdaGrad Solver based optimization like SGD that tries to find rarely seen features.
Use Adam Solver which uses gradient based optimization like SGD that includes 'adaptive momentum esti...
Use AdamW Solver which uses gradient based optimization like Adam with a decoupled weight decay.
Use Nesterov's accelerated gradient Solver, which is similar to SGD, but the error gradient is comput...
Use RmsProp Solver which uses gradient based optimization like SGD.
Stochastic Gradient Descent solver with momentum updates weights by a linear combination of the negat...
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
List< Net< T > > m_rgTestNets
Specifies the testing Nets.
int TrainingIterations
Returns the current training iterations remaining.
void InitTestNets()
Initializes the Net used by the Solver for testing.
EventHandler< CustomForwardBackArgs< T > > OnCustomForwardBack
The OnCustomForwardBack allows for overriding the forward/backward operations within the solver.
int m_nSolverCount
Specifies the Solver count in a multi-GPU training session.
void Dispose()
Discards the resources (GPU and Host) used by this Solver.
double m_dfSmoothedLoss
Specifies the smoothed loss protected for derived classes to use.
SolverParameter m_param
Specifies the SolverParameter that defines how the Solver operates.
EventHandler< TrainingIterationArgs< T > > OnTrainingIteration
The OnTrainingIteration event fires at the end of each training iteration.
List< double > m_rgLosses
Specifies the Losses used to calculate the smoothed Loss.
abstract byte[] SnapshotSolverState()
Save the current solver state.
static SGDSolver< T > Create(CudaDnn< T > cuda, Log log, ProjectEx p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
Create a new Solver based on the project containing the SolverParameter.
double smoothed_loss
Returns the smoothed loss.
void Restore(byte[] rgWeights, byte[] rgState, string strSkipBlobTypes=null)
The restore method simply calls the RestoreSolverState method of the inherited class.
int iter
Returns the current training iteration.
CudaDnn< T > m_cuda
Specifies the instance of CudaDnn used by the Solver that provides a connection to Cuda.
void Snapshot(bool bForced, bool bScheduled, bool bUpdateDatabase=true)
The snapshot function implements the basic snapshotting utility that stores the learned net....
int MaximumIteration
Returns the maximum training iterations.
double? m_dfIterAccuracy
Specifies the iteration accuracy calculated when a blob exists with the name 'accuracy'.
Solver(CudaDnn< T > cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
The Solver constructor.
EventHandler< SnapshotArgs > OnSnapshot
The OnSnapshot event fires when the Solver detects that a snapshot is needed.
bool EnableBlobDebugging
When enabled, the OnTrainingIteration event is set extra debugging information describing the state o...
SolverParameter.SolverType type
Returns the type of solver.
Net< T > net
Returns the main training Net.
bool ForceOnTrainingIterationEvent()
Force an OnTrainingIterationEvent to fire.
object Tag
Returns a generic tag associated with the Solver.
double TestDetection(int nIterationOverride=-1, int nTestNetId=0)
Run an SSD detection test on a given test Net by running it through its iterations.
bool? is_root_solver
Returns whether or not this is the root solver.
double LearningRateOverride
Get/set the learning rate override. When 0, this setting is ignored.
bool EnableTesting
When enabled, the training cycle calls TestAll periodically based on the SolverParameter....
int m_nIter
Specifies the current iteration.
Net< T > TrainingNet
Returns the training Net used by the solver.
double m_dfLearningRateOverride
Optionally, specifies a learning rate override (default = 0, which ignores this setting).
EventHandler< TestingIterationArgs< T > > OnTestingIteration
The OnTestingIteration event fires at the end of each testing iteration.
void InitTrainNet(Net< T > shareNet=null)
Initializes the Net used by the solver for training.
abstract void RestoreSolverState(byte[] rgState)
Restore a solver state.
void UpdateSmoothedLoss(double dfLoss, int nStartIter, int nAverageLoss=0)
Update the avaraged loss value.
void Init(SolverParameter p, Net< T > shareNet=null)
Initializes the Solver.
bool EnableBreakOnFirstNaN
When enabled (requires EnableBlobDebugging = true), the Solver immediately stop training upon detecti...
int solver_count
Returns the solver count in a multi-GPU session.
SolverParameter parameter
Returns the SolverParameter used.
bool forceSnapshot
Returns whether or not a snapshot has been forced.
SNAPSHOT_WEIGHT_UPDATE_METHOD SnapshotWeightUpdateMethod
Get/set the snapshot weight update method.
EventHandler OnAborted
The OnAborted event fires after aborting a training cycle.
List< Net< T > > test_nets
Returns the testing Nets.
IXPersist< T > m_persist
Specifies the persistance object used to save weight and solver states.
int TrainingTimeLimitInMinutes
Get/set the training time limit in minutes. When set to 0, no time limit is imposed on training.
EventHandler< WorkspaceArgs > OnGetWorkspace
Specifies the OnGetWorkspace event that fires when the getWorkspace() function is called by a layer t...
double TestClassification(int nIterationOverride=-1, int nTestNetId=0)
Run a test on a given test Net by running it through its iterations.
void Reset()
Reset the iterations of the net.
bool Step(int nIters, TRAIN_STEP step=TRAIN_STEP.NONE, bool bZeroDiffs=true, bool bApplyUpdates=true, bool bDisableOutput=false, bool bDisableProgress=false, double? dfLossOverride=null, bool? bAllowSnapshot=null)
Steps a set of iterations through a training cycle.
double TestAll(int nIterationOverride=-1)
Run a TestAll by running all test Nets.
string LabelQueryEpochs
Return the label query epochs for the active datasource.
EventHandler< TestResultArgs< T > > OnTestResults
When specified, the OnTestResults event fires after each single test run. The recipient is responsibl...
EventHandler< GradientsReadyArgs > OnGradientsReady
The OnGradientsReady event fires after the gradients of a Solver are ready for distribution to other ...
EventHandler< WorkspaceArgs > OnSetWorkspace
Specifies the OnSetWorkspace event that fires when the setWorkspace() function is called by a layer t...
int? TestingIterations
Returns the current testing iterations remaining.
static SGDSolver< T > Create(CudaDnn< T > cuda, Log log, SolverParameter solverParam, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
Create a new Solver based on the project containing the SolverParameter.
bool forceTest
Returns whether or not a test has been forced.
int TrainingIterationOverride
Get/set the training iteration override.
EventHandler OnTestStart
The OnTestStart event fires at the start of each testing iteration.
Net< T > m_net
Specifies the training Net.
bool WeightsUpdated
Get/set when the weights have been updated.
int m_nCurrentStep
Specifies the current step.
int solver_rank
Returns this Solver's rank in a multi-GPU session.
bool EnableDetailedNanDetection
When enabled (requires EnableBlobDebugging = true), the detailed Nan (and Infinity) detection is pero...
Log m_log
Specifies the Log for output.
SnapshotArgs GetSnapshotArgs(byte[] rgState, byte[] rgWeights, double dfAccuracy, double dfError, int nIteration, SNAPSHOT_WEIGHT_UPDATE_METHOD wtUpdt)
The GetSnapshotArgs method fills out a snapshot args structure.
virtual void dispose()
Override that allows discarding of resources (GPU and Host) used by this Solver.
EventHandler< TestArgs > OnTest
When specified, the OnTest event fires during a TestAll and overrides the call to Test.
int TestingIterationOverride
Get/set the testing iteration override.
virtual void Solve(int nIterationOverride=-1, byte[] rgWeights=null, byte[] rgState=null, TRAIN_STEP step=TRAIN_STEP.NONE)
The main entry of the solver function. In default, iter will be zero. Pass in a non-zero iter number ...
EventHandler OnStart
The OnStart event fires at the start of each training iteration.
string ActiveLabelCounts
Returns a string describing the labels detected in the training along with the % that each label has ...
AutoResetEvent CompletedEvent
Returns an auto reset event that is set upon training completion.
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
CudaDnn< T > Cuda
Returns the CudaDnn instance used by the Solver.
bool EnableSingleStep
When enabled (requires EnableBlobDebugging = true), the Solver only runs one training cycle.
int m_nSolverRank
Specifies the Solver rank of this solver, where rank == 0 is the root Solver.
string LabelQueryHitPercents
Return the label query hit percentages for the active datasource.
Net< T > TestingNet
Returns the testing Net used by the solver.
bool EnableLayerDebugging
Enable/disable layer debugging which causes each layer to check for NAN/INF on each forward/backward ...
int CurrentIteration
Returns the current training iteration.
The IXDatabaseBase interface defines the general interface to the in-memory database.
The IXPersist interface is used by the CaffeControl to load and save weights.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Phase
Defines the Phase under which to run a Net.
SNAPSHOT_WEIGHT_UPDATE_METHOD
Defines the snapshot weight update method.
The MyCaffe.common namespace contains common MyCaffe classes.
BLOB_TYPE
Defines the tpe of data held by a given Blob.
TRAIN_STEP
Defines the training stepping method (if any).
The MyCaffe.db.image namespace contains all image database related classes.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...