5using System.Collections.Generic;
8using System.Threading.Tasks;
12 class DataTemporalLayerInfo : LayerInfo
14 static int m_nGenerationCount = 0;
16 public DataTemporalLayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
20 public override string Generate(GENERATE gen)
22 if (m_layer.include.Count == 0 || m_layer.include[0].phase !=
Phase.TRAIN)
26 if (gen == GENERATE.CLASSES)
28 if (m_nGenerationCount == 0)
29 strCode += generateDataTemporalClass(m_layer);
32 if (gen == GENERATE.DEFINITION)
34 strCode +=
" self." + m_layer.name +
" = DataTemporal()" + Environment.NewLine;
36 else if (gen == GENERATE.INITWEIGHTS)
39 else if (gen == GENERATE.FORWARD)
40 strCode +=
" " + m_outputs.AsText +
" = self." + m_layer.name +
"(" + m_inputs.AsText +
")" + Environment.NewLine;
49 strCode +=
"class DataTemporal(nn.Module):" + Environment.NewLine;
50 strCode +=
" def __init__(self):" + Environment.NewLine;
51 strCode +=
" super(DataTemporal, self).__init__()" + Environment.NewLine;
52 strCode += Environment.NewLine;
53 strCode +=
" def forward(self) -> torch.Tensor:" + Environment.NewLine;
54 strCode +=
" return None" + Environment.NewLine;
55 strCode += Environment.NewLine;
61 class ChannelEmbeddingLayerInfo : LayerInfo
63 static int m_nGenerationCreditCount = 0;
64 static int m_nGenerationCount = 0;
66 public ChannelEmbeddingLayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
70 public override string Generate(GENERATE gen)
73 if (gen == GENERATE.CREDITS)
75 strCode += generateCredits();
77 else if (gen == GENERATE.CLASSES)
79 if (m_nGenerationCount == 0)
81 strCode += generateNullTransformationClass(m_layer);
82 strCode += generateTimeDistributedClass(m_layer);
83 strCode += generateNumericalTransformationClass(m_layer);
84 strCode += generateCategoricalTransformationClass(m_layer);
85 strCode += generateChannelEmbeddingClass(m_layer);
90 else if (gen == GENERATE.DEFINITION)
92 bool bTimeDistributed =
false;
93 string strCardinalities =
Utility.ToString<
int>(m_layer.categorical_trans_param.cardinalities, -1, -1,
"[",
"]");
95 strCode +=
" self.time_distributed = " + bTimeDistributed.ToString() + Environment.NewLine;
96 strCode +=
" self.categorical_cardinalities = " + strCardinalities + Environment.NewLine;
97 strCode +=
" self." + m_layer.name +
" = nn.ChannelEmbedding(state_size=" + m_layer.numeric_trans_param.state_size.ToString() +
",num_numeric=" + m_layer.numeric_trans_param.num_input.ToString() +
",num_categorical=" + m_layer.categorical_trans_param.num_input.ToString() +
",categorical_cardinalities=self.categorical_cardinalities,time_distributed=self.time_distributed)" + Environment.NewLine;
99 else if (gen == GENERATE.INITWEIGHTS)
101 strCode +=
" self." + m_layer.name +
".init_weights()" + Environment.NewLine;
103 else if (gen == GENERATE.FORWARD)
105 strCode +=
" " + m_outputs.AsText +
" = self." + m_layer.name +
"(" + m_inputs.AsText +
")" + Environment.NewLine;
111 public static string generateCredits()
115 if (m_nGenerationCreditCount > 0)
118 strCode +=
"# NullTransform, TimeDistributed, NumericalTransformation, CategoricalTransformation, ChannelEmbedding Layers:" + Environment.NewLine;
119 strCode +=
"# original code <https://github.com/PlaytikaOSS/tft-torch/tree/main>" + Environment.NewLine;
120 strCode +=
"# license: (MIT) <https://github.com/PlaytikaOSS/tft-torch/blob/main/LICENSE>" + Environment.NewLine;
122 m_nGenerationCreditCount++;
131 strCode +=
"class NullTransformation(nn.Module):" + Environment.NewLine;
132 strCode +=
" def __init__(self):" + Environment.NewLine;
133 strCode +=
" super(NullTransformation, self).__init__()" + Environment.NewLine;
134 strCode += Environment.NewLine;
135 strCode +=
" def forward(self, x: torch.Tensor):" + Environment.NewLine;
136 strCode +=
" return []" + Environment.NewLine;
137 strCode += Environment.NewLine;
146 strCode +=
"class TimeDistributed(nn.Module):" + Environment.NewLine;
147 strCode +=
" def __init__(self, module: nn.Module, batch_first: bool = True, return_reshaped: bool = True):" + Environment.NewLine;
148 strCode +=
" super(TimeDistributed, self).__init__()" + Environment.NewLine;
149 strCode +=
" self.module = module" + Environment.NewLine;
150 strCode +=
" self.batch_first = batch_first" + Environment.NewLine;
151 strCode +=
" self.return_reshaped = return_reshaped" + Environment.NewLine;
152 strCode += Environment.NewLine;
153 strCode +=
" def forward(self, x: torch.Tensor):" + Environment.NewLine;
154 strCode +=
" if len(x.shape) <= 2:" + Environment.NewLine;
155 strCode +=
" return self.module(x)" + Environment.NewLine;
156 strCode += Environment.NewLine;
157 strCode +=
" x_reshape = x.contiguous().view(-1, x.shape[-1])" + Environment.NewLine;
158 strCode +=
" y = self.module(x_reshape)" + Environment.NewLine;
159 strCode += Environment.NewLine;
160 strCode +=
" if self.return_reshaped:" + Environment.NewLine;
161 strCode +=
" y = y.contiguous().view(x.shape[0], -1, y.shape[-1])" + Environment.NewLine;
162 strCode +=
" else:" + Environment.NewLine;
163 strCode +=
" y = y.contiguous().view(-1, x.shape[1], y.shape[-1])" + Environment.NewLine;
164 strCode += Environment.NewLine;
165 strCode +=
" return y" + Environment.NewLine;
166 strCode += Environment.NewLine;
171 private string generateNumericalTransformationClass(
LayerParameter p)
178 strCode += generateNumericalTransformationClassComments();
180 strCode +=
"class NumericalTransformation(nn.Module):" + Environment.NewLine;
181 strCode +=
" def __init__(self, num_inputs: int, state_size: int):" + Environment.NewLine;
182 strCode +=
" super(NumericalTransformation, self).__init__()" + Environment.NewLine;
183 strCode +=
" self.num_inputs = num_inputs" + Environment.NewLine;
184 strCode +=
" self.state_size = state_size" + Environment.NewLine;
185 strCode += Environment.NewLine;
186 strCode +=
" self.numeric_projection_layers = nn.ModuleList()" + Environment.NewLine;
187 strCode +=
" for i in range(num_inputs):" + Environment.NewLine;
188 strCode +=
" self.numeric_projection_layers.append(nn.Linear(1, self.state_size))" + Environment.NewLine;
189 strCode += Environment.NewLine;
190 strCode +=
" def init_weights(self):" + Environment.NewLine;
191 strCode +=
" for i in range(self.num_inputs):" + Environment.NewLine;
192 strCode += initWeights(
" ",
"self.numeric_projection_layers[i]",
true, wtfiller, biasfiller);
193 strCode += Environment.NewLine;
194 strCode +=
" def forward(self, x: torch.Tensor):" + Environment.NewLine;
195 strCode +=
" projections = []" + Environment.NewLine;
196 strCode +=
" for i in range(self.num_inputs):" + Environment.NewLine;
197 strCode +=
" x1 = x[:,[i]]" + Environment.NewLine;
198 strCode +=
" x2 = self.numeric_projection_layers[i](x1)" + Environment.NewLine;
199 strCode +=
" projections.append(x2)" + Environment.NewLine;
200 strCode +=
" return projections" + Environment.NewLine;
201 strCode += Environment.NewLine;
206 private string generateNumericalTransformationClassComments()
210 strCode += LayerInfo.generateBar();
211 strCode +=
"# The NumericalTransformation class transforms the numerical input into a set of projections for each input." + Environment.NewLine;
212 strCode +=
"# Each input is projected using a dedicated linear layer to a vector within the state_size, which is." + Environment.NewLine;
213 strCode +=
"# output as a list of length num_inputs that contains each embedding." + Environment.NewLine;
214 strCode +=
"#" + Environment.NewLine;
215 strCode +=
"# Parameters" + Environment.NewLine;
216 strCode +=
"# ----------" + Environment.NewLine;
217 strCode +=
"# num_input : int" + Environment.NewLine;
218 strCode +=
"# The number of numerical inputs." + Environment.NewLine;
219 strCode +=
"# state_size : int" + Environment.NewLine;
220 strCode +=
"# The state size of the model, which determines the embedding dimension/width for each input variable." + Environment.NewLine;
225 private string generateCategoricalTransformationClass(
LayerParameter p)
231 strCode += generateCategoricalTransformationClassComments();
233 strCode +=
"class CategoricalTransformation(nn.Module):" + Environment.NewLine;
234 strCode +=
" def __init__(self, num_inputs: int, state_size: int, cardinalities: List[int]):" + Environment.NewLine;
235 strCode +=
" super(CategoricalTransformation, self).__init__()" + Environment.NewLine;
236 strCode +=
" self.num_inputs = num_inputs" + Environment.NewLine;
237 strCode +=
" self.state_size = state_size" + Environment.NewLine;
238 strCode +=
" self.cardinalities = cardinalities" + Environment.NewLine;
239 strCode += Environment.NewLine;
240 strCode +=
" self.categorical_embedding_layers = nn.ModuleList()" + Environment.NewLine;
241 strCode +=
" for i, cardinality in enumerate(self.cardinalities):" + Environment.NewLine;
242 strCode +=
" self.categorical_embedding_layers.append(nn.Embedding(cardinality, self.state_size))" + Environment.NewLine;
243 strCode += Environment.NewLine;
244 strCode +=
" def init_weights(self):" + Environment.NewLine;
245 strCode +=
" for i, cardinality in enumerate(self.cardinalities):" + Environment.NewLine;
246 strCode += initWeights(
" ",
"self.categorical_embedding_layers[i]",
false, wtfiller,
null);
247 strCode += Environment.NewLine;
248 strCode +=
" def forward(self, x: torch.Tensor):" + Environment.NewLine;
249 strCode +=
" embeddings = []" + Environment.NewLine;
250 strCode +=
" for i in range(self.num_inputs):" + Environment.NewLine;
251 strCode +=
" x1 = x[:,i]" + Environment.NewLine;
252 strCode +=
" x2 = self.categorical_embedding_layers[i](x1)" + Environment.NewLine;
253 strCode +=
" embeddings.append(x2)" + Environment.NewLine;
254 strCode +=
" return embeddings" + Environment.NewLine;
255 strCode += Environment.NewLine;
260 private string generateCategoricalTransformationClassComments()
264 strCode += LayerInfo.generateBar();
265 strCode +=
"# The CategoricalTransformation class transforms the categorical input into a set of embeddings for each input." + Environment.NewLine;
266 strCode +=
"# Each input is projected using a dedicated embedding layer to a vector within the state_size, which is." + Environment.NewLine;
267 strCode +=
"# output as a list of length num_inputs that contains each embedding." + Environment.NewLine;
268 strCode +=
"#" + Environment.NewLine;
269 strCode +=
"# Parameters" + Environment.NewLine;
270 strCode +=
"# ----------" + Environment.NewLine;
271 strCode +=
"# num_input : int" + Environment.NewLine;
272 strCode +=
"# The number of categorical inputs." + Environment.NewLine;
273 strCode +=
"# state_size : int" + Environment.NewLine;
274 strCode +=
"# The state size of the model, which determines the embedding dimension/width for each input variable." + Environment.NewLine;
275 strCode +=
"# cadinalities : List[int]" + Environment.NewLine;
276 strCode +=
"# The cardinality of each categorical input." + Environment.NewLine;
286 strCode += generateChannelEmbeddingClassComments();
288 strCode +=
"class ChannelEmbedding(nn.Module):" + Environment.NewLine;
289 strCode +=
" def __init__(self, state_size: int, num_numeric: int, num_categorical: int, categorical_cardinalities: List[int], time_distribute: Optional[bool] = False):" + Environment.NewLine;
290 strCode +=
" super(ChannelEmbedding, self).__init__()" + Environment.NewLine;
291 strCode +=
" self.state_size = state_size" + Environment.NewLine;
292 strCode +=
" self.num_numeric = num_numeric" + Environment.NewLine;
293 strCode +=
" self.num_categorical = num_categorical" + Environment.NewLine;
294 strCode +=
" self.categorical_cardinalities = categorical_cardinalities" + Environment.NewLine;
295 strCode += Environment.NewLine;
296 strCode +=
" if num_numeric > 0:" + Environment.NewLine;
297 strCode +=
" if time_distribute:" + Environment.NewLine;
298 strCode +=
" self.numerical_transformation = TimeDistributed(NumericalTransformation(num_inputs=self.num_numeric, state_size=self.state_size))" + Environment.NewLine;
299 strCode +=
" else:" + Environment.NewLine;
300 strCode +=
" self.numerical_transformation = NumericalTransformation(num_inputs=self.num_numeric, state_size=self.state_size)" + Environment.NewLine;
301 strCode +=
" else:" + Environment.NewLine;
302 strCode +=
" self.numerical_transformation = NullTransformation()" + Environment.NewLine;
303 strCode += Environment.NewLine;
304 strCode +=
" if num_categorical > 0:" + Environment.NewLine;
305 strCode +=
" if time_distribute:" + Environment.NewLine;
306 strCode +=
" self.categorical_transformation = TimeDistributed(CategoricalTransformation(num_inputs=self.num_categorical, state_size=self.state_size, cardinalities=self.categorical_cardinalities))" + Environment.NewLine;
307 strCode +=
" else:" + Environment.NewLine;
308 strCode +=
" self.categorical_transformation = CategoricalTransformation(num_inputs=self.num_categorical, state_size=self.state_size, cardinalities=self.categorical_cardinalities)" + Environment.NewLine;
309 strCode +=
" else:" + Environment.NewLine;
310 strCode +=
" self.categorical_transformation = NullTransformation()" + Environment.NewLine;
311 strCode += Environment.NewLine;
312 strCode +=
" def init_weights(self):" + Environment.NewLine;
313 strCode +=
" self.numeric_transformation.init_weights()" + Environment.NewLine;
314 strCode +=
" self.categorical_transformation.init_weights()" + Environment.NewLine;
315 strCode += Environment.NewLine;
316 strCode +=
" def forward(self, x_num: torch.Tensor, x_cat: torch.Tensor) -> torch.Tensor:" + Environment.NewLine;
317 strCode +=
" batch_shape = x_num.shape if x_num.nelement() > 0 else x_cat.shape" + Environment.NewLine;
318 strCode +=
" processed_num = self.numerical_transformation(x_num)" + Environment.NewLine;
319 strCode +=
" processed_cat = self.categorical_transformation(x_cat)" + Environment.NewLine;
320 strCode +=
" merged_transformations = torch.cat(processed_num + processed_cat, dim=1)" + Environment.NewLine;
321 strCode += Environment.NewLine;
322 strCode +=
" return merged_transformations" + Environment.NewLine;
323 strCode += Environment.NewLine;
328 private string generateChannelEmbeddingClassComments()
332 strCode += LayerInfo.generateBar();
333 strCode +=
"# The ChannelEmbedding class handles the transformation/embedding of the input channel of numeric and categorical tensors." + Environment.NewLine;
334 strCode +=
"# A NumericalTransformation class is used to process the numerical inputs, and" + Environment.NewLine;
335 strCode +=
"# a CategoricalTransformation class is used to process the categorical inputs." + Environment.NewLine;
336 strCode +=
"#" + Environment.NewLine;
337 strCode +=
"# Parameters" + Environment.NewLine;
338 strCode +=
"# ----------" + Environment.NewLine;
339 strCode +=
"# state_size : int" + Environment.NewLine;
340 strCode +=
"# The state size of the model, which determines the embedding dimension/width for each input variable." + Environment.NewLine;
341 strCode +=
"# num_numeric : int" + Environment.NewLine;
342 strCode +=
"# The number of numeric inputs." + Environment.NewLine;
343 strCode +=
"# num_categorical : int" + Environment.NewLine;
344 strCode +=
"# The number of categorical inputs." + Environment.NewLine;
345 strCode +=
"# categorical_cardinalities : List[int]" + Environment.NewLine;
346 strCode +=
"# The cardinality of each categorical input." + Environment.NewLine;
347 strCode +=
"# time_distribute : Optional[bool] = False" + Environment.NewLine;
348 strCode +=
"# When True, the TimeDistributed transformation is applied to each time step of the input tensor." + Environment.NewLine;
354 class GLULayerInfo : LayerInfo
356 static int m_nGenerationCreditCount = 0;
357 static int m_nGenerationCount = 0;
359 public GLULayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
363 public override string Generate(GENERATE gen)
366 if (gen == GENERATE.CREDITS)
368 strCode += generateGluCredits();
370 else if (gen == GENERATE.CLASSES)
372 strCode += generateGluClass(m_layer.glu_param.bias_term, m_layer.glu_param.weight_filler, m_layer.glu_param.bias_filler, m_bAddComments);
374 if (gen == GENERATE.DEFINITION)
376 strCode +=
" self." + m_layer.name +
" = GLU(input_dim=" + m_layer.glu_param.input_dim.ToString() +
")" + Environment.NewLine;
378 else if (gen == GENERATE.INITWEIGHTS)
380 strCode +=
" self." + m_layer.name +
".init_weights()" + Environment.NewLine;
382 else if (gen == GENERATE.FORWARD)
383 strCode +=
" " + m_outputs.AsText +
" = self." + m_layer.name +
"(" + m_inputs.AsText +
")" + Environment.NewLine;
388 public static string generateGluCredits()
392 if (m_nGenerationCreditCount > 0)
395 strCode +=
"# GLU Layer:" + Environment.NewLine;
396 strCode +=
"# original code <https://github.com/PlaytikaOSS/tft-torch/tree/main>" + Environment.NewLine;
397 strCode +=
"# license: (MIT) <https://github.com/PlaytikaOSS/tft-torch/blob/main/LICENSE>" + Environment.NewLine;
399 m_nGenerationCreditCount++;
408 if (m_nGenerationCount == 0)
411 strCode += generateGluClassComments();
412 strCode +=
"class GLU(nn.Module):" + Environment.NewLine;
413 strCode +=
" def __init__(self, input_dim: int):" + Environment.NewLine;
414 strCode +=
" super(GLU, self).__init__()" + Environment.NewLine;
415 strCode +=
" self.input_dim = input_dim" + Environment.NewLine;
416 strCode +=
" self.fc1 = nn.Linear(self.input_dim, self.input_dim)" + Environment.NewLine;
417 strCode +=
" self.fc2 = nn.Linear(self.input_dim, self.input_dim)" + Environment.NewLine;
418 strCode +=
" self.sigmoid = nn.Sigmoid()" + Environment.NewLine;
419 strCode += Environment.NewLine;
420 strCode +=
" def init_weights(self):" + Environment.NewLine;
421 strCode += initWeights(
" ",
"self.fc1", bBiasTerm, wtFiller, biasFiller);
422 strCode += initWeights(
" ",
"self.fc2", bBiasTerm, wtFiller, biasFiller);
423 strCode += Environment.NewLine;
424 strCode +=
" def forward(self, x: torch.Tensor) -> torch.Tensor:" + Environment.NewLine;
425 strCode +=
" x = self.fc1(x)" + Environment.NewLine;
426 strCode +=
" sig = self.sigmoid(x)" + Environment.NewLine;
427 strCode +=
" x = self.fc2(x)" + Environment.NewLine;
428 strCode +=
" return torch.mul(sig, x)" + Environment.NewLine;
429 strCode += Environment.NewLine;
432 m_nGenerationCount++;
437 private static string generateGluClassComments()
441 strCode += LayerInfo.generateBar();
442 strCode +=
"# The GLU class defines the Gated Linear Unit (GLU) layer, as described" + Environment.NewLine;
443 strCode +=
"# in Dauphin, Yann N., et al. \"Language modeling with gated convolutional networks.\" arXiv preprint arXiv:1612.08083 (2016)." + Environment.NewLine;
444 strCode +=
"# <https://arxiv.org/pdf/1612.08083.pdf>" + Environment.NewLine;
445 strCode +=
"#" + Environment.NewLine;
446 strCode +=
"# The output of this layer is a linear projection (X * W + b) modulated by the gates Sigmoid(X * V + c)." + Environment.NewLine;
447 strCode +=
"# These gates multiply each element of the matrix X * W and control the information passed on in the heirarchy." + Environment.NewLine;
448 strCode +=
"# This unit is simplified gating mechanism for non-deterministic gates that reduce the vanishing gradient problem," + Environment.NewLine;
449 strCode +=
"# by having linear units coupled to the gates. This retains the non-linear capabilities of the network while allowing" + Environment.NewLine;
450 strCode +=
"# the gradient to propagate through the linear unit without scaling." + Environment.NewLine;
451 strCode +=
"#" + Environment.NewLine;
452 strCode +=
"# Parameters" + Environment.NewLine;
453 strCode +=
"# ----------" + Environment.NewLine;
454 strCode +=
"# input_dim : int" + Environment.NewLine;
455 strCode +=
"# The embedding width/dimension of the input." + Environment.NewLine;
461 class GRNLayerInfo : LayerInfo
463 static int m_nGenerationCreditCount = 0;
464 static int m_nGenerationCount = 0;
466 public GRNLayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
470 public override string Generate(GENERATE gen)
473 if (gen == GENERATE.CREDITS)
475 strCode += generateCredits();
477 else if (gen == GENERATE.CLASSES)
479 strCode += generateGrnClass(m_layer);
481 if (gen == GENERATE.DEFINITION)
483 strCode +=
" self." + m_layer.name +
" = GRN(input_dim=" + m_layer.grn_param.input_dim.ToString() +
", hidden_dim=" + m_layer.grn_param.hidden_dim.ToString() +
", output_dim=" + m_layer.grn_param.output_dim.ToString() +
", dropout=" + m_layer.grn_param.dropout_ratio.ToString() +
", context_dim=" + m_layer.grn_param.context_dim.ToString() +
", batch_first=" + m_layer.grn_param.batch_first.ToString() +
", activation=\'" + m_layer.grn_param.activation.ToString() +
"\')" + Environment.NewLine;
485 else if (gen == GENERATE.INITWEIGHTS)
487 strCode +=
" self." + m_layer.name +
".init_weights()" + Environment.NewLine;
489 else if (gen == GENERATE.FORWARD)
490 strCode +=
" " + m_outputs.AsText +
" = self." + m_layer.name +
"(" + m_inputs.AsText +
")" + Environment.NewLine;
495 public static string generateCredits()
499 if (m_nGenerationCreditCount++ > 0)
502 strCode +=
"# GLU, GRN Layers:" + Environment.NewLine;
503 strCode +=
"# original code <https://github.com/PlaytikaOSS/tft-torch/tree/main>" + Environment.NewLine;
504 strCode +=
"# license: (MIT) <https://github.com/PlaytikaOSS/tft-torch/blob/main/LICENSE>" + Environment.NewLine;
506 m_nGenerationCreditCount++;
515 if (m_nGenerationCount > 0)
518 strCode += GLULayerInfo.generateGluClass(
true, p.
grn_param.weight_filler, p.
grn_param.bias_filler, m_bAddComments);
521 strCode += generateClassComments(p);
523 strCode +=
"class GRN(nn.Module):" + Environment.NewLine;
524 strCode +=
" def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: Optional[float] = 0.05, context_dim: Optional[int] = None, batch_first: Optional[bool] = True, activation: Optional[string] = 'ELU'):" + Environment.NewLine;
525 strCode +=
" super(GRN, self).__init__()" + Environment.NewLine;
526 strCode +=
" self.input_dim = input_dim" + Environment.NewLine;
527 strCode +=
" self.hidden_dim = hidden_dim" + Environment.NewLine;
528 strCode +=
" self.output_dim = output_dim" + Environment.NewLine;
529 strCode +=
" self.context_dim = context_dim" + Environment.NewLine;
530 strCode +=
" self.batch_first = batch_first" + Environment.NewLine;
531 strCode +=
" self.dropout = dropout" + Environment.NewLine;
532 strCode += Environment.NewLine;
533 strCode +=
" self.project_residual: bool = self.input_dim != self.output_dim" + Environment.NewLine;
534 strCode +=
" if self.project_residual:" + Environment.NewLine;
535 strCode +=
" self.skip_layer = TimeDistributed(nn.Linear(self.input_dim, self.output_dim))" + Environment.NewLine;
536 strCode += Environment.NewLine;
537 strCode +=
" self.fc1 = TimeDistributed(nn.Linear(self.input_dim, self.hidden_dim), batch_first=batch_first)" + Environment.NewLine;
538 strCode += Environment.NewLine;
539 strCode +=
" if self.context_dim is not None:" + Environment.NewLine;
540 strCode +=
" self.context_projection = TimeDistributed(nn.Linear(self.context_dim, self.hidden_dim, bias=False), batch_first=batch_first)" + Environment.NewLine;
541 strCode += Environment.NewLine;
542 strCode +=
" if activation == 'RELU':" + Environment.NewLine;
543 strCode +=
" self.activation = nn.ReLU()" + Environment.NewLine;
544 strCode +=
" else:" + Environment.NewLine;
545 strCode +=
" self.activation = nn.ELU()" + Environment.NewLine;
546 strCode += Environment.NewLine;
547 strCode +=
" self.fc2 = TimeDistributed(nn.Linear(self.hidden_dim, self.output_dim), batch_first=batch_first)" + Environment.NewLine;
548 strCode += Environment.NewLine;
549 strCode +=
" self.dropout = nn.Dropout(self.dropout)" + Environment.NewLine;
550 strCode +=
" self.gate = TimeDistributed(GLU(self.output_dim), batch_first=batch_first)" + Environment.NewLine;
551 strCode +=
" self.layernorm = TimeDistributed(nn.LayerNorm(self.output_dim), batch_first=batch_first)" + Environment.NewLine;
552 strCode += Environment.NewLine;
553 strCode +=
" def init_weights(self):" + Environment.NewLine;
554 strCode += initWeights(
"",
"self.skip_layer",
true, m_layer.grn_param.weight_filler, m_layer.grn_param.bias_filler);
555 strCode += initWeights(
"",
"self.fc1",
true, m_layer.grn_param.weight_filler, m_layer.grn_param.bias_filler);
556 strCode += initWeights(
"",
"self.context_projection",
true, m_layer.grn_param.weight_filler, m_layer.grn_param.bias_filler);
557 strCode += initWeights(
"",
"self.fc2",
true, m_layer.grn_param.weight_filler, m_layer.grn_param.bias_filler);
558 strCode +=
" self.gate.init_weights()" + Environment.NewLine;
559 strCode += Environment.NewLine;
560 strCode +=
" def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:" + Environment.NewLine;
561 strCode +=
" if self.project_residual:" + Environment.NewLine;
562 strCode +=
" residual = self.skip_layer(x)" + Environment.NewLine;
563 strCode +=
" else:" + Environment.NewLine;
564 strCode +=
" residual = x" + Environment.NewLine;
565 strCode += Environment.NewLine;
566 strCode +=
" x = self.fc1(x)" + Environment.NewLine;
567 strCode +=
" if context is not None:" + Environment.NewLine;
568 strCode +=
" context = self.context_projection(context)" + Environment.NewLine;
569 strCode +=
" x = x + context" + Environment.NewLine;
570 strCode += Environment.NewLine;
571 strCode +=
" x = self.activation(x)" + Environment.NewLine;
572 strCode +=
" x = self.fc2(x)" + Environment.NewLine;
573 strCode +=
" x = self.dropout(x)" + Environment.NewLine;
574 strCode +=
" x = self.gate(x)" + Environment.NewLine;
575 strCode +=
" x = x + residual" + Environment.NewLine;
576 strCode +=
" x = self.layernorm(x)" + Environment.NewLine;
577 strCode +=
" return x" + Environment.NewLine;
578 strCode += Environment.NewLine;
580 m_nGenerationCount++;
589 strCode += LayerInfo.generateBar();
590 strCode +=
"# The GRN class defines the Gated Residual Network layer." + Environment.NewLine;
591 strCode +=
"#" + Environment.NewLine;
592 strCode +=
"# The primary consists of the input (x) and an optional context vector (c)." + Environment.NewLine;
593 strCode +=
"# A GLU is used for controlling the extent to which the module contributes to the original input (x)," + Environment.NewLine;
594 strCode +=
"# potentially skipping over the layer entirely as the GLU outputs could be close to zero, therefore" + Environment.NewLine;
595 strCode +=
"# suppressing the non-linear contribution. When no context vector is used, the GRN treats the context" + Environment.NewLine;
596 strCode +=
"# input as zero. During training, dropout is applied before the gating layer." + Environment.NewLine;
597 strCode +=
"#" + Environment.NewLine;
598 strCode +=
"# Parameters" + Environment.NewLine;
599 strCode +=
"# ----------" + Environment.NewLine;
600 strCode +=
"# input_dim : int" + Environment.NewLine;
601 strCode +=
"# The embedding width/dimension of the input." + Environment.NewLine;
602 strCode +=
"# hidden_dim : int" + Environment.NewLine;
603 strCode +=
"# The itermediate embedding width." + Environment.NewLine;
604 strCode +=
"# output_dim : int" + Environment.NewLine;
605 strCode +=
"# The embedding width of the output." + Environment.NewLine;
606 strCode +=
"# dropout : Optional[float]" + Environment.NewLine;
607 strCode +=
"# The dropout ratio to use." + Environment.NewLine;
608 strCode +=
"# context_dim : int" + Environment.NewLine;
609 strCode +=
"# The embedding width/dimension of the context siganl expected to be fed as an auxiliary input." + Environment.NewLine;
610 strCode +=
"# batch_first : bool" + Environment.NewLine;
611 strCode +=
"# When true, the first dimension of the input and output is the batch size." + Environment.NewLine;
616 class VarSelNetLayerInfo : LayerInfo
618 static int m_nGenerationCreditCount = 0;
619 static int m_nGenerationCount = 0;
621 public VarSelNetLayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
625 public override string Generate(GENERATE gen)
628 if (gen == GENERATE.CREDITS)
630 strCode += generateCredits();
632 else if (gen == GENERATE.CLASSES)
634 strCode += generateVarSelNetClass(m_layer);
636 if (gen == GENERATE.DEFINITION)
638 strCode +=
" self." + m_layer.name +
" = VarSelNet(input_dim=" + m_layer.varselnet_param.input_dim.ToString() +
", num_inputs=" + m_layer.varselnet_param.num_inputs.ToString() +
", hidden_dim=" + m_layer.varselnet_param.hidden_dim.ToString() +
", dropout=" + m_layer.varselnet_param.dropout_ratio.ToString() +
", context_dim=" + m_layer.varselnet_param.context_dim.ToString() +
", batch_first=" + m_layer.varselnet_param.batch_first.ToString() +
")" + Environment.NewLine;
640 else if (gen == GENERATE.INITWEIGHTS)
642 strCode +=
" self." + m_layer.name +
".init_weights()" + Environment.NewLine;
644 else if (gen == GENERATE.FORWARD)
645 strCode +=
" " + m_outputs.AsText +
" = self." + m_layer.name +
"(" + m_inputs.AsText +
")" + Environment.NewLine;
650 public static string generateCredits()
654 if (m_nGenerationCreditCount > 0)
657 strCode +=
"# VarSetNet Layer:" + Environment.NewLine;
658 strCode +=
"# original code <https://github.com/PlaytikaOSS/tft-torch/tree/main>" + Environment.NewLine;
659 strCode +=
"# license: (MIT) <https://github.com/PlaytikaOSS/tft-torch/blob/main/LICENSE>" + Environment.NewLine;
661 m_nGenerationCreditCount++;
670 if (m_nGenerationCount > 0)
674 strCode += generateClassComments();
676 strCode +=
"class VarSelNet(nn.Module):" + Environment.NewLine;
677 strCode +=
" def __init__(self, input_dim: int, num_inputs: int, hidden_dim: int, dropout: float, context_dim: Optional[int] = None, batch_first: Optional[bool] = True):" + Environment.NewLine;
678 strCode +=
" super(VarSelNet, self).__init__()" + Environment.NewLine;
679 strCode += Environment.NewLine;
680 strCode +=
" self.hidden_dim = hidden_dim" + Environment.NewLine;
681 strCode +=
" self.input_dim = input_dim" + Environment.NewLine;
682 strCode +=
" self.num_inputs = num_inputs" + Environment.NewLine;
683 strCode +=
" self.dropout = dropout" + Environment.NewLine;
684 strCode +=
" self.context_dim = context_dim" + Environment.NewLine;
685 strCode += Environment.NewLine;
686 strCode +=
" self.flattened_grn = GRN(input_dim=self.num_inputs * self.input_dim, hidden_dim=self.hidden_dim, output_dim=self.num_inputs, dropout=self.dropout, context_dim=self.context_dim, batch_first=batch_first)" + Environment.NewLine;
687 strCode +=
" self.softmax = nn.Softmax(dim=1)" + Environment.NewLine;
688 strCode += Environment.NewLine;
689 strCode +=
" self.single_variable_grns = nn.ModuleList()" + Environment.NewLine;
690 strCode +=
" for i in range(self.num_inputs):" + Environment.NewLine;
691 strCode +=
" self.single_variable_grns.append(GRN(input_dim=self.input_dim, hidden_dim=self.hidden_dim, output_dim=self.hidden_dim, dropout=self.dropout, batch_first=batch_first))" + Environment.NewLine;
692 strCode += Environment.NewLine;
693 strCode +=
" def init_weights(self):" + Environment.NewLine;
694 strCode +=
" self.flattened_grn.init_weights()" + Environment.NewLine;
695 strCode +=
" for i in range(self.num_inputs):" + Environment.NewLine;
696 strCode +=
" self.single_variable_grns[i].init_weights()" + Environment.NewLine;
697 strCode += Environment.NewLine;
698 strCode +=
" def forward(self, flattened_embedding: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:" + Environment.NewLine;
699 strCode +=
" sparse_weights = self.flattened_grn(flattened_embedding, context)" + Environment.NewLine;
700 strCode +=
" sparse_weights = self.softmax(sparse_weights)" + Environment.NewLine;
701 strCode += Environment.NewLine;
702 strCode +=
" processed_inputs = []" + Environment.NewLine;
703 strCode +=
" for i in range(self.num_inputs):" + Environment.NewLine;
704 strCode +=
"# processed_inputs.append(self.single_variable_grns[i](flattened_embedding[..., i * self.input_dim:(i + 1) * self.input_dim]))" + Environment.NewLine;
705 strCode +=
" processed_inputs.append(self.single_variable_grns[i](flattened_embedding[:, i * self.input_dim:(i + 1) * self.input_dim]))" + Environment.NewLine;
706 strCode +=
" processed_inputs = torch.stack(processed_inputs, dim=1)" + Environment.NewLine;
707 strCode += Environment.NewLine;
708 strCode +=
" outputs = processed_inputs * sparse_weights.transpose(1, 2)" + Environment.NewLine;
709 strCode +=
" outputs = outputs.sum(axis=-1)" + Environment.NewLine;
710 strCode += Environment.NewLine;
711 strCode +=
" return outputs, sparse_weights" + Environment.NewLine;
712 strCode += Environment.NewLine;
714 m_nGenerationCount++;
719 private string generateClassComments()
723 strCode += LayerInfo.generateBar();
724 strCode +=
"# The VarSelNet class handles the fact that the relevant and specific contribution of each input." + Environment.NewLine;
725 strCode +=
"# variable to the output is unknown. This class enables instance-wise variable selection, and is applied." + Environment.NewLine;
726 strCode +=
"# to both the static covariates and time-dependent covariates. In addition to providing insights int which" + Environment.NewLine;
727 strCode +=
"# variables are most important, this class also enables the model remove any unecessary noisy inputs that" + Environment.NewLine;
728 strCode +=
"# could negatively impact performance." + Environment.NewLine;
729 strCode +=
"#" + Environment.NewLine;
730 strCode +=
"# Parameters" + Environment.NewLine;
731 strCode +=
"# ----------" + Environment.NewLine;
732 strCode +=
"# input_dim : int" + Environment.NewLine;
733 strCode +=
"# The attribute/embedding dimension of the input, associated with the 'state_size' of the model.." + Environment.NewLine;
734 strCode +=
"# num_input : int" + Environment.NewLine;
735 strCode +=
"# The number of input variables, including both numeric and categorical inputs." + Environment.NewLine;
736 strCode +=
"# hidden_dim : int" + Environment.NewLine;
737 strCode +=
"# The embedding width of the output." + Environment.NewLine;
738 strCode +=
"# dropout : float" + Environment.NewLine;
739 strCode +=
"# The dropout rate associated with the 'GRN' classes." + Environment.NewLine;
740 strCode +=
"# context_dim : int" + Environment.NewLine;
741 strCode +=
"# The embedding width of the context signal expected to be fed as an auxilary input." + Environment.NewLine;
742 strCode +=
"# batch_first : bool" + Environment.NewLine;
743 strCode +=
"# When True, the first dimension of the input and output tensors represent the batch size." + Environment.NewLine;
749 class GateAddNormLayerInfo : LayerInfo
751 static int m_nGenerationCreditCount = 0;
752 static int m_nGenerationCount = 0;
754 public GateAddNormLayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
758 public override string Generate(GENERATE gen)
761 if (gen == GENERATE.CREDITS)
763 strCode += generateCredits();
765 else if (gen == GENERATE.CLASSES)
767 strCode += generateGateAddNormClass(m_layer);
769 if (gen == GENERATE.DEFINITION)
771 strCode +=
" self." + m_layer.name +
" = GLU(input_dim=" + m_layer.glu_param.input_dim.ToString() +
")" + Environment.NewLine;
773 else if (gen == GENERATE.INITWEIGHTS)
775 strCode +=
" self." + m_layer.name +
".init_weights()" + Environment.NewLine;
777 else if (gen == GENERATE.FORWARD)
778 strCode +=
" " + m_outputs.AsText +
" = self." + m_layer.name +
"(" + m_inputs.AsText +
")" + Environment.NewLine;
783 public static string generateCredits()
787 if (m_nGenerationCreditCount > 0)
790 strCode +=
"# GateAddNorm Layer:" + Environment.NewLine;
791 strCode +=
"# original code <https://github.com/PlaytikaOSS/tft-torch/tree/main>" + Environment.NewLine;
792 strCode +=
"# license: (MIT) <https://github.com/PlaytikaOSS/tft-torch/blob/main/LICENSE>" + Environment.NewLine;
793 m_nGenerationCreditCount++;
802 if (m_nGenerationCount > 0)
806 strCode += generateClassComments();
808 strCode +=
"class GateAddNorm(nn.Module):" + Environment.NewLine;
809 strCode +=
" def __init__(self, input_dim: int, dropout: Optional[float] = None):" + Environment.NewLine;
810 strCode +=
" super(GateAddNorm, self).__init__()" + Environment.NewLine;
811 strCode +=
" self.input_dim = input_dim" + Environment.NewLine;
812 strCode +=
" self.dropout = dropout" + Environment.NewLine;
813 strCode += Environment.NewLine;
814 strCode +=
" if self.dropout is not None:" + Environment.NewLine;
815 strCode +=
" self.dropout_layer = nn.Dropout(p=self.dropout)" + Environment.NewLine;
816 strCode +=
" self.gate = TimeDistributed(GLU(self.input_dim), batch_first=True)" + Environment.NewLine;
817 strCode +=
" self.layernorm = TimeDistributed(nn.LayerNorm(self.input_dim), batch_first=True)" + Environment.NewLine;
818 strCode += Environment.NewLine;
819 strCode +=
" def init_weights(self):" + Environment.NewLine;
820 strCode +=
" self.gate.init_weights()" + Environment.NewLine;
821 strCode += Environment.NewLine;
822 strCode +=
" def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:" + Environment.NewLine;
823 strCode +=
" if self.dropout is not None:" + Environment.NewLine;
824 strCode +=
" x = self.dropout_layer(x)" + Environment.NewLine;
825 strCode += Environment.NewLine;
826 strCode +=
" x = self.gate(x)" + Environment.NewLine;
827 strCode +=
" if residual is not None:" + Environment.NewLine;
828 strCode +=
" x = x + residual" + Environment.NewLine;
829 strCode += Environment.NewLine;
830 strCode +=
" x = self.layernorm(x)" + Environment.NewLine;
831 strCode +=
" return x" + Environment.NewLine;
832 strCode += Environment.NewLine;
834 m_nGenerationCount++;
839 private string generateClassComments()
843 strCode += LayerInfo.generateBar();
844 strCode +=
"# The GateAddNorm class performs a dropout, residual connection and layer normalization." + Environment.NewLine;
845 strCode +=
"#" + Environment.NewLine;
846 strCode +=
"# Parameters" + Environment.NewLine;
847 strCode +=
"# ----------" + Environment.NewLine;
848 strCode +=
"# input_dim : int" + Environment.NewLine;
849 strCode +=
"# The attribute/embedding dimension of the input, associated with the 'state_size' of the model.." + Environment.NewLine;
850 strCode +=
"# dropout : float" + Environment.NewLine;
851 strCode +=
"# The dropout rate associated with the 'GRN' classes." + Environment.NewLine;
857 class MultiheadAttentionInterpLayerInfo : LayerInfo
859 static int m_nGenerationCreditCount = 0;
860 static int m_nGenerationCount = 0;
862 public MultiheadAttentionInterpLayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
866 public override string Generate(GENERATE gen)
869 if (gen == GENERATE.CREDITS)
871 strCode += generateCredits();
873 else if (gen == GENERATE.CLASSES)
875 strCode += generateMultiheadAttentionInterpClass(m_layer);
877 if (gen == GENERATE.DEFINITION)
879 strCode +=
" self." + m_layer.name +
" = MultiheadAttentionInterp(embed_dim=" + m_layer.multihead_attention_interp_param.embed_dim.ToString() +
", num_heads=" + m_layer.multihead_attention_interp_param.num_heads.ToString() +
")" + Environment.NewLine;
881 else if (gen == GENERATE.INITWEIGHTS)
883 strCode +=
" self." + m_layer.name +
".init_weights()" + Environment.NewLine;
885 else if (gen == GENERATE.FORWARD)
886 strCode +=
" " + m_outputs.AsText +
" = self." + m_layer.name +
"(" + m_inputs.AsText +
")" + Environment.NewLine;
891 public static string generateCredits()
895 if (m_nGenerationCreditCount > 0)
898 strCode +=
"# MultiheadAttentionInterp Layer:" + Environment.NewLine;
899 strCode +=
"# original code <https://github.com/PlaytikaOSS/tft-torch/tree/main>" + Environment.NewLine;
900 strCode +=
"# license: (MIT) <https://github.com/PlaytikaOSS/tft-torch/blob/main/LICENSE>" + Environment.NewLine;
902 m_nGenerationCreditCount++;
907 private string generateMultiheadAttentionInterpClass(
LayerParameter p)
911 if (m_nGenerationCount > 0)
915 strCode += generateClassComments();
917 strCode +=
"class MultiheadAttentionInterp(nn.Module):" + Environment.NewLine;
918 strCode +=
" def __init__(self, embed_dim: int, num_heads: int):" + Environment.NewLine;
919 strCode +=
" super(MultiheadAttentionInterp, self).__init__()" + Environment.NewLine;
920 strCode +=
" self.d_model = embed_dim" + Environment.NewLine;
921 strCode +=
" self.num_heads = num_heads" + Environment.NewLine;
922 strCode +=
" self.all_heads_dim = embed_dim * num_heads" + Environment.NewLine;
923 strCode += Environment.NewLine;
924 strCode +=
" self.w_q = nn.Linear(self.embed_dim, self.all_heads_dim)" + Environment.NewLine;
925 strCode +=
" self.w_k = nn.Linear(self.embed_dim, self.all_heads_dim)" + Environment.NewLine;
926 strCode +=
" self.w_v = nn.Linear(self.embed_dim, self.ebmed_dim)" + Environment.NewLine;
927 strCode += Environment.NewLine;
928 strCode +=
" self.out = nn.Linear(self.d_model, self.d_model)" + Environment.NewLine;
929 strCode += Environment.NewLine;
930 strCode +=
" def init_weights(self):" + Environment.NewLine;
931 strCode += initWeights(
" ",
"self.w_q",
true, m_layer.multihead_attention_interp_param.weight_filler, m_layer.multihead_attention_interp_param.bias_filler);
932 strCode += initWeights(
" ",
"self.w_k",
true, m_layer.multihead_attention_interp_param.weight_filler, m_layer.multihead_attention_interp_param.bias_filler);
933 strCode += initWeights(
" ",
"self.w_v",
true, m_layer.multihead_attention_interp_param.weight_filler, m_layer.multihead_attention_interp_param.bias_filler);
934 strCode += initWeights(
" ",
"self.out",
true, m_layer.multihead_attention_interp_param.weight_filler, m_layer.multihead_attention_interp_param.bias_filler);
935 strCode += Environment.NewLine;
936 strCode +=
" def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:" + Environment.NewLine;
937 strCode +=
" num_samples = q.size(0)" + Environment.NewLine;
938 strCode += Environment.NewLine;
939 strCode +=
" q_proj = self.w_q(q).view(num_samples, -1, self.num_heads, self.d_model)" + Environment.NewLine;
940 strCode +=
" k_proj = self.w_k(k).view(num_samples, -1, self.num_heads, self.d_model)" + Environment.NewLine;
941 strCode +=
" v_proj = self.w_v(v).repeat(1, 1, self.num_heads).view(num_samples, -1, self.num_heads, self.d_model)" + Environment.NewLine;
942 strCode += Environment.NewLine;
943 strCode +=
" q_proj = q_proj.transpose(1, 2)" + Environment.NewLine;
944 strCode +=
" k_proj = k_proj.transpose(1, 2)" + Environment.NewLine;
945 strCode +=
" v_proj = v_proj.transpose(1, 2)" + Environment.NewLine;
946 strCode += Environment.NewLine;
947 strCode +=
" attn_outputs_all_heads, attn_scores_all_heads = self.attention(q_proj, k_proj, v_proj, mask)" + Environment.NewLine;
948 strCode += Environment.NewLine;
949 strCode +=
" attn_scores = attn_scores_all_heads.mean(dim=1)" + Environment.NewLine;
950 strCode +=
" attn_outputs = attn_outputs_all_heads.mean(dim=1)" + Environment.NewLine;
951 strCode += Environment.NewLine;
952 strCode +=
" attn_outputs = self.out(attn_outputs)" + Environment.NewLine;
953 strCode += Environment.NewLine;
954 strCode +=
" return attn_outputs, attn_scores" + Environment.NewLine;
955 strCode += Environment.NewLine;
956 strCode +=
" def attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:" + Environment.NewLine;
957 strCode +=
" scores = torch.matmul(q, k.transpose(-2, -1))" + Environment.NewLine;
958 strCode += Environment.NewLine;
959 strCode +=
" if mask is not None:" + Environment.NewLine;
960 strCode +=
" scores = scores.masked_fill(mask, -1e9)" + Environment.NewLine;
961 strCode += Environment.NewLine;
962 strCode +=
" scores = F.softmax(scores, dim=-1)" + Environment.NewLine;
963 strCode +=
" outputs = torch.matmul(scores, v)" + Environment.NewLine;
964 strCode += Environment.NewLine;
965 strCode +=
" return outputs, scores" + Environment.NewLine;
966 strCode += Environment.NewLine;
968 m_nGenerationCount++;
973 private string generateClassComments()
977 strCode += LayerInfo.generateBar();
978 strCode +=
"# The MultiheadAttentionInterp class learns long-term relationsips across different time-steps." + Environment.NewLine;
979 strCode +=
"# A multi-head attention is modified to enhance explainability. With traditional multi-head attention" + Environment.NewLine;
980 strCode +=
"# the 'values' signal is shared for all heads an additive aggregation is employed across all heads." + Environment.NewLine;
981 strCode +=
"# However, according to the paper, each head can learn different temporal patterns, while attending to a common set" + Environment.NewLine;
982 strCode +=
"# of input features which can be interpreted as a simple ensemble over attention weights in a combined matrix, which" + Environment.NewLine;
983 strCode +=
"# compared to the original multi-head attention matrix, yields an increased representational capacity." + Environment.NewLine;
984 strCode +=
"#" + Environment.NewLine;
985 strCode +=
"# Parameters" + Environment.NewLine;
986 strCode +=
"# ----------" + Environment.NewLine;
987 strCode +=
"# embed_dim : int" + Environment.NewLine;
988 strCode +=
"# The dimensions associated with the 'state_size' of the model, corresponding to the input and output." + Environment.NewLine;
989 strCode +=
"# num_heads : float" + Environment.NewLine;
990 strCode +=
"# The number of heads used by the multi-head attention component." + Environment.NewLine;
996 class ReshapeTemporalLayerInfo : LayerInfo
998 static int m_nGenerationCountBefore = 0;
999 static int m_nGenerationCountAfter = 0;
1001 public ReshapeTemporalLayerInfo(
LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
1005 public override string Generate(GENERATE gen)
1007 string strCode =
"";
1008 if (gen == GENERATE.CLASSES)
1010 strCode += generateReshapeTemporalClass(m_layer);
1012 if (gen == GENERATE.DEFINITION)
1014 if (m_layer.reshape_temporal_param.mode == param.tft.ReshapeTemporalParameter.MODE.BEFORE)
1015 strCode +=
" self." + m_layer.name +
" = ReshapeTemporalBefore()" + Environment.NewLine;
1016 else if (m_layer.reshape_temporal_param.mode == param.tft.ReshapeTemporalParameter.MODE.AFTER)
1017 strCode +=
" self." + m_layer.name +
" = ReshapeTemporalAfter()" + Environment.NewLine;
1019 else if (gen == GENERATE.INITWEIGHTS)
1022 else if (gen == GENERATE.FORWARD)
1023 strCode +=
" " + m_outputs.AsText +
" = self." + m_layer.name +
"(" + m_inputs.AsText +
")" + Environment.NewLine;
1030 string strCode =
"";
1034 if (m_nGenerationCountBefore == 0)
1036 strCode +=
"class ReshapeTemporalBefore(nn.Module):" + Environment.NewLine;
1037 strCode +=
" def __init__(self):" + Environment.NewLine;
1038 strCode +=
" super(ReshapeTemporalBefore, self).__init__()" + Environment.NewLine;
1039 strCode += Environment.NewLine;
1040 strCode +=
" def forward(self, x: torch.Tensor) -> torch.Tensor:" + Environment.NewLine;
1041 strCode +=
" pass" + Environment.NewLine;
1042 strCode += Environment.NewLine;
1044 m_nGenerationCountBefore++;
1048 if (m_nGenerationCountAfter == 0)
1050 strCode +=
"class ReshapeTemporalAfter(nn.Module):" + Environment.NewLine;
1051 strCode +=
" def __init__(self):" + Environment.NewLine;
1052 strCode +=
" super(ReshapeTemporalAfter, self).__init__()" + Environment.NewLine;
1053 strCode += Environment.NewLine;
1054 strCode +=
" def forward(self, x: torch.Tensor) -> torch.Tensor:" + Environment.NewLine;
1055 strCode +=
" pass" + Environment.NewLine;
1056 strCode += Environment.NewLine;
1058 m_nGenerationCountAfter++;
The Utility class provides general utility funtions.
Specifies the filler parameters used to create each Filler.
Specifies the base parameter for all layers.
GrnParameter grn_param
Returns the parameter set when initialized with LayerType.GLU
ReshapeTemporalParameter reshape_temporal_param
Returns the parameter set when initialized with LayerType.RESHAPE_TEMPORAL
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Phase
Defines the Phase under which to run a Net.
The MyCaffe.fillers namespace contains all fillers including the Filler class.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...