4using System.Collections.Generic;
7using System.Drawing.Drawing2D;
11using System.Threading.Tasks;
20 string m_strName =
"Curve";
21 Dictionary<string, int> m_rgActionSpace;
24 int m_nMaxSteps =
int.MaxValue;
27 GeomGraph m_geomGraph;
29 List<GeomPolyLine> m_rgGeomPredictedLines =
new List<GeomPolyLine>();
31 float m_fXScale = 512;
32 float m_fYScale = 150;
34 float m_fInc = (float)(Math.PI * 2.0f / 360.0f);
35 float m_fMax = (float)(Math.PI * 2.0f);
36 List<DataPoint> m_rgPrevPoints =
new List<DataPoint>();
37 Dictionary<Color, Brush> m_rgBrushes =
new Dictionary<Color, Brush>();
38 Dictionary<Color, Brush> m_rgBrushesEmphasize =
new Dictionary<Color, Brush>();
39 int m_nMaxPlots = 500;
41 bool m_bRenderImage =
true;
42 List<string> m_rgstrLabels =
null;
43 List<bool> m_rgEmphasize =
null;
44 List<Color> m_rgPallete =
new List<Color>();
45 List<Color> m_rgPalleteEmphasize =
new List<Color>() { Color.Orange, Color.Green, Color.Red, Color.HotPink, Color.Tomato, Color.Lavender, Color.DarkOrange };
46 int m_nAlphaNormal = 64;
48 Random m_random =
new Random();
49 CurveState m_state =
new CurveState();
91 m_rgActionSpace =
new Dictionary<string, int>();
92 m_rgActionSpace.Add(
"MoveUp", 0);
93 m_rgActionSpace.Add(
"MoveDn", 1);
95 foreach (Color clr
in m_rgPalleteEmphasize)
97 Color clrLite = Color.FromArgb(m_nAlphaNormal, clr);
98 m_rgPallete.Add(clrLite);
100 m_rgBrushes.Add(clr,
new SolidBrush(clrLite));
101 m_rgBrushesEmphasize.Add(clr,
new SolidBrush(clr));
110 foreach (KeyValuePair<Color, Brush> kv
in m_rgBrushes)
115 foreach (KeyValuePair<Color, Brush> kv
in m_rgBrushesEmphasize)
137 if (nCurveType != -1)
153 if (properties !=
null)
164 get {
return false; }
188 get {
return m_strName; }
213 return m_rgActionSpace;
216 private void processAction(
ACTION? a,
double? dfOverride =
null)
235 public Tuple<Bitmap, SimpleDatum>
Render(
bool bShowUi,
int nWidth,
int nHeight,
bool bGetAction)
237 List<double> rgData =
new List<double>();
239 rgData.Add(m_state.X);
240 rgData.Add(m_state.Y);
241 rgData.Add(m_nSteps);
243 if (m_state.PredictedYValues.Count > 0)
244 rgData.AddRange(m_state.PredictedYValues);
248 m_rgstrLabels = m_state.PredictedYNames;
249 m_rgEmphasize = m_state.PredictedYEmphasize;
251 return Render(bShowUi, nWidth, nHeight, rgData.ToArray(), bGetAction);
263 public Tuple<Bitmap, SimpleDatum>
Render(
bool bShowUi,
int nWidth,
int nHeight,
double[] rgData,
bool bGetAction)
265 double dfX = rgData[0];
266 double dfY = rgData[1];
267 int nSteps = (int)rgData[2];
268 int nPredictedIdx = 3;
271 m_nMaxSteps = Math.Max(nSteps, m_nMaxSteps);
278 bmp =
new Bitmap(nWidth, nHeight);
279 using (Graphics g = Graphics.FromImage(bmp))
281 Rectangle rc =
new Rectangle(0, 0, bmp.Width, bmp.Height);
282 g.FillRectangle(Brushes.White, rc);
284 float fScreenWidth = g.VisibleClipBounds.Width;
285 float fScreenHeight = g.VisibleClipBounds.Height;
286 float fWorldWidth = (float)m_fXScale;
287 float fWorldHeight = (float)m_fYScale * 2;
288 float fScale = fScreenHeight / fWorldHeight;
291 float fR = fWorldWidth;
292 float fT = fWorldHeight / 2;
293 float fB = -fWorldHeight / 2;
295 if (m_geomGraph ==
null)
297 m_geomGraph =
new GeomGraph(fL, fR, fT, fB, Color.Azure, Color.SteelBlue);
298 m_geomGraph.SetLocation(0, fScale * (fWorldHeight / 2));
300 if (m_geomTargetLine ==
null)
302 m_geomTargetLine =
new GeomPolyLine(fL, fR, fT, fB, Color.Blue, Color.Blue, m_nMaxPlots);
303 m_geomTargetLine.
Polygon.Clear();
304 m_geomTargetLine.
SetLocation(0, fScale * (fWorldHeight / 2));
306 m_geomTargetLine.
Polygon.Add(
new PointF((
float)dfX, (
float)(dfY * m_fYScale)));
308 if (m_rgGeomPredictedLines.Count == 0)
310 bool bEmphasize = (m_rgEmphasize ==
null || m_rgEmphasize.Count == 0) || (m_rgEmphasize.Count > 0 && m_rgEmphasize[0]);
311 List<Color> rgClr = (bEmphasize) ? m_rgPalleteEmphasize : m_rgPallete;
313 geomPredictLine.
Polygon.Clear();
314 geomPredictLine.
SetLocation(0, fScale * (fWorldHeight / 2));
315 m_rgGeomPredictedLines.Add(geomPredictLine);
318 if (m_rgstrLabels !=
null && m_rgstrLabels.Count > 0)
320 if (m_rgEmphasize[0])
322 List<Color> rgClr = ((m_rgEmphasize ==
null || m_rgEmphasize.Count == 0) || (m_rgEmphasize.Count > 0 && m_rgEmphasize[0])) ? m_rgPalleteEmphasize : m_rgPallete;
323 m_rgGeomPredictedLines[0].SetColors(rgClr[0], rgClr[0]);
327 while (m_rgGeomPredictedLines.Count < m_rgstrLabels.Count && nIdx < m_rgPallete.Count)
329 bool bEmphasize = (m_rgEmphasize ==
null || m_rgEmphasize.Count == 0) || (m_rgEmphasize.Count > 0 && m_rgEmphasize[nIdx]);
330 List<Color> rgClr = (bEmphasize) ? m_rgPalleteEmphasize : m_rgPallete;
332 geomPredictLine.
Polygon.Clear();
333 geomPredictLine.
SetLocation(0, fScale * (fWorldHeight / 2));
334 m_rgGeomPredictedLines.Add(geomPredictLine);
339 for (
int i=0; i < m_rgGeomPredictedLines.Count; i++)
341 double dfPredicted = 0;
342 if (i+nPredictedIdx < rgData.Length)
343 dfPredicted = rgData[nPredictedIdx + i];
345 m_rgGeomPredictedLines[i].Polygon.Add(
new PointF((
float)dfX, (
float)(dfPredicted * m_fYScale)));
350 view.
RenderText(g,
"X = " + dfX.ToString(
"N02"), 10, 24);
351 view.
RenderText(g,
"Y = " + dfY.ToString(
"N02"), 10, 36);
354 if (m_rgstrLabels !=
null && m_rgstrLabels.Count > 0)
356 for (
int i = 0; i < m_rgGeomPredictedLines.Count && i < m_rgstrLabels.Count; i++)
358 bool bEmphasize = (m_rgEmphasize ==
null || m_rgEmphasize.Count == 0) || (m_rgEmphasize.Count > 0 && m_rgEmphasize[i]);
359 Dictionary<Color, Brush> rgClr = (bEmphasize) ? m_rgBrushesEmphasize : m_rgBrushes;
360 view.
RenderText(g,
"Predicted Y (" + m_rgstrLabels[i] +
") = " + rgData[nPredictedIdx + i].ToString(
"N02"), 10, nY, rgClr[m_rgPalleteEmphasize[i]]);
366 bool bEmphasize = (m_rgEmphasize ==
null || m_rgEmphasize.Count == 0) || (m_rgEmphasize.Count > 0 && m_rgEmphasize[0]);
367 Dictionary<Color, Brush> rgClr = (bEmphasize) ? m_rgBrushesEmphasize : m_rgBrushes;
368 view.
RenderText(g,
"Predicted Y = " + rgData[nPredictedIdx].ToString(
"N02"), 10, nY, rgClr[m_rgPalleteEmphasize[0]]);
372 view.
RenderText(g,
"Curve Type = " + m_curveType.ToString(), 10, nY);
379 for (
int i = 0; i < m_rgGeomPredictedLines.Count; i++)
381 view.
AddObject(m_rgGeomPredictedLines[i]);
387 sdAction = getActionData((
float)dfX, (
float)dfY, (
float)fWorldWidth, (
float)fWorldHeight, bmp);
393 return new Tuple<Bitmap, SimpleDatum>(bmp, sdAction);
396 private SimpleDatum getActionData(
float fX,
float fY,
float fWid,
float fHt, Bitmap bmpSrc)
398 double dfX = (fWid * 0.85);
399 double dfY = (bmpSrc.Height - fY) - (fHt * 0.75);
401 RectangleF rc =
new RectangleF((
float)dfX, (
float)dfY, fWid, fHt);
402 Bitmap bmp =
new Bitmap((
int)fWid, (
int)fHt);
404 using (Graphics g = Graphics.FromImage(bmp))
406 RectangleF rc1 =
new RectangleF(0, 0, (
float)fWid, (
float)fHt);
407 g.FillRectangle(Brushes.Black, rc1);
408 g.DrawImage(bmpSrc, rc1, rc, GraphicsUnit.Pixel);
424 double dfPredictedY = 0;
426 m_bRenderImage =
true;
427 m_rgPrevPoints.Clear();
431 m_state =
new CurveState(dfX, dfY,
new List<double>() { dfPredictedY });
435 bool bTraining = props.GetPropertyAsBool(
"Training",
false);
437 m_bRenderImage =
false;
439 m_fLastY = (float)props.GetPropertyAsDouble(
"TrainingStart", 0);
442 return new Tuple<State, double, bool>(m_state.Clone(), 1,
false);
445 private double randomUniform(
double dfMin,
double dfMax)
447 double dfRange = dfMax - dfMin;
448 return dfMin + (m_random.NextDouble() * dfRange);
451 private double calculateTarget(
double dfX)
456 return Math.Sin(dfX);
459 return Math.Cos(dfX);
462 float fCurve = m_fLastY + (float)(m_random.NextDouble() - 0.5) * (float)(m_random.NextDouble() * 0.20f);
471 throw new Exception(
Name +
" does not support the curve type '" + m_curveType.ToString() +
"'.");
482 public Tuple<State, double, bool>
Step(
int nAction,
bool bGetLabel,
PropertySet propExtra =
null)
484 CurveState state =
new CurveState(m_state);
486 List<double> rgOverrides =
new List<double>();
487 List<string> rgOverrideNames =
new List<string>();
488 List<bool> rgOverrideEmphasize =
new List<bool>();
489 double? dfOverride =
null;
491 m_bRenderImage =
true;
493 if (propExtra !=
null)
495 bool bTraining = propExtra.GetPropertyAsBool(
"Training",
false);
497 m_bRenderImage =
false;
499 double dfCount = propExtra.GetPropertyAsDouble(
"override_predictions", 0);
502 for (
int i = 0; i < (int)dfCount; i++)
504 double dfVal = propExtra.GetPropertyAsDouble(
"override_prediction" + i.ToString(),
double.MaxValue);
505 if (dfVal !=
double.MaxValue)
506 rgOverrides.Add(dfVal);
508 string strName = propExtra.GetProperty(
"override_prediction" + i.ToString() +
"_name");
509 if (!
string.IsNullOrEmpty(strName))
510 rgOverrideNames.Add(strName);
512 string strVal = propExtra.GetProperty(
"override_prediction" + i.ToString() +
"_emphasize");
514 if (
bool.TryParse(strVal, out bEmphasize))
515 rgOverrideEmphasize.Add(bEmphasize);
520 double dfVal = propExtra.GetPropertyAsDouble(
"override_prediction",
double.MaxValue);
521 if (dfVal !=
double.MaxValue)
522 rgOverrides.Add(dfVal);
525 if (rgOverrides.Count > 0)
526 dfOverride = rgOverrides[0];
529 processAction((
ACTION)nAction, dfOverride);
531 double dfX = state.X;
532 double dfY = state.Y;
533 double dfPredictedY = ((dfOverride.HasValue) ? dfOverride.Value : state.PredictedY);
539 dfY = calculateTarget(m_fX);
542 float[] rgInput =
new float[] { m_fX };
543 float[] rgMask =
new float[] { 1 };
544 float fTarget = (float)dfY;
545 float fTime = (float)(dfX / m_nMaxPlots);
547 DataPoint pt =
new DataPoint(rgInput, rgMask, fTarget, rgOverrides, rgOverrideNames, rgOverrideEmphasize, fTime);
548 m_rgPrevPoints.Add(pt);
550 if (m_rgPrevPoints.Count > m_nMaxPlots)
551 m_rgPrevPoints.RemoveAt(0);
553 CurveState stateOut = m_state;
554 m_state =
new CurveState(dfX, dfY, rgOverrides, rgOverrideNames, rgOverrideEmphasize, m_rgPrevPoints);
556 dfReward = 1.0 - Math.Abs(dfPredictedY - dfY);
561 m_nMaxSteps = Math.Max(m_nMaxSteps, m_nSteps);
563 stateOut.Steps = m_nSteps;
564 return new Tuple<State, double, bool>(stateOut.Clone(), dfReward,
false);
599 class GeomGraph : GeomPolygon
601 public GeomGraph(
float fL,
float fR,
float fT,
float fB, Color clrFill, Color clrBorder)
602 : base(fL, fR, fT, fB, clrFill, clrBorder)
606 public override void Render(Graphics g)
612 class CurveState : State
616 List<double> m_rgdfPredictedY =
null;
617 List<string> m_rgstrPredictedY =
null;
618 List<bool> m_rgbPredictedY =
null;
621 public const double MAX_X = 2.4;
622 public const double MAX_Y = 2.4;
624 public CurveState(
double dfX = 0,
double dfY = 0, List<double> rgdfPredictedY =
null, List<string> rgstrPredictedY =
null, List<bool> rgbPredictedY =
null, List<DataPoint> rgPoints =
null)
628 m_rgdfPredictedY = rgdfPredictedY;
629 m_rgstrPredictedY = rgstrPredictedY;
630 m_rgbPredictedY = rgbPredictedY;
632 if (rgPoints !=
null)
633 m_rgPrevPoints = rgPoints;
636 public CurveState(CurveState s)
640 m_rgdfPredictedY =
Utility.Clone<
double>(s.m_rgdfPredictedY);
641 m_rgstrPredictedY =
Utility.Clone<
string>(s.m_rgstrPredictedY);
642 m_rgbPredictedY =
Utility.Clone<
bool>(s.m_rgbPredictedY);
643 m_nSteps = s.m_nSteps;
644 m_rgPrevPoints =
new List<DataPoint>();
646 if (s.m_rgPrevPoints !=
null)
647 m_rgPrevPoints.AddRange(s.m_rgPrevPoints);
652 get {
return m_nSteps; }
653 set { m_nSteps = value; }
658 get {
return m_dfX; }
659 set { m_dfX = value; }
664 get {
return m_dfY; }
665 set { m_dfY = value; }
668 public double PredictedY
672 if (m_rgdfPredictedY ==
null || m_rgdfPredictedY.Count == 0)
675 return m_rgdfPredictedY[0];
679 public List<double> PredictedYValues
681 get {
return m_rgdfPredictedY; }
682 set { m_rgdfPredictedY = value; }
685 public List<string> PredictedYNames
687 get {
return m_rgstrPredictedY; }
688 set { m_rgstrPredictedY = value; }
691 public List<bool> PredictedYEmphasize
693 get {
return m_rgbPredictedY; }
694 set { m_rgbPredictedY = value; }
697 public override State Clone()
699 return new CurveState(
this);
702 public override SimpleDatum GetData(
bool bNormalize, out
int nDataLen)
707 double dfPredictedY = 0;
709 if (m_rgdfPredictedY !=
null && m_rgdfPredictedY.Count > 0)
710 dfPredictedY = m_rgdfPredictedY[0];
712 data.
SetPixel(0, 0, getValue(m_dfX, -MAX_X, MAX_X, bNormalize));
713 data.
SetPixel(0, 1, getValue(m_dfY, -MAX_Y, MAX_Y, bNormalize));
714 data.
SetPixel(0, 2, getValue(dfPredictedY, -MAX_Y, MAX_Y, bNormalize));
720 private double getValue(
double dfVal,
double dfMin,
double dfMax,
bool bNormalize)
725 return (dfVal - dfMin) / (dfMax - dfMin);
The ColorMapper maps a value within a number range, to a Color within a color scheme.
The ImageData class is a helper class used to convert between Datum, other raw data,...
static Datum GetImageDataD(Bitmap bmp, int nChannels, bool bDataIsReal, int nLabel, bool bUseLockBitmap=true, int[] rgFocusMap=null)
The GetImageDataD function converts a Bitmap into a Datum using the double type for real data.
The Log class provides general output in text form.
Specifies a key-value pair of properties.
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
The SimpleDatum class holds a data input within host memory.
The Utility class provides general utility funtions.
The Realmap operates similar to a bitmap but is actually just an array of doubles.
void SetPixel(int nX, int nY, double clr)
Set a given pixel to a given color.
The DatasetDescriptor class describes a dataset which contains both a training data source and testin...
The SourceDescriptor class contains all information describing a data source.
The Curve Gym provides a simulation of continuous curve such as Sin or Cos.
DATA_TYPE[] SupportedDataType
Returns the data types supported by this gym.
double TestingPercent
Returns the testinng percent of -1, which then uses the default of 0.2.
void Close()
Shutdown and close the gym.
bool RequiresDisplayImage
Returns false indicating that this Gym does not require a display image.
Tuple< State, double, bool > Reset(bool bGetLabel, PropertySet props=null)
Reset the state of the gym.
CurveGym()
The constructor.
IXMyCaffeGym Clone(PropertySet properties=null)
Create a new copy of the gym.
DatasetDescriptor GetDataset(DATA_TYPE dt, Log log=null)
Returns the dataset descriptor of the dynamic dataset produced by the Gym.
string Name
Returns the gym's name.
DATA_TYPE SelectedDataType
Returns the selected data type.
CURVE_TYPE
Defines the curve types.
int UiDelay
Returns the delay to use (if any) when the user-display is visible.
void Dispose()
Release all resources used.
void Initialize(Log log, PropertySet properties)
Initialize the gym with the specified properties.
Tuple< Bitmap, SimpleDatum > Render(bool bShowUi, int nWidth, int nHeight, bool bGetAction)
Render the gym's current state on a bitmap and SimpleDatum.
Tuple< Bitmap, SimpleDatum > Render(bool bShowUi, int nWidth, int nHeight, double[] rgData, bool bGetAction)
Render the gyms specified data.
ACTION
Defines the actions to perform.
Tuple< State, double, bool > Step(int nAction, bool bGetLabel, PropertySet propExtra=null)
Step the gym one step in its simulation.
Dictionary< string, int > GetActionSpace()
Returns the action space as a dictionary of name,actionid pairs.
The DataPoint contains the data used when training.
virtual void SetLocation(float fX, float fY)
Sets the object location.
List< PointF > Polygon
Returns the bounds as a Polygon.
The GeomPolyLine object is used to render an spline.
The GeomView manages and renders a collection of Geometric objects.
void Render(Graphics g)
Renders the view.
void AddObject(GeomObj obj)
Add a new geometric object to the view.
void RenderText(Graphics g, string str, float fX, float fY, Brush br=null)
Render text at a location.
void RenderSteps(Graphics g, int nSteps, int nMax)
Renders the Gym step information.
The IXMyCaffeGym interface is used to interact with each Gym.
The descriptors namespace contains all descriptor used to describe various items stored within the da...
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
@ RANDOM
Randomly select the images, ignore the input index.
GYM_TYPE
Defines the gym type (if any).
DATA_TYPE
Defines the gym data type.
The MyCaffe.gym namespace contains all classes related to the Gym's supported by MyCaffe.
GYM_SRC_TRAIN_ID
Defines the Standard GYM Training Data Source ID's.
GYM_DS_ID
Defines the Standard GYM Dataset ID's.
GYM_SRC_TEST_ID
Defines the Standard GYM Testing Data Source ID's.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...