2using System.Collections.Generic;
 
    8using System.Threading.Tasks;
 
   33        bool m_bUsePreloadData = 
true;
 
   50            m_icallback = icallback;
 
   52            m_properties = properties;
 
   54            m_rgVocabulary = rgVocabulary;            
 
   76        private void wait(
int nWait)
 
   81            while (nTotalWait < nWait)
 
   83                m_icallback.OnWait(
new WaitArgs(nWaitInc));
 
   84                nTotalWait += nWaitInc;
 
   95            if (m_mycaffe != 
null)
 
  101            m_icallback.OnShutdown();
 
  116            Agent<T> agent = 
new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random, 
Phase.RUN, m_rgVocabulary, m_bUsePreloadData, runProp);
 
  117            float[] rgResults = agent.Run(nN);
 
  133            Agent<T> agent = 
new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random, 
Phase.RUN, m_rgVocabulary, m_bUsePreloadData, runProp);
 
  134            byte[] rgResults = agent.Run(nN, out type);
 
  151            Agent<T> agent = 
new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random, 
Phase.TEST, m_rgVocabulary, m_bUsePreloadData);
 
  170            Agent<T> agent = 
new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random, 
Phase.TRAIN, m_rgVocabulary, m_bUsePreloadData);
 
  171            agent.Run(
Phase.TRAIN, nN, type, step);
 
  179    class Agent<T> : IDisposable 
 
  188            m_icallback = icallback;
 
  189            m_brain = 
new Brain<T>(mycaffe, properties, random, icallback as 
IxTrainerCallbackRNN, phase, rgVocabulary, bUsePreloadData, runProp);
 
  190            m_properties = properties;
 
  194        public void Dispose()
 
  203        private StateBase getData(
Phase phase, 
int nAction)
 
  205            GetDataArgs args = m_brain.getDataArgs(phase, 0, nAction, 
true);
 
  206            m_icallback.OnGetData(args);
 
  225                throw new Exception(
"The TrainerRNN only supports the ITERATION type.");
 
  227            StateBase s = getData(phase, -1);
 
  229            while (!m_brain.Cancel.WaitOne(0) && !s.Done)
 
  231                if (phase == 
Phase.TEST)
 
  233                else if (phase == 
Phase.TRAIN)
 
  234                    m_brain.Train(s, nN, step);
 
  236                s = getData(phase, 1);
 
  245        public float[] Run(
int nN)
 
  247            return m_brain.Run(nN);
 
  256        public byte[] Run(
int nN, out 
string type)
 
  258            float[] rgResults = m_brain.Run(nN);
 
  260            ConvertOutputArgs args = 
new ConvertOutputArgs(nN, rgResults);
 
  261            IxTrainerCallbackRNN icallback = m_icallback as IxTrainerCallbackRNN;
 
  262            if (icallback == 
null)
 
  263                throw new Exception(
"The Run method requires an IxTrainerCallbackRNN interface to convert the results into the native format!");
 
  265            icallback.OnConvertOutput(args);
 
  268            return args.RawOutput;
 
  272    class Brain<T> : IDisposable 
 
  274        IxTrainerCallbackRNN m_icallback;
 
  275        MyCaffeControl<T> m_mycaffe;
 
  284        int m_nSequenceLength;
 
  285        int m_nSequenceLengthLabel;
 
  287        int m_nVocabSize = 1;
 
  293        double m_dfRunTemperature = 0;
 
  294        double m_dfTestTemperature = 0;
 
  295        byte[] m_rgTestData = 
null;
 
  296        byte[] m_rgTrainData = 
null;
 
  297        float[] m_rgfTestData = 
null;
 
  298        float[] m_rgfTrainData = 
null;
 
  299        bool m_bIsDataReal = 
false;
 
  300        Stopwatch m_sw = 
new Stopwatch();
 
  301        double m_dfLastLoss = 0;
 
  302        double m_dfLastLearningRate = 0;
 
  304        bool m_bUsePreloadData = 
true;
 
  305        bool m_bDisableVocabulary = 
false;
 
  308        int m_nSolverSequenceLength = -1;
 
  310        DataCollectionPool m_dataPool = 
new DataCollectionPool();
 
  311        double m_dfScale = 1.0;
 
  315            string strOutputBlob = 
null;
 
  318                m_runProperties = runProp;
 
  320            m_icallback = icallback;
 
  322            m_properties = properties;
 
  324            m_rgVocabulary = rgVocabulary;
 
  325            m_bUsePreloadData = bUsePreloadData;
 
  326            m_nSolverSequenceLength = m_properties.
GetPropertyAsInt(
"SequenceLength", -1);
 
  327            m_bDisableVocabulary = m_properties.
GetPropertyAsBool(
"DisableVocabulary", 
false);
 
  333                m_dataPool.Initialize(m_nThreads, icallback);
 
  335            if (m_runProperties != 
null)
 
  338                if (m_dfRunTemperature > 1.0)
 
  339                    m_dfRunTemperature = 1.0;
 
  341                string strPhaseOnRun = m_runProperties.
GetProperty(
"PhaseOnRun", 
false);
 
  342                switch (strPhaseOnRun)
 
  345                        m_phaseOnRun = 
Phase.RUN;
 
  349                        m_phaseOnRun = 
Phase.TEST;
 
  353                        m_phaseOnRun = 
Phase.TRAIN;
 
  357                if (phase == 
Phase.RUN && m_phaseOnRun != 
Phase.NONE)
 
  359                    if (m_phaseOnRun != 
Phase.RUN)
 
  360                        m_mycaffe.Log.WriteLine(
"Warning: Running on the '" + m_phaseOnRun.ToString() + 
"' network.");
 
  362                    strOutputBlob = m_runProperties.
GetProperty(
"OutputBlob", 
false);
 
  363                    if (strOutputBlob == 
null)
 
  364                        throw new Exception(
"You must specify the 'OutputBlob' when Running with a phase other than RUN.");
 
  368                    phase = m_phaseOnRun;
 
  372            m_net = mycaffe.GetInternalNet(phase);
 
  375                mycaffe.Log.WriteLine(
"WARNING: Test net does not exist, set test_iteration > 0.  Using TRAIN phase instead.");
 
  376                m_net = mycaffe.GetInternalNet(
Phase.TRAIN);
 
  407            if (lstmLayer == 
null && lstmAttentionLayer == 
null && lstmSimpleLayer == 
null)
 
  408                throw new Exception(
"Could not find the required LSTM or LSTM_ATTENTION or LSTM_SIMPLE layer!");
 
  410            if (m_phaseOnRun != 
Phase.NONE && m_phaseOnRun != 
Phase.RUN && strOutputBlob != 
null)
 
  412                if ((m_blobOutput = m_net.
FindBlob(strOutputBlob)) == 
null)
 
  413                    throw new Exception(
"Could not find the 'Output' layer top named '" + strOutputBlob + 
"'!");
 
  416            if ((m_blobData = m_net.
FindBlob(
"data")) == 
null)
 
  417                throw new Exception(
"Could not find the 'Input' layer top named 'data'!");
 
  419            if ((m_blobClip = m_net.
FindBlob(
"clip")) == 
null)
 
  420                throw new Exception(
"Could not find the 'Input' layer top named 'clip'!");
 
  423            m_mycaffe.Log.CHECK(layer != 
null, 
"Could not find an ending INNERPRODUCT layer!");
 
  425            if (!m_bDisableVocabulary)
 
  428                if (rgVocabulary != 
null)
 
  429                    m_mycaffe.Log.CHECK_EQ(m_nVocabSize, rgVocabulary.
Count, 
"The vocabulary count = '" + rgVocabulary.
Count.ToString() + 
"' and last inner product output count = '" + m_nVocabSize.ToString() + 
"' - these do not match but they should!");
 
  434                m_nSequenceLength = m_blobData.
shape(0);
 
  435                m_nBatchSize = m_blobData.
shape(1);
 
  440                m_nSequenceLength = m_blobData.
shape(0) / m_nBatchSize;
 
  442                if (phase == 
Phase.RUN)
 
  446                    List<int> rgNewShape = 
new List<int>() { m_nSequenceLength, 1 };
 
  447                    m_blobData.
Reshape(rgNewShape);
 
  448                    m_blobClip.
Reshape(rgNewShape);
 
  453            m_mycaffe.Log.CHECK_EQ(m_nSequenceLength, m_blobData.
num, 
"The data num must equal the sequence lengh of " + m_nSequenceLength.ToString());
 
  455            m_rgDataInput = 
new T[m_nSequenceLength * m_nBatchSize];
 
  457            T[] rgClipInput = 
new T[m_nSequenceLength * m_nBatchSize];
 
  458            m_mycaffe.Log.CHECK_EQ(rgClipInput.Length, m_blobClip.
count(), 
"The clip count must equal the sequence length * batch size: " + rgClipInput.Length.ToString());
 
  459            m_tZero = (T)Convert.ChangeType(0, typeof(T));
 
  460            m_tOne = (T)Convert.ChangeType(1, typeof(T));
 
  462            for (
int i = 0; i < rgClipInput.Length; i++)
 
  465                    rgClipInput[i] = (i < m_nBatchSize) ? m_tZero : m_tOne;
 
  467                    rgClipInput[i] = (i % m_nSequenceLength == 0) ? m_tZero : m_tOne;
 
  472            if (phase != 
Phase.RUN)
 
  474                m_solver = mycaffe.GetInternalSolver();
 
  475                m_solver.
OnStart += m_solver_OnStart;
 
  481                if ((m_blobLabel = m_net.
FindBlob(
"label")) == 
null)
 
  482                    throw new Exception(
"Could not find the 'Input' layer top named 'label'!");
 
  484                m_nSequenceLengthLabel = m_blobLabel.
count(0, 2);
 
  485                m_rgLabelInput = 
new T[m_nSequenceLengthLabel];
 
  486                m_mycaffe.Log.CHECK_EQ(m_rgLabelInput.Length, m_blobLabel.
count(), 
"The label count must equal the label sequence length * batch size: " + m_rgLabelInput.Length.ToString());
 
  487                m_mycaffe.Log.CHECK(m_nSequenceLengthLabel == m_nSequenceLength * m_nBatchSize || m_nSequenceLengthLabel == 1, 
"The label sqeuence length must be 1 or equal the length of the sequence: " + m_nSequenceLength.ToString());
 
  493            if (m_sw.Elapsed.TotalMilliseconds > 1000)
 
  504            if (m_sw.Elapsed.TotalMilliseconds > 1000)
 
  511        private void dispose(ref 
Blob<T> b)
 
  520        public void Dispose()
 
  522            if (m_dataPool != 
null)
 
  524                m_dataPool.Shutdown();
 
  529        private void updateStatus(
int nIteration, 
int nMaxIteration, 
double dfAccuracy, 
double dfLoss, 
double dfLearningRate)
 
  531            GetStatusArgs args = 
new GetStatusArgs(0, nIteration, nIteration, nMaxIteration, dfAccuracy, 0, 0, 0, dfLoss, dfLearningRate);
 
  532            m_icallback.OnUpdateStatus(args);
 
  535        public GetDataArgs getDataArgs(
Phase phase, 
int nIdx, 
int nAction, 
bool bGetLabel = 
false, 
int nBatchSize = 1)
 
  537            bool bReset = (nAction == -1) ? 
true : 
false;
 
  538            return new GetDataArgs(phase, nIdx, m_mycaffe, m_mycaffe.Log, m_mycaffe.CancelEvent, bReset, nAction, 
false, bGetLabel, (nBatchSize > 1) ? 
true : 
false);
 
  543            get { 
return m_mycaffe.
Log; }
 
  551        private void copyData(
SimpleDatum sd, 
int nSrcOffset, 
float[] rgfDst, 
int nCount)
 
  554            float[] rgfSrc = sd.GetData<
float>();
 
  558                Array.
Copy(rgfSrc, nSrcOffset, rgfDst, 0, nCount);
 
  562                for (
int i = 0; i < nCount; i++)
 
  564                    rgfDst[i] = rgfSrc[(nSrcOffset + i) * nDim];
 
  569        private void getRawData(StateBase s)
 
  571            int nTestLen = (int)(s.Data.Channels * s.TestingPercent);
 
  572            int nTrainLen = s.Data.Channels - nTestLen;
 
  574            if (s.Data.IsRealData)
 
  576                m_bIsDataReal = 
true;
 
  578                m_rgfTrainData = 
new float[nTrainLen];
 
  579                copyData(s.Data, 0, m_rgfTrainData, nTrainLen);
 
  583                    m_rgfTestData = m_rgfTrainData;
 
  587                    m_rgfTestData = 
new float[nTestLen];
 
  588                    copyData(s.Data, nTrainLen, m_rgfTestData, nTestLen);
 
  593                int nDim = s.Data.Height * s.Data.Width;
 
  595                    throw new Exception(
"When training on binary data the height and width must = 1.");
 
  597                m_bIsDataReal = 
false;
 
  598                m_rgTrainData = 
new byte[nTrainLen];
 
  599                Array.Copy(s.Data.ByteData, 0, m_rgTrainData, 0, nTrainLen);
 
  603                    m_rgTestData = m_rgTrainData;
 
  607                    m_rgTestData = 
new byte[nTestLen];
 
  608                    Array.Copy(s.Data.ByteData, nTrainLen, m_rgTestData, 0, nTestLen);
 
  613        public void Test(StateBase s, 
int nIterations)
 
  615            if (nIterations <= 0)
 
  628        private void m_solver_OnTestStart(
object sender, EventArgs e)
 
  640        private int getLabel(
float[] rgfScores, 
int nIdx, 
int nDim)
 
  642            float[] rgfLastScores = 
new float[nDim];
 
  643            int nStartIdx = nIdx * nDim;
 
  645            for (
int i = 0; i < nDim; i++)
 
  647                rgfLastScores[i] = rgfScores[nStartIdx + i];
 
  650            return getLastPrediction(rgfLastScores, m_dfTestTemperature);
 
  659            if (nNum != m_rgLabelInput.Length)
 
  662            int nDim = e.
Results[1].count(1);
 
  664            float[] rgfScores = 
Utility.ConvertVecF<T>(e.
Results[1].mutable_cpu_data);
 
  665            int nCorrectCount = 0;
 
  666            for (
int i = 0; i < m_nBatchSize; i++)
 
  668                int nIdx = (m_lstmType == 
LayerParameter.
LayerType.LSTM_SIMPLE) ? (i * m_nSequenceLength + m_nSequenceLength - 1) : ((nNum - m_nBatchSize) + i);
 
  669                int nExpectedLabel = (int)
Utility.ConvertVal<T>(m_rgLabelInput[nIdx]);
 
  670                int nActualLabel = getLabel(rgfScores, nIdx, nDim);
 
  671                bool bHandled = 
false;
 
  673                TestAccuracyUpdateArgs args = 
new TestAccuracyUpdateArgs(nActualLabel, nExpectedLabel);
 
  674                m_icallback.OnTestAccuracyUpdate(args);
 
  684                    if (nExpectedLabel == nActualLabel)
 
  689            e.
Accuracy = (double)nCorrectCount / m_nBatchSize;
 
  692        public void Train(StateBase s, 
int nIterations, 
TRAIN_STEP step)
 
  694            if (nIterations <= 0)
 
  699            m_solver.
Solve(nIterations, 
null, 
null, step);
 
  702        private void m_solver_OnStart(
object sender, EventArgs e)
 
  707        public void FeedNet(
bool bTrain)
 
  716                if (m_bUsePreloadData)
 
  718                    float[] rgfData = (bTrain) ? m_rgfTrainData : m_rgfTestData;
 
  721                    for (
int i = 0; i < m_nBatchSize; i++)
 
  723                        int nCurrentValIdx = m_random.
Next(rgfData.Length - m_nSequenceLength - 1);
 
  725                        for (
int j = 0; j < m_nSequenceLength; j++)
 
  728                            double dfData = rgfData[nCurrentValIdx + j];
 
  730                            double dfLabel = rgfData[nCurrentValIdx + j + 1]; 
 
  731                            float fDataIdx = findIndex(dfData, out bFound);
 
  732                            float fLabelIdx = findIndex(dfLabel, out bFound);
 
  737                                nIdx = m_nBatchSize * j + i;
 
  742                                nIdx = i * m_nBatchSize + j;
 
  744                            m_rgDataInput[nIdx] = (T)Convert.ChangeType(fDataIdx, typeof(T));
 
  746                            if (m_nSequenceLengthLabel == (m_nSequenceLength * m_nBatchSize) || j == m_nSequenceLength - 1)
 
  747                                m_rgLabelInput[nIdx] = (T)Convert.ChangeType(fLabelIdx, typeof(T));
 
  756                    m_mycaffe.Log.CHECK_EQ(m_nBatchSize, m_nThreads, 
"The 'Threads' setting of " + m_nThreads.ToString() + 
" must match the batch size = " + m_nBatchSize.ToString() + 
"!");
 
  758                    List<GetDataArgs> rgDataArgs = 
new List<GetDataArgs>();
 
  760                    if (m_nBatchSize == 1)
 
  762                        GetDataArgs e = getDataArgs(phase, 0, 0, 
true, m_nBatchSize);
 
  763                        m_icallback.OnGetData(e);
 
  768                        for (
int i = 0; i < m_nBatchSize; i++)
 
  770                            rgDataArgs.Add(getDataArgs(phase, i, 0, 
true, m_nBatchSize));
 
  773                        if (!m_dataPool.Run(rgDataArgs))
 
  774                            m_mycaffe.Log.FAIL(
"Data Time Out - Failed to collect all data to build the RNN batch!");
 
  777                    double[] rgData = rgDataArgs[0].State.Data.GetData<
double>();
 
  778                    double[] rgLabel = rgDataArgs[0].State.Label.GetData<
double>();
 
  779                    double[] rgClip = rgDataArgs[0].State.Clip.GetData<
double>();
 
  781                    int nDataLen = rgData.Length;
 
  782                    int nLabelLen = rgLabel.Length;
 
  783                    int nClipLen = rgClip.Length;
 
  784                    int nDataItem = nDataLen / nLabelLen;
 
  786                    if (m_nBatchSize > 1)
 
  788                        rgData = 
new double[nDataLen * m_nBatchSize];
 
  789                        rgLabel = 
new double[nLabelLen * m_nBatchSize];
 
  790                        rgClip = 
new double[nClipLen * m_nBatchSize];
 
  792                        for (
int i = 0; i < m_nBatchSize; i++)
 
  794                            for (
int j = 0; j < m_nSequenceLength; j++)
 
  799                                    nIdx = m_nBatchSize * j + i;
 
  804                                    nIdx = i * m_nBatchSize + j;
 
  806                                Array.Copy(rgDataArgs[i].State.Data.GetData<
double>(), 0, rgData, nIdx * nDataItem, nDataItem);
 
  807                                rgLabel[nIdx] = rgDataArgs[i].State.Label.GetDataAtD(j);
 
  808                                rgClip[nIdx] = rgDataArgs[i].State.Clip.GetDataAtD(j);
 
  813                    string strSolverErr = 
"";
 
  814                    if (m_nSolverSequenceLength >= 0 && m_nSolverSequenceLength != m_nSequenceLength)
 
  815                        strSolverErr = 
"The solver parameter 'SequenceLength' length of " + m_nSolverSequenceLength.ToString() + 
" must match the model sequence length of " + m_nSequenceLength.ToString() + 
".  ";
 
  817                    int nExpectedCount = m_blobData.
count();
 
  818                    m_mycaffe.Log.CHECK_EQ(nExpectedCount, rgData.Length, strSolverErr + 
"The size of the data received ('" + rgData.Length.ToString() + 
"') does mot match the expected data count of '" + nExpectedCount.ToString() + 
"'!");
 
  821                    nExpectedCount = m_blobLabel.
count();
 
  822                    m_mycaffe.Log.CHECK_EQ(nExpectedCount, rgLabel.Length, strSolverErr + 
"The size of the label received ('" + rgLabel.Length.ToString() + 
"') does not match the expected label count of '" + nExpectedCount.ToString() + 
"'!");
 
  825                    nExpectedCount = m_blobClip.
count();
 
  826                    m_mycaffe.Log.CHECK_EQ(nExpectedCount, rgClip.Length, strSolverErr + 
"The size of the clip received ('" + rgClip.Length.ToString() + 
"') does not match the expected clip count of '" + nExpectedCount.ToString() + 
"'!");
 
  833                byte[] rgData = (bTrain) ? m_rgTrainData : m_rgTestData;
 
  840                for (
int i = 0; i < m_nBatchSize; i++)
 
  842                    int nCurrentCharIdx = m_random.
Next(rgData.Length - m_nSequenceLength - 2);
 
  844                    for (
int j = 0; j < m_nSequenceLength; j++)
 
  847                        byte bData = rgData[nCurrentCharIdx + j];
 
  849                        byte bLabel = rgData[nCurrentCharIdx + j + 1]; 
 
  850                        float fDataIdx = findIndex(bData, out bFound);
 
  851                        float fLabelIdx = findIndex(bLabel, out bFound);
 
  856                            nIdx = m_nBatchSize * j + i;
 
  861                            nIdx = i * m_nBatchSize + j;
 
  863                        m_rgDataInput[nIdx] = (T)Convert.ChangeType(fDataIdx, typeof(T));
 
  865                        if (m_nSequenceLengthLabel == (m_nSequenceLength * m_nBatchSize) || j == m_nSequenceLength - 1)
 
  866                            m_rgLabelInput[nIdx] = (T)Convert.ChangeType(fLabelIdx, typeof(T));
 
  875        private float findIndex(
byte b, out 
bool bFound)
 
  879            if (m_rgVocabulary == 
null || m_bDisableVocabulary)
 
  887        private float findIndex(
double df, out 
bool bFound)
 
  891            if (m_rgVocabulary == 
null || m_bDisableVocabulary)
 
  897        private List<T> getInitialInput(
bool bIsReal)
 
  899            List<T> rgInput = 
new List<T>();
 
  900            float[] rgCorrectLengthSequence = 
new float[m_nSequenceLength];
 
  902            for (
int i = 0; i < m_nSequenceLength; i++)
 
  904                rgCorrectLengthSequence[i] = (int)m_random.
Next(m_nVocabSize);
 
  908            bool bDataNeeded = 
true;
 
  909            if (!bIsReal && m_runProperties != 
null)
 
  913                if (rgSeed != 
null && rgSeed.Length > 0)
 
  915                    int nLen = rgSeed.Length;
 
  916                    if (rgSeed[nLen - 1] == 0)
 
  919                    int nStart = rgCorrectLengthSequence.Length - nLen;
 
  923                    for (
int i = nStart; i < rgCorrectLengthSequence.Length; i++)
 
  925                        byte bVal = rgSeed[i - nStart];
 
  927                        int nIdx = (int)findIndex(bVal, out bFound);
 
  930                            rgCorrectLengthSequence[i] = nIdx;
 
  937            if (bDataNeeded && m_runProperties != 
null)
 
  939                GetDataArgs e = getDataArgs(
Phase.RUN, 0, 0, 
false, m_nSequenceLength);
 
  940                e.ExtraProperties = m_runProperties;
 
  941                e.ExtraProperties.
SetProperty(
"DataCountRequested", m_nSequenceLength.ToString());
 
  942                m_icallback.OnGetData(e);
 
  944                if (e.State.Data != 
null)
 
  946                    float[] rgf = e.State.Data.GetData<
float>();
 
  947                    int nDim = e.State.Data.Height * e.State.Data.Width;
 
  952                    for (
int i = 0; i < rgCorrectLengthSequence.Length; i++)
 
  954                        float fChar = rgf[i * nDim];
 
  957                        rgCorrectLengthSequence[i] = fChar;
 
  965                m_mycaffe.Log.WriteLine(
"WARNING: No seed data found - using random data.");
 
  967            for (
int i = 0; i < rgCorrectLengthSequence.Length; i++)
 
  969                rgInput.Add((T)Convert.ChangeType(rgCorrectLengthSequence[i], typeof(T)));
 
  975        public float[] Run(
int nN)
 
  979                Stopwatch sw = 
new Stopwatch();
 
  980                float[] rgPredictions = 
new float[nN];
 
  984                m_bIsDataReal = 
true;
 
  986                if (m_rgVocabulary != 
null)
 
  989                m_mycaffe.Log.Enable = 
false;
 
  991                if (m_bIsDataReal && !m_bUsePreloadData)
 
  993                    string strSolverErr = 
"";
 
  995                    if (m_nSolverSequenceLength >= 0 && m_nSolverSequenceLength < m_nSequenceLength)
 
  996                        nLookahead = m_nSequenceLength - m_nSolverSequenceLength;
 
  998                    rgPredictions = 
new float[nN * 2 * nLookahead];
 
 1000                    for (
int i = 0; i < nN; i++)
 
 1002                        GetDataArgs e = getDataArgs(
Phase.RUN, 0, 0, 
true);
 
 1003                        m_icallback.OnGetData(e);
 
 1005                        int nExpectedCount = m_blobData.
count();
 
 1006                        m_mycaffe.Log.CHECK_EQ(nExpectedCount, e.State.Data.ItemCount, strSolverErr + 
"The size of the data received ('" + e.State.Data.ItemCount.ToString() + 
"') does mot match the expected data count of '" + nExpectedCount.ToString() + 
"'!");
 
 1009                        if (m_blobLabel != 
null)
 
 1011                            nExpectedCount = m_blobLabel.
count();
 
 1012                            m_mycaffe.Log.CHECK_EQ(nExpectedCount, e.State.Label.ItemCount, strSolverErr + 
"The size of the label received ('" + e.State.Label.ItemCount.ToString() + 
"') does not match the expected label count of '" + nExpectedCount.ToString() + 
"'!");
 
 1018                        Blob<T> blobOutput = colResults[0];
 
 1020                        if (m_blobOutput != 
null)
 
 1021                            blobOutput = m_blobOutput;
 
 1025                        for (
int j = nLookahead; j > 0; j--)
 
 1027                            float fPrediction = getLastPrediction(rgResults, m_rgVocabulary, j);
 
 1028                            int nIdx = e.State.Label.ItemCount - j;
 
 1029                            float fActual = (float)e.State.Label.GetDataAtF(nIdx);
 
 1031                            int nIdx0 = ((nLookahead - j) * nN * 2);
 
 1032                            int nIdx1 = nIdx0 + nN;
 
 1034                            if (m_dfScale != 1.0 && m_dfScale > 0)
 
 1035                                fActual /= (float)m_dfScale;
 
 1037                            if (m_rgVocabulary == 
null || m_bDisableVocabulary)
 
 1039                                if (m_dfScale != 1.0 && m_dfScale > 0)
 
 1040                                    fPrediction /= (float)m_dfScale;
 
 1042                                rgPredictions[nIdx0 + i] = fPrediction;
 
 1043                                rgPredictions[nIdx1 + i] = fActual;
 
 1047                                rgPredictions[nIdx0 + i] = (float)m_rgVocabulary.
GetValueAt((
int)fPrediction, 
true);
 
 1048                                rgPredictions[nIdx1 + i] = (float)m_rgVocabulary.
GetValueAt((
int)fActual, 
true);
 
 1052                        if (sw.Elapsed.TotalMilliseconds > 1000)
 
 1054                            double dfPct = (double)i / (
double)nN;
 
 1055                            m_mycaffe.Log.Enable = 
true;
 
 1056                            m_mycaffe.Log.Progress = dfPct;
 
 1057                            m_mycaffe.Log.WriteLine(
"Running at " + dfPct.ToString(
"P") + 
" complete...");
 
 1058                            m_mycaffe.Log.Enable = 
false;
 
 1062                        if (m_mycaffe.CancelEvent.WaitOne(0))
 
 1069                    List<T> rgInput = getInitialInput(m_bIsDataReal);
 
 1072                    for (
int i = 0; i < nN; i++)
 
 1074                        T[] rgInputVector = 
new T[m_blobData.
count()];
 
 1075                        for (
int j = 0; j < m_nSequenceLength; j++)
 
 1078                            nIdx = j * m_nBatchSize;
 
 1079                            rgInputVector[nIdx] = rgInput[j];
 
 1086                        Blob<T> blobOutput = colResults[0];
 
 1088                        if (m_blobOutput != 
null)
 
 1089                            blobOutput = m_blobOutput;
 
 1093                            if (blobLossBtm == 
null)
 
 1100                            if (blobLossBtm != 
null)
 
 1101                                blobOutput = blobLossBtm;
 
 1105                        float fPrediction = getLastPrediction(rgResults, m_rgVocabulary, 1);
 
 1108                        rgInput.Add((T)Convert.ChangeType(fPrediction, typeof(T)));
 
 1109                        rgInput.RemoveAt(0);
 
 1111                        if (m_rgVocabulary == 
null || m_bDisableVocabulary)
 
 1112                            rgPredictions[i] = fPrediction;
 
 1114                            rgPredictions[i] = (float)m_rgVocabulary.
GetValueAt((
int)fPrediction);
 
 1116                        if (sw.Elapsed.TotalMilliseconds > 1000)
 
 1118                            double dfPct = (double)i / (
double)nN;
 
 1119                            m_mycaffe.Log.Enable = 
true;
 
 1120                            m_mycaffe.Log.Progress = dfPct;
 
 1121                            m_mycaffe.Log.WriteLine(
"Running at " + dfPct.ToString(
"P") + 
" complete...");
 
 1122                            m_mycaffe.Log.Enable = 
false;
 
 1126                        if (m_mycaffe.CancelEvent.WaitOne(0))
 
 1131                return rgPredictions;
 
 1133            catch (Exception excpt)
 
 1139                m_mycaffe.Log.Enable = 
true;
 
 1143        private float getLastPrediction(
float[] rgDataRaw, 
BucketCollection rgVocabulary, 
int nLookahead)
 
 1146            int nOffset = (m_nSequenceLength - nLookahead) * m_nBatchSize * m_nVocabSize;
 
 1148            if (m_bDisableVocabulary)
 
 1149                return rgDataRaw[nOffset];
 
 1151            float[] rgData = 
new float[m_nVocabSize];
 
 1153            for (
int i = 0; i < rgData.Length; i++)
 
 1155                rgData[i] = rgDataRaw[nOffset + i];
 
 1158            return getLastPrediction(rgData, m_dfRunTemperature);
 
 1161        private int getLastPrediction(
float[] rgData, 
double dfTemperature)
 
 1163            int nIdx = m_nVocabSize - 1;
 
 1166            if (dfTemperature == 0)
 
 1168                nIdx = ArgMax(rgData, 0, m_nVocabSize);
 
 1173                double[] rgAccumulatedProba = 
new double[m_nVocabSize];
 
 1174                double[] rgProba = 
new double[m_nVocabSize];
 
 1175                double dfExpoSum = 0;
 
 1177                double dfMax = rgData.Max();
 
 1178                for (
int i = 0; i < m_nVocabSize; i++)
 
 1181                    rgProba[i] = Math.Exp((rgData[i] - dfMax) / dfTemperature);
 
 1182                    dfExpoSum += rgProba[i];
 
 1185                rgProba[0] /= dfExpoSum;
 
 1186                rgAccumulatedProba[0] = rgProba[0];
 
 1190                for (
int i = 1; i < rgProba.Length; i++)
 
 1193                    if (rgAccumulatedProba[i - 1] > dfRandom)
 
 1199                    rgProba[i] /= dfExpoSum;
 
 1200                    rgAccumulatedProba[i] = rgAccumulatedProba[i - 1] + rgProba[i];
 
 1204            if (nIdx < 0 || nIdx > m_nVocabSize)
 
 1205                throw new Exception(
"Invalid index - out of the vocabulary range of [0," + m_nVocabSize.ToString() + 
"]");
 
 1210        private int ArgMax(
float[] rg, 
int nOffset, 
int nCount)
 
 1215            int nMaxIdx = nOffset;
 
 1216            float fMax = rg[nOffset];
 
 1218            for (
int i = nOffset; i < nOffset + nCount; i++)
 
 1227            return nMaxIdx - nOffset;
 
 1231    class DataCollectionPool 
 
 1233        List<DataCollector> m_rgCollectors = 
new List<DataCollector>();
 
 1235        public DataCollectionPool()
 
 1239        public void Initialize(
int nThreads, IxTrainerCallback icallback)
 
 1241            for (
int i = 0; i < nThreads; i++)
 
 1243                m_rgCollectors.Add(
new DataCollector(icallback));
 
 1247        public void Shutdown()
 
 1249            foreach (DataCollector col 
in m_rgCollectors)
 
 1255        public bool Run(List<GetDataArgs> rgStartup)
 
 1257            List<ManualResetEvent> rgWait = 
new List<ManualResetEvent>();
 
 1259            if (rgStartup.Count != m_rgCollectors.Count)
 
 1260                throw new Exception(
"The startup count does not match the collector count.");
 
 1262            for (
int i = 0; i < rgStartup.Count; i++)
 
 1264                rgWait.Add(rgStartup[i].DataReady);
 
 1265                m_rgCollectors[i].Run(rgStartup[i]);
 
 1268            return WaitHandle.WaitAll(rgWait.ToArray(), 10000);
 
 1274        ManualResetEvent m_evtAbort = 
new ManualResetEvent(
false);
 
 1275        AutoResetEvent m_evtRun = 
new AutoResetEvent(
false);
 
 1278        IxTrainerCallback m_icallback;
 
 1280        public DataCollector(IxTrainerCallback icallback)
 
 1282            m_icallback = icallback;
 
 1283            m_thread = 
new Thread(
new ThreadStart(doWork));
 
 1287        public void CleanUp()
 
 1292        public void Run(GetDataArgs args)
 
 1298        private void doWork()
 
 1301            List<WaitHandle> rgWait = 
new List<WaitHandle>();
 
 1302            rgWait.Add(m_evtAbort);
 
 1303            rgWait.Add(m_evtRun);
 
 1307                int nWait = WaitHandle.WaitAny(rgWait.ToArray());
 
 1311                m_icallback.OnGetData(m_args);
 
 1312                m_args.DataReady.Set();
 
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
CancelEvent CancelEvent
Returns the CancelEvent used.
The BucketCollection contains a set of Buckets.
int FindIndex(double dfVal)
Finds the index of the Bucket containing the value.
int Count
Returns the number of Buckets.
bool IsDataReal
Get/set whether or not the Buckets hold Real values.
double GetValueAt(int nIdx, bool bUseMidPoint=false)
Returns the average of the Bucket at a given index.
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
void Reset()
Resets the event clearing any signaled state.
CancelEvent()
The CancelEvent constructor.
void Set()
Sets the event to the signaled state.
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
int Next(int nMinVal, int nMaxVal, bool bMaxInclusive=true)
Returns a random int within the range
double NextDouble()
Returns a random double within the range .
The Log class provides general output in text form.
Log(string strSrc)
The Log constructor.
Specifies a key-value pair of properties.
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
byte[] GetPropertyBlob(string strName, bool bThrowExceptions=true)
Returns a property blob as a byte array value.
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
double GetPropertyAsDouble(string strName, double dfDefault=0)
Returns a property as an double value.
void SetProperty(string strName, string strVal)
Sets a property in the property set to a value if it exists, otherwise it adds the new property.
The SimpleDatum class holds a data input within host memory.
void Copy(SimpleDatum d, bool bCopyData, int? nHeight=null, int? nWidth=null)
Copy another SimpleDatum into this one.
int Width
Return the width of the data.
int Height
Return the height of the data.
The Utility class provides general utility funtions.
static string Replace(string str, char ch1, char ch2)
Replaces each instance of one character with another character in a given string.
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
The BlobCollection contains a list of Blobs.
The Blob is the main holder of data that moves through the Layers of the Net.
T[] mutable_cpu_data
Get data from the GPU and bring it over to the host, or Set data from the Host and send it over to th...
void Reshape(int nNum, int nChannels, int nHeight, int nWidth, bool? bUseHalfSize=null)
DEPRECIATED; use
BLOB_TYPE type
Returns the BLOB_TYPE of the Blob.
List< int > shape()
Returns an array where each element contains the shape of an axis of the Blob.
T[] update_cpu_data()
Update the CPU data by transferring the GPU data over to the Host.
int count()
Returns the total number of items in the Blob.
int num
DEPRECIATED; legacy shape accessor num: use shape(0) instead.
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
List< Layer< T > > layers
Returns the layers.
void Reshape()
Reshape all layers from the bottom to the top.
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
Layer< T > FindLastLayer(LayerParameter.LayerType type)
Find the last layer with the matching type.
Blob< T > FindBlob(string strName)
Finds a Blob in the Net by name.
Blob< T > FindLossBottomBlob()
Find the bottom blob of the Loss layer if it exists, otherwise null is returned.
The TestResultArgs are passed to the Solver::OnTestResults event.
BlobCollection< T > Results
Returns the results from the test.
double Accuracy
Get/set the accuracy. The recipient of this event should set this value.
Specifies the TestingIterationArgs sent to the Solver::OnTestingIteration, which is called at the end...
double Accuracy
Return the accuracy of the test cycle.
int Iteration
Return the iteration of the test cycle.
The TrainingIterationArgs is sent to the Solver::OnTrainingIteration event that fires at the end of a...
double LearningRate
Return the current learning rate.
double SmoothedLoss
Retunrs the average loss after the training cycle.
The LSTMAttentionLayer adds attention to the long-short term memory layer and is used in encoder/deco...
The LSTMLayer processes sequential inputs using a 'Long Short-Term Memory' (LSTM) [1] style recurrent...
[DEPRECIATED - use LSTMAttentionLayer instead with enable_attention = false] The LSTMSimpleLayer is a...
An interface for the units of computation which can be composed into a Net.
LayerParameter layer_param
Returns the LayerParameter for this Layer.
uint num_output
The number of outputs for the layer.
uint batch_size
Specifies the batch size, default = 1.
Specifies the base parameter for all layers.
LayerType type
Specifies the type of this LayerParameter.
LSTMSimpleParameter lstm_simple_param
[DEPRECIATED] Returns the parameter set when initialized with LayerType.LSTM_SIMPLE
InnerProductParameter inner_product_param
Returns the parameter set when initialized with LayerType.INNERPRODUCT
LayerType
Specifies the layer type.
int max_iter
The maximum number of iterations.
List< int > test_iter
The number of iterations for each test.
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
EventHandler< TrainingIterationArgs< T > > OnTrainingIteration
The OnTrainingIteration event fires at the end of each training iteration.
int MaximumIteration
Returns the maximum training iterations.
EventHandler< TestingIterationArgs< T > > OnTestingIteration
The OnTestingIteration event fires at the end of each testing iteration.
SolverParameter parameter
Returns the SolverParameter used.
double TestAll(int nIterationOverride=-1)
Run a TestAll by running all test Nets.
EventHandler< TestResultArgs< T > > OnTestResults
When specified, the OnTestResults event fires after each single test run. The recipient is responsibl...
EventHandler OnTestStart
The OnTestStart event fires at the start of each testing iteration.
virtual void Solve(int nIterationOverride=-1, byte[] rgWeights=null, byte[] rgState=null, TRAIN_STEP step=TRAIN_STEP.NONE)
The main entry of the solver function. In default, iter will be zero. Pass in a non-zero iter number ...
EventHandler OnStart
The OnStart event fires at the start of each training iteration.
The InitializeArgs is passed to the OnInitialize event.
The WaitArgs is passed to the OnWait event.
The TrainerRNN implements a simple RNN trainer inspired by adepierre's GitHub site referenced.
bool Initialize()
Initialize the trainer.
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network using a modified PG training algorithm optimized for GPU use.
bool Shutdown(int nWait)
Shutdown the trainer.
float[] Run(int nN, PropertySet runProp)
Run a single cycle on the environment after the delay.
bool Test(int nN, ITERATOR_TYPE type)
Run the test cycle - currently this is not implemented.
void Dispose()
Releases all resources used.
TrainerRNN(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback, BucketCollection rgVocabulary)
The constructor.
byte[] Run(int nN, PropertySet runProp, out string type)
Run a single cycle on the environment after the delay.
The IxTrainerCallback provides functions used by each trainer to 'call-back' to the parent for inform...
The IxTrainerCallbackRNN provides functions used by each trainer to 'call-back' to the parent for inf...
The IxTrainerRL interface is implemented by each RL Trainer.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Phase
Defines the Phase under which to run a Net.
The MyCaffe.common namespace contains common MyCaffe classes.
BLOB_TYPE
Defines the tpe of data held by a given Blob.
TRAIN_STEP
Defines the training stepping method (if any).
The MyCaffe.fillers namespace contains all fillers including the Filler class.
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
ITERATOR_TYPE
Specifies the iterator type to use.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...