2using System.Collections.Generic;
5using System.Net.Http.Headers;
31 int m_nNumRepeatCount;
32 int m_nForcedRepeatCount;
33 List<int> m_rgShape =
new List<int>(4);
34 Dictionary<string, List<int>> m_rgShapes =
new Dictionary<string, List<int>>();
35 Blob<T> m_blobTimeDistributedContext;
53 dispose(ref m_blobTimeDistributedContext);
63 if (m_blobTimeDistributedContext !=
null)
64 col.
Add(m_blobTimeDistributedContext);
109 private void replicate_along_time_fwd(
Blob<T> bBtm,
Blob<T> bTop,
int nTimeSteps,
bool bTemporalRepeat,
bool bReshapeOnly =
false)
115 m_rgShape.Add(m_nNumSamples);
116 m_rgShape.Add(nTimeSteps);
117 m_rgShape.Add(bBtm.
shape(1));
122 int nInnerNum = bBtm.
count(1);
123 for (
int i = 0; i < nTimeSteps; i++)
131 m_rgShape.Add(nTimeSteps);
132 m_rgShape.Add(m_nNumSamples);
133 m_rgShape.Add(bBtm.
shape(1));
138 int nInnerNum = bBtm.
count(1);
139 for (
int i = 0; i < nTimeSteps; i++)
156 private void replicate_along_time_bwd(
Blob<T> bBtm,
Blob<T> bTop,
int nTimeSteps,
bool bTemporalRepeat)
158 int nInnerNum = bBtm.
count(1);
173 private void stack_time_steps_along_batch_fwd(
Blob<T> bBtm,
Blob<T> bTop,
bool bResizeOnly =
false)
180 m_rgShape.Add(bBtm.
count(2));
184 private void stack_time_steps_along_batch_bwd(
Blob<T> bBtm,
Blob<T> bTop)
186 bBtm.
CopyFrom(bTop,
true,
false, 0,
true);
200 m_nNumSamples = colBottom[0].num;
203 if (m_nForcedRepeatCount >= 0)
204 m_nNumRepeatCount = m_nForcedRepeatCount;
206 m_nNumRepeatCount = colBottom[0].shape(1);
209 if (colBottom.
Count > 1)
212 m_blobWork.
Name =
"work";
214 if (m_nNumRepeatCount > 0)
217 m_blobTimeDistributedContext.Name =
m_param.
name +
".tdctx";
218 replicate_along_time_fwd(colBottom[1], m_blobTimeDistributedContext, m_nNumRepeatCount, m_nForcedRepeatCount < 0,
true);
219 stack_time_steps_along_batch_fwd(m_blobTimeDistributedContext, colTop[1],
true);
223 stack_time_steps_along_batch_fwd(colBottom[1], colTop[1],
true);
226 colTop[1].SetParameter(
"num_samples", m_nNumSamples);
227 colTop[1].SetParameter(
"num_temporal_steps", m_nNumRepeatCount);
228 colTop[1].SetParameter(
"forced_temporal_steps", m_nForcedRepeatCount);
232 stack_time_steps_along_batch_fwd(colBottom[0], colTop[0],
true);
233 colTop[0].SetParameter(
"num_samples", m_nNumSamples);
234 colTop[0].SetParameter(
"num_temporal_steps", colBottom[0].shape(1));
238 m_nNumSamples = (int)colBottom[0].GetParameter(
"num_samples");
239 int nTemporalSteps = (int)colBottom[0].GetParameter(
"num_temporal_steps");
241 int nCount = colBottom[0].count();
242 int nDim = m_nNumSamples * nTemporalSteps;
244 m_rgShape.Add(m_nNumSamples);
245 m_rgShape.Add(nTemporalSteps);
246 m_rgShape.Add(nCount / nDim);
252 m_log.
CHECK_GT(colTop.
Count, nIdx,
"There must be at least " + (nIdx + 1).ToString() +
" tops for the enable clip output!");
254 m_rgShape.Add(m_nNumSamples);
255 m_rgShape.Add(nTemporalSteps);
256 colTop[nIdx].
Reshape(m_rgShape);
260 if (colBottom.
Count > 1)
262 m_nNumRepeatCount = (int)colBottom[1].GetParameter(
"num_temporal_steps");
263 m_nForcedRepeatCount = (int)colBottom[1].GetParameter(
"forced_temporal_steps");
267 m_log.
CHECK_GT(colTop.
Count, nIdx,
"There must be at least " + (nIdx + 1).ToString() +
" tops for the enable clip output!");
268 nCount = colBottom[1].count();
271 if (m_nForcedRepeatCount >= 0)
273 m_rgShape.Add(m_nNumSamples);
274 m_rgShape.Add(nCount / nDim);
278 m_rgShape.Add(m_nNumSamples);
279 if (m_nNumRepeatCount > 0)
280 m_rgShape.Add(m_nNumRepeatCount);
281 m_rgShape.Add(nCount / nDim);
284 colTop[nIdx].
Reshape(m_rgShape);
302 if (colBottom.
Count > 1)
304 if (m_nNumRepeatCount > 0)
306 replicate_along_time_fwd(colBottom[1], m_blobTimeDistributedContext, m_nNumRepeatCount, m_nForcedRepeatCount < 0,
true);
308 stack_time_steps_along_batch_fwd(m_blobTimeDistributedContext, colTop[1],
true);
312 stack_time_steps_along_batch_fwd(colBottom[1], colTop[1],
true);
317 stack_time_steps_along_batch_fwd(colBottom[0], colTop[0],
true);
321 int nTemporalSteps = (int)colBottom[0].GetParameter(
"num_temporal_steps");
322 int nCount = colBottom[0].count();
323 int nDim = m_nNumSamples * nTemporalSteps;
325 m_rgShape.Add(m_nNumSamples);
326 m_rgShape.Add(nTemporalSteps);
327 m_rgShape.Add(nCount / nDim);
333 m_log.
CHECK_GT(colTop.
Count, nIdx,
"There must be at least " + (nIdx + 1).ToString() +
" tops for the enable clip output!");
335 m_rgShape.Add(m_nNumSamples);
336 m_rgShape.Add(nTemporalSteps);
337 colTop[nIdx].
Reshape(m_rgShape);
341 if (colBottom.
Count > 1)
343 m_nNumRepeatCount = (int)colBottom[1].GetParameter(
"num_temporal_steps");
344 m_nForcedRepeatCount = (int)colBottom[1].GetParameter(
"forced_temporal_steps");
348 m_log.
CHECK_GT(colTop.
Count, nIdx,
"There must be at least " + (nIdx + 1).ToString() +
" tops for the enable clip output!");
349 nCount = colBottom[1].count();
352 if (m_nForcedRepeatCount >= 0)
354 m_rgShape.Add(m_nNumSamples);
355 m_rgShape.Add(nCount / nDim);
359 m_rgShape.Add(m_nNumSamples);
360 if (m_nNumRepeatCount > 0)
361 m_rgShape.Add(m_nNumRepeatCount);
362 m_rgShape.Add(nCount / nDim);
365 colTop[nIdx].
Reshape(m_rgShape);
386 if (colBottom.
Count > 1)
389 if (m_nNumRepeatCount > 0)
391 replicate_along_time_fwd(colBottom[1], m_blobTimeDistributedContext, m_nNumRepeatCount, m_nForcedRepeatCount < 0);
392 stack_time_steps_along_batch_fwd(m_blobTimeDistributedContext, colTop[1]);
396 stack_time_steps_along_batch_fwd(colBottom[1], colTop[1]);
401 stack_time_steps_along_batch_fwd(colBottom[0], colTop[0]);
405 colTop[0].
CopyFrom(colBottom[0],
false,
false, 0,
true);
416 colTop[nIdx].
CopyFrom(colBottom[1],
false,
false, 0,
true);
440 stack_time_steps_along_batch_bwd(colBottom[0], colTop[0]);
443 if (colBottom.
Count > 1)
445 if (m_nNumRepeatCount > 0)
447 stack_time_steps_along_batch_bwd(m_blobTimeDistributedContext, colTop[1]);
448 replicate_along_time_bwd(colBottom[1], m_blobTimeDistributedContext, m_nNumRepeatCount, m_nForcedRepeatCount < 0);
452 stack_time_steps_along_batch_bwd(colBottom[1], colTop[1]);
458 colBottom[0].
CopyFrom(colTop[0],
true,
false, 0,
true);
464 if (colBottom.
Count > 1)
465 colBottom[1].
CopyFrom(colTop[nIdx],
true,
false, 0,
true);
The Log class provides general output in text form.
void CHECK_GT(double df1, double df2, string str)
Test whether one number is greater than another.
The BlobCollection contains a list of Blobs.
void Add(Blob< T > b)
Add a new Blob to the collection.
void SetData(double df)
Set all blob data to the value specified.
int Count
Returns the number of items in the collection.
void Clear(bool bDispose=false)
Remove all items from the collection.
void Reshape(int[] rgShape)
Reshapes all blobs in the collection to the given shape.
void CopyFrom(BlobCollection< T > bSrc, bool bCopyDiff=false)
Copy the data or diff from another BlobCollection into this one.
The Blob is the main holder of data that moves through the Layers of the Net.
long mutable_gpu_diff
Returns the diff GPU handle used by the CudaDnn connection.
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
void Reshape(int nNum, int nChannels, int nHeight, int nWidth, bool? bUseHalfSize=null)
DEPRECIATED; use
void CopyFrom(Blob< T > src, int nSrcOffset, int nDstOffset, int nCount, bool bCopyData, bool bCopyDiff)
Copy from a source Blob.
List< int > shape()
Returns an array where each element contains the shape of an axis of the Blob.
int count()
Returns the total number of items in the Blob.
void ReshapeLike(Blob< T > b, bool? bUseHalfSize=null)
Reshape this Blob to have the same shape as another Blob.
string Name
Get/set the name of the Blob.
long gpu_diff
Returns the diff GPU handle used by the CudaDnn connection.
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
An interface for the units of computation which can be composed into a Net.
Log m_log
Specifies the Log for output.
LayerParameter m_param
Specifies the LayerParameter describing the Layer.
BlobCollection< T > m_colInternalBlobs
Specifies internal blobs used by the layer.
CudaDnn< T > m_cuda
Specifies the CudaDnn connection to Cuda.
LayerParameter.LayerType m_type
Specifies the Layer type.
The ReshapeTemporalLayer implements the Variable Selection Network
override int MinBottomBlobs
Returns the min number of required bottom (input) Blobs: temporal_rep
override int MinTopBlobs
Returns the exact number of required top (output) Blobs: temporal_selection_output
override void setup_internal_blobs(BlobCollection< T > col)
Derivative layers should add all internal blobws to the 'col' provided.
ReshapeTemporalLayer(CudaDnn< T > cuda, Log log, LayerParameter p)
The constructor.
override void backward(BlobCollection< T > colTop, List< bool > rgbPropagateDown, BlobCollection< T > colBottom)
Computes the error gradient w.r.t. the stacked embedding numeric and categorical value inputs.
override int MaxTopBlobs
Returns the exact number of required top (output) Blobs: temporal_selection_output,...
override void dispose()
Releases all GPU and host resources used by the Layer.
override int MaxBottomBlobs
Returns the max number of required bottom (input) Blobs: temporal_rep, static_selection
override void forward(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Forward computation
override void LayerSetUp(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Setup the layer.
override void Reshape(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Reshape the top (output) blobs.
Specifies the base parameter for all layers.
string name
Specifies the name of this LayerParameter.
ReshapeTemporalParameter reshape_temporal_param
Returns the parameter set when initialized with LayerType.RESHAPE_TEMPORAL
LayerType
Specifies the layer type.
Specifies the parameters for the ReshapeTemporalLayer.
bool enable_weight_output
Specifies to output the weights for the data output in the AFTER mode.
int forced_repeat_count
Specifies the forced repeat steps bottom(1). A value of -1 specifies to use the temporal axis as the ...
bool enable_clip_output
Specifies to output the clip for the data output in the AFTER mode.
MODE
Defines the modulation type.
ReshapeTemporalParameter()
Constructor for the parameter.
MODE mode
Specifies the mode of operation.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
The MyCaffe.common namespace contains common MyCaffe classes.
DIR
Defines the direction of data flow.
The MyCaffe.layers.tft namespace contains all TFT related layers.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...