2using System.Collections.Generic;
7using System.Threading.Tasks;
41 m_icallback = icallback;
43 m_properties = properties;
65 private void wait(
int nWait)
70 while (nTotalWait < nWait)
72 m_icallback.OnWait(
new WaitArgs(nWaitInc));
73 nTotalWait += nWaitInc;
84 if (m_mycaffe !=
null)
90 m_icallback.OnShutdown();
103 Agent<T> agent =
new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random,
Phase.TRAIN);
119 Agent<T> agent =
new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random,
Phase.RUN);
120 byte[] rgResults = agent.Run(nN, out type);
135 string strProp = m_properties.
ToString();
138 strProp +=
"EnableNumSkip=False;";
142 Agent<T> agent =
new Agent<T>(m_icallback, m_mycaffe, properties, m_random,
Phase.TRAIN);
143 agent.Run(
Phase.TEST, nN, type);
161 throw new Exception(
"The simple traininer does not support stepping - use the 'PG.MT' trainer instead.");
164 Agent<T> agent =
new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random,
Phase.TRAIN);
165 agent.Run(
Phase.TRAIN, nN, type);
172 class Agent<T> : IDisposable
179 bool m_bAllowDiscountReset =
false;
180 bool m_bUseRawInput =
false;
184 m_icallback = icallback;
185 m_brain =
new Brain<T>(mycaffe, properties, random, phase);
186 m_properties = properties;
190 m_bAllowDiscountReset = properties.
GetPropertyAsBool(
"AllowDiscountReset",
false);
194 public void Dispose()
203 private StateBase getData(
Phase phase,
int nAction)
205 GetDataArgs args = m_brain.getDataArgs(phase, nAction);
206 m_icallback.OnGetData(args);
210 private void updateStatus(
int nIteration,
int nEpisodeCount,
double dfRewardSum,
double dfRunningReward)
212 GetStatusArgs args =
new GetStatusArgs(0, nIteration, nEpisodeCount, 1000000, dfRunningReward, dfRewardSum, 0, 0, 0, 0);
213 m_icallback.OnUpdateStatus(args);
216 public byte[] Run(
int nIterations, out
string type)
218 IxTrainerCallbackRNN icallback = m_icallback as IxTrainerCallbackRNN;
219 if (icallback ==
null)
220 throw new Exception(
"The Run method requires an IxTrainerCallbackRNN interface to convert the results into the native format!");
222 StateBase s = getData(
Phase.RUN, -1);
224 List<float> rgResults =
new List<float>();
226 while (!m_brain.Cancel.WaitOne(0) && (nIterations == -1 || nIteration < nIterations))
229 SimpleDatum x = m_brain.Preprocess(s, m_bUseRawInput);
233 int action = m_brain.act(x, out fAprob);
235 rgResults.
Add(s.Data.TimeStamp.ToFileTime());
236 rgResults.
Add(s.Data.GetDataAtF(0));
237 rgResults.
Add(action);
240 StateBase s_ = getData(
Phase.RUN, action);
244 ConvertOutputArgs args =
new ConvertOutputArgs(nIterations, rgResults.ToArray());
245 icallback.OnConvertOutput(args);
248 return args.RawOutput;
251 private bool isAtIteration(
int nN,
ITERATOR_TYPE type,
int nIteration,
int nEpisode)
284 MemoryCollection m_rgMemory =
new MemoryCollection();
285 double? dfRunningReward =
null;
286 double dfEpisodeReward = 0;
290 StateBase s = getData(phase, -1);
293 throw new Exception(
"The PG.SIMPLE trainer does not support recurrent layers or clip data, use the 'PG.ST' or 'PG.MT' trainer instead.");
295 while (!m_brain.Cancel.WaitOne(0) && !isAtIteration(nN, type, nIteration, nEpisode))
298 SimpleDatum x = m_brain.Preprocess(s, m_bUseRawInput);
302 int action = m_brain.act(x, out fAprob);
305 StateBase s_ = getData(phase, action);
306 dfEpisodeReward += s_.Reward;
308 if (phase ==
Phase.TRAIN)
311 m_rgMemory.Add(
new MemoryItem(s, x, action, fAprob, (
float)s_.Reward));
319 m_brain.Reshape(m_rgMemory);
322 float[] rgDiscountedR = m_rgMemory.GetDiscountedRewards(m_fGamma, m_bAllowDiscountReset);
324 m_brain.SetDiscountedR(rgDiscountedR);
327 float[] rgDlogp = m_rgMemory.GetPolicyGradients();
329 m_brain.SetPolicyGradients(rgDlogp);
332 List<Datum> rgData = m_rgMemory.GetData();
333 m_brain.SetData(rgData);
334 m_brain.Train(nIteration);
337 if (!dfRunningReward.HasValue)
338 dfRunningReward = dfEpisodeReward;
340 dfRunningReward = dfRunningReward * 0.99 + dfEpisodeReward * 0.01;
342 updateStatus(nIteration, nEpisode, dfEpisodeReward, dfRunningReward.Value);
345 s = getData(phase, -1);
360 if (!dfRunningReward.HasValue)
361 dfRunningReward = dfEpisodeReward;
363 dfRunningReward = dfRunningReward * 0.99 + dfEpisodeReward * 0.01;
365 updateStatus(nIteration, nEpisode, dfEpisodeReward, dfRunningReward.Value);
368 s = getData(phase, -1);
381 class Brain<T> : IDisposable
383 MyCaffeControl<T> m_mycaffe;
393 int m_nMiniBatch = 10;
399 m_net = mycaffe.GetInternalNet(phase);
400 m_solver = mycaffe.GetInternalSolver();
401 m_properties = properties;
409 throw new Exception(
"The PG.SIMPLE trainer does not support the Softmax layer, use the 'PG.ST' or 'PG.MT' trainer instead.");
411 if (m_memData ==
null)
412 throw new Exception(
"Could not find the MemoryData Layer!");
414 if (m_memLoss ==
null)
415 throw new Exception(
"Could not find the MemoryLoss Layer!");
417 m_memLoss.
OnGetLoss += memLoss_OnGetLoss;
419 m_blobDiscountedR =
new Blob<T>(mycaffe.Cuda, mycaffe.Log);
420 m_blobPolicyGradient =
new Blob<T>(mycaffe.Cuda, mycaffe.Log);
422 int nMiniBatch = mycaffe.CurrentProject.GetBatchSize(phase);
424 m_nMiniBatch = nMiniBatch;
429 private void dispose(ref
Blob<T> b)
438 public void Dispose()
440 m_memLoss.
OnGetLoss -= memLoss_OnGetLoss;
441 dispose(ref m_blobDiscountedR);
442 dispose(ref m_blobPolicyGradient);
445 public void Reshape(MemoryCollection col)
447 int nNum = col.Count;
448 int nChannels = col[0].Data.Channels;
449 int nHeight = col[0].Data.Height;
450 int nWidth = col[0].Data.Height;
452 m_blobDiscountedR.
Reshape(nNum, 1, 1, 1);
453 m_blobPolicyGradient.
Reshape(nNum, 1, 1, 1);
456 public void SetDiscountedR(
float[] rg)
458 double dfMean = m_blobDiscountedR.
mean(rg);
459 double dfStd = m_blobDiscountedR.
std(dfMean, rg);
464 public void SetPolicyGradients(
float[] rg)
469 public void SetData(List<Datum> rgData)
474 public GetDataArgs getDataArgs(
Phase phase,
int nAction)
476 bool bReset = (nAction == -1) ?
true :
false;
477 return new GetDataArgs(phase, 0, m_mycaffe, m_mycaffe.Log, m_mycaffe.CancelEvent, bReset, nAction,
false);
482 get {
return m_mycaffe.
Log; }
490 public SimpleDatum Preprocess(StateBase s,
bool bUseRawInput)
497 if (m_sdLast ==
null)
509 List<Datum> rgData =
new List<Datum>();
510 rgData.Add(
new Datum(sd));
517 float[] rgfAprob =
null;
519 for (
int i = 0; i < res.
Count; i++)
523 rgfAprob =
Utility.ConvertVecF<T>(res[i].update_cpu_data());
528 if (rgfAprob ==
null)
529 throw new Exception(
"Could not find a non-loss output! Your model should output the loss and the action probabilities.");
531 if (rgfAprob.Length != 1)
532 throw new Exception(
"The simple policy gradient only supports a single data output!");
534 fAprob = rgfAprob[0];
543 public void Train(
int nIteration)
545 m_mycaffe.Log.Enable =
false;
548 if (nIteration % m_nMiniBatch == 0)
554 m_mycaffe.Log.Enable =
true;
562 int nCount = m_blobPolicyGradient.
count();
564 long hBottomDiff = e.
Bottom[0].mutable_gpu_diff;
565 long hDiscountedR = m_blobDiscountedR.
gpu_data;
569 double dfMean = dfSumSq;
575 m_mycaffe.Cuda.mul(nCount, hPolicyGrad, hDiscountedR, hPolicyGrad);
576 m_mycaffe.Cuda.copy(nCount, hPolicyGrad, hBottomDiff);
577 m_mycaffe.Cuda.mul_scalar(nCount, -1.0, hBottomDiff);
583 public MemoryCollection()
587 public float[] GetDiscountedRewards(
float fGamma,
bool bAllowReset)
589 float fRunningAdd = 0;
590 float[] rgR =
m_rgItems.Select(p => p.Reward).ToArray();
591 float[] rgDiscountedR =
new float[rgR.Length];
593 for (
int t = Count - 1; t >= 0; t--)
595 if (bAllowReset && rgR[t] != 0)
598 fRunningAdd = fRunningAdd * fGamma + rgR[t];
599 rgDiscountedR[t] = fRunningAdd;
602 return rgDiscountedR;
605 public float[] GetPolicyGradients()
607 return m_rgItems.Select(p => p.dlogps).ToArray();
610 public List<Datum> GetData()
612 List<Datum> rgData =
new List<Datum>();
614 for (
int i = 0; i <
m_rgItems.Count; i++)
616 rgData.Add(
new Datum(m_rgItems[i].Data));
622 public List<Datum> GetClip()
636 public MemoryItem(StateBase s,
SimpleDatum x,
int nAction,
float fAprob,
float fReward)
645 public StateBase State
647 get {
return m_state; }
657 get {
return m_nAction; }
662 get {
return m_fReward; }
680 return fY - m_fAprob;
684 public override string ToString()
686 return "action = " + m_nAction.
ToString() +
" reward = " + m_fReward.ToString(
"N2") +
" aprob = " + m_fAprob.ToString(
"N5") +
" dlogps = " + dlogps.ToString(
"N5");
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
CancelEvent CancelEvent
Returns the CancelEvent used.
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
void Reset()
Resets the event clearing any signaled state.
CancelEvent()
The CancelEvent constructor.
void Set()
Sets the event to the signaled state.
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
double NextDouble()
Returns a random double within the range .
The Datum class is a simple wrapper to the SimpleDatum class to ensure compatibility with the origina...
The GenericList provides a base used to implement a generic list by only implementing the minimum amo...
List< T > m_rgItems
The actual list of items.
The Log class provides general output in text form.
Log(string strSrc)
The Log constructor.
Specifies a key-value pair of properties.
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
double GetPropertyAsDouble(string strName, double dfDefault=0)
Returns a property as an double value.
override string ToString()
Returns the string representation of the properties.
The SimpleDatum class holds a data input within host memory.
bool Sub(SimpleDatum sd, bool bSetNegativeToZero=false)
Subtract the data of another SimpleDatum from this one, so this = this - sd.
void Zero()
Zero out all data in the datum but keep the size and other settings.
SimpleDatum Add(SimpleDatum d)
Creates a new SimpleDatum and adds another SimpleDatum to it.
override string ToString()
Return a string representation of the SimpleDatum.
The Utility class provides general utility funtions.
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
The BlobCollection contains a list of Blobs.
int Count
Returns the number of items in the collection.
The Blob is the main holder of data that moves through the Layers of the Net.
void SetData(T[] rgData, int nCount=-1, bool bSetCount=true)
Sets a number of items within the Blob's data.
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
void Reshape(int nNum, int nChannels, int nHeight, int nWidth, bool? bUseHalfSize=null)
DEPRECIATED; use
double std(double? dfMean=null, float[] rgDf=null)
Calculate the standard deviation of the blob data.
double mean(float[] rgDf=null, bool bDiff=false)
Calculate the mean of the blob data.
T sumsq_data()
Calcualte the sum of squares (L2 norm squared) of the data.
void NormalizeData(double? dfMean=null, double? dfStd=null)
Normalize the blob data by subtracting the mean and dividing by the standard deviation.
int count()
Returns the total number of items in the Blob.
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
Layer< T > FindLayer(LayerParameter.LayerType? type, string strName)
Find the layer with the matching type, name and or both.
void ClearParamDiffs()
Zero out the diffs of all netw parameters. This should be run before Backward.
The ResultCollection contains the result of a given CaffeControl::Run.
The MemoryDataLayer provides data to the Net from memory. This layer is initialized with the MyCaffe....
virtual void AddDatumVector(Datum[] rgData, Datum[] rgClip=null, int nLblAxis=1, bool bReset=false, bool bResizeBatch=false)
This method is used to add a list of Datums to the memory.
The MemoryLossLayerGetLossArgs class is passed to the OnGetLoss event.
bool EnableLossUpdate
Get/set enabling the loss update within the backpropagation pass.
double Loss
Get/set the externally calculated total loss.
BlobCollection< T > Bottom
Specifies the bottom passed in during the forward pass.
The MemoryLossLayer provides a method of performing a custom loss functionality. Similar to the Memor...
EventHandler< MemoryLossLayerGetLossArgs< T > > OnGetLoss
The OnGetLoss event fires during each forward pass. The value returned is saved, and applied on the b...
The SoftmaxLayer computes the softmax function. This layer is initialized with the MyCaffe....
Specifies the base parameter for all layers.
LayerType
Specifies the layer type.
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
bool Step(int nIters, TRAIN_STEP step=TRAIN_STEP.NONE, bool bZeroDiffs=true, bool bApplyUpdates=true, bool bDisableOutput=false, bool bDisableProgress=false, double? dfLossOverride=null, bool? bAllowSnapshot=null)
Steps a set of iterations through a training cycle.
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
The InitializeArgs is passed to the OnInitialize event.
The WaitArgs is passed to the OnWait event.
The TrainerPG implements a simple Policy Gradient trainer inspired by Andrej Karpathy's blog posed re...
byte[] Run(int nN, PropertySet runProp, out string type)
Run a set of iterations and return the resuts.
bool Initialize()
Initialize the trainer.
bool Test(int nN, ITERATOR_TYPE type)
Run the test cycle - currently this is not implemented.
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network using a modified PG training algorithm optimized for GPU use.
TrainerPG(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback)
The constructor.
bool Shutdown(int nWait)
Shutdown the trainer.
void Dispose()
Releases all resources used.
ResultCollection RunOne(int nDelay=1000)
Run a single cycle on the environment after the delay.
The IxTrainerCallback provides functions used by each trainer to 'call-back' to the parent for inform...
The IxTrainerRL interface is implemented by each RL Trainer.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Phase
Defines the Phase under which to run a Net.
The MyCaffe.common namespace contains common MyCaffe classes.
BLOB_TYPE
Defines the tpe of data held by a given Blob.
TRAIN_STEP
Defines the training stepping method (if any).
The MyCaffe.fillers namespace contains all fillers including the Filler class.
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
ITERATOR_TYPE
Specifies the iterator type to use.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...