2using System.Collections.Generic;
9using System.Threading.Tasks;
21using System.Security.Cryptography;
91 bool m_bOwnRunNet =
true;
92 MemoryStream m_msWeights =
new MemoryStream();
97 long m_hCopyBuffer = 0;
98 string m_strStage =
null;
99 bool m_bLoadLite =
false;
100 string m_strSolver =
null;
101 string m_strModel =
null;
102 ManualResetEvent m_evtSyncUnload =
new ManualResetEvent(
false);
103 ManualResetEvent m_evtSyncMain =
new ManualResetEvent(
false);
105 bool m_bEnableVerboseStatus =
false;
106 T[] m_rgRunData =
null;
136 public MyCaffeControl(
SettingsCaffe settings,
Log log,
CancelEvent evtCancel, AutoResetEvent evtSnapshot =
null, AutoResetEvent evtForceTest =
null, ManualResetEvent evtPause =
null, List<int> rgGpuId =
null,
string strCudaPath =
"",
bool bCreateCudaDnn =
false,
ConnectInfo ci =
null)
139 m_guidUser = Guid.NewGuid();
141 InitializeComponent();
143 if (evtCancel ==
null)
144 throw new ArgumentNullException(
"The cancel event must be specified!");
146 if (evtSnapshot ==
null)
147 evtSnapshot =
new AutoResetEvent(
false);
149 if (evtForceTest ==
null)
150 evtForceTest =
new AutoResetEvent(
false);
152 if (evtPause ==
null)
153 evtPause =
new ManualResetEvent(
false);
165 string[] rgstrGpuId = settings.
GpuIds.Split(
',');
167 foreach (
string str
in rgstrGpuId)
169 string strGpuId = str.Trim(
' ',
'\t',
'\n',
'\r');
170 m_rgGpu.Add(
int.Parse(strGpuId));
193 if (m_evtSyncMain.WaitOne(0))
203 if (m_hCopyBuffer != 0)
231 if (m_msWeights !=
null)
245 m_evtSyncMain.Reset();
256 string strLocation = Assembly.GetExecutingAssembly().Location;
257 return FileVersionInfo.GetVersionInfo(strLocation);
266 get {
return m_dsCi; }
274 get {
return m_strStage; }
288 s.
GpuIds = nGpuID.ToString();
293 mycaffe.
LoadLite(
Phase.TRAIN, m_strSolver, m_strModel,
null);
307 mycaffe.m_hCopyBuffer = bDst.
CopyFrom(bSrc,
false,
false, mycaffe.m_hCopyBuffer);
329 m_hCopyBuffer = bDst.
CopyFrom(bSrc,
true,
false, m_hCopyBuffer);
349 m_hCopyBuffer = bDst.
CopyFrom(bSrc,
false,
false, m_hCopyBuffer);
378 get {
return m_bEnableVerboseStatus; }
379 set { m_bEnableVerboseStatus = value; }
387 public void Unload(
bool bUnloadImageDb =
true,
bool bIgnoreExceptions =
false)
389 if (m_solver ==
null && m_net ==
null)
392 if (m_evtSyncUnload.WaitOne(0))
395 m_evtSyncUnload.Set();
399 if (m_solver !=
null)
412 if (
m_db !=
null && bUnloadImageDb)
419 IDisposable idisp =
m_db as IDisposable;
429 catch (Exception excpt)
431 if (!bIgnoreExceptions)
436 m_evtSyncUnload.Reset();
461 m_solver.
OnTest += onTest;
470 m_solver.
OnStart += onTrainingStart;
530 if (m_solver !=
null)
546 if (m_solver !=
null)
559 if (m_solver !=
null)
575 if (m_solver !=
null)
591 if (m_solver !=
null)
617 get {
return m_cuda; }
625 get {
return m_log; }
633 get {
return m_persist; }
701 return "GPU #" + nId.ToString() +
" " + m_cuda.
GetDeviceName(nId);
753 get {
return m_lastPhaseRun; }
801 m_inputShape = shape;
839 if (type !=
null && type.
Value ==
"Input")
842 if (input_param !=
null)
848 int nNum = (rgDim.
Count > 0) ?
int.Parse(rgDim[0].Value) : 1;
849 nC = (rgDim.
Count > 1) ?
int.Parse(rgDim[1].Value) : 1;
850 nH = (rgDim.
Count > 2) ?
int.Parse(rgDim[2].Value) : 1;
851 nW = (rgDim.
Count > 3) ?
int.Parse(rgDim[3].Value) : 1;
859 if (nC == 0 && nH == 0 && nW == 0)
860 throw new Exception(
"Could not dicern the shape to use for no 'sdMean' parameter was supplied and the model does not contain an 'Input' layer!");
864 m_inputShape = shape;
884 int nNum = (bMaintainBatchSize) ? shape.
dim[0] : 1;
885 int nImageChannels = shape.
dim[1];
886 int nImageHeight = (shape.
dim.Count > 2) ? shape.
dim[2] : 1;
887 int nImageWidth = (shape.
dim.Count > 3) ? shape.
dim[3] : 1;
889 transform_param =
null;
892 bool bSkipTransformParam =
false;
893 RawProto protoModel =
ProjectEx.
CreateModelForRunning(strModel,
"data", nNum, nImageChannels, nImageHeight, nImageWidth, out protoTransform, out bSkipTransformParam, stage, bSkipLossLayer);
895 if (!bSkipTransformParam)
897 if (protoTransform !=
null)
902 if (transform_param.resize_param !=
null && transform_param.resize_param.Active)
904 shape.
dim[2] = (int)transform_param.resize_param.height;
905 shape.
dim[3] = (
int)transform_param.resize_param.width;
911 string strInput =
"";
917 if (!
string.IsNullOrEmpty(strInput1))
918 strInput += strInput1;
921 if (!
string.IsNullOrEmpty(strInput))
926 if (rgInput.Count > 0)
928 np.input =
new List<string>();
929 np.input_dim =
new List<int>();
930 np.input_shape =
new List<BlobShape>();
932 foreach (KeyValuePair<string, BlobShape> kv
in rgInput)
934 np.input.Add(kv.Key);
935 np.input_shape.Add(kv.Value);
942 np.state.phase =
Phase.RUN;
960 List<int> rgShape =
new List<int>() { 1, nC, nH, nW };
964 private Stage getStage(
string strStage)
966 if (strStage ==
Stage.RNN.ToString())
969 if (strStage ==
Stage.RL.ToString())
975 private string addStage(
string strModel,
Phase phase,
string strStage)
977 if (
string.IsNullOrEmpty(strStage))
1027 private bool verifySharedWeights()
1030 if (netTest !=
null)
1036 m_log.
WriteLine(
"WARNING: Training net has a different number of parameters than the testing net!");
1040 for (
int i = 0; i < netTrain.
parameters.Count; i++)
1044 m_log.
WriteLine(
"WARNING: Training net parameter " + i.ToString() +
" is not shared with the testing net!");
1077 m_strStage = strStage;
1081 if (db !=
null && bUseDb)
1084 throw new Exception(
"The database version in the settings (" +
m_settings.
DbVersion.ToString() +
") must match the database version (" + db.GetVersion().ToString() +
") of the 'db' parameter!");
1087 if (
m_db ==
null && bUseDb)
1100 throw new NotImplementedException(
"The temporal database is not yet supported.");
1120 if (labelSelectionOverride.HasValue)
1121 lblSel = labelSelectionOverride.Value;
1123 if (itemSelectionOverride.HasValue)
1124 imgSel = itemSelectionOverride.Value;
1135 string strType =
"images";
1145 throw new NotImplementedException(
"The temporal database is not yet supported.");
1164 throw new Exception(
"You must specify a project.");
1175 if (phase ==
Phase.TEST || phase ==
Phase.TRAIN)
1183 string strSkipBlobType =
null;
1186 if (param !=
null && param.
Value ==
"True")
1195 m_solver.
OnSnapshot +=
new EventHandler<SnapshotArgs>(m_solver_OnSnapshot);
1196 m_solver.
OnTrainingIteration +=
new EventHandler<TrainingIterationArgs<T>>(m_solver_OnTrainingIteration);
1197 m_solver.
OnTestingIteration +=
new EventHandler<TestingIterationArgs<T>>(m_solver_OnTestingIteration);
1200 verifySharedWeights();
1205#warning ImageDatabase V1 only
1206 if (phase ==
Phase.TRAIN &&
m_db !=
null)
1209 if (phase ==
Phase.TEST &&
m_db !=
null)
1213 if (phase ==
Phase.RUN && !bCreateRunNet)
1214 throw new Exception(
"You cannot opt out of creating the Run net when using the RUN phase.");
1216 if (p ==
null || !bCreateRunNet)
1233 m_log.
CHECK_EQ(nC, sdMean.
Channels,
"The mean channel count does not match the datasets channel count.");
1234 m_log.
CHECK_EQ(nH, sdMean.
Height,
"The mean height count does not match the datasets height count.");
1235 m_log.
CHECK_EQ(nW, sdMean.
Width,
"The mean width count does not match the datasets width count.");
1243 if (phase ==
Phase.RUN)
1253 else if (phase ==
Phase.TEST || phase ==
Phase.TRAIN)
1259 catch (Exception excpt)
1261 m_log.
WriteLine(
"WARNING: Failed to create run net - using Test net instead for running. Error = " + excpt.Message);
1263 m_bOwnRunNet =
false;
1267 catch (Exception excpt)
1298 public bool Load(
Phase phase,
string strSolver,
string strModel,
byte[] rgWeights,
DB_LABEL_SELECTION_METHOD? labelSelectionOverride =
null,
DB_ITEM_SELECTION_METHOD? itemSelectionOverride =
null,
bool bResetFirst =
false,
IXDatabaseBase db =
null,
bool bUseDb =
true,
bool bCreateRunNet =
true,
string strStage =
null,
bool bEnableMemTrace =
false)
1304 m_strStage = strStage;
1311 strModel = addStage(strModel, phase, strStage);
1318 if (db !=
null && bUseDb)
1321 throw new Exception(
"The database version in the settings (" +
m_settings.
DbVersion.ToString() +
") must match the database version (" + db.GetVersion().ToString() +
") of the 'db' parameter!");
1324 if (
m_db ==
null && bUseDb)
1326 string strType =
"images";
1344 throw new NotImplementedException(
"The temporal database is not yet implemented!");
1355 if (labelSelectionOverride.HasValue)
1356 lblSel = labelSelectionOverride.Value;
1358 if (itemSelectionOverride.HasValue)
1359 imgSel = itemSelectionOverride.Value;
1366 if (dsTarget !=
null)
1381 throw new NotImplementedException(
"The temporal database is not yet implemented!");
1399 if (phase ==
Phase.TEST || phase ==
Phase.TRAIN)
1405 if (rgWeights !=
null)
1408 m_solver.
Restore(rgWeights,
null);
1411 m_solver.
OnSnapshot +=
new EventHandler<SnapshotArgs>(m_solver_OnSnapshot);
1412 m_solver.
OnTrainingIteration +=
new EventHandler<TrainingIterationArgs<T>>(m_solver_OnTrainingIteration);
1413 m_solver.
OnTestingIteration +=
new EventHandler<TestingIterationArgs<T>>(m_solver_OnTestingIteration);
1419 if (phase ==
Phase.RUN)
1420 throw new Exception(
"You cannot opt out of creating the Run net when using the RUN phase.");
1450 if (nC == 0 || nH == 0 || nW == 0)
1451 throw new Exception(
"Unable to size the Data Transformer for there is no Mean or Project to gather the sizing information from.");
1458 if (phase ==
Phase.RUN)
1462 if (rgWeights !=
null)
1465 loadWeights(m_net, rgWeights);
1468 else if (phase ==
Phase.TEST || phase ==
Phase.TRAIN)
1476 catch (Exception excpt)
1478 m_log.
WriteLine(
"WARNING: Failed to create run net - using Test net instead for running. Error = " + excpt.Message);
1480 m_bOwnRunNet =
false;
1484 catch (Exception excpt)
1512 public bool LoadLite(
Phase phase,
string strSolver,
string strModel,
byte[] rgWeights =
null,
bool bResetFirst =
false,
bool bCreateRunNet =
true,
SimpleDatum sdMean =
null,
string strStage =
null,
bool bEnableMemTrace =
false)
1519 m_strSolver = strSolver;
1520 m_strModel = strModel;
1522 m_strStage = strStage;
1529 strModel = addStage(strModel, phase, strStage);
1543 if (phase ==
Phase.TEST || phase ==
Phase.TRAIN)
1549 if (rgWeights !=
null)
1550 m_solver.
Restore(rgWeights,
null);
1552 m_solver.
OnSnapshot +=
new EventHandler<SnapshotArgs>(m_solver_OnSnapshot);
1553 m_solver.
OnTrainingIteration +=
new EventHandler<TrainingIterationArgs<T>>(m_solver_OnTrainingIteration);
1554 m_solver.
OnTestingIteration +=
new EventHandler<TestingIterationArgs<T>>(m_solver_OnTestingIteration);
1560 if (phase ==
Phase.RUN)
1561 throw new Exception(
"You cannot opt out of creating the Run net when using the RUN phase.");
1576 if (nC == 0 || nH == 0 || nW == 0)
1577 throw new Exception(
"Unable to size the Data Transformer for no Mean image was provided as the 'sdMean' parameter which is used to gather the sizing information.");
1584 if (phase ==
Phase.RUN)
1588 if (rgWeights !=
null)
1591 loadWeights(m_net, rgWeights);
1594 else if (phase ==
Phase.TEST || phase ==
Phase.TRAIN)
1602 catch (Exception excpt)
1604 m_log.
WriteLine(
"WARNING: Failed to create run net - using Test net instead for running. Error = " + excpt.Message);
1606 m_bOwnRunNet =
false;
1610 catch (Exception excpt)
1644 m_loadToRunShape = shape;
1655 if (bConvertToRunNet)
1675 if (transParam !=
null)
1681 throw new Exception(
"The transformer expects an image mean, yet the sdMean parameter is null!");
1694 loadWeights(m_net, rgWeights);
1696 catch (Exception excpt)
1708 string strSrc =
null;
1729 throw new Exception(
"Could not find the data source in the model!");
1735 throw new Exception(
"Could not find the data source '" + strSrc +
"' in the database.");
1742 string strTestSrc =
null;
1743 string strTrainSrc =
null;
1749 string strSrc =
null;
1761 strTrainSrc = strSrc;
1763 strTestSrc = strSrc;
1767 if (strTrainSrc !=
null && strTestSrc !=
null)
1769 if (dsPrimary ==
null || (strTrainSrc != dsPrimary.TrainingSourceName && strTestSrc != dsPrimary.TestingSourceName))
1774 if (strTrainSrc ==
null || strTestSrc ==
null)
1777 if (dsPrimary !=
null && (strTrainSrc == dsPrimary.TrainingSourceName && strTestSrc == dsPrimary.TestingSourceName))
1784 throw new Exception(
"The datset sources '" + strTestSrc +
"' and '" + strTrainSrc +
"' do not exist in the database - do you need to load them?");
1789 private void loadWeights(
Net<T> net,
byte[] rgWeights)
1804 m_log.
WriteLine(
"WARNING: The number of learnable parameters differs between the two nets!");
1819 m_log.
WriteLine(
"WARNING: The name of the blobs at index " + i.ToString() +
" differ: net1 - " + blob1.
Name +
" vs. net2 - " + blob2.
Name);
1859 void m_solver_OnSnapshot(
object sender,
SnapshotArgs e)
1876 public void Train(
int nIterationOverride = -1,
int nTrainingTimeLimitInMinutes = 0,
TRAIN_STEP step =
TRAIN_STEP.NONE,
double dfLearningRateOverride = 0,
bool bReset =
false)
1878 m_lastPhaseRun =
Phase.TRAIN;
1880 if (nIterationOverride == -1)
1890 if (dfLearningRateOverride > 0)
1897 if (nTrainingTimeLimitInMinutes > 0)
1899 m_log.
WriteLine(
"You have a training time-limit of " + nTrainingTimeLimitInMinutes.ToString(
"N0") +
" minutes. Multi-GPU training is not supported when a training time-limit is imposed.");
1909 m_solver.
Solve(-1,
null,
null, step);
1912 catch (Exception excpt)
1922 private string listToString(List<int> rg)
1926 for (
int i = 0; i < rg.Count; i++)
1928 strOut += rg[i].ToString();
1930 if (i < rg.Count - 1)
1942 public double Test(
int nIterationOverride = -1)
1944 m_lastPhaseRun =
Phase.TEST;
1946 if (nIterationOverride == -1)
1964 throw new Exception(
"Custom input is only supported by MODEL based datasets!");
1966 m_lastPhaseRun =
Phase.RUN;
1982 res =
Run(customInput, nK, dfThreshold, nMax);
1997 m_lastPhaseRun =
Phase.RUN;
2005 if (customInput.
GetProperty(
"Temporal") ==
"True")
2008 throw new Exception(
"TestManyEx currently only supports temporal testing.");
2025 Dictionary<int, int> rgMissedThreshold =
new Dictionary<int, int>();
2027 m_lastPhaseRun =
Phase.RUN;
2031 m_log.
CHECK_GT(nCount, 0,
"You must select at least 1 image to train on!");
2033 Stopwatch sw =
new Stopwatch();
2045 string strSet = (bOnTrainingSet) ?
"training" :
"test";
2046 int nCorrectCount = 0;
2047 Dictionary<int, int> rgCorrectCounts =
new Dictionary<int, int>();
2048 Dictionary<int, int> rgLabelTotals =
new Dictionary<int, int>();
2049 Dictionary<int, Dictionary<int, int>> rgDetectedCounts =
new Dictionary<int, Dictionary<int, int>>();
2055 strSet = (bOnTrainingSet) ?
"target training" :
"target test";
2060 m_log.
WriteHeader(
"Test Many (" + nCount.ToString() +
") - on " + strSet +
" '" + strSrc +
"'");
2083 if (nImageStartIdx < 0)
2086 List<SimpleDatum> rgImg =
null;
2087 if (dtImageStartTime.HasValue && dtImageStartTime.Value > DateTime.MinValue)
2089 m_log.
WriteLine(
"INFO: Starting test many at images with time " + dtImageStartTime.Value.ToString() +
" or later...");
2091 if (nCount > rgImg.Count)
2092 nCount = rgImg.Count;
2095 throw new Exception(
"No images found after time '" + dtImageStartTime.Value.ToString() +
"'. Make sure to use the LOAD_ALL image loading method when running TestMany after a specified time.");
2098 List<Tuple<SimpleDatum, ResultCollection>> rgrgResults =
new List<Tuple<SimpleDatum, ResultCollection>>();
2099 int nTotalCount = 0;
2102 int nImgCount = nCount;
2109 List<int> rgOriginalRunNetInputShape =
null;
2114 for (
int i = 0; i < nCount; i++)
2135 Trace.WriteLine(
"You should not be here.");
2136 throw new Exception(
"NO DATA!");
2142 List<int> rgIgnoreLabels =
null;
2143 if (accuracyParam !=
null && accuracyParam.ignore_labels !=
null)
2144 rgIgnoreLabels = accuracyParam.ignore_labels;
2146 List<ResultCollection> rgrgResults1 =
Run(blobData,
false,
false,
int.MaxValue, rgIgnoreLabels);
2151 rgrgResults.Add(
new Tuple<SimpleDatum, ResultCollection>(sd, rgResults));
2155 Dictionary<int, List<Result>> rgLabeledResults =
new Dictionary<int, List<Result>>();
2156 Dictionary<int, int> rgLabeledOrder =
new Dictionary<int, int>();
2161 if (!rgLabeledResults.ContainsKey(result.
Label))
2163 rgLabeledResults.Add(result.
Label,
new List<Result>());
2164 rgLabeledOrder.Add(result.
Label, nIdx);
2168 rgLabeledResults[result.
Label].Add(result);
2171 List<Tuple<int, List<Result>>> rgBestResults =
new List<Tuple<int, List<Result>>>();
2172 List<int> rgDetectedLabels = rgLabeledOrder.OrderBy(p => p.Value).Select(p => p.Key).ToList();
2182 if (!rgCorrectCounts.ContainsKey(nExpectedLabel))
2183 rgCorrectCounts.
Add(nExpectedLabel, 0);
2185 if (!rgLabelTotals.ContainsKey(nExpectedLabel))
2186 rgLabelTotals.Add(nExpectedLabel, 1);
2188 rgLabelTotals[nExpectedLabel]++;
2190 if (rgDetectedLabels.Contains(nExpectedLabel))
2192 rgCorrectCounts[nExpectedLabel]++;
2207 int nExpectedLabel = sd.
Label;
2215 if (labelMapping !=
null)
2218 m_log.
FAIL(
"You can use either the LabelMappingLayer or the DataTransformer label_mapping, but not both!");
2220 nExpectedLabel = labelMapping.MapLabel(nExpectedLabel);
2223 if (!rgCorrectCounts.ContainsKey(nExpectedLabel))
2224 rgCorrectCounts.Add(nExpectedLabel, 0);
2226 if (!rgLabelTotals.ContainsKey(nExpectedLabel))
2227 rgLabelTotals.Add(nExpectedLabel, 1);
2229 rgLabelTotals[nExpectedLabel]++;
2231 if (nExpectedLabel == nDetectedLabel)
2234 rgCorrectCounts[nExpectedLabel]++;
2237 if (!rgDetectedCounts.ContainsKey(nExpectedLabel))
2238 rgDetectedCounts.Add(nExpectedLabel,
new Dictionary<int, int>());
2240 if (!rgDetectedCounts[nExpectedLabel].ContainsKey(nDetectedLabel))
2241 rgDetectedCounts[nExpectedLabel].Add(nDetectedLabel, 0);
2243 rgDetectedCounts[nExpectedLabel][nDetectedLabel]++;
2249 if (!rgMissedThreshold.ContainsKey(nExpectedLabel))
2250 rgMissedThreshold.Add(nExpectedLabel, 0);
2252 rgMissedThreshold[nExpectedLabel]++;
2256 double dfPct = ((double)i / (
double)nCount);
2259 if (sw.ElapsedMilliseconds > 1000)
2261 m_log.
WriteLine(
"processing test many at " + dfPct.ToString(
"P"));
2269 if (rgOriginalRunNetInputShape !=
null)
2272 m_net.
input_blobs[0].Reshape(rgOriginalRunNetInputShape);
2277 if (blobData !=
null)
2284 double dfCorrectPct = (nTotalCount == 0) ? 0 : ((
double)nCorrectCount / (double)nTotalCount);
2287 m_log.
WriteLine(
" " + dfCorrectPct.ToString(
"P") +
" correct detections.");
2288 m_log.
WriteLine(
" " + (nTotalCount - nCorrectCount).ToString(
"N") +
" incorrect detections.");
2290 foreach (KeyValuePair<int, int> kv
in rgCorrectCounts.OrderBy(p => p.Key).ToList())
2294 foreach (KeyValuePair<int, int> kv1
in rgLabelTotals)
2296 if (kv1.Key == kv.Key)
2305 string strSecondDetection =
"";
2306 if (rgDetectedCounts.ContainsKey(kv.Key))
2308 List<KeyValuePair<int, int>> rgDetectedCountsSorted = rgDetectedCounts[kv.Key].OrderByDescending(p => p.Value).ToList();
2309 if (rgDetectedCountsSorted.Count > 1)
2311 strSecondDetection =
" (secondary detections: " + rgDetectedCountsSorted[1].Key.ToString();
2313 if (rgDetectedCountsSorted.Count > 2)
2314 strSecondDetection +=
" and " + rgDetectedCountsSorted[2].Key.ToString();
2316 strSecondDetection +=
")";
2320 dfCorrectPct = ((double)kv.Value / (
double)nCount);
2321 m_log.
WriteLine(
"Label #" + kv.Key.ToString() +
" had " + dfCorrectPct.ToString(
"P") +
" correct detections out of " + nCount.ToString(
"N0") +
" items with this label." + strSecondDetection);
2327 int nTotalBelow = 0;
2328 int nCorrectBelow = 0;
2329 int nTotalAbove = 0;
2330 int nCorrectAbove = 0;
2331 int nTotalBelowAndAbove = 0;
2332 int nCorrectBelowAndAbove = 0;
2334 List<KeyValuePair<int, int>> rgLabelTotalsList = rgLabelTotals.OrderBy(p => p.Key).ToList();
2335 List<KeyValuePair<int, int>> rgCorrectCountsList = rgCorrectCounts.OrderBy(p => p.Key).ToList();
2337 for (
int i = 0; i < rgLabelTotalsList.Count; i++)
2341 nTotalBelow += rgLabelTotalsList[i].Value;
2342 nCorrectBelow += rgCorrectCountsList[i].Value;
2343 nTotalBelowAndAbove += rgLabelTotalsList[i].Value;
2344 nCorrectBelowAndAbove += rgCorrectCountsList[i].Value;
2346 else if (i > nMidPoint)
2348 nTotalAbove += rgLabelTotalsList[i].Value;
2349 nCorrectAbove += rgCorrectCountsList[i].Value;
2350 nTotalBelowAndAbove += rgLabelTotalsList[i].Value;
2351 nCorrectBelowAndAbove += rgCorrectCountsList[i].Value;
2355 dfCorrectPct = (nTotalBelow == 0) ? 0 : nCorrectBelow / (
double)nTotalBelow;
2356 m_log.
WriteLine(
"Correct below midpoint of " + nMidPoint.ToString() +
" = " + dfCorrectPct.ToString(
"P"));
2357 dfCorrectPct = (nTotalAbove == 0) ? 0 : nCorrectAbove / (
double)nTotalAbove;
2358 m_log.
WriteLine(
"Correct above midpoint of " + nMidPoint.ToString() +
" = " + dfCorrectPct.ToString(
"P"));
2359 dfCorrectPct = (nTotalBelowAndAbove == 0) ? 0 : nCorrectBelowAndAbove / (
double)nTotalBelowAndAbove;
2360 m_log.
WriteLine(
"Correct below and above midpoint of " + nMidPoint.ToString() +
" = " + dfCorrectPct.ToString(
"P"));
2363 if (rgMissedThreshold.Count > 0)
2368 foreach (KeyValuePair<int, int> kv
in rgMissedThreshold)
2370 m_log.
WriteLine(
"Expected Label " + kv.Key.ToString() +
": " + kv.Value.ToString() +
" items missed threshold (" + ((double)kv.Value/nImgCount).ToString(
"P") +
").");
2374 m_log.
WriteLine(
"A total of " + nTotal.ToString() +
" items did not meet the threshold of " + dfThreshold.Value.ToString() +
", (" + ((double)nTotal / nImgCount).ToString(
"P") +
")");
2390 return Run(sd,
true, bPad);
2399 public List<ResultCollection>
Run(List<int> rgImageIdx, ref
Blob<T> blob)
2401 List<SimpleDatum> rgSd =
new List<SimpleDatum>();
2403 foreach (
int nImageIdx
in rgImageIdx)
2410 return Run(rgSd, ref blob,
false,
int.MaxValue);
2418 public List<ResultCollection>
Run(List<int> rgImageIdx)
2420 List<SimpleDatum> rgSd =
new List<SimpleDatum>();
2423 throw new Exception(
"Running on indexes requires a full Load that includes loading the dataset.");
2425 foreach (
int nImageIdx
in rgImageIdx)
2433 List<ResultCollection> rgRes =
Run(rgSd, ref blob);
2441 private int getCount(List<int> rg)
2445 foreach (
int nDim
in rg)
2465 blob.SetData(d,
true);
2481 int nCount = getCount(rgShape);
2482 blob.Reshape(rgShape);
2485 if (m_rgRunData ==
null || m_rgRunData.Length != nCount)
2486 m_rgRunData =
new T[nCount];
2489 Array.Copy(rgData, 0, m_rgRunData, 0, rgData.Length);
2491 blob.mutable_cpu_data = m_rgRunData;
2510 throw new Exception(
"The Run net has not been created!");
2527 colResults = m_solver.
TrainingNet.Forward(colBottom, out dfLoss, bPad);
2531 lastLayerType = m_net.
layers[m_net.
layers.Count - 1].type;
2532 colResults = m_net.
Forward(colBottom, out dfLoss, bPad);
2537 List<int> rgShape =
Utility.Clone<
int>(colResults[0].shape());
2539 if (rgShape[0] <= 0)
2541 colResults[0].
Reshape(rgShape);
2544 List<Result> rgResults =
new List<Result>();
2545 float[] rgData =
Utility.ConvertVecF<T>(colResults[0].update_cpu_data());
2547 if (colResults[0].type ==
BLOB_TYPE.MULTIBBOX)
2549 int nNum = rgData.Length / 7;
2551 for (
int n = 0; n < nNum; n++)
2553 int i = (int)rgData[(n * 7)];
2554 int nLabel = (int)rgData[(n * 7) + 1];
2555 double dfScore = rgData[(n * 7) + 2];
2556 double[] rgExtra =
new double[4];
2557 rgExtra[0] = rgData[(n * 7) + 3];
2558 rgExtra[1] = rgData[(n * 7) + 4];
2559 rgExtra[2] = rgData[(n * 7) + 5];
2560 rgExtra[3] = rgData[(n * 7) + 6];
2562 rgResults.Add(
new Result(nLabel, dfScore, rgExtra));
2567 for (
int i = 0; i < rgData.Length; i++)
2569 double dfProb = rgData[i];
2570 rgResults.Add(
new Result(i, dfProb));
2578 catch (Exception excpt)
2599 public List<ResultCollection>
Run(List<SimpleDatum> rgSd, ref
Blob<T> blob,
bool bUseSolverNet =
false,
int nMax =
int.MaxValue)
2604 throw new Exception(
"The Run net has not been created!");
2606 List<ResultCollection> rgFinalResults =
new List<ResultCollection>();
2607 int nBatchSize = rgSd.Count;
2608 int nChannels = rgSd[0].Channels;
2609 int nHeight = rgSd[0].Height;
2610 int nWidth = rgSd[0].Width;
2611 List<T> rgDataInput =
new List<T>();
2614 blob =
new common.Blob<T>(m_cuda,
m_log, nBatchSize, nChannels, nHeight, nWidth,
false);
2617 for (
int i=0; i<rgSd.Count && i < nMax; i++)
2623 blob.Reshape(nCount, nChannels, nHeight, nWidth);
2624 blob.mutable_cpu_data = rgDataInput.ToArray();
2637 colResults = m_solver.
TrainingNet.Forward(colBottom, out dfLoss);
2642 lastLayerType = m_net.
layers[m_net.
layers.Count - 1].type;
2643 colResults = m_net.
Forward(colBottom, out dfLoss,
true);
2646 T[] rgDataOutput = colResults[0].update_cpu_data();
2647 int nOutputCount = rgDataOutput.Length / rgSd.
Count;
2649 for (
int i = 0; i < rgSd.Count && i < nMax; i++)
2651 List<Result> rgResults =
new List<Result>();
2653 for (
int j = 0; j < nOutputCount; j++)
2655 int nIdx = i * nOutputCount + j;
2656 double dfProb = (double)Convert.ChangeType(rgDataOutput[nIdx], typeof(
double));
2657 rgResults.Add(
new Result(j, dfProb));
2665 rgFinalResults.Add(result);
2668 return rgFinalResults;
2680 public List<ResultCollection>
Run(
Blob<T> blob,
bool bSort =
true,
bool bUseSolverNet =
false,
int nMax =
int.MaxValue, List<int> rgIgnoreLabels =
null)
2685 throw new Exception(
"The Run net has not been created!");
2687 if (
m_dataSet ==
null && (m_loadToRunShape ==
null || m_loadToRunShape.
dim.Count < 4))
2688 throw new Exception(
"Cannot determine the blob shape, you must either load with a database, or use LoadToRun before calling Run with a Blob. When using LoadToRun, the shape must have at least 4 dimensions.");
2690 List<ResultCollection> rgFinalResults =
new List<ResultCollection>();
2691 int nBatchSize = blob.
num;
2694 throw new Exception(
"The blob channels must match those of the testing dataset which has channels = " +
m_dataSet.
TestingSource.
Channels.ToString());
2701 List<int> rgShape =
m_dataTransformer.InferBlobShape(nChannels, nWidth, nHeight);
2702 nHeight = rgShape[2];
2703 nWidth = rgShape[3];
2706 if (blob.
height != nHeight)
2707 throw new Exception(
"The blob height must match those of the testing dataset which has height = " + nHeight.ToString());
2709 if (blob.
width != nWidth)
2710 throw new Exception(
"The blob width must match those of the testing dataset which as width = " + nWidth.ToString());
2724 colResults = m_solver.
TrainingNet.Forward(colBottom, out dfLoss);
2729 lastLayerType = m_net.
layers[m_net.
layers.Count - 1].type;
2730 colResults = m_net.
Forward(colBottom, out dfLoss,
true);
2733 float[] rgData =
Utility.ConvertVecF<T>(colResults[0].update_cpu_data());
2734 int nOutputCount = rgData.Length / blob.
num;
2736 int nNum = blob.
num;
2742 for (
int n = 0; n < nNum && n < nMax; n++)
2744 List<Result> rgResults =
new List<Result>();
2746 if (colResults[0].type ==
BLOB_TYPE.MULTIBBOX)
2748 int i = (int)rgData[(n * 7)];
2749 int nLabel = (int)rgData[(n * 7) + 1];
2750 double dfScore = rgData[(n * 7) + 2];
2751 double[] rgExtra =
new double[4];
2752 rgExtra[0] = rgData[(n * 7) + 3];
2753 rgExtra[1] = rgData[(n * 7) + 4];
2754 rgExtra[2] = rgData[(n * 7) + 5];
2755 rgExtra[3] = rgData[(n * 7) + 6];
2757 rgResults.Add(
new Result(nLabel, dfScore, rgExtra));
2761 for (
int j = 0; j < nOutputCount; j++)
2763 int nIdx = n * nOutputCount + j;
2764 double dfProb = rgData[nIdx];
2766 if (rgIgnoreLabels !=
null && rgIgnoreLabels.Contains(j))
2769 dfProb =
double.MaxValue;
2774 rgResults.Add(
new Result(j, dfProb));
2782 rgFinalResults.Add(result);
2785 return rgFinalResults;
2801 throw new Exception(
"The Run net has not been created!");
2803 int nChannels = m_inputShape.
dim[1];
2805 if (typeof(T) == typeof(
double))
2820 return Run(d, bSort,
false, bPad);
2833 if (customInput !=
null)
2835 string strPhase = customInput.
GetProperty(
"Phase",
false);
2836 if (!
string.IsNullOrEmpty(strPhase))
2838 if (strPhase ==
Phase.TRAIN.ToString())
2839 phase =
Phase.TRAIN;
2840 else if (strPhase ==
Phase.TEST.ToString())
2842 else if (strPhase ==
Phase.RUN.ToString())
2847 m_log.
WriteLine(
"INFO: Running TestMany with the " + phase.ToString() +
" phase.");
2854 foreach (
Blob<T> blob
in colTop)
2856 string strName = blob.
Name;
2860 res.SetPropertyBlob(strName, rgBytes);
2861 res.SetPropertyInt(strName, (
int)blob.
type);
2877 if (customInput !=
null)
2879 string strPhase = customInput.
GetProperty(
"Phase",
false);
2880 if (!
string.IsNullOrEmpty(strPhase))
2882 if (strPhase ==
Phase.TRAIN.ToString())
2883 phase =
Phase.TRAIN;
2884 else if (strPhase ==
Phase.TEST.ToString())
2886 else if (strPhase ==
Phase.RUN.ToString())
2891 m_log.
WriteLine(
"INFO: Running TestMany with the " + phase.ToString() +
" phase.");
2912 string strInput = customInput.
GetProperty(
"InputData");
2913 string[] rgstrInput = strInput.Split(
'|');
2914 List<string> rgstrOutput =
new List<string>();
2917 foreach (
string strInput1
in rgstrInput)
2920 string strOut =
"\n";
2928 if (colBottom !=
null)
2936 if (colBottom ==
null)
2937 throw new Exception(
"At least one layer must support the 'PreprocessInput' method!");
2956 List<string> rgOutput =
new List<string>();
2957 List<Tuple<string, int, double>> res;
2959 Stopwatch sw =
new Stopwatch();
2962 for (
int i = 0; i < nMax; i++)
2964 if (layerInput.SupportsPostProcessingLogits)
2967 blobTop = m_net.
FindBlob(
"logits");
2968 if (blobTop ==
null)
2969 throw new Exception(
"Could not find the 'logits' blob!");
2970 res = layerInput.PostProcessLogitsOutput(i, blobTop, softmax, nAxis, nK);
2973 res = layerInput.PostProcessOutput(blobTop);
2975 if (!layerInput.PreProcessInput(
null, res[0].Item2, colBottom))
2978 rgOutput.Add(res[0].Item1);
2980 colTop = m_net.
Forward(colBottom, out dfLoss, layerInput.SupportsPostProcessingLogits);
2981 blobTop = colTop[0];
2984 if (sw.Elapsed.TotalMilliseconds > 1000)
2986 double dfPct = (double)nCount / nMax;
2987 m_log.
WriteLine(
"Generating response at " + dfPct.ToString(
"P") +
"...");
2992 foreach (
string str
in rgOutput)
3001 List<Tuple<double, bool, List<Tuple<string, int, double>>>> res = search.
Search(input, nK, dfThreshold, nMax);
3005 for (
int i = 0; i < res[0].Item3.Count; i++)
3007 strOut += res[0].Item3[i].Item1.ToString() +
" ";
3011 strOut = strOut.Trim();
3014 rgstrOutput.Add(strOut);
3017 string strFinal =
"";
3018 foreach (
string str
in rgstrOutput)
3020 strFinal += str +
"|";
3023 strFinal = clean(strFinal);
3035 return m_net.
Forward(colBottom, out dfLoss,
true);
3038 private string clean(
string strFinal)
3042 foreach (
char ch
in strFinal)
3063 throw new Exception(
"The GetTestImage only works with non-temporal databases.");
3072 if (strLabel ==
null || strLabel.Length == 0)
3087 throw new Exception(
"The GetTestImage only works with non-temporal databases.");
3109 throw new Exception(
"The GetTestImage only works with non-temporal databases.");
3117 if (strLabel ==
null || strLabel.Length == 0)
3118 strLabel = nLabel.ToString();
3138 throw new Exception(
"The GetTestImage only works with non-temporal databases.");
3145 if (strLabel ==
null || strLabel.Length == 0)
3146 strLabel = nLabel.ToString();
3161 throw new Exception(
"The image database is null!");
3163 if (m_solver ==
null)
3164 throw new Exception(
"The solver is null - make sure that you are loaded for training.");
3166 if (m_solver.
net ==
null)
3167 throw new Exception(
"The solver net is null - make sure that you are loaded for training.");
3169 string strSrc = m_solver.
net.GetDataSource();
3197 return m_solver.
net.SaveWeights(m_persist);
3208 bool? bLogEnabled =
null;
3218 if (m_net !=
null && m_bOwnRunNet)
3226 for (
int i = 0; i < m_solver.
net.learnable_parameters.Count; i++)
3228 Blob<T> b = m_solver.
net.learnable_parameters[i];
3232 if (bRun.
CopyFrom(b,
false,
true) != 0)
3244 m_log.
FAIL(
"Could not find the run blob '" + bRun.
Name +
"' in the solver net!");
3250 if (nCopyCount == 0)
3251 loadWeights(m_net, m_solver.
net.SaveWeights(m_persist));
3253 catch (Exception excpt)
3255 m_log.
WriteLine(
"WARNING: " + excpt.Message +
", attempting to load with legacy (slower method)...");
3256 loadWeights(m_net, m_solver.
net.SaveWeights(m_persist));
3263 m_log.
WriteLine(
"WARNING: The run weights differ from the training weights!");
3268 if (bLogEnabled.HasValue)
3280 loadWeights(m_net, rgWeights);
3284 List<string> rgExpectedShapes =
new List<string>();
3292 m_persist.
LoadWeights(rgWeights, rgExpectedShapes, m_solver.
TrainingNet.learnable_parameters,
false, out bLoadDiffs);
3306 if (cudaOverride ==
null)
3307 cudaOverride = m_cuda;
3311 loadWeights(net, rgWeights);
3330 if (phase ==
Phase.ALL)
3331 phase = m_lastPhaseRun;
3333 if (phase ==
Phase.NONE)
3336 if (phase ==
Phase.TEST)
3337 return (m_solver !=
null) ? m_solver.
TestingNet :
null;
3339 else if (phase ==
Phase.TRAIN)
3340 return (m_solver !=
null) ? m_solver.
TrainingNet :
null;
3360 m_solver.
Snapshot(
true,
false, bUpdateDatabase);
3382 string str = Properties.Resources.LICENSE;
3383 int nYear = DateTime.Now.Year;
3386 str = replaceMacro(str,
"$$YEAR$$",
"-" + nYear.ToString());
3388 str = replaceMacro(str,
"$$YEAR$$",
"");
3390 if (strOtherLicenses !=
null && strOtherLicenses.Length > 0)
3391 str = replaceMacro(str,
"$$OTHERLICENSES$$", strOtherLicenses);
3393 return fixupReturns(str);
3413 public bool VerifyCompute(
string strExtra =
null,
int nDeviceID = -1,
bool bThrowException =
true)
3416 throw new Exception(
"You must initialize the MyCaffeControl with an instance of CudaDnn<T>, or Load a new project.");
3422 if (nDeviceID == -1)
3426 string strCompute = parse(strDevName,
"compute ",
")");
3427 string[] rgstr = strCompute.Split(
'.');
3428 string strMajor = rgstr[0];
3429 string strMinor = rgstr[1];
3430 if (strMajor ==
null || strMinor ==
null)
3431 throw new Exception(
"Could not find the current device's major and minor version information!");
3433 int nMajor =
int.Parse(strMajor);
3434 int nMinor =
int.Parse(strMinor);
3436 if (nMajor < nMinMajor || (nMajor == nMinMajor && nMinor < nMinMinor))
3438 string strErr =
"The device " + nDeviceID.ToString() +
" - '" + strDevName +
" does not meet the minimum compute of '" + nMinMajor.ToString() +
"." + nMinMinor.ToString() +
"' required by the CudaDnnDll used ('" + strDll +
"')!";
3439 if (!
string.IsNullOrEmpty(strExtra))
3441 throw new Exception(strErr);
3447 private string parse(
string str,
string strT1,
string strT2)
3449 int nPos = str.IndexOf(strT1);
3453 str = str.Substring(nPos + strT1.Length);
3454 nPos = str.IndexOf(strT2);
3458 return str.Substring(0, nPos).Trim();
3461 private static string replaceMacro(
string str,
string strMacro,
string strReplacement)
3463 int nPos = str.IndexOf(strMacro);
3468 string strA = str.Substring(0, nPos);
3470 strA += strReplacement;
3471 strA += str.Substring(nPos + strMacro.Length);
3476 private static string fixupReturns(
string str)
3480 foreach (
char ch
in str)
3530 return m_cuda.
RunExtension(hExtension, lfnIdx, rgParam);
3564 return Utility.ConvertVecF<T>(rg);
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
DatasetDescriptor m_dataSet
The dataset descriptor of the dataset used in the image database.
string GetDeviceName(int nDeviceID)
Returns the device name of a given device ID.
EventHandler< TrainingIterationArgs< T > > OnTrainingIteration
The OnTrainingIteration event fires at the end of each training iteration.
static FileVersionInfo Version
Get the file version of the MyCaffe assembly running.
long CreateExtension(string strExtensionDLLPath)
Create and load a new extension DLL.
ResultCollection Run(SimpleDatum d, bool bSort=true, bool bPad=true)
Run on a given Datum.
void Snapshot(bool bUpdateDatabase=true)
The Snapshot function forces a snapshot to occur.
void RemoveCancelOverrideByName(string strEvtCancel)
Remove a cancel override.
bool Load(Phase phase, string strSolver, string strModel, byte[] rgWeights, DB_LABEL_SELECTION_METHOD? labelSelectionOverride=null, DB_ITEM_SELECTION_METHOD? itemSelectionOverride=null, bool bResetFirst=false, IXDatabaseBase db=null, bool bUseDb=true, bool bCreateRunNet=true, string strStage=null, bool bEnableMemTrace=false)
Load a project and optionally the MyCaffeImageDatabase.
bool EnableTesting
Enable/disable testing. For example reinforcement learning does not use testing.
bool? EnableDetailedNanDetection
When enabled (requires EnableBlobDebugging = true), the detailed Nan (and Infinity) detection is pero...
void LoadToRun(string strModel, byte[] rgWeights, BlobShape shape, SimpleDatum sdMean=null, TransformationParameter transParam=null, bool bForceBackward=false, bool bConvertToRunNet=true)
The LoadToRun method loads the MyCaffeControl for running only (e.g. deployment).
static void ResetDevice(int nDeviceID)
Reset the device at the given device ID.
void dispose()
Releases all GPU and Host resources used by the CaffeControl.
List< int > ActiveGpus
Returns a list of Active GPU's used by the control.
EventHandler< TestingIterationArgs< T > > OnTestingIteration
The OnTestingIteration event fires at the end of each testing iteration.
AutoResetEvent m_evtForceSnapshot
An auto-reset event used to force a snapshot.
MyCaffeControl(SettingsCaffe settings, Log log, CancelEvent evtCancel, AutoResetEvent evtSnapshot=null, AutoResetEvent evtForceTest=null, ManualResetEvent evtPause=null, List< int > rgGpuId=null, string strCudaPath="", bool bCreateCudaDnn=false, ConnectInfo ci=null)
The MyCaffeControl constructor.
SettingsCaffe Settings
Returns the settings used to create the control.
PropertySet TestMany(PropertySet customInput)
Test on custom input data.
Bitmap GetTargetImage(int nSrcId, int nIdx, out int nLabel, out string strLabel, out byte[] rgCriteria, out SimpleDatum.DATA_FORMAT fmtCriteria)
Retrives the image at a given index within the Testing data set.
bool ReInitializeParameters(WEIGHT_TARGET target, params string[] rgstrLayers)
Re-initializes each of the specified layers by re-running the filler (if any) specified by the layer....
ProjectEx m_project
The active project (if any).
ResultCollection Run(Bitmap img, bool bSort=true, bool bPad=true)
Run on a given bitmap image.
void SetOnTestingStartOverride(EventHandler onTestingStart)
Sets the root solver's onTestingStart event function triggered on the start of each testing pass.
static NetParameter CreateNetParameterForRunning(BlobShape shape, string strModel, out TransformationParameter transform_param, Stage stage=Stage.NONE, bool bSkipLossLayer=false, bool bMaintainBatchSize=false)
Creates a net parameter for the RUN phase.
int GetDeviceCount()
Returns the total number of devices installed on this computer.
int CurrentIteration
Returns the current iteration.
EventHandler< SnapshotArgs > OnSnapshot
The OnSnapshot event fires each time a snap-shot is taken.
void SetOnTestOverride(EventHandler< TestArgs > onTest)
Sets the root solver's onTest event function.
ResultCollection Run(SimpleDatum d, bool bSort, bool bUseSolverNet, bool bPad=true)
Run on a given Datum.
void Train(int nIterationOverride=-1, int nTrainingTimeLimitInMinutes=0, TRAIN_STEP step=TRAIN_STEP.NONE, double dfLearningRateOverride=0, bool bReset=false)
Train the network a set number of iterations.
bool? EnableLayerDebugging
Enable/disable layer debugging which causes each layer to check for NAN/INF on each forward/backward ...
string LabelQueryEpochs
Returns a string describing the label query epochs observed during training.
void AddCancelOverride(CancelEvent evtCancel)
Adds a cancel override.
SimpleDatum GetItemMean()
Returns the item (e.g., image or temporal item) mean used by the solver network used during training.
Net< T > GetInternalNet(Phase phase=Phase.RUN)
Returns the internal net based on the Phase specified: TRAIN, TEST or RUN.
SettingsCaffe m_settings
The settings used to configure the control.
void UpdateRunWeights(bool bOutputStatus=false, bool bVerifyWeights=true)
Loads the weights from the training net into the Net used for running.
void CopyWeightsFrom(MyCaffeControl< T > src)
Copy the learnable parameter data from the source MyCaffeControl into this one.
Blob< T > CreateBlob(string strName)
Create an unsized blob and set its name.
void FreeExtension(long hExtension)
Free an existing extension and unload it.
byte[] GetWeights()
Retrieves the weights of the training network.
ManualResetEvent m_evtPause
An auto-reset event used to pause training.
NetParameter createNetParameterForRunning(ProjectEx p, out TransformationParameter transform_param)
Creates a net parameter for the RUN phase.
ConnectInfo DatasetConnectInfo
Returns the dataset connection information, if used (default = null).
bool CompareWeights(Net< T > net1, Net< T > net2)
The CompareWeights method compares the weights held in two different Net objects.
string CurrentDevice
Returns the name of the current device used.
Solver< T > GetInternalSolver()
Get the internal solver.
bool VerifyCompute(string strExtra=null, int nDeviceID=-1, bool bThrowException=true)
VerifyCompute compares the current compute of the current device (or device specified) against the re...
IXDatabaseBase m_db
The image database.
void CopyGradientsFrom(MyCaffeControl< T > src)
Copy the learnable parameter diffs from the source MyCaffeControl into this one.
void UpdateWeights(byte[] rgWeights)
Loads the training Net with new weights.
List< ResultCollection > Run(Blob< T > blob, bool bSort=true, bool bUseSolverNet=false, int nMax=int.MaxValue, List< int > rgIgnoreLabels=null)
Run on a Blob of data.
string GetLicenseText(string strOtherLicenses)
Returns the license text for MyCaffe.
string ActiveLabelCounts
Returns a string describing the active label counts observed during training.
float[] RunExtensionF(long hExtension, long lfnIdx, float[] rgParam)
Run a function on an existing extension using the float base type.
BlobCollection< T > Run(BlobCollection< T > colBottom)
Run the network forward on the bottom blobs.
void PrepareImageMeans(ProjectEx prj)
Prepare the testing image mean by copying the training image mean if the testing image mean is missin...
Bitmap GetTestImage(Phase phase, int nLabel)
Retrieves a random image from either the training or test set depending on the Phase specified.
ResultCollection Run(int nImageIdx, bool bPad=true)
Run on a given image in the MyCaffeImageDatabase based on its image index.
Bitmap GetTargetImage(int nImageID, out int nLabel, out string strLabel, out byte[] rgCriteria, out SimpleDatum.DATA_FORMAT fmtCriteria)
Retrives the image with a given ID.
double ApplyUpdate(int nIteration)
Directs the solver to apply the leanred blob diffs to the weights using the solver's learning rate an...
AutoResetEvent m_evtForceTest
An auto-reset event used to force a test cycle.
MyCaffeControl< T > Clone(int nGpuID)
Clone the current instance of the MyCaffeControl creating a second instance.
Net< T > CreateNet(byte[] rgWeights, CudaDnn< T > cudaOverride=null)
Creates a new Net, loads the weights specified into it and returns it.
IXPersist< T > Persist
Returns the persist used to load and save weights.
BlobCollection< T > TestManyEx(PropertySet customInput)
Test on custom input data.
bool Load(Phase phase, ProjectEx p, DB_LABEL_SELECTION_METHOD? labelSelectionOverride=null, DB_ITEM_SELECTION_METHOD? itemSelectionOverride=null, bool bResetFirst=false, IXDatabaseBase db=null, bool bUseDb=true, bool bCreateRunNet=true, string strStage=null, bool bEnableMemTrace=false)
Load a project and optionally the MyCaffeImageDatabase.
DataTransformer< T > m_dataTransformer
The data transformer used to transform data.
CancelEvent m_evtCancel
The CancelEvent used to cancel training and testing operations.
void AddCancelOverrideByName(string strEvtCancel)
Adds a cancel override.
string LabelQueryHitPercents
Returns a string describing the label query hit percentages observed during training.
CudaDnn< T > Cuda
Returns the CudaDnn connection used.
ProjectEx CurrentProject
Returns the name of the currently loaded project.
void SetOnTrainingStartOverride(EventHandler onTrainingStart)
Sets the root solver's onStart event function triggered on the start of each training pass.
Log m_log
The log used for output.
NetParameter createNetParameterForRunning(DatasetDescriptor ds, string strModel, out TransformationParameter transform_param, Stage stage=Stage.NONE)
Creates a net parameter for the RUN phase.
Blob< T > CreateDataBlob(SimpleDatum d, Blob< T > blob=null, bool bPad=true)
Create a data blob from a SimpleDatum by transforming the data and placing the results in the blob re...
Bitmap GetTestImage(Phase phase, out int nLabel, out string strLabel)
Retrieves a random image from either the training or test set depending on the Phase specified.
void Unload(bool bUnloadImageDb=true, bool bIgnoreExceptions=false)
Unload the currently loaded project, if any.
static string GetLicenseTextEx(string strOtherLicenses)
Returns the license text for MyCaffe.
string CurrentStage
Returns the stage under which the project was loaded, if any.
NetParameter createNetParameterForRunning(SimpleDatum sdMean, string strModel, out TransformationParameter transform_param, out int nC, out int nH, out int nW, Stage stage=Stage.NONE)
Creates a net parameter for the RUN phase.
void RemoveCancelOverride(CancelEvent evtCancel)
Remove a cancel override.
bool EnableVerboseStatus
Get/set whether or not to use verbose status. When enabled, the full status is output when loading a ...
double[] RunExtensionD(long hExtension, long lfnIdx, double[] rgParam)
Run a function on an existing extension using the double base type.
BlobCollection< T > RunModelEx(PropertySet customInput)
Run the model using data from the model itself - requires a Data layer with the RUN phase.
bool m_bDbOwner
Whether or not the control owns the image database.
List< ResultCollection > Run(List< SimpleDatum > rgSd, ref Blob< T > blob, bool bUseSolverNet=false, int nMax=int.MaxValue)
Run on a given list of Datum.
PropertySet RunModel(PropertySet customInput)
Run the model using data from the model itself - requires a Data layer with the RUN phase.
List< Tuple< SimpleDatum, ResultCollection > > TestMany(int nCount, bool bOnTrainingSet, bool bOnTargetSet=false, DB_ITEM_SELECTION_METHOD imgSelMethod=DB_ITEM_SELECTION_METHOD.RANDOM, int nImageStartIdx=0, DateTime? dtImageStartTime=null, double? dfThreshold=null)
Test on a number of images by selecting random images from the database, running them through the Run...
double Test(int nIterationOverride=-1)
Test the network a given number of iterations.
bool? EnableBreakOnFirstNaN
Enable/disable break training after first detecting a NaN.
NetParameter createNetParameterForRunning(BlobShape shape, string strModel, out TransformationParameter transform_param, Stage stage=Stage.NONE)
Creates a net parameter for the RUN phase.
string m_strCudaPath
The low-level path of the underlying CudaDnn DLL.
Phase LastPhase
Returns the last phase run (TRAIN, TEST or RUN).
bool? EnableBlobDebugging
Enable/disable blob debugging.
List< ResultCollection > Run(List< int > rgImageIdx, ref Blob< T > blob)
Run on a set of images in the MyCaffeImageDatabase based on their image indexes.
T[] RunExtension(long hExtension, long lfnIdx, T[] rgParam)
Run a function on an existing extension.
bool? EnableSingleStep
Enable/disable single step training.
PropertySet Run(PropertySet customInput, int nK=1, double dfThreshold=0.01, int nMax=80, bool bBeamSearch=false)
Run the model on custom input data.
List< ResultCollection > Run(List< int > rgImageIdx)
Run on a set of images in the MyCaffeImageDatabase based on their image indexes.
bool LoadLite(Phase phase, string strSolver, string strModel, byte[] rgWeights=null, bool bResetFirst=false, bool bCreateRunNet=true, SimpleDatum sdMean=null, string strStage=null, bool bEnableMemTrace=false)
Load a solver and model without using the MyCaffeImageDatabase.
int MaximumIteration
Returns the maximum iteration.
List< int > m_rgGpu
A list of the Device ID's used for training.
DatasetDescriptor GetDataset()
Returns the current dataset used when training and testing.
void Add(AnnotationGroupCollection col)
Add another AnnotationGroupCollection to this one.
int Count
Specifies the number of items in the collection.
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
bool RemoveCancelOverride(string strName)
Remove a new cancel override.
void AddCancelOverride(string strName)
Add a new cancel override.
string Name
Return the name of the cancel event.
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
void Set()
Sets the event to the signaled state.
The ConnectInfo class specifies the server, database and username/password used to connect to a datab...
The Datum class is a simple wrapper to the SimpleDatum class to ensure compatibility with the origina...
The ImageData class is a helper class used to convert between Datum, other raw data,...
static Bitmap GetImage(SimpleDatum d, ColorMapper clrMap=null, List< int > rgClrOrder=null)
Converts a SimplDatum (or Datum) into an image, optionally using a ColorMapper.
static Datum GetImageDataF(Bitmap bmp, int nChannels, bool bDataIsReal, int nLabel, bool bUseLockBitmap=true, int[] rgFocusMap=null)
The GetImageDataF function converts a Bitmap into a Datum using the float type for real data.
static Datum GetImageDataD(Bitmap bmp, int nChannels, bool bDataIsReal, int nLabel, bool bUseLockBitmap=true, int[] rgFocusMap=null)
The GetImageDataD function converts a Bitmap into a Datum using the double type for real data.
The Log class provides general output in text form.
void CHECK(bool b, string str)
Test a flag for true.
bool IsEnabled
Returns whether or not the Log is enabled.
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
bool Enable
Enables/disables the Log. When disabled, the Log does not output any data.
void FAIL(string str)
Causes a failure which throws an exception with the desciptive text.
double Progress
Get/set the progress associated with the Log.
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
void WriteHeader(string str)
Write a header as output.
void CHECK_GT(double df1, double df2, string str)
Test whether one number is greater than another.
void CHECK_GE(double df1, double df2, string str)
Test whether one number is greater than or equal to another.
The ProjectEx class manages a project containing the solver description, model description,...
int TargetDatasetID
Get/set the dataset ID of the target dataset (if exists), otherwise return 0.
RawProto CreateModelForRunning(string strName, int nNum, int nChannels, int nHeight, int nWidth, out RawProto protoTransform, out bool bSkipTransformParam, Stage stage=Stage.NONE, bool bSkipLossLayer=false)
Create a model description as a RawProto for running the Project.
int ID
Returns the ID of the Project in the database.
ParameterDescriptorCollection Parameters
Returns any project parameters that may exist (if any).
DatasetDescriptor Dataset
Return the descriptor of the dataset used.
string? ModelDescription
Get/set the model description script used by the Project.
byte[] WeightsState
Get/set the weight state.
DatasetDescriptor DatasetTarget
Returns the target dataset (if exists) or null if it does not.
byte[] SolverState
Get/set the solver state.
Stage Stage
Return the stage under which the project was opened.
Specifies a key-value pair of properties.
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
double GetPropertyAsDouble(string strName, double dfDefault=0)
Returns a property as an double value.
void SetProperty(string strName, string strVal)
Sets a property in the property set to a value if it exists, otherwise it adds the new property.
The RawProtoCollection class is a list of RawProto objects.
int Count
Returns the number of items in the collection.
The RawProto class is used to parse and output Google prototxt file data.
string Value
Get/set the value of the node.
RawProto FindChild(string strName)
Searches for a given node.
override string ToString()
Returns the RawProto as its full prototxt string.
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
RawProtoCollection FindChildren(params string[] rgstrName)
Searches for all children with a given name in this node's children.
The Result class contains a single result.
int Label
Returns the label.
The SettingsCaffe defines the settings used by the MyCaffe CaffeControl.
SNAPSHOT_WEIGHT_UPDATE_METHOD SnapshotWeightUpdateMethod
Get/set the snapshot update method.
SettingsCaffe Clone()
Returns a copy of the SettingsCaffe object.
DB_LOAD_METHOD DbLoadMethod
Get/set the image database loading method.
bool ItemDbLoadDebugData
Specifies whether or not to load the debug data from file (default = false).
string GpuIds
Get/set the default GPU ID's to use when training.
int MaximumIterationOverride
Get/set the maximum iteration override. When set, this overrides the training iterations specified in...
DB_VERSION DbVersion
Get/set the version of the MyCaffeImageDatabase to use.
bool ItemDbLoadDataCriteria
Specifies whether or not to load the image criteria data from file (default = false).
int TestingIterationOverride
Get/set the testing iteration override. When set, this overrides the testing iterations specified in ...
The SimpleDatum class holds a data input within host memory.
bool GetDataValid(bool bByType=true)
Returns true if the ByteData or RealDataD or RealDataF are not null, false otherwise.
int Channels
Return the number of channels of the data.
AnnotationGroupCollection annotation_group
When using annoations, each annotation group contains an annotation for a particular class used with ...
byte[] DataCriteria
Get/set data criteria associated with the data.
DATA_FORMAT
Defines the data format of the DebugData and DataCriteria when specified.
int Width
Return the width of the data.
int ImageID
Returns the ID of the image in the database.
int Height
Return the height of the data.
DATA_FORMAT DataCriteriaFormat
Get/set the data format of the data criteria.
int Label
Return the known label of the data.
The Utility class provides general utility funtions.
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
int ID
Get/set the database ID of the item.
string Name
Get/set the name of the item.
The DatasetDescriptor class describes a dataset which contains both a training data source and testin...
bool? IsGym
Returns whether or not this dataset is from a Gym.
SourceDescriptor TrainingSource
Get/set the training data source.
string? TrainingSourceName
Returns the training source name, or null if not specifies.
bool? IsModelData
Returns whether or not this dataset is from the model itself.
SourceDescriptor TestingSource
Get/set the testing data source.
string? TestingSourceName
Returns the testing source name or null if not specified.
ParameterDescriptor Find(string strName)
Searches for a parameter by name in the collection.
The ParameterDescriptor class describes a parameter in the database.
override string ToString()
Creates the string representation of the descriptor.
string Value
Get/set the value of the item.
The SourceDescriptor class contains all information describing a data source.
override string ToString()
Return a string representation of thet SourceDescriptor.
int Height
Returns the height of each data item in the data source.
int Width
Returns the width of each data item in the data source.
int Channels
Returns the item colors - 1 channel = black/white, 3 channels = RGB color.
The BeamSearch uses the softmax output from the network and continually runs the net on each output (...
List< Tuple< double, bool, List< Tuple< string, int, double > > > > Search(PropertySet input, int nK, double dfThreshold=0.01, int nMax=80)
Perform the beam-search.
The BlobCollection contains a list of Blobs.
int Count
Returns the number of items in the collection.
void Reshape(int[] rgShape)
Reshapes all blobs in the collection to the given shape.
The Blob is the main holder of data that moves through the Layers of the Net.
int channels
DEPRECIATED; legacy shape accessor channels: use shape(1) instead.
int height
DEPRECIATED; legacy shape accessor height: use shape(2) instead.
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
string shape_string
Returns a string describing the Blob's shape.
T[] mutable_cpu_data
Get data from the GPU and bring it over to the host, or Set data from the Host and send it over to th...
BLOB_TYPE type
Returns the BLOB_TYPE of the Blob.
byte[] ToByteArray()
Saves this Blob to a byte array.
void CopyFrom(Blob< T > src, int nSrcOffset, int nDstOffset, int nCount, bool bCopyData, bool bCopyDiff)
Copy from a source Blob.
int width
DEPRECIATED; legacy shape accessor width: use shape(3) instead.
T asum_data()
Compute the sum of absolute values (L1 norm) of the data.
int count()
Returns the total number of items in the Blob.
void ReshapeLike(Blob< T > b, bool? bUseHalfSize=null)
Reshape this Blob to have the same shape as another Blob.
string Name
Get/set the name of the Blob.
virtual void Dispose(bool bDisposing)
Releases all resources used by the Blob (including both GPU and Host).
bool Padded
Get/set the padding state of the blob.
int num
DEPRECIATED; legacy shape accessor num: use shape(0) instead.
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
int GetDeviceID()
Returns the current device id set within Cuda.
string GetRequiredCompute(out int nMinMajor, out int nMinMinor)
The GetRequiredCompute function returns the Major and Minor compute values required by the current Cu...
void sub(int n, long hA, long hB, long hY, int nAOff=0, int nBOff=0, int nYOff=0, int nB=0)
Subtracts B from A and places the result in Y.
string Path
Specifies the file path used to load the Low-Level Cuda DNN Dll file.
void FreeExtension(long hExtension)
Free an instance of an Extension.
void FreeHostBuffer(long hMem)
Free previously allocated host memory.
int GetDeviceCount()
Query the number of devices (gpu's) installed.
string GetDeviceName(int nDeviceID)
Query the name of a device.
T[] RunExtension(long hExtension, long lfnIdx, T[] rgParam)
Run a function on the extension specified.
long CreateExtension(string strExtensionDllPath)
Create an instance of an Extension DLL.
virtual void Dispose(bool bDisposing)
Disposes this instance freeing up all of its host and GPU memory.
The NCCL class manages the multi-GPU operations using the low-level NCCL functionality provided by th...
void Run(List< int > rgGpus, int nIterationOverride=-1)
Run the root Solver and coordinate with all other Solver's participating in the multi-GPU training.
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
List< Layer< T > > layers
Returns the layers.
BlobCollection< T > parameters
Returns the parameters.
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
BlobCollection< T > input_blobs
Returns the collection of input Blobs.
virtual void Dispose(bool bDisposing)
Releases all resources (GPU and Host) used by the Net.
void LoadWeights(byte[] rgWeights, IXPersist< T > persist, List< string > inputWtInfo=null, List< string > targetWtInfo=null, string strSkipBlobType=null)
Loads new weights into the Net.
NetParameter ToProto(bool bIncludeBlobs)
Writes the net to a proto.
Blob< T > FindBlob(string strName)
Finds a Blob in the Net by name.
byte[] SaveWeights(IXPersist< T > persist, bool bSaveDiff=false)
Save the weights to a byte array.
BlobCollection< T > learnable_parameters
Returns the learnable parameters.
void ShareTrainedLayersWith(Net< T > srcNet, bool bEnableLog=false)
For an already initialized net, implicitly compies (i.e., using no additional memory) the pre-trained...
bool ReInitializeParameters(WEIGHT_TARGET target, params string[] rgstrLayers)
Re-initializes the blobs and each of the specified layers by re-running the filler (if any) specified...
The PersistCaffe class is used to load and save weight files in the .caffemodel format.
PersistCaffe(Log log, bool bFailOnFirstTry)
The PersistCaffe constructor.
BlobCollection< T > LoadWeights(byte[] rgWeights, List< string > rgExpectedShapes, BlobCollection< T > colBlobs, bool bSizeToFit, out bool bLoadedDiffs, List< string > inputWtInfo=null, List< string > targetWtInfo=null, string strSkipBlobType=null)
Loads new weights into a BlobCollection
The ResultCollection contains the result of a given CaffeControl::Run.
RESULT_TYPE ResultType
Returns the result type of the result data: PROBABILITIES (Sigmoid), DISTANCES (Decode),...
List< Result > ResultsSorted
Returns the original results in sorted order.
double DetectedLabelOutput
Returns the detected label output depending on the result type (distance or probability) with a defau...
static RESULT_TYPE GetResultType(LayerParameter.LayerType type)
Get the result type based on the layer-type used.
void SetLabels(List< LabelDescriptor > rgLabels)
Sets the label names in the label dictionary lookup.
int DetectedLabel
Returns the detected label depending on the result type (distance or probability) with a default type...
RESULT_TYPE
Defines the type of result.
List< Result > ResultsOriginal
Returns the original results.
The SnapshotArgs is sent to the Solver::OnSnapshot event which fires each time the Solver::Snapshot m...
Specifies the TestingIterationArgs sent to the Solver::OnTestingIteration, which is called at the end...
The TrainingIterationArgs is sent to the Solver::OnTrainingIteration event that fires at the end of a...
The Database class manages the actual connection to the physical database using Entity Framworks from...
The DatasetFactory manages the connection to the Database object.
SimpleDatum QueryImageMean(int nSrcId=0)
Return the SimpleDatum for the image mean from the open data source.
int GetRawImageMeanID(int nSrcId=0)
Returns the raw image ID for the image mean associated with a data source.
DatasetDescriptor LoadDataset(string strDataset, ConnectInfo ci=null)
Load a dataset descriptor from a dataset name.
bool CopyImageMean(string strSrcSrc, string strDstSrc)
Copy the raw image mean from one source to another.
SourceDescriptor LoadSource(string strSource)
Load the source descriptor from a data source name.
[V2 Image Database] The MyCaffeImageDatabase2 provides an enhanced in-memory image database used for ...
The MyCaffeImageDatabase provides an enhanced in-memory image database used for quick image retrieval...
static Tuple< DB_LABEL_SELECTION_METHOD, DB_ITEM_SELECTION_METHOD > GetSelectionMethod(SettingsCaffe s)
Returns the label/image selection methods based on the SettingsCaffe settings.
An interface for the units of computation which can be composed into a Net.
LayerParameter.LayerType type
Returns the LayerType of this Layer.
LayerParameter layer_param
Returns the LayerParameter for this Layer.
virtual BlobCollection< T > PreProcessInput(PropertySet customInput, out int nSeqLen, BlobCollection< T > colBottom=null)
The PreprocessInput allows derivative data layers to convert a property set of input data into the bo...
Specifies the parameters for the AccuracyLayer.
Specifies the shape of a Blob.
List< int > dim
The blob shape dimensions.
string source
When used with the DATA parameter, specifies the data 'source' within the database....
/b DEPRECIATED (use DataLayer DataLabelMappingParameter instead) Specifies the parameters for the Lab...
Specifies the base parameter for all layers.
void PrepareRunModel()
Prepare the layer settings for a run model.
LayerType type
Specifies the type of this LayerParameter.
SoftmaxParameter softmax_param
Returns the parameter set when initialized with LayerType.SOFTMAX
List< NetStateRule > include
Specifies the NetStateRule's for which this LayerParameter should be included.
AccuracyParameter accuracy_param
Returns the parameter set when initialized with LayerType.ACCURACY
string PrepareRunModelInputs()
Prepare model inputs for the run-net (if any are needed for the layer).
TransformationParameter transform_param
Returns the parameter set when initialized with LayerType.TRANSFORM
DataParameter data_param
Returns the parameter set when initialized with LayerType.DATA
LayerType
Specifies the layer type.
LabelMappingParameter labelmapping_param
Returns the parameter set when initialized with LayerType.LABELMAPPING
Specifies the parameters use to create a Net
static NetParameter FromProto(RawProto rp)
Parse a RawProto into a new instance of the parameter.
NetState state
The current 'state' of the network, including the phase, level and stage. Some layers may be included...
bool force_backward
Whether the network will force every layer to carry out backward operation. If set False,...
override RawProto ToProto(string strName)
Constructor for the parameter.
List< LayerParameter > layer
The layers that make up the net. Each of their configurations, including connectivity and behavior,...
static Dictionary< string, BlobShape > InputFromProto(RawProto rp)
Collect the inputs from the RawProto.
Phase phase
Specifies the Phase of the NetState.
List< string > stage
Specifies the stages of the NetState.
Specifies a NetStateRule used to determine whether a Net falls within a given include or exclude patt...
Phase phase
Set phase to require the NetState to have a particular phase (TRAIN or TEST) to meet this rule.
int axis
The axis along which to perform the softmax – may be negative to index from the end (e....
The SolverParameter is a parameter for the solver, specifying the train and test networks.
NetParameter net_param
Inline train net param, possibly combined with one or more test nets.
static SolverParameter FromProto(RawProto rp)
Parses a new SolverParameter from a RawProto.
Specifies the parameters for the DecodeLayer and the AccuracyEncodingLayer.
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
void Dispose()
Discards the resources (GPU and Host) used by this Solver.
EventHandler< TrainingIterationArgs< T > > OnTrainingIteration
The OnTrainingIteration event fires at the end of each training iteration.
static SGDSolver< T > Create(CudaDnn< T > cuda, Log log, ProjectEx p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
Create a new Solver based on the project containing the SolverParameter.
void Restore(byte[] rgWeights, byte[] rgState, string strSkipBlobTypes=null)
The restore method simply calls the RestoreSolverState method of the inherited class.
void Snapshot(bool bForced, bool bScheduled, bool bUpdateDatabase=true)
The snapshot function implements the basic snapshotting utility that stores the learned net....
int MaximumIteration
Returns the maximum training iterations.
EventHandler< SnapshotArgs > OnSnapshot
The OnSnapshot event fires when the Solver detects that a snapshot is needed.
bool EnableBlobDebugging
When enabled, the OnTrainingIteration event is set extra debugging information describing the state o...
Net< T > net
Returns the main training Net.
bool ForceOnTrainingIterationEvent()
Force an OnTrainingIterationEvent to fire.
double LearningRateOverride
Get/set the learning rate override. When 0, this setting is ignored.
bool EnableTesting
When enabled, the training cycle calls TestAll periodically based on the SolverParameter....
Net< T > TrainingNet
Returns the training Net used by the solver.
EventHandler< TestingIterationArgs< T > > OnTestingIteration
The OnTestingIteration event fires at the end of each testing iteration.
bool EnableBreakOnFirstNaN
When enabled (requires EnableBlobDebugging = true), the Solver immediately stop training upon detecti...
SNAPSHOT_WEIGHT_UPDATE_METHOD SnapshotWeightUpdateMethod
Get/set the snapshot weight update method.
int TrainingTimeLimitInMinutes
Get/set the training time limit in minutes. When set to 0, no time limit is imposed on training.
void Reset()
Reset the iterations of the net.
double TestAll(int nIterationOverride=-1)
Run a TestAll by running all test Nets.
string LabelQueryEpochs
Return the label query epochs for the active datasource.
int TrainingIterationOverride
Get/set the training iteration override.
EventHandler OnTestStart
The OnTestStart event fires at the start of each testing iteration.
bool WeightsUpdated
Get/set when the weights have been updated.
bool EnableDetailedNanDetection
When enabled (requires EnableBlobDebugging = true), the detailed Nan (and Infinity) detection is pero...
EventHandler< TestArgs > OnTest
When specified, the OnTest event fires during a TestAll and overrides the call to Test.
int TestingIterationOverride
Get/set the testing iteration override.
virtual void Solve(int nIterationOverride=-1, byte[] rgWeights=null, byte[] rgState=null, TRAIN_STEP step=TRAIN_STEP.NONE)
The main entry of the solver function. In default, iter will be zero. Pass in a non-zero iter number ...
EventHandler OnStart
The OnStart event fires at the start of each training iteration.
string ActiveLabelCounts
Returns a string describing the labels detected in the training along with the % that each label has ...
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
bool EnableSingleStep
When enabled (requires EnableBlobDebugging = true), the Solver only runs one training cycle.
string LabelQueryHitPercents
Return the label query hit percentages for the active datasource.
Net< T > TestingNet
Returns the testing Net used by the solver.
bool EnableLayerDebugging
Enable/disable layer debugging which causes each layer to check for NAN/INF on each forward/backward ...
int CurrentIteration
Returns the current training iteration.
The Component class is a standard Microsoft.NET class that implements the IComponent interface and is...
The IXDatabaseBase interface defines the general interface to the in-memory database.
SimpleDatum GetItem(int nItemID, params int[] rgSrcId)
Get the item (e.g., image or temporal item) with a given Raw Item ID.
SimpleDatum GetItemMean(int nSrcId)
Returns the item (e.g., image or temporal item) mean for a data source.
int GetSourceID(string strSrc)
Returns a data source ID given its name.
SimpleDatum QueryItem(int nSrcId, int nIdx, DB_LABEL_SELECTION_METHOD? labelSelectionOverride=null, DB_ITEM_SELECTION_METHOD? imageSelectionOverride=null, int? nLabel=null, bool bLoadDataCriteria=false, bool bLoadDebugData=false)
Query an image in a given data source.
Tuple< DB_LABEL_SELECTION_METHOD, DB_ITEM_SELECTION_METHOD > GetSelectionMethod()
Returns the label and image selection method used.
SimpleDatum QueryItemMean(int nSrcId)
Queries the item (e.g., image or temporal item) mean for a data source from the database on disk.
void SetSelectionMethod(DB_LABEL_SELECTION_METHOD? lbl, DB_ITEM_SELECTION_METHOD? img)
Sets the label and image selection methods.
void CleanUp(int nDsId=0, bool bForce=false)
Releases the image database, and if this is the last instance using the in-memory database,...
DB_VERSION GetVersion()
Returns the version of the MyCaffe Image Database being used.
List< SimpleDatum > GetItemsFromTime(int nSrcId, DateTime dtStart, int nQueryCount=int.MaxValue, string strFilterVal=null, int? nBoostVal=null, bool bBoostValIsExact=false)
Returns the array of items (e.g., images or temporal items) in the item set, possibly filtered with t...
The IXImageDatabase interface defines the eneral interface to the in-memory image database.
The IXImageDatabase2 interface defines the general interface to the in-memory image database (v2).
The IXImageDatabaseBase interface defines the general interface to the in-memory image database.
The IXMyCaffeExtension interface allows for easy extension management of the low-level software that ...
The IXMyCaffe interface contains functions used to perform MyCaffe operations that work with the MyCa...
The IXMyCaffeNoDb interface contains functions used to perform MyCaffe operations that run in a light...
The IXMyCaffeState interface contains functions related to the MyCaffeComponent state.
The IXPersist interface is used by the CaffeControl to load and save weights.
The descriptors namespace contains all descriptor used to describe various items stored within the da...
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
DB_ITEM_SELECTION_METHOD
Defines the item (e.g., image or temporal item) selection method.
Phase
Defines the Phase under which to run a Net.
DB_VERSION
Defines the image database version to use.
DB_LABEL_SELECTION_METHOD
Defines the label selection method.
Stage
Specifies the stage underwhich to run a custom trainer.
The MyCaffe.common namespace contains common MyCaffe classes.
DEVINIT
Specifies the initialization flags used when initializing CUDA.
BLOB_TYPE
Defines the tpe of data held by a given Blob.
TRAIN_STEP
Defines the training stepping method (if any).
WEIGHT_TARGET
Defines the type of weight to target in re-initializations.
The MyCaffe.data namespace contains dataset creators used to create common testing datasets such as M...
The MyCaffe.db.image namespace contains all image database related classes.
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
The MyCaffe.param.beta parameters are used by the MyCaffe.layer.beta layers.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...