2using System.Collections.Generic;
8using System.Threading.Tasks;
76 double m_dfExplorationRate = 0;
77 double m_dfOptimalSelectionRate = 0;
78 double m_dfImmediateRewards = 0;
79 double m_dfGlobalRewards = 0;
80 double m_dfGlobalRewardsAve = 0;
81 double m_dfGlobalRewardsMax = -
double.MaxValue;
82 int m_nGlobalEpisodeCount = 0;
83 int m_nGlobalEpisodeMax = 0;
86 REWARD_TYPE m_rewardType = REWARD_TYPE.MAXIMUM;
87 TRAINER_TYPE m_trainerType = TRAINER_TYPE.PG_ST;
88 double m_dfAccuracy = 0;
90 int m_nIterations = -1;
93 bool m_bSnapshot =
false;
96 bool m_bUsePreloadData =
false;
97 object m_syncObj =
new object();
124 InitializeComponent();
135 InitializeComponent();
145 get {
return "MyCaffe RL/RNN Dual Trainer"; }
196 if (stage ==
Stage.RNN)
198 switch (m_trainerType)
200 case TRAINER_TYPE.RNN_SUPER_SIMPLE:
201 return new rnn.simple.TrainerRNNSimple<
double>(mycaffe,
m_properties,
m_random,
this, m_rgVocabulary);
203 case TRAINER_TYPE.RNN_SIMPLE:
204 return new rnn.simple.TrainerRNN<
double>(mycaffe,
m_properties,
m_random,
this, m_rgVocabulary);
207 throw new Exception(
"The trainer type '" + m_trainerType.ToString() +
"' is not supported in the RNN stage!");
212 switch (m_trainerType)
214 case TRAINER_TYPE.PG_SIMPLE:
217 case TRAINER_TYPE.PG_ST:
220 case TRAINER_TYPE.PG_MT:
223 case TRAINER_TYPE.C51_ST:
226 case TRAINER_TYPE.DQN_ST:
229 case TRAINER_TYPE.DQN_SIMPLE:
233 throw new Exception(
"The trainer type '" + m_trainerType.ToString() +
"' is not supported in the RL stage!");
258 if (stage ==
Stage.RNN)
260 switch (m_trainerType)
262 case TRAINER_TYPE.RNN_SUPER_SIMPLE:
263 return new rnn.simple.TrainerRNNSimple<
float>(mycaffe,
m_properties,
m_random,
this, m_rgVocabulary);
265 case TRAINER_TYPE.RNN_SIMPLE:
269 throw new Exception(
"The trainer type '" + m_trainerType.ToString() +
"' is not supported in the RNN stage!");
274 switch (m_trainerType)
276 case TRAINER_TYPE.PG_SIMPLE:
279 case TRAINER_TYPE.PG_ST:
282 case TRAINER_TYPE.PG_MT:
285 case TRAINER_TYPE.C51_ST:
288 case TRAINER_TYPE.DQN_ST:
291 case TRAINER_TYPE.DQN_SIMPLE:
295 throw new Exception(
"The trainer type '" + m_trainerType.ToString() +
"' is not supported in the RL stage!");
395 #region IXMyCaffeCustomTrainer Interface
402 get {
return m_stage; }
474 private void cleanup(
int nWait,
bool bCallShutdown)
478 if (m_itrainer !=
null)
497 m_icallback = icallback;
502 if (strRewardType ==
null)
503 strRewardType =
"VAL";
505 strRewardType = strRewardType.ToUpper();
507 if (strRewardType ==
"VAL" || strRewardType ==
"VALUE")
508 m_rewardType = REWARD_TYPE.VALUE;
509 else if (strRewardType ==
"AVE" || strRewardType ==
"AVERAGE")
510 m_rewardType = REWARD_TYPE.AVERAGE;
514 switch (strTrainerType)
517 m_trainerType = TRAINER_TYPE.PG_SIMPLE;
522 m_trainerType = TRAINER_TYPE.PG_ST;
528 m_trainerType = TRAINER_TYPE.PG_MT;
533 m_trainerType = TRAINER_TYPE.C51_ST;
538 m_trainerType = TRAINER_TYPE.DQN_ST;
543 m_trainerType = TRAINER_TYPE.DQN_SIMPLE;
548 m_trainerType = TRAINER_TYPE.RNN_SIMPLE;
552 case "RNN.SUPER.SIMPLE":
553 m_trainerType = TRAINER_TYPE.RNN_SUPER_SIMPLE;
558 throw new Exception(
"Unknown trainer type '" + strTrainerType +
"'!");
562 private Stage getStage()
564 if (m_trainerType == TRAINER_TYPE.RNN_SIMPLE || m_trainerType == TRAINER_TYPE.RNN_SUPER_SIMPLE)
572 IxTrainer itrainer =
null;
579 itrainer.Initialize();
592 if (m_itrainer ==
null)
593 m_itrainer = createTrainer(mycaffe, getStage());
595 if (nIterationOverride == -1)
596 nIterationOverride = m_nIterations;
598 m_itrainer.
Test(nIterationOverride, type);
611 if (m_itrainer ==
null)
612 m_itrainer = createTrainer(mycaffe, getStage());
614 if (nIterationOverride == -1)
615 nIterationOverride = m_nIterations;
617 m_itrainer.
Train(nIterationOverride, type, step);
618 cleanup(1000,
false);
673 m_dfImmediateRewards = e.
Reward;
675 m_dfGlobalRewardsMax = Math.Max(m_dfGlobalRewardsMax, e.
TotalReward);
676 m_dfGlobalRewardsAve = (1.0 / (double)m_nThreads) * e.
TotalReward + ((m_nThreads - 1) / (
double)m_nThreads) * m_dfGlobalRewardsAve;
681 m_nGlobalEpisodeCount++;
683 m_nGlobalEpisodeCount = e.
Frames;
688 if (m_icallback !=
null)
690 Dictionary<string, double> rgValues =
new Dictionary<string, double>();
695 rgValues.Add(
"Threads", m_nThreads);
702 if (e.
Index == 0 && m_nSnapshot > 0 && m_nGlobalEpisodeCount > 0 && (m_nGlobalEpisodeCount % m_nSnapshot) == 0)
711 Thread.Sleep(e.
Wait);
726 case "GlobalAccuracy":
729 case "GlobalIteration":
732 case "GlobalMaxIterations":
733 return m_nIterations;
735 case "GlobalRewards":
738 case "GlobalEpisodeCount":
741 case "ExplorationRate":
745 throw new Exception(
"The property '" + strProp +
"' is not supported by the MyCaffeTrainerRNN.");
762 switch (m_rewardType)
764 case REWARD_TYPE.VALUE:
765 return m_dfGlobalRewards;
767 case REWARD_TYPE.AVERAGE:
768 return m_dfGlobalRewardsAve;
771 return (m_dfGlobalRewardsMax == -
double.MaxValue) ? 0 : m_dfGlobalRewardsMax;
781 get {
return m_dfImmediateRewards; }
789 get {
return m_dfLoss; }
797 get {
return m_nIteration; }
805 get {
return m_nGlobalEpisodeCount; }
813 get {
return m_nGlobalEpisodeMax; }
821 get {
return m_dfExplorationRate; }
829 get {
return m_dfOptimalSelectionRate; }
848 #region IXMyCaffeCustomTrainerRL Methods
858 if (m_itrainer ==
null)
859 m_itrainer = createTrainer(mycaffe,
Stage.RL);
862 if (itrainer ==
null)
863 throw new Exception(
"The trainer must be set to to 'C51.ST', PG.SIMPLE', 'PG.ST' or 'PG.MT' to run in reinforcement learning mode.");
880 if (m_itrainer ==
null)
881 m_itrainer = createTrainer(mycaffe,
Stage.RL);
885 if (icallback !=
null)
889 if (itrainer ==
null)
890 throw new Exception(
"The IxTrainerRL interface must be implemented.");
892 byte[] rgResults = itrainer.
Run(nN, runProp, out type);
900 #region IXMyCaffeCustomTrainerRNN
910 if (m_itrainer ==
null)
911 m_itrainer = createTrainer(mycaffe,
Stage.RNN);
914 if (itrainer ==
null)
915 throw new Exception(
"The trainer must be set to to 'RNN.SIMPLE' to run in recurrent learning mode.");
919 if (icallback !=
null)
922 float[] rgResults = itrainer.
Run(nN, runProp);
935 byte[] IXMyCaffeCustomTrainerRNN.Run(
Component mycaffe,
int nN, out
string type)
937 if (m_itrainer ==
null)
938 m_itrainer = createTrainer(mycaffe,
Stage.RNN);
940 IxTrainerRNN itrainer = m_itrainer as IxTrainerRNN;
941 if (itrainer ==
null)
942 throw new Exception(
"The trainer must be set to to 'RNN.SIMPLE' to run in recurrent learning mode.");
945 IXMyCaffeCustomTrainerCallbackRNN icallback = m_icallback as IXMyCaffeCustomTrainerCallbackRNN;
946 if (icallback !=
null)
947 runProp = icallback.GetRunProperties();
949 byte[] rgResults = itrainer.Run(nN, runProp, out type);
967 return preloaddata(log, evtCancel, nProjectID, propertyOverride, ci);
978 string IXMyCaffeCustomTrainerRNN.ResizeModel(
Log log,
string strModel,
BucketCollection rgVocabulary)
980 if (rgVocabulary ==
null || rgVocabulary.
Count == 0)
983 int nVocabCount = rgVocabulary.
Count;
985 string strEmbedName =
"";
987 string strIpName =
"";
994 strEmbedName = layer.
name;
999 strIpName = layer.
name;
1006 if (embed.
input_dim != (uint)nVocabCount)
1008 log.
WriteLine(
"WARNING: Embed layer '" + strEmbedName +
"' input dim changed from " + embed.
input_dim.ToString() +
" to " + nVocabCount.ToString() +
" to accomodate for the vocabulary count.");
1013 if (ip !=
null && ip.
num_output != (uint)nVocabCount)
1015 log.
WriteLine(
"WARNING: InnerProduct layer '" + strIpName +
"' num_output changed from " + ip.
num_output.ToString() +
" to " + nVocabCount.ToString() +
" to accomodate for the vocabulary count.");
1019 m_rgVocabulary = rgVocabulary;
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
ConnectInfo DatasetConnectInfo
Returns the dataset connection information, if used (default = null).
ProjectEx CurrentProject
Returns the name of the currently loaded project.
The BucketCollection contains a set of Buckets.
int Count
Returns the number of Buckets.
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
The ConnectInfo class specifies the server, database and username/password used to connect to a datab...
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
The Log class provides general output in text form.
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
string GetSolverSetting(string strParam)
Get a setting from the solver descriptor.
int OriginalID
Get/set the original project ID.
Specifies a key-value pair of properties.
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
void SetProperty(string strName, string strVal)
Sets a property in the property set to a value if it exists, otherwise it adds the new property.
The RawProto class is used to parse and output Google prototxt file data.
override string ToString()
Returns the RawProto as its full prototxt string.
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
The DatasetDescriptor class describes a dataset which contains both a training data source and testin...
The ResultCollection contains the result of a given CaffeControl::Run.
Specifies the parameters used by the EmbedLayer.
uint input_dim
Specifies the input given as integers to be interpreted as one-hot vector indices with dimension num_...
Specifies the parameters for the InnerProductLayer.
uint num_output
The number of outputs for the layer.
Specifies the base parameter for all layers.
string name
Specifies the name of this LayerParameter.
LayerType type
Specifies the type of this LayerParameter.
EmbedParameter embed_param
Returns the parameter set when initialized with LayerType.EMBED
InnerProductParameter inner_product_param
Returns the parameter set when initialized with LayerType.INNERPRODUCT
LayerType
Specifies the layer type.
Specifies the parameters use to create a Net
static NetParameter FromProto(RawProto rp)
Parse a RawProto into a new instance of the parameter.
override RawProto ToProto(string strName)
Constructor for the parameter.
List< LayerParameter > layer
The layers that make up the net. Each of their configurations, including connectivity and behavior,...
The ConvertOutputArgs is passed to the OnConvertOutput event.
The GetDataArgs is passed to the OnGetData event to retrieve data.
The GetStatusArgs is passed to the OnGetStatus event.
double Loss
Returns the loss value.
double OptimalSelectionCoefficient
Returns the optimal selection coefficient.
int MaxFrames
Returns the maximum frame count.
int Iteration
Returns the number of iterations (steps) run.
int Frames
Returns the total frame count across all agents.
int NewFrameCount
Get/set the new frame count.
double ExplorationRate
Returns the current exploration rate.
double TotalReward
Returns the total rewards.
int Index
Returns the index of the caller.
double Reward
Returns the immediate reward for the current episode.
double LearningRate
Returns the current learning rate.
The InitializeArgs is passed to the OnInitialize event.
The MyCaffeTraininerDual is used to perform both reinforcement and recurrent learning training tasks ...
virtual TRAINING_CATEGORY category
Override when using a training method other than the REINFORCEMENT method (the default).
void OnUpdateStatus(GetStatusArgs e)
The OnGetStatus callback fires on each iteration within the Train method.
ConnectInfo m_dsCi
Optionally, specifies the dataset connection info, or null.
virtual BucketCollection preloaddata(Log log, CancelEvent evtCancel, int nProjectID, PropertySet propertyOverride=null, ConnectInfo ci=null)
The preloaddata method gives the custom trainer an opportunity to pre-load any data.
virtual string get_information()
Returns information describing the specific trainer, such as the gym used, if any.
double ImmediateRewards
Returns the immediate rewards for the current training cycle as opposed to the averaged rewards.
bool IsRunningSupported
Returns whether or not Running is supported.
void Initialize(string strProperties, IXMyCaffeCustomTrainerCallback icallback)
Initializes a new custom trainer by loading the key-value pair of properties into the property set.
virtual bool getData(GetDataArgs e)
Override called by the OnGetData event fired by the Trainer to retrieve a new set of observation coll...
void Test(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type=ITERATOR_TYPE.ITERATION)
Create a new trainer and use it to run a test cycle using the current 'stage' = RNN or RL.
int GlobalEpisodeMax
Returns the maximum global episode count.
virtual void testAccuracyUpdate(TestAccuracyUpdateArgs e)
Override called by the OnTestAccuracyUpdate event fired from within the Run method and is used to giv...
int GlobalEpisodeCount
Returns the global episode count.
bool IsTrainingSupported
Returns whether or not Training is supported.
byte[] Run(Component mycaffe, int nN, out string type)
Run the network using the run technique implemented by this trainer.
virtual void openUi()
Called by OpenUi, override this when a UI (via WCF) should be displayed.
double ExplorationRate
Returns the current exploration rate.
int m_nProjectID
Specifies the project ID of the project held by the instance of MyCaffe.
MyCaffeTrainerDual()
The constructor.
MyCaffeTrainerDual(IContainer container)
The constructor.
void OnWait(WaitArgs e)
The OnWait callback fires when waiting for a shutdown.
string Information
Returns information describing the trainer.
void OnTestAccuracyUpdate(TestAccuracyUpdateArgs e)
The OnTestAccuracyUpdate callback fires from within the Run method and is used to give the recipient ...
void OnInitialize(InitializeArgs e)
The OnIntialize callback fires when initializing the trainer.
bool GetUpdateSnapshot(out int nIteration, out double dfAccuracy)
Returns true when the training is ready for a snap-shot, false otherwise.
double OptimalSelectionRate
Returns the rate of selection from the optimal set with the highest reward (this setting is optional,...
void OpenUi()
Open the user interface for the trainer, of one exists.
virtual DatasetDescriptor get_dataset_override(int nProjectID, ConnectInfo ci=null)
Returns a dataset override to use (if any) instead of the project's dataset. If there is no dataset o...
TRAINING_CATEGORY TrainingCategory
Returns the training category of the custom trainer (default = REINFORCEMENT).
CryptoRandom m_random
Random number generator used to get initial actions, etc.
int GlobalIteration
Returns the global iteration.
virtual IxTrainer create_trainerF(Component caffe, Stage stage)
Optionally overridden to return a new type of trainer.
virtual string name
Overriden to give the actual name of the custom trainer.
DatasetDescriptor GetDatasetOverride(int nProjectID, ConnectInfo ci=null)
Returns a dataset override to use (if any) instead of the project's dataset. If there is no dataset o...
virtual bool get_update_snapshot(out int nIteration, out double dfAccuracy)
Returns true when the training is ready for a snap-shot, false otherwise.
double GetProperty(string strProp)
Return a property value from the trainer.
bool IsTestingSupported
Returns whether or not Testing is supported.
virtual void initialize(InitializeArgs e)
Override called by the Initialize method of the trainer.
virtual bool convertOutput(ConvertOutputArgs e)
Override called by the OnConvertOutput event fired by the Trainer to convert the network output into ...
virtual void dispose()
Override to dispose of resources used.
double GlobalLoss
Return the global loss.
double? GlobalRewards
Returns the global rewards based on the reward type specified by the 'RewardType' property.
PropertySet m_properties
Specifies the properties parsed from the key-value pair passed to the Initialize method.
string Name
Returns the name of the custom trainer. This method calls the 'name' override.
void Train(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type=ITERATOR_TYPE.ITERATION, TRAIN_STEP step=TRAIN_STEP.NONE)
Create a new trainer and use it to run a training cycle using the current 'stage' = RNN or RL.
void OnConvertOutput(ConvertOutputArgs e)
The OnConvertOutput callback fires from within the Run method and is used to convert the network outp...
virtual void shutdown()
Override called from within the CleanUp method.
void CleanUp()
Releases any resources used by the component.
void OnShutdown()
The OnShutdown callback fires when shutting down the trainer.
virtual IxTrainer create_trainerD(Component caffe, Stage stage)
Optionally overridden to return a new type of trainer.
void OnGetData(GetDataArgs e)
The OnGetData callback fires from within the Train method and is used to get a new observation data.
The TestAccuracyUpdateArgs are passed to the OnTestAccuracyUpdate event.
The WaitArgs is passed to the OnWait event.
int Wait
Returns the amount of time to wait in milliseconds.
The Component class is a standard Microsoft.NET class that implements the IComponent interface and is...
The IXMyCaffeCustomTrainerCallback interface is used to call back to the parent running the custom tr...
void Update(TRAINING_CATEGORY cat, Dictionary< string, double > rgValues)
The Update method updates the parent with the global iteration, reward and loss.
The IXMyCaffeCustomTrainerCallbackRNN interface is used to call back to the parent running the custom...
PropertySet GetRunProperties()
The GetRunProperties method is used to qeury the properties used when Running, if any.
The IXMyCaffeCustomTrainer interface is used by the MyCaffeCustomTraininer components that provide va...
ResultCollection RunOne(Component mycaffe, int nDelay)
Run the network using the run technique implemented by this trainer.
The IXMyCaffeCustomTrainer interface is used by the MyCaffeCustomTraininer components that provide va...
float[] Run(Component mycaffe, int nN)
Run the network using the run technique implemented by this trainer.
The IxTrainerCallbackRNN provides functions used by each trainer to 'call-back' to the parent for inf...
The IxTrainer interface is implemented by each Trainer.
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network.
bool Test(int nN, ITERATOR_TYPE type)
Test the newtork.
bool Shutdown(int nWait)
Shutdown the trainer.
The IxTrainerRL interface is implemented by each RL Trainer.
ResultCollection RunOne(int nDelay=1000)
Run a single cycle on the trainer.
byte[] Run(int nN, PropertySet runProp, out string type)
Run a number of 'nN' samples on the trainer.
The IxTrainerRL interface is implemented by each RL Trainer.
float[] Run(int nN, PropertySet runProp)
Run a number of 'nN' samples on the trainer.
The descriptors namespace contains all descriptor used to describe various items stored within the da...
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
TRAINING_CATEGORY
Defines the category of training.
Stage
Specifies the stage underwhich to run a custom trainer.
The MyCaffe.common namespace contains common MyCaffe classes.
TRAIN_STEP
Defines the training stepping method (if any).
The MyCaffe.gym namespace contains all classes related to the Gym's supported by MyCaffe.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.trainers namespace contains all reinforcement and recurrent learning trainers.
ITERATOR_TYPE
Specifies the iterator type to use.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...