12using System.Collections.Generic;
18using System.Threading.Tasks;
52 m_icallback = icallback;
54 m_properties = properties;
83 if (m_mycaffe !=
null)
89 m_icallback.OnShutdown();
94 private void wait(
int nWait)
99 while (nTotalWait < nWait)
101 m_icallback.OnWait(
new WaitArgs(nWaitInc));
102 nTotalWait += nWaitInc;
131 byte[] rgResults = agent.
Run(nN, out type);
146 string strProp = m_properties.
ToString();
149 strProp +=
"EnableNumSkip=False;";
173 agent.
Run(
Phase.TRAIN, nN, type, step);
191 float m_fGamma = 0.99f;
192 bool m_bUseRawInput =
true;
193 double m_dfBetaStart = 0.4;
194 int m_nBetaFrames = 1000;
195 int m_nMemorySize = 10000;
196 float m_fPriorityAlpha = 0.6f;
197 int m_nUpdateTargetFreq = 1000;
211 m_icallback = icallback;
212 m_brain =
new Brain<T>(mycaffe, properties, random, phase);
213 m_properties = properties;
232 private double beta_by_frame(
int nFrameIdx)
234 return Math.Min(1.0, m_dfBetaStart + nFrameIdx * (1.0 - m_dfBetaStart) / m_nBetaFrames);
240 m_icallback.OnGetData(args);
245 private void updateStatus(
int nIteration,
int nEpisodeCount,
double dfRewardSum,
double dfRunningReward,
double dfLoss,
double dfLearningRate,
bool bModelUpdated)
247 GetStatusArgs args =
new GetStatusArgs(0, nIteration, nEpisodeCount, 1000000, dfRunningReward, dfRewardSum, 0, 0, dfLoss, dfLearningRate, bModelUpdated);
248 m_icallback.OnUpdateStatus(args);
257 public byte[]
Run(
int nIterations, out
string type)
260 if (icallback ==
null)
261 throw new Exception(
"The Run method requires an IxTrainerCallbackRNN interface to convert the results into the native format!");
265 List<float> rgResults =
new List<float>();
268 while (!m_brain.
Cancel.
WaitOne(0) && (nIterations == -1 || nIteration < nIterations))
278 rgResults.Add(action);
283 s = getData(
Phase.RUN, action, nIteration);
293 private bool isAtIteration(
int nN,
ITERATOR_TYPE type,
int nIteration,
int nEpisode)
329 double dfRunningReward = 0;
330 double dfEpisodeReward = 0;
332 bool bDifferent =
false;
334 StateBase state = getData(phase, -1, -1);
341 while (!m_brain.
Cancel.
WaitOne(0) && !isAtIteration(nN, type, nIteration, nEpisode))
347 StateBase state_next = getData(phase, action, nIteration);
352 m_brain.
Log.
WriteLine(
"WARNING: The current state is the same as the previous state!");
355 iMemory.
Add(
new MemoryItem(state, x, action, state_next, x_next, state_next.
Reward, state_next.
Done, nIteration, nEpisode));
356 dfEpisodeReward += state_next.
Reward;
361 double dfBeta = beta_by_frame(nIteration);
364 iMemory.
Update(rgSamples);
366 if (nIteration % m_nUpdateTargetFreq == 0)
373 dfRunningReward = dfRunningReward * 0.99 + dfEpisodeReward * 0.01;
376 updateStatus(nIteration, nEpisode, dfEpisodeReward, dfRunningReward, 0, 0, m_brain.
GetModelUpdated());
378 state = getData(phase, -1, -1);
379 x = m_brain.
Preprocess(state, m_bUseRawInput, out bDifferent,
true);
412 Blob<T> m_blobNextQValue =
null;
413 Blob<T> m_blobExpectedQValue =
null;
418 bool m_bUseAcceleratedTraining =
false;
419 double m_dfLearningRate;
420 int m_nMiniBatch = 1;
421 float m_fGamma = 0.99f;
422 int m_nBatchSize = 32;
424 int m_nActionCount = 2;
425 bool m_bModelUpdated =
false;
427 Dictionary<Color, Tuple<Brush, Brush, Pen, Brush>> m_rgStyle =
new Dictionary<Color, Tuple<Brush, Brush, Pen, Brush>>();
428 List<SimpleDatum> m_rgX =
new List<SimpleDatum>();
429 float[] m_rgOverlay =
null;
444 m_netTarget =
new Net<T>(m_mycaffe.Cuda, m_mycaffe.Log, m_netOutput.
net_param, m_mycaffe.CancelEvent,
null, phase);
445 m_properties = properties;
450 m_mycaffe.Log.FAIL(
"Missing the expected input 'data' blob!");
452 m_nBatchSize = data.
num;
456 m_mycaffe.Log.FAIL(
"Missing the expected input 'logits' blob!");
461 if (m_transformer ==
null)
464 int nC = m_mycaffe.CurrentProject.Dataset.TrainingSource.Channels;
465 int nH = m_mycaffe.CurrentProject.Dataset.TrainingSource.Height;
466 int nW = m_mycaffe.CurrentProject.Dataset.TrainingSource.Width;
467 m_transformer =
new DataTransformer<T>(m_mycaffe.Cuda, m_mycaffe.Log, trans_param, phase, nC, nH, nW);
469 m_blobActions =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log,
false);
470 m_blobQValue =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log);
471 m_blobNextQValue =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log);
472 m_blobExpectedQValue =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log);
473 m_blobDone =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log,
false);
474 m_blobLoss =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log);
475 m_blobWeights =
new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log,
false);
480 if (m_memLoss ==
null)
481 m_mycaffe.Log.FAIL(
"Missing the expected MEMORY_LOSS layer!");
485 m_dfLearningRate = dfRate.Value;
488 m_bUseAcceleratedTraining = properties.
GetPropertyAsBool(
"UseAcceleratedTraining",
false);
490 if (m_nMiniBatch > 1)
493 m_colAccumulatedGradients.
SetDiff(0);
497 private void dispose(ref
Blob<T> b)
511 dispose(ref m_blobActions);
512 dispose(ref m_blobQValue);
513 dispose(ref m_blobNextQValue);
514 dispose(ref m_blobExpectedQValue);
515 dispose(ref m_blobDone);
516 dispose(ref m_blobLoss);
517 dispose(ref m_blobWeights);
519 if (m_colAccumulatedGradients !=
null)
521 m_colAccumulatedGradients.
Dispose();
522 m_colAccumulatedGradients =
null;
525 if (m_netTarget !=
null)
537 foreach (KeyValuePair<Color, Tuple<Brush, Brush, Pen, Brush>> kv
in m_rgStyle)
539 kv.Value.Item1.Dispose();
540 kv.Value.Item2.Dispose();
541 kv.Value.Item3.Dispose();
542 kv.Value.Item4.Dispose();
556 bool bReset = (nAction == -1) ?
true :
false;
557 return new GetDataArgs(phase, 0, m_mycaffe, m_mycaffe.Log, m_mycaffe.CancelEvent, bReset, nAction,
true,
false,
false,
this);
565 get {
return m_nBatchSize; }
573 get {
return m_mycaffe.
Log; }
603 if (m_sdLast ==
null)
606 bDifferent = sd.
Sub(m_sdLast);
630 setData(m_netOutput, sd, sdClip);
635 throw new Exception(
"Missing expected 'logits' blob!");
647 bool bModelUpdated = m_bModelUpdated;
648 m_bModelUpdated =
false;
649 return bModelUpdated;
657 m_mycaffe.Log.Enable =
false;
660 m_mycaffe.Log.Enable =
true;
661 m_bModelUpdated =
true;
672 m_rgSamples = rgSamples;
674 if (m_nActionCount != nActionCount)
675 throw new Exception(
"The logit output of '" + m_nActionCount.ToString() +
"' does not match the action count of '" + nActionCount.ToString() +
"'!");
678 m_mycaffe.Log.Enable =
false;
679 setNextStateData(m_netTarget, rgSamples);
682 setCurrentStateData(m_netOutput, rgSamples);
683 m_memLoss.
OnGetLoss += m_memLoss_ComputeTdLoss;
685 if (m_nMiniBatch == 1)
691 m_solver.
Step(1,
TRAIN_STEP.NONE,
true, m_bUseAcceleratedTraining,
true,
true);
694 if (nIteration % m_nMiniBatch == 0)
697 m_colAccumulatedGradients.
SetDiff(0);
698 m_dfLearningRate = m_solver.
ApplyUpdate(nIteration);
703 m_memLoss.
OnGetLoss -= m_memLoss_ComputeTdLoss;
704 m_mycaffe.Log.Enable =
true;
706 resetNoise(m_netOutput);
707 resetNoise(m_netTarget);
729 reduce_sum_axis1(m_blobQValue);
732 m_blobNextQValue.
CopyFrom(next_q_values,
false,
true);
733 reduce_argmax_axis1(m_blobNextQValue);
745 m_mycaffe.Cuda.mul_scalar(m_blobExpectedQValue.
count(), m_fGamma, m_blobExpectedQValue.
mutable_gpu_diff);
746 m_mycaffe.Cuda.add(m_blobExpectedQValue.
count(), m_blobExpectedQValue.
gpu_diff, m_blobExpectedQValue.
gpu_data, m_blobExpectedQValue.
gpu_data);
763 for (
int i = 0; i < rgPrios.Length; i++)
775 double dfGradient = 1.0;
780 dfGradient /= m_blobLoss.
count();
781 m_blobLoss.
SetDiff(dfGradient);
793 mul(m_blobLoss, m_blobActions, e.
Bottom[0]);
795 e.
Loss = reduce_mean(m_blobLoss,
false);
799 private void resetNoise(
Net<T> net)
815 float[] rgResult =
new float[rgActions.Length];
817 for (
int i = 0; i < actions.
num; i++)
819 float fPred = rgVal[i];
821 for (
int j = 0; j < actions.
channels; j++)
823 int nIdx = (i * actions.
channels) + j;
824 rgResult[nIdx] = rgActions[nIdx] * fPred;
831 private float reduce_mean(
Blob<T> b,
bool bDiff)
834 float fSum = rg.Sum(p => p);
835 return fSum / rg.Length;
838 private void reduce_sum_axis1(
Blob<T> b)
840 int nNum = b.
shape(0);
841 int nActions = b.
shape(1);
842 int nInnerCount = b.
count(2);
844 float[] rgSum =
new float[nNum * nInnerCount];
846 for (
int i = 0; i < nNum; i++)
848 for (
int j = 0; j < nInnerCount; j++)
852 for (
int k = 0; k < nActions; k++)
854 int nIdx = (i * nActions * nInnerCount) + (k * nInnerCount);
855 fSum += rg[nIdx + j];
858 int nIdxR = i * nInnerCount;
859 rgSum[nIdxR + j] = fSum;
863 b.
Reshape(nNum, nInnerCount, 1, 1);
867 private void reduce_argmax_axis1(
Blob<T> b)
869 int nNum = b.
shape(0);
870 int nActions = b.
shape(1);
871 int nInnerCount = b.
count(2);
873 float[] rgMax =
new float[nNum * nInnerCount];
875 for (
int i = 0; i < nNum; i++)
877 for (
int j = 0; j < nInnerCount; j++)
879 float fMax = -
float.MaxValue;
881 for (
int k = 0; k < nActions; k++)
883 int nIdx = (i * nActions * nInnerCount) + (k * nInnerCount);
884 fMax = Math.Max(fMax, rg[nIdx + j]);
887 int nIdxR = i * nInnerCount;
888 rgMax[nIdxR + j] = fMax;
892 b.
Reshape(nNum, nInnerCount, 1, 1);
896 private int argmax(
float[] rgProb,
int nActionCount,
int nSampleIdx)
898 float[] rgfProb =
new float[nActionCount];
900 for (
int j = 0; j < nActionCount; j++)
902 int nIdx = (nSampleIdx * nActionCount) + j;
903 rgfProb[j] = rgProb[nIdx];
906 return argmax(rgfProb);
909 private int argmax(
float[] rgfAprob)
911 double fMax = -
float.MaxValue;
914 for (
int i = 0; i < rgfAprob.Length; i++)
916 if (rgfAprob[i] == fMax)
921 else if (fMax < rgfAprob[i])
939 setData(net, rgData, rgClip);
948 SimpleDatum[] rgClip = (rgClip0 !=
null) ? rgClip0.ToArray() :
null;
950 setData(net, rgData, rgClip);
959 SimpleDatum[] rgClip = (rgClip1 !=
null) ? rgClip1.ToArray() :
null;
961 setData(net, rgData, rgClip);
969 m_transformer.
Transform(rgData, data, m_mycaffe.Cuda, m_mycaffe.Log);
977 clip.
Reshape(rgClip.Length, rgClip[0].
Channels, rgClip[0].Height, rgClip[0].Width);
978 m_transformer.
Transform(rgClip, clip, m_mycaffe.Cuda, m_mycaffe.Log,
true);
996 if (m_rgOverlay ==
null)
1003 int nWid1 = nWid / m_rgOverlay.Length;
1008 float fMax = -
float.MaxValue;
1010 float fMin1 = m_rgOverlay.Min(p => p);
1011 float fMax1 = m_rgOverlay.Max(p => p);
1013 for (
int i=0; i<m_rgOverlay.Length; i++)
1015 if (fMin1 < 0 || fMax1 > 1)
1016 m_rgOverlay[i] = (m_rgOverlay[i] - fMin1) / (fMax1 - fMin1);
1018 if (m_rgOverlay[i] > fMax)
1020 fMax = m_rgOverlay[i];
1025 for (
int i = 0; i < m_rgOverlay.Length; i++)
1027 drawProbabilities(g, nX, nY, nWid1, nHt1, i, m_rgOverlay[i], fMin1, fMax1, clrMap.
GetColor(i + 1), (i == nMaxIdx) ?
true :
false);
1033 private void drawProbabilities(Graphics g,
int nX,
int nY,
int nWid,
int nHt,
int nAction,
float fProb,
float fMin,
float fMax, Color clr,
bool bMax)
1038 m_font =
new Font(
"Century Gothic", 9.0f);
1040 if (!m_rgStyle.ContainsKey(clr))
1042 Color clr1 = Color.FromArgb(128, clr);
1043 Brush br1 =
new SolidBrush(clr1);
1044 Color clr2 = Color.FromArgb(64, clr);
1045 Pen pen =
new Pen(clr2, 1.0f);
1046 Brush br2 =
new SolidBrush(clr2);
1047 Brush brBright =
new SolidBrush(clr);
1048 m_rgStyle.Add(clr,
new Tuple<Brush, Brush, Pen, Brush>(br1, br2, pen, brBright));
1051 Brush brBack = m_rgStyle[clr].Item1;
1052 Brush brFront = m_rgStyle[clr].Item2;
1053 Brush brTop = m_rgStyle[clr].Item4;
1054 Pen penLine = m_rgStyle[clr].Item3;
1056 if (fMin != 0 || fMax != 0)
1058 str =
"Action " + nAction.ToString() +
" (" + fProb.ToString(
"N7") +
")";
1062 str =
"Action " + nAction.ToString() +
" - No Probabilities";
1065 SizeF sz = g.MeasureString(str, m_font);
1067 int nY1 = (int)(nY + (nHt - sz.Height));
1068 int nX1 = (int)(nX + (nWid / 2) - (sz.Width / 2));
1069 g.DrawString(str, m_font, (bMax) ? brTop : brFront,
new Point(nX1, nY1));
1071 if (fMin != 0 || fMax != 0)
1075 nHt -= (int)sz.Height;
1077 float fHt = nHt * fProb;
1078 float fHt1 = nHt - fHt;
1079 RectangleF rc1 =
new RectangleF(fX, nY + fHt1, fWid, fHt);
1080 g.FillRectangle(brBack, rc1);
1081 g.DrawRectangle(penLine, rc1.X, rc1.Y, rc1.Width, rc1.Height);
1091 if (File.Exists(strFile))
1092 File.Delete(strFile);
1094 using (StreamWriter sw =
new StreamWriter(strFile))
1107 for (
int i = 0; i < layer.
blobs.Count; i++)
1109 float[] rgf =
Utility.ConvertVecF<T>(layer.
blobs[i].mutable_cpu_data);
1110 string strLine =
"";
1112 for (
int j = 0; j < rgf.Length; j++)
1114 strLine += rgf[j].ToString() +
",";
1117 sw.WriteLine(strLine.TrimEnd(
','));
1123 string strLine =
"";
1125 for (
int j = 0; j < rgf.Length; j++)
1127 strLine += rgf[j].ToString() +
",";
1130 sw.WriteLine(strLine.TrimEnd(
','));
1140 if (!File.Exists(strFile))
1143 using (StreamReader sr =
new StreamReader(strFile))
1156 for (
int i = 0; i < layer.
blobs.Count; i++)
1158 List<float> rgf =
new List<float>();
1159 string strLine = sr.ReadLine();
1160 string[] rgstr = strLine.Split(
',');
1162 for (
int j = 0; j < rgstr.Length; j++)
1172 List<float> rgf =
new List<float>();
1173 string strLine = sr.ReadLine();
1174 string[] rgstr = strLine.Split(
',');
1176 for (
int j = 0; j < rgstr.Length; j++)
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
CancelEvent CancelEvent
Returns the CancelEvent used.
Net< T > GetInternalNet(Phase phase=Phase.RUN)
Returns the internal net based on the Phase specified: TRAIN, TEST or RUN.
Solver< T > GetInternalSolver()
Get the internal solver.
ProjectEx CurrentProject
Returns the name of the currently loaded project.
The BaseParameter class is the base class for all other parameter classes.
static float ParseFloat(string strVal)
Parse float values using the US culture if the decimal separator = '.', then using the native culture...
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
void Reset()
Resets the event clearing any signaled state.
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
CancelEvent()
The CancelEvent constructor.
void Set()
Sets the event to the signaled state.
The ColorMapper maps a value within a number range, to a Color within a color scheme.
Color GetColor(double dfVal)
Find the color using a binary search algorithm.
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
double NextDouble()
Returns a random double within the range .
The Log class provides general output in text form.
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
Log(string strSrc)
The Log constructor.
double? GetSolverSettingAsNumeric(string strParam)
Get a setting from the solver descriptor as a double value.
Specifies a key-value pair of properties.
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
double GetPropertyAsDouble(string strName, double dfDefault=0)
Returns a property as an double value.
override string ToString()
Returns the string representation of the properties.
The SimpleDatum class holds a data input within host memory.
float GetDataAtF(int nIdx)
Returns the item at a specified index in the float type.
bool Sub(SimpleDatum sd, bool bSetNegativeToZero=false)
Subtract the data of another SimpleDatum from this one, so this = this - sd.
void Zero()
Zero out all data in the datum but keep the size and other settings.
DateTime TimeStamp
Get/set the Timestamp.
object Tag
Specifies user data associated with the SimpleDatum.
int Channels
Return the number of channels of the data.
int Index
Returns the index of the SimpleDatum.
The Utility class provides general utility funtions.
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
The BlobCollection contains a list of Blobs.
void Dispose()
Release all resource used by the collection and its Blobs.
void Accumulate(CudaDnn< T > cuda, BlobCollection< T > src, bool bAccumulateDiff)
Accumulate the diffs from one BlobCollection into another.
void SetDiff(double df)
Set all blob diff to the value specified.
The Blob is the main holder of data that moves through the Layers of the Net.
int channels
DEPRECIATED; legacy shape accessor channels: use shape(1) instead.
int height
DEPRECIATED; legacy shape accessor height: use shape(2) instead.
long mutable_gpu_diff
Returns the diff GPU handle used by the CudaDnn connection.
T[] mutable_cpu_diff
Get diff from the GPU and bring it over to the host, or Set diff from the Host and send it over to th...
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
T[] mutable_cpu_data
Get data from the GPU and bring it over to the host, or Set data from the Host and send it over to th...
void Reshape(int nNum, int nChannels, int nHeight, int nWidth, bool? bUseHalfSize=null)
DEPRECIATED; use
void CopyFrom(Blob< T > src, int nSrcOffset, int nDstOffset, int nCount, bool bCopyData, bool bCopyDiff)
Copy from a source Blob.
int width
DEPRECIATED; legacy shape accessor width: use shape(3) instead.
List< int > shape()
Returns an array where each element contains the shape of an axis of the Blob.
int count()
Returns the total number of items in the Blob.
void ReshapeLike(Blob< T > b, bool? bUseHalfSize=null)
Reshape this Blob to have the same shape as another Blob.
long gpu_diff
Returns the diff GPU handle used by the CudaDnn connection.
void SetDiff(double dfVal, int nIdx=-1)
Either sets all of the diff items in the Blob to a given value, or alternatively only sets a single i...
int num
DEPRECIATED; legacy shape accessor num: use shape(0) instead.
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
List< Layer< T > > layers
Returns the layers.
double ForwardFromTo(int nStart=0, int nEnd=int.MaxValue)
The FromTo variant of forward and backward operate on the (topological) ordering by which the net is ...
void CopyInternalBlobsTo(Net< T > dstNet)
Copy the internal blobs from one net to another.
void CopyTrainedLayersTo(Net< T > dstNet)
Copies the trained layer of this Net to another Net.
Layer< T > FindLastLayer(LayerParameter.LayerType type)
Find the last layer with the matching type.
Layer< T > layer_by_name(string strLayer, bool bThrowExceptionOnError=true)
Returns a Layer given its name.
virtual void Dispose(bool bDisposing)
Releases all resources (GPU and Host) used by the Net.
void ClearParamDiffs()
Zero out the diffs of all netw parameters. This should be run before Backward.
BlobCollection< T > learnable_parameters
Returns the learnable parameters.
NetParameter net_param
Returns the net parameter.
Blob< T > blob_by_name(string strName, bool bThrowExceptionOnError=true)
Returns a blob given its name.
The ResultCollection contains the result of a given CaffeControl::Run.
The InnerProductLayer, also know as a 'fully-connected' layer, computes the inner product with a set ...
An interface for the units of computation which can be composed into a Net.
LayerParameter.LayerType type
Returns the LayerType of this Layer.
LayerParameter layer_param
Returns the LayerParameter for this Layer.
BlobCollection< T > blobs
Returns the collection of learnable parameter Blobs for the Layer.
BlobCollection< T > internal_blobs
Returns the collection of internal Blobs used by the Layer.
The MemoryLossLayerGetLossArgs class is passed to the OnGetLoss event.
bool EnableLossUpdate
Get/set enabling the loss update within the backpropagation pass.
double Loss
Get/set the externally calculated total loss.
BlobCollection< T > Bottom
Specifies the bottom passed in during the forward pass.
The MemoryLossLayer provides a method of performing a custom loss functionality. Similar to the Memor...
EventHandler< MemoryLossLayerGetLossArgs< T > > OnGetLoss
The OnGetLoss event fires during each forward pass. The value returned is saved, and applied on the b...
bool enable_noise
Enable/disable noise in the inner-product layer (default = false).
Specifies the base parameter for all layers.
List< double > loss_weight
Specifies the loss weight.
InnerProductParameter inner_product_param
Returns the parameter set when initialized with LayerType.INNERPRODUCT
LayerType
Specifies the layer type.
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
bool Step(int nIters, TRAIN_STEP step=TRAIN_STEP.NONE, bool bZeroDiffs=true, bool bApplyUpdates=true, bool bDisableOutput=false, bool bDisableProgress=false, double? dfLossOverride=null, bool? bAllowSnapshot=null)
Steps a set of iterations through a training cycle.
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
The ConvertOutputArgs is passed to the OnConvertOutput event.
byte[] RawOutput
Specifies the raw output byte stream.
string RawType
Specifies the type of the raw output byte stream.
The GetDataArgs is passed to the OnGetData event to retrieve data.
StateBase State
Specifies the state data of the observations.
The GetStatusArgs is passed to the OnGetStatus event.
The InitializeArgs is passed to the OnInitialize event.
The OverlayArgs is passed ot the OnOverlay event, optionally fired just before displaying a gym image...
Bitmap DisplayImage
Get/set the display image.
The StateBase is the base class for the state of each observation - this is defined by actual trainer...
bool Done
Get/set whether the state is done or not.
double Reward
Get/set the reward of the state.
SimpleDatum Data
Returns other data associated with the state.
int ActionCount
Returns the number of actions.
SimpleDatum Clip
Returns the clip data assoicated with the state.
The WaitArgs is passed to the OnWait event.
The MemoryCollectionFactory is used to create various memory collection types.
static IMemoryCollection CreateMemory(MEMTYPE type, int nMax, float fAlpha=0, string strFile=null)
CreateMemory creates the memory collection type based on the MEMTYPE parameter.
The memory collection stores a set of memory items.
float[] GetInvertedDoneAsOneHotVector()
Returns the inverted done (1 - done) values as a one-hot vector.
List< SimpleDatum > GetNextStateClip()
Returns the list of clip items associated with the next state.
double[] Priorities
Get/set the priorities associated with the collection (if any).
List< SimpleDatum > GetCurrentStateData()
Returns the list of data items associated with the current state.
float[] GetActionsAsOneHotVector(int nActionCount)
Returns the action items as a set of one-hot vectors.
List< SimpleDatum > GetCurrentStateClip()
Returns the list of clip items associated with the current state.
float[] GetRewards()
Returns the rewards as a vector.
List< SimpleDatum > GetNextStateData()
Returns the list of data items associated with the next state.
The MemoryItem stores the information about a given cycle.
The Brain uses the instance of MyCaffe (e.g. the open project) to run new actions and train the netwo...
Brain(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, Phase phase)
The constructor.
void UpdateTargetModel()
The UpdateTargetModel transfers the trained layers from the active Net to the target Net.
int BatchSize
Returns the batch size defined by the model.
void SaveWeights(string strFile)
Save the weight and bias values to file.
Log Log
Returns the output log.
CancelEvent Cancel
Returns the Cancel event used to cancel all MyCaffe tasks.
void OnOverlay(OverlayArgs e)
The OnOverlay callback is called just before displaying the gym image, thus allowing for an overlay t...
void Train(int nIteration, MemoryCollection rgSamples, int nActionCount)
Train the model at the current iteration.
GetDataArgs getDataArgs(Phase phase, int nAction)
Returns the GetDataArgs used to retrieve new data from the envrionment implemented by derived parent ...
bool GetModelUpdated()
Get whether or not the model has been udpated or not.
int act(SimpleDatum sd, SimpleDatum sdClip, int nActionCount)
Returns the action from running the model. The action returned is either randomly selected (when usin...
SimpleDatum Preprocess(StateBase s, bool bUseRawInput, out bool bDifferent, bool bReset=false)
Preprocesses the data.
void Dispose()
Release all resources used by the Brain.
void LoadWeights(string strFile)
Load the weight and bias values from file.
The DqnAgent both builds episodes from the envrionment and trains on them using the Brain.
void Run(Phase phase, int nN, ITERATOR_TYPE type, TRAIN_STEP step)
The Run method provides the main loop that performs the following steps: 1.) get state 2....
void Dispose()
Release all resources used.
DqnAgent(IxTrainerCallback icallback, MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, Phase phase)
The constructor.
byte[] Run(int nIterations, out string type)
Run the action on a set number of iterations and return the results with no training.
The TrainerNoisyDqn implements the Noisy-DQN algorithm as described by Google Dopamine DNQAgent,...
bool Test(int nN, ITERATOR_TYPE type)
Run the test cycle - currently this is not implemented.
ResultCollection RunOne(int nDelay=1000)
Run a single cycle on the environment after the delay.
TrainerNoisyDqn(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback)
The constructor.
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network using a modified PG training algorithm optimized for GPU use.
void Dispose()
Release all resources used.
bool Shutdown(int nWait)
Shutdown the trainer.
byte[] Run(int nN, PropertySet runProp, out string type)
Run a set of iterations and return the resuts.
bool Initialize()
Initialize the trainer.
The IxTrainerCallback provides functions used by each trainer to 'call-back' to the parent for inform...
The IxTrainerCallbackRNN provides functions used by each trainer to 'call-back' to the parent for inf...
void OnConvertOutput(ConvertOutputArgs e)
The OnConvertOutput callback fires from within the Run method and is used to convert the network's ou...
The IxTrainerGetDataCallback interface is called right after rendering the output image and just befo...
The IxTrainerRL interface is implemented by each RL Trainer.
The IMemoryCollection interface is implemented by all memory collection types.
void Update(MemoryCollection rgSamples)
Updates the memory collection - currently only used by the Prioritized memory collection to update it...
int Count
Returns the number of items in the memory collection.
void CleanUp()
Performs final clean-up tasks.
void Add(MemoryItem m)
Add a new item to the memory collection.
MemoryCollection GetSamples(CryptoRandom random, int nCount, double dfBeta)
Retrieve a set of samples from the collection.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Phase
Defines the Phase under which to run a Net.
The MyCaffe.common namespace contains common MyCaffe classes.
TRAIN_STEP
Defines the training stepping method (if any).
The MyCaffe.data namespace contains dataset creators used to create common testing datasets such as M...
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
MEMTYPE
Specifies the type of memory collection to use.
The MyCaffe.trainers namespace contains all reinforcement and recurrent learning trainers.
ITERATOR_TYPE
Specifies the iterator type to use.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...