MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
TFT_Layers.cs
1using MyCaffe.basecode;
2using MyCaffe.fillers;
3using MyCaffe.param;
4using System;
5using System.Collections.Generic;
6using System.Linq;
7using System.Text;
8using System.Threading.Tasks;
9
11{
12 class DataTemporalLayerInfo : LayerInfo
13 {
14 static int m_nGenerationCount = 0;
15
16 public DataTemporalLayerInfo(LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
17 {
18 }
19
20 public override string Generate(GENERATE gen)
21 {
22 if (m_layer.include.Count == 0 || m_layer.include[0].phase != Phase.TRAIN)
23 return "";
24
25 string strCode = "";
26 if (gen == GENERATE.CLASSES)
27 {
28 if (m_nGenerationCount == 0)
29 strCode += generateDataTemporalClass(m_layer);
30 m_nGenerationCount++;
31 }
32 if (gen == GENERATE.DEFINITION)
33 {
34 strCode += " self." + m_layer.name + " = DataTemporal()" + Environment.NewLine;
35 }
36 else if (gen == GENERATE.INITWEIGHTS)
37 {
38 }
39 else if (gen == GENERATE.FORWARD)
40 strCode += " " + m_outputs.AsText + " = self." + m_layer.name + "(" + m_inputs.AsText + ")" + Environment.NewLine;
41
42 return strCode;
43 }
44
45 private string generateDataTemporalClass(LayerParameter p)
46 {
47 string strCode = "";
48
49 strCode += "class DataTemporal(nn.Module):" + Environment.NewLine;
50 strCode += " def __init__(self):" + Environment.NewLine;
51 strCode += " super(DataTemporal, self).__init__()" + Environment.NewLine;
52 strCode += Environment.NewLine;
53 strCode += " def forward(self) -> torch.Tensor:" + Environment.NewLine;
54 strCode += " return None" + Environment.NewLine;
55 strCode += Environment.NewLine;
56
57 return strCode;
58 }
59 }
60
61 class ChannelEmbeddingLayerInfo : LayerInfo
62 {
63 static int m_nGenerationCreditCount = 0;
64 static int m_nGenerationCount = 0;
65
66 public ChannelEmbeddingLayerInfo(LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
67 {
68 }
69
70 public override string Generate(GENERATE gen)
71 {
72 string strCode = "";
73 if (gen == GENERATE.CREDITS)
74 {
75 strCode += generateCredits();
76 }
77 else if (gen == GENERATE.CLASSES)
78 {
79 if (m_nGenerationCount == 0)
80 {
81 strCode += generateNullTransformationClass(m_layer);
82 strCode += generateTimeDistributedClass(m_layer);
83 strCode += generateNumericalTransformationClass(m_layer);
84 strCode += generateCategoricalTransformationClass(m_layer);
85 strCode += generateChannelEmbeddingClass(m_layer);
86 }
87
88 m_nGenerationCount++;
89 }
90 else if (gen == GENERATE.DEFINITION)
91 {
92 bool bTimeDistributed = false;
93 string strCardinalities = Utility.ToString<int>(m_layer.categorical_trans_param.cardinalities, -1, -1, "[", "]");
94
95 strCode += " self.time_distributed = " + bTimeDistributed.ToString() + Environment.NewLine;
96 strCode += " self.categorical_cardinalities = " + strCardinalities + Environment.NewLine;
97 strCode += " self." + m_layer.name + " = nn.ChannelEmbedding(state_size=" + m_layer.numeric_trans_param.state_size.ToString() + ",num_numeric=" + m_layer.numeric_trans_param.num_input.ToString() + ",num_categorical=" + m_layer.categorical_trans_param.num_input.ToString() + ",categorical_cardinalities=self.categorical_cardinalities,time_distributed=self.time_distributed)" + Environment.NewLine;
98 }
99 else if (gen == GENERATE.INITWEIGHTS)
100 {
101 strCode += " self." + m_layer.name + ".init_weights()" + Environment.NewLine;
102 }
103 else if (gen == GENERATE.FORWARD)
104 {
105 strCode += " " + m_outputs.AsText + " = self." + m_layer.name + "(" + m_inputs.AsText + ")" + Environment.NewLine;
106 }
107
108 return strCode;
109 }
110
111 public static string generateCredits()
112 {
113 string strCode = "";
114
115 if (m_nGenerationCreditCount > 0)
116 return strCode;
117
118 strCode += "# NullTransform, TimeDistributed, NumericalTransformation, CategoricalTransformation, ChannelEmbedding Layers:" + Environment.NewLine;
119 strCode += "# original code <https://github.com/PlaytikaOSS/tft-torch/tree/main>" + Environment.NewLine;
120 strCode += "# license: (MIT) <https://github.com/PlaytikaOSS/tft-torch/blob/main/LICENSE>" + Environment.NewLine;
121
122 m_nGenerationCreditCount++;
123
124 return strCode;
125 }
126
127 private string generateNullTransformationClass(LayerParameter p)
128 {
129 string strCode = "";
130
131 strCode += "class NullTransformation(nn.Module):" + Environment.NewLine;
132 strCode += " def __init__(self):" + Environment.NewLine;
133 strCode += " super(NullTransformation, self).__init__()" + Environment.NewLine;
134 strCode += Environment.NewLine;
135 strCode += " def forward(self, x: torch.Tensor):" + Environment.NewLine;
136 strCode += " return []" + Environment.NewLine;
137 strCode += Environment.NewLine;
138
139 return strCode;
140 }
141
142 private string generateTimeDistributedClass(LayerParameter p)
143 {
144 string strCode = "";
145
146 strCode += "class TimeDistributed(nn.Module):" + Environment.NewLine;
147 strCode += " def __init__(self, module: nn.Module, batch_first: bool = True, return_reshaped: bool = True):" + Environment.NewLine;
148 strCode += " super(TimeDistributed, self).__init__()" + Environment.NewLine;
149 strCode += " self.module = module" + Environment.NewLine;
150 strCode += " self.batch_first = batch_first" + Environment.NewLine;
151 strCode += " self.return_reshaped = return_reshaped" + Environment.NewLine;
152 strCode += Environment.NewLine;
153 strCode += " def forward(self, x: torch.Tensor):" + Environment.NewLine;
154 strCode += " if len(x.shape) <= 2:" + Environment.NewLine;
155 strCode += " return self.module(x)" + Environment.NewLine;
156 strCode += Environment.NewLine;
157 strCode += " x_reshape = x.contiguous().view(-1, x.shape[-1])" + Environment.NewLine;
158 strCode += " y = self.module(x_reshape)" + Environment.NewLine;
159 strCode += Environment.NewLine;
160 strCode += " if self.return_reshaped:" + Environment.NewLine;
161 strCode += " y = y.contiguous().view(x.shape[0], -1, y.shape[-1])" + Environment.NewLine;
162 strCode += " else:" + Environment.NewLine;
163 strCode += " y = y.contiguous().view(-1, x.shape[1], y.shape[-1])" + Environment.NewLine;
164 strCode += Environment.NewLine;
165 strCode += " return y" + Environment.NewLine;
166 strCode += Environment.NewLine;
167
168 return strCode;
169 }
170
171 private string generateNumericalTransformationClass(LayerParameter p)
172 {
173 FillerParameter wtfiller = new FillerParameter("xavier");
174 FillerParameter biasfiller = new FillerParameter("constant");
175 string strCode = "";
176
177 if (m_bAddComments)
178 strCode += generateNumericalTransformationClassComments();
179
180 strCode += "class NumericalTransformation(nn.Module):" + Environment.NewLine;
181 strCode += " def __init__(self, num_inputs: int, state_size: int):" + Environment.NewLine;
182 strCode += " super(NumericalTransformation, self).__init__()" + Environment.NewLine;
183 strCode += " self.num_inputs = num_inputs" + Environment.NewLine;
184 strCode += " self.state_size = state_size" + Environment.NewLine;
185 strCode += Environment.NewLine;
186 strCode += " self.numeric_projection_layers = nn.ModuleList()" + Environment.NewLine;
187 strCode += " for i in range(num_inputs):" + Environment.NewLine;
188 strCode += " self.numeric_projection_layers.append(nn.Linear(1, self.state_size))" + Environment.NewLine;
189 strCode += Environment.NewLine;
190 strCode += " def init_weights(self):" + Environment.NewLine;
191 strCode += " for i in range(self.num_inputs):" + Environment.NewLine;
192 strCode += initWeights(" ", "self.numeric_projection_layers[i]", true, wtfiller, biasfiller);
193 strCode += Environment.NewLine;
194 strCode += " def forward(self, x: torch.Tensor):" + Environment.NewLine;
195 strCode += " projections = []" + Environment.NewLine;
196 strCode += " for i in range(self.num_inputs):" + Environment.NewLine;
197 strCode += " x1 = x[:,[i]]" + Environment.NewLine;
198 strCode += " x2 = self.numeric_projection_layers[i](x1)" + Environment.NewLine;
199 strCode += " projections.append(x2)" + Environment.NewLine;
200 strCode += " return projections" + Environment.NewLine;
201 strCode += Environment.NewLine;
202
203 return strCode;
204 }
205
206 private string generateNumericalTransformationClassComments()
207 {
208 string strCode = "";
209
210 strCode += LayerInfo.generateBar();
211 strCode += "# The NumericalTransformation class transforms the numerical input into a set of projections for each input." + Environment.NewLine;
212 strCode += "# Each input is projected using a dedicated linear layer to a vector within the state_size, which is." + Environment.NewLine;
213 strCode += "# output as a list of length num_inputs that contains each embedding." + Environment.NewLine;
214 strCode += "#" + Environment.NewLine;
215 strCode += "# Parameters" + Environment.NewLine;
216 strCode += "# ----------" + Environment.NewLine;
217 strCode += "# num_input : int" + Environment.NewLine;
218 strCode += "# The number of numerical inputs." + Environment.NewLine;
219 strCode += "# state_size : int" + Environment.NewLine;
220 strCode += "# The state size of the model, which determines the embedding dimension/width for each input variable." + Environment.NewLine;
221
222 return strCode;
223 }
224
225 private string generateCategoricalTransformationClass(LayerParameter p)
226 {
227 FillerParameter wtfiller = new FillerParameter("xavier");
228 string strCode = "";
229
230 if (m_bAddComments)
231 strCode += generateCategoricalTransformationClassComments();
232
233 strCode += "class CategoricalTransformation(nn.Module):" + Environment.NewLine;
234 strCode += " def __init__(self, num_inputs: int, state_size: int, cardinalities: List[int]):" + Environment.NewLine;
235 strCode += " super(CategoricalTransformation, self).__init__()" + Environment.NewLine;
236 strCode += " self.num_inputs = num_inputs" + Environment.NewLine;
237 strCode += " self.state_size = state_size" + Environment.NewLine;
238 strCode += " self.cardinalities = cardinalities" + Environment.NewLine;
239 strCode += Environment.NewLine;
240 strCode += " self.categorical_embedding_layers = nn.ModuleList()" + Environment.NewLine;
241 strCode += " for i, cardinality in enumerate(self.cardinalities):" + Environment.NewLine;
242 strCode += " self.categorical_embedding_layers.append(nn.Embedding(cardinality, self.state_size))" + Environment.NewLine;
243 strCode += Environment.NewLine;
244 strCode += " def init_weights(self):" + Environment.NewLine;
245 strCode += " for i, cardinality in enumerate(self.cardinalities):" + Environment.NewLine;
246 strCode += initWeights(" ", "self.categorical_embedding_layers[i]", false, wtfiller, null);
247 strCode += Environment.NewLine;
248 strCode += " def forward(self, x: torch.Tensor):" + Environment.NewLine;
249 strCode += " embeddings = []" + Environment.NewLine;
250 strCode += " for i in range(self.num_inputs):" + Environment.NewLine;
251 strCode += " x1 = x[:,i]" + Environment.NewLine;
252 strCode += " x2 = self.categorical_embedding_layers[i](x1)" + Environment.NewLine;
253 strCode += " embeddings.append(x2)" + Environment.NewLine;
254 strCode += " return embeddings" + Environment.NewLine;
255 strCode += Environment.NewLine;
256
257 return strCode;
258 }
259
260 private string generateCategoricalTransformationClassComments()
261 {
262 string strCode = "";
263
264 strCode += LayerInfo.generateBar();
265 strCode += "# The CategoricalTransformation class transforms the categorical input into a set of embeddings for each input." + Environment.NewLine;
266 strCode += "# Each input is projected using a dedicated embedding layer to a vector within the state_size, which is." + Environment.NewLine;
267 strCode += "# output as a list of length num_inputs that contains each embedding." + Environment.NewLine;
268 strCode += "#" + Environment.NewLine;
269 strCode += "# Parameters" + Environment.NewLine;
270 strCode += "# ----------" + Environment.NewLine;
271 strCode += "# num_input : int" + Environment.NewLine;
272 strCode += "# The number of categorical inputs." + Environment.NewLine;
273 strCode += "# state_size : int" + Environment.NewLine;
274 strCode += "# The state size of the model, which determines the embedding dimension/width for each input variable." + Environment.NewLine;
275 strCode += "# cadinalities : List[int]" + Environment.NewLine;
276 strCode += "# The cardinality of each categorical input." + Environment.NewLine;
277
278 return strCode;
279 }
280
281 private string generateChannelEmbeddingClass(LayerParameter p)
282 {
283 string strCode = "";
284
285 if (m_bAddComments)
286 strCode += generateChannelEmbeddingClassComments();
287
288 strCode += "class ChannelEmbedding(nn.Module):" + Environment.NewLine;
289 strCode += " def __init__(self, state_size: int, num_numeric: int, num_categorical: int, categorical_cardinalities: List[int], time_distribute: Optional[bool] = False):" + Environment.NewLine;
290 strCode += " super(ChannelEmbedding, self).__init__()" + Environment.NewLine;
291 strCode += " self.state_size = state_size" + Environment.NewLine;
292 strCode += " self.num_numeric = num_numeric" + Environment.NewLine;
293 strCode += " self.num_categorical = num_categorical" + Environment.NewLine;
294 strCode += " self.categorical_cardinalities = categorical_cardinalities" + Environment.NewLine;
295 strCode += Environment.NewLine;
296 strCode += " if num_numeric > 0:" + Environment.NewLine;
297 strCode += " if time_distribute:" + Environment.NewLine;
298 strCode += " self.numerical_transformation = TimeDistributed(NumericalTransformation(num_inputs=self.num_numeric, state_size=self.state_size))" + Environment.NewLine;
299 strCode += " else:" + Environment.NewLine;
300 strCode += " self.numerical_transformation = NumericalTransformation(num_inputs=self.num_numeric, state_size=self.state_size)" + Environment.NewLine;
301 strCode += " else:" + Environment.NewLine;
302 strCode += " self.numerical_transformation = NullTransformation()" + Environment.NewLine;
303 strCode += Environment.NewLine;
304 strCode += " if num_categorical > 0:" + Environment.NewLine;
305 strCode += " if time_distribute:" + Environment.NewLine;
306 strCode += " self.categorical_transformation = TimeDistributed(CategoricalTransformation(num_inputs=self.num_categorical, state_size=self.state_size, cardinalities=self.categorical_cardinalities))" + Environment.NewLine;
307 strCode += " else:" + Environment.NewLine;
308 strCode += " self.categorical_transformation = CategoricalTransformation(num_inputs=self.num_categorical, state_size=self.state_size, cardinalities=self.categorical_cardinalities)" + Environment.NewLine;
309 strCode += " else:" + Environment.NewLine;
310 strCode += " self.categorical_transformation = NullTransformation()" + Environment.NewLine;
311 strCode += Environment.NewLine;
312 strCode += " def init_weights(self):" + Environment.NewLine;
313 strCode += " self.numeric_transformation.init_weights()" + Environment.NewLine;
314 strCode += " self.categorical_transformation.init_weights()" + Environment.NewLine;
315 strCode += Environment.NewLine;
316 strCode += " def forward(self, x_num: torch.Tensor, x_cat: torch.Tensor) -> torch.Tensor:" + Environment.NewLine;
317 strCode += " batch_shape = x_num.shape if x_num.nelement() > 0 else x_cat.shape" + Environment.NewLine;
318 strCode += " processed_num = self.numerical_transformation(x_num)" + Environment.NewLine;
319 strCode += " processed_cat = self.categorical_transformation(x_cat)" + Environment.NewLine;
320 strCode += " merged_transformations = torch.cat(processed_num + processed_cat, dim=1)" + Environment.NewLine;
321 strCode += Environment.NewLine;
322 strCode += " return merged_transformations" + Environment.NewLine;
323 strCode += Environment.NewLine;
324
325 return strCode;
326 }
327
328 private string generateChannelEmbeddingClassComments()
329 {
330 string strCode = "";
331
332 strCode += LayerInfo.generateBar();
333 strCode += "# The ChannelEmbedding class handles the transformation/embedding of the input channel of numeric and categorical tensors." + Environment.NewLine;
334 strCode += "# A NumericalTransformation class is used to process the numerical inputs, and" + Environment.NewLine;
335 strCode += "# a CategoricalTransformation class is used to process the categorical inputs." + Environment.NewLine;
336 strCode += "#" + Environment.NewLine;
337 strCode += "# Parameters" + Environment.NewLine;
338 strCode += "# ----------" + Environment.NewLine;
339 strCode += "# state_size : int" + Environment.NewLine;
340 strCode += "# The state size of the model, which determines the embedding dimension/width for each input variable." + Environment.NewLine;
341 strCode += "# num_numeric : int" + Environment.NewLine;
342 strCode += "# The number of numeric inputs." + Environment.NewLine;
343 strCode += "# num_categorical : int" + Environment.NewLine;
344 strCode += "# The number of categorical inputs." + Environment.NewLine;
345 strCode += "# categorical_cardinalities : List[int]" + Environment.NewLine;
346 strCode += "# The cardinality of each categorical input." + Environment.NewLine;
347 strCode += "# time_distribute : Optional[bool] = False" + Environment.NewLine;
348 strCode += "# When True, the TimeDistributed transformation is applied to each time step of the input tensor." + Environment.NewLine;
349
350 return strCode;
351 }
352 }
353
354 class GLULayerInfo : LayerInfo
355 {
356 static int m_nGenerationCreditCount = 0;
357 static int m_nGenerationCount = 0;
358
359 public GLULayerInfo(LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
360 {
361 }
362
363 public override string Generate(GENERATE gen)
364 {
365 string strCode = "";
366 if (gen == GENERATE.CREDITS)
367 {
368 strCode += generateGluCredits();
369 }
370 else if (gen == GENERATE.CLASSES)
371 {
372 strCode += generateGluClass(m_layer.glu_param.bias_term, m_layer.glu_param.weight_filler, m_layer.glu_param.bias_filler, m_bAddComments);
373 }
374 if (gen == GENERATE.DEFINITION)
375 {
376 strCode += " self." + m_layer.name + " = GLU(input_dim=" + m_layer.glu_param.input_dim.ToString() + ")" + Environment.NewLine;
377 }
378 else if (gen == GENERATE.INITWEIGHTS)
379 {
380 strCode += " self." + m_layer.name + ".init_weights()" + Environment.NewLine;
381 }
382 else if (gen == GENERATE.FORWARD)
383 strCode += " " + m_outputs.AsText + " = self." + m_layer.name + "(" + m_inputs.AsText + ")" + Environment.NewLine;
384
385 return strCode;
386 }
387
388 public static string generateGluCredits()
389 {
390 string strCode = "";
391
392 if (m_nGenerationCreditCount > 0)
393 return strCode;
394
395 strCode += "# GLU Layer:" + Environment.NewLine;
396 strCode += "# original code <https://github.com/PlaytikaOSS/tft-torch/tree/main>" + Environment.NewLine;
397 strCode += "# license: (MIT) <https://github.com/PlaytikaOSS/tft-torch/blob/main/LICENSE>" + Environment.NewLine;
398
399 m_nGenerationCreditCount++;
400
401 return strCode;
402 }
403
404 public static string generateGluClass(bool bBiasTerm, FillerParameter wtFiller, FillerParameter biasFiller, bool bAddComments)
405 {
406 string strCode = "";
407
408 if (m_nGenerationCount == 0)
409 {
410 if (bAddComments)
411 strCode += generateGluClassComments();
412 strCode += "class GLU(nn.Module):" + Environment.NewLine;
413 strCode += " def __init__(self, input_dim: int):" + Environment.NewLine;
414 strCode += " super(GLU, self).__init__()" + Environment.NewLine;
415 strCode += " self.input_dim = input_dim" + Environment.NewLine;
416 strCode += " self.fc1 = nn.Linear(self.input_dim, self.input_dim)" + Environment.NewLine;
417 strCode += " self.fc2 = nn.Linear(self.input_dim, self.input_dim)" + Environment.NewLine;
418 strCode += " self.sigmoid = nn.Sigmoid()" + Environment.NewLine;
419 strCode += Environment.NewLine;
420 strCode += " def init_weights(self):" + Environment.NewLine;
421 strCode += initWeights(" ", "self.fc1", bBiasTerm, wtFiller, biasFiller);
422 strCode += initWeights(" ", "self.fc2", bBiasTerm, wtFiller, biasFiller);
423 strCode += Environment.NewLine;
424 strCode += " def forward(self, x: torch.Tensor) -> torch.Tensor:" + Environment.NewLine;
425 strCode += " x = self.fc1(x)" + Environment.NewLine;
426 strCode += " sig = self.sigmoid(x)" + Environment.NewLine;
427 strCode += " x = self.fc2(x)" + Environment.NewLine;
428 strCode += " return torch.mul(sig, x)" + Environment.NewLine;
429 strCode += Environment.NewLine;
430 }
431
432 m_nGenerationCount++;
433
434 return strCode;
435 }
436
437 private static string generateGluClassComments()
438 {
439 string strCode = "";
440
441 strCode += LayerInfo.generateBar();
442 strCode += "# The GLU class defines the Gated Linear Unit (GLU) layer, as described" + Environment.NewLine;
443 strCode += "# in Dauphin, Yann N., et al. \"Language modeling with gated convolutional networks.\" arXiv preprint arXiv:1612.08083 (2016)." + Environment.NewLine;
444 strCode += "# <https://arxiv.org/pdf/1612.08083.pdf>" + Environment.NewLine;
445 strCode += "#" + Environment.NewLine;
446 strCode += "# The output of this layer is a linear projection (X * W + b) modulated by the gates Sigmoid(X * V + c)." + Environment.NewLine;
447 strCode += "# These gates multiply each element of the matrix X * W and control the information passed on in the heirarchy." + Environment.NewLine;
448 strCode += "# This unit is simplified gating mechanism for non-deterministic gates that reduce the vanishing gradient problem," + Environment.NewLine;
449 strCode += "# by having linear units coupled to the gates. This retains the non-linear capabilities of the network while allowing" + Environment.NewLine;
450 strCode += "# the gradient to propagate through the linear unit without scaling." + Environment.NewLine;
451 strCode += "#" + Environment.NewLine;
452 strCode += "# Parameters" + Environment.NewLine;
453 strCode += "# ----------" + Environment.NewLine;
454 strCode += "# input_dim : int" + Environment.NewLine;
455 strCode += "# The embedding width/dimension of the input." + Environment.NewLine;
456
457 return strCode;
458 }
459 }
460
461 class GRNLayerInfo : LayerInfo
462 {
463 static int m_nGenerationCreditCount = 0;
464 static int m_nGenerationCount = 0;
465
466 public GRNLayerInfo(LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
467 {
468 }
469
470 public override string Generate(GENERATE gen)
471 {
472 string strCode = "";
473 if (gen == GENERATE.CREDITS)
474 {
475 strCode += generateCredits();
476 }
477 else if (gen == GENERATE.CLASSES)
478 {
479 strCode += generateGrnClass(m_layer);
480 }
481 if (gen == GENERATE.DEFINITION)
482 {
483 strCode += " self." + m_layer.name + " = GRN(input_dim=" + m_layer.grn_param.input_dim.ToString() + ", hidden_dim=" + m_layer.grn_param.hidden_dim.ToString() + ", output_dim=" + m_layer.grn_param.output_dim.ToString() + ", dropout=" + m_layer.grn_param.dropout_ratio.ToString() + ", context_dim=" + m_layer.grn_param.context_dim.ToString() + ", batch_first=" + m_layer.grn_param.batch_first.ToString() + ", activation=\'" + m_layer.grn_param.activation.ToString() + "\')" + Environment.NewLine;
484 }
485 else if (gen == GENERATE.INITWEIGHTS)
486 {
487 strCode += " self." + m_layer.name + ".init_weights()" + Environment.NewLine;
488 }
489 else if (gen == GENERATE.FORWARD)
490 strCode += " " + m_outputs.AsText + " = self." + m_layer.name + "(" + m_inputs.AsText + ")" + Environment.NewLine;
491
492 return strCode;
493 }
494
495 public static string generateCredits()
496 {
497 string strCode = "";
498
499 if (m_nGenerationCreditCount++ > 0)
500 return strCode;
501
502 strCode += "# GLU, GRN Layers:" + Environment.NewLine;
503 strCode += "# original code <https://github.com/PlaytikaOSS/tft-torch/tree/main>" + Environment.NewLine;
504 strCode += "# license: (MIT) <https://github.com/PlaytikaOSS/tft-torch/blob/main/LICENSE>" + Environment.NewLine;
505
506 m_nGenerationCreditCount++;
507
508 return strCode;
509 }
510
511 private string generateGrnClass(LayerParameter p)
512 {
513 string strCode = "";
514
515 if (m_nGenerationCount > 0)
516 return strCode;
517
518 strCode += GLULayerInfo.generateGluClass(true, p.grn_param.weight_filler, p.grn_param.bias_filler, m_bAddComments);
519
520 if (m_bAddComments)
521 strCode += generateClassComments(p);
522
523 strCode += "class GRN(nn.Module):" + Environment.NewLine;
524 strCode += " def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: Optional[float] = 0.05, context_dim: Optional[int] = None, batch_first: Optional[bool] = True, activation: Optional[string] = 'ELU'):" + Environment.NewLine;
525 strCode += " super(GRN, self).__init__()" + Environment.NewLine;
526 strCode += " self.input_dim = input_dim" + Environment.NewLine;
527 strCode += " self.hidden_dim = hidden_dim" + Environment.NewLine;
528 strCode += " self.output_dim = output_dim" + Environment.NewLine;
529 strCode += " self.context_dim = context_dim" + Environment.NewLine;
530 strCode += " self.batch_first = batch_first" + Environment.NewLine;
531 strCode += " self.dropout = dropout" + Environment.NewLine;
532 strCode += Environment.NewLine;
533 strCode += " self.project_residual: bool = self.input_dim != self.output_dim" + Environment.NewLine;
534 strCode += " if self.project_residual:" + Environment.NewLine;
535 strCode += " self.skip_layer = TimeDistributed(nn.Linear(self.input_dim, self.output_dim))" + Environment.NewLine;
536 strCode += Environment.NewLine;
537 strCode += " self.fc1 = TimeDistributed(nn.Linear(self.input_dim, self.hidden_dim), batch_first=batch_first)" + Environment.NewLine;
538 strCode += Environment.NewLine;
539 strCode += " if self.context_dim is not None:" + Environment.NewLine;
540 strCode += " self.context_projection = TimeDistributed(nn.Linear(self.context_dim, self.hidden_dim, bias=False), batch_first=batch_first)" + Environment.NewLine;
541 strCode += Environment.NewLine;
542 strCode += " if activation == 'RELU':" + Environment.NewLine;
543 strCode += " self.activation = nn.ReLU()" + Environment.NewLine;
544 strCode += " else:" + Environment.NewLine;
545 strCode += " self.activation = nn.ELU()" + Environment.NewLine;
546 strCode += Environment.NewLine;
547 strCode += " self.fc2 = TimeDistributed(nn.Linear(self.hidden_dim, self.output_dim), batch_first=batch_first)" + Environment.NewLine;
548 strCode += Environment.NewLine;
549 strCode += " self.dropout = nn.Dropout(self.dropout)" + Environment.NewLine;
550 strCode += " self.gate = TimeDistributed(GLU(self.output_dim), batch_first=batch_first)" + Environment.NewLine;
551 strCode += " self.layernorm = TimeDistributed(nn.LayerNorm(self.output_dim), batch_first=batch_first)" + Environment.NewLine;
552 strCode += Environment.NewLine;
553 strCode += " def init_weights(self):" + Environment.NewLine;
554 strCode += initWeights("", "self.skip_layer", true, m_layer.grn_param.weight_filler, m_layer.grn_param.bias_filler);
555 strCode += initWeights("", "self.fc1", true, m_layer.grn_param.weight_filler, m_layer.grn_param.bias_filler);
556 strCode += initWeights("", "self.context_projection", true, m_layer.grn_param.weight_filler, m_layer.grn_param.bias_filler);
557 strCode += initWeights("", "self.fc2", true, m_layer.grn_param.weight_filler, m_layer.grn_param.bias_filler);
558 strCode += " self.gate.init_weights()" + Environment.NewLine;
559 strCode += Environment.NewLine;
560 strCode += " def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:" + Environment.NewLine;
561 strCode += " if self.project_residual:" + Environment.NewLine;
562 strCode += " residual = self.skip_layer(x)" + Environment.NewLine;
563 strCode += " else:" + Environment.NewLine;
564 strCode += " residual = x" + Environment.NewLine;
565 strCode += Environment.NewLine;
566 strCode += " x = self.fc1(x)" + Environment.NewLine;
567 strCode += " if context is not None:" + Environment.NewLine;
568 strCode += " context = self.context_projection(context)" + Environment.NewLine;
569 strCode += " x = x + context" + Environment.NewLine;
570 strCode += Environment.NewLine;
571 strCode += " x = self.activation(x)" + Environment.NewLine;
572 strCode += " x = self.fc2(x)" + Environment.NewLine;
573 strCode += " x = self.dropout(x)" + Environment.NewLine;
574 strCode += " x = self.gate(x)" + Environment.NewLine;
575 strCode += " x = x + residual" + Environment.NewLine;
576 strCode += " x = self.layernorm(x)" + Environment.NewLine;
577 strCode += " return x" + Environment.NewLine;
578 strCode += Environment.NewLine;
579
580 m_nGenerationCount++;
581
582 return strCode;
583 }
584
585 private string generateClassComments(LayerParameter p)
586 {
587 string strCode = "";
588
589 strCode += LayerInfo.generateBar();
590 strCode += "# The GRN class defines the Gated Residual Network layer." + Environment.NewLine;
591 strCode += "#" + Environment.NewLine;
592 strCode += "# The primary consists of the input (x) and an optional context vector (c)." + Environment.NewLine;
593 strCode += "# A GLU is used for controlling the extent to which the module contributes to the original input (x)," + Environment.NewLine;
594 strCode += "# potentially skipping over the layer entirely as the GLU outputs could be close to zero, therefore" + Environment.NewLine;
595 strCode += "# suppressing the non-linear contribution. When no context vector is used, the GRN treats the context" + Environment.NewLine;
596 strCode += "# input as zero. During training, dropout is applied before the gating layer." + Environment.NewLine;
597 strCode += "#" + Environment.NewLine;
598 strCode += "# Parameters" + Environment.NewLine;
599 strCode += "# ----------" + Environment.NewLine;
600 strCode += "# input_dim : int" + Environment.NewLine;
601 strCode += "# The embedding width/dimension of the input." + Environment.NewLine;
602 strCode += "# hidden_dim : int" + Environment.NewLine;
603 strCode += "# The itermediate embedding width." + Environment.NewLine;
604 strCode += "# output_dim : int" + Environment.NewLine;
605 strCode += "# The embedding width of the output." + Environment.NewLine;
606 strCode += "# dropout : Optional[float]" + Environment.NewLine;
607 strCode += "# The dropout ratio to use." + Environment.NewLine;
608 strCode += "# context_dim : int" + Environment.NewLine;
609 strCode += "# The embedding width/dimension of the context siganl expected to be fed as an auxiliary input." + Environment.NewLine;
610 strCode += "# batch_first : bool" + Environment.NewLine;
611 strCode += "# When true, the first dimension of the input and output is the batch size." + Environment.NewLine;
612 return strCode;
613 }
614 }
615
616 class VarSelNetLayerInfo : LayerInfo
617 {
618 static int m_nGenerationCreditCount = 0;
619 static int m_nGenerationCount = 0;
620
621 public VarSelNetLayerInfo(LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
622 {
623 }
624
625 public override string Generate(GENERATE gen)
626 {
627 string strCode = "";
628 if (gen == GENERATE.CREDITS)
629 {
630 strCode += generateCredits();
631 }
632 else if (gen == GENERATE.CLASSES)
633 {
634 strCode += generateVarSelNetClass(m_layer);
635 }
636 if (gen == GENERATE.DEFINITION)
637 {
638 strCode += " self." + m_layer.name + " = VarSelNet(input_dim=" + m_layer.varselnet_param.input_dim.ToString() + ", num_inputs=" + m_layer.varselnet_param.num_inputs.ToString() + ", hidden_dim=" + m_layer.varselnet_param.hidden_dim.ToString() + ", dropout=" + m_layer.varselnet_param.dropout_ratio.ToString() + ", context_dim=" + m_layer.varselnet_param.context_dim.ToString() + ", batch_first=" + m_layer.varselnet_param.batch_first.ToString() + ")" + Environment.NewLine;
639 }
640 else if (gen == GENERATE.INITWEIGHTS)
641 {
642 strCode += " self." + m_layer.name + ".init_weights()" + Environment.NewLine;
643 }
644 else if (gen == GENERATE.FORWARD)
645 strCode += " " + m_outputs.AsText + " = self." + m_layer.name + "(" + m_inputs.AsText + ")" + Environment.NewLine;
646
647 return strCode;
648 }
649
650 public static string generateCredits()
651 {
652 string strCode = "";
653
654 if (m_nGenerationCreditCount > 0)
655 return strCode;
656
657 strCode += "# VarSetNet Layer:" + Environment.NewLine;
658 strCode += "# original code <https://github.com/PlaytikaOSS/tft-torch/tree/main>" + Environment.NewLine;
659 strCode += "# license: (MIT) <https://github.com/PlaytikaOSS/tft-torch/blob/main/LICENSE>" + Environment.NewLine;
660
661 m_nGenerationCreditCount++;
662
663 return strCode;
664 }
665
666 private string generateVarSelNetClass(LayerParameter p)
667 {
668 string strCode = "";
669
670 if (m_nGenerationCount > 0)
671 return strCode;
672
673 if (m_bAddComments)
674 strCode += generateClassComments();
675
676 strCode += "class VarSelNet(nn.Module):" + Environment.NewLine;
677 strCode += " def __init__(self, input_dim: int, num_inputs: int, hidden_dim: int, dropout: float, context_dim: Optional[int] = None, batch_first: Optional[bool] = True):" + Environment.NewLine;
678 strCode += " super(VarSelNet, self).__init__()" + Environment.NewLine;
679 strCode += Environment.NewLine;
680 strCode += " self.hidden_dim = hidden_dim" + Environment.NewLine;
681 strCode += " self.input_dim = input_dim" + Environment.NewLine;
682 strCode += " self.num_inputs = num_inputs" + Environment.NewLine;
683 strCode += " self.dropout = dropout" + Environment.NewLine;
684 strCode += " self.context_dim = context_dim" + Environment.NewLine;
685 strCode += Environment.NewLine;
686 strCode += " self.flattened_grn = GRN(input_dim=self.num_inputs * self.input_dim, hidden_dim=self.hidden_dim, output_dim=self.num_inputs, dropout=self.dropout, context_dim=self.context_dim, batch_first=batch_first)" + Environment.NewLine;
687 strCode += " self.softmax = nn.Softmax(dim=1)" + Environment.NewLine;
688 strCode += Environment.NewLine;
689 strCode += " self.single_variable_grns = nn.ModuleList()" + Environment.NewLine;
690 strCode += " for i in range(self.num_inputs):" + Environment.NewLine;
691 strCode += " self.single_variable_grns.append(GRN(input_dim=self.input_dim, hidden_dim=self.hidden_dim, output_dim=self.hidden_dim, dropout=self.dropout, batch_first=batch_first))" + Environment.NewLine;
692 strCode += Environment.NewLine;
693 strCode += " def init_weights(self):" + Environment.NewLine;
694 strCode += " self.flattened_grn.init_weights()" + Environment.NewLine;
695 strCode += " for i in range(self.num_inputs):" + Environment.NewLine;
696 strCode += " self.single_variable_grns[i].init_weights()" + Environment.NewLine;
697 strCode += Environment.NewLine;
698 strCode += " def forward(self, flattened_embedding: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:" + Environment.NewLine;
699 strCode += " sparse_weights = self.flattened_grn(flattened_embedding, context)" + Environment.NewLine;
700 strCode += " sparse_weights = self.softmax(sparse_weights)" + Environment.NewLine;
701 strCode += Environment.NewLine;
702 strCode += " processed_inputs = []" + Environment.NewLine;
703 strCode += " for i in range(self.num_inputs):" + Environment.NewLine;
704 strCode += "# processed_inputs.append(self.single_variable_grns[i](flattened_embedding[..., i * self.input_dim:(i + 1) * self.input_dim]))" + Environment.NewLine;
705 strCode += " processed_inputs.append(self.single_variable_grns[i](flattened_embedding[:, i * self.input_dim:(i + 1) * self.input_dim]))" + Environment.NewLine;
706 strCode += " processed_inputs = torch.stack(processed_inputs, dim=1)" + Environment.NewLine;
707 strCode += Environment.NewLine;
708 strCode += " outputs = processed_inputs * sparse_weights.transpose(1, 2)" + Environment.NewLine;
709 strCode += " outputs = outputs.sum(axis=-1)" + Environment.NewLine;
710 strCode += Environment.NewLine;
711 strCode += " return outputs, sparse_weights" + Environment.NewLine;
712 strCode += Environment.NewLine;
713
714 m_nGenerationCount++;
715
716 return strCode;
717 }
718
719 private string generateClassComments()
720 {
721 string strCode = "";
722
723 strCode += LayerInfo.generateBar();
724 strCode += "# The VarSelNet class handles the fact that the relevant and specific contribution of each input." + Environment.NewLine;
725 strCode += "# variable to the output is unknown. This class enables instance-wise variable selection, and is applied." + Environment.NewLine;
726 strCode += "# to both the static covariates and time-dependent covariates. In addition to providing insights int which" + Environment.NewLine;
727 strCode += "# variables are most important, this class also enables the model remove any unecessary noisy inputs that" + Environment.NewLine;
728 strCode += "# could negatively impact performance." + Environment.NewLine;
729 strCode += "#" + Environment.NewLine;
730 strCode += "# Parameters" + Environment.NewLine;
731 strCode += "# ----------" + Environment.NewLine;
732 strCode += "# input_dim : int" + Environment.NewLine;
733 strCode += "# The attribute/embedding dimension of the input, associated with the 'state_size' of the model.." + Environment.NewLine;
734 strCode += "# num_input : int" + Environment.NewLine;
735 strCode += "# The number of input variables, including both numeric and categorical inputs." + Environment.NewLine;
736 strCode += "# hidden_dim : int" + Environment.NewLine;
737 strCode += "# The embedding width of the output." + Environment.NewLine;
738 strCode += "# dropout : float" + Environment.NewLine;
739 strCode += "# The dropout rate associated with the 'GRN' classes." + Environment.NewLine;
740 strCode += "# context_dim : int" + Environment.NewLine;
741 strCode += "# The embedding width of the context signal expected to be fed as an auxilary input." + Environment.NewLine;
742 strCode += "# batch_first : bool" + Environment.NewLine;
743 strCode += "# When True, the first dimension of the input and output tensors represent the batch size." + Environment.NewLine;
744
745 return strCode;
746 }
747 }
748
749 class GateAddNormLayerInfo : LayerInfo
750 {
751 static int m_nGenerationCreditCount = 0;
752 static int m_nGenerationCount = 0;
753
754 public GateAddNormLayerInfo(LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
755 {
756 }
757
758 public override string Generate(GENERATE gen)
759 {
760 string strCode = "";
761 if (gen == GENERATE.CREDITS)
762 {
763 strCode += generateCredits();
764 }
765 else if (gen == GENERATE.CLASSES)
766 {
767 strCode += generateGateAddNormClass(m_layer);
768 }
769 if (gen == GENERATE.DEFINITION)
770 {
771 strCode += " self." + m_layer.name + " = GLU(input_dim=" + m_layer.glu_param.input_dim.ToString() + ")" + Environment.NewLine;
772 }
773 else if (gen == GENERATE.INITWEIGHTS)
774 {
775 strCode += " self." + m_layer.name + ".init_weights()" + Environment.NewLine;
776 }
777 else if (gen == GENERATE.FORWARD)
778 strCode += " " + m_outputs.AsText + " = self." + m_layer.name + "(" + m_inputs.AsText + ")" + Environment.NewLine;
779
780 return strCode;
781 }
782
783 public static string generateCredits()
784 {
785 string strCode = "";
786
787 if (m_nGenerationCreditCount > 0)
788 return strCode;
789
790 strCode += "# GateAddNorm Layer:" + Environment.NewLine;
791 strCode += "# original code <https://github.com/PlaytikaOSS/tft-torch/tree/main>" + Environment.NewLine;
792 strCode += "# license: (MIT) <https://github.com/PlaytikaOSS/tft-torch/blob/main/LICENSE>" + Environment.NewLine;
793 m_nGenerationCreditCount++;
794
795 return strCode;
796 }
797
798 private string generateGateAddNormClass(LayerParameter p)
799 {
800 string strCode = "";
801
802 if (m_nGenerationCount > 0)
803 return strCode;
804
805 if (m_bAddComments)
806 strCode += generateClassComments();
807
808 strCode += "class GateAddNorm(nn.Module):" + Environment.NewLine;
809 strCode += " def __init__(self, input_dim: int, dropout: Optional[float] = None):" + Environment.NewLine;
810 strCode += " super(GateAddNorm, self).__init__()" + Environment.NewLine;
811 strCode += " self.input_dim = input_dim" + Environment.NewLine;
812 strCode += " self.dropout = dropout" + Environment.NewLine;
813 strCode += Environment.NewLine;
814 strCode += " if self.dropout is not None:" + Environment.NewLine;
815 strCode += " self.dropout_layer = nn.Dropout(p=self.dropout)" + Environment.NewLine;
816 strCode += " self.gate = TimeDistributed(GLU(self.input_dim), batch_first=True)" + Environment.NewLine;
817 strCode += " self.layernorm = TimeDistributed(nn.LayerNorm(self.input_dim), batch_first=True)" + Environment.NewLine;
818 strCode += Environment.NewLine;
819 strCode += " def init_weights(self):" + Environment.NewLine;
820 strCode += " self.gate.init_weights()" + Environment.NewLine;
821 strCode += Environment.NewLine;
822 strCode += " def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:" + Environment.NewLine;
823 strCode += " if self.dropout is not None:" + Environment.NewLine;
824 strCode += " x = self.dropout_layer(x)" + Environment.NewLine;
825 strCode += Environment.NewLine;
826 strCode += " x = self.gate(x)" + Environment.NewLine;
827 strCode += " if residual is not None:" + Environment.NewLine;
828 strCode += " x = x + residual" + Environment.NewLine;
829 strCode += Environment.NewLine;
830 strCode += " x = self.layernorm(x)" + Environment.NewLine;
831 strCode += " return x" + Environment.NewLine;
832 strCode += Environment.NewLine;
833
834 m_nGenerationCount++;
835
836 return strCode;
837 }
838
839 private string generateClassComments()
840 {
841 string strCode = "";
842
843 strCode += LayerInfo.generateBar();
844 strCode += "# The GateAddNorm class performs a dropout, residual connection and layer normalization." + Environment.NewLine;
845 strCode += "#" + Environment.NewLine;
846 strCode += "# Parameters" + Environment.NewLine;
847 strCode += "# ----------" + Environment.NewLine;
848 strCode += "# input_dim : int" + Environment.NewLine;
849 strCode += "# The attribute/embedding dimension of the input, associated with the 'state_size' of the model.." + Environment.NewLine;
850 strCode += "# dropout : float" + Environment.NewLine;
851 strCode += "# The dropout rate associated with the 'GRN' classes." + Environment.NewLine;
852
853 return strCode;
854 }
855 }
856
857 class MultiheadAttentionInterpLayerInfo : LayerInfo
858 {
859 static int m_nGenerationCreditCount = 0;
860 static int m_nGenerationCount = 0;
861
862 public MultiheadAttentionInterpLayerInfo(LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
863 {
864 }
865
866 public override string Generate(GENERATE gen)
867 {
868 string strCode = "";
869 if (gen == GENERATE.CREDITS)
870 {
871 strCode += generateCredits();
872 }
873 else if (gen == GENERATE.CLASSES)
874 {
875 strCode += generateMultiheadAttentionInterpClass(m_layer);
876 }
877 if (gen == GENERATE.DEFINITION)
878 {
879 strCode += " self." + m_layer.name + " = MultiheadAttentionInterp(embed_dim=" + m_layer.multihead_attention_interp_param.embed_dim.ToString() + ", num_heads=" + m_layer.multihead_attention_interp_param.num_heads.ToString() + ")" + Environment.NewLine;
880 }
881 else if (gen == GENERATE.INITWEIGHTS)
882 {
883 strCode += " self." + m_layer.name + ".init_weights()" + Environment.NewLine;
884 }
885 else if (gen == GENERATE.FORWARD)
886 strCode += " " + m_outputs.AsText + " = self." + m_layer.name + "(" + m_inputs.AsText + ")" + Environment.NewLine;
887
888 return strCode;
889 }
890
891 public static string generateCredits()
892 {
893 string strCode = "";
894
895 if (m_nGenerationCreditCount > 0)
896 return strCode;
897
898 strCode += "# MultiheadAttentionInterp Layer:" + Environment.NewLine;
899 strCode += "# original code <https://github.com/PlaytikaOSS/tft-torch/tree/main>" + Environment.NewLine;
900 strCode += "# license: (MIT) <https://github.com/PlaytikaOSS/tft-torch/blob/main/LICENSE>" + Environment.NewLine;
901
902 m_nGenerationCreditCount++;
903
904 return strCode;
905 }
906
907 private string generateMultiheadAttentionInterpClass(LayerParameter p)
908 {
909 string strCode = "";
910
911 if (m_nGenerationCount > 0)
912 return strCode;
913
914 if (m_bAddComments)
915 strCode += generateClassComments();
916
917 strCode += "class MultiheadAttentionInterp(nn.Module):" + Environment.NewLine;
918 strCode += " def __init__(self, embed_dim: int, num_heads: int):" + Environment.NewLine;
919 strCode += " super(MultiheadAttentionInterp, self).__init__()" + Environment.NewLine;
920 strCode += " self.d_model = embed_dim" + Environment.NewLine;
921 strCode += " self.num_heads = num_heads" + Environment.NewLine;
922 strCode += " self.all_heads_dim = embed_dim * num_heads" + Environment.NewLine;
923 strCode += Environment.NewLine;
924 strCode += " self.w_q = nn.Linear(self.embed_dim, self.all_heads_dim)" + Environment.NewLine;
925 strCode += " self.w_k = nn.Linear(self.embed_dim, self.all_heads_dim)" + Environment.NewLine;
926 strCode += " self.w_v = nn.Linear(self.embed_dim, self.ebmed_dim)" + Environment.NewLine;
927 strCode += Environment.NewLine;
928 strCode += " self.out = nn.Linear(self.d_model, self.d_model)" + Environment.NewLine;
929 strCode += Environment.NewLine;
930 strCode += " def init_weights(self):" + Environment.NewLine;
931 strCode += initWeights(" ", "self.w_q", true, m_layer.multihead_attention_interp_param.weight_filler, m_layer.multihead_attention_interp_param.bias_filler);
932 strCode += initWeights(" ", "self.w_k", true, m_layer.multihead_attention_interp_param.weight_filler, m_layer.multihead_attention_interp_param.bias_filler);
933 strCode += initWeights(" ", "self.w_v", true, m_layer.multihead_attention_interp_param.weight_filler, m_layer.multihead_attention_interp_param.bias_filler);
934 strCode += initWeights(" ", "self.out", true, m_layer.multihead_attention_interp_param.weight_filler, m_layer.multihead_attention_interp_param.bias_filler);
935 strCode += Environment.NewLine;
936 strCode += " def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:" + Environment.NewLine;
937 strCode += " num_samples = q.size(0)" + Environment.NewLine;
938 strCode += Environment.NewLine;
939 strCode += " q_proj = self.w_q(q).view(num_samples, -1, self.num_heads, self.d_model)" + Environment.NewLine;
940 strCode += " k_proj = self.w_k(k).view(num_samples, -1, self.num_heads, self.d_model)" + Environment.NewLine;
941 strCode += " v_proj = self.w_v(v).repeat(1, 1, self.num_heads).view(num_samples, -1, self.num_heads, self.d_model)" + Environment.NewLine;
942 strCode += Environment.NewLine;
943 strCode += " q_proj = q_proj.transpose(1, 2)" + Environment.NewLine;
944 strCode += " k_proj = k_proj.transpose(1, 2)" + Environment.NewLine;
945 strCode += " v_proj = v_proj.transpose(1, 2)" + Environment.NewLine;
946 strCode += Environment.NewLine;
947 strCode += " attn_outputs_all_heads, attn_scores_all_heads = self.attention(q_proj, k_proj, v_proj, mask)" + Environment.NewLine;
948 strCode += Environment.NewLine;
949 strCode += " attn_scores = attn_scores_all_heads.mean(dim=1)" + Environment.NewLine;
950 strCode += " attn_outputs = attn_outputs_all_heads.mean(dim=1)" + Environment.NewLine;
951 strCode += Environment.NewLine;
952 strCode += " attn_outputs = self.out(attn_outputs)" + Environment.NewLine;
953 strCode += Environment.NewLine;
954 strCode += " return attn_outputs, attn_scores" + Environment.NewLine;
955 strCode += Environment.NewLine;
956 strCode += " def attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:" + Environment.NewLine;
957 strCode += " scores = torch.matmul(q, k.transpose(-2, -1))" + Environment.NewLine;
958 strCode += Environment.NewLine;
959 strCode += " if mask is not None:" + Environment.NewLine;
960 strCode += " scores = scores.masked_fill(mask, -1e9)" + Environment.NewLine;
961 strCode += Environment.NewLine;
962 strCode += " scores = F.softmax(scores, dim=-1)" + Environment.NewLine;
963 strCode += " outputs = torch.matmul(scores, v)" + Environment.NewLine;
964 strCode += Environment.NewLine;
965 strCode += " return outputs, scores" + Environment.NewLine;
966 strCode += Environment.NewLine;
967
968 m_nGenerationCount++;
969
970 return strCode;
971 }
972
973 private string generateClassComments()
974 {
975 string strCode = "";
976
977 strCode += LayerInfo.generateBar();
978 strCode += "# The MultiheadAttentionInterp class learns long-term relationsips across different time-steps." + Environment.NewLine;
979 strCode += "# A multi-head attention is modified to enhance explainability. With traditional multi-head attention" + Environment.NewLine;
980 strCode += "# the 'values' signal is shared for all heads an additive aggregation is employed across all heads." + Environment.NewLine;
981 strCode += "# However, according to the paper, each head can learn different temporal patterns, while attending to a common set" + Environment.NewLine;
982 strCode += "# of input features which can be interpreted as a simple ensemble over attention weights in a combined matrix, which" + Environment.NewLine;
983 strCode += "# compared to the original multi-head attention matrix, yields an increased representational capacity." + Environment.NewLine;
984 strCode += "#" + Environment.NewLine;
985 strCode += "# Parameters" + Environment.NewLine;
986 strCode += "# ----------" + Environment.NewLine;
987 strCode += "# embed_dim : int" + Environment.NewLine;
988 strCode += "# The dimensions associated with the 'state_size' of the model, corresponding to the input and output." + Environment.NewLine;
989 strCode += "# num_heads : float" + Environment.NewLine;
990 strCode += "# The number of heads used by the multi-head attention component." + Environment.NewLine;
991
992 return strCode;
993 }
994 }
995
996 class ReshapeTemporalLayerInfo : LayerInfo
997 {
998 static int m_nGenerationCountBefore = 0;
999 static int m_nGenerationCountAfter = 0;
1000
1001 public ReshapeTemporalLayerInfo(LayerParameter layer, VariableCollection inputs) : base(layer, inputs)
1002 {
1003 }
1004
1005 public override string Generate(GENERATE gen)
1006 {
1007 string strCode = "";
1008 if (gen == GENERATE.CLASSES)
1009 {
1010 strCode += generateReshapeTemporalClass(m_layer);
1011 }
1012 if (gen == GENERATE.DEFINITION)
1013 {
1014 if (m_layer.reshape_temporal_param.mode == param.tft.ReshapeTemporalParameter.MODE.BEFORE)
1015 strCode += " self." + m_layer.name + " = ReshapeTemporalBefore()" + Environment.NewLine;
1016 else if (m_layer.reshape_temporal_param.mode == param.tft.ReshapeTemporalParameter.MODE.AFTER)
1017 strCode += " self." + m_layer.name + " = ReshapeTemporalAfter()" + Environment.NewLine;
1018 }
1019 else if (gen == GENERATE.INITWEIGHTS)
1020 {
1021 }
1022 else if (gen == GENERATE.FORWARD)
1023 strCode += " " + m_outputs.AsText + " = self." + m_layer.name + "(" + m_inputs.AsText + ")" + Environment.NewLine;
1024
1025 return strCode;
1026 }
1027
1028 private string generateReshapeTemporalClass(LayerParameter p)
1029 {
1030 string strCode = "";
1031
1032 if (p.reshape_temporal_param.mode == param.tft.ReshapeTemporalParameter.MODE.BEFORE)
1033 {
1034 if (m_nGenerationCountBefore == 0)
1035 {
1036 strCode += "class ReshapeTemporalBefore(nn.Module):" + Environment.NewLine;
1037 strCode += " def __init__(self):" + Environment.NewLine;
1038 strCode += " super(ReshapeTemporalBefore, self).__init__()" + Environment.NewLine;
1039 strCode += Environment.NewLine;
1040 strCode += " def forward(self, x: torch.Tensor) -> torch.Tensor:" + Environment.NewLine;
1041 strCode += " pass" + Environment.NewLine;
1042 strCode += Environment.NewLine;
1043 }
1044 m_nGenerationCountBefore++;
1045 }
1046 else
1047 {
1048 if (m_nGenerationCountAfter == 0)
1049 {
1050 strCode += "class ReshapeTemporalAfter(nn.Module):" + Environment.NewLine;
1051 strCode += " def __init__(self):" + Environment.NewLine;
1052 strCode += " super(ReshapeTemporalAfter, self).__init__()" + Environment.NewLine;
1053 strCode += Environment.NewLine;
1054 strCode += " def forward(self, x: torch.Tensor) -> torch.Tensor:" + Environment.NewLine;
1055 strCode += " pass" + Environment.NewLine;
1056 strCode += Environment.NewLine;
1057 }
1058 m_nGenerationCountAfter++;
1059 }
1060
1061 return strCode;
1062 }
1063 }
1064}
The Utility class provides general utility funtions.
Definition: Utility.cs:35
Specifies the filler parameters used to create each Filler.
Specifies the base parameter for all layers.
GrnParameter grn_param
Returns the parameter set when initialized with LayerType.GLU
ReshapeTemporalParameter reshape_temporal_param
Returns the parameter set when initialized with LayerType.RESHAPE_TEMPORAL
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
Phase
Defines the Phase under which to run a Net.
Definition: Interfaces.cs:61
The MyCaffe.fillers namespace contains all fillers including the Filler class.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12