3using System.Collections.Generic;
6using System.Threading.Tasks;
10 class InnerProductLayerInfo : LayerInfo
12 public InnerProductLayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
15 m_outputs[0].Shape[2] = 1;
16 m_outputs[0].Shape[3] = 1;
19 public override string Generate(GENERATE gen)
21 int nInFeatures = m_inputs[0].getCount(m_layer.inner_product_param.axis);
22 int nOutFeatures = (int)m_layer.inner_product_param.num_output;
25 if (gen == GENERATE.DEFINITION)
26 strCode +=
" self." + m_layer.name +
" = nn.Linear(in_features=" + nInFeatures +
", out_features=" + nOutFeatures +
", bias=" + m_layer.inner_product_param.bias_term.ToString() +
")" + Environment.NewLine;
27 else if (gen == GENERATE.INITWEIGHTS)
28 strCode += initWeights(
"", m_layer.name, m_layer.inner_product_param.bias_term, m_layer.inner_product_param.weight_filler, m_layer.inner_product_param.bias_filler);
29 else if (gen == GENERATE.FORWARD)
31 strCode +=
" " + m_inputs.AsText +
" = " + m_inputs.AsText +
".view(" + m_inputs.AsText +
".size(0), -1)" + Environment.NewLine;
32 strCode +=
" " + m_outputs.AsText +
" = self." + m_layer.name +
"(" + m_inputs.AsText +
")" + Environment.NewLine;
39 class LSTMLayerInfo : LayerInfo
41 public LSTMLayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
45 public override string Generate(GENERATE gen)
48 if (gen == GENERATE.CLASSES)
51 if (gen == GENERATE.DEFINITION)
53 string strStateSize = m_layer.recurrent_param.num_output.ToString();
54 strCode +=
" self." + m_layer.name +
" = LSTM(input_dim=" + strStateSize +
", hidden_size=" + strStateSize +
", num_layers=" + m_layer.recurrent_param.num_layers.ToString() +
", dropout=" + m_layer.recurrent_param.dropout_ratio.ToString() +
", batch_first=" + m_layer.recurrent_param.batch_first.ToString() +
")" + Environment.NewLine;
56 else if (gen == GENERATE.INITWEIGHTS)
58 strCode += initWeights(
"", m_layer.name,
true, m_layer.recurrent_param.weight_filler, m_layer.recurrent_param.bias_filler);
60 else if (gen == GENERATE.FORWARD)
61 strCode +=
" " + m_outputs.AsText +
" = self." + m_layer.name +
"(" + m_inputs.AsText +
")" + Environment.NewLine;
67 class ConcatLayerInfo : LayerInfo
69 public ConcatLayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
71 m_outputs = m_inputs.Clone(1);
72 m_outputs[0].Name = layer.
top[0];
76 for (
int i = 0; i < m_inputs.Count; i++)
84 public override string Generate(GENERATE gen)
87 if (gen == GENERATE.DEFINITION)
88 strCode +=
"# self." + m_layer.name +
" = Concat(" + m_inputs.AsText +
")" + Environment.NewLine;
89 else if (gen == GENERATE.INITWEIGHTS)
92 else if (gen == GENERATE.FORWARD)
93 strCode +=
" " + m_outputs.AsText +
" = torch.concat(" + m_inputs.AsText +
", dim=0)" + Environment.NewLine;
99 class SplitLayerInfo : LayerInfo
101 static int m_nGenerateCount = 0;
103 public SplitLayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
105 m_outputs = m_inputs.Clone();
106 m_outputs[0].Name = layer.
top[0];
107 m_outputs.Add(layer.
top[1], m_inputs[0].Shape);
110 public override string Generate(GENERATE gen)
113 if (gen == GENERATE.CLASSES)
115 strCode += generateSplitClass(m_layer);
117 else if (gen == GENERATE.DEFINITION)
119 strCode +=
"# self." + m_layer.name +
" = Split(" + m_inputs.AsText +
")" + Environment.NewLine;
121 else if (gen == GENERATE.INITWEIGHTS)
124 else if (gen == GENERATE.FORWARD)
125 strCode +=
" " + m_outputs.AsText +
" = self." + m_layer.name +
"(" + m_inputs.AsText +
")" + Environment.NewLine;
134 if (m_nGenerateCount > 0)
137 strCode +=
"class Split(nn.Module):" + Environment.NewLine;
138 strCode +=
" def __init__(self):" + Environment.NewLine;
139 strCode +=
" super(Split, self).__init__()" + Environment.NewLine;
140 strCode += Environment.NewLine;
141 strCode +=
" def forward(self, x: torch.Tensor) -> torch.Tensor:" + Environment.NewLine;
142 strCode +=
" x1 = x.detach().clone()" + Environment.NewLine;
143 strCode +=
" x2 = x.detach().clone()" + Environment.NewLine;
144 strCode +=
" return x1, x2" + Environment.NewLine;
145 strCode += Environment.NewLine;
153 class DropoutLayerInfo : LayerInfo
155 public DropoutLayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
159 public override string Generate(GENERATE gen)
162 if (gen == GENERATE.DEFINITION)
163 strCode +=
" self." + m_layer.name +
" = nn.Dropout(p=" + m_layer.dropout_param.dropout_ratio.ToString() +
")" + Environment.NewLine;
164 else if (gen == GENERATE.INITWEIGHTS)
167 else if (gen == GENERATE.FORWARD)
168 strCode +=
" " + m_outputs.AsText +
" = self." + m_layer.name +
"(" + m_inputs.AsText +
")" + Environment.NewLine;
174 class SoftmaxLayerInfo : LayerInfo
176 public SoftmaxLayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
180 public override string Generate(GENERATE gen)
183 if (gen == GENERATE.DEFINITION)
184 strCode +=
" self." + m_layer.name +
" = nn.Softmax(dim=" + m_layer.softmax_param.axis.ToString() +
")" + Environment.NewLine;
185 else if (gen == GENERATE.INITWEIGHTS)
188 else if (gen == GENERATE.FORWARD)
189 strCode +=
" " + m_outputs.AsText +
" = self." + m_layer.name +
"(" + m_inputs.AsText +
")" + Environment.NewLine;
195 class AccuracyLayerInfo : LayerInfo
197 public AccuracyLayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
199 m_rgstrReturnValues.Add(
"self.accuracy", 2);
202 public override string Generate(GENERATE gen)
205 if (gen == GENERATE.DEFINITION)
207 strCode +=
"# self." + m_layer.name +
" = Accuracy(" + m_inputs.AsText +
")" + Environment.NewLine;
208 strCode +=
" self.accuracy_sum = 0" + Environment.NewLine;
209 strCode +=
" self.accuracy_count = 0" + Environment.NewLine;
210 strCode +=
" self.accuracy = 0" + Environment.NewLine;
212 else if (gen == GENERATE.INITWEIGHTS)
215 else if (gen == GENERATE.FORWARD)
217 strCode +=
" x1 = torch.argmax(" + m_inputs.AsText +
", dim=1)" + Environment.NewLine;
218 strCode +=
" self.accuracy_sum += torch.sum(x1 == " + m_layer.bottom[1] +
")" + Environment.NewLine;
219 strCode +=
" self.accuracy_count += len(x1)" + Environment.NewLine;
220 strCode +=
" self.accuracy = self.accuracy_sum / self.accuracy_count" + Environment.NewLine;
int axis
The axis along which to concatenate – may be negative to index from the end (e.g.,...
uint num_output
The number of outputs for the layer.
Specifies the base parameter for all layers.
List< string > top
Specifies the active top connections (in the bottom, out the top)
InnerProductParameter inner_product_param
Returns the parameter set when initialized with LayerType.INNERPRODUCT
ConcatParameter concat_param
Returns the parameter set when initialized with LayerType.CONCAT
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...