2using System.Collections.Generic;
22 List<uint> m_rgSlicePoints =
new List<uint>();
74 int nNumAxes = colBottom[0].num_true_axes;
81 m_log.
CHECK_GE(m_nSliceAxis, 0,
"casting slice_dim from uint to int produced a negative result; slice_dim must satisfy 0 <= slice_dim < " +
Blob<T>.
MAX_BLOB_AXES.ToString());
82 m_log.
CHECK_LT(m_nSliceAxis, nNumAxes,
"slice_dim is out of range.");
89 List<int> rgTopShape =
Utility.Clone<
int>(colBottom[0].shape());
90 int bottom_slice_axis = colBottom[0].shape(m_nSliceAxis);
92 m_nNumSlices = colBottom[0].count(0, m_nSliceAxis);
93 m_nSliceSize = colBottom[0].count(m_nSliceAxis + 1);
97 if (m_rgSlicePoints.Count != 0)
99 m_log.
CHECK_EQ(m_rgSlicePoints.Count, colTop.
Count - 1,
"The slice point count is incorrect.");
100 m_log.
CHECK_LE(colTop.
Count, bottom_slice_axis,
"slice axis: " + bottom_slice_axis.ToString() +
", bottom[0] shape: '" + colBottom[0].shape_string +
"'");
103 List<int> rgSlices =
new List<int>();
105 for (
int i = 0; i < m_rgSlicePoints.Count; i++)
107 m_log.
CHECK_GT((
int)m_rgSlicePoints[i], nPrev,
"The slice point at " + i.ToString() +
" should be greater than the previous slice point of " + nPrev.ToString());
108 rgSlices.Add((
int)m_rgSlicePoints[i] - nPrev);
109 nPrev = (int)m_rgSlicePoints[i];
112 rgSlices.Add(bottom_slice_axis - nPrev);
114 for (
int i = 0; i < colTop.
Count; i++)
116 rgTopShape[m_nSliceAxis] = rgSlices[i];
118 nCount += colTop[i].count();
123 m_log.
CHECK_EQ(bottom_slice_axis % colTop.
Count, 0,
"Number of top blobs (" + colTop.
Count.ToString() +
") should evenly divide input slice axis (" + bottom_slice_axis.ToString() +
")");
124 rgTopShape[m_nSliceAxis] = bottom_slice_axis / colTop.
Count;
126 for (
int i = 0; i < colTop.
Count; i++)
129 nCount += colTop[i].count();
133 m_log.
CHECK_EQ(nCount, colBottom[0].count(),
"The count (" + nCount.ToString() +
") should be the same as the bottom count (" + colBottom[0].count().ToString() +
")");
135 if (colTop.
Count == 1)
137 colTop[0].ShareData(colBottom[0]);
138 colTop[0].ShareDiff(colBottom[0]);
153 if (colTop.
Count == 1)
156 int nOffsetSliceAxis = 0;
157 long hBottomData = colBottom[0].gpu_data;
158 int nBottomSliceAxis = colBottom[0].shape(m_nSliceAxis);
160 for (
int i = 0; i < colTop.
Count; i++)
162 long hTopData = colTop[i].mutable_gpu_data;
163 int nTopSliceAxis = colTop[i].shape(m_nSliceAxis);
164 int nTopSliceSize = nTopSliceAxis * m_nSliceSize;
165 int nCount = nTopSliceSize * m_nNumSlices;
167 m_cuda.slice_fwd(nCount, hBottomData, m_nNumSlices, m_nSliceSize, nBottomSliceAxis, nTopSliceAxis, nOffsetSliceAxis, hTopData);
168 nOffsetSliceAxis += nTopSliceAxis;
182 if (!rgbPropagateDown[0] || colTop.
Count == 1)
185 int nOffsetSliceAxis = 0;
186 long hBottomDiff = colBottom[0].mutable_gpu_diff;
187 int nBottomSliceAxis = colBottom[0].shape(m_nSliceAxis);
189 for (
int i = 0; i < colTop.
Count; i++)
191 long hTopDiff = colTop[i].gpu_diff;
192 int nTopSliceAxis = colTop[i].shape(m_nSliceAxis);
193 int nTopSliceSize = nTopSliceAxis * m_nSliceSize;
194 int nCount = nTopSliceSize * m_nNumSlices;
196 m_cuda.slice_bwd(nCount, hTopDiff, m_nNumSlices, m_nSliceSize, nBottomSliceAxis, nTopSliceAxis, nOffsetSliceAxis, hBottomDiff);
197 nOffsetSliceAxis += nTopSliceAxis;
The Log class provides general output in text form.
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
void CHECK_GT(double df1, double df2, string str)
Test whether one number is greater than another.
void CHECK_LE(double df1, double df2, string str)
Test whether one number is less than or equal to another.
void CHECK_GE(double df1, double df2, string str)
Test whether one number is greater than or equal to another.
void CHECK_LT(double df1, double df2, string str)
Test whether one number is less than another.
The Utility class provides general utility funtions.
The BlobCollection contains a list of Blobs.
int Count
Returns the number of items in the collection.
void Reshape(int[] rgShape)
Reshapes all blobs in the collection to the given shape.
The Blob is the main holder of data that moves through the Layers of the Net.
const int MAX_BLOB_AXES
Defines the maximum number of Axes supported by the Blob.
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
An interface for the units of computation which can be composed into a Net.
Log m_log
Specifies the Log for output.
LayerParameter m_param
Specifies the LayerParameter describing the Layer.
CudaDnn< T > m_cuda
Specifies the CudaDnn connection to Cuda.
LayerParameter.LayerType m_type
Specifies the Layer type.
The SliceLayer takes a blob and slices it along either the num or channel dimensions outputting multi...
override int ExactNumBottomBlobs
Returns the exact number of required bottom (input) Blobs: input.
SliceLayer(CudaDnn< T > cuda, Log log, LayerParameter p)
The SliceLayer constructor.
override int MinTopBlobs
Returns the minimum number of required top (output) Blobs: slice
override void backward(BlobCollection< T > colTop, List< bool > rgbPropagateDown, BlobCollection< T > colBottom)
Computes the error gradient w.r.t the inputs.
override void Reshape(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Reshape the bottom (input) and top (output) blobs.
override void LayerSetUp(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Setup the layer.
override void forward(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Computes the forward calculation.
Specifies the base parameter for all layers.
SliceParameter slice_param
Returns the parameter set when initialized with LayerType.SLICE
LayerType
Specifies the layer type.
uint slice_dim
DEPRECIATED: alias for 'axis' – does not support negative indexing.
List< uint > slice_point
Specifies optional slice points which indicate the indexes in the selected dimensions (the number of ...
int axis
Specifies the axis along wich to slice – may be negative to index from the end (e....
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
The MyCaffe.common namespace contains common MyCaffe classes.
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...