2using System.Collections.Generic;
14 [TypeConverter(typeof(ExpandableObjectConverter))]
19 double m_dfAttnDropout = 0.1;
20 double m_dfResidDropout = 0.1;
21 uint m_nBlockSize = 128;
25 bool m_bEnableLayerNormCudaImplementation =
false;
35 CAUSAL_SELF_ATTENTION = 0,
79 get {
return m_bEnableLayerNormCudaImplementation; }
80 set { m_bEnableLayerNormCudaImplementation = value; }
88 get {
return m_activation; }
89 set { m_activation = value; }
97 get {
return m_type; }
98 set { m_type = value; }
104 [Description(
"Specifies number of layers (transformer blocks) used.")]
107 get {
return m_nLayers; }
108 set { m_nLayers = value; }
114 [Description(
"Specifies number of heads used.")]
117 get {
return m_nHeads; }
118 set { m_nHeads = value; }
126 get {
return m_nEmbed; }
127 set { m_nEmbed = value; }
135 get {
return m_nBlockSize; }
136 set { m_nBlockSize = value; }
144 get {
return m_dfAttnDropout; }
145 set { m_dfAttnDropout = value; }
153 get {
return m_dfResidDropout; }
154 set { m_dfResidDropout = value; }
158 public override object Load(
System.IO.BinaryReader br,
bool bNewInstance =
true)
202 rgChildren.
Add(
"layers",
layers.ToString());
203 rgChildren.
Add(
"heads",
heads.ToString());
204 rgChildren.
Add(
"embed",
embed.ToString());
212 return new RawProto(strName,
"", rgChildren);
225 if ((strVal = rp.
FindValue(
"layers")) !=
null)
226 p.
layers = uint.Parse(strVal);
228 if ((strVal = rp.
FindValue(
"heads")) !=
null)
229 p.
heads = uint.Parse(strVal);
231 if ((strVal = rp.
FindValue(
"embed")) !=
null)
232 p.
embed = uint.Parse(strVal);
234 if ((strVal = rp.
FindValue(
"block_size")) !=
null)
237 if ((strVal = rp.
FindValue(
"attn_dropout")) !=
null)
240 if ((strVal = rp.
FindValue(
"resid_dropout")) !=
null)
243 if ((strVal = rp.
FindValue(
"activation")) !=
null)
247 else if (strVal ==
ACTIVATION.GELU_BERT.ToString())
253 if ((strVal = rp.
FindValue(
"block_type")) !=
null)
255 if (strVal ==
BLOCK_TYPE.CAUSAL_SELF_ATTENTION.ToString())
257 else if (strVal ==
BLOCK_TYPE.ENCODER.ToString())
259 else if (strVal ==
BLOCK_TYPE.DECODER.ToString())
263 if ((strVal = rp.
FindValue(
"enable_ln_cuda_impl")) !=
null)
The RawProtoCollection class is a list of RawProto objects.
void Add(RawProto p)
Adds a RawProto to the collection.
The RawProto class is used to parse and output Google prototxt file data.
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
string FindValue(string strName)
Searches for a falue of a node within this nodes children.
The LayerParameterBase is the base class for all other layer specific parameters.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...