3using System.Collections.Generic;
8using System.Threading.Tasks;
22 bool m_bFailOnFirstTry =
false;
23 const string m_strWeightMyCaffeTag =
"mycaffe.ai";
33 m_bFailOnFirstTry = bFailOnFirstTry;
42 get {
return m_strWeightMyCaffeTag; }
51 public bool IsMyCaffe(
byte[] rgWeights, out
string strVer)
55 if (rgWeights ==
null || rgWeights.Length < 10)
58 string strCaffeNet = Encoding.ASCII.GetString(rgWeights, rgWeights.Length - 10, 10);
59 if (strCaffeNet == m_strWeightMyCaffeTag)
61 long lCaffeStart = BitConverter.ToInt64(rgWeights, rgWeights.Length - (10 +
sizeof(
long)));
62 strVer = Encoding.ASCII.GetString(rgWeights, (
int)lCaffeStart + 10, 5);
77 FieldDescriptor fd = FieldDescriptor.CreateSolverStateFieldDesc();
78 ProtoBufWriter writer =
new ProtoBufWriter(m_log);
80 m_log.WriteLine(
"Saving state...");
82 writer.WriteField(fd,
"iter",
new int[] { state.
iter });
83 writer.WriteField(fd,
"current_step",
new int[] { state.
current_step });
87 writer.WriteField(fd,
"start",
new int[] { state.
start });
88 writer.WriteField(fd,
"end",
new int[] { state.
end });
91 for (
int i = 0; i < state.
history.Count; i++)
93 writer.WriteField(fd,
"history", saveBlobProto(fd.FindFirstChild(
"history"), state.
history[i]));
98 for (
int i = 0; i < state.
s_history.Count; i++)
100 writer.WriteField(fd,
"s_history", saveBlobProto(fd.FindFirstChild(
"s_history"), state.
s_history[i]));
103 writer.WriteField(fd,
"gradients", saveBlobProto(fd.FindFirstChild(
"gradient"), state.
gradients));
104 writer.WriteField(fd,
"direction", saveBlobProto(fd.FindFirstChild(
"direction"), state.
direction));
107 return writer.GetBytes(
true);
119 FieldDescriptor fd = FieldDescriptor.CreateSolverStateFieldDesc();
120 ProtoBufReader reader =
new ProtoBufReader(rgState);
121 ProtoBufFieldCollection fields = reader.ReadFields(fd,
false);
122 Stopwatch sw =
new Stopwatch();
124 m_log.WriteLine(
"Loading the Solver state...");
126 if (fields ==
null || fields.Count == 0)
134 ProtoBufField pbIter = fields.FindFirstChild(
"iter");
135 state.
iter = (pbIter ==
null || pbIter.IntValues ==
null || pbIter.IntValues.Length == 0) ? 0 : pbIter.IntValues[0];
137 ProtoBufField pbCurStep = fields.FindFirstChild(
"current_step");
138 state.
current_step = (pbCurStep ==
null || pbCurStep.IntValues ==
null || pbCurStep.IntValues.Length == 0) ? 1 : pbCurStep.IntValues[0];
142 ProtoBufField pbStart = fields.FindFirstChild(
"start");
143 state.
start = (pbStart ==
null || pbStart.IntValues ==
null || pbStart.IntValues.Length == 0) ? 0 : pbStart.IntValues[0];
145 ProtoBufField pbEnd = fields.FindFirstChild(
"end");
146 state.
end = (pbEnd ==
null || pbEnd.IntValues ==
null || pbEnd.IntValues.Length == 0) ? 1 : pbEnd.IntValues[0];
149 ProtoBufFieldCollection col = fields.FindAllChildren(
"history");
150 if (col !=
null && col.Count > 0)
152 FieldDescriptor fdHist = fd.FindFirstChild(
"history");
154 for (
int i = 0; i < col.Count; i++)
162 ProtoBufFieldCollection colS = fields.FindAllChildren(
"s_history");
163 if (colS !=
null && colS.Count > 0)
165 FieldDescriptor fdHist = fd.FindFirstChild(
"s_history");
167 for (
int i = 0; i < colS.Count; i++)
173 ProtoBufField pbGrad = fields.FindFirstChild(
"gradients");
176 FieldDescriptor fdGrad = fd.FindFirstChild(
"gradients");
180 ProtoBufField pbDir = fields.FindFirstChild(
"direction");
183 FieldDescriptor fdDir = fd.FindFirstChild(
"direction");
209 public BlobCollection<T> LoadWeights(
byte[] rgWeights, List<string> rgExpectedShapes,
BlobCollection<T> colBlobs,
bool bSizeToFit, out
bool bLoadedDiffs, List<string> inputWtInfo =
null, List<string> targetWtInfo =
null,
string strSkipBlobType =
null)
212 m_log.WriteLine(
"Attempting to load the weights in Caffe model format...");
217 colBlob1 = loadFromCaffe(rgWeights, rgExpectedShapes, colBlobs, bSizeToFit, out bLoadedDiffs, inputWtInfo, targetWtInfo, strSkipBlobType);
218 if (colBlob1 !=
null)
220 m_log.WriteLine(
"Weights loaded in Caffe model format.");
224 if (m_bFailOnFirstTry)
225 throw new Exception(
"Failed to load the weights from the caffe model.");
227 else if (strVer ==
"v.1.0")
229 m_log.FAIL(
"Loading weights with 'depreciated' native v.1.0 format...");
232 m_log.WriteLine(
"Attempting to load weights in MyCaffe model format...");
233 colBlob1 = loadFromMyCaffe(rgWeights, rgExpectedShapes, colBlobs, bSizeToFit, out bLoadedDiffs, inputWtInfo, targetWtInfo, strSkipBlobType);
234 if (colBlob1 !=
null)
236 m_log.WriteLine(
"Weights loaded in MyCaffe model format.");
240 if (m_bFailOnFirstTry)
241 throw new Exception(
"Failed to load the weights from the MyCaffe model.");
243 m_log.FAIL(
"Loading weights with 'depreciated' native format...");
257 return loadInfoFromCaffe(rgWeights);
259 return loadInfoFromMyCaffe(rgWeights);
271 foreach (
Blob<T> b
in colBlobs)
293 FieldDescriptor fd = FieldDescriptor.CreateNetworkParamFieldDesc();
294 ProtoBufWriter writer =
new ProtoBufWriter(m_log);
295 Dictionary<string, BlobCollection<T>> rgLayers =
new Dictionary<string, BlobCollection<T>>();
297 foreach (
Blob<T> blob
in colBlobs)
301 string strLayer = (string)blob.
Tag;
302 if (strLayer ==
null || strLayer.Length == 0)
303 throw new Exception(
"Invalid blob specification - missing layer name.");
305 if (!rgLayers.ContainsKey(strLayer))
308 rgLayers[strLayer].Add(blob);
312 writer.WriteField(fd,
"name",
"");
316 m_log.WriteLine(
"Saving layer '" + kv.Key +
"'...");
317 writer.WriteField(fd,
"LayerParameter", saveLayerParameter(fd.FindFirstChild(
"LayerParameter"), kv.Key, kv.Value));
322 long lCaffeNetStart = writer.Length;
323 byte[] rgPad =
new byte[256];
325 using (BinaryWriter bw =
new BinaryWriter(writer.Stream))
328 lCaffeNetStart += rgPad.Length;
331 byte[] rgCaffeNet = Encoding.ASCII.GetBytes(strCaffeNet);
333 string strVer =
"1.0.1";
334 byte[] rgV = Encoding.ASCII.GetBytes(strVer);
335 byte[] rgVer =
new byte[32];
336 Array.Copy(rgV, rgVer, rgV.Length);
338 bw.Write(rgCaffeNet);
340 bw.Write(lCaffeNetStart);
341 bw.Write(rgCaffeNet);
344 return writer.GetBytes(
false);
355 FieldDescriptor fd = FieldDescriptor.CreateBlobProtoDesc(nFieldId);
356 ProtoBufReader reader =
new ProtoBufReader(rg);
357 ProtoBufFieldCollection fields = reader.ReadFields(fd,
false);
359 if (fields ==
null || fields.Count == 0)
362 for (
int i = 0; i < fields.Count; i++)
364 ProtoBufField field = fields[i];
365 field.LoadSubFields(0, 4);
368 List<int> rgShape =
new List<int>();
370 ProtoBufField pbShape = fields.FindFirstChild(
"shape");
373 if (pbShape.Type != ProtoBufField.TYPE.ARRAY)
374 throw new Exception(
"Invalid proto buf: invalid type 'shape'");
376 ProtoBufField pbDim = pbShape.Array.FindFirstChild(
"dim");
377 if (pbDim ==
null || pbDim.Type != ProtoBufField.TYPE.LONG_ARRAY)
378 throw new Exception(
"Invalid proto buf: missing 'dim' type.");
380 for (
int i = 0; i < pbDim.LongValues.Length; i++)
382 rgShape.Add((
int)pbDim.LongValues[i]);
387 ProtoBufField pbNum = fields.FindFirstChild(
"num");
390 if (pbNum.Type != ProtoBufField.TYPE.BIT32)
391 throw new Exception(
"Invalid proto buf: invalid type 'num'");
393 rgShape.Add(pbNum.IntValue);
395 ProtoBufField pbChannels = fields.FindFirstChild(
"channels");
396 if (pbChannels !=
null)
398 if (pbChannels.Type != ProtoBufField.TYPE.BIT32)
399 throw new Exception(
"Invalid proto buf: invalid type 'channels'");
401 rgShape.Add(pbChannels.IntValue);
403 ProtoBufField pbHeight = fields.FindFirstChild(
"height");
404 if (pbHeight !=
null)
406 if (pbHeight.Type != ProtoBufField.TYPE.BIT32)
407 throw new Exception(
"Invalid proto buf: invalid type 'height'");
409 rgShape.Add(pbHeight.IntValue);
411 ProtoBufField pbWidth = fields.FindFirstChild(
"width");
414 if (pbWidth.Type != ProtoBufField.TYPE.BIT32)
415 throw new Exception(
"Invalid proto buf: invalid type 'width'");
417 rgShape.Add(pbWidth.IntValue);
424 ProtoBufField pbData = fields.FindFirstChild(
"data");
427 pbData = fields.FindFirstChild(
"double_data");
429 throw new Exception(
"Invalid proto buf: missing 'data' or 'double_data'");
434 if (pbData.Type == ProtoBufField.TYPE.FLOAT_ARRAY)
435 proto.
data =
new List<float>(pbData.FloatValues);
436 else if (pbData.Type == ProtoBufField.TYPE.DOUBLE_ARRAY)
437 proto.
double_data =
new List<double>(pbData.DoubleValues);
439 throw new Exception(
"Invalid proto buf: invalid data type '" + pbData.Type.ToString() +
"'.");
454 using (FileStream fs =
new FileStream(strFile, FileMode.Open, FileAccess.Read))
456 using (BinaryReader br =
new BinaryReader(fs))
458 rgBytes = br.ReadBytes((
int)fs.Length);
465 private byte[] saveLayerParameter(FieldDescriptor fd,
string strName,
BlobCollection<T> col)
467 ProtoBufWriter writer =
new common.ProtoBufWriter(m_log);
469 writer.WriteField(fd,
"name", strName);
473 writer.WriteField(fd,
"blobs", saveBlobProto(fd.FindFirstChild(
"blobs"), blob));
474 m_log.WriteLine(
" - saved blob '" + blob.
Name +
"'");
477 return writer.GetBytes();
480 private byte[] saveBlobProto(FieldDescriptor fd,
BlobProto bp)
482 ProtoBufWriter writer =
new ProtoBufWriter(m_log);
484 writer.WriteField(fd,
"shape", saveBlobShape(fd.FindFirstChild(
"shape"), bp.
shape.
dim));
487 writer.WriteField(fd,
"double_data", bp.
double_data.ToArray());
489 writer.WriteField(fd,
"data", bp.
data.ToArray());
491 return writer.GetBytes();
494 private byte[] saveBlobProto(FieldDescriptor fd, Blob<T> blob)
496 ProtoBufWriter writer =
new ProtoBufWriter(m_log);
498 writer.WriteField(fd,
"shape", saveBlobShape(fd.FindFirstChild(
"shape"), blob.shape()));
500 T[] rg = blob.update_cpu_data();
502 if (typeof(T) == typeof(
double))
504 double[] rgD = (
double[])Convert.ChangeType(rg, typeof(
double[]));
505 writer.WriteField(fd,
"double_data", rgD);
509 float[] rgD = (
float[])Convert.ChangeType(rg, typeof(
float[]));
510 writer.WriteField(fd,
"data", rgD);
513 return writer.GetBytes();
516 private byte[] saveBlobShape(FieldDescriptor fd, List<int> rg)
518 ProtoBufWriter writer =
new ProtoBufWriter(m_log);
519 List<long> rgLong =
new List<long>();
521 for (
int i = 0; i < rg.Count; i++)
526 writer.WriteField(fd,
"dim", rgLong.ToArray());
528 return writer.GetBytes();
531 private BlobCollection<T> loadFromMyCaffe(
byte[] rgWeights, List<string> rgExpectedShapes, BlobCollection<T> colBlobs,
bool bSizeToFit, out
bool bLoadedDiffs, List<string> inputWtInfo =
null, List<string> targetWtInfo =
null,
string strSkipBlobType =
null)
533 BlobCollection<T> colBlobs1 = loadFromCaffe(rgWeights, rgExpectedShapes, colBlobs, bSizeToFit, out bLoadedDiffs, inputWtInfo, targetWtInfo, strSkipBlobType);
537 private WeightInfo<T> loadInfoFromMyCaffe(
byte[] rgWeights)
539 return loadInfoFromCaffe(rgWeights);
542 private BlobCollection<T> loadFromCaffe(
byte[] rgWeights, List<string> rgExpectedShapes, BlobCollection<T> colBlobs,
bool bSizeToFit, out
bool bLoadedDiffs, List<string> inputWtInfo =
null, List<string> targetWtInfo =
null,
string strSkipBlobType =
null)
544 FieldDescriptor fd = FieldDescriptor.CreateNetworkParamFieldDesc();
545 ProtoBufReader reader =
new ProtoBufReader(rgWeights);
546 ProtoBufFieldCollection fields = reader.ReadFields(fd,
true);
547 Stopwatch sw =
new Stopwatch();
548 BlobName name =
new BlobName();
550 bLoadedDiffs =
false;
552 if (fields ==
null || fields.Count == 0)
557 for (
int i=0; i<fields.Count; i++)
559 ProtoBufField field = fields[i];
560 field.LoadSubFields(0, 4);
562 if (sw.Elapsed.TotalMilliseconds > 1000)
564 m_log.Progress = (double)i / (
double)fields.Count;
565 m_log.WriteLine(
"(" + m_log.Progress.ToString(
"P") +
") loading fields...");
575 ProtoBufFieldCollection colFieldBlobs =
new common.ProtoBufFieldCollection();
578 for (
int i = 0; i < fields.Count; i++)
580 if (fields[i].FieldDesc !=
null)
582 if (fields[i].FieldDesc.Name ==
"LayerParameter")
584 ProtoBufField pbName = fields[i].Array.FindFirstChild(
"name");
585 ProtoBufFieldCollection col = fields[i].Array.FindAllChildren(
"blobs");
586 string strName = (pbName !=
null) ? pbName.StringValue : (
"layer_" + nLayerIdx.ToString());
588 if (col !=
null && col.Count > 0)
591 colFieldBlobs.AddRange(col);
596 else if (fields[i].FieldDesc.Name ==
"V1LayerParameter")
598 ProtoBufField pbName = fields[i].Array.FindFirstChild(
"name");
599 ProtoBufFieldCollection col = fields[i].Array.FindAllChildren(
"blobs");
600 string strName = (pbName !=
null) ? pbName.StringValue : (
"layer_" + nLayerIdx.ToString());
602 if (col !=
null && col.Count > 0)
606 colFieldBlobs.AddRange(col);
620 m_log.WriteLine(
"Loading the weights...");
622 if (colBlobs.Count != colFieldBlobs.Count)
623 m_log.WriteLine(
"The number of learnable blobs within the weights does not match the number within the network, attempting to load by size...");
630 List<long> rgBlobShape =
null;
632 while (nFieldIdx < colFieldBlobs.Count && nBlobIdx < colBlobs.Count)
634 Blob<T> blob = colBlobs[nBlobIdx];
635 string strName = name.GetName(blob.Name);
637 if (targetWtInfo !=
null)
639 while (strName != targetWtInfo[nTargetIdx] && nBlobIdx < colBlobs.Count)
643 if (nBlobIdx == colBlobs.Count)
646 blob = colBlobs[nBlobIdx];
647 strName = name.GetName(blob.Name);
650 if (nBlobIdx == colBlobs.Count)
651 m_log.WriteError(
new Exception(
"Could not find the target blob '" + targetWtInfo[nTargetIdx] +
"'!"));
656 string strShapeB = rgExpectedShapes[nBlobIdx];
658 string strShapeW =
"";
660 bool bResizeNeeded =
false;
661 bool bMisSized =
false;
666 while (nFieldIdx < colFieldBlobs.Count)
670 ProtoBufField pbName = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"name");
671 if (pbName !=
null && pbName.Type == ProtoBufField.TYPE.STRING)
673 strName = pbName.StringValue;
677 ProtoBufField pbType = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"type");
678 if (pbType !=
null && pbType.Type == ProtoBufField.TYPE.STRING)
679 strName = pbType.StringValue +
"_" + nFieldIdx.ToString();
681 strName =
"blob_" + nFieldIdx.ToString();
684 if (inputWtInfo ==
null || strName == inputWtInfo[nInfoIdx])
688 ProtoBufField pbShape = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"shape");
689 if (pbShape !=
null && pbShape.Type == ProtoBufField.TYPE.ARRAY)
691 ProtoBufField pbDim = pbShape.Array.FindFirstChild(
"dim");
692 if (pbDim !=
null && pbDim.Type == ProtoBufField.TYPE.LONG_ARRAY)
694 strShapeW = createShapeString(pbDim.LongValues, out lCount);
696 if (compareShapes(strShapeB, strShapeW) || bSizeToFitWts)
698 rgBlobShape =
new List<long>(pbDim.LongValues);
699 bResizeNeeded = bSizeToFitWts;
703 if (bSizeToFit && compareShapes(strShapeB, strShapeW, 2))
705 rgBlobShape =
new List<long>(pbDim.LongValues);
715 ProtoBufField pbNum = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"num");
716 if (pbNum !=
null && pbNum.Type == ProtoBufField.TYPE.BIT32)
718 List<long> rgShape =
new List<long>();
719 rgShape.Add(pbNum.IntValue);
721 ProtoBufField pbChannels = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"channels");
722 if (pbChannels !=
null && pbChannels.Type == ProtoBufField.TYPE.BIT32)
724 rgShape.Add(pbChannels.IntValue);
726 ProtoBufField pbHeight = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"height");
727 if (pbHeight !=
null && pbHeight.Type == ProtoBufField.TYPE.BIT32)
729 rgShape.Add(pbHeight.IntValue);
731 ProtoBufField pbWidth = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"width");
732 if (pbWidth !=
null && pbWidth.Type == ProtoBufField.TYPE.BIT32)
734 rgShape.Add(pbWidth.IntValue);
739 strShapeW = createShapeString(rgShape.ToArray(), out lCount);
741 if (compareShapes(strShapeB, strShapeW) || (bSizeToFit || bSizeToFitWts))
743 rgBlobShape = rgShape;
747 if ((bSizeToFit || bSizeToFitWts) && compareShapes(strShapeB, strShapeW, 2))
749 rgBlobShape = rgShape;
750 bResizeNeeded =
true;
763 if (nFieldIdx == colFieldBlobs.Count)
772 if (!bMisSized && (strSkipBlobType ==
null || blob.type.ToString() != strSkipBlobType))
774 ProtoBufField pbData = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"data");
775 FieldDescriptor.TYPE type = FieldDescriptor.TYPE.FLOAT;
779 pbData = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"double_data");
780 type = FieldDescriptor.TYPE.DOUBLE;
781 lDataCount = pbData.DoubleValues.Length;
785 lDataCount = pbData.FloatValues.Length;
788 if (pbData ==
null || (lDataCount != lCount && !bSizeToFit && !bSizeToFitWts))
789 m_log.FAIL(
"Could not find the weights matching the data size '" + strShapeB +
"'!");
796 List<int> rgNewShape = parseShape(strShapeW);
798 while (rgNewShape.Count < rgBlobShape.Count)
803 blob.Reshape(rgNewShape);
805 for (
int i = 0; i < rgNewShape.Count; i++)
807 rgBlobShape[i] = rgNewShape[i];
811 T[] rgData = copyData(pbData, type, lDataCount, rgBlobShape);
815 Blob<T> blobTemp =
new Blob<T>(blob.Cuda, blob.Log,
false,
false);
816 blobTemp.ReshapeLike(blob);
817 blobTemp.mutable_cpu_data = rgData;
818 blob.CopyFrom(blobTemp);
823 blob.mutable_cpu_data = rgData;
828 if (bSizeToFit && !compareShapes(strShapeB, strShapeW, 4))
829 m_log.FAIL(
"Could not find the weights matching the first two items of the shape '" + strShapeB +
"'!");
831 T[] rgData = copyData(pbData, type, lDataCount, rgBlobShape);
835 Blob<T> blobTemp =
new Blob<T>(blob.Cuda, blob.Log,
false,
false);
836 blobTemp.ReshapeLike(blob);
837 blobTemp.mutable_cpu_data = rgData;
838 blob.CopyFrom(blobTemp);
843 blob.mutable_cpu_data = rgData;
846 if (bSizeToFit && bResizeNeeded)
848 List<int> rgNewShape = parseShape(strShapeB);
849 Blob<T> blobResized = blob.Resize(rgNewShape);
851 colBlobs[nBlobIdx] = blobResized;
855 blob.Tag = colFieldBlobs[nFieldIdx].Tag;
857 m_log.WriteLine(
"(" + m_log.Progress.ToString(
"P") +
") loaded blob '" + colBlobs[nBlobIdx].Name +
"' size = " + strShapeB);
861 m_log.WriteLine(
"WARNING: did NOT load blob '" + colBlobs[nBlobIdx].Name +
"' size = " + strShapeB);
864 m_log.Progress = (double)nBlobIdx / (
double)colBlobs.Count;
869 if ((targetWtInfo !=
null && nTargetIdx == targetWtInfo.Count) ||
870 (inputWtInfo !=
null && nInfoIdx == inputWtInfo.Count))
877 private WeightInfo<T> loadInfoFromCaffe(
byte[] rgWeights)
879 WeightInfo<T> info =
new common.WeightInfo<T>();
880 FieldDescriptor fd = FieldDescriptor.CreateNetworkParamFieldDesc();
881 ProtoBufReader reader =
new ProtoBufReader(rgWeights);
882 ProtoBufFieldCollection fields = reader.ReadFields(fd,
true);
883 Stopwatch sw =
new Stopwatch();
885 if (fields ==
null || fields.Count == 0)
890 for (
int i = 0; i < fields.Count; i++)
892 ProtoBufField field = fields[i];
893 field.LoadSubFields(0, 4);
895 if (sw.Elapsed.TotalMilliseconds > 1000)
897 m_log.Progress = (double)i / (
double)fields.Count;
898 m_log.WriteLine(
"(" + m_log.Progress.ToString(
"P") +
") loading fields...");
908 ProtoBufFieldCollection colFieldBlobs =
new common.ProtoBufFieldCollection();
911 for (
int i = 0; i < fields.Count; i++)
913 if (fields[i].FieldDesc !=
null)
915 if (fields[i].FieldDesc.Name ==
"LayerParameter")
917 ProtoBufField pbName = fields[i].Array.FindFirstChild(
"name");
918 ProtoBufFieldCollection col = fields[i].Array.FindAllChildren(
"blobs");
919 string strName = (pbName !=
null) ? pbName.StringValue : (
"layer_" + nLayerIdx.ToString());
921 if (col !=
null && col.Count > 0)
924 colFieldBlobs.AddRange(col);
929 else if (fields[i].FieldDesc.Name ==
"V1LayerParameter")
931 ProtoBufField pbName = fields[i].Array.FindFirstChild(
"name");
932 ProtoBufFieldCollection col = fields[i].Array.FindAllChildren(
"blobs");
933 string strName = (pbName !=
null) ? pbName.StringValue : (
"layer_" + nLayerIdx.ToString());
935 if (col !=
null && col.Count > 0)
939 colFieldBlobs.AddRange(col);
956 while (nFieldIdx < colFieldBlobs.Count)
958 string strName =
null;
960 ProtoBufField pbName = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"name");
961 if (pbName !=
null && pbName.Type == ProtoBufField.TYPE.STRING)
963 strName = pbName.StringValue;
967 ProtoBufField pbType = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"type");
968 if (pbType !=
null && pbType.Type == ProtoBufField.TYPE.STRING)
969 strName = pbType.StringValue +
"_" + nFieldIdx.ToString();
971 strName =
"blob_" + nFieldIdx.ToString();
974 List<int> rgShape =
new List<int>();
976 ProtoBufField pbShape = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"shape");
977 if (pbShape !=
null && pbShape.Type == ProtoBufField.TYPE.ARRAY)
979 ProtoBufField pbDim = pbShape.Array.FindFirstChild(
"dim");
980 if (pbDim !=
null && pbDim.Type == ProtoBufField.TYPE.LONG_ARRAY)
982 for (
int i = 0; i < pbDim.LongValues.Length; i++)
984 rgShape.Add((
int)pbDim.LongValues[i]);
990 ProtoBufField pbNum = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"num");
991 if (pbNum !=
null && pbNum.Type == ProtoBufField.TYPE.BIT32)
993 rgShape.Add(pbNum.IntValue);
995 ProtoBufField pbChannels = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"channels");
996 if (pbChannels !=
null && pbChannels.Type == ProtoBufField.TYPE.BIT32)
998 rgShape.Add(pbChannels.IntValue);
1000 ProtoBufField pbHeight = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"height");
1001 if (pbHeight !=
null && pbHeight.Type == ProtoBufField.TYPE.BIT32)
1003 rgShape.Add(pbHeight.IntValue);
1005 ProtoBufField pbWidth = colFieldBlobs[nFieldIdx].Array.FindFirstChild(
"width");
1006 if (pbWidth !=
null && pbWidth.Type == ProtoBufField.TYPE.BIT32)
1008 rgShape.Add(pbWidth.IntValue);
1015 info.AddBlob(strName, rgShape,
BLOB_TYPE.UNKNOWN);
1023 private T[] copyData(ProtoBufField pb, FieldDescriptor.TYPE type,
long lCount, List<long> rgBlobShape)
1025 T[] rgData =
new T[lCount];
1027 if (type == FieldDescriptor.TYPE.FLOAT)
1028 Array.Copy(pb.FloatValues, rgData, lCount);
1031 if (typeof(T) == typeof(
double))
1032 Array.Copy(pb.DoubleValues, rgData, lCount);
1040 private List<int> parseShape(
string strShape,
int nCount =
int.MaxValue)
1042 List<int> rg1 =
new List<int>();
1043 string[] rgstr1 = strShape.Split(
' ');
1045 for (
int i = 0; i < rgstr1.Length - 1 && i < nCount; i++)
1047 int nVal =
int.Parse(rgstr1[i]);
1056 private bool compareShapes(
string strA,
string strB,
int nCount =
int.MaxValue)
1061 List<int> rg1 = parseShape(strA, nCount);
1062 List<int> rg2 = parseShape(strB, nCount);
1064 if (rg1.Count != rg2.Count)
1075 for (
int i = 0; i < rg1.Count; i++)
1077 if (rg1[i] != rg2[i])
1084 private string createShapeString(
long[] rg, out
long lCount)
1089 for (
int i = 0; i < rg.Length; i++)
1093 str += rg[i].ToString();
1099 str +=
"(" + rg.Length.ToString() +
")";
1105 class ProtoBufWriter : IDisposable
1107 MemoryStream m_ms =
null;
1108 CodedOutputStream m_strm =
null;
1109 bool m_bOwnStream =
true;
1111 static int m_nUnknownFieldID = 5000;
1112 Dictionary<string, int> m_rgUnknownFields =
new Dictionary<string, int>();
1114 public ProtoBufWriter(
Log log)
1117 m_ms =
new MemoryStream();
1118 m_strm =
new CodedOutputStream(m_ms);
1121 public ProtoBufWriter(
Log log, CodedOutputStream strm)
1124 m_bOwnStream =
false;
1127 public void Dispose()
1129 if (m_strm !=
null && m_bOwnStream)
1144 get {
return (
int)m_ms.Length; }
1147 public byte[] GetBytes(
bool bFlush =
true)
1149 if (m_strm !=
null && bFlush)
1152 byte[] rg = m_ms.ToArray();
1161 public MemoryStream Stream
1163 get {
return m_ms; }
1166 private int getFieldId(FieldDescriptor fd,
string strName, out FieldDescriptor.TYPE type)
1168 type = FieldDescriptor.TYPE.UNKNOWN;
1170 fd = fd.FindFirstChild(strName);
1177 if (m_rgUnknownFields.ContainsKey(strName))
1178 return m_rgUnknownFields[strName];
1180 int nId = m_nUnknownFieldID;
1181 m_nUnknownFieldID++;
1183 m_rgUnknownFields.Add(strName, nId);
1188 public void WriteField(FieldDescriptor fd,
string strName,
string strVal)
1190 FieldDescriptor.TYPE type;
1191 int nFieldId = getFieldId(fd, strName, out type);
1192 uint tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.LengthDelimited);
1194 m_strm.WriteUInt32(tag);
1195 m_strm.WriteString(strVal);
1198 public void WriteField(FieldDescriptor fd,
string strName,
byte[] rg)
1200 FieldDescriptor.TYPE type;
1201 int nFieldId = getFieldId(fd, strName, out type);
1202 uint tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.LengthDelimited);
1204 m_strm.WriteUInt32(tag);
1205 m_strm.WriteBytes(ByteString.CopyFrom(rg));
1208 public void WriteField(FieldDescriptor fd,
string strName,
double dfVal)
1210 FieldDescriptor.TYPE type;
1211 int nFieldId = getFieldId(fd, strName, out type);
1216 case FieldDescriptor.TYPE.DOUBLE:
1217 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.Fixed64);
1218 m_strm.WriteUInt32(tag);
1219 m_strm.WriteDouble(dfVal);
1222 case FieldDescriptor.TYPE.FLOAT:
1223 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.Fixed32);
1224 m_strm.WriteUInt32(tag);
1225 m_strm.WriteFloat((
float)dfVal);
1228 case FieldDescriptor.TYPE.LONG:
1229 case FieldDescriptor.TYPE.ULONG:
1230 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.Fixed64);
1231 m_strm.WriteUInt32(tag);
1232 m_strm.WriteFixed64((ulong)(
long)dfVal);
1235 case FieldDescriptor.TYPE.INT:
1236 case FieldDescriptor.TYPE.UINT:
1237 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.Fixed32);
1238 m_strm.WriteUInt32(tag);
1239 m_strm.WriteFixed32((uint)(
int)dfVal);
1243 throw new Exception(
"Unknown type '" + type.ToString() +
"'");
1247 public void WriteField(FieldDescriptor fd,
string strName,
long[] rgVal)
1249 FieldDescriptor.TYPE type;
1250 int nFieldId = getFieldId(fd, strName, out type);
1253 if (type != FieldDescriptor.TYPE.LONG &&
1254 type != FieldDescriptor.TYPE.ULONG)
1255 throw new Exception(
"Invalid type '" + type.ToString() +
"'");
1257 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.LengthDelimited);
1258 m_strm.WriteUInt32(tag);
1260 ProtoBufWriter pbWriter =
new ProtoBufWriter(m_log);
1261 byte[] rg = pbWriter.WriteArray(type, rgVal);
1262 m_strm.WriteBytes(ByteString.CopyFrom(rg));
1265 public byte[] WriteArray(FieldDescriptor.TYPE type,
long[] rgVal)
1267 for (
int i = 0; i < rgVal.Length; i++)
1269 if (type == FieldDescriptor.TYPE.ULONG)
1270 m_strm.WriteUInt64((uint)rgVal[i]);
1272 m_strm.WriteInt64(rgVal[i]);
1278 public void WriteField(FieldDescriptor fd,
string strName,
int[] rgVal)
1280 FieldDescriptor.TYPE type;
1281 int nFieldId = getFieldId(fd, strName, out type);
1284 if (type != FieldDescriptor.TYPE.INT &&
1285 type != FieldDescriptor.TYPE.UINT)
1286 throw new Exception(
"Invalid type '" + type.ToString() +
"'");
1288 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.LengthDelimited);
1289 m_strm.WriteUInt32(tag);
1291 ProtoBufWriter pbWriter =
new ProtoBufWriter(m_log);
1292 byte[] rg = pbWriter.WriteArray(type, rgVal);
1293 m_strm.WriteBytes(ByteString.CopyFrom(rg));
1296 public byte[] WriteArray(FieldDescriptor.TYPE type,
int[] rgVal)
1298 for (
int i = 0; i < rgVal.Length; i++)
1300 if (type == FieldDescriptor.TYPE.UINT)
1301 m_strm.WriteUInt32((uint)rgVal[i]);
1303 m_strm.WriteInt64(rgVal[i]);
1309 public void WriteField(FieldDescriptor fd,
string strName,
double[] rgVal)
1311 FieldDescriptor.TYPE type;
1312 int nFieldId = getFieldId(fd, strName, out type);
1315 if (type != FieldDescriptor.TYPE.DOUBLE &&
1316 type != FieldDescriptor.TYPE.FLOAT)
1317 throw new Exception(
"Invalid type '" + type.ToString() +
"'");
1319 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.LengthDelimited);
1320 m_strm.WriteUInt32(tag);
1322 ProtoBufWriter pbWriter =
new ProtoBufWriter(m_log);
1323 byte[] rg = pbWriter.WriteArray(type, rgVal);
1324 m_strm.WriteBytes(ByteString.CopyFrom(rg));
1327 public byte[] WriteArray(FieldDescriptor.TYPE type,
double[] rgVal)
1329 for (
int i = 0; i < rgVal.Length; i++)
1331 m_strm.WriteDouble(rgVal[i]);
1337 public void WriteField(FieldDescriptor fd,
string strName,
float[] rgVal)
1339 FieldDescriptor.TYPE type;
1340 int nFieldId = getFieldId(fd, strName, out type);
1343 if (type != FieldDescriptor.TYPE.DOUBLE &&
1344 type != FieldDescriptor.TYPE.FLOAT)
1345 throw new Exception(
"Invalid type '" + type.ToString() +
"'");
1347 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.LengthDelimited);
1348 m_strm.WriteUInt32(tag);
1350 ProtoBufWriter pbWriter =
new ProtoBufWriter(m_log);
1351 byte[] rg = pbWriter.WriteArray(type, rgVal);
1352 m_strm.WriteBytes(ByteString.CopyFrom(rg));
1355 public byte[] WriteArray(FieldDescriptor.TYPE type,
float[] rgVal)
1357 for (
int i = 0; i < rgVal.Length; i++)
1359 m_strm.WriteFloat(rgVal[i]);
1366 class ProtoBufReader : IDisposable
1368 CodedInputStream m_strm =
null;
1369 bool m_bOwnStream =
true;
1371 public ProtoBufReader(
byte[] rg)
1373 m_strm =
new CodedInputStream(rg);
1376 public ProtoBufReader(CodedInputStream strm)
1379 m_bOwnStream =
false;
1382 public void Dispose()
1384 if (m_strm !=
null && m_bOwnStream)
1391 public ProtoBufFieldCollection ReadFields(FieldDescriptor fd,
bool bFirstRead)
1393 ProtoBufFieldCollection fields =
new common.ProtoBufFieldCollection();
1394 ProtoBufField field = ReadField(fd, bFirstRead);
1396 while (field !=
null)
1398 if (field.Length > 0 || (field.Type != ProtoBufField.TYPE.BYTES && field.Type != ProtoBufField.TYPE.STRING))
1401 field = ReadField(fd, bFirstRead);
1407 public ProtoBufField ReadField(FieldDescriptor fd,
bool bFirstRead)
1412 uint tag = m_strm.ReadUInt32();
1413 int nField = WireFormat.GetTagFieldNumber(tag);
1418 int nWireFmt = (int)WireFormat.GetTagWireType(tag);
1419 if (bFirstRead && nWireFmt != (
int)WireFormat.WireType.LengthDelimited)
1423 fd = fd.FindFirstChild(nField);
1425 ProtoBufField field =
new ProtoBufField(m_strm, nField, fd);
1426 if (!field.Load((WireFormat.WireType)nWireFmt))
1433 class ProtoBufFieldCollection : IEnumerable<ProtoBufField>
1435 List<ProtoBufField> m_rgFields =
new List<ProtoBufField>();
1437 public ProtoBufFieldCollection()
1443 get {
return m_rgFields.Count; }
1446 public ProtoBufField
this[
int nIdx]
1448 get {
return m_rgFields[nIdx]; }
1451 public void SetTag(
string str)
1453 foreach (ProtoBufField field
in m_rgFields)
1459 public void SetLegacy(
bool bLegacy)
1461 foreach (ProtoBufField field
in m_rgFields)
1463 field.Legacy = bLegacy;
1467 public void Add(ProtoBufField p)
1472 public void AddRange(ProtoBufFieldCollection col)
1474 m_rgFields.AddRange(col.m_rgFields);
1477 public ProtoBufFieldCollection FindAllChildren(
string strName)
1479 ProtoBufFieldCollection col =
new common.ProtoBufFieldCollection();
1481 foreach (ProtoBufField field
in m_rgFields)
1483 if (field.FieldDesc !=
null && field.FieldDesc.Name == strName)
1490 public ProtoBufField FindFirstChild(
string strName)
1492 foreach (ProtoBufField field
in m_rgFields)
1494 if (field.FieldDesc !=
null && field.FieldDesc.Name == strName)
1501 public IEnumerator<ProtoBufField> GetEnumerator()
1503 return m_rgFields.GetEnumerator();
1506 IEnumerator IEnumerable.GetEnumerator()
1508 return m_rgFields.GetEnumerator();
1515 FieldDescriptor m_fd;
1516 CodedInputStream m_strm;
1523 int[] m_rgnVal =
null;
1524 long[] m_rglVal =
null;
1525 float[] m_rgfVal =
null;
1526 double[] m_rgdfVal =
null;
1527 string m_strTag =
null;
1528 bool m_bLegacy =
false;
1530 TYPE m_type = TYPE.BYTES;
1531 ProtoBufFieldCollection m_col =
new ProtoBufFieldCollection();
1533 WireFormat.WireType m_wireType;
1548 public ProtoBufField(CodedInputStream strm,
int nField, FieldDescriptor fd)
1555 public bool Load(WireFormat.WireType wireType)
1557 m_wireType = wireType;
1561 case WireFormat.WireType.Varint:
1562 m_lVal = m_strm.ReadInt32();
1563 m_nVal = (int)m_lVal;
1564 m_type = TYPE.BIT32;
1567 case WireFormat.WireType.LengthDelimited:
1568 ByteString bs = m_strm.ReadBytes();
1571 m_rgBytes = bs.ToByteArray();
1573 if (m_fd ==
null || m_fd.Type == FieldDescriptor.TYPE.STRING)
1574 m_strVal = getString(m_rgBytes, out m_type);
1576 if (m_type == TYPE.BYTES && m_fd !=
null && m_fd.Type != FieldDescriptor.TYPE.FIELDDESC)
1580 case FieldDescriptor.TYPE.INT:
1581 case FieldDescriptor.TYPE.UINT:
1582 m_rgnVal = readIntArray(m_rgBytes, m_fd.Type);
1583 m_type = TYPE.INT_ARRAY;
1586 case FieldDescriptor.TYPE.LONG:
1587 case FieldDescriptor.TYPE.ULONG:
1588 m_rglVal = readLongArray(m_rgBytes, m_fd.Type);
1589 m_type = TYPE.LONG_ARRAY;
1592 case FieldDescriptor.TYPE.FLOAT:
1593 m_rgfVal = readFloatArray(m_rgBytes);
1594 m_type = TYPE.FLOAT_ARRAY;
1597 case FieldDescriptor.TYPE.DOUBLE:
1598 m_rgdfVal = readDoubleArray(m_rgBytes);
1599 m_type = TYPE.DOUBLE_ARRAY;
1606 case WireFormat.WireType.Fixed32:
1607 float fVal = m_strm.ReadFloat();
1609 m_fVal = (float)fVal;
1610 m_type = TYPE.BIT32;
1613 case WireFormat.WireType.Fixed64:
1614 double dfVal = m_strm.ReadDouble();
1615 m_lVal = (long)dfVal;
1616 m_dfVal = (double)dfVal;
1617 m_type = TYPE.BIT64;
1627 private int[] readIntArray(
byte[] rgBytes, FieldDescriptor.TYPE type)
1629 CodedInputStream strm =
new CodedInputStream(rgBytes);
1630 List<int> rg =
new List<int>();
1632 while (!strm.IsAtEnd)
1634 int lVal = (type == FieldDescriptor.TYPE.INT) ? (
int)strm.ReadInt32() : (int)strm.ReadUInt32();
1638 return rg.ToArray();
1641 private long[] readLongArray(
byte[] rgBytes, FieldDescriptor.TYPE type)
1643 CodedInputStream strm =
new CodedInputStream(rgBytes);
1644 List<long> rg =
new List<long>();
1646 while (!strm.IsAtEnd)
1648 long lVal = (type == FieldDescriptor.TYPE.LONG) ? (
long)strm.ReadInt64() : (long)strm.ReadUInt64();
1652 return rg.ToArray();
1655 private float[] readFloatArray(
byte[] rgBytes)
1657 int nCount = rgBytes.Length /
sizeof(float);
1658 int nErr = rgBytes.Length %
sizeof(float);
1661 throw new Exception(
"Invalid " + m_fd.Type.ToString() +
" data - not aligned.");
1663 CodedInputStream strm =
new CodedInputStream(rgBytes);
1664 float[] rg =
new float[nCount];
1666 for (
int i = 0; i < nCount; i++)
1668 rg[i] = strm.ReadFloat();
1674 private double[] readDoubleArray(
byte[] rgBytes)
1676 int nCount = rgBytes.Length /
sizeof(double);
1677 int nErr = rgBytes.Length %
sizeof(double);
1680 throw new Exception(
"Invalid " + m_fd.Type.ToString() +
" data - not aligned.");
1682 CodedInputStream strm =
new CodedInputStream(rgBytes);
1683 double[] rg =
new double[nCount];
1685 for (
int i = 0; i < nCount; i++)
1687 rg[i] = strm.ReadDouble();
1693 public void LoadSubFields(
int nDepth = 0,
int nMaxDepth =
int.MaxValue, List<KeyValuePair<int, string>> rgIgnore =
null)
1695 ProtoBufFieldCollection col =
null;
1697 if (m_type == TYPE.BYTES)
1699 ProtoBufReader reader =
new common.ProtoBufReader(m_rgBytes);
1700 col = reader.ReadFields(m_fd,
false);
1702 m_type = TYPE.ARRAY;
1704 else if (m_type == TYPE.ARRAY)
1709 if (col !=
null && col.Count > 0)
1713 if (nDepth < nMaxDepth)
1715 if (rgIgnore !=
null)
1717 foreach (KeyValuePair<int, string> kv
in rgIgnore)
1719 if (kv.Key <= m_col.Count &&
1720 m_col[kv.Key].Type == TYPE.STRING &&
1721 m_col[kv.Key].StringValue == kv.Value)
1726 foreach (ProtoBufField field
in m_col)
1728 field.LoadSubFields(nDepth, nMaxDepth);
1734 private string getString(
byte[] rg, out TYPE type)
1736 string strOut =
null;
1740 for (
int i = 0; i < rg.Length; i++)
1742 char ch = (char)rg[i];
1743 if (
char.IsControl(ch))
1754 private byte[] getBytes(
string str, out TYPE type)
1756 byte[] rg =
new byte[str.Length];
1760 for (
int i = 0; i < str.Length; i++)
1762 rg[i] = (byte)str[i];
1764 if (
char.IsControl(str[i]))
1773 get {
return m_bLegacy; }
1774 set { m_bLegacy = value; }
1779 get {
return m_strTag; }
1780 set { m_strTag = value; }
1785 get {
return m_rgBytes; }
1790 get {
return (m_rgBytes ==
null) ? 0 : m_rgBytes.Length; }
1795 get {
return m_type; }
1798 public string StringValue
1800 get {
return m_strVal; }
1803 public long LongValue
1805 get {
return m_lVal; }
1808 public long[] LongValues
1810 get {
return m_rglVal; }
1815 get {
return m_nVal; }
1818 public int[] IntValues
1820 get {
return m_rgnVal; }
1823 public float FloatValue
1825 get {
return m_fVal; }
1828 public float[] FloatValues
1830 get {
return m_rgfVal; }
1833 public double DoubleValue
1835 get {
return m_dfVal; }
1838 public double[] DoubleValues
1840 get {
return m_rgdfVal; }
1843 public ProtoBufFieldCollection Array
1845 get {
return m_col; }
1850 get {
return m_nField; }
1853 public FieldDescriptor FieldDesc
1855 get {
return m_fd; }
1858 public override string ToString()
1860 string strName = (m_fd ==
null) ?
"NO FLDESC!" : m_fd.Name;
1861 string str = strName +
"(" + m_nField.ToString() +
")[" + m_wireType.ToString() +
"] " + m_type.ToString() +
": ";
1863 if (m_type == TYPE.STRING)
1864 return str + m_strVal;
1866 if (m_type == TYPE.BIT32)
1867 return str + m_nVal.ToString() +
" (float = " + m_fVal.ToString() +
")";
1869 if (m_type == TYPE.BIT64)
1870 return str + m_lVal.ToString() +
" (double = " + m_dfVal.ToString() +
")";
1872 if (m_type == TYPE.ARRAY)
1873 return str +
" Count = " + m_col.Count.ToString();
1875 return str +
" bytes = " + ((m_rgBytes ==
null) ?
"0" : m_rgBytes.Length.ToString());
1879#pragma warning disable 1591
1881 public class FieldDescriptor
1883 List<FieldDescriptor> m_rgChildren =
new List<FieldDescriptor>();
1885 string m_strName =
"";
1886 TYPE m_type = TYPE.UNKNOWN;
1902 public FieldDescriptor(
int nField,
string strName, TYPE type, List<FieldDescriptor> rgChildren =
null)
1904 m_nFieldID = nField;
1905 m_strName = strName;
1908 if (rgChildren !=
null)
1909 m_rgChildren = rgChildren;
1912 public FieldDescriptor FindFirstChild(
int nFieldId)
1914 foreach (FieldDescriptor fd
in m_rgChildren)
1916 if (fd.FieldId == nFieldId)
1923 public FieldDescriptor FindFirstChild(
string strName)
1925 foreach (FieldDescriptor fd
in m_rgChildren)
1927 if (fd.Name == strName)
1936 get {
return m_nFieldID; }
1941 get {
return m_strName; }
1946 get {
return m_type; }
1949 public List<FieldDescriptor> Children
1951 get {
return m_rgChildren; }
1954 public override string ToString()
1956 return m_strName +
" (" + m_nFieldID.ToString() +
") - " + m_type.ToString();
1959 public static FieldDescriptor CreateSolverStateFieldDesc()
1961 return new common.FieldDescriptor(0,
"SolverState", TYPE.FIELDDESC, loadSolverState());
1964 public static FieldDescriptor CreateNetworkParamFieldDesc()
1966 return new common.FieldDescriptor(0,
"NetParameter", TYPE.FIELDDESC, loadNetParameter());
1969 public static FieldDescriptor CreateBlobProtoDesc(
int nFieldId)
1971 return new FieldDescriptor(nFieldId,
"BlobProto", TYPE.FIELDDESC, loadBlobProto());
1974 private static List<FieldDescriptor> loadSolverState()
1976 List<FieldDescriptor> rgF =
new List<common.FieldDescriptor>();
1977 rgF.Add(
new FieldDescriptor(1,
"iter", TYPE.INT));
1978 rgF.Add(
new FieldDescriptor(3,
"history", TYPE.FIELDDESC, loadBlobProto()));
1979 rgF.Add(
new FieldDescriptor(4,
"current_step", TYPE.INT));
1983 private static List<FieldDescriptor> loadNetParameter()
1985 List<FieldDescriptor> rgF =
new List<common.FieldDescriptor>();
1986 rgF.Add(
new FieldDescriptor(1,
"name", TYPE.STRING));
1987 rgF.Add(
new FieldDescriptor(100,
"LayerParameter", TYPE.FIELDDESC, loadLayerParameter()));
1988 rgF.Add(
new FieldDescriptor(2,
"V1LayerParameter", TYPE.FIELDDESC, loadV1LayerParameter()));
1992 private static List<FieldDescriptor> loadLayerParameter()
1994 List<FieldDescriptor> rgF =
new List<common.FieldDescriptor>();
1995 rgF.Add(
new FieldDescriptor(1,
"name", FieldDescriptor.TYPE.STRING));
1996 rgF.Add(
new FieldDescriptor(2,
"type", FieldDescriptor.TYPE.STRING));
1997 rgF.Add(
new FieldDescriptor(3,
"bottom", FieldDescriptor.TYPE.STRING));
1998 rgF.Add(
new FieldDescriptor(4,
"top", FieldDescriptor.TYPE.STRING));
1999 rgF.Add(
new FieldDescriptor(10,
"phase", FieldDescriptor.TYPE.INT));
2000 rgF.Add(
new FieldDescriptor(5,
"loss_weight", FieldDescriptor.TYPE.FLOAT));
2001 rgF.Add(
new FieldDescriptor(6,
"param", FieldDescriptor.TYPE.FIELDDESC, loadParamSpec()));
2002 rgF.Add(
new FieldDescriptor(7,
"blobs", FieldDescriptor.TYPE.FIELDDESC, loadBlobProto()));
2003 rgF.Add(
new FieldDescriptor(11,
"prop_down", FieldDescriptor.TYPE.BOOL));
2004 rgF.Add(
new FieldDescriptor(8,
"include", FieldDescriptor.TYPE.FIELDDESC, loadNetStateRule()));
2005 rgF.Add(
new FieldDescriptor(9,
"exclude", FieldDescriptor.TYPE.FIELDDESC, loadNetStateRule()));
2006 rgF.Add(
new FieldDescriptor(100,
LayerParameter.
LayerType.TRANSFORM.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2007 rgF.Add(
new FieldDescriptor(101,
LayerParameter.
LayerType.LOSS.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2009 rgF.Add(
new FieldDescriptor(102,
LayerParameter.
LayerType.ACCURACY.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2010 rgF.Add(
new FieldDescriptor(103,
LayerParameter.
LayerType.ARGMAX.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2011 rgF.Add(
new FieldDescriptor(139,
LayerParameter.
LayerType.BATCHNORM.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2012 rgF.Add(
new FieldDescriptor(141,
LayerParameter.
LayerType.BIAS.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2013 rgF.Add(
new FieldDescriptor(104,
LayerParameter.
LayerType.CONCAT.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2014 rgF.Add(
new FieldDescriptor(105,
LayerParameter.
LayerType.CONTRASTIVE_LOSS.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2015 rgF.Add(
new FieldDescriptor(106,
LayerParameter.
LayerType.CONVOLUTION.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC, loadConvolutionParam()));
2016 rgF.Add(
new FieldDescriptor(144,
LayerParameter.
LayerType.CROP.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2017 rgF.Add(
new FieldDescriptor(107,
LayerParameter.
LayerType.DATA.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2018 rgF.Add(
new FieldDescriptor(108,
LayerParameter.
LayerType.DROPOUT.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2019 rgF.Add(
new FieldDescriptor(109,
LayerParameter.
LayerType.DUMMYDATA.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2020 rgF.Add(
new FieldDescriptor(110,
LayerParameter.
LayerType.ELTWISE.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2021 rgF.Add(
new FieldDescriptor(140,
LayerParameter.
LayerType.ELU.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2022 rgF.Add(
new FieldDescriptor(137,
LayerParameter.
LayerType.EMBED.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2023 rgF.Add(
new FieldDescriptor(111,
LayerParameter.
LayerType.EXP.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2024 rgF.Add(
new FieldDescriptor(135,
LayerParameter.
LayerType.FLATTEN.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2025 rgF.Add(
new FieldDescriptor(112,
"hdf5_input_param", FieldDescriptor.TYPE.FIELDDESC));
2026 rgF.Add(
new FieldDescriptor(113,
"hdf5_output_param", FieldDescriptor.TYPE.FIELDDESC));
2027 rgF.Add(
new FieldDescriptor(114,
LayerParameter.
LayerType.HINGE_LOSS.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2028 rgF.Add(
new FieldDescriptor(115,
"image_data_param", FieldDescriptor.TYPE.FIELDDESC));
2029 rgF.Add(
new FieldDescriptor(116,
LayerParameter.
LayerType.INFOGAIN_LOSS.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2030 rgF.Add(
new FieldDescriptor(117,
LayerParameter.
LayerType.INNERPRODUCT.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2031 rgF.Add(
new FieldDescriptor(143,
LayerParameter.
LayerType.INPUT.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2032 rgF.Add(
new FieldDescriptor(134,
LayerParameter.
LayerType.LOG.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2033 rgF.Add(
new FieldDescriptor(118,
LayerParameter.
LayerType.LRN.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2034 rgF.Add(
new FieldDescriptor(119,
LayerParameter.
LayerType.MEMORYDATA.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2035 rgF.Add(
new FieldDescriptor(120,
LayerParameter.
LayerType.MVN.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2036 rgF.Add(
new FieldDescriptor(121,
LayerParameter.
LayerType.PARAMETER.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2037 rgF.Add(
new FieldDescriptor(121,
LayerParameter.
LayerType.POOLING.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2038 rgF.Add(
new FieldDescriptor(122,
LayerParameter.
LayerType.POWER.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2039 rgF.Add(
new FieldDescriptor(131,
LayerParameter.
LayerType.PRELU.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2040 rgF.Add(
new FieldDescriptor(130,
"python_param", FieldDescriptor.TYPE.FIELDDESC));
2041 rgF.Add(
new FieldDescriptor(146,
LayerParameter.
LayerType.RECURRENT.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2042 rgF.Add(
new FieldDescriptor(136,
LayerParameter.
LayerType.REDUCTION.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2043 rgF.Add(
new FieldDescriptor(123,
LayerParameter.
LayerType.RELU.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2044 rgF.Add(
new FieldDescriptor(133,
LayerParameter.
LayerType.RESHAPE.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2045 rgF.Add(
new FieldDescriptor(142,
LayerParameter.
LayerType.SCALE.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2046 rgF.Add(
new FieldDescriptor(142,
LayerParameter.
LayerType.SCALAR.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2047 rgF.Add(
new FieldDescriptor(124,
LayerParameter.
LayerType.SIGMOID.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2048 rgF.Add(
new FieldDescriptor(125,
LayerParameter.
LayerType.SOFTMAX.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2049 rgF.Add(
new FieldDescriptor(132,
LayerParameter.
LayerType.SPP.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2050 rgF.Add(
new FieldDescriptor(126,
LayerParameter.
LayerType.SLICE.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2051 rgF.Add(
new FieldDescriptor(127,
LayerParameter.
LayerType.TANH.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2052 rgF.Add(
new FieldDescriptor(128,
LayerParameter.
LayerType.THRESHOLD.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2053 rgF.Add(
new FieldDescriptor(138,
LayerParameter.
LayerType.TILE.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2054 rgF.Add(
new FieldDescriptor(129,
"window_data_param", FieldDescriptor.TYPE.FIELDDESC));
2059 private static List<FieldDescriptor> loadV1LayerParameter()
2061 List<FieldDescriptor> rgF =
new List<common.FieldDescriptor>();
2062 rgF.Add(
new FieldDescriptor(2,
"bottom", FieldDescriptor.TYPE.STRING));
2063 rgF.Add(
new FieldDescriptor(3,
"top", FieldDescriptor.TYPE.STRING));
2064 rgF.Add(
new FieldDescriptor(4,
"name", FieldDescriptor.TYPE.STRING));
2065 rgF.Add(
new FieldDescriptor(32,
"include", FieldDescriptor.TYPE.FIELDDESC, loadNetStateRule()));
2066 rgF.Add(
new FieldDescriptor(33,
"exclude", FieldDescriptor.TYPE.FIELDDESC, loadNetStateRule()));
2067 rgF.Add(
new FieldDescriptor(5,
"type", FieldDescriptor.TYPE.INT));
2068 rgF.Add(
new FieldDescriptor(6,
"blobs", FieldDescriptor.TYPE.FIELDDESC, loadBlobProto()));
2069 rgF.Add(
new FieldDescriptor(1001,
"param", FieldDescriptor.TYPE.FIELDDESC, loadParamSpec()));
2070 rgF.Add(
new FieldDescriptor(1002,
"blob_share_mode", FieldDescriptor.TYPE.INT));
2071 rgF.Add(
new FieldDescriptor(7,
"blobs_lr", FieldDescriptor.TYPE.FLOAT));
2072 rgF.Add(
new FieldDescriptor(8,
"weight_decay", FieldDescriptor.TYPE.FLOAT));
2073 rgF.Add(
new FieldDescriptor(35,
"loss_weight", FieldDescriptor.TYPE.FLOAT));
2075 rgF.Add(
new FieldDescriptor(27,
LayerParameter.
LayerType.ACCURACY.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2076 rgF.Add(
new FieldDescriptor(23,
LayerParameter.
LayerType.ARGMAX.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2077 rgF.Add(
new FieldDescriptor(9,
LayerParameter.
LayerType.CONCAT.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2078 rgF.Add(
new FieldDescriptor(40,
LayerParameter.
LayerType.CONTRASTIVE_LOSS.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2079 rgF.Add(
new FieldDescriptor(10,
LayerParameter.
LayerType.CONVOLUTION.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC, loadConvolutionParam()));
2080 rgF.Add(
new FieldDescriptor(11,
LayerParameter.
LayerType.DATA.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2081 rgF.Add(
new FieldDescriptor(12,
LayerParameter.
LayerType.DROPOUT.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2082 rgF.Add(
new FieldDescriptor(26,
LayerParameter.
LayerType.DUMMYDATA.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2083 rgF.Add(
new FieldDescriptor(24,
LayerParameter.
LayerType.ELTWISE.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2084 rgF.Add(
new FieldDescriptor(41,
LayerParameter.
LayerType.EXP.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2085 rgF.Add(
new FieldDescriptor(13,
"hdf5_input_param", FieldDescriptor.TYPE.FIELDDESC));
2086 rgF.Add(
new FieldDescriptor(14,
"hdf5_output_param", FieldDescriptor.TYPE.FIELDDESC));
2087 rgF.Add(
new FieldDescriptor(29,
LayerParameter.
LayerType.HINGE_LOSS.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2088 rgF.Add(
new FieldDescriptor(15,
"image_data_param", FieldDescriptor.TYPE.FIELDDESC));
2089 rgF.Add(
new FieldDescriptor(16,
LayerParameter.
LayerType.INFOGAIN_LOSS.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2090 rgF.Add(
new FieldDescriptor(17,
LayerParameter.
LayerType.INNERPRODUCT.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2091 rgF.Add(
new FieldDescriptor(18,
LayerParameter.
LayerType.LRN.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2092 rgF.Add(
new FieldDescriptor(22,
LayerParameter.
LayerType.MEMORYDATA.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2093 rgF.Add(
new FieldDescriptor(34,
LayerParameter.
LayerType.MVN.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2094 rgF.Add(
new FieldDescriptor(19,
LayerParameter.
LayerType.POOLING.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2095 rgF.Add(
new FieldDescriptor(21,
LayerParameter.
LayerType.POWER.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2096 rgF.Add(
new FieldDescriptor(30,
LayerParameter.
LayerType.RELU.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2097 rgF.Add(
new FieldDescriptor(38,
LayerParameter.
LayerType.SIGMOID.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2098 rgF.Add(
new FieldDescriptor(39,
LayerParameter.
LayerType.SOFTMAX.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2099 rgF.Add(
new FieldDescriptor(31,
LayerParameter.
LayerType.SLICE.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2100 rgF.Add(
new FieldDescriptor(37,
LayerParameter.
LayerType.TANH.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2101 rgF.Add(
new FieldDescriptor(25,
LayerParameter.
LayerType.THRESHOLD.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2102 rgF.Add(
new FieldDescriptor(20,
"window_data_param", FieldDescriptor.TYPE.FIELDDESC));
2103 rgF.Add(
new FieldDescriptor(36,
LayerParameter.
LayerType.TRANSFORM.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2104 rgF.Add(
new FieldDescriptor(42,
LayerParameter.
LayerType.LOSS.ToString() +
"_param", FieldDescriptor.TYPE.FIELDDESC));
2109 private static List<FieldDescriptor> loadParamSpec()
2111 List<FieldDescriptor> rgF =
new List<common.FieldDescriptor>();
2112 rgF.Add(
new FieldDescriptor(1,
"name", FieldDescriptor.TYPE.STRING));
2113 rgF.Add(
new FieldDescriptor(2,
"share_mode", FieldDescriptor.TYPE.INT));
2114 rgF.Add(
new FieldDescriptor(3,
"lr_mult", FieldDescriptor.TYPE.FLOAT));
2115 rgF.Add(
new FieldDescriptor(4,
"decay_mult", FieldDescriptor.TYPE.FLOAT));
2119 private static List<FieldDescriptor> loadBlobShape()
2121 List<FieldDescriptor> rgF =
new List<common.FieldDescriptor>();
2122 rgF.Add(
new FieldDescriptor(1,
"dim", FieldDescriptor.TYPE.LONG));
2126 private static List<FieldDescriptor> loadBlobProto()
2128 List<FieldDescriptor> rgF =
new List<common.FieldDescriptor>();
2129 rgF.Add(
new FieldDescriptor(7,
"shape", FieldDescriptor.TYPE.FIELDDESC, loadBlobShape()));
2130 rgF.Add(
new FieldDescriptor(5,
"data", FieldDescriptor.TYPE.FLOAT));
2131 rgF.Add(
new FieldDescriptor(6,
"diff", FieldDescriptor.TYPE.FLOAT));
2132 rgF.Add(
new FieldDescriptor(8,
"double_data", FieldDescriptor.TYPE.DOUBLE));
2133 rgF.Add(
new FieldDescriptor(9,
"double_diff", FieldDescriptor.TYPE.DOUBLE));
2134 rgF.Add(
new FieldDescriptor(1,
"num", FieldDescriptor.TYPE.INT));
2135 rgF.Add(
new FieldDescriptor(2,
"channels", FieldDescriptor.TYPE.INT));
2136 rgF.Add(
new FieldDescriptor(3,
"height", FieldDescriptor.TYPE.INT));
2137 rgF.Add(
new FieldDescriptor(4,
"width", FieldDescriptor.TYPE.INT));
2141 private static List<FieldDescriptor> loadNetStateRule()
2143 List<FieldDescriptor> rgF =
new List<common.FieldDescriptor>();
2144 rgF.Add(
new FieldDescriptor(1,
"phase", FieldDescriptor.TYPE.INT));
2145 rgF.Add(
new FieldDescriptor(2,
"min_level", FieldDescriptor.TYPE.INT));
2146 rgF.Add(
new FieldDescriptor(3,
"max_level", FieldDescriptor.TYPE.INT));
2147 rgF.Add(
new FieldDescriptor(4,
"stage", FieldDescriptor.TYPE.STRING));
2148 rgF.Add(
new FieldDescriptor(5,
"not_stage", FieldDescriptor.TYPE.STRING));
2152 private static List<FieldDescriptor> loadFillerParam()
2154 List<FieldDescriptor> rgF =
new List<common.FieldDescriptor>();
2155 rgF.Add(
new FieldDescriptor(1,
"type", FieldDescriptor.TYPE.STRING));
2156 rgF.Add(
new FieldDescriptor(2,
"value", FieldDescriptor.TYPE.FLOAT));
2157 rgF.Add(
new FieldDescriptor(3,
"min", FieldDescriptor.TYPE.FLOAT));
2158 rgF.Add(
new FieldDescriptor(4,
"max", FieldDescriptor.TYPE.FLOAT));
2159 rgF.Add(
new FieldDescriptor(5,
"mean", FieldDescriptor.TYPE.FLOAT));
2160 rgF.Add(
new FieldDescriptor(6,
"std", FieldDescriptor.TYPE.FLOAT));
2161 rgF.Add(
new FieldDescriptor(7,
"sparse", FieldDescriptor.TYPE.INT));
2162 rgF.Add(
new FieldDescriptor(8,
"variance_norm", FieldDescriptor.TYPE.INT));
2166 private static List<FieldDescriptor> loadConvolutionParam()
2168 List<FieldDescriptor> rgF =
new List<common.FieldDescriptor>();
2169 rgF.Add(
new FieldDescriptor(1,
"num_output", FieldDescriptor.TYPE.UINT));
2170 rgF.Add(
new FieldDescriptor(2,
"bias_term", FieldDescriptor.TYPE.BOOL));
2171 rgF.Add(
new FieldDescriptor(3,
"pad", FieldDescriptor.TYPE.UINT));
2172 rgF.Add(
new FieldDescriptor(4,
"kernel_size", FieldDescriptor.TYPE.UINT));
2173 rgF.Add(
new FieldDescriptor(6,
"stride", FieldDescriptor.TYPE.UINT));
2174 rgF.Add(
new FieldDescriptor(18,
"dilation", FieldDescriptor.TYPE.UINT));
2175 rgF.Add(
new FieldDescriptor(9,
"pad_h", FieldDescriptor.TYPE.UINT));
2176 rgF.Add(
new FieldDescriptor(10,
"pad_w", FieldDescriptor.TYPE.UINT));
2177 rgF.Add(
new FieldDescriptor(11,
"kernel_h", FieldDescriptor.TYPE.UINT));
2178 rgF.Add(
new FieldDescriptor(12,
"kernel_w", FieldDescriptor.TYPE.UINT));
2179 rgF.Add(
new FieldDescriptor(13,
"stride_h", FieldDescriptor.TYPE.UINT));
2180 rgF.Add(
new FieldDescriptor(14,
"stride_w", FieldDescriptor.TYPE.UINT));
2181 rgF.Add(
new FieldDescriptor(5,
"group", FieldDescriptor.TYPE.UINT));
2182 rgF.Add(
new FieldDescriptor(7,
"weight_filler", FieldDescriptor.TYPE.FIELDDESC, loadFillerParam()));
2183 rgF.Add(
new FieldDescriptor(8,
"bias_filler", FieldDescriptor.TYPE.FIELDDESC, loadFillerParam()));
2184 rgF.Add(
new FieldDescriptor(15,
"engine", FieldDescriptor.TYPE.INT));
2185 rgF.Add(
new FieldDescriptor(16,
"axis", FieldDescriptor.TYPE.INT));
2186 rgF.Add(
new FieldDescriptor(17,
"force_nd", FieldDescriptor.TYPE.BOOL));
2191#pragma warning restore 1591
The Log class provides general output in text form.
The Utility class provides general utility funtions.
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
The BlobCollection contains a list of Blobs.
The Blob is the main holder of data that moves through the Layers of the Net.
object Tag
Returns a user defined object associated with the Blob.
bool reshape_when_sharing
When true, this Blob is reshaped to the source when sharing the source data (default = false).
string Name
Get/set the name of the Blob.
The PersistCaffe class is used to load and save weight files in the .caffemodel format.
PersistCaffe(Log log, bool bFailOnFirstTry)
The PersistCaffe constructor.
BlobProto LoadBlobProto(string strFile, int nFieldId)
The LoadBlobProto function loads a BlobProto from a file.
SolverState LoadSolverState(byte[] rgState, SolverParameter.SolverType type=SolverParameter.SolverType.SGD)
Load the solver state from a byte array.
BlobCollection< T > LoadWeights(byte[] rgWeights, List< string > rgExpectedShapes, BlobCollection< T > colBlobs, bool bSizeToFit, out bool bLoadedDiffs, List< string > inputWtInfo=null, List< string > targetWtInfo=null, string strSkipBlobType=null)
Loads new weights into a BlobCollection
WeightInfo< T > LoadWeightInfo(BlobCollection< T > colBlobs)
Returns the weight information describing the weights containined within the Blob collection.
string MyCaffeTag
This tag is used to mark the ending section of each weighting file with 'MyCaffe' specific informatio...
BlobProto LoadBlobProto(byte[] rg, int nFieldId)
The LoadBlobProto function loads a BlobProto from a proto buffer.
byte[] SaveWeights(BlobCollection< T > colBlobs, bool bSaveDiffs=false)
Save the weights to a byte array.
bool IsMyCaffe(byte[] rgWeights, out string strVer)
This method returns whether or not the weights have been marked as 'mycaffe.ai'.
byte[] SaveSolverState(SolverState state, SolverParameter.SolverType type=SolverParameter.SolverType.SGD)
Save the solver state to a byte array.
WeightInfo< T > LoadWeightInfo(byte[] rgWeights)
Returns the weight information describing the weights containined within the weight bytes.
The WeightInfo class describes the weights of a given weight set including the blob names and sizes o...
WeightInfo()
The constructor.
The BlobProto contains the descripion of a blob.
List< float > data
Get/set the data as a List of float.
BlobShape shape
Specifies the shape of the Blob.
List< double > double_data
Get/set the data as a List of double.
BlobProto()
Constructor for the BlobProto.
List< int > dim
The blob shape dimensions.
Specifies the base parameter for all layers.
LayerType
Specifies the layer type.
The SolverParameter is a parameter for the solver, specifying the train and test networks.
SolverType
Defines the type of solver.
The SolverState specifies the state of a given solver.
int end
Specifies the end used by L-BGFS
BlobProto gradients
Gradients used with L-BFGS state.
int iter
The current iteration.
List< BlobProto > history
The history for SGD solvers.
int start
Specifies the start used by L-BGFS
int current_step
The current step for learning rate.
List< BlobProto > s_history
S history used with L-BFGS state.
BlobProto direction
Direction used with L-BFGS state.
The IXPersist interface is used by the CaffeControl to load and save weights.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
The MyCaffe.common namespace contains common MyCaffe classes.
@ FLOAT
Specifies the single type.
@ DOUBLE
Specifies the double type.
BLOB_TYPE
Defines the tpe of data held by a given Blob.
@ UNKNOWN
The blob is an unknown type.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...