2using System.Collections.Generic;
8using System.Threading.Tasks;
33 bool m_bUsePreloadData =
true;
50 m_icallback = icallback;
52 m_properties = properties;
54 m_rgVocabulary = rgVocabulary;
76 private void wait(
int nWait)
81 while (nTotalWait < nWait)
83 m_icallback.OnWait(
new WaitArgs(nWaitInc));
84 nTotalWait += nWaitInc;
95 if (m_mycaffe !=
null)
101 m_icallback.OnShutdown();
116 Agent<T> agent =
new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random,
Phase.RUN, m_rgVocabulary, m_bUsePreloadData, runProp);
117 float[] rgResults = agent.Run(nN);
133 Agent<T> agent =
new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random,
Phase.RUN, m_rgVocabulary, m_bUsePreloadData, runProp);
134 byte[] rgResults = agent.Run(nN, out type);
151 Agent<T> agent =
new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random,
Phase.TEST, m_rgVocabulary, m_bUsePreloadData);
170 Agent<T> agent =
new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random,
Phase.TRAIN, m_rgVocabulary, m_bUsePreloadData);
171 agent.Run(
Phase.TRAIN, nN, type, step);
179 class Agent<T> : IDisposable
188 m_icallback = icallback;
189 m_brain =
new Brain<T>(mycaffe, properties, random, icallback as
IxTrainerCallbackRNN, phase, rgVocabulary, bUsePreloadData, runProp);
190 m_properties = properties;
194 public void Dispose()
203 private StateBase getData(
Phase phase,
int nAction)
205 GetDataArgs args = m_brain.getDataArgs(phase, 0, nAction,
true);
206 m_icallback.OnGetData(args);
225 throw new Exception(
"The TrainerRNN only supports the ITERATION type.");
227 StateBase s = getData(phase, -1);
229 while (!m_brain.Cancel.WaitOne(0) && !s.Done)
231 if (phase ==
Phase.TEST)
233 else if (phase ==
Phase.TRAIN)
234 m_brain.Train(s, nN, step);
236 s = getData(phase, 1);
245 public float[] Run(
int nN)
247 return m_brain.Run(nN);
256 public byte[] Run(
int nN, out
string type)
258 float[] rgResults = m_brain.Run(nN);
260 ConvertOutputArgs args =
new ConvertOutputArgs(nN, rgResults);
261 IxTrainerCallbackRNN icallback = m_icallback as IxTrainerCallbackRNN;
262 if (icallback ==
null)
263 throw new Exception(
"The Run method requires an IxTrainerCallbackRNN interface to convert the results into the native format!");
265 icallback.OnConvertOutput(args);
268 return args.RawOutput;
272 class Brain<T> : IDisposable
274 IxTrainerCallbackRNN m_icallback;
275 MyCaffeControl<T> m_mycaffe;
284 int m_nSequenceLength;
285 int m_nSequenceLengthLabel;
287 int m_nVocabSize = 1;
293 double m_dfRunTemperature = 0;
294 double m_dfTestTemperature = 0;
295 byte[] m_rgTestData =
null;
296 byte[] m_rgTrainData =
null;
297 float[] m_rgfTestData =
null;
298 float[] m_rgfTrainData =
null;
299 bool m_bIsDataReal =
false;
300 Stopwatch m_sw =
new Stopwatch();
301 double m_dfLastLoss = 0;
302 double m_dfLastLearningRate = 0;
304 bool m_bUsePreloadData =
true;
305 bool m_bDisableVocabulary =
false;
308 int m_nSolverSequenceLength = -1;
310 DataCollectionPool m_dataPool =
new DataCollectionPool();
311 double m_dfScale = 1.0;
315 string strOutputBlob =
null;
318 m_runProperties = runProp;
320 m_icallback = icallback;
322 m_properties = properties;
324 m_rgVocabulary = rgVocabulary;
325 m_bUsePreloadData = bUsePreloadData;
326 m_nSolverSequenceLength = m_properties.
GetPropertyAsInt(
"SequenceLength", -1);
327 m_bDisableVocabulary = m_properties.
GetPropertyAsBool(
"DisableVocabulary",
false);
333 m_dataPool.Initialize(m_nThreads, icallback);
335 if (m_runProperties !=
null)
338 if (m_dfRunTemperature > 1.0)
339 m_dfRunTemperature = 1.0;
341 string strPhaseOnRun = m_runProperties.
GetProperty(
"PhaseOnRun",
false);
342 switch (strPhaseOnRun)
345 m_phaseOnRun =
Phase.RUN;
349 m_phaseOnRun =
Phase.TEST;
353 m_phaseOnRun =
Phase.TRAIN;
357 if (phase ==
Phase.RUN && m_phaseOnRun !=
Phase.NONE)
359 if (m_phaseOnRun !=
Phase.RUN)
360 m_mycaffe.Log.WriteLine(
"Warning: Running on the '" + m_phaseOnRun.ToString() +
"' network.");
362 strOutputBlob = m_runProperties.
GetProperty(
"OutputBlob",
false);
363 if (strOutputBlob ==
null)
364 throw new Exception(
"You must specify the 'OutputBlob' when Running with a phase other than RUN.");
368 phase = m_phaseOnRun;
372 m_net = mycaffe.GetInternalNet(phase);
375 mycaffe.Log.WriteLine(
"WARNING: Test net does not exist, set test_iteration > 0. Using TRAIN phase instead.");
376 m_net = mycaffe.GetInternalNet(
Phase.TRAIN);
407 if (lstmLayer ==
null && lstmAttentionLayer ==
null && lstmSimpleLayer ==
null)
408 throw new Exception(
"Could not find the required LSTM or LSTM_ATTENTION or LSTM_SIMPLE layer!");
410 if (m_phaseOnRun !=
Phase.NONE && m_phaseOnRun !=
Phase.RUN && strOutputBlob !=
null)
412 if ((m_blobOutput = m_net.
FindBlob(strOutputBlob)) ==
null)
413 throw new Exception(
"Could not find the 'Output' layer top named '" + strOutputBlob +
"'!");
416 if ((m_blobData = m_net.
FindBlob(
"data")) ==
null)
417 throw new Exception(
"Could not find the 'Input' layer top named 'data'!");
419 if ((m_blobClip = m_net.
FindBlob(
"clip")) ==
null)
420 throw new Exception(
"Could not find the 'Input' layer top named 'clip'!");
423 m_mycaffe.Log.CHECK(layer !=
null,
"Could not find an ending INNERPRODUCT layer!");
425 if (!m_bDisableVocabulary)
428 if (rgVocabulary !=
null)
429 m_mycaffe.Log.CHECK_EQ(m_nVocabSize, rgVocabulary.
Count,
"The vocabulary count = '" + rgVocabulary.
Count.ToString() +
"' and last inner product output count = '" + m_nVocabSize.ToString() +
"' - these do not match but they should!");
434 m_nSequenceLength = m_blobData.
shape(0);
435 m_nBatchSize = m_blobData.
shape(1);
440 m_nSequenceLength = m_blobData.
shape(0) / m_nBatchSize;
442 if (phase ==
Phase.RUN)
446 List<int> rgNewShape =
new List<int>() { m_nSequenceLength, 1 };
447 m_blobData.
Reshape(rgNewShape);
448 m_blobClip.
Reshape(rgNewShape);
453 m_mycaffe.Log.CHECK_EQ(m_nSequenceLength, m_blobData.
num,
"The data num must equal the sequence lengh of " + m_nSequenceLength.ToString());
455 m_rgDataInput =
new T[m_nSequenceLength * m_nBatchSize];
457 T[] rgClipInput =
new T[m_nSequenceLength * m_nBatchSize];
458 m_mycaffe.Log.CHECK_EQ(rgClipInput.Length, m_blobClip.
count(),
"The clip count must equal the sequence length * batch size: " + rgClipInput.Length.ToString());
459 m_tZero = (T)Convert.ChangeType(0, typeof(T));
460 m_tOne = (T)Convert.ChangeType(1, typeof(T));
462 for (
int i = 0; i < rgClipInput.Length; i++)
465 rgClipInput[i] = (i < m_nBatchSize) ? m_tZero : m_tOne;
467 rgClipInput[i] = (i % m_nSequenceLength == 0) ? m_tZero : m_tOne;
472 if (phase !=
Phase.RUN)
474 m_solver = mycaffe.GetInternalSolver();
475 m_solver.
OnStart += m_solver_OnStart;
481 if ((m_blobLabel = m_net.
FindBlob(
"label")) ==
null)
482 throw new Exception(
"Could not find the 'Input' layer top named 'label'!");
484 m_nSequenceLengthLabel = m_blobLabel.
count(0, 2);
485 m_rgLabelInput =
new T[m_nSequenceLengthLabel];
486 m_mycaffe.Log.CHECK_EQ(m_rgLabelInput.Length, m_blobLabel.
count(),
"The label count must equal the label sequence length * batch size: " + m_rgLabelInput.Length.ToString());
487 m_mycaffe.Log.CHECK(m_nSequenceLengthLabel == m_nSequenceLength * m_nBatchSize || m_nSequenceLengthLabel == 1,
"The label sqeuence length must be 1 or equal the length of the sequence: " + m_nSequenceLength.ToString());
493 if (m_sw.Elapsed.TotalMilliseconds > 1000)
504 if (m_sw.Elapsed.TotalMilliseconds > 1000)
511 private void dispose(ref
Blob<T> b)
520 public void Dispose()
522 if (m_dataPool !=
null)
524 m_dataPool.Shutdown();
529 private void updateStatus(
int nIteration,
int nMaxIteration,
double dfAccuracy,
double dfLoss,
double dfLearningRate)
531 GetStatusArgs args =
new GetStatusArgs(0, nIteration, nIteration, nMaxIteration, dfAccuracy, 0, 0, 0, dfLoss, dfLearningRate);
532 m_icallback.OnUpdateStatus(args);
535 public GetDataArgs getDataArgs(
Phase phase,
int nIdx,
int nAction,
bool bGetLabel =
false,
int nBatchSize = 1)
537 bool bReset = (nAction == -1) ?
true :
false;
538 return new GetDataArgs(phase, nIdx, m_mycaffe, m_mycaffe.Log, m_mycaffe.CancelEvent, bReset, nAction,
false, bGetLabel, (nBatchSize > 1) ?
true :
false);
543 get {
return m_mycaffe.
Log; }
551 private void copyData(
SimpleDatum sd,
int nSrcOffset,
float[] rgfDst,
int nCount)
554 float[] rgfSrc = sd.GetData<
float>();
558 Array.
Copy(rgfSrc, nSrcOffset, rgfDst, 0, nCount);
562 for (
int i = 0; i < nCount; i++)
564 rgfDst[i] = rgfSrc[(nSrcOffset + i) * nDim];
569 private void getRawData(StateBase s)
571 int nTestLen = (int)(s.Data.Channels * s.TestingPercent);
572 int nTrainLen = s.Data.Channels - nTestLen;
574 if (s.Data.IsRealData)
576 m_bIsDataReal =
true;
578 m_rgfTrainData =
new float[nTrainLen];
579 copyData(s.Data, 0, m_rgfTrainData, nTrainLen);
583 m_rgfTestData = m_rgfTrainData;
587 m_rgfTestData =
new float[nTestLen];
588 copyData(s.Data, nTrainLen, m_rgfTestData, nTestLen);
593 int nDim = s.Data.Height * s.Data.Width;
595 throw new Exception(
"When training on binary data the height and width must = 1.");
597 m_bIsDataReal =
false;
598 m_rgTrainData =
new byte[nTrainLen];
599 Array.Copy(s.Data.ByteData, 0, m_rgTrainData, 0, nTrainLen);
603 m_rgTestData = m_rgTrainData;
607 m_rgTestData =
new byte[nTestLen];
608 Array.Copy(s.Data.ByteData, nTrainLen, m_rgTestData, 0, nTestLen);
613 public void Test(StateBase s,
int nIterations)
615 if (nIterations <= 0)
628 private void m_solver_OnTestStart(
object sender, EventArgs e)
640 private int getLabel(
float[] rgfScores,
int nIdx,
int nDim)
642 float[] rgfLastScores =
new float[nDim];
643 int nStartIdx = nIdx * nDim;
645 for (
int i = 0; i < nDim; i++)
647 rgfLastScores[i] = rgfScores[nStartIdx + i];
650 return getLastPrediction(rgfLastScores, m_dfTestTemperature);
659 if (nNum != m_rgLabelInput.Length)
662 int nDim = e.
Results[1].count(1);
664 float[] rgfScores =
Utility.ConvertVecF<T>(e.
Results[1].mutable_cpu_data);
665 int nCorrectCount = 0;
666 for (
int i = 0; i < m_nBatchSize; i++)
668 int nIdx = (m_lstmType ==
LayerParameter.
LayerType.LSTM_SIMPLE) ? (i * m_nSequenceLength + m_nSequenceLength - 1) : ((nNum - m_nBatchSize) + i);
669 int nExpectedLabel = (int)
Utility.ConvertVal<T>(m_rgLabelInput[nIdx]);
670 int nActualLabel = getLabel(rgfScores, nIdx, nDim);
671 bool bHandled =
false;
673 TestAccuracyUpdateArgs args =
new TestAccuracyUpdateArgs(nActualLabel, nExpectedLabel);
674 m_icallback.OnTestAccuracyUpdate(args);
684 if (nExpectedLabel == nActualLabel)
689 e.
Accuracy = (double)nCorrectCount / m_nBatchSize;
692 public void Train(StateBase s,
int nIterations,
TRAIN_STEP step)
694 if (nIterations <= 0)
699 m_solver.
Solve(nIterations,
null,
null, step);
702 private void m_solver_OnStart(
object sender, EventArgs e)
707 public void FeedNet(
bool bTrain)
716 if (m_bUsePreloadData)
718 float[] rgfData = (bTrain) ? m_rgfTrainData : m_rgfTestData;
721 for (
int i = 0; i < m_nBatchSize; i++)
723 int nCurrentValIdx = m_random.
Next(rgfData.Length - m_nSequenceLength - 1);
725 for (
int j = 0; j < m_nSequenceLength; j++)
728 double dfData = rgfData[nCurrentValIdx + j];
730 double dfLabel = rgfData[nCurrentValIdx + j + 1];
731 float fDataIdx = findIndex(dfData, out bFound);
732 float fLabelIdx = findIndex(dfLabel, out bFound);
737 nIdx = m_nBatchSize * j + i;
742 nIdx = i * m_nBatchSize + j;
744 m_rgDataInput[nIdx] = (T)Convert.ChangeType(fDataIdx, typeof(T));
746 if (m_nSequenceLengthLabel == (m_nSequenceLength * m_nBatchSize) || j == m_nSequenceLength - 1)
747 m_rgLabelInput[nIdx] = (T)Convert.ChangeType(fLabelIdx, typeof(T));
756 m_mycaffe.Log.CHECK_EQ(m_nBatchSize, m_nThreads,
"The 'Threads' setting of " + m_nThreads.ToString() +
" must match the batch size = " + m_nBatchSize.ToString() +
"!");
758 List<GetDataArgs> rgDataArgs =
new List<GetDataArgs>();
760 if (m_nBatchSize == 1)
762 GetDataArgs e = getDataArgs(phase, 0, 0,
true, m_nBatchSize);
763 m_icallback.OnGetData(e);
768 for (
int i = 0; i < m_nBatchSize; i++)
770 rgDataArgs.Add(getDataArgs(phase, i, 0,
true, m_nBatchSize));
773 if (!m_dataPool.Run(rgDataArgs))
774 m_mycaffe.Log.FAIL(
"Data Time Out - Failed to collect all data to build the RNN batch!");
777 double[] rgData = rgDataArgs[0].State.Data.GetData<
double>();
778 double[] rgLabel = rgDataArgs[0].State.Label.GetData<
double>();
779 double[] rgClip = rgDataArgs[0].State.Clip.GetData<
double>();
781 int nDataLen = rgData.Length;
782 int nLabelLen = rgLabel.Length;
783 int nClipLen = rgClip.Length;
784 int nDataItem = nDataLen / nLabelLen;
786 if (m_nBatchSize > 1)
788 rgData =
new double[nDataLen * m_nBatchSize];
789 rgLabel =
new double[nLabelLen * m_nBatchSize];
790 rgClip =
new double[nClipLen * m_nBatchSize];
792 for (
int i = 0; i < m_nBatchSize; i++)
794 for (
int j = 0; j < m_nSequenceLength; j++)
799 nIdx = m_nBatchSize * j + i;
804 nIdx = i * m_nBatchSize + j;
806 Array.Copy(rgDataArgs[i].State.Data.GetData<
double>(), 0, rgData, nIdx * nDataItem, nDataItem);
807 rgLabel[nIdx] = rgDataArgs[i].State.Label.GetDataAtD(j);
808 rgClip[nIdx] = rgDataArgs[i].State.Clip.GetDataAtD(j);
813 string strSolverErr =
"";
814 if (m_nSolverSequenceLength >= 0 && m_nSolverSequenceLength != m_nSequenceLength)
815 strSolverErr =
"The solver parameter 'SequenceLength' length of " + m_nSolverSequenceLength.ToString() +
" must match the model sequence length of " + m_nSequenceLength.ToString() +
". ";
817 int nExpectedCount = m_blobData.
count();
818 m_mycaffe.Log.CHECK_EQ(nExpectedCount, rgData.Length, strSolverErr +
"The size of the data received ('" + rgData.Length.ToString() +
"') does mot match the expected data count of '" + nExpectedCount.ToString() +
"'!");
821 nExpectedCount = m_blobLabel.
count();
822 m_mycaffe.Log.CHECK_EQ(nExpectedCount, rgLabel.Length, strSolverErr +
"The size of the label received ('" + rgLabel.Length.ToString() +
"') does not match the expected label count of '" + nExpectedCount.ToString() +
"'!");
825 nExpectedCount = m_blobClip.
count();
826 m_mycaffe.Log.CHECK_EQ(nExpectedCount, rgClip.Length, strSolverErr +
"The size of the clip received ('" + rgClip.Length.ToString() +
"') does not match the expected clip count of '" + nExpectedCount.ToString() +
"'!");
833 byte[] rgData = (bTrain) ? m_rgTrainData : m_rgTestData;
840 for (
int i = 0; i < m_nBatchSize; i++)
842 int nCurrentCharIdx = m_random.
Next(rgData.Length - m_nSequenceLength - 2);
844 for (
int j = 0; j < m_nSequenceLength; j++)
847 byte bData = rgData[nCurrentCharIdx + j];
849 byte bLabel = rgData[nCurrentCharIdx + j + 1];
850 float fDataIdx = findIndex(bData, out bFound);
851 float fLabelIdx = findIndex(bLabel, out bFound);
856 nIdx = m_nBatchSize * j + i;
861 nIdx = i * m_nBatchSize + j;
863 m_rgDataInput[nIdx] = (T)Convert.ChangeType(fDataIdx, typeof(T));
865 if (m_nSequenceLengthLabel == (m_nSequenceLength * m_nBatchSize) || j == m_nSequenceLength - 1)
866 m_rgLabelInput[nIdx] = (T)Convert.ChangeType(fLabelIdx, typeof(T));
875 private float findIndex(
byte b, out
bool bFound)
879 if (m_rgVocabulary ==
null || m_bDisableVocabulary)
887 private float findIndex(
double df, out
bool bFound)
891 if (m_rgVocabulary ==
null || m_bDisableVocabulary)
897 private List<T> getInitialInput(
bool bIsReal)
899 List<T> rgInput =
new List<T>();
900 float[] rgCorrectLengthSequence =
new float[m_nSequenceLength];
902 for (
int i = 0; i < m_nSequenceLength; i++)
904 rgCorrectLengthSequence[i] = (int)m_random.
Next(m_nVocabSize);
908 bool bDataNeeded =
true;
909 if (!bIsReal && m_runProperties !=
null)
913 if (rgSeed !=
null && rgSeed.Length > 0)
915 int nLen = rgSeed.Length;
916 if (rgSeed[nLen - 1] == 0)
919 int nStart = rgCorrectLengthSequence.Length - nLen;
923 for (
int i = nStart; i < rgCorrectLengthSequence.Length; i++)
925 byte bVal = rgSeed[i - nStart];
927 int nIdx = (int)findIndex(bVal, out bFound);
930 rgCorrectLengthSequence[i] = nIdx;
937 if (bDataNeeded && m_runProperties !=
null)
939 GetDataArgs e = getDataArgs(
Phase.RUN, 0, 0,
false, m_nSequenceLength);
940 e.ExtraProperties = m_runProperties;
941 e.ExtraProperties.
SetProperty(
"DataCountRequested", m_nSequenceLength.ToString());
942 m_icallback.OnGetData(e);
944 if (e.State.Data !=
null)
946 float[] rgf = e.State.Data.GetData<
float>();
947 int nDim = e.State.Data.Height * e.State.Data.Width;
952 for (
int i = 0; i < rgCorrectLengthSequence.Length; i++)
954 float fChar = rgf[i * nDim];
957 rgCorrectLengthSequence[i] = fChar;
965 m_mycaffe.Log.WriteLine(
"WARNING: No seed data found - using random data.");
967 for (
int i = 0; i < rgCorrectLengthSequence.Length; i++)
969 rgInput.Add((T)Convert.ChangeType(rgCorrectLengthSequence[i], typeof(T)));
975 public float[] Run(
int nN)
979 Stopwatch sw =
new Stopwatch();
980 float[] rgPredictions =
new float[nN];
984 m_bIsDataReal =
true;
986 if (m_rgVocabulary !=
null)
989 m_mycaffe.Log.Enable =
false;
991 if (m_bIsDataReal && !m_bUsePreloadData)
993 string strSolverErr =
"";
995 if (m_nSolverSequenceLength >= 0 && m_nSolverSequenceLength < m_nSequenceLength)
996 nLookahead = m_nSequenceLength - m_nSolverSequenceLength;
998 rgPredictions =
new float[nN * 2 * nLookahead];
1000 for (
int i = 0; i < nN; i++)
1002 GetDataArgs e = getDataArgs(
Phase.RUN, 0, 0,
true);
1003 m_icallback.OnGetData(e);
1005 int nExpectedCount = m_blobData.
count();
1006 m_mycaffe.Log.CHECK_EQ(nExpectedCount, e.State.Data.ItemCount, strSolverErr +
"The size of the data received ('" + e.State.Data.ItemCount.ToString() +
"') does mot match the expected data count of '" + nExpectedCount.ToString() +
"'!");
1009 if (m_blobLabel !=
null)
1011 nExpectedCount = m_blobLabel.
count();
1012 m_mycaffe.Log.CHECK_EQ(nExpectedCount, e.State.Label.ItemCount, strSolverErr +
"The size of the label received ('" + e.State.Label.ItemCount.ToString() +
"') does not match the expected label count of '" + nExpectedCount.ToString() +
"'!");
1018 Blob<T> blobOutput = colResults[0];
1020 if (m_blobOutput !=
null)
1021 blobOutput = m_blobOutput;
1025 for (
int j = nLookahead; j > 0; j--)
1027 float fPrediction = getLastPrediction(rgResults, m_rgVocabulary, j);
1028 int nIdx = e.State.Label.ItemCount - j;
1029 float fActual = (float)e.State.Label.GetDataAtF(nIdx);
1031 int nIdx0 = ((nLookahead - j) * nN * 2);
1032 int nIdx1 = nIdx0 + nN;
1034 if (m_dfScale != 1.0 && m_dfScale > 0)
1035 fActual /= (float)m_dfScale;
1037 if (m_rgVocabulary ==
null || m_bDisableVocabulary)
1039 if (m_dfScale != 1.0 && m_dfScale > 0)
1040 fPrediction /= (float)m_dfScale;
1042 rgPredictions[nIdx0 + i] = fPrediction;
1043 rgPredictions[nIdx1 + i] = fActual;
1047 rgPredictions[nIdx0 + i] = (float)m_rgVocabulary.
GetValueAt((
int)fPrediction,
true);
1048 rgPredictions[nIdx1 + i] = (float)m_rgVocabulary.
GetValueAt((
int)fActual,
true);
1052 if (sw.Elapsed.TotalMilliseconds > 1000)
1054 double dfPct = (double)i / (
double)nN;
1055 m_mycaffe.Log.Enable =
true;
1056 m_mycaffe.Log.Progress = dfPct;
1057 m_mycaffe.Log.WriteLine(
"Running at " + dfPct.ToString(
"P") +
" complete...");
1058 m_mycaffe.Log.Enable =
false;
1062 if (m_mycaffe.CancelEvent.WaitOne(0))
1069 List<T> rgInput = getInitialInput(m_bIsDataReal);
1072 for (
int i = 0; i < nN; i++)
1074 T[] rgInputVector =
new T[m_blobData.
count()];
1075 for (
int j = 0; j < m_nSequenceLength; j++)
1078 nIdx = j * m_nBatchSize;
1079 rgInputVector[nIdx] = rgInput[j];
1086 Blob<T> blobOutput = colResults[0];
1088 if (m_blobOutput !=
null)
1089 blobOutput = m_blobOutput;
1093 if (blobLossBtm ==
null)
1100 if (blobLossBtm !=
null)
1101 blobOutput = blobLossBtm;
1105 float fPrediction = getLastPrediction(rgResults, m_rgVocabulary, 1);
1108 rgInput.Add((T)Convert.ChangeType(fPrediction, typeof(T)));
1109 rgInput.RemoveAt(0);
1111 if (m_rgVocabulary ==
null || m_bDisableVocabulary)
1112 rgPredictions[i] = fPrediction;
1114 rgPredictions[i] = (float)m_rgVocabulary.
GetValueAt((
int)fPrediction);
1116 if (sw.Elapsed.TotalMilliseconds > 1000)
1118 double dfPct = (double)i / (
double)nN;
1119 m_mycaffe.Log.Enable =
true;
1120 m_mycaffe.Log.Progress = dfPct;
1121 m_mycaffe.Log.WriteLine(
"Running at " + dfPct.ToString(
"P") +
" complete...");
1122 m_mycaffe.Log.Enable =
false;
1126 if (m_mycaffe.CancelEvent.WaitOne(0))
1131 return rgPredictions;
1133 catch (Exception excpt)
1139 m_mycaffe.Log.Enable =
true;
1143 private float getLastPrediction(
float[] rgDataRaw,
BucketCollection rgVocabulary,
int nLookahead)
1146 int nOffset = (m_nSequenceLength - nLookahead) * m_nBatchSize * m_nVocabSize;
1148 if (m_bDisableVocabulary)
1149 return rgDataRaw[nOffset];
1151 float[] rgData =
new float[m_nVocabSize];
1153 for (
int i = 0; i < rgData.Length; i++)
1155 rgData[i] = rgDataRaw[nOffset + i];
1158 return getLastPrediction(rgData, m_dfRunTemperature);
1161 private int getLastPrediction(
float[] rgData,
double dfTemperature)
1163 int nIdx = m_nVocabSize - 1;
1166 if (dfTemperature == 0)
1168 nIdx = ArgMax(rgData, 0, m_nVocabSize);
1173 double[] rgAccumulatedProba =
new double[m_nVocabSize];
1174 double[] rgProba =
new double[m_nVocabSize];
1175 double dfExpoSum = 0;
1177 double dfMax = rgData.Max();
1178 for (
int i = 0; i < m_nVocabSize; i++)
1181 rgProba[i] = Math.Exp((rgData[i] - dfMax) / dfTemperature);
1182 dfExpoSum += rgProba[i];
1185 rgProba[0] /= dfExpoSum;
1186 rgAccumulatedProba[0] = rgProba[0];
1190 for (
int i = 1; i < rgProba.Length; i++)
1193 if (rgAccumulatedProba[i - 1] > dfRandom)
1199 rgProba[i] /= dfExpoSum;
1200 rgAccumulatedProba[i] = rgAccumulatedProba[i - 1] + rgProba[i];
1204 if (nIdx < 0 || nIdx > m_nVocabSize)
1205 throw new Exception(
"Invalid index - out of the vocabulary range of [0," + m_nVocabSize.ToString() +
"]");
1210 private int ArgMax(
float[] rg,
int nOffset,
int nCount)
1215 int nMaxIdx = nOffset;
1216 float fMax = rg[nOffset];
1218 for (
int i = nOffset; i < nOffset + nCount; i++)
1227 return nMaxIdx - nOffset;
1231 class DataCollectionPool
1233 List<DataCollector> m_rgCollectors =
new List<DataCollector>();
1235 public DataCollectionPool()
1239 public void Initialize(
int nThreads, IxTrainerCallback icallback)
1241 for (
int i = 0; i < nThreads; i++)
1243 m_rgCollectors.Add(
new DataCollector(icallback));
1247 public void Shutdown()
1249 foreach (DataCollector col
in m_rgCollectors)
1255 public bool Run(List<GetDataArgs> rgStartup)
1257 List<ManualResetEvent> rgWait =
new List<ManualResetEvent>();
1259 if (rgStartup.Count != m_rgCollectors.Count)
1260 throw new Exception(
"The startup count does not match the collector count.");
1262 for (
int i = 0; i < rgStartup.Count; i++)
1264 rgWait.Add(rgStartup[i].DataReady);
1265 m_rgCollectors[i].Run(rgStartup[i]);
1268 return WaitHandle.WaitAll(rgWait.ToArray(), 10000);
1274 ManualResetEvent m_evtAbort =
new ManualResetEvent(
false);
1275 AutoResetEvent m_evtRun =
new AutoResetEvent(
false);
1278 IxTrainerCallback m_icallback;
1280 public DataCollector(IxTrainerCallback icallback)
1282 m_icallback = icallback;
1283 m_thread =
new Thread(
new ThreadStart(doWork));
1287 public void CleanUp()
1292 public void Run(GetDataArgs args)
1298 private void doWork()
1301 List<WaitHandle> rgWait =
new List<WaitHandle>();
1302 rgWait.Add(m_evtAbort);
1303 rgWait.Add(m_evtRun);
1307 int nWait = WaitHandle.WaitAny(rgWait.ToArray());
1311 m_icallback.OnGetData(m_args);
1312 m_args.DataReady.Set();
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
CancelEvent CancelEvent
Returns the CancelEvent used.
The BucketCollection contains a set of Buckets.
int FindIndex(double dfVal)
Finds the index of the Bucket containing the value.
int Count
Returns the number of Buckets.
bool IsDataReal
Get/set whether or not the Buckets hold Real values.
double GetValueAt(int nIdx, bool bUseMidPoint=false)
Returns the average of the Bucket at a given index.
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
void Reset()
Resets the event clearing any signaled state.
CancelEvent()
The CancelEvent constructor.
void Set()
Sets the event to the signaled state.
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
int Next(int nMinVal, int nMaxVal, bool bMaxInclusive=true)
Returns a random int within the range
double NextDouble()
Returns a random double within the range .
The Log class provides general output in text form.
Log(string strSrc)
The Log constructor.
Specifies a key-value pair of properties.
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
byte[] GetPropertyBlob(string strName, bool bThrowExceptions=true)
Returns a property blob as a byte array value.
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
double GetPropertyAsDouble(string strName, double dfDefault=0)
Returns a property as an double value.
void SetProperty(string strName, string strVal)
Sets a property in the property set to a value if it exists, otherwise it adds the new property.
The SimpleDatum class holds a data input within host memory.
void Copy(SimpleDatum d, bool bCopyData, int? nHeight=null, int? nWidth=null)
Copy another SimpleDatum into this one.
int Width
Return the width of the data.
int Height
Return the height of the data.
The Utility class provides general utility funtions.
static string Replace(string str, char ch1, char ch2)
Replaces each instance of one character with another character in a given string.
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
The BlobCollection contains a list of Blobs.
The Blob is the main holder of data that moves through the Layers of the Net.
T[] mutable_cpu_data
Get data from the GPU and bring it over to the host, or Set data from the Host and send it over to th...
void Reshape(int nNum, int nChannels, int nHeight, int nWidth, bool? bUseHalfSize=null)
DEPRECIATED; use
BLOB_TYPE type
Returns the BLOB_TYPE of the Blob.
List< int > shape()
Returns an array where each element contains the shape of an axis of the Blob.
T[] update_cpu_data()
Update the CPU data by transferring the GPU data over to the Host.
int count()
Returns the total number of items in the Blob.
int num
DEPRECIATED; legacy shape accessor num: use shape(0) instead.
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
List< Layer< T > > layers
Returns the layers.
void Reshape()
Reshape all layers from the bottom to the top.
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
Layer< T > FindLastLayer(LayerParameter.LayerType type)
Find the last layer with the matching type.
Blob< T > FindBlob(string strName)
Finds a Blob in the Net by name.
Blob< T > FindLossBottomBlob()
Find the bottom blob of the Loss layer if it exists, otherwise null is returned.
The TestResultArgs are passed to the Solver::OnTestResults event.
BlobCollection< T > Results
Returns the results from the test.
double Accuracy
Get/set the accuracy. The recipient of this event should set this value.
Specifies the TestingIterationArgs sent to the Solver::OnTestingIteration, which is called at the end...
double Accuracy
Return the accuracy of the test cycle.
int Iteration
Return the iteration of the test cycle.
The TrainingIterationArgs is sent to the Solver::OnTrainingIteration event that fires at the end of a...
double LearningRate
Return the current learning rate.
double SmoothedLoss
Retunrs the average loss after the training cycle.
The LSTMAttentionLayer adds attention to the long-short term memory layer and is used in encoder/deco...
The LSTMLayer processes sequential inputs using a 'Long Short-Term Memory' (LSTM) [1] style recurrent...
[DEPRECIATED - use LSTMAttentionLayer instead with enable_attention = false] The LSTMSimpleLayer is a...
An interface for the units of computation which can be composed into a Net.
LayerParameter layer_param
Returns the LayerParameter for this Layer.
uint num_output
The number of outputs for the layer.
uint batch_size
Specifies the batch size, default = 1.
Specifies the base parameter for all layers.
LayerType type
Specifies the type of this LayerParameter.
LSTMSimpleParameter lstm_simple_param
[DEPRECIATED] Returns the parameter set when initialized with LayerType.LSTM_SIMPLE
InnerProductParameter inner_product_param
Returns the parameter set when initialized with LayerType.INNERPRODUCT
LayerType
Specifies the layer type.
int max_iter
The maximum number of iterations.
List< int > test_iter
The number of iterations for each test.
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
EventHandler< TrainingIterationArgs< T > > OnTrainingIteration
The OnTrainingIteration event fires at the end of each training iteration.
int MaximumIteration
Returns the maximum training iterations.
EventHandler< TestingIterationArgs< T > > OnTestingIteration
The OnTestingIteration event fires at the end of each testing iteration.
SolverParameter parameter
Returns the SolverParameter used.
double TestAll(int nIterationOverride=-1)
Run a TestAll by running all test Nets.
EventHandler< TestResultArgs< T > > OnTestResults
When specified, the OnTestResults event fires after each single test run. The recipient is responsibl...
EventHandler OnTestStart
The OnTestStart event fires at the start of each testing iteration.
virtual void Solve(int nIterationOverride=-1, byte[] rgWeights=null, byte[] rgState=null, TRAIN_STEP step=TRAIN_STEP.NONE)
The main entry of the solver function. In default, iter will be zero. Pass in a non-zero iter number ...
EventHandler OnStart
The OnStart event fires at the start of each training iteration.
The InitializeArgs is passed to the OnInitialize event.
The WaitArgs is passed to the OnWait event.
The TrainerRNN implements a simple RNN trainer inspired by adepierre's GitHub site referenced.
bool Initialize()
Initialize the trainer.
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network using a modified PG training algorithm optimized for GPU use.
bool Shutdown(int nWait)
Shutdown the trainer.
float[] Run(int nN, PropertySet runProp)
Run a single cycle on the environment after the delay.
bool Test(int nN, ITERATOR_TYPE type)
Run the test cycle - currently this is not implemented.
void Dispose()
Releases all resources used.
TrainerRNN(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback, BucketCollection rgVocabulary)
The constructor.
byte[] Run(int nN, PropertySet runProp, out string type)
Run a single cycle on the environment after the delay.
The IxTrainerCallback provides functions used by each trainer to 'call-back' to the parent for inform...
The IxTrainerCallbackRNN provides functions used by each trainer to 'call-back' to the parent for inf...
The IxTrainerRL interface is implemented by each RL Trainer.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Phase
Defines the Phase under which to run a Net.
The MyCaffe.common namespace contains common MyCaffe classes.
BLOB_TYPE
Defines the tpe of data held by a given Blob.
TRAIN_STEP
Defines the training stepping method (if any).
The MyCaffe.fillers namespace contains all fillers including the Filler class.
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
ITERATOR_TYPE
Specifies the iterator type to use.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...