9using System.Collections.Generic;
14using System.Threading.Tasks;
47 m_icallback = icallback;
49 m_properties = properties;
78 if (m_mycaffe !=
null)
84 m_icallback.OnShutdown();
89 private void wait(
int nWait)
94 while (nTotalWait < nWait)
96 m_icallback.OnWait(
new WaitArgs(nWaitInc));
97 nTotalWait += nWaitInc;
126 byte[] rgResults = agent.
Run(nN, out type);
141 string strProp = m_properties.
ToString();
144 strProp +=
"EnableNumSkip=False;";
168 agent.
Run(
Phase.TRAIN, nN, type, step);
186 float m_fGamma = 0.95f;
187 bool m_bUseRawInput =
true;
188 int m_nMaxMemory = 50000;
189 int m_nTrainingUpdateFreq = 5000;
190 int m_nExplorationNum = 50000;
192 double m_dfEpsStart = 0;
193 double m_dfEpsEnd = 0;
194 double m_dfEpsDelta = 0;
195 double m_dfExplorationRate = 0;
196 STATE m_state = STATE.EXPLORING;
215 m_icallback = icallback;
216 m_brain =
new Brain<T>(mycaffe, properties, random, phase);
217 m_properties = properties;
223 m_nTrainingUpdateFreq = properties.
GetPropertyAsInt(
"TrainingUpdateFreq", m_nTrainingUpdateFreq);
224 m_nExplorationNum = properties.
GetPropertyAsInt(
"ExplorationNum", m_nExplorationNum);
228 m_dfEpsDelta = (m_dfEpsStart - m_dfEpsEnd) / m_nEpsSteps;
229 m_dfExplorationRate = m_dfEpsStart;
231 if (m_dfEpsStart < 0 || m_dfEpsStart > 1)
232 throw new Exception(
"The 'EpsStart' is out of range - please specify a real number in the range [0,1]");
234 if (m_dfEpsEnd < 0 || m_dfEpsEnd > 1)
235 throw new Exception(
"The 'EpsEnd' is out of range - please specify a real number in the range [0,1]");
237 if (m_dfEpsEnd > m_dfEpsStart)
238 throw new Exception(
"The 'EpsEnd' must be less than the 'EpsStart' value.");
256 m_icallback.OnGetData(args);
268 case STATE.EXPLORING:
269 return m_random.
Next(nActionCount);
272 if (m_dfExplorationRate > m_dfEpsEnd)
273 m_dfExplorationRate -= m_dfEpsDelta;
275 if (m_random.
NextDouble() < m_dfExplorationRate)
276 return m_random.
Next(nActionCount);
281 return m_brain.
act(sd, sdClip, nActionCount);
284 private void updateStatus(
int nIteration,
int nEpisodeCount,
double dfRewardSum,
double dfRunningReward,
double dfLoss,
double dfLearningRate,
bool bModelUpdated)
286 GetStatusArgs args =
new GetStatusArgs(0, nIteration, nEpisodeCount, 1000000, dfRunningReward, dfRewardSum, m_dfExplorationRate, 0, dfLoss, dfLearningRate, bModelUpdated);
287 m_icallback.OnUpdateStatus(args);
296 public byte[]
Run(
int nIterations, out
string type)
299 if (icallback ==
null)
300 throw new Exception(
"The Run method requires an IxTrainerCallbackRNN interface to convert the results into the native format!");
304 List<float> rgResults =
new List<float>();
307 while (!m_brain.
Cancel.
WaitOne(0) && (nIterations == -1 || nIteration < nIterations))
317 rgResults.Add(action);
322 s = getData(
Phase.RUN, action, nIteration);
332 private bool isAtIteration(
int nN,
ITERATOR_TYPE type,
int nIteration,
int nEpisode)
366 MemoryEpisodeCollection rgMemory =
new MemoryEpisodeCollection(m_nMaxMemory);
368 double? dfRunningReward =
null;
369 double dfRewardSum = 0;
371 bool bDifferent =
false;
377 while (!m_brain.
Cancel.
WaitOne(0) && !isAtIteration(nN, type, nIteration, nEpisode))
379 if (nIteration > m_nExplorationNum && rgMemory.Count > m_brain.
BatchSize)
380 m_state = STATE.TRAINING;
386 StateBase s_ = getData(phase, action, nIteration);
391 m_brain.
Log.
WriteLine(
"WARNING: The current state is the same as the previous state!");
396 rgMemory.Add(
new MemoryItem(s, x, action, s_, x_, s_.
Reward, s_.
Done, nIteration, nEpisode));
399 if (m_state == STATE.TRAINING)
401 MemoryCollection rgRandomSamples = rgMemory.GetRandomSamples(m_random, m_brain.
BatchSize);
404 if (nIteration % m_nTrainingUpdateFreq == 0)
411 if (!dfRunningReward.HasValue)
412 dfRunningReward = dfRewardSum;
414 dfRunningReward = dfRunningReward.Value * 0.99 + dfRewardSum * 0.01;
417 updateStatus(nIteration, nEpisode, dfRewardSum, dfRunningReward.Value, 0, 0, m_brain.
GetModelUpdated());
419 s = getData(phase, -1, -1);
420 x = m_brain.
Preprocess(s, m_bUseRawInput, out bDifferent,
true);
457 Blob<T> m_blobActionBinaryLoss =
null;
458 Blob<T> m_blobActionTarget =
null;
462 float[] m_rgfZ =
null;
463 float m_fGamma = 0.99f;
465 double m_dfVMax = 10;
466 double m_dfVMin = -10;
467 int m_nFramesPerX = 4;
468 int m_nStackPerX = 4;
469 int m_nBatchSize = 32;
470 int m_nMiniBatch = 1;
472 bool m_bUseAcceleratedTraining =
false;
473 double m_dfLearningRate;
474 MemoryCollection m_rgSamples;
475 int m_nActionCount = 3;
476 bool m_bModelUpdated =
false;
478 Dictionary<Color, Tuple<Brush, Brush, Pen, Brush>> m_rgStyle =
new Dictionary<Color, Tuple<Brush, Brush, Pen, Brush>>();
479 List<SimpleDatum> m_rgX =
new List<SimpleDatum>();
480 bool m_bNormalizeOverlay =
true;
481 List<List<float>> m_rgOverlay =
null;
496 m_netTarget =
new Net<T>(m_mycaffe.Cuda, m_mycaffe.Log, m_net.
net_param, m_mycaffe.CancelEvent,
null, phase);
497 m_properties = properties;
501 if (m_transformer ==
null)
504 int nC = m_mycaffe.CurrentProject.Dataset.TrainingSource.Channels;
505 int nH = m_mycaffe.CurrentProject.Dataset.TrainingSource.Height;
506 int nW = m_mycaffe.CurrentProject.Dataset.TrainingSource.Width;
507 m_transformer =
new DataTransformer<T>(m_mycaffe.Cuda, m_mycaffe.Log, trans_param, phase, nC, nH, nW);
521 m_blobZ =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log,
false);
522 m_blobZ1 =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log,
false);
523 m_blobQ =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log,
true);
524 m_blobMLoss =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log,
true);
525 m_blobPLoss =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log,
true);
526 m_blobLoss =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log,
true);
527 m_blobActionBinaryLoss =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log,
false);
528 m_blobActionTarget =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log,
false);
529 m_blobAction =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log,
false);
530 m_blobLabel =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log,
true);
533 if (m_memLoss ==
null)
534 m_mycaffe.Log.FAIL(
"Missing the expected MEMORY_LOSS layer!");
537 m_bUseAcceleratedTraining = properties.
GetPropertyAsBool(
"UseAcceleratedTraining",
false);
541 m_mycaffe.Log.FAIL(
"Missing the expected input 'data' blob!");
544 m_nBatchSize = data.
num;
548 if (m_nMiniBatch > 1)
551 m_colAccumulatedGradients.
SetDiff(0);
555 private void dispose(ref
Blob<T> b)
569 dispose(ref m_blobZ);
570 dispose(ref m_blobZ1);
571 dispose(ref m_blobQ);
572 dispose(ref m_blobMLoss);
573 dispose(ref m_blobPLoss);
574 dispose(ref m_blobActionBinaryLoss);
575 dispose(ref m_blobActionTarget);
576 dispose(ref m_blobAction);
577 dispose(ref m_blobLabel);
579 if (m_colAccumulatedGradients !=
null)
581 m_colAccumulatedGradients.
Dispose();
582 m_colAccumulatedGradients =
null;
585 if (m_softmax !=
null)
591 if (m_netTarget !=
null)
603 foreach (KeyValuePair<Color, Tuple<Brush, Brush, Pen, Brush>> kv
in m_rgStyle)
605 kv.Value.Item1.Dispose();
606 kv.Value.Item2.Dispose();
607 kv.Value.Item3.Dispose();
608 kv.Value.Item4.Dispose();
622 bool bReset = (nAction == -1) ?
true :
false;
623 return new GetDataArgs(phase, 0, m_mycaffe, m_mycaffe.Log, m_mycaffe.CancelEvent, bReset, nAction,
true,
false,
false,
this);
631 get {
return m_nFramesPerX; }
639 get {
return m_nBatchSize; }
647 get {
return m_mycaffe.
Log; }
677 if (m_sdLast ==
null)
680 bDifferent = sd.
Sub(m_sdLast);
693 m_rgX =
new List<SimpleDatum>();
695 for (
int i = 0; i < m_nFramesPerX * m_nStackPerX; i++)
708 for (
int i=0; i<m_nStackPerX; i++)
710 int nIdx = ((m_nStackPerX - i) * m_nFramesPerX) - 1;
711 rgSd[i] = m_rgX[nIdx];
717 private float[] createZArray(
double dfVMin,
double dfVMax,
int nAtoms, out
float fDeltaZ)
719 float[] rgZ =
new float[nAtoms];
720 float fZ = (float)dfVMin;
721 fDeltaZ = (float)((dfVMax - dfVMin) / (nAtoms - 1));
723 for (
int i = 0; i < nAtoms; i++)
732 private void createZ(
int nNumSamples,
int nActions,
int nAtoms)
738 m_blobZ1.
Reshape(1, nAtoms, 1, 1);
740 m_rgfZ = createZArray(m_dfVMin, m_dfVMax, m_nAtoms, out m_fDeltaZ);
746 m_blobZ.
Reshape(nActions, m_nBatchSize, nAtoms, 1);
748 for (
int i = 0; i < nActions; i++)
750 for (
int j = 0; j < m_nBatchSize; j++)
753 nOffset += m_blobZ1.
count();
758 m_blobZ.
Reshape(nActions, nNumSamples, nAtoms, 1);
771 setData(m_net, sd, sdClip);
776 throw new Exception(
"Missing expected 'logits' blob!");
778 Blob<T> actions = softmax_forward(logits, m_blobAction);
780 createZ(1, nActionCount, m_nAtoms);
784 reduce_sum_axis2(m_blobQ);
795 bool bModelUpdated = m_bModelUpdated;
796 m_bModelUpdated =
false;
797 return bModelUpdated;
805 m_mycaffe.Log.Enable =
false;
807 m_mycaffe.Log.Enable =
true;
808 m_bModelUpdated =
true;
817 public void Train(
int nIteration, MemoryCollection rgSamples,
int nActionCount)
819 m_rgSamples = rgSamples;
820 m_nActionCount = nActionCount;
822 m_mycaffe.Log.Enable =
false;
823 setData1(m_netTarget, rgSamples);
826 setData1(m_net, rgSamples);
827 m_memLoss.
OnGetLoss += m_memLoss_OnGetLoss;
829 m_memLoss.
OnGetLoss -= m_memLoss_OnGetLoss;
831 setData0(m_net, rgSamples);
832 m_memLoss.
OnGetLoss += m_memLoss_ProjectDistribution;
834 if (m_nMiniBatch == 1)
840 m_solver.
Step(1,
TRAIN_STEP.NONE,
true, m_bUseAcceleratedTraining,
true,
true);
843 if (nIteration % m_nMiniBatch == 0)
846 m_colAccumulatedGradients.
SetDiff(0);
847 m_dfLearningRate = m_solver.
ApplyUpdate(nIteration);
852 m_memLoss.
OnGetLoss -= m_memLoss_ProjectDistribution;
853 m_mycaffe.Log.Enable =
true;
863 int nNumSamples = m_rgSamples.Count;
867 throw new Exception(
"Missing expected 'logits' blob!");
877 m_blobPLoss.
Reshape(m_blobPLoss.
num, m_nActionCount, m_nAtoms, 1);
878 m_blobLabel.
Reshape(m_blobLabel.
num, m_nActionCount, m_nAtoms, 1);
883 for (
int i = 0; i < nNumSamples; i++)
885 for (
int j = 0; j < m_nActionCount; j++)
888 nDstOffset += m_nAtoms;
891 nSrcOffset += m_nAtoms;
894 e.
Loss = softmaxLoss_forward(m_blobPLoss, m_blobLabel, m_blobLoss);
895 softmaxLoss_backward(m_blobPLoss, m_blobLabel, m_blobLoss);
910 throw new Exception(
"Missing expected 'logits' blob!");
912 Blob<T> actions = softmax_forward(logits, m_blobAction);
915 if (p_logits ==
null)
916 throw new Exception(
"Missing expected 'logits' blob!");
918 Blob<T> p_actions = softmax_forward(p_logits, m_blobActionTarget);
920 int nNumSamples = m_rgSamples.Count;
921 createZ(nNumSamples, m_nActionCount, m_nAtoms);
925 m_mycaffe.Log.CHECK_EQ(m_blobQ.
shape(0), nNumSamples,
"The result should have shape(0) = NumSamples which is " + nNumSamples.ToString());
926 m_mycaffe.Log.CHECK_EQ(m_blobQ.
shape(1), m_nActionCount,
"The result should have shape(1) = Actions which is " + m_nActionCount.ToString());
927 m_mycaffe.Log.CHECK_EQ(m_blobQ.
shape(2), m_nAtoms,
"The result should have shape(2) = Atoms which is " + m_nAtoms.ToString());
931 reduce_sum_axis2(m_blobQ);
932 m_blobQ.
Reshape(nNumSamples, m_nActionCount, 1, 1);
936 float[] rgMBatch =
new float[nNumSamples * m_nAtoms];
938 for (
int i = 0; i < nNumSamples; i++)
940 int nActionMax = argmax(rgQbatch, m_nActionCount, i);
942 if (m_rgSamples[i].IsTerminated)
944 double dfTz = m_rgSamples[i].Reward;
947 dfTz = setBounds(dfTz, m_dfVMin, m_dfVMax);
949 double dfB = (dfTz - m_dfVMin) / m_fDeltaZ;
950 int nL = (int)Math.Floor(dfB);
951 int nU = (int)Math.Ceiling(dfB);
952 int nIdx = i * m_nAtoms;
954 rgMBatch[nIdx + nL] += (float)(nU - dfB);
955 rgMBatch[nIdx + nU] += (float)(dfB - nL);
959 for (
int j = 0; j < m_nAtoms; j++)
961 double dfTz = m_rgSamples[i].Reward + m_fGamma * m_rgfZ[j];
964 dfTz = setBounds(dfTz, m_dfVMin, m_dfVMax);
966 double dfB = (dfTz - m_dfVMin) / m_fDeltaZ;
967 int nL = (int)Math.Floor(dfB);
968 int nU = (int)Math.Ceiling(dfB);
969 int nIdx = i * m_nAtoms;
970 int nIdxT = (i * m_nActionCount * m_nAtoms) + (nActionMax * m_nAtoms);
972 rgMBatch[nIdx + nL] += rgPbatch[nIdxT + j] * (float)(nU - dfB);
973 rgMBatch[nIdx + nU] += rgPbatch[nIdxT + j] * (float)(dfB - nL);
979 for (
int j = 0; j < m_nAtoms; j++)
981 fSum += rgMBatch[(i * m_nAtoms) + j];
986 for (
int j = 0; j < m_nAtoms; j++)
988 rgMBatch[(i * m_nAtoms) + j] /= fSum;
993 m_blobMLoss.
Reshape(nNumSamples, m_nAtoms, 1, 1);
996 m_blobActionBinaryLoss.
Reshape(nNumSamples, m_nActionCount, m_nAtoms, 1);
997 m_blobActionBinaryLoss.
SetData(0.0);
999 for (
int i = 0; i < m_rgSamples.Count; i++)
1001 int nAction = m_rgSamples[i].Action;
1002 int nIdx = (i * m_nActionCount * m_nAtoms) + (nAction * m_nAtoms);
1004 m_blobActionBinaryLoss.
SetData(1.0, nIdx, m_nAtoms);
1008 private float reduce_mean(
Blob<T> b)
1011 float fSum = rg.Sum(p => p);
1012 return fSum / rg.Length;
1015 private void reduce_sum_axis1(
Blob<T> b)
1017 int nNum = b.
shape(0);
1018 int nActions = b.
shape(1);
1019 int nAtoms = b.
shape(2);
1021 float[] rgSum =
new float[nNum * nAtoms];
1023 for (
int i = 0; i < nNum; i++)
1025 for (
int j = 0; j < nAtoms; j++)
1029 for (
int k = 0; k < nActions; k++)
1031 int nIdx = (i * nActions * nAtoms) + (k * nAtoms);
1032 fSum += rg[nIdx + j];
1035 int nIdxR = i * nAtoms;
1036 rgSum[nIdxR + j] = fSum;
1040 b.
Reshape(nNum, nAtoms, 1, 1);
1044 private void reduce_sum_axis2(
Blob<T> b)
1046 int nNum = b.
shape(0);
1047 int nActions = b.
shape(1);
1048 int nAtoms = b.
shape(2);
1050 float[] rgSum =
new float[nNum * nActions];
1052 for (
int i = 0; i < nNum; i++)
1054 for (
int j = 0; j < nActions; j++)
1056 int nIdx = (i * nActions * nAtoms) + (j * nAtoms);
1059 for (
int k = 0; k < nAtoms; k++)
1061 fSum += rg[nIdx + k];
1064 int nIdxR = i * nActions;
1065 rgSum[nIdxR + j] = fSum;
1069 b.
Reshape(nNum, nAtoms, 1, 1);
1076 colBottom.
Add(actual);
1077 colBottom.
Add(target);
1082 if (m_softmaxLoss ==
null)
1088 m_softmaxLoss.
Setup(colBottom, colTop);
1091 return m_softmaxLoss.
Forward(colBottom, colTop);
1097 colBottom.
Add(actual);
1098 colBottom.
Add(target);
1103 m_softmaxLoss.
Backward(colTop,
new List<bool>() {
true,
false }, colBottom);
1109 colBottom.
Add(bBottom);
1114 if (m_softmax ==
null)
1119 m_softmax.
Setup(colBottom, colTop);
1122 m_softmax.
Reshape(colBottom, colTop);
1123 m_softmax.
Forward(colBottom, colTop);
1128 private double setBounds(
double z,
double dfMin,
double dfMax)
1139 private int argmax(
float[] rgProb,
int nActionCount,
int nSampleIdx)
1141 float[] rgfProb =
new float[nActionCount];
1143 for (
int j = 0; j < nActionCount; j++)
1145 int nIdx = (nSampleIdx * nActionCount) + j;
1146 rgfProb[j] = rgProb[nIdx];
1149 return argmax(rgfProb);
1152 private int argmax(
float[] rgfAprob)
1154 float fMax = -
float.MaxValue;
1157 for (
int i = 0; i < rgfAprob.Length; i++)
1159 if (rgfAprob[i] == fMax)
1164 else if (fMax < rgfAprob[i])
1182 setData(net, rgData, rgClip);
1185 private void setData0(
Net<T> net, MemoryCollection rgSamples)
1187 List<SimpleDatum> rgData0 = rgSamples.GetData0();
1188 List<SimpleDatum> rgClip0 = rgSamples.GetClip0();
1191 SimpleDatum[] rgClip = (rgClip0 !=
null) ? rgClip0.ToArray() :
null;
1193 setData(net, rgData, rgClip);
1196 private void setData1(
Net<T> net, MemoryCollection rgSamples)
1198 List<SimpleDatum> rgData1 = rgSamples.GetData1();
1199 List<SimpleDatum> rgClip1 = rgSamples.GetClip1();
1202 SimpleDatum[] rgClip = (rgClip1 !=
null) ? rgClip1.ToArray() :
null;
1204 setData(net, rgData, rgClip);
1212 m_transformer.
Transform(rgData, data, m_mycaffe.Cuda, m_mycaffe.Log);
1220 clip.
Reshape(rgClip.Length, rgClip[0].
Channels, rgClip[0].Height, rgClip[0].Width);
1221 m_transformer.
Transform(rgClip, clip, m_mycaffe.Cuda, m_mycaffe.Log,
true);
1236 if (logits.
num == 1)
1238 Blob<T> actions = softmax_forward(logits, m_blobAction);
1242 List<List<float>> rgData =
new List<List<float>>();
1243 for (
int i = 0; i < m_nActionCount; i++)
1245 List<float> rgProb =
new List<float>();
1247 for (
int j = 0; j < m_nAtoms; j++)
1249 int nIdx = (i * m_nAtoms) + j;
1250 rgProb.Add(rgActions[nIdx]);
1256 m_rgOverlay = rgData;
1259 if (m_rgOverlay ==
null)
1262 using (Graphics g = Graphics.FromImage(e.
DisplayImage))
1266 int nWid1 = nWid / m_rgOverlay.Count;
1271 float[] rgfMin =
new float[m_rgOverlay.Count];
1272 float[] rgfMax =
new float[m_rgOverlay.Count];
1273 float fMax = -
float.MaxValue;
1274 float fMaxMax = -
float.MaxValue;
1277 for (
int i=0; i<m_rgOverlay.Count; i++)
1279 rgfMin[i] = m_rgOverlay[i].Min(p => p);
1280 rgfMax[i] = m_rgOverlay[i].Max(p => p);
1282 if (rgfMax[i] > fMax)
1288 fMaxMax = Math.Max(fMax, fMaxMax);
1292 m_bNormalizeOverlay =
false;
1294 for (
int i = 0; i < m_rgOverlay.Count; i++)
1296 drawProbabilities(g, nX, nY, nWid1, nHt1, i, m_rgOverlay[i], clrMap.
GetColor(i + 1), rgfMin.Min(p => p), rgfMax.Max(p => p), (i == nMaxIdx) ?
true :
false, m_bNormalizeOverlay);
1302 private void drawProbabilities(Graphics g,
int nX,
int nY,
int nWid,
int nHt,
int nAction, List<float> rgProb, Color clr,
float fMin,
float fMax,
bool bMax,
bool bNormalize)
1307 m_font =
new Font(
"Century Gothic", 9.0f);
1309 if (!m_rgStyle.ContainsKey(clr))
1311 Color clr1 = Color.FromArgb(128, clr);
1312 Brush br1 =
new SolidBrush(clr1);
1313 Color clr2 = Color.FromArgb(64, clr);
1314 Pen pen =
new Pen(clr2, 1.0f);
1315 Brush br2 =
new SolidBrush(clr2);
1316 Brush brBright =
new SolidBrush(clr);
1317 m_rgStyle.Add(clr,
new Tuple<Brush, Brush, Pen, Brush>(br1, br2, pen, brBright));
1320 Brush brBack = m_rgStyle[clr].Item1;
1321 Brush brFront = m_rgStyle[clr].Item2;
1322 Brush brTop = m_rgStyle[clr].Item4;
1323 Pen penLine = m_rgStyle[clr].Item3;
1325 if (fMin != 0 || fMax != 0)
1327 str =
"Action " + nAction.ToString() +
" (" + (fMax - fMin).ToString(
"N7") +
")";
1331 str =
"Action " + nAction.ToString() +
" - No Probabilities";
1334 SizeF sz = g.MeasureString(str, m_font);
1336 int nY1 = (int)(nY + (nHt - sz.Height));
1337 int nX1 = (int)(nX + (nWid / 2) - (sz.Width / 2));
1338 g.DrawString(str, m_font, (bMax) ? brTop : brFront,
new Point(nX1, nY1));
1340 if (fMin != 0 || fMax != 0)
1343 float fWid = nWid / (float)rgProb.Count;
1344 nHt -= (
int)sz.Height;
1346 for (
int i = 0; i < rgProb.Count; i++)
1348 float fProb = rgProb[i];
1351 fProb = (fProb - fMin) / (fMax - fMin);
1353 float fHt = nHt * fProb;
1354 float fHt1 = nHt - fHt;
1355 RectangleF rc1 =
new RectangleF(fX, nY + fHt1, fWid, fHt);
1356 g.FillRectangle(brBack, rc1);
1357 g.DrawRectangle(penLine, rc1.X, rc1.Y, rc1.Width, rc1.Height);
1364 class MemoryEpisodeCollection
1366 int m_nTotalCount = 0;
1368 List<MemoryCollection> m_rgItems =
new List<MemoryCollection>();
1378 public MemoryEpisodeCollection(
int nMax)
1385 get {
return m_rgItems.Count; }
1394 public void Add(MemoryItem item)
1398 if (m_rgItems.Count == 0 || m_rgItems[m_rgItems.Count - 1].Episode != item.Episode)
1400 MemoryCollection col =
new MemoryCollection(
int.MaxValue);
1406 m_rgItems[m_rgItems.Count - 1].Add(item);
1409 if (m_nTotalCount > m_nMax)
1411 List<MemoryCollection> rgItems = m_rgItems.OrderBy(p => p.TotalReward).ToList();
1412 m_nTotalCount -= rgItems[0].Count;
1413 m_rgItems.Remove(rgItems[0]);
1417 public MemoryCollection GetRandomSamples(
CryptoRandom random,
int nCount)
1419 MemoryCollection col =
new MemoryCollection(nCount);
1420 List<string> rgItems =
new List<string>();
1422 for (
int i = 0; i < nCount; i++)
1424 int nEpisode = random.
Next(m_rgItems.Count);
1425 int nItem = random.
Next(m_rgItems[nEpisode].Count);
1426 string strItem = nEpisode.ToString() +
"_" + nItem.ToString();
1428 if (!rgItems.Contains(strItem))
1430 col.Add(m_rgItems[nEpisode][nItem]);
1431 rgItems.Add(strItem);
1438 List<StateBase> GetState1()
1440 List<StateBase> rgItems =
new List<StateBase>();
1442 for (
int i = 0; i < m_rgItems.Count; i++)
1444 for (
int j = 0; j < m_rgItems[i].Count; j++)
1446 rgItems.Add(m_rgItems[i][j].State1);
1453 List<SimpleDatum> GetItem(ITEM item)
1455 List<SimpleDatum> rgItems =
new List<SimpleDatum>();
1457 for (
int i = 0; i < m_rgItems.Count; i++)
1462 rgItems.AddRange(m_rgItems[i].GetData0());
1466 rgItems.AddRange(m_rgItems[i].GetData1());
1470 rgItems.AddRange(m_rgItems[i].GetClip0());
1474 rgItems.AddRange(m_rgItems[i].GetClip1());
1483 class MemoryCollection : IEnumerable<MemoryItem>
1485 double m_dfTotalReward = 0;
1488 List<MemoryItem> m_rgItems =
new List<MemoryItem>();
1490 public MemoryCollection(
int nMax)
1497 get {
return m_rgItems.Count; }
1500 public MemoryItem
this[
int nIdx]
1502 get {
return m_rgItems[nIdx]; }
1505 public void Add(MemoryItem item)
1507 m_nEpisode = item.Episode;
1508 m_dfTotalReward += item.Reward;
1510 m_rgItems.Add(item);
1512 if (m_rgItems.Count > m_nMax)
1513 m_rgItems.RemoveAt(0);
1519 m_dfTotalReward = 0;
1525 get {
return m_nEpisode; }
1528 public double TotalReward
1530 get {
return m_dfTotalReward; }
1533 public MemoryCollection GetRandomSamples(
CryptoRandom random,
int nCount)
1535 MemoryCollection col =
new MemoryCollection(m_nMax);
1536 List<int> rgIdx =
new List<int>();
1538 while (col.Count < nCount)
1540 int nIdx = random.
Next(m_rgItems.Count);
1541 if (!rgIdx.Contains(nIdx))
1543 col.Add(m_rgItems[nIdx]);
1551 public List<StateBase> GetState1()
1553 return m_rgItems.Select(p => p.State1).ToList();
1556 public List<SimpleDatum> GetData1()
1558 return m_rgItems.Select(p => p.Data1).ToList();
1561 public List<SimpleDatum> GetClip1()
1563 if (m_rgItems[0].State1.Clip !=
null)
1564 return m_rgItems.Select(p => p.State1.Clip).ToList();
1569 public List<SimpleDatum> GetData0()
1571 return m_rgItems.Select(p => p.Data0).ToList();
1574 public List<SimpleDatum> GetClip0()
1576 if (m_rgItems[0].State0.Clip !=
null)
1577 return m_rgItems.Select(p => p.State0.Clip).ToList();
1582 public IEnumerator<MemoryItem> GetEnumerator()
1584 return m_rgItems.GetEnumerator();
1587 IEnumerator IEnumerable.GetEnumerator()
1589 return m_rgItems.GetEnumerator();
1592 public override string ToString()
1594 return "Episode #" + m_nEpisode.ToString() +
" (" + m_rgItems.Count.ToString() +
") => " + m_dfTotalReward.ToString();
1610 public MemoryItem(StateBase s,
SimpleDatum x,
int nAction, StateBase s_,
SimpleDatum x_,
double dfReward,
bool bTerminated,
int nIteration,
int nEpisode)
1616 m_nAction = nAction;
1617 m_bTerminated = bTerminated;
1618 m_dfReward = dfReward;
1619 m_nIteration = nIteration;
1620 m_nEpisode = nEpisode;
1623 public bool IsTerminated
1625 get {
return m_bTerminated; }
1628 public double Reward
1630 get {
return m_dfReward; }
1631 set { m_dfReward = value; }
1634 public StateBase State0
1636 get {
return m_state0; }
1639 public StateBase State1
1641 get {
return m_state1; }
1646 get {
return m_x0; }
1651 get {
return m_x1; }
1656 get {
return m_nAction; }
1659 public int Iteration
1661 get {
return m_nIteration; }
1666 get {
return m_nEpisode; }
1669 public override string ToString()
1671 return "episode = " + m_nEpisode.
ToString() +
" action = " + m_nAction.ToString() +
" reward = " + m_dfReward.ToString(
"N2");
1674 private string tostring(
float[] rg)
1678 for (
int i = 0; i < rg.Length; i++)
1680 str += rg[i].ToString(
"N5");
1684 str = str.TrimEnd(
',');
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
CancelEvent CancelEvent
Returns the CancelEvent used.
Net< T > GetInternalNet(Phase phase=Phase.RUN)
Returns the internal net based on the Phase specified: TRAIN, TEST or RUN.
Solver< T > GetInternalSolver()
Get the internal solver.
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
void Reset()
Resets the event clearing any signaled state.
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
CancelEvent()
The CancelEvent constructor.
void Set()
Sets the event to the signaled state.
The ColorMapper maps a value within a number range, to a Color within a color scheme.
Color GetColor(double dfVal)
Find the color using a binary search algorithm.
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
int Next(int nMinVal, int nMaxVal, bool bMaxInclusive=true)
Returns a random int within the range
double NextDouble()
Returns a random double within the range .
The Log class provides general output in text form.
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
Log(string strSrc)
The Log constructor.
Specifies a key-value pair of properties.
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
double GetPropertyAsDouble(string strName, double dfDefault=0)
Returns a property as an double value.
override string ToString()
Returns the string representation of the properties.
The SimpleDatum class holds a data input within host memory.
float GetDataAtF(int nIdx)
Returns the item at a specified index in the float type.
bool Sub(SimpleDatum sd, bool bSetNegativeToZero=false)
Subtract the data of another SimpleDatum from this one, so this = this - sd.
void Zero()
Zero out all data in the datum but keep the size and other settings.
DateTime TimeStamp
Get/set the Timestamp.
object Tag
Specifies user data associated with the SimpleDatum.
override string ToString()
Return a string representation of the SimpleDatum.
int Channels
Return the number of channels of the data.
int Index
Returns the index of the SimpleDatum.
The Utility class provides general utility funtions.
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
The BlobCollection contains a list of Blobs.
void Dispose()
Release all resource used by the collection and its Blobs.
void Add(Blob< T > b)
Add a new Blob to the collection.
void Accumulate(CudaDnn< T > cuda, BlobCollection< T > src, bool bAccumulateDiff)
Accumulate the diffs from one BlobCollection into another.
void SetDiff(double df)
Set all blob diff to the value specified.
The Blob is the main holder of data that moves through the Layers of the Net.
int channels
DEPRECIATED; legacy shape accessor channels: use shape(1) instead.
void SetData(T[] rgData, int nCount=-1, bool bSetCount=true)
Sets a number of items within the Blob's data.
int height
DEPRECIATED; legacy shape accessor height: use shape(2) instead.
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
T[] mutable_cpu_data
Get data from the GPU and bring it over to the host, or Set data from the Host and send it over to th...
void Reshape(int nNum, int nChannels, int nHeight, int nWidth, bool? bUseHalfSize=null)
DEPRECIATED; use
int width
DEPRECIATED; legacy shape accessor width: use shape(3) instead.
List< int > shape()
Returns an array where each element contains the shape of an axis of the Blob.
int count()
Returns the total number of items in the Blob.
void ReshapeLike(Blob< T > b, bool? bUseHalfSize=null)
Reshape this Blob to have the same shape as another Blob.
long gpu_diff
Returns the diff GPU handle used by the CudaDnn connection.
int num
DEPRECIATED; legacy shape accessor num: use shape(0) instead.
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
List< Layer< T > > layers
Returns the layers.
double ForwardFromTo(int nStart=0, int nEnd=int.MaxValue)
The FromTo variant of forward and backward operate on the (topological) ordering by which the net is ...
void CopyTrainedLayersTo(Net< T > dstNet)
Copies the trained layer of this Net to another Net.
Layer< T > FindLastLayer(LayerParameter.LayerType type)
Find the last layer with the matching type.
virtual void Dispose(bool bDisposing)
Releases all resources (GPU and Host) used by the Net.
void ClearParamDiffs()
Zero out the diffs of all netw parameters. This should be run before Backward.
BlobCollection< T > learnable_parameters
Returns the learnable parameters.
NetParameter net_param
Returns the net parameter.
Blob< T > blob_by_name(string strName, bool bThrowExceptionOnError=true)
Returns a blob given its name.
The ResultCollection contains the result of a given CaffeControl::Run.
void Backward(BlobCollection< T > colTop, List< bool > rgbPropagateDown, BlobCollection< T > colBottom)
Given the top Blob error gradients, compute the bottom Blob error gradients.
double Forward(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Given the bottom (input) Blobs, this function computes the top (output) Blobs and the loss.
void Dispose()
Releases all GPU and host resources used by the Layer.
void Setup(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Implements common Layer setup functionality.
The MemoryLossLayerGetLossArgs class is passed to the OnGetLoss event.
bool EnableLossUpdate
Get/set enabling the loss update within the backpropagation pass.
double Loss
Get/set the externally calculated total loss.
BlobCollection< T > Bottom
Specifies the bottom passed in during the forward pass.
The MemoryLossLayer provides a method of performing a custom loss functionality. Similar to the Memor...
EventHandler< MemoryLossLayerGetLossArgs< T > > OnGetLoss
The OnGetLoss event fires during each forward pass. The value returned is saved, and applied on the b...
The SoftmaxCrossEntropyLossLayer computes the cross-entropy (logisitic) loss and is often used for pr...
The SoftmaxLayer computes the softmax function. This layer is initialized with the MyCaffe....
override void Reshape(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Reshape the bottom (input) and top (output) blobs.
Specifies the base parameter for all layers.
SoftmaxParameter softmax_param
Returns the parameter set when initialized with LayerType.SOFTMAX
LayerType
Specifies the layer type.
LossParameter loss_param
Returns the parameter set when initialized with LayerType.LOSS
Stores the parameters used by loss layers.
NormalizationMode
How to normalize the loss for loss layers that aggregate across batches, spatial dimensions,...
NormalizationMode? normalization
Specifies the normalization mode (default = VALID).
int axis
The axis along which to perform the softmax – may be negative to index from the end (e....
double delta
Numerical stability for RMSProp, AdaGrad, AdaDelta, Adam and AdamW solvers (default = 1e-08).
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
SolverParameter parameter
Returns the SolverParameter used.
bool Step(int nIters, TRAIN_STEP step=TRAIN_STEP.NONE, bool bZeroDiffs=true, bool bApplyUpdates=true, bool bDisableOutput=false, bool bDisableProgress=false, double? dfLossOverride=null, bool? bAllowSnapshot=null)
Steps a set of iterations through a training cycle.
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
The ConvertOutputArgs is passed to the OnConvertOutput event.
byte[] RawOutput
Specifies the raw output byte stream.
string RawType
Specifies the type of the raw output byte stream.
The GetDataArgs is passed to the OnGetData event to retrieve data.
StateBase State
Specifies the state data of the observations.
The InitializeArgs is passed to the OnInitialize event.
The OverlayArgs is passed ot the OnOverlay event, optionally fired just before displaying a gym image...
Bitmap DisplayImage
Get/set the display image.
The StateBase is the base class for the state of each observation - this is defined by actual trainer...
bool Done
Get/set whether the state is done or not.
double Reward
Get/set the reward of the state.
SimpleDatum Data
Returns other data associated with the state.
int ActionCount
Returns the number of actions.
SimpleDatum Clip
Returns the clip data assoicated with the state.
The WaitArgs is passed to the OnWait event.
The Brain uses the instance of MyCaffe (e.g. the open project) to run new actions and train the netwo...
Log Log
Returns the output log.
GetDataArgs getDataArgs(Phase phase, int nAction)
Returns the GetDataArgs used to retrieve new data from the envrionment implemented by derived parent ...
int FrameStack
Specifies the number of frames per X value.
Brain(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, Phase phase)
The constructor.
void Train(int nIteration, MemoryCollection rgSamples, int nActionCount)
Train the model at the current iteration.
int BatchSize
Returns the batch size defined by the model.
CancelEvent Cancel
Returns the Cancel event used to cancel all MyCaffe tasks.
void OnOverlay(OverlayArgs e)
The OnOverlay callback is called just before displaying the gym image, thus allowing for an overlay t...
int act(SimpleDatum sd, SimpleDatum sdClip, int nActionCount)
Returns the action from running the model. The action returned is either randomly selected (when usin...
void UpdateTargetModel()
The UpdateTargetModel transfers the trained layers from the active Net to the target Net.
SimpleDatum Preprocess(StateBase s, bool bUseRawInput, out bool bDifferent, bool bReset=false)
Preprocesses the data.
void Dispose()
Release all resources used by the Brain.
bool GetModelUpdated()
Get whether or not the model has been udpated or not.
The DqnAgent both builds episodes from the envrionment and trains on them using the Brain.
void Dispose()
Release all resources used.
byte[] Run(int nIterations, out string type)
Run the action on a set number of iterations and return the results with no training.
DqnAgent(IxTrainerCallback icallback, MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, Phase phase)
The constructor.
void Run(Phase phase, int nN, ITERATOR_TYPE type, TRAIN_STEP step)
The Run method provides the main loop that performs the following steps: 1.) get state 2....
The TrainerC51 implements the C51-DQN algorithm as described by Bellemare et al., Google Dopamine Rai...
bool Test(int nN, ITERATOR_TYPE type)
Run the test cycle - currently this is not implemented.
byte[] Run(int nN, PropertySet runProp, out string type)
Run a set of iterations and return the resuts.
TrainerC51(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback)
The constructor.
bool Shutdown(int nWait)
Shutdown the trainer.
ResultCollection RunOne(int nDelay=1000)
Run a single cycle on the environment after the delay.
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network using a modified PG training algorithm optimized for GPU use.
void Dispose()
Release all resources used.
bool Initialize()
Initialize the trainer.
The IxTrainerCallback provides functions used by each trainer to 'call-back' to the parent for inform...
The IxTrainerCallbackRNN provides functions used by each trainer to 'call-back' to the parent for inf...
void OnConvertOutput(ConvertOutputArgs e)
The OnConvertOutput callback fires from within the Run method and is used to convert the network's ou...
The IxTrainerGetDataCallback interface is called right after rendering the output image and just befo...
The IxTrainerRL interface is implemented by each RL Trainer.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Phase
Defines the Phase under which to run a Net.
The MyCaffe.common namespace contains common MyCaffe classes.
TRAIN_STEP
Defines the training stepping method (if any).
The MyCaffe.data namespace contains dataset creators used to create common testing datasets such as M...
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
ITERATOR_TYPE
Specifies the iterator type to use.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...