2using System.Collections.Generic;
78 m_nTopK = (int)p.
top_k;
85 m_nAxis = colBottom[0].CanonicalAxisIndex(p.
axis.Value);
86 m_log.
CHECK_GE(m_nAxis.Value, 0,
"axis must not be less than zero.");
87 m_log.
CHECK_LE(m_nAxis.Value, colBottom[0].num_axes,
"axis must be less tahn or equal to the dimension of the axis.");
88 m_log.
CHECK_LE(m_nTopK, colBottom[0].shape(m_nAxis.Value),
"top_k must be less than or equal to the dimension of the axis.");
92 m_log.
CHECK_LE(m_nTopK, colBottom[0].count(1),
"top_k must be less than or equal to the dimension of the flattened bottom blob per instance.");
103 int nNumTopAxes = colBottom[0].num_axes;
110 if (m_nAxis.HasValue)
113 rgShape =
Utility.Clone<
int>(colBottom[0].shape());
114 rgShape[m_nAxis.Value] = m_nTopK;
118 rgShape[0] = colBottom[0].shape(0);
120 rgShape[2] = m_nTopK;
151 forward_gpu(colBottom, colTop);
153 forward_cpu(colBottom, colTop);
158 int nAxis = m_nAxis.GetValueOrDefault(1);
159 int nOuterNum = colBottom[0].count(0, nAxis);
160 int nChannels = colBottom[0].count(nAxis);
164 m_log.
WriteLine(
"WARNING: The gpu implementation of argmax only supports TopK = 1.");
167 m_log.
WriteLine(
"WARNING: Currently the gpu implementation of argmax does now support output of the max values.");
170 m_cuda.channel_max(colBottom[0].count(), nOuterNum, nChannels, nInnerNum, colBottom[0].gpu_data, colTop[0].mutable_gpu_data,
true);
172 m_cuda.channel_min(colBottom[0].count(), nOuterNum, nChannels, nInnerNum, colBottom[0].gpu_data, colTop[0].mutable_gpu_data,
true);
177 double[] rgBottomData =
convertD(colBottom[0].update_cpu_data());
178 double[] rgTopData =
convertD(colTop[0].mutable_cpu_data);
182 if (m_nAxis.HasValue)
184 nDim = colBottom[0].shape(m_nAxis.Value);
186 nAxisDist = colBottom[0].count(m_nAxis.Value) / nDim;
190 nDim = colBottom[0].count(1);
194 int nNum = colBottom[0].count() / nDim;
196 for (
int i = 0; i < nNum; i++)
198 List<KeyValuePair<double, int>> rgBottomDataPair =
new List<KeyValuePair<double, int>>();
200 for (
int j = 0; j < nDim; j++)
202 int nIdx = (i / nAxisDist * nDim + j) * nAxisDist + i % nAxisDist;
203 rgBottomDataPair.
Add(
new KeyValuePair<double, int>(rgBottomData[nIdx], j));
207 rgBottomDataPair.Sort(
new Comparison<KeyValuePair<double, int>>(sortDataItemsDescending));
209 rgBottomDataPair.Sort(
new Comparison<KeyValuePair<double, int>>(sortDataItemsAscending));
211 for (
int j = 0; j < m_nTopK; j++)
215 if (m_nAxis.HasValue)
218 int nIdx = (i / nAxisDist * m_nTopK + j) * nAxisDist + i % nAxisDist;
219 rgTopData[nAxisDist] = rgBottomDataPair[j].Key;
224 int nIdx1 = 2 * i * m_nTopK + j;
225 rgTopData[nIdx1] = rgBottomDataPair[j].Value;
226 int nIdx2 = 2 * i * m_nTopK + m_nTopK + j;
227 rgTopData[nIdx2] = rgBottomDataPair[j].Key;
233 int nIdx = (i / nAxisDist * m_nTopK + j) * nAxisDist + i % nAxisDist;
234 rgTopData[nIdx] = rgBottomDataPair[j].Value;
239 colTop[0].mutable_cpu_data =
convert(rgTopData);
242 private int sortDataItemsAscending(KeyValuePair<double, int> a, KeyValuePair<double, int> b)
253 private int sortDataItemsDescending(KeyValuePair<double, int> a, KeyValuePair<double, int> b)
267 new NotImplementedException();
The Log class provides general output in text form.
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
void CHECK_LE(double df1, double df2, string str)
Test whether one number is less than or equal to another.
void CHECK_GE(double df1, double df2, string str)
Test whether one number is greater than or equal to another.
The Utility class provides general utility funtions.
static List< int > Create(int nCount, int nStart, int nInc)
Create a new List and fill it with values starting with start and incrementing by inc.
The BlobCollection contains a list of Blobs.
void Add(Blob< T > b)
Add a new Blob to the collection.
void Reshape(int[] rgShape)
Reshapes all blobs in the collection to the given shape.
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
The ArgMaxLayer computes the index of the K max values for each datum across all dimensions ....
override void Reshape(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Reshape the bottom (input) and top (output) blobs.
override void forward(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Forward computation. When 'enable_cuda_impl' = true (default = false) the GPU version is run.
override int ExactNumTopBlobs
Returns the exact number of top blobs required: argmax
override void LayerSetUp(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Setup the layer.
override int ExactNumBottomBlobs
Returns the exact number of bottom blobs required: input
override void backward(BlobCollection< T > colTop, List< bool > rgbPropagateDown, BlobCollection< T > colBottom)
Not implemented.
ArgMaxLayer(CudaDnn< T > cuda, Log log, LayerParameter p)
Constructor.
An interface for the units of computation which can be composed into a Net.
Log m_log
Specifies the Log for output.
LayerParameter m_param
Specifies the LayerParameter describing the Layer.
void convert(BlobCollection< T > col)
Convert a collection of blobs from / to half size.
double convertD(T df)
Converts a generic to a double value.
CudaDnn< T > m_cuda
Specifies the CudaDnn connection to Cuda.
LayerParameter.LayerType m_type
Specifies the Layer type.
Specifies the parameters for the ArgMaxLayer
bool enable_cuda_impl
Specifies to use the low-level full cuda implementation of LayerNorm (default = false).
uint top_k
When computing accuracy, count as correct by comparing the true label to the top_k scoring classes....
COMPARE_OPERATOR operation
Specifies the operation to use (default = MAX).
bool out_max_val
If true produce pairs (argmax, maxval)
COMPARE_OPERATOR
Defines the compare operator to use (max or min, default = max).
int? axis
The axis along which to maximize – may be negative to index from the end (e.g., -1 for the last axis)...
Specifies the base parameter for all layers.
ArgMaxParameter argmax_param
Returns the parameter set when initialized with LayerType.ARGMAX
LayerType
Specifies the layer type.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
The MyCaffe.common namespace contains common MyCaffe classes.
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...