2using System.Collections.Generic;
8using System.Threading.Tasks;
47 TRAINER_TYPE m_trainerType = TRAINER_TYPE.RNN_SIMPLE;
51 bool m_bSnapshot =
false;
53 double m_dfAccuracy = 0;
55 int m_nIterations = -1;
57 object m_syncObj =
new object();
70 InitializeComponent();
81 InitializeComponent();
89 protected virtual string name
91 get {
return "MyCaffe RNN Trainer"; }
139 switch (m_trainerType)
141 case TRAINER_TYPE.RNN_SUPER_SIMPLE:
142 return new rnn.simple.TrainerRNNSimple<
double>(mycaffe,
m_properties, m_random,
this, m_rgVocabulary);
144 case TRAINER_TYPE.RNN_SIMPLE:
145 return new rnn.simple.TrainerRNN<
double>(mycaffe,
m_properties, m_random,
this, m_rgVocabulary);
148 throw new Exception(
"Unknown trainer type '" + m_trainerType.ToString() +
"'!");
169 switch (m_trainerType)
171 case TRAINER_TYPE.RNN_SUPER_SIMPLE:
172 return new rnn.simple.TrainerRNNSimple<
float>(mycaffe,
m_properties, m_random,
this, m_rgVocabulary);
174 case TRAINER_TYPE.RNN_SIMPLE:
175 return new rnn.simple.TrainerRNN<
float>(mycaffe,
m_properties, m_random,
this, m_rgVocabulary);
178 throw new Exception(
"Unknown trainer type '" + m_trainerType.ToString() +
"'!");
277 #region IXMyCaffeCustomTrainer Interface
284 get {
return Stage.RNN; }
356 private void cleanup(
int nWait,
bool bCallShutdown =
false)
360 if (m_itrainer !=
null)
379 m_icallback = icallback;
384 switch (strTrainerType)
387 m_trainerType = TRAINER_TYPE.RNN_SIMPLE;
391 throw new Exception(
"Unknown trainer type '" + strTrainerType +
"'!");
417 if (m_itrainer ==
null)
418 m_itrainer = createTrainer(mycaffe);
422 if (icallback !=
null)
425 float[] rgResults = m_itrainer.
Run(nN, runProp);
440 if (m_itrainer ==
null)
441 m_itrainer = createTrainer(mycaffe);
445 if (icallback !=
null)
448 byte[] rgResults = m_itrainer.
Run(nN, runProp, out type);
462 if (m_itrainer ==
null)
463 m_itrainer = createTrainer(mycaffe);
465 if (nIterationOverride == -1)
466 nIterationOverride = m_nIterations;
468 m_itrainer.
Test(nIterationOverride, type);
481 if (m_itrainer ==
null)
482 m_itrainer = createTrainer(mycaffe);
484 if (nIterationOverride == -1)
485 nIterationOverride = m_nIterations;
487 m_itrainer.
Train(nIterationOverride, type, step);
540 if (m_icallback !=
null)
547 Dictionary<string, double> rgValues =
new Dictionary<string, double>();
548 rgValues.Add(
"GlobalIteration", e.
Frames);
549 rgValues.Add(
"GlobalLoss", e.
Loss);
561 Thread.Sleep(e.
Wait);
580 case "GlobalAccuracy":
583 case "GlobalIteration":
586 case "GlobalMaxIterations":
587 return m_nIterations;
590 throw new Exception(
"The property '" + strProp +
"' is not supported by the MyCaffeTrainerRNN.");
621 return preloaddata(log, evtCancel, nProjectID, propertyOverride, ci);
634 if (rgVocabulary ==
null || rgVocabulary.
Count == 0)
637 int nVocabCount = rgVocabulary.
Count;
639 string strEmbedName =
"";
641 string strIpName =
"";
648 strEmbedName = layer.
name;
653 strIpName = layer.
name;
660 if (embed.input_dim != (uint)nVocabCount)
662 log.WriteLine(
"WARNING: Embed layer '" + strEmbedName +
"' input dim changed from " + embed.input_dim.ToString() +
" to " + nVocabCount.ToString() +
" to accomodate for the vocabulary count.");
663 embed.input_dim = (uint)nVocabCount;
667 if (ip.num_output != (uint)nVocabCount)
669 log.WriteLine(
"WARNING: InnerProduct layer '" + strIpName +
"' num_output changed from " + ip.num_output.ToString() +
" to " + nVocabCount.ToString() +
" to accomodate for the vocabulary count.");
673 m_rgVocabulary = rgVocabulary;
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
ConnectInfo DatasetConnectInfo
Returns the dataset connection information, if used (default = null).
ProjectEx CurrentProject
Returns the name of the currently loaded project.
The BucketCollection contains a set of Buckets.
int Count
Returns the number of Buckets.
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
The ConnectInfo class specifies the server, database and username/password used to connect to a datab...
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
The Log class provides general output in text form.
string GetSolverSetting(string strParam)
Get a setting from the solver descriptor.
int OriginalID
Get/set the original project ID.
Specifies a key-value pair of properties.
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
The RawProto class is used to parse and output Google prototxt file data.
override string ToString()
Returns the RawProto as its full prototxt string.
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
The DatasetDescriptor class describes a dataset which contains both a training data source and testin...
Specifies the parameters used by the EmbedLayer.
Specifies the parameters for the InnerProductLayer.
uint num_output
The number of outputs for the layer.
Specifies the base parameter for all layers.
string name
Specifies the name of this LayerParameter.
LayerType type
Specifies the type of this LayerParameter.
EmbedParameter embed_param
Returns the parameter set when initialized with LayerType.EMBED
InnerProductParameter inner_product_param
Returns the parameter set when initialized with LayerType.INNERPRODUCT
LayerType
Specifies the layer type.
Specifies the parameters use to create a Net
static NetParameter FromProto(RawProto rp)
Parse a RawProto into a new instance of the parameter.
The ConvertOutputArgs is passed to the OnConvertOutput event.
The GetDataArgs is passed to the OnGetData event to retrieve data.
The GetStatusArgs is passed to the OnGetStatus event.
double Loss
Returns the loss value.
int MaxFrames
Returns the maximum frame count.
int Frames
Returns the total frame count across all agents.
double TotalReward
Returns the total rewards.
double LearningRate
Returns the current learning rate.
The InitializeArgs is passed to the OnInitialize event.
(Depreciated - use MyCaffeTrainerDual instead.) The MyCaffeTrainerRNN is used to perform recurrent ne...
void Initialize(string strProperties, IXMyCaffeCustomTrainerCallback icallback)
Initializes a new custom trainer by loading the key-value pair of properties into the property set.
bool IsTestingSupported
Returns whether or not Testing is supported.
double GetProperty(string strProp)
Returns a specific property value.
void CleanUp()
Releases any resources used by the component.
void OnShutdown()
The OnShutdown callback fires when shutting down the trainer.
virtual void dispose()
Override to dispose of resources used.
void OnGetData(GetDataArgs e)
The OnGetData callback fires from within the Train method and is used to get a new observation data.
byte[] Run(Component mycaffe, int nN, out string type)
Run the network using the run technique implemented by this trainer.
MyCaffeTrainerRNN(IContainer container)
The constructor.
void OpenUi()
Open the user interface for the trainer, of one exists.
void OnConvertOutput(ConvertOutputArgs e)
The OnConvertOutput callback fires from within the Run method and is used to convert the network outp...
void OnInitialize(InitializeArgs e)
The OnIntialize callback fires when initializing the trainer.
string ResizeModel(Log log, string strModel, BucketCollection rgVocabulary)
The ResizeModel method gives the custom trainer the opportunity to resize the model if needed.
void OnWait(WaitArgs e)
The OnWait callback fires when waiting for a shutdown.
void Train(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type=ITERATOR_TYPE.ITERATION, TRAIN_STEP step=TRAIN_STEP.NONE)
Create a new trainer and use it to run a training cycle.
TRAINING_CATEGORY TrainingCategory
Returns the training category of the custom trainer (default = REINFORCEMENT).
void OnUpdateStatus(GetStatusArgs e)
The OnGetStatus callback fires on each iteration within the Train method.
virtual void testAccuracyUpdate(TestAccuracyUpdateArgs e)
Override called by the OnTestAccuracyUpdate event fired from within the Run method and is used to giv...
float[] Run(Component mycaffe, int nN)
Create a new trainer and use it to run a single run cycle.
bool GetUpdateSnapshot(out int nIteration, out double dfAccuracy)
Returns true when the training is ready for a snap-shot, false otherwise.
virtual DatasetDescriptor get_dataset_override(int nProjectID, ConnectInfo ci=null)
Returns a dataset override to use (if any) instead of the project's dataset. If there is no dataset o...
virtual BucketCollection preloaddata(Log log, CancelEvent evtCancel, int nProjectID, PropertySet propertyOverride=null, ConnectInfo ci=null)
The preloaddata method gives the custom trainer an opportunity to pre-load any data.
virtual string get_information()
Returns information describing the specific trainer, such as the gym used, if any.
MyCaffeTrainerRNN()
The constructor.
string Information
Returns information describing the trainer.
virtual TRAINING_CATEGORY category
Override when using a training method other than the RECURRENT method (the default).
void OnTestAccuracyUpdate(TestAccuracyUpdateArgs e)
The OnTestAccuracyUpdate callback fires from within the Run method and is used to give the recipient ...
PropertySet m_properties
Specifies the properties parsed from the key-value pair passed to the Initialize method.
virtual void openUi()
Called by OpenUi, override this when a UI (via WCF) should be displayed.
DatasetDescriptor GetDatasetOverride(int nProjectID, ConnectInfo ci=null)
Returns a dataset override to use (if any) instead of the project's dataset. If there is no dataset o...
virtual void initialize(InitializeArgs e)
Override called by the Initialize method of the trainer.
virtual bool getData(GetDataArgs e)
Override called by the OnGetData event fired by the Trainer to retrieve a new set of observation coll...
virtual bool get_update_snapshot(out int nIteration, out double dfAccuracy)
Returns true when the training is ready for a snap-shot, false otherwise.
virtual IxTrainerRNN create_trainerF(Component caffe)
Optionally overridden to return a new type of trainer.
virtual IxTrainerRNN create_trainerD(Component caffe)
Optionally overridden to return a new type of trainer.
virtual void shutdown()
Override called from within the CleanUp method.
void Test(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type=ITERATOR_TYPE.ITERATION)
Create a new trainer and use it to run a test cycle.
virtual string name
Overriden to give the actual name of the custom trainer.
ConnectInfo m_dsCi
Optionally, specifies the dataset connection info, or null.
int m_nProjectID
Specifies the project ID of the project held by the instance of MyCaffe.
bool IsRunningSupported
Returns whether or not Running is supported.
string Name
Returns the name of the custom trainer. This method calls the 'name' override.
bool IsTrainingSupported
Returns whether or not Training is supported.
virtual bool convertOutput(ConvertOutputArgs e)
Override called by the OnConvertOutput event fired by the Trainer to convert the network output into ...
BucketCollection PreloadData(Log log, CancelEvent evtCancel, int nProjectID, PropertySet propertyOverride=null, ConnectInfo ci=null)
The PreloadData method gives the custom trainer an opportunity to pre-load any data.
The TestAccuracyUpdateArgs are passed to the OnTestAccuracyUpdate event.
The WaitArgs is passed to the OnWait event.
int Wait
Returns the amount of time to wait in milliseconds.
The Component class is a standard Microsoft.NET class that implements the IComponent interface and is...
The IXMyCaffeCustomTrainerCallback interface is used to call back to the parent running the custom tr...
void Update(TRAINING_CATEGORY cat, Dictionary< string, double > rgValues)
The Update method updates the parent with the global iteration, reward and loss.
The IXMyCaffeCustomTrainerCallbackRNN interface is used to call back to the parent running the custom...
PropertySet GetRunProperties()
The GetRunProperties method is used to qeury the properties used when Running, if any.
The IXMyCaffeCustomTrainer interface is used by the MyCaffeCustomTraininer components that provide va...
The IxTrainerCallbackRNN provides functions used by each trainer to 'call-back' to the parent for inf...
bool Initialize()
Initialize the trainer.
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network.
bool Test(int nN, ITERATOR_TYPE type)
Test the newtork.
bool Shutdown(int nWait)
Shutdown the trainer.
The IxTrainerRL interface is implemented by each RL Trainer.
float[] Run(int nN, PropertySet runProp)
Run a number of 'nN' samples on the trainer.
The descriptors namespace contains all descriptor used to describe various items stored within the da...
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
TRAINING_CATEGORY
Defines the category of training.
Stage
Specifies the stage underwhich to run a custom trainer.
The MyCaffe.common namespace contains common MyCaffe classes.
TRAIN_STEP
Defines the training stepping method (if any).
The MyCaffe.gym namespace contains all classes related to the Gym's supported by MyCaffe.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.trainers namespace contains all reinforcement and recurrent learning trainers.
ITERATOR_TYPE
Specifies the iterator type to use.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...