2using System.Collections.Generic;
7using System.Threading.Tasks;
41 m_icallback = icallback;
43 m_properties = properties;
65 private void wait(
int nWait)
70 while (nTotalWait < nWait)
72 m_icallback.OnWait(
new WaitArgs(nWaitInc));
73 nTotalWait += nWaitInc;
84 if (m_mycaffe !=
null)
90 m_icallback.OnShutdown();
103 Agent<T> agent =
new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random,
Phase.TRAIN);
119 Agent<T> agent =
new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random,
Phase.RUN);
120 byte[] rgResults = agent.Run(nN, out type);
135 string strProp = m_properties.
ToString();
138 strProp +=
"EnableNumSkip=False;";
142 Agent<T> agent =
new Agent<T>(m_icallback, m_mycaffe, properties, m_random,
Phase.TRAIN);
161 Agent<T> agent =
new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random,
Phase.TRAIN);
162 agent.Run(
Phase.TRAIN, nN, type, step);
169 class Agent<T> : IDisposable
176 bool m_bAllowDiscountReset =
false;
177 bool m_bUseRawInput =
false;
181 m_icallback = icallback;
182 m_brain =
new Brain<T>(mycaffe, properties, random, phase);
183 m_properties = properties;
187 m_bAllowDiscountReset = properties.
GetPropertyAsBool(
"AllowDiscountReset",
false);
191 public void Dispose()
200 private StateBase getData(
Phase phase,
int nAction)
202 GetDataArgs args = m_brain.getDataArgs(phase, nAction);
203 m_icallback.OnGetData(args);
207 private void updateStatus(
int nIteration,
int nEpisodeCount,
double dfRewardSum,
double dfRunningReward)
209 GetStatusArgs args =
new GetStatusArgs(0, nIteration, nEpisodeCount, 1000000, dfRunningReward, dfRewardSum, 0, 0, 0, 0);
210 m_icallback.OnUpdateStatus(args);
219 public byte[] Run(
int nIterations, out
string type)
221 IxTrainerCallbackRNN icallback = m_icallback as IxTrainerCallbackRNN;
222 if (icallback ==
null)
223 throw new Exception(
"The Run method requires an IxTrainerCallbackRNN interface to convert the results into the native format!");
225 StateBase s = getData(
Phase.RUN, -1);
227 List<float> rgResults =
new List<float>();
229 while (!m_brain.Cancel.WaitOne(0) && (nIterations == -1 || nIteration < nIterations))
232 SimpleDatum x = m_brain.Preprocess(s, m_bUseRawInput);
236 int action = m_brain.act(x, s.Clip, out rgfAprob);
238 rgResults.
Add(s.Data.TimeStamp.ToFileTime());
239 rgResults.
Add((
float)s.Data.GetDataAtF(0));
240 rgResults.
Add(action);
243 StateBase s_ = getData(
Phase.RUN, action);
247 ConvertOutputArgs args =
new ConvertOutputArgs(nIterations, rgResults.ToArray());
248 icallback.OnConvertOutput(args);
251 return args.RawOutput;
254 private bool isAtIteration(
int nN,
ITERATOR_TYPE type,
int nIteration,
int nEpisode)
288 MemoryCollection m_rgMemory =
new MemoryCollection();
289 double? dfRunningReward =
null;
290 double dfEpisodeReward = 0;
294 StateBase s = getData(phase, -1);
296 while (!m_brain.Cancel.WaitOne(0) && !isAtIteration(nN, type, nIteration, nEpisode))
299 SimpleDatum x = m_brain.Preprocess(s, m_bUseRawInput);
303 int action = m_brain.act(x, s.Clip, out rgfAprob);
309 StateBase s_ = getData(phase, action);
310 dfEpisodeReward += s_.Reward;
312 if (phase ==
Phase.TRAIN)
315 m_rgMemory.Add(
new MemoryItem(s, x, action, rgfAprob, (
float)s_.Reward));
323 m_brain.Reshape(m_rgMemory);
326 float[] rgDiscountedR = m_rgMemory.GetDiscountedRewards(m_fGamma, m_bAllowDiscountReset);
328 m_brain.SetDiscountedR(rgDiscountedR);
331 float[] rgfAprobSet = m_rgMemory.GetActionProbabilities();
333 m_brain.SetActionProbabilities(rgfAprobSet);
339 float[] rgfAonehotSet = m_rgMemory.GetActionOneHotVectors();
340 m_brain.SetActionOneHotVectors(rgfAonehotSet);
343 List<Datum> rgData = m_rgMemory.GetData();
344 List<Datum> rgClip = m_rgMemory.GetClip();
345 m_brain.SetData(rgData, rgClip);
346 m_brain.Train(nIteration, step);
349 if (!dfRunningReward.HasValue)
350 dfRunningReward = dfEpisodeReward;
352 dfRunningReward = dfRunningReward * 0.99 + dfEpisodeReward * 0.01;
354 updateStatus(nIteration, nEpisode, dfEpisodeReward, dfRunningReward.Value);
357 s = getData(phase, -1);
375 if (!dfRunningReward.HasValue)
376 dfRunningReward = dfEpisodeReward;
378 dfRunningReward = dfRunningReward * 0.99 + dfEpisodeReward * 0.01;
380 updateStatus(nIteration, nEpisode, dfEpisodeReward, dfRunningReward.Value);
383 s = getData(phase, -1);
396 class Brain<T> : IDisposable
398 MyCaffeControl<T> m_mycaffe;
405 bool m_bSoftmaxCeSetup =
false;
418 int m_nMiniBatch = 10;
420 int m_nRecurrentSequenceLength = 0;
421 List<Datum> m_rgData =
null;
422 List<Datum> m_rgClip =
null;
427 m_net = mycaffe.GetInternalNet(phase);
428 m_solver = mycaffe.GetInternalSolver();
429 m_properties = properties;
436 if (m_memData ==
null)
437 throw new Exception(
"Could not find the MemoryData Layer!");
439 if (m_memLoss ==
null)
440 throw new Exception(
"Could not find the MemoryLoss Layer!");
443 m_memLoss.
OnGetLoss += memLoss_OnGetLoss;
445 m_blobDiscountedR =
new Blob<T>(mycaffe.Cuda, mycaffe.Log);
446 m_blobPolicyGradient =
new Blob<T>(mycaffe.Cuda, mycaffe.Log);
447 m_blobActionOneHot =
new Blob<T>(mycaffe.Cuda, mycaffe.Log);
448 m_blobDiscountedR1 =
new Blob<T>(mycaffe.Cuda, mycaffe.Log);
449 m_blobPolicyGradient1 =
new Blob<T>(mycaffe.Cuda, mycaffe.Log);
450 m_blobActionOneHot1 =
new Blob<T>(mycaffe.Cuda, mycaffe.Log);
451 m_blobLoss =
new Blob<T>(mycaffe.Cuda, mycaffe.Log);
452 m_blobAprobLogit =
new Blob<T>(mycaffe.Cuda, mycaffe.Log);
454 if (m_softmax !=
null)
464 m_colAccumulatedGradients.
SetDiff(0);
466 int nMiniBatch = mycaffe.CurrentProject.GetBatchSize(phase);
468 m_nMiniBatch = nMiniBatch;
473 private void dispose(ref
Blob<T> b)
482 public void Dispose()
484 m_memLoss.
OnGetLoss -= memLoss_OnGetLoss;
485 dispose(ref m_blobDiscountedR);
486 dispose(ref m_blobPolicyGradient);
487 dispose(ref m_blobActionOneHot);
488 dispose(ref m_blobDiscountedR1);
489 dispose(ref m_blobPolicyGradient1);
490 dispose(ref m_blobActionOneHot1);
491 dispose(ref m_blobLoss);
492 dispose(ref m_blobAprobLogit);
494 if (m_colAccumulatedGradients !=
null)
496 m_colAccumulatedGradients.
Dispose();
497 m_colAccumulatedGradients =
null;
501 public int RecurrentSequenceLength
503 get {
return m_nRecurrentSequenceLength; }
506 public int Reshape(MemoryCollection col)
508 int nNum = col.
Count;
509 int nChannels = col[0].Data.Channels;
510 int nHeight = col[0].Data.Height;
511 int nWidth = col[0].Data.Height;
512 int nActionProbs = 1;
520 nActionProbs = Math.Max(nCh, nActionProbs);
526 throw new Exception(
"Could not find a non-loss output! Your model should output the loss and the action probabilities.");
528 m_blobDiscountedR.
Reshape(nNum, nActionProbs, 1, 1);
529 m_blobPolicyGradient.
Reshape(nNum, nActionProbs, 1, 1);
530 m_blobActionOneHot.
Reshape(nNum, nActionProbs, 1, 1);
531 m_blobDiscountedR1.
Reshape(nNum, nActionProbs, 1, 1);
532 m_blobPolicyGradient1.
Reshape(nNum, nActionProbs, 1, 1);
533 m_blobActionOneHot1.
Reshape(nNum, nActionProbs, 1, 1);
534 m_blobLoss.
Reshape(1, 1, 1, 1);
539 public void SetDiscountedR(
float[] rg)
541 double dfMean = m_blobDiscountedR.
mean(rg);
542 double dfStd = m_blobDiscountedR.
std(dfMean, rg);
543 int nC = m_blobDiscountedR.
channels;
548 List<float> rgR =
new List<float>();
550 for (
int i = 0; i < rg.Length; i++)
552 for (
int j = 0; j < nC; j++)
565 public void SetActionProbabilities(
float[] rg)
570 public void SetActionOneHotVectors(
float[] rg)
575 public void SetData(List<Datum> rgData, List<Datum> rgClip)
577 if (m_nRecurrentSequenceLength != 1 && rgData.Count > 1 && rgClip !=
null)
590 public GetDataArgs getDataArgs(
Phase phase,
int nAction)
592 bool bReset = (nAction == -1) ?
true :
false;
593 return new GetDataArgs(phase, 0, m_mycaffe, m_mycaffe.Log, m_mycaffe.CancelEvent, bReset, nAction,
true);
598 get {
return m_mycaffe.
Log; }
606 public SimpleDatum Preprocess(StateBase s,
bool bUseRawInput)
613 if (m_sdLast ==
null)
625 List<Datum> rgData =
new List<Datum>();
626 rgData.Add(
new Datum(sd));
629 List<Datum> rgClip =
null;
633 rgClip =
new List<Datum>();
634 rgClip.Add(
new Datum(sdClip));
644 for (
int i = 0; i < res.
Count; i++)
650 if (m_nRecurrentSequenceLength > 1 && res[i].num > 1)
652 int nCount = res[i].count();
653 int nOutput = nCount / res[i].num;
654 nStart = nCount - nOutput;
657 throw new Exception(
"The start must be zero or greater!");
660 rgfAprob =
Utility.ConvertVecF<T>(res[i].update_cpu_data(), nStart);
665 if (rgfAprob ==
null)
666 throw new Exception(
"Could not find a non-loss output! Your model should output the loss and the action probabilities.");
670 for (
int i = 0; i < rgfAprob.Length; i++)
678 if (rgfAprob.Length == 1)
681 return rgfAprob.Length - 1;
692 int nCount = dst.
count();
693 dst.
CopyFrom(src, nIdx * nCount, 0, nCount,
true,
false);
696 public void Train(
int nIteration,
TRAIN_STEP step)
698 m_mycaffe.Log.Enable =
false;
701 if (m_nRecurrentSequenceLength != 1 && m_rgData !=
null && m_rgData.Count > 1 && m_rgClip !=
null)
703 prepareBlob(m_blobActionOneHot1, m_blobActionOneHot);
704 prepareBlob(m_blobDiscountedR1, m_blobDiscountedR);
705 prepareBlob(m_blobPolicyGradient1, m_blobPolicyGradient);
707 for (
int i = 0; i < m_rgData.Count; i++)
709 copyBlob(i, m_blobActionOneHot1, m_blobActionOneHot);
710 copyBlob(i, m_blobDiscountedR1, m_blobDiscountedR);
711 copyBlob(i, m_blobPolicyGradient1, m_blobPolicyGradient);
713 List<Datum> rgData1 =
new List<Datum>() { m_rgData[i] };
714 List<Datum> rgClip1 =
new List<Datum>() { m_rgClip[i] };
718 m_solver.
Step(1, step,
true,
false,
true,
true);
721 m_blobActionOneHot.
ReshapeLike(m_blobActionOneHot1);
723 m_blobPolicyGradient.
ReshapeLike(m_blobPolicyGradient1);
730 m_solver.
Step(1, step,
true,
false,
true,
true);
738 m_colAccumulatedGradients.
SetDiff(0);
743 m_mycaffe.Log.Enable =
true;
746 private T[] unpackLabel(
Datum d)
776 List<int> rgDataShape = e.
Data.shape();
777 List<int> rgClipShape = e.
Clip.shape();
778 List<int> rgLabelShape = e.
Label.shape();
780 int nSeqLen = rgDataShape[0];
782 e.
Data.Log.CHECK_GT(nSeqLen, 0,
"The sequence lenth must be greater than zero!");
783 e.
Data.Log.CHECK_EQ(nBatch, e.
ClipItems.Count,
"The data and clip should have the same number of items.");
784 e.
Data.Log.CHECK_EQ(nSeqLen, rgClipShape[0],
"The data and clip should have the same sequence count.");
786 rgDataShape[1] = nBatch;
787 rgClipShape[1] = nBatch;
788 rgLabelShape[1] = nBatch;
790 e.
Data.Reshape(rgDataShape);
791 e.
Clip.Reshape(rgClipShape);
792 e.
Label.Reshape(rgLabelShape);
794 T[] rgRawData =
new T[e.
Data.count()];
795 T[] rgRawClip =
new T[e.
Clip.count()];
796 T[] rgRawLabel =
new T[e.
Label.count()];
798 int nDataSize = e.
Data.count(2);
799 T[] rgDataItem =
new T[nDataSize];
803 for (
int i = 0; i < nBatch; i++)
808 T[] rgLabel = unpackLabel(data);
810 for (
int j = 0; j < nSeqLen; j++)
812 dfClip = clip.GetDataAt<T>(j);
814 for (
int k = 0; k < nDataSize; k++)
816 rgDataItem[k] = data.GetDataAt<T>(j * nDataSize + k);
822 nIdx = nBatch * j + i;
827 nIdx = i * nBatch + j;
829 Array.
Copy(rgDataItem, 0, rgRawData, nIdx * nDataSize, nDataSize);
830 rgRawClip[nIdx] = dfClip;
834 if (rgLabel.Length == nSeqLen)
835 rgRawLabel[nIdx] = rgLabel[j];
836 else if (rgLabel.Length == 1)
838 if (j == nSeqLen - 1)
839 rgRawLabel[0] = rgLabel[0];
843 throw new Exception(
"The Solver SequenceLength parameter does not match the actual sequence length! The label length '" + rgLabel.Length.ToString() +
"' must be either '1' for SINGLE labels, or the sequence length of '" + nSeqLen.ToString() +
"' for MULTI labels. Stopping training.");
849 e.
Data.mutable_cpu_data = rgRawData;
850 e.
Clip.mutable_cpu_data = rgRawClip;
851 e.
Label.mutable_cpu_data = rgRawLabel;
852 m_nRecurrentSequenceLength = nSeqLen;
875 int nCount = m_blobPolicyGradient.
count();
876 long hActionOneHot = m_blobActionOneHot.
gpu_data;
878 long hDiscountedR = m_blobDiscountedR.
gpu_data;
881 int nDataSize = e.
Bottom[0].count(1);
882 bool bUsingEndData =
false;
886 if (m_nRecurrentSequenceLength > 1)
893 List<int> rgShape = e.
Bottom[0].shape();
895 e.
Bottom[0].Reshape(rgShape);
896 e.
Bottom[0].CopyFrom(m_blobAprobLogit, (m_blobAprobLogit.
num - 1) * nDataSize, 0, nDataSize,
true,
true);
897 bUsingEndData =
true;
901 long hBottomDiff = e.
Bottom[0].mutable_gpu_diff;
904 if (m_softmax !=
null)
910 colBottom.
Add(m_blobActionOneHot);
911 colTop.
Add(m_blobLoss);
912 colTop.
Add(m_blobPolicyGradient);
914 if (!m_bSoftmaxCeSetup)
916 m_softmaxCe.
Setup(colBottom, colTop);
917 m_bSoftmaxCeSetup =
true;
920 dfLoss = m_softmaxCe.
Forward(colBottom, colTop);
921 m_softmaxCe.
Backward(colTop,
new List<bool>() {
true,
false }, colBottom);
922 hPolicyGrad = colBottom[0].gpu_diff;
927 m_mycaffe.Cuda.add_scalar(nCount, -1.0, hActionOneHot);
928 m_mycaffe.Cuda.abs(nCount, hActionOneHot, hActionOneHot);
929 m_mycaffe.Cuda.mul_scalar(nCount, -1.0, hPolicyGrad);
930 m_mycaffe.Cuda.add(nCount, hActionOneHot, hPolicyGrad, hPolicyGrad);
933 m_mycaffe.Cuda.mul_scalar(nCount, -1.0, hPolicyGrad);
937 m_mycaffe.Cuda.mul(nCount, hPolicyGrad, hDiscountedR, hPolicyGrad);
942 if (hPolicyGrad != hBottomDiff)
943 m_mycaffe.Cuda.copy(nCount, hPolicyGrad, hBottomDiff);
947 if (m_nRecurrentSequenceLength > 1 && bUsingEndData)
950 m_blobAprobLogit.
CopyFrom(e.
Bottom[0], 0, (m_blobAprobLogit.
num - 1) * nDataSize, nDataSize,
false,
true);
951 e.
Bottom[0].CopyFrom(m_blobAprobLogit,
false,
true);
952 e.
Bottom[0].CopyFrom(m_blobAprobLogit,
true);
959 public MemoryCollection()
963 public float[] GetDiscountedRewards(
float fGamma,
bool bAllowReset)
965 float fRunningAdd = 0;
966 float[] rgR =
m_rgItems.Select(p => p.Reward).ToArray();
967 float[] rgDiscountedR =
new float[rgR.Length];
969 for (
int t = Count - 1; t >= 0; t--)
971 if (bAllowReset && rgR[t] != 0)
974 fRunningAdd = fRunningAdd * fGamma + rgR[t];
975 rgDiscountedR[t] = fRunningAdd;
978 return rgDiscountedR;
981 public float[] GetActionProbabilities()
983 List<float> rgfAprob =
new List<float>();
985 for (
int i = 0; i <
m_rgItems.Count; i++)
987 rgfAprob.AddRange(m_rgItems[i].Aprob);
990 return rgfAprob.ToArray();
993 public float[] GetActionOneHotVectors()
995 List<float> rgfAonehot =
new List<float>();
997 for (
int i = 0; i <
m_rgItems.Count; i++)
999 float[] rgfOneHot =
new float[
m_rgItems[0].Aprob.Length];
1001 if (rgfOneHot.Length == 1)
1002 rgfOneHot[0] = m_rgItems[i].Action;
1004 rgfOneHot[m_rgItems[i].Action] = 1;
1006 rgfAonehot.AddRange(rgfOneHot);
1009 return rgfAonehot.ToArray();
1012 public List<Datum> GetData()
1014 List<Datum> rgData =
new List<Datum>();
1016 for (
int i = 0; i < m_rgItems.Count; i++)
1018 rgData.Add(
new Datum(m_rgItems[i].Data));
1024 public List<Datum> GetClip()
1026 if (m_rgItems.Count == 0)
1029 if (m_rgItems[0].State.Clip ==
null)
1032 List<Datum> rgData =
new List<Datum>();
1034 for (
int i = 0; i < m_rgItems.Count; i++)
1036 if (m_rgItems[i].State.Clip ==
null)
1039 rgData.Add(
new Datum(m_rgItems[i].State.Clip));
1054 public MemoryItem(StateBase s,
SimpleDatum x,
int nAction,
float[] rgfAprob,
float fReward)
1058 m_nAction = nAction;
1059 m_rgfAprob = rgfAprob;
1060 m_fReward = fReward;
1063 public StateBase State
1065 get {
return m_state; }
1075 get {
return m_nAction; }
1080 get {
return m_fReward; }
1086 public float[] Aprob
1088 get {
return m_rgfAprob; }
1091 public override string ToString()
1093 return "action = " + m_nAction.
ToString() +
" reward = " + m_fReward.ToString(
"N2") +
" aprob = " + tostring(m_rgfAprob);
1096 private string tostring(
float[] rg)
1100 for (
int i = 0; i < rg.Length; i++)
1102 str += rg[i].ToString(
"N5");
1106 str = str.TrimEnd(
',');
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
CancelEvent CancelEvent
Returns the CancelEvent used.
The BinaryData class is used to pack and unpack DataCriteria binary data, optionally stored within ea...
static List< double > UnPackDoubleList(byte[] rg, DATA_FORMAT fmtExpected)
Unpack the byte array into a list of double values.
static List< float > UnPackFloatList(byte[] rg, DATA_FORMAT fmtExpected)
Unpack the byte array into a list of float values.
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
void Reset()
Resets the event clearing any signaled state.
CancelEvent()
The CancelEvent constructor.
void Set()
Sets the event to the signaled state.
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
double NextDouble()
Returns a random double within the range .
The Datum class is a simple wrapper to the SimpleDatum class to ensure compatibility with the origina...
The GenericList provides a base used to implement a generic list by only implementing the minimum amo...
List< T > m_rgItems
The actual list of items.
The Log class provides general output in text form.
Log(string strSrc)
The Log constructor.
Specifies a key-value pair of properties.
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
double GetPropertyAsDouble(string strName, double dfDefault=0)
Returns a property as an double value.
override string ToString()
Returns the string representation of the properties.
The SimpleDatum class holds a data input within host memory.
void Copy(SimpleDatum d, bool bCopyData, int? nHeight=null, int? nWidth=null)
Copy another SimpleDatum into this one.
bool Sub(SimpleDatum sd, bool bSetNegativeToZero=false)
Subtract the data of another SimpleDatum from this one, so this = this - sd.
void Zero()
Zero out all data in the datum but keep the size and other settings.
SimpleDatum Add(SimpleDatum d)
Creates a new SimpleDatum and adds another SimpleDatum to it.
override string ToString()
Return a string representation of the SimpleDatum.
byte[] DataCriteria
Get/set data criteria associated with the data.
DATA_FORMAT
Defines the data format of the DebugData and DataCriteria when specified.
DATA_FORMAT DataCriteriaFormat
Get/set the data format of the data criteria.
The Utility class provides general utility funtions.
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
The BlobCollection contains a list of Blobs.
void Dispose()
Release all resource used by the collection and its Blobs.
void Add(Blob< T > b)
Add a new Blob to the collection.
void Accumulate(CudaDnn< T > cuda, BlobCollection< T > src, bool bAccumulateDiff)
Accumulate the diffs from one BlobCollection into another.
void SetDiff(double df)
Set all blob diff to the value specified.
int Count
Returns the number of items in the collection.
The Blob is the main holder of data that moves through the Layers of the Net.
int channels
DEPRECIATED; legacy shape accessor channels: use shape(1) instead.
void SetData(T[] rgData, int nCount=-1, bool bSetCount=true)
Sets a number of items within the Blob's data.
int height
DEPRECIATED; legacy shape accessor height: use shape(2) instead.
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
void Reshape(int nNum, int nChannels, int nHeight, int nWidth, bool? bUseHalfSize=null)
DEPRECIATED; use
double std(double? dfMean=null, float[] rgDf=null)
Calculate the standard deviation of the blob data.
double mean(float[] rgDf=null, bool bDiff=false)
Calculate the mean of the blob data.
void CopyFrom(Blob< T > src, int nSrcOffset, int nDstOffset, int nCount, bool bCopyData, bool bCopyDiff)
Copy from a source Blob.
int width
DEPRECIATED; legacy shape accessor width: use shape(3) instead.
T sumsq_data()
Calcualte the sum of squares (L2 norm squared) of the data.
void NormalizeData(double? dfMean=null, double? dfStd=null)
Normalize the blob data by subtracting the mean and dividing by the standard deviation.
int count()
Returns the total number of items in the Blob.
void ReshapeLike(Blob< T > b, bool? bUseHalfSize=null)
Reshape this Blob to have the same shape as another Blob.
void SetDiff(double dfVal, int nIdx=-1)
Either sets all of the diff items in the Blob to a given value, or alternatively only sets a single i...
int num
DEPRECIATED; legacy shape accessor num: use shape(0) instead.
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
Layer< T > FindLayer(LayerParameter.LayerType? type, string strName)
Find the layer with the matching type, name and or both.
BlobCollection< T > output_blobs
Returns the collection of output Blobs.
void ClearParamDiffs()
Zero out the diffs of all netw parameters. This should be run before Backward.
BlobCollection< T > learnable_parameters
Returns the learnable parameters.
The ResultCollection contains the result of a given CaffeControl::Run.
void Backward(BlobCollection< T > colTop, List< bool > rgbPropagateDown, BlobCollection< T > colBottom)
Given the top Blob error gradients, compute the bottom Blob error gradients.
double Forward(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Given the bottom (input) Blobs, this function computes the top (output) Blobs and the loss.
void Setup(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Implements common Layer setup functionality.
The MemoryDataLayer provides data to the Net from memory. This layer is initialized with the MyCaffe....
virtual void AddDatumVector(Datum[] rgData, Datum[] rgClip=null, int nLblAxis=1, bool bReset=false, bool bResizeBatch=false)
This method is used to add a list of Datums to the memory.
EventHandler< MemoryDataLayerPackDataArgs< T > > OnDataPack
The OnDataPack event fires from within the AddDatumVector method and is used to pack the data into a ...
The MemoryDataLayerPackDataArgs is passed to the OnDataPack event which fires each time the data rece...
Blob< T > Label
Returns the label data to fill with ordered label information.
Blob< T > Clip
Returns the clip data to fill with ordered data for clipping.
List< Datum > ClipItems
Returns the raw clip items to use to fill.
LayerParameter.LayerType LstmType
Returns the LSTM type.
Blob< T > Data
Returns the blob data to fill with ordered data.
List< Datum > DataItems
Returns the raw data items to use to fill.
The MemoryLossLayerGetLossArgs class is passed to the OnGetLoss event.
bool EnableLossUpdate
Get/set enabling the loss update within the backpropagation pass.
double Loss
Get/set the externally calculated total loss.
BlobCollection< T > Bottom
Specifies the bottom passed in during the forward pass.
The MemoryLossLayer provides a method of performing a custom loss functionality. Similar to the Memor...
EventHandler< MemoryLossLayerGetLossArgs< T > > OnGetLoss
The OnGetLoss event fires during each forward pass. The value returned is saved, and applied on the b...
The SoftmaxCrossEntropyLossLayer computes the cross-entropy (logisitic) loss and is often used for pr...
The SoftmaxLayer computes the softmax function. This layer is initialized with the MyCaffe....
Specifies the base parameter for all layers.
List< double > loss_weight
Specifies the loss weight.
LayerType
Specifies the layer type.
LossParameter loss_param
Returns the parameter set when initialized with LayerType.LOSS
Stores the parameters used by loss layers.
NormalizationMode
How to normalize the loss for loss layers that aggregate across batches, spatial dimensions,...
NormalizationMode? normalization
Specifies the normalization mode (default = VALID).
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
bool Step(int nIters, TRAIN_STEP step=TRAIN_STEP.NONE, bool bZeroDiffs=true, bool bApplyUpdates=true, bool bDisableOutput=false, bool bDisableProgress=false, double? dfLossOverride=null, bool? bAllowSnapshot=null)
Steps a set of iterations through a training cycle.
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
The InitializeArgs is passed to the OnInitialize event.
The WaitArgs is passed to the OnWait event.
The TrainerPG implements a simple Policy Gradient trainer inspired by Andrej Karpathy's blog posed re...
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network using a modified PG training algorithm optimized for GPU use.
void Dispose()
Releases all resources used.
byte[] Run(int nN, PropertySet runProp, out string type)
Run a set of iterations and return the resuts.
bool Initialize()
Initialize the trainer.
bool Shutdown(int nWait)
Shutdown the trainer.
ResultCollection RunOne(int nDelay=1000)
Run a single cycle on the environment after the delay.
TrainerPG(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback)
The constructor.
bool Test(int nN, ITERATOR_TYPE type)
Run the test cycle - currently this is not implemented.
The IxTrainerCallback provides functions used by each trainer to 'call-back' to the parent for inform...
The IxTrainerRL interface is implemented by each RL Trainer.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Phase
Defines the Phase under which to run a Net.
The MyCaffe.common namespace contains common MyCaffe classes.
BLOB_TYPE
Defines the tpe of data held by a given Blob.
TRAIN_STEP
Defines the training stepping method (if any).
The MyCaffe.fillers namespace contains all fillers including the Filler class.
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
ITERATOR_TYPE
Specifies the iterator type to use.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...