2using System.Collections.Generic;
21 DataItem m_currentData =
null;
25 float[] m_rgEncInput1;
26 float[] m_rgEncInput2;
30 float[] m_rgDecTarget;
113 get {
return m_vocab; }
121 get {
return (m_currentData ==
null) ?
new IterationInfo(
true,
true, 0) : m_currentData.IterationInfo; }
124 private static string clean(
string str)
128 foreach (
char ch
in str)
173 private static List<string> preprocess(
string str,
int nMaxLen = 0)
175 string strInput = clean(str);
176 List<string> rgstr = strInput.ToLower().Trim().Split(
' ').ToList();
180 rgstr = rgstr.Take(nMaxLen).ToList();
181 if (rgstr.Count < nMaxLen)
188 private string getPath(
string strPath)
190 string strTarget =
"$ProgramData$";
192 if (!strPath.StartsWith(strTarget))
195 string strProgData = Environment.GetFolderPath(Environment.SpecialFolder.CommonApplicationData);
196 strProgData = strProgData.TrimEnd(
'\\');
198 strPath = strProgData + strPath.Substring(strTarget.Length);
208 List<List<string>> rgrgstrInput =
new List<List<string>>();
209 List<List<string>> rgrgstrTarget =
new List<List<string>>();
214 string[] rgstrInput = File.ReadAllLines(strEncoderSrc);
215 string[] rgstrTarget = File.ReadAllLines(strDecoderSrc);
217 if (rgstrInput.Length != rgstrTarget.Length)
218 throw new Exception(
"Both the input and target files must contains the same number of lines!");
222 List<string> rgstrInput1 = preprocess(rgstrInput[i]);
223 List<string> rgstrTarget1 = preprocess(rgstrTarget[i]);
225 if (rgstrInput1 !=
null && rgstrTarget1 !=
null)
227 rgrgstrInput.Add(rgstrInput1);
228 rgrgstrTarget.Add(rgstrTarget1);
233 m_vocab.
Load(rgrgstrInput, rgrgstrTarget);
234 m_data =
new Data(rgrgstrInput, rgrgstrTarget, m_vocab);
251 if (colBottom ==
null)
258 foreach (KeyValuePair<string, BlobShape> kv
in rgInput)
267 string strEncInput = customInput.
GetProperty(
"InputData");
268 if (strEncInput ==
null)
269 throw new Exception(
"Could not find the expected input property 'InputData'!");
290 if (nDecInput.HasValue && nDecInput.Value == (
int)
SPECIAL_TOKENS.EOS)
293 List<string> rgstrInput =
null;
294 if (strEncInput !=
null)
295 rgstrInput = preprocess(strEncInput);
297 DataItem data = Data.GetInputData(m_vocab, rgstrInput, nDecInput);
300 m_log.
CHECK_EQ(colBottom.
Count, 4,
"The bottom collection must have 3 items: dec_input, enc_input, enc_inputr, enc_clip");
302 m_log.
CHECK_EQ(colBottom.
Count, 3,
"The bottom collection must have 3 items: dec_input, enc_input | enc_inputr, enc_clip");
307 colBottom[nBtmIdx].
Reshape(
new List<int>() { 1, 1, 1 });
312 colBottom[nBtmIdx].
Reshape(
new List<int>() { nT, 1, 1 });
318 colBottom[nBtmIdx].
Reshape(
new List<int>() { nT, 1, 1 });
322 colBottom[nBtmIdx].
Reshape(
new List<int>() { nT, 1 });
324 float[] rgEncInput =
null;
325 float[] rgEncInputR =
null;
326 float[] rgEncClip =
null;
327 float[] rgDecInput =
new float[1];
329 if (data.EncoderInput !=
null)
331 rgEncInput =
new float[nT];
332 rgEncInputR =
new float[nT];
333 rgEncClip =
new float[nT];
335 for (
int i = 0; i < nT && i < data.EncoderInput.Count; i++)
337 rgEncInput[i] = data.EncoderInput[i];
338 rgEncInputR[i] = data.EncoderInputReverse[i];
339 rgEncClip[i] = (i == 0) ? 0 : 1;
343 rgDecInput[0] = data.DecoderInput;
346 colBottom[nBtmIdx].mutable_cpu_data =
convert(rgDecInput);
351 if (rgEncInput !=
null)
352 colBottom[nBtmIdx].mutable_cpu_data =
convert(rgEncInput);
358 if (rgEncInputR !=
null)
359 colBottom[nBtmIdx].mutable_cpu_data =
convert(rgEncInputR);
363 if (rgEncClip !=
null)
364 colBottom[nBtmIdx].mutable_cpu_data =
convert(rgEncClip);
380 List<Tuple<string, int, double>> rgRes =
new List<Tuple<string, int, double>>();
383 double dfProb = blobSoftmax.
GetMaxData(out lPos);
385 rgRes.Add(
new Tuple<string, int, double>(m_vocab.
IndexToWord((
int)lPos), (
int)lPos, dfProb));
391 for (
int i = 1; i < nK; i++)
393 blobSoftmax.
SetData(-1000000000, (
int)lPos);
397 if (strWord.Length > 0)
398 rgRes.Add(
new Tuple<string, int, double>(strWord, (
int)lPos, dfProb));
432 m_log.
CHECK_EQ(colTop.
Count, 7,
"When normal and reverse encoder output used, there must be 7 tops: dec, dclip, enc, encr, eclip, vocabcount, dectgt (only valid on TEST | TRAIN)");
434 m_log.
CHECK_EQ(colTop.
Count, 6,
"When normal or reverse encoder output used, there must be 6 tops: dec, dclip, enc | encr, eclip, vocabcount, dectgt (only valid on TEST | TRAIN)");
436 m_log.
FAIL(
"You must specify to enable either normal, reverse or both encoder inputs.");
450 reshape(colTop,
true);
482 reshape(colTop,
false);
489 List<int> rgTopShape =
new List<int>() { nT, nBatchSize, 1 };
494 colTop[nTopIdx].
Reshape(
new List<int>() { 1, nBatchSize, 1 });
499 colTop[nTopIdx].
Reshape(
new List<int>() { 1, nBatchSize });
506 colTop[nTopIdx].
Reshape(rgTopShape);
514 colTop[nTopIdx].
Reshape(rgTopShape);
520 colTop[nTopIdx].
Reshape(
new List<int>() { nT, nBatchSize });
524 colTop[nTopIdx].
Reshape(
new List<int>() { 1 });
531 colTop[nTopIdx].
Reshape(
new List<int>() { 1, nBatchSize, 1 });
547 Array.Clear(m_rgDecInput, 0, m_rgDecInput.Length);
549 Array.Clear(m_rgDecTarget, 0, m_rgDecTarget.Length);
550 Array.Clear(m_rgDecClip, 0, m_rgDecClip.Length);
551 Array.Clear(m_rgEncInput1, 0, m_rgEncInput1.Length);
552 Array.Clear(m_rgEncInput2, 0, m_rgEncInput2.Length);
553 Array.Clear(m_rgEncClip, 0, m_rgEncClip.Length);
559 for (
int i = 0; i < nBatch; i++)
571 for (
int j = 0; j < nT && j < m_currentData.EncoderInput.Count; j++)
573 m_rgEncInput1[nIdx + j] = m_currentData.EncoderInput[j];
574 m_rgEncInput2[nIdx + j] = m_currentData.EncoderInputReverse[j];
575 m_rgEncClip[nIdx + j] = (j == 0) ? 0 : 1;
578 m_rgDecClip[i] = m_currentData.DecoderClip;
579 m_rgDecInput[i] = m_currentData.DecoderInput;
580 m_rgDecTarget[i] = m_currentData.DecoderTarget;
583 colTop[nTopIdx].mutable_cpu_data =
convert(m_rgDecInput);
586 colTop[nTopIdx].mutable_cpu_data =
convert(m_rgDecClip);
591 colTop[nTopIdx].mutable_cpu_data =
convert(m_rgEncInput1);
597 colTop[nTopIdx].mutable_cpu_data =
convert(m_rgEncInput2);
601 colTop[nTopIdx].mutable_cpu_data =
convert(m_rgEncClip);
606 colTop[nTopIdx].mutable_cpu_data =
convert(m_rgDecTarget);
612 float fDecInput =
convertF(colBottom[nBtmIdx].GetData(0));
619 colTop[nTopIdx].
SetData(fDecInput, 0);
623 colTop[nTopIdx].
SetData((fDecInput == 1) ? 0 : 1, 0);
628 colTop[nTopIdx].
CopyFrom(colBottom[nBtmIdx]);
635 colTop[nTopIdx].
CopyFrom(colBottom[nBtmIdx]);
641 colTop[nTopIdx].
CopyFrom(colBottom[nBtmIdx]);
654#pragma warning disable 1591
658 Random m_random =
new Random((
int)DateTime.Now.Ticks);
659 List<List<string>> m_rgInput;
660 List<List<string>> m_rgOutput;
661 int m_nCurrentSequence = -1;
662 int m_nCurrentOutputIdx = 0;
663 int m_nSequenceIdx = 0;
665 int m_nIterations = 0;
666 int m_nOutputCount = 0;
669 public Data(List<List<string>> rgInput, List<List<string>> rgOutput,
Vocabulary vocab)
673 m_rgOutput = rgOutput;
678 get {
return m_vocab; }
681 public int VocabularyCount
686 public static DataItem GetInputData(
Vocabulary vocab, List<string> rgstrInput,
int? nDecInput =
null)
688 List<int> rgInput =
null;
690 if (rgstrInput !=
null)
692 rgInput =
new List<int>();
693 foreach (
string str
in rgstrInput)
701 if (!nDecInput.HasValue)
707 return new DataItem(rgInput, nDecInput.Value, -1, nClip,
false,
true, 0);
710 public DataItem GetNextData(
bool bShuffle)
714 bool bNewSequence =
false;
715 bool bNewEpoch =
false;
717 if (m_nCurrentSequence == -1)
724 m_nCurrentSequence = m_random.Next(m_rgInput.Count);
728 m_nCurrentSequence = m_nSequenceIdx;
730 if (m_nSequenceIdx == m_rgOutput.Count)
734 m_nOutputCount = m_rgOutput[m_nCurrentSequence].Count;
737 if (m_nIterations == m_rgOutput.Count)
744 List<string> rgstrInput = m_rgInput[m_nCurrentSequence];
745 List<int> rgInput =
new List<int>();
746 foreach (
string str
in rgstrInput)
753 if (m_nCurrentOutputIdx < m_rgOutput[m_nCurrentSequence].Count)
755 string strTarget = m_rgOutput[m_nCurrentSequence][m_nCurrentOutputIdx];
759 DataItem data =
new DataItem(rgInput, m_nIxInput, nIxTarget, nDecClip, bNewEpoch, bNewSequence, m_nOutputCount);
760 m_nIxInput = nIxTarget;
762 m_nCurrentOutputIdx++;
764 if (m_nCurrentOutputIdx == m_rgOutput[m_nCurrentSequence].Count)
766 m_nCurrentSequence = -1;
767 m_nCurrentOutputIdx = 0;
779 List<int> m_rgInputReverse;
784 public DataItem(List<int> rgInput,
int nIxInput,
int nIxTarget,
int nDecClip,
bool bNewEpoch,
bool bNewSequence,
int nOutputCount)
787 m_nIxInput = nIxInput;
788 m_nIxTarget = nIxTarget;
789 m_nDecClip = nDecClip;
790 m_iter =
new IterationInfo(bNewEpoch, bNewSequence, nOutputCount);
791 m_rgInputReverse =
new List<int>();
795 for (
int i = rgInput.Count - 1; i >= 0; i--)
797 m_rgInputReverse.Add(rgInput[i]);
802 m_rgInputReverse =
null;
806 public List<int> EncoderInput
808 get {
return m_rgInput; }
811 public List<int> EncoderInputReverse
813 get {
return m_rgInputReverse; }
816 public int DecoderInput
818 get {
return m_nIxInput; }
821 public int DecoderTarget
823 get {
return m_nIxTarget; }
826 public int DecoderClip
828 get {
return m_nDecClip; }
833 get {
return m_iter; }
837#pragma warning restore 1591
856 m_bNewEpoch = bNewEpoch;
857 m_bNewSequence = bNewSequence;
858 m_nOutputCount = nOutputCount;
866 get {
return m_bNewEpoch; }
874 get {
return m_bNewSequence; }
882 get {
return m_nOutputCount; }
891 Dictionary<string, int> m_rgDictionary =
new Dictionary<string, int>();
892 Dictionary<string, int> m_rgWordToIndex =
new Dictionary<string, int>();
893 Dictionary<int, string> m_rgIndexToWord =
new Dictionary<int, string>();
894 List<string> m_rgstrVocabulary =
new List<string>();
910 if (!m_rgWordToIndex.ContainsKey(strWord))
911 throw new Exception(
"I do not know the word '" + strWord +
"'!");
913 return m_rgWordToIndex[strWord];
923 if (!m_rgIndexToWord.ContainsKey(nIdx))
926 return m_rgIndexToWord[nIdx];
934 get {
return m_rgstrVocabulary.Count; }
942 public void Load(List<List<string>> rgrgstrInput, List<List<string>> rgrgstrTarget)
944 m_rgDictionary =
new Dictionary<string, int>();
947 for (
int i = 0; i < rgrgstrInput.Count; i++)
949 for (
int j = 0; j < rgrgstrInput[i].Count; j++)
951 string strWord = rgrgstrInput[i][j];
953 if (!m_rgDictionary.ContainsKey(strWord))
954 m_rgDictionary.Add(strWord, 1);
956 m_rgDictionary[strWord]++;
959 for (
int j = 0; j < rgrgstrTarget[i].Count; j++)
961 string strWord = rgrgstrTarget[i][j];
963 if (!m_rgDictionary.ContainsKey(strWord))
964 m_rgDictionary.Add(strWord, 1);
966 m_rgDictionary[strWord]++;
974 foreach (KeyValuePair<string, int> kv
in m_rgDictionary)
979 m_rgWordToIndex[kv.Key] = nIdx;
980 m_rgIndexToWord[nIdx] = kv.Key;
981 m_rgstrVocabulary.Add(kv.Key);
1012 get {
return m_vocab; }
1020 get {
return m_iter; }
The Log class provides general output in text form.
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
void FAIL(string str)
Causes a failure which throws an exception with the desciptive text.
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
Specifies a key-value pair of properties.
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
The RawProto class is used to parse and output Google prototxt file data.
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
The BlobCollection contains a list of Blobs.
void SetData(double df)
Set all blob data to the value specified.
int Count
Returns the number of items in the collection.
void Reshape(int[] rgShape)
Reshapes all blobs in the collection to the given shape.
void CopyFrom(BlobCollection< T > bSrc, bool bCopyDiff=false)
Copy the data or diff from another BlobCollection into this one.
The Blob is the main holder of data that moves through the Layers of the Net.
int channels
DEPRECIATED; legacy shape accessor channels: use shape(1) instead.
void SetData(T[] rgData, int nCount=-1, bool bSetCount=true)
Sets a number of items within the Blob's data.
long mutable_gpu_diff
Returns the diff GPU handle used by the CudaDnn connection.
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
void Reshape(int nNum, int nChannels, int nHeight, int nWidth, bool? bUseHalfSize=null)
DEPRECIATED; use
int count()
Returns the total number of items in the Blob.
string Name
Get/set the name of the Blob.
long gpu_diff
Returns the diff GPU handle used by the CudaDnn connection.
void SetDiff(double dfVal, int nIdx=-1)
Either sets all of the diff items in the Blob to a given value, or alternatively only sets a single i...
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
double GetMaxData(out long lPos)
Returns the maximum data and the position where the maximum is located in the data.
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
An interface for the units of computation which can be composed into a Net.
Log m_log
Specifies the Log for output.
LayerParameter m_param
Specifies the LayerParameter describing the Layer.
void convert(BlobCollection< T > col)
Convert a collection of blobs from / to half size.
float convertF(T df)
Converts a generic to a float value.
Phase m_phase
Specifies the Phase under which the Layer is run.
CudaDnn< T > m_cuda
Specifies the CudaDnn connection to Cuda.
LayerParameter.LayerType m_type
Specifies the Layer type.
The IterationInfo class contains information about each iteration.
int OutputCount
Returns the output count of the current sequence.
IterationInfo(bool bNewEpoch, bool bNewSequence, int nOutputCount)
The constructor.
bool NewEpoch
Returns whether or not the current iteration is in a new epoch.
bool NewSequence
Returns whether or not the current iteration is in a new sequence.
Defines the arguments passed to the OnGetData event.
OnGetDataArgs(Vocabulary vocab, IterationInfo iter)
The constructor.
The Vocabulary object manages the overall word dictionary and word to index and index to word mapping...
int WordToIndex(string strWord)
The WordToIndex method maps a word to its corresponding index value.
int VocabularCount
Returns the number of words in the vocabulary.
void Load(List< List< string > > rgrgstrInput, List< List< string > > rgrgstrTarget)
Loads the word to index mappings.
string IndexToWord(int nIdx)
The IndexToWord method maps an index value to its corresponding word.
Vocabulary()
The constructor.
The TextDataLayer loads data from text data files for an encoder/decoder type model....
override string PostProcessOutput(int nIdx)
Convert the index to the word.
Vocabulary Vocabulary
Returns the vocabulary of the data sources.
override int? MinBottomBlobs
When running in TRAIN or TEST phase, returns 0 for data layers have no bottom (input) Blobs....
TextDataLayer(CudaDnn< T > cuda, Log log, LayerParameter p)
The TextDataLayer constructor.
EventHandler< OnGetDataArgs > OnGetData
The OnGetTrainingData is called during each forward pass after getting the training data for the pass...
override void dispose()
Release all internal blobs.
override BlobCollection< T > PreProcessInput(PropertySet customInput, out int nSeqLen, BlobCollection< T > colBottom=null)
The PreprocessInput allows derivative data layers to convert a property set of input data into the bo...
override int MaxTopBlobs
Returns the maximum number of required top (output) Blobs: dec, dclip, enc, encr, eclip,...
IterationInfo? IterationInfo
Returns information on the current iteration.
void Next()
Proceeds to the next data item. When shuffling, the next item is randomly selected.
bool Skip()
Skip to the next data input.
override bool SupportsPreProcessing
Should return true when pre processing methods are overriden.
override void LayerSetUp(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Setup the layer.
override void backward(BlobCollection< T > colTop, List< bool > rgbPropagateDown, BlobCollection< T > colBottom)
Not implemented - data Layers do not perform backward..
override void forward(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Run the Forward computation, which fills the data into the top (output) Blobs.
override int? MaxBottomBlobs
When running in TRAIN or TEST phase, returns 0 for data layers have no bottom (input) Blobs....
override List< Tuple< string, int, double > > PostProcessOutput(Blob< T > blobSoftmax, int nK=1)
Convert the maximum index within the softmax into the word index, then convert the word index back in...
override bool SupportsPostProcessing
Should return true when pre postprocessing methods are overriden.
override bool PreProcessInput(string strEncInput, int? nDecInput, BlobCollection< T > colBottom)
Preprocess the input data for the RUN phase.
override int MinTopBlobs
Returns the minimum number of required top (output) Blobs: dec, dclip, enc, eclip,...
void PreProcessInputFiles(TextDataParameter p)
Load the input and target files and convert each into a list of lines each containing a list of words...
override void Reshape(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Data layers have no bottoms, so reshaping is trivial.
Specifies the base parameter for all layers.
int solver_count
Returns the number of Solvers participating in a multi-GPU session for which the Solver using this La...
TextDataParameter text_data_param
Returns the parameter set when initialized with LayerType.TEXT_DATA
int solver_rank
Returns the SolverRank of the Solver using this LayerParameter (if any).
string PrepareRunModelInputs()
Prepare model inputs for the run-net (if any are needed for the layer).
TransformationParameter transform_param
Returns the parameter set when initialized with LayerType.TRANSFORM
Phase phase
Specifies the Phase for which this LayerParameter is run.
LayerType
Specifies the layer type.
override string ToString()
Returns a string representation of the LayerParameter.
Specifies the parameters use to create a Net
static Dictionary< string, BlobShape > InputFromProto(RawProto rp)
Collect the inputs from the RawProto.
Specifies the parameter for the Text data layer.
bool enable_reverse_encoder_output
When enabled, the reverse ordered encoder data is output (default = true).
uint sample_size
Specifies the sample size to select from the data sources.
uint time_steps
Specifies the maximum length for each encoder input.
bool shuffle
Specifies the whether to shuffle the data or now.
bool enable_normal_encoder_output
When enabled, the normal ordered encoder data is output (default = true).
string decoder_source
Specifies the decoder data source.
virtual uint batch_size
Specifies the batch size.
string encoder_source
Specifies the encoder data source.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Phase
Defines the Phase under which to run a Net.
SPECIAL_TOKENS
Specifies the special tokens.
The MyCaffe.common namespace contains common MyCaffe classes.
The MyCaffe.fillers namespace contains all fillers including the Filler class.
The MyCaffe.layers.beta namespace contains all beta stage layers.
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...