2using System.Collections.Generic;
6using System.Runtime.InteropServices;
9using System.Threading.Tasks;
35 bool m_bUsePreloadData =
true;
53 m_icallback = icallback;
55 m_properties = properties;
57 m_rgVocabulary = rgVocabulary;
79 private void wait(
int nWait)
84 while (nTotalWait < nWait)
86 m_icallback.OnWait(
new WaitArgs(nWaitInc));
87 nTotalWait += nWaitInc;
98 if (m_mycaffe !=
null)
104 m_icallback.OnShutdown();
109 private void updateStatus(
int nIteration,
int nMaxIteration,
double dfAccuracy,
double dfLoss,
double dfLearningRate)
111 GetStatusArgs args =
new GetStatusArgs(0, nIteration, nIteration, nMaxIteration, dfAccuracy, 0, 0, 0, dfLoss, dfLearningRate);
112 m_icallback.OnUpdateStatus(args);
115 private float computeAccuracy(List<Tuple<float, float>> rg,
float fThreshold)
119 for (
int i=0; i<rg.Count; i++)
121 float fDiff = Math.Abs(rg[i].Item1 - rg[i].Item2);
123 if (fDiff < fThreshold)
127 return (
float)nMatch / (float)rg.Count;
176 return run(nN, type, step,
Phase.TRAIN);
187 if (m_getDataTrainArgs ==
null)
191 m_icallback.OnGetData(m_getDataTrainArgs);
194 m_getDataTrainArgs.
Action = 0;
195 m_getDataTrainArgs.
Reset =
false;
202 throw new Exception(
"Missing expected input layer!");
206 throw new Exception(
"Expected batch size of 1!");
211 string strVal = m_properties.
GetProperty(
"BlobNames");
212 string[] rgstrVal = strVal.Split(
'|');
213 Dictionary<string, string> rgstrValMap =
new Dictionary<string, string>();
215 foreach (
string strVal1
in rgstrVal)
217 string[] rgstrVal2 = strVal1.Split(
'~');
218 if (rgstrVal2.Length != 2)
219 throw new Exception(
"Invalid BlobNames property, expected 'name=blobname'!");
221 rgstrValMap.Add(rgstrVal2[0], rgstrVal2[1]);
225 if (rgstrValMap.ContainsKey(
"x"))
229 if (rgstrValMap.ContainsKey(
"tt"))
233 if (rgstrValMap.ContainsKey(
"mask"))
237 if (rgstrValMap.ContainsKey(
"target"))
241 if (rgstrValMap.ContainsKey(
"xhat"))
245 throw new Exception(
"The 'x' blob was not found in the 'BlobNames' property!");
247 throw new Exception(
"The 'tt' blob was not found in the 'BlobNames' property!");
248 if (blobMask ==
null)
249 throw new Exception(
"The 'mask' blob was not found in the 'BlobNames' property!");
250 if (blobTarget ==
null)
251 throw new Exception(
"The 'target' blob was not found in the 'BlobNames' property!");
252 if (blobXhat ==
null)
253 throw new Exception(
"The 'xhat' blob was not found in the 'BlobNames' property!");
255 if (blobX.
count() != nInputDim)
256 throw new Exception(
"The 'x' blob must have a count of '" + nInputDim.ToString() +
"'!");
257 if (blobTt.
count() != nInputDim)
258 throw new Exception(
"The 'tt' blob must have a count of '" + nInputDim.ToString() +
"'!");
259 if (blobMask.
count() != nInputDim)
260 throw new Exception(
"The 'mask' blob must have a count of '" + nInputDim.ToString() +
"'!");
261 if (blobTarget.
count() != nOutputDim)
262 throw new Exception(
"The 'target' blob must have a count of '" + nOutputDim.ToString() +
"'!");
263 if (blobXhat.
count() != nOutputDim)
264 throw new Exception(
"The 'xhat' blob must have a count of '" + nOutputDim.ToString() +
"'!");
266 float[] rgInput =
new float[nInputDim];
267 float[] rgTimeSteps =
new float[nInputDim];
268 float[] rgMask =
new float[nInputDim];
269 float[] rgTarget =
new float[nOutputDim];
271 List<Tuple<float, float>> rgAccHistory =
new List<Tuple<float, float>>();
273 for (
int i = 0; i < nN; i++)
276 double dfAccuracy = 0;
277 float fPredictedY = 0;
280 m_icallback.OnGetData(m_getDataTrainArgs);
288 List<DataPoint> rgHistory = m_getDataTrainArgs.
State.
History;
289 DataPoint dpLast = (rgHistory.Count > 0) ? rgHistory.Last() :
null;
296 if (rgHistory.Count >= nInputDim)
298 for (
int j = 0; j < nInputDim; j++)
300 int nIdx = rgHistory.Count - nInputDim + j;
301 rgInput[j] = rgHistory[nIdx].Inputs[0];
302 rgTimeSteps[j] = rgHistory[nIdx].Time;
303 rgMask[j] = rgHistory[nIdx].Mask[0];
304 rgTarget[0] = rgHistory[nIdx].Target;
314 if (phase ==
Phase.TRAIN)
321 fPredictedY = rgOutput[0];
323 prop.
SetProperty(
"override_prediction", fPredictedY.ToString());
330 rgAccHistory.Add(
new Tuple<float, float>(fTargetY, fPredictedY));
331 if (rgAccHistory.Count > 100)
332 rgAccHistory.RemoveAt(0);
334 dfAccuracy = computeAccuracy(rgAccHistory, 0.005f);
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
CancelEvent CancelEvent
Returns the CancelEvent used.
Net< T > GetInternalNet(Phase phase=Phase.RUN)
Returns the internal net based on the Phase specified: TRAIN, TEST or RUN.
Solver< T > GetInternalSolver()
Get the internal solver.
Log Log
Returns the Log (for output) used.
The BucketCollection contains a set of Buckets.
void Reset()
Resets the event clearing any signaled state.
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
void Set()
Sets the event to the signaled state.
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
Specifies a key-value pair of properties.
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
void SetProperty(string strName, string strVal)
Sets a property in the property set to a value if it exists, otherwise it adds the new property.
The Utility class provides general utility funtions.
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
The Blob is the main holder of data that moves through the Layers of the Net.
T[] mutable_cpu_data
Get data from the GPU and bring it over to the host, or Set data from the Host and send it over to th...
int count()
Returns the total number of items in the Blob.
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
List< Layer< T > > layers
Returns the layers.
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
void Backward(int nStart=int.MaxValue, int nEnd=0)
The network backward should take no input and output, since it solely computes the gradient w....
Blob< T > blob_by_name(string strName, bool bThrowExceptionOnError=true)
Returns a blob given its name.
The DataPoint contains the data used when training.
float Target
Returns the target value.
LayerParameter layer_param
Returns the LayerParameter for this Layer.
InputParameter input_param
Returns the parameter set when initialized with LayerType.INPUT
double base_lr
The base learning rate (default = 0.01).
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
SolverParameter parameter
Returns the SolverParameter used.
bool Step(int nIters, TRAIN_STEP step=TRAIN_STEP.NONE, bool bZeroDiffs=true, bool bApplyUpdates=true, bool bDisableOutput=false, bool bDisableProgress=false, double? dfLossOverride=null, bool? bAllowSnapshot=null)
Steps a set of iterations through a training cycle.
The GetDataArgs is passed to the OnGetData event to retrieve data.
CancelEvent CancelEvent
Returns the cancel event.
PropertySet ExtraProperties
Get/set extra properties.
bool Reset
Returns whether or not to reset the observation environment or not.
int Action
Returns the action to run. If less than zero, this parameter is ignored.
StateBase State
Specifies the state data of the observations.
The GetStatusArgs is passed to the OnGetStatus event.
The InitializeArgs is passed to the OnInitialize event.
List< DataPoint > History
Get/set the data history (if any exists).
The WaitArgs is passed to the OnWait event.
The TrainerRNNSimple implements a very simple RNN trainer inspired by adepierre's GitHub site referen...
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network using a modified PG training algorithm optimized for GPU use.
TrainerRNNSimple(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback, BucketCollection rgVocabulary)
The constructor.
byte[] Run(int nN, PropertySet runProp, out string type)
Run a single cycle on the environment after the delay.
void Dispose()
Releases all resources used.
bool Initialize()
Initialize the trainer.
bool Test(int nN, ITERATOR_TYPE type)
Run the test cycle - currently this is not implemented.
float[] Run(int nN, PropertySet runProp)
Run a single cycle on the environment after the delay.
bool Shutdown(int nWait)
Shutdown the trainer.
The IxTrainerCallback provides functions used by each trainer to 'call-back' to the parent for inform...
The IxTrainerRL interface is implemented by each RL Trainer.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Phase
Defines the Phase under which to run a Net.
The MyCaffe.common namespace contains common MyCaffe classes.
TRAIN_STEP
Defines the training stepping method (if any).
The MyCaffe.fillers namespace contains all fillers including the Filler class.
The MyCaffe.gym namespace contains all classes related to the Gym's supported by MyCaffe.
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
ITERATOR_TYPE
Specifies the iterator type to use.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...