A Discrete-Event Network Simulator
API
thompson-sampling-wifi-manager.cc
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021 IITP RAS
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License version 2 as
6  * published by the Free Software Foundation;
7  *
8  * This program is distributed in the hope that it will be useful,
9  * but WITHOUT ANY WARRANTY; without even the implied warranty of
10  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11  * GNU General Public License for more details.
12  *
13  * You should have received a copy of the GNU General Public License
14  * along with this program; if not, write to the Free Software
15  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
16  *
17  * Author: Alexander Krotov <krotov@iitp.ru>
18  */
19 
21 
22 #include "ns3/core-module.h"
23 #include "ns3/double.h"
24 #include "ns3/log.h"
25 #include "ns3/packet.h"
26 #include "ns3/wifi-phy.h"
27 
28 #include <cstdint>
29 #include <cstdlib>
30 #include <fstream>
31 #include <iostream>
32 #include <string>
33 
34 namespace ns3
35 {
36 
41 struct RateStats
42 {
44  uint16_t channelWidth;
45  uint8_t nss;
46 
47  double success{0.0};
48  double fails{0.0};
50 };
51 
59 {
60  size_t m_nextMode;
61  size_t m_lastMode;
62 
63  std::vector<RateStats> m_mcsStats;
64 };
65 
67 
68 NS_LOG_COMPONENT_DEFINE("ThompsonSamplingWifiManager");
69 
70 TypeId
72 {
73  static TypeId tid =
74  TypeId("ns3::ThompsonSamplingWifiManager")
76  .SetGroupName("Wifi")
77  .AddConstructor<ThompsonSamplingWifiManager>()
78  .AddAttribute(
79  "Decay",
80  "Exponential decay coefficient, Hz; zero is a valid value for static scenarios",
81  DoubleValue(1.0),
83  MakeDoubleChecker<double>(0.0))
84  .AddTraceSource("Rate",
85  "Traced value for rate changes (b/s)",
87  "ns3::TracedValueCallback::Uint64");
88  return tid;
89 }
90 
92  : m_currentRate{0}
93 {
94  NS_LOG_FUNCTION(this);
95 
96  m_gammaRandomVariable = CreateObject<GammaRandomVariable>();
97 }
98 
100 {
101  NS_LOG_FUNCTION(this);
102 }
103 
106 {
107  NS_LOG_FUNCTION(this);
108  auto station = new ThompsonSamplingWifiRemoteStation();
109  station->m_nextMode = 0;
110  station->m_lastMode = 0;
111  return station;
112 }
113 
114 void
116 {
117  auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
118  if (!station->m_mcsStats.empty())
119  {
120  return;
121  }
122 
123  // Add HT, VHT or HE MCSes
124  for (const auto& mode : GetPhy()->GetMcsList())
125  {
126  for (uint16_t j = 20; j <= GetPhy()->GetChannelWidth(); j *= 2)
127  {
128  WifiModulationClass modulationClass = WIFI_MOD_CLASS_HT;
129  if (GetVhtSupported())
130  {
131  modulationClass = WIFI_MOD_CLASS_VHT;
132  }
133  if (GetHeSupported())
134  {
135  modulationClass = WIFI_MOD_CLASS_HE;
136  }
137  if (mode.GetModulationClass() == modulationClass)
138  {
139  for (uint8_t k = 1; k <= GetPhy()->GetMaxSupportedTxSpatialStreams(); k++)
140  {
141  if (mode.IsAllowed(j, k))
142  {
143  RateStats stats;
144  stats.mode = mode;
145  stats.channelWidth = j;
146  stats.nss = k;
147 
148  station->m_mcsStats.push_back(stats);
149  }
150  }
151  }
152  }
153  }
154 
155  if (station->m_mcsStats.empty())
156  {
157  // Add legacy non-HT modes.
158  for (uint8_t i = 0; i < GetNSupported(station); i++)
159  {
160  RateStats stats;
161  stats.mode = GetSupported(station, i);
164  {
165  stats.channelWidth = 22;
166  }
167  else
168  {
169  stats.channelWidth = 20;
170  }
171  stats.nss = 1;
172  station->m_mcsStats.push_back(stats);
173  }
174  }
175 
176  NS_ASSERT_MSG(!station->m_mcsStats.empty(), "No usable MCS found");
177 
178  UpdateNextMode(st);
179 }
180 
181 void
183 {
184  NS_LOG_FUNCTION(this << station << rxSnr << txMode);
185 }
186 
187 void
189 {
190  NS_LOG_FUNCTION(this << station);
191 }
192 
193 void
195 {
196  NS_LOG_FUNCTION(this << st);
197  InitializeStation(st);
198  auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
199  Decay(st, station->m_lastMode);
200  station->m_mcsStats.at(station->m_lastMode).fails++;
201  UpdateNextMode(st);
202 }
203 
204 void
206  double ctsSnr,
207  WifiMode ctsMode,
208  double rtsSnr)
209 {
210  NS_LOG_FUNCTION(this << st << ctsSnr << ctsMode.GetUniqueName() << rtsSnr);
211 }
212 
213 void
215 {
216  InitializeStation(st);
217  auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
218 
219  double maxThroughput = 0.0;
220  double frameSuccessRate = 1.0;
221 
222  NS_ASSERT(!station->m_mcsStats.empty());
223 
224  // Use the most robust MCS if frameSuccessRate is 0 for all MCS.
225  station->m_nextMode = 0;
226 
227  for (uint32_t i = 0; i < station->m_mcsStats.size(); i++)
228  {
229  Decay(st, i);
230  const WifiMode mode{station->m_mcsStats.at(i).mode};
231 
232  uint16_t guardInterval = GetModeGuardInterval(st, mode);
233  double rate = mode.GetDataRate(station->m_mcsStats.at(i).channelWidth,
234  guardInterval,
235  station->m_mcsStats.at(i).nss);
236 
237  // Thompson sampling
238  frameSuccessRate = SampleBetaVariable(1.0 + station->m_mcsStats.at(i).success,
239  1.0 + station->m_mcsStats.at(i).fails);
240  NS_LOG_DEBUG("Draw"
241  << " success=" << station->m_mcsStats.at(i).success
242  << " fails=" << station->m_mcsStats.at(i).fails
243  << " frameSuccessRate=" << frameSuccessRate << " mode=" << mode);
244  if (frameSuccessRate * rate > maxThroughput)
245  {
246  maxThroughput = frameSuccessRate * rate;
247  station->m_nextMode = i;
248  }
249  }
250 }
251 
252 void
254  double ackSnr,
255  WifiMode ackMode,
256  double dataSnr,
257  uint16_t dataChannelWidth,
258  uint8_t dataNss)
259 {
260  NS_LOG_FUNCTION(this << st << ackSnr << ackMode.GetUniqueName() << dataSnr);
261  InitializeStation(st);
262  auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
263  Decay(st, station->m_lastMode);
264  station->m_mcsStats.at(station->m_lastMode).success++;
265  UpdateNextMode(st);
266 }
267 
268 void
270  uint16_t nSuccessfulMpdus,
271  uint16_t nFailedMpdus,
272  double rxSnr,
273  double dataSnr,
274  uint16_t dataChannelWidth,
275  uint8_t dataNss)
276 {
277  NS_LOG_FUNCTION(this << st << nSuccessfulMpdus << nFailedMpdus << rxSnr << dataSnr);
278  InitializeStation(st);
279  auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
280 
281  Decay(st, station->m_lastMode);
282  station->m_mcsStats.at(station->m_lastMode).success += nSuccessfulMpdus;
283  station->m_mcsStats.at(station->m_lastMode).fails += nFailedMpdus;
284 
285  UpdateNextMode(st);
286 }
287 
288 void
290 {
291  NS_LOG_FUNCTION(this << station);
292 }
293 
294 void
296 {
297  NS_LOG_FUNCTION(this << station);
298 }
299 
300 uint16_t
302 {
304  {
306  }
307  else if ((mode.GetModulationClass() == WIFI_MOD_CLASS_HT) ||
309  {
310  return std::max<uint16_t>(GetShortGuardIntervalSupported(st) ? 400 : 800,
311  GetShortGuardIntervalSupported() ? 400 : 800);
312  }
313  else
314  {
315  return 800;
316  }
317 }
318 
321 {
322  NS_LOG_FUNCTION(this << st << allowedWidth);
323  InitializeStation(st);
324  auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
325 
326  auto& stats = station->m_mcsStats.at(station->m_nextMode);
327  WifiMode mode = stats.mode;
328  uint16_t channelWidth = std::min(stats.channelWidth, allowedWidth);
329  uint8_t nss = stats.nss;
330  uint16_t guardInterval = GetModeGuardInterval(st, mode);
331 
332  station->m_lastMode = station->m_nextMode;
333 
334  NS_LOG_DEBUG("Using"
335  << " mode=" << mode << " channelWidth=" << channelWidth << " nss=" << +nss
336  << " guardInterval=" << guardInterval);
337 
338  uint64_t rate = mode.GetDataRate(channelWidth, guardInterval, nss);
339  if (m_currentRate != rate)
340  {
341  NS_LOG_DEBUG("New datarate: " << rate);
342  m_currentRate = rate;
343  }
344 
345  return WifiTxVector(
346  mode,
349  GetModeGuardInterval(st, mode),
351  nss,
352  0, // NESS
353  GetPhy()->GetTxBandwidth(mode, channelWidth),
354  GetAggregation(station),
355  false);
356 }
357 
360 {
361  NS_LOG_FUNCTION(this << st);
362  InitializeStation(st);
363  auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
364 
365  // Use the most robust MCS for the control channel.
366  auto& stats = station->m_mcsStats.at(0);
367  WifiMode mode = stats.mode;
368  uint8_t nss = stats.nss;
369 
370  // Make sure control frames are sent using 1 spatial stream.
371  NS_ASSERT(nss == 1);
372 
373  return WifiTxVector(
374  mode,
377  GetModeGuardInterval(st, mode),
379  nss,
380  0, // NESS
381  GetPhy()->GetTxBandwidth(mode, stats.channelWidth),
382  GetAggregation(station),
383  false);
384 }
385 
386 double
387 ThompsonSamplingWifiManager::SampleBetaVariable(uint64_t alpha, uint64_t beta) const
388 {
389  double X = m_gammaRandomVariable->GetValue(alpha, 1.0);
390  double Y = m_gammaRandomVariable->GetValue(beta, 1.0);
391  return X / (X + Y);
392 }
393 
394 void
396 {
397  NS_LOG_FUNCTION(this << st << i);
398  InitializeStation(st);
399  auto station = static_cast<ThompsonSamplingWifiRemoteStation*>(st);
400 
401  Time now = Simulator::Now();
402  auto& stats = station->m_mcsStats.at(i);
403  if (now > stats.lastDecay)
404  {
405  const double coefficient = std::exp(m_decay * (stats.lastDecay - now).GetSeconds());
406 
407  stats.success *= coefficient;
408  stats.fails *= coefficient;
409  stats.lastDecay = now;
410  }
411 }
412 
413 int64_t
415 {
416  NS_LOG_FUNCTION(this << stream);
417  m_gammaRandomVariable->SetStream(stream);
418  return 1;
419 }
420 
421 } // namespace ns3
#define min(a, b)
Definition: 80211b.c:41
#define max(a, b)
Definition: 80211b.c:42
This class can be used to hold variables of floating point type such as 'double' or 'float'.
Definition: double.h:42
static Time Now()
Return the current simulation virtual time.
Definition: simulator.cc:208
Thompson Sampling rate control algorithm.
uint16_t GetModeGuardInterval(WifiRemoteStation *st, WifiMode mode) const
Returns guard interval in nanoseconds for the given mode.
void DoReportRxOk(WifiRemoteStation *station, double rxSnr, WifiMode txMode) override
This method is a pure virtual method that must be implemented by the sub-class.
void InitializeStation(WifiRemoteStation *station) const
Initializes station rate tables.
void DoReportDataFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
void DoReportDataOk(WifiRemoteStation *station, double ackSnr, WifiMode ackMode, double dataSnr, uint16_t dataChannelWidth, uint8_t dataNss) override
This method is a pure virtual method that must be implemented by the sub-class.
void DoReportAmpduTxStatus(WifiRemoteStation *station, uint16_t nSuccessfulMpdus, uint16_t nFailedMpdus, double rxSnr, double dataSnr, uint16_t dataChannelWidth, uint8_t dataNss) override
Typically called per A-MPDU, either when a Block ACK was successfully received or when a BlockAckTime...
TracedValue< uint64_t > m_currentRate
Trace rate changes.
double SampleBetaVariable(uint64_t alpha, uint64_t beta) const
Sample beta random variable with given parameters.
WifiRemoteStation * DoCreateStation() const override
Ptr< GammaRandomVariable > m_gammaRandomVariable
Variable used to sample beta-distributed random variables.
void DoReportFinalRtsFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
void DoReportFinalDataFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
static TypeId GetTypeId()
Get the type ID.
WifiTxVector DoGetDataTxVector(WifiRemoteStation *station, uint16_t allowedWidth) override
double m_decay
Exponential decay coefficient, Hz.
void DoReportRtsFailed(WifiRemoteStation *station) override
This method is a pure virtual method that must be implemented by the sub-class.
void UpdateNextMode(WifiRemoteStation *station) const
Draws a new MCS and related parameters to try next time for this station.
int64_t AssignStreams(int64_t stream) override
Assign a fixed random variable stream number to the random variables used by this model.
WifiTxVector DoGetRtsTxVector(WifiRemoteStation *station) override
void DoReportRtsOk(WifiRemoteStation *station, double ctsSnr, WifiMode ctsMode, double rtsSnr) override
This method is a pure virtual method that must be implemented by the sub-class.
void Decay(WifiRemoteStation *st, size_t i) const
Applies exponential decay to MCS statistics.
Simulation virtual time values and global simulation resolution.
Definition: nstime.h:105
a unique identifier for an interface.
Definition: type-id.h:59
TypeId SetParent(TypeId tid)
Set the parent TypeId.
Definition: type-id.cc:931
represent a single transmission mode
Definition: wifi-mode.h:51
std::string GetUniqueName() const
Definition: wifi-mode.cc:148
WifiModulationClass GetModulationClass() const
Definition: wifi-mode.cc:185
uint64_t GetDataRate(uint16_t channelWidth, uint16_t guardInterval, uint8_t nss) const
Definition: wifi-mode.cc:122
uint16_t GetChannelWidth() const
Definition: wifi-phy.cc:1051
uint8_t GetMaxSupportedTxSpatialStreams() const
Definition: wifi-phy.cc:1301
std::list< WifiMode > GetMcsList() const
The WifiPhy::GetMcsList() method is used (e.g., by a WifiRemoteStationManager) to determine the set o...
Definition: wifi-phy.cc:2003
hold a list of per-remote-station state.
uint8_t GetNSupported(const WifiRemoteStation *station) const
Return the number of modes supported by the given station.
Ptr< WifiPhy > GetPhy() const
Return the WifiPhy.
uint16_t GetGuardInterval() const
Return the supported HE guard interval duration (in nanoseconds).
bool GetAggregation(const WifiRemoteStation *station) const
Return whether the given station supports A-MPDU.
bool GetShortGuardIntervalSupported() const
Return whether the device has SGI support enabled.
bool GetVhtSupported() const
Return whether the device has VHT capability support enabled.
bool GetShortPreambleEnabled() const
Return whether the device uses short PHY preambles.
WifiMode GetSupported(const WifiRemoteStation *station, uint8_t i) const
Return whether mode associated with the specified station at the specified index.
bool GetHeSupported() const
Return whether the device has HE capability support enabled.
This class mimics the TXVECTOR which is to be passed to the PHY in order to define the parameters whi...
#define NS_ASSERT(condition)
At runtime, in debugging builds, if this condition is not true, the program prints the source file,...
Definition: assert.h:66
#define NS_ASSERT_MSG(condition, message)
At runtime, in debugging builds, if this condition is not true, the program prints the message to out...
Definition: assert.h:86
#define NS_LOG_COMPONENT_DEFINE(name)
Define a Log component with a specific name.
Definition: log.h:202
#define NS_LOG_DEBUG(msg)
Use NS_LOG to output a message of level LOG_DEBUG.
Definition: log.h:268
#define NS_LOG_FUNCTION(parameters)
If log level LOG_FUNCTION is enabled, this macro will output all input parameters separated by ",...
#define NS_OBJECT_ENSURE_REGISTERED(type)
Register an Object subclass with the TypeId system.
Definition: object-base.h:46
Ptr< const TraceSourceAccessor > MakeTraceSourceAccessor(T a)
Create a TraceSourceAccessor which will control access to the underlying trace source.
WifiModulationClass
This enumeration defines the modulation classes per (Table 10-6 "Modulation classes"; IEEE 802....
@ WIFI_MOD_CLASS_HR_DSSS
HR/DSSS (Clause 16)
@ WIFI_MOD_CLASS_HT
HT (Clause 19)
@ WIFI_MOD_CLASS_VHT
VHT (Clause 22)
@ WIFI_MOD_CLASS_HE
HE (Clause 27)
@ WIFI_MOD_CLASS_DSSS
DSSS (Clause 15)
Every class exported by the ns3 library is enclosed in the ns3 namespace.
Ptr< const AttributeAccessor > MakeDoubleAccessor(T1 a1)
Definition: double.h:43
WifiPreamble GetPreambleForTransmission(WifiModulationClass modulation, bool useShortPreamble)
Return the preamble to be used for the transmission.
A structure containing parameters of a single rate and its statistics.
uint16_t channelWidth
channel width in MHz
uint8_t nss
Number of spatial streams.
double success
averaged number of successful transmissions
double fails
averaged number of failed transmissions
Time lastDecay
last time exponential decay was applied to this rate
Holds station state and collected statistics.
size_t m_nextMode
Mode to select for the next transmission.
std::vector< RateStats > m_mcsStats
Collected statistics.
size_t m_lastMode
Most recently used mode, used to write statistics.
hold per-remote-station state.