A Discrete-Event Network Simulator
API
thompson-sampling-wifi-manager.cc
Go to the documentation of this file.
1 /* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
2 /*
3  * Copyright (c) 2021 IITP RAS
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License version 2 as
7  * published by the Free Software Foundation;
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, write to the Free Software
16  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17  *
18  * Author: Alexander Krotov <krotov@iitp.ru>
19  */
20 
21 #include "ns3/log.h"
22 #include "ns3/double.h"
23 #include "ns3/core-module.h"
24 #include "ns3/packet.h"
25 
26 #include "ns3/wifi-phy.h"
27 
29 
30 #include <cstdint>
31 #include <cstdlib>
32 #include <fstream>
33 #include <iostream>
34 #include <string>
35 
36 namespace ns3 {
37 
42 struct RateStats {
44  uint16_t channelWidth;
45  uint8_t nss;
46 
47  double success{0.0};
48  double fails{0.0};
50 };
51 
59 {
60  size_t m_nextMode;
61  size_t m_lastMode;
62 
63  std::vector<RateStats> m_mcsStats;
64 };
65 
67 
68 NS_LOG_COMPONENT_DEFINE ("ThompsonSamplingWifiManager");
69 
70 TypeId
72 {
73  static TypeId tid = TypeId ("ns3::ThompsonSamplingWifiManager")
75  .SetGroupName ("Wifi")
76  .AddConstructor<ThompsonSamplingWifiManager> ()
77  .AddAttribute ("Decay",
78  "Exponential decay coefficient, Hz; zero is a valid value for static scenarios",
79  DoubleValue (1.0),
81  MakeDoubleChecker<double> (0.0))
82  .AddTraceSource ("Rate",
83  "Traced value for rate changes (b/s)",
85  "ns3::TracedValueCallback::Uint64")
86  ;
87  return tid;
88 }
89 
91  : m_currentRate{0}
92 {
93  NS_LOG_FUNCTION (this);
94 
95  m_gammaRandomVariable = CreateObject<GammaRandomVariable> ();
96 }
97 
99 {
100  NS_LOG_FUNCTION (this);
101 }
102 
105 {
106  NS_LOG_FUNCTION (this);
108  station->m_nextMode = 0;
109  station->m_lastMode = 0;
110  return station;
111 }
112 
113 void
115 {
116  auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
117  if (!station->m_mcsStats.empty ())
118  {
119  return;
120  }
121 
122  // Add HT, VHT or HE MCSes
123  for (const auto &mode : GetPhy ()->GetMcsList ())
124  {
125  for (uint16_t j = 20; j <= GetPhy ()->GetChannelWidth (); j *= 2)
126  {
127  WifiModulationClass modulationClass = WIFI_MOD_CLASS_HT;
128  if (GetVhtSupported ())
129  {
130  modulationClass = WIFI_MOD_CLASS_VHT;
131  }
132  if (GetHeSupported ())
133  {
134  modulationClass = WIFI_MOD_CLASS_HE;
135  }
136  if (mode.GetModulationClass () == modulationClass)
137  {
138  for (uint8_t k = 1; k <= GetPhy ()->GetMaxSupportedTxSpatialStreams (); k++)
139  {
140  if (mode.IsAllowed (j, k))
141  {
142  RateStats stats;
143  stats.mode = mode;
144  stats.channelWidth = j;
145  stats.nss = k;
146 
147  station->m_mcsStats.push_back (stats);
148  }
149  }
150  }
151  }
152  }
153 
154  if (station->m_mcsStats.empty ())
155  {
156  // Add legacy non-HT modes.
157  for (uint8_t i = 0; i < GetNSupported (station); i++)
158  {
159  RateStats stats;
160  stats.mode = GetSupported (station, i);
163  {
164  stats.channelWidth = 22;
165  }
166  else
167  {
168  stats.channelWidth = 20;
169  }
170  stats.nss = 1;
171  station->m_mcsStats.push_back (stats);
172  }
173  }
174 
175  NS_ASSERT_MSG (!station->m_mcsStats.empty (), "No usable MCS found");
176 
177  UpdateNextMode (st);
178 }
179 
180 void
182 {
183  NS_LOG_FUNCTION (this << station << rxSnr << txMode);
184 }
185 
186 void
188 {
189  NS_LOG_FUNCTION (this << station);
190 }
191 
192 void
194 {
195  NS_LOG_FUNCTION (this << st);
196  InitializeStation (st);
197  auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
198  Decay (st, station->m_lastMode);
199  station->m_mcsStats.at (station->m_lastMode).fails++;
200  UpdateNextMode (st);
201 }
202 
203 void
205  double rtsSnr)
206 {
207  NS_LOG_FUNCTION (this << st << ctsSnr << ctsMode.GetUniqueName () << rtsSnr);
208 }
209 
210 void
212 {
213  InitializeStation (st);
214  auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
215 
216  double maxThroughput = 0.0;
217  double frameSuccessRate = 1.0;
218 
219  NS_ASSERT (!station->m_mcsStats.empty ());
220 
221  // Use the most robust MCS if frameSuccessRate is 0 for all MCS.
222  station->m_nextMode = 0;
223 
224  for (uint32_t i = 0; i < station->m_mcsStats.size (); i++)
225  {
226  Decay (st, i);
227  const WifiMode mode{station->m_mcsStats.at (i).mode};
228 
229  uint16_t guardInterval = GetModeGuardInterval (st, mode);
230  double rate = mode.GetDataRate (station->m_mcsStats.at (i).channelWidth,
231  guardInterval,
232  station->m_mcsStats.at (i).nss);
233 
234  // Thompson sampling
235  frameSuccessRate = SampleBetaVariable (1.0 + station->m_mcsStats.at (i).success,
236  1.0 + station->m_mcsStats.at (i).fails);
237  NS_LOG_DEBUG ("Draw"
238  << " success=" << station->m_mcsStats.at (i).success
239  << " fails=" << station->m_mcsStats.at (i).fails
240  << " frameSuccessRate=" << frameSuccessRate
241  << " mode=" << mode);
242  if (frameSuccessRate * rate > maxThroughput)
243  {
244  maxThroughput = frameSuccessRate * rate;
245  station->m_nextMode = i;
246  }
247  }
248 }
249 
250 void
252  double dataSnr, uint16_t dataChannelWidth, uint8_t dataNss)
253 {
254  NS_LOG_FUNCTION (this << st << ackSnr << ackMode.GetUniqueName () << dataSnr);
255  InitializeStation (st);
256  auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
257  Decay (st, station->m_lastMode);
258  station->m_mcsStats.at (station->m_lastMode).success++;
259  UpdateNextMode (st);
260 }
261 
262 void
264  uint16_t nFailedMpdus, double rxSnr, double dataSnr,
265  uint16_t dataChannelWidth, uint8_t dataNss)
266 {
267  NS_LOG_FUNCTION (this << st << nSuccessfulMpdus << nFailedMpdus << rxSnr << dataSnr);
268  InitializeStation (st);
269  auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
270 
271  Decay (st, station->m_lastMode);
272  station->m_mcsStats.at (station->m_lastMode).success += nSuccessfulMpdus;
273  station->m_mcsStats.at (station->m_lastMode).fails += nFailedMpdus;
274 
275  UpdateNextMode (st);
276 }
277 
278 void
280 {
281  NS_LOG_FUNCTION (this << station);
282 }
283 
284 void
286 {
287  NS_LOG_FUNCTION (this << station);
288 }
289 
290 uint16_t
292 {
293  if (mode.GetModulationClass () == WIFI_MOD_CLASS_HE)
294  {
295  return std::max (GetGuardInterval (st), GetGuardInterval ());
296  }
297  else if ((mode.GetModulationClass () == WIFI_MOD_CLASS_HT) ||
299  {
300  return std::max<uint16_t> (GetShortGuardIntervalSupported (st) ? 400 : 800,
301  GetShortGuardIntervalSupported () ? 400 : 800);
302  }
303  else
304  {
305  return 800;
306  }
307 }
308 
311 {
312  NS_LOG_FUNCTION (this << st);
313  InitializeStation (st);
314  auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
315 
316  auto &stats = station->m_mcsStats.at (station->m_nextMode);
317  WifiMode mode = stats.mode;
318  uint16_t channelWidth = std::min (stats.channelWidth, GetPhy ()->GetChannelWidth ());
319  uint8_t nss = stats.nss;
320  uint16_t guardInterval = GetModeGuardInterval (st, mode);
321 
322  station->m_lastMode = station->m_nextMode;
323 
324  NS_LOG_DEBUG ("Using"
325  << " mode=" << mode
326  << " channelWidth=" << channelWidth
327  << " nss=" << +nss
328  << " guardInterval=" << guardInterval);
329 
330  uint64_t rate = mode.GetDataRate (channelWidth, guardInterval, nss);
331  if (m_currentRate != rate)
332  {
333  NS_LOG_DEBUG ("New datarate: " << rate);
334  m_currentRate = rate;
335  }
336 
337  return WifiTxVector (
338  mode,
342  GetModeGuardInterval (st, mode),
344  nss,
345  0, // NESS
346  GetChannelWidthForTransmission (mode, channelWidth),
347  GetAggregation (station),
348  false);
349 }
350 
353 {
354  NS_LOG_FUNCTION (this << st);
355  InitializeStation (st);
356  auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
357 
358  // Use the most robust MCS for the control channel.
359  auto &stats = station->m_mcsStats.at (0);
360  WifiMode mode = stats.mode;
361  uint16_t channelWidth = std::min (stats.channelWidth, GetPhy ()->GetChannelWidth ());
362  uint8_t nss = stats.nss;
363 
364  // Make sure control frames are sent using 1 spatial stream.
365  NS_ASSERT (nss == 1);
366 
367  return WifiTxVector (
368  mode, GetDefaultTxPowerLevel (),
370  GetModeGuardInterval (st, mode),
372  nss,
373  0, // NESS
374  GetChannelWidthForTransmission (mode, channelWidth),
375  GetAggregation (station),
376  false);
377 }
378 
379 double
381 {
382  double X = m_gammaRandomVariable->GetValue (alpha, 1.0);
383  double Y = m_gammaRandomVariable->GetValue (beta, 1.0);
384  return X / (X + Y);
385 }
386 
387 void
389 {
390  NS_LOG_FUNCTION (this << st << i);
391  InitializeStation (st);
392  auto station = static_cast<ThompsonSamplingWifiRemoteStation *> (st);
393 
394  Time now = Simulator::Now ();
395  auto &stats = station->m_mcsStats.at (i);
396  if (now > stats.lastDecay)
397  {
398  const double coefficient =
399  std::exp (m_decay * (stats.lastDecay - now).GetSeconds ());
400 
401  stats.success *= coefficient;
402  stats.fails *= coefficient;
403  stats.lastDecay = now;
404  }
405 }
406 
407 int64_t
409 {
410  NS_LOG_FUNCTION (this << stream);
411  m_gammaRandomVariable->SetStream (stream);
412  return 1;
413 }
414 
415 } //namespace ns3
#define min(a, b)
Definition: 80211b.c:42
#define max(a, b)
Definition: 80211b.c:43
This class can be used to hold variables of floating point type such as 'double' or 'float'.
Definition: double.h:41
static Time Now(void)
Return the current simulation virtual time.
Definition: simulator.cc:195
Thompson Sampling rate control algorithm.
uint16_t GetModeGuardInterval(WifiRemoteStation *st, WifiMode mode) const
Returns guard interval in nanoseconds for the given mode.
void DoReportRxOk(WifiRemoteStation *station, double rxSnr, WifiMode txMode) override
This method is a pure virtual method that must be implemented by the sub-class.
void InitializeStation(WifiRemoteStation *station) const
Initializes station rate tables.
void DoReportDataFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
void DoReportDataOk(WifiRemoteStation *station, double ackSnr, WifiMode ackMode, double dataSnr, uint16_t dataChannelWidth, uint8_t dataNss) override
This method is a pure virtual method that must be implemented by the sub-class.
void DoReportAmpduTxStatus(WifiRemoteStation *station, uint16_t nSuccessfulMpdus, uint16_t nFailedMpdus, double rxSnr, double dataSnr, uint16_t dataChannelWidth, uint8_t dataNss) override
Typically called per A-MPDU, either when a Block ACK was successfully received or when a BlockAckTime...
TracedValue< uint64_t > m_currentRate
Trace rate changes.
double SampleBetaVariable(uint64_t alpha, uint64_t beta) const
Sample beta random variable with given parameters.
WifiRemoteStation * DoCreateStation() const override
Ptr< GammaRandomVariable > m_gammaRandomVariable
Variable used to sample beta-distributed random variables.
void DoReportFinalRtsFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
static TypeId GetTypeId(void)
Get the type ID.
WifiTxVector DoGetDataTxVector(WifiRemoteStation *station) override
void DoReportFinalDataFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
double m_decay
Exponential decay coefficient, Hz.
void DoReportRtsFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
void UpdateNextMode(WifiRemoteStation *station) const
Draws a new MCS and related parameters to try next time for this station.
int64_t AssignStreams(int64_t stream) override
Assign a fixed random variable stream number to the random variables used by this model.
WifiTxVector DoGetRtsTxVector(WifiRemoteStation *station) override
void DoReportRtsOk(WifiRemoteStation *station, double ctsSnr, WifiMode ctsMode, double rtsSnr) override
This method is a pure virtual method that must be implemented by the sub-class.
void Decay(WifiRemoteStation *st, size_t i) const
Applies exponential decay to MCS statistics.
Simulation virtual time values and global simulation resolution.
Definition: nstime.h:103
a unique identifier for an interface.
Definition: type-id.h:59
TypeId SetParent(TypeId tid)
Set the parent TypeId.
Definition: type-id.cc:922
represent a single transmission mode
Definition: wifi-mode.h:48
WifiModulationClass GetModulationClass() const
Definition: wifi-mode.cc:177
std::string GetUniqueName(void) const
Definition: wifi-mode.cc:140
uint64_t GetDataRate(uint16_t channelWidth, uint16_t guardInterval, uint8_t nss) const
Definition: wifi-mode.cc:114
std::list< WifiMode > GetMcsList(void) const
The WifiPhy::GetMcsList() method is used (e.g., by a WifiRemoteStationManager) to determine the set o...
Definition: wifi-phy.cc:1767
uint8_t GetMaxSupportedTxSpatialStreams(void) const
Definition: wifi-phy.cc:1120
uint16_t GetChannelWidth(void) const
Definition: wifi-phy.cc:918
hold a list of per-remote-station state.
uint16_t GetChannelWidth(const WifiRemoteStation *station) const
Return the channel width supported by the station.
bool GetVhtSupported(void) const
Return whether the device has VHT capability support enabled.
Ptr< WifiPhy > GetPhy(void) const
Return the WifiPhy.
uint8_t GetNSupported(const WifiRemoteStation *station) const
Return the number of modes supported by the given station.
bool GetAggregation(const WifiRemoteStation *station) const
Return whether the given station supports A-MPDU.
bool GetShortPreambleEnabled(void) const
Return whether the device uses short PHY preambles.
bool GetHeSupported(void) const
Return whether the device has HE capability support enabled.
bool GetShortGuardIntervalSupported(void) const
Return whether the device has SGI support enabled.
WifiMode GetSupported(const WifiRemoteStation *station, uint8_t i) const
Return whether mode associated with the specified station at the specified index.
uint16_t GetGuardInterval(void) const
Return the supported HE guard interval duration (in nanoseconds).
This class mimics the TXVECTOR which is to be passed to the PHY in order to define the parameters whi...
#define NS_ASSERT(condition)
At runtime, in debugging builds, if this condition is not true, the program prints the source file,...
Definition: assert.h:67
#define NS_ASSERT_MSG(condition, message)
At runtime, in debugging builds, if this condition is not true, the program prints the message to out...
Definition: assert.h:88
Ptr< const AttributeAccessor > MakeDoubleAccessor(T1 a1)
Create an AttributeAccessor for a class data member, or a lone class get functor or set method.
Definition: double.h:42
#define NS_LOG_COMPONENT_DEFINE(name)
Define a Log component with a specific name.
Definition: log.h:205
#define NS_LOG_DEBUG(msg)
Use NS_LOG to output a message of level LOG_DEBUG.
Definition: log.h:273
#define NS_LOG_FUNCTION(parameters)
If log level LOG_FUNCTION is enabled, this macro will output all input parameters separated by ",...
#define NS_OBJECT_ENSURE_REGISTERED(type)
Register an Object subclass with the TypeId system.
Definition: object-base.h:45
Ptr< const TraceSourceAccessor > MakeTraceSourceAccessor(T a)
Create a TraceSourceAccessor which will control access to the underlying trace source.
WifiModulationClass
This enumeration defines the modulation classes per (Table 10-6 "Modulation classes"; IEEE 802....
@ WIFI_MOD_CLASS_HR_DSSS
HR/DSSS (Clause 16)
@ WIFI_MOD_CLASS_HT
HT (Clause 19)
@ WIFI_MOD_CLASS_VHT
VHT (Clause 22)
@ WIFI_MOD_CLASS_HE
HE (Clause 27)
@ WIFI_MOD_CLASS_DSSS
DSSS (Clause 15)
Every class exported by the ns3 library is enclosed in the ns3 namespace.
uint16_t GetChannelWidthForTransmission(WifiMode mode, uint16_t maxAllowedChannelWidth)
Return the channel width that is allowed based on the selected mode and the given maximum channel wid...
WifiPreamble GetPreambleForTransmission(WifiModulationClass modulation, bool useShortPreamble)
Return the preamble to be used for the transmission.
float alpha
Plot alpha value (transparency)
A structure containing parameters of a single rate and its statistics.
uint16_t channelWidth
channel width in MHz
uint8_t nss
Number of spatial streams.
double success
averaged number of successful transmissions
double fails
averaged number of failed transmissions
Time lastDecay
last time exponential decay was applied to this rate
Holds station state and collected statistics.
size_t m_nextMode
Mode to select for the next transmission.
std::vector< RateStats > m_mcsStats
Collected statistics.
size_t m_lastMode
Most recently used mode, used to write statistics.
hold per-remote-station state.