1 /*
2  * Copyright (C) 2010-2013 Felix Fietkau <nbd@openwrt.org>
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License version 2 as
6  * published by the Free Software Foundation.
7  */
8 #include <linux/netdevice.h>
9 #include <linux/types.h>
10 #include <linux/skbuff.h>
11 #include <linux/debugfs.h>
12 #include <linux/random.h>
13 #include <linux/ieee80211.h>
14 #include <net/mac80211.h>
15 #include "rate.h"
16 #include "rc80211_minstrel.h"
17 #include "rc80211_minstrel_ht.h"
18 
19 #define AVG_PKT_SIZE	1200
20 #define SAMPLE_COLUMNS	10
21 #define EWMA_LEVEL		75
22 
23 /* Number of bits for an average sized packet */
24 #define MCS_NBITS (AVG_PKT_SIZE << 3)
25 
26 /* Number of symbols for a packet with (bps) bits per symbol */
27 #define MCS_NSYMS(bps) ((MCS_NBITS + (bps) - 1) / (bps))
28 
29 /* Transmission time for a packet containing (syms) symbols */
30 #define MCS_SYMBOL_TIME(sgi, syms)					\
31 	(sgi ?								\
32 	  ((syms) * 18 + 4) / 5 :	/* syms * 3.6 us */		\
33 	  (syms) << 2			/* syms * 4 us */		\
34 	)
35 
36 /* Transmit duration for the raw data part of an average sized packet */
37 #define MCS_DURATION(streams, sgi, bps) MCS_SYMBOL_TIME(sgi, MCS_NSYMS((streams) * (bps)))
38 
39 /*
40  * Define group sort order: HT40 -> SGI -> #streams
41  */
42 #define GROUP_IDX(_streams, _sgi, _ht40)	\
43 	MINSTREL_MAX_STREAMS * 2 * _ht40 +	\
44 	MINSTREL_MAX_STREAMS * _sgi +		\
45 	_streams - 1
46 
47 /* MCS rate information for an MCS group */
48 #define MCS_GROUP(_streams, _sgi, _ht40)				\
49 	[GROUP_IDX(_streams, _sgi, _ht40)] = {				\
50 	.streams = _streams,						\
51 	.flags =							\
52 		(_sgi ? IEEE80211_TX_RC_SHORT_GI : 0) |			\
53 		(_ht40 ? IEEE80211_TX_RC_40_MHZ_WIDTH : 0),		\
54 	.duration = {							\
55 		MCS_DURATION(_streams, _sgi, _ht40 ? 54 : 26),		\
56 		MCS_DURATION(_streams, _sgi, _ht40 ? 108 : 52),		\
57 		MCS_DURATION(_streams, _sgi, _ht40 ? 162 : 78),		\
58 		MCS_DURATION(_streams, _sgi, _ht40 ? 216 : 104),	\
59 		MCS_DURATION(_streams, _sgi, _ht40 ? 324 : 156),	\
60 		MCS_DURATION(_streams, _sgi, _ht40 ? 432 : 208),	\
61 		MCS_DURATION(_streams, _sgi, _ht40 ? 486 : 234),	\
62 		MCS_DURATION(_streams, _sgi, _ht40 ? 540 : 260)		\
63 	}								\
64 }
65 
66 #define CCK_DURATION(_bitrate, _short, _len)		\
67 	(10 /* SIFS */ +				\
68 	 (_short ? 72 + 24 : 144 + 48 ) +		\
69 	 (8 * (_len + 4) * 10) / (_bitrate))
70 
71 #define CCK_ACK_DURATION(_bitrate, _short)			\
72 	(CCK_DURATION((_bitrate > 10 ? 20 : 10), false, 60) +	\
73 	 CCK_DURATION(_bitrate, _short, AVG_PKT_SIZE))
74 
75 #define CCK_DURATION_LIST(_short)			\
76 	CCK_ACK_DURATION(10, _short),			\
77 	CCK_ACK_DURATION(20, _short),			\
78 	CCK_ACK_DURATION(55, _short),			\
79 	CCK_ACK_DURATION(110, _short)
80 
81 #define CCK_GROUP						\
82 	[MINSTREL_MAX_STREAMS * MINSTREL_STREAM_GROUPS] = {	\
83 		.streams = 0,					\
84 		.duration = {					\
85 			CCK_DURATION_LIST(false),		\
86 			CCK_DURATION_LIST(true)			\
87 		}						\
88 	}
89 
90 /*
91  * To enable sufficiently targeted rate sampling, MCS rates are divided into
92  * groups, based on the number of streams and flags (HT40, SGI) that they
93  * use.
94  *
95  * Sortorder has to be fixed for GROUP_IDX macro to be applicable:
96  * HT40 -> SGI -> #streams
97  */
98 const struct mcs_group minstrel_mcs_groups[] = {
99 	MCS_GROUP(1, 0, 0),
100 	MCS_GROUP(2, 0, 0),
101 #if MINSTREL_MAX_STREAMS >= 3
102 	MCS_GROUP(3, 0, 0),
103 #endif
104 
105 	MCS_GROUP(1, 1, 0),
106 	MCS_GROUP(2, 1, 0),
107 #if MINSTREL_MAX_STREAMS >= 3
108 	MCS_GROUP(3, 1, 0),
109 #endif
110 
111 	MCS_GROUP(1, 0, 1),
112 	MCS_GROUP(2, 0, 1),
113 #if MINSTREL_MAX_STREAMS >= 3
114 	MCS_GROUP(3, 0, 1),
115 #endif
116 
117 	MCS_GROUP(1, 1, 1),
118 	MCS_GROUP(2, 1, 1),
119 #if MINSTREL_MAX_STREAMS >= 3
120 	MCS_GROUP(3, 1, 1),
121 #endif
122 
123 	/* must be last */
124 	CCK_GROUP
125 };
126 
127 #define MINSTREL_CCK_GROUP	(ARRAY_SIZE(minstrel_mcs_groups) - 1)
128 
129 static u8 sample_table[SAMPLE_COLUMNS][MCS_GROUP_RATES];
130 
131 /*
132  * Perform EWMA (Exponentially Weighted Moving Average) calculation
133  */
134 static int
135 minstrel_ewma(int old, int new, int weight)
136 {
137 	return (new * (100 - weight) + old * weight) / 100;
138 }
139 
140 /*
141  * Look up an MCS group index based on mac80211 rate information
142  */
143 static int
144 minstrel_ht_get_group_idx(struct ieee80211_tx_rate *rate)
145 {
146 	return GROUP_IDX((rate->idx / MCS_GROUP_RATES) + 1,
147 			 !!(rate->flags & IEEE80211_TX_RC_SHORT_GI),
148 			 !!(rate->flags & IEEE80211_TX_RC_40_MHZ_WIDTH));
149 }
150 
151 static struct minstrel_rate_stats *
152 minstrel_ht_get_stats(struct minstrel_priv *mp, struct minstrel_ht_sta *mi,
153 		      struct ieee80211_tx_rate *rate)
154 {
155 	int group, idx;
156 
157 	if (rate->flags & IEEE80211_TX_RC_MCS) {
158 		group = minstrel_ht_get_group_idx(rate);
159 		idx = rate->idx % MCS_GROUP_RATES;
160 	} else {
161 		group = MINSTREL_CCK_GROUP;
162 
163 		for (idx = 0; idx < ARRAY_SIZE(mp->cck_rates); idx++)
164 			if (rate->idx == mp->cck_rates[idx])
165 				break;
166 
167 		/* short preamble */
168 		if (!(mi->groups[group].supported & BIT(idx)))
169 			idx += 4;
170 	}
171 	return &mi->groups[group].rates[idx];
172 }
173 
174 static inline struct minstrel_rate_stats *
175 minstrel_get_ratestats(struct minstrel_ht_sta *mi, int index)
176 {
177 	return &mi->groups[index / MCS_GROUP_RATES].rates[index % MCS_GROUP_RATES];
178 }
179 
180 
181 /*
182  * Recalculate success probabilities and counters for a rate using EWMA
183  */
184 static void
185 minstrel_calc_rate_ewma(struct minstrel_rate_stats *mr)
186 {
187 	if (unlikely(mr->attempts > 0)) {
188 		mr->sample_skipped = 0;
189 		mr->cur_prob = MINSTREL_FRAC(mr->success, mr->attempts);
190 		if (!mr->att_hist)
191 			mr->probability = mr->cur_prob;
192 		else
193 			mr->probability = minstrel_ewma(mr->probability,
194 				mr->cur_prob, EWMA_LEVEL);
195 		mr->att_hist += mr->attempts;
196 		mr->succ_hist += mr->success;
197 	} else {
198 		mr->sample_skipped++;
199 	}
200 	mr->last_success = mr->success;
201 	mr->last_attempts = mr->attempts;
202 	mr->success = 0;
203 	mr->attempts = 0;
204 }
205 
206 /*
207  * Calculate throughput based on the average A-MPDU length, taking into account
208  * the expected number of retransmissions and their expected length
209  */
210 static void
211 minstrel_ht_calc_tp(struct minstrel_ht_sta *mi, int group, int rate)
212 {
213 	struct minstrel_rate_stats *mr;
214 	unsigned int usecs = 0;
215 
216 	mr = &mi->groups[group].rates[rate];
217 
218 	if (mr->probability < MINSTREL_FRAC(1, 10)) {
219 		mr->cur_tp = 0;
220 		return;
221 	}
222 
223 	if (group != MINSTREL_CCK_GROUP)
224 		usecs = mi->overhead / MINSTREL_TRUNC(mi->avg_ampdu_len);
225 
226 	usecs += minstrel_mcs_groups[group].duration[rate];
227 	mr->cur_tp = MINSTREL_TRUNC((1000000 / usecs) * mr->probability);
228 }
229 
230 /*
231  * Update rate statistics and select new primary rates
232  *
233  * Rules for rate selection:
234  *  - max_prob_rate must use only one stream, as a tradeoff between delivery
235  *    probability and throughput during strong fluctuations
236  *  - as long as the max prob rate has a probability of more than 3/4, pick
237  *    higher throughput rates, even if the probablity is a bit lower
238  */
239 static void
240 minstrel_ht_update_stats(struct minstrel_priv *mp, struct minstrel_ht_sta *mi)
241 {
242 	struct minstrel_mcs_group_data *mg;
243 	struct minstrel_rate_stats *mr;
244 	int cur_prob, cur_prob_tp, cur_tp, cur_tp2;
245 	int group, i, index;
246 
247 	if (mi->ampdu_packets > 0) {
248 		mi->avg_ampdu_len = minstrel_ewma(mi->avg_ampdu_len,
249 			MINSTREL_FRAC(mi->ampdu_len, mi->ampdu_packets), EWMA_LEVEL);
250 		mi->ampdu_len = 0;
251 		mi->ampdu_packets = 0;
252 	}
253 
254 	mi->sample_slow = 0;
255 	mi->sample_count = 0;
256 	mi->max_tp_rate = 0;
257 	mi->max_tp_rate2 = 0;
258 	mi->max_prob_rate = 0;
259 
260 	for (group = 0; group < ARRAY_SIZE(minstrel_mcs_groups); group++) {
261 		cur_prob = 0;
262 		cur_prob_tp = 0;
263 		cur_tp = 0;
264 		cur_tp2 = 0;
265 
266 		mg = &mi->groups[group];
267 		if (!mg->supported)
268 			continue;
269 
270 		mg->max_tp_rate = 0;
271 		mg->max_tp_rate2 = 0;
272 		mg->max_prob_rate = 0;
273 		mi->sample_count++;
274 
275 		for (i = 0; i < MCS_GROUP_RATES; i++) {
276 			if (!(mg->supported & BIT(i)))
277 				continue;
278 
279 			mr = &mg->rates[i];
280 			mr->retry_updated = false;
281 			index = MCS_GROUP_RATES * group + i;
282 			minstrel_calc_rate_ewma(mr);
283 			minstrel_ht_calc_tp(mi, group, i);
284 
285 			if (!mr->cur_tp)
286 				continue;
287 
288 			if ((mr->cur_tp > cur_prob_tp && mr->probability >
289 			     MINSTREL_FRAC(3, 4)) || mr->probability > cur_prob) {
290 				mg->max_prob_rate = index;
291 				cur_prob = mr->probability;
292 				cur_prob_tp = mr->cur_tp;
293 			}
294 
295 			if (mr->cur_tp > cur_tp) {
296 				swap(index, mg->max_tp_rate);
297 				cur_tp = mr->cur_tp;
298 				mr = minstrel_get_ratestats(mi, index);
299 			}
300 
301 			if (index >= mg->max_tp_rate)
302 				continue;
303 
304 			if (mr->cur_tp > cur_tp2) {
305 				mg->max_tp_rate2 = index;
306 				cur_tp2 = mr->cur_tp;
307 			}
308 		}
309 	}
310 
311 	/* try to sample up to half of the available rates during each interval */
312 	mi->sample_count *= 4;
313 
314 	cur_prob = 0;
315 	cur_prob_tp = 0;
316 	cur_tp = 0;
317 	cur_tp2 = 0;
318 	for (group = 0; group < ARRAY_SIZE(minstrel_mcs_groups); group++) {
319 		mg = &mi->groups[group];
320 		if (!mg->supported)
321 			continue;
322 
323 		mr = minstrel_get_ratestats(mi, mg->max_prob_rate);
324 		if (cur_prob_tp < mr->cur_tp &&
325 		    minstrel_mcs_groups[group].streams == 1) {
326 			mi->max_prob_rate = mg->max_prob_rate;
327 			cur_prob = mr->cur_prob;
328 			cur_prob_tp = mr->cur_tp;
329 		}
330 
331 		mr = minstrel_get_ratestats(mi, mg->max_tp_rate);
332 		if (cur_tp < mr->cur_tp) {
333 			mi->max_tp_rate2 = mi->max_tp_rate;
334 			cur_tp2 = cur_tp;
335 			mi->max_tp_rate = mg->max_tp_rate;
336 			cur_tp = mr->cur_tp;
337 		}
338 
339 		mr = minstrel_get_ratestats(mi, mg->max_tp_rate2);
340 		if (cur_tp2 < mr->cur_tp) {
341 			mi->max_tp_rate2 = mg->max_tp_rate2;
342 			cur_tp2 = mr->cur_tp;
343 		}
344 	}
345 
346 	mi->stats_update = jiffies;
347 }
348 
349 static bool
350 minstrel_ht_txstat_valid(struct minstrel_priv *mp, struct ieee80211_tx_rate *rate)
351 {
352 	if (rate->idx < 0)
353 		return false;
354 
355 	if (!rate->count)
356 		return false;
357 
358 	if (rate->flags & IEEE80211_TX_RC_MCS)
359 		return true;
360 
361 	return rate->idx == mp->cck_rates[0] ||
362 	       rate->idx == mp->cck_rates[1] ||
363 	       rate->idx == mp->cck_rates[2] ||
364 	       rate->idx == mp->cck_rates[3];
365 }
366 
367 static void
368 minstrel_next_sample_idx(struct minstrel_ht_sta *mi)
369 {
370 	struct minstrel_mcs_group_data *mg;
371 
372 	for (;;) {
373 		mi->sample_group++;
374 		mi->sample_group %= ARRAY_SIZE(minstrel_mcs_groups);
375 		mg = &mi->groups[mi->sample_group];
376 
377 		if (!mg->supported)
378 			continue;
379 
380 		if (++mg->index >= MCS_GROUP_RATES) {
381 			mg->index = 0;
382 			if (++mg->column >= ARRAY_SIZE(sample_table))
383 				mg->column = 0;
384 		}
385 		break;
386 	}
387 }
388 
389 static void
390 minstrel_downgrade_rate(struct minstrel_ht_sta *mi, unsigned int *idx,
391 			bool primary)
392 {
393 	int group, orig_group;
394 
395 	orig_group = group = *idx / MCS_GROUP_RATES;
396 	while (group > 0) {
397 		group--;
398 
399 		if (!mi->groups[group].supported)
400 			continue;
401 
402 		if (minstrel_mcs_groups[group].streams >
403 		    minstrel_mcs_groups[orig_group].streams)
404 			continue;
405 
406 		if (primary)
407 			*idx = mi->groups[group].max_tp_rate;
408 		else
409 			*idx = mi->groups[group].max_tp_rate2;
410 		break;
411 	}
412 }
413 
414 static void
415 minstrel_aggr_check(struct ieee80211_sta *pubsta, struct sk_buff *skb)
416 {
417 	struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data;
418 	struct sta_info *sta = container_of(pubsta, struct sta_info, sta);
419 	u16 tid;
420 
421 	if (unlikely(!ieee80211_is_data_qos(hdr->frame_control)))
422 		return;
423 
424 	if (unlikely(skb->protocol == cpu_to_be16(ETH_P_PAE)))
425 		return;
426 
427 	tid = *ieee80211_get_qos_ctl(hdr) & IEEE80211_QOS_CTL_TID_MASK;
428 	if (likely(sta->ampdu_mlme.tid_tx[tid]))
429 		return;
430 
431 	if (skb_get_queue_mapping(skb) == IEEE80211_AC_VO)
432 		return;
433 
434 	ieee80211_start_tx_ba_session(pubsta, tid, 5000);
435 }
436 
437 static void
438 minstrel_ht_tx_status(void *priv, struct ieee80211_supported_band *sband,
439                       struct ieee80211_sta *sta, void *priv_sta,
440                       struct sk_buff *skb)
441 {
442 	struct minstrel_ht_sta_priv *msp = priv_sta;
443 	struct minstrel_ht_sta *mi = &msp->ht;
444 	struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb);
445 	struct ieee80211_tx_rate *ar = info->status.rates;
446 	struct minstrel_rate_stats *rate, *rate2;
447 	struct minstrel_priv *mp = priv;
448 	bool last;
449 	int i;
450 
451 	if (!msp->is_ht)
452 		return mac80211_minstrel.tx_status(priv, sband, sta, &msp->legacy, skb);
453 
454 	/* This packet was aggregated but doesn't carry status info */
455 	if ((info->flags & IEEE80211_TX_CTL_AMPDU) &&
456 	    !(info->flags & IEEE80211_TX_STAT_AMPDU))
457 		return;
458 
459 	if (!(info->flags & IEEE80211_TX_STAT_AMPDU)) {
460 		info->status.ampdu_ack_len =
461 			(info->flags & IEEE80211_TX_STAT_ACK ? 1 : 0);
462 		info->status.ampdu_len = 1;
463 	}
464 
465 	mi->ampdu_packets++;
466 	mi->ampdu_len += info->status.ampdu_len;
467 
468 	if (!mi->sample_wait && !mi->sample_tries && mi->sample_count > 0) {
469 		mi->sample_wait = 16 + 2 * MINSTREL_TRUNC(mi->avg_ampdu_len);
470 		mi->sample_tries = 2;
471 		mi->sample_count--;
472 	}
473 
474 	if (info->flags & IEEE80211_TX_CTL_RATE_CTRL_PROBE)
475 		mi->sample_packets += info->status.ampdu_len;
476 
477 	last = !minstrel_ht_txstat_valid(mp, &ar[0]);
478 	for (i = 0; !last; i++) {
479 		last = (i == IEEE80211_TX_MAX_RATES - 1) ||
480 		       !minstrel_ht_txstat_valid(mp, &ar[i + 1]);
481 
482 		rate = minstrel_ht_get_stats(mp, mi, &ar[i]);
483 
484 		if (last)
485 			rate->success += info->status.ampdu_ack_len;
486 
487 		rate->attempts += ar[i].count * info->status.ampdu_len;
488 	}
489 
490 	/*
491 	 * check for sudden death of spatial multiplexing,
492 	 * downgrade to a lower number of streams if necessary.
493 	 */
494 	rate = minstrel_get_ratestats(mi, mi->max_tp_rate);
495 	if (rate->attempts > 30 &&
496 	    MINSTREL_FRAC(rate->success, rate->attempts) <
497 	    MINSTREL_FRAC(20, 100))
498 		minstrel_downgrade_rate(mi, &mi->max_tp_rate, true);
499 
500 	rate2 = minstrel_get_ratestats(mi, mi->max_tp_rate2);
501 	if (rate2->attempts > 30 &&
502 	    MINSTREL_FRAC(rate2->success, rate2->attempts) <
503 	    MINSTREL_FRAC(20, 100))
504 		minstrel_downgrade_rate(mi, &mi->max_tp_rate2, false);
505 
506 	if (time_after(jiffies, mi->stats_update + (mp->update_interval / 2 * HZ) / 1000)) {
507 		minstrel_ht_update_stats(mp, mi);
508 		if (!(info->flags & IEEE80211_TX_CTL_AMPDU) &&
509 		    mi->max_prob_rate / MCS_GROUP_RATES != MINSTREL_CCK_GROUP)
510 			minstrel_aggr_check(sta, skb);
511 	}
512 }
513 
514 static void
515 minstrel_calc_retransmit(struct minstrel_priv *mp, struct minstrel_ht_sta *mi,
516                          int index)
517 {
518 	struct minstrel_rate_stats *mr;
519 	const struct mcs_group *group;
520 	unsigned int tx_time, tx_time_rtscts, tx_time_data;
521 	unsigned int cw = mp->cw_min;
522 	unsigned int ctime = 0;
523 	unsigned int t_slot = 9; /* FIXME */
524 	unsigned int ampdu_len = MINSTREL_TRUNC(mi->avg_ampdu_len);
525 	unsigned int overhead = 0, overhead_rtscts = 0;
526 
527 	mr = minstrel_get_ratestats(mi, index);
528 	if (mr->probability < MINSTREL_FRAC(1, 10)) {
529 		mr->retry_count = 1;
530 		mr->retry_count_rtscts = 1;
531 		return;
532 	}
533 
534 	mr->retry_count = 2;
535 	mr->retry_count_rtscts = 2;
536 	mr->retry_updated = true;
537 
538 	group = &minstrel_mcs_groups[index / MCS_GROUP_RATES];
539 	tx_time_data = group->duration[index % MCS_GROUP_RATES] * ampdu_len;
540 
541 	/* Contention time for first 2 tries */
542 	ctime = (t_slot * cw) >> 1;
543 	cw = min((cw << 1) | 1, mp->cw_max);
544 	ctime += (t_slot * cw) >> 1;
545 	cw = min((cw << 1) | 1, mp->cw_max);
546 
547 	if (index / MCS_GROUP_RATES != MINSTREL_CCK_GROUP) {
548 		overhead = mi->overhead;
549 		overhead_rtscts = mi->overhead_rtscts;
550 	}
551 
552 	/* Total TX time for data and Contention after first 2 tries */
553 	tx_time = ctime + 2 * (overhead + tx_time_data);
554 	tx_time_rtscts = ctime + 2 * (overhead_rtscts + tx_time_data);
555 
556 	/* See how many more tries we can fit inside segment size */
557 	do {
558 		/* Contention time for this try */
559 		ctime = (t_slot * cw) >> 1;
560 		cw = min((cw << 1) | 1, mp->cw_max);
561 
562 		/* Total TX time after this try */
563 		tx_time += ctime + overhead + tx_time_data;
564 		tx_time_rtscts += ctime + overhead_rtscts + tx_time_data;
565 
566 		if (tx_time_rtscts < mp->segment_size)
567 			mr->retry_count_rtscts++;
568 	} while ((tx_time < mp->segment_size) &&
569 	         (++mr->retry_count < mp->max_retry));
570 }
571 
572 
573 static void
574 minstrel_ht_set_rate(struct minstrel_priv *mp, struct minstrel_ht_sta *mi,
575                      struct ieee80211_tx_rate *rate, int index,
576                      bool sample, bool rtscts)
577 {
578 	const struct mcs_group *group = &minstrel_mcs_groups[index / MCS_GROUP_RATES];
579 	struct minstrel_rate_stats *mr;
580 
581 	mr = minstrel_get_ratestats(mi, index);
582 	if (!mr->retry_updated)
583 		minstrel_calc_retransmit(mp, mi, index);
584 
585 	if (sample)
586 		rate->count = 1;
587 	else if (mr->probability < MINSTREL_FRAC(20, 100))
588 		rate->count = 2;
589 	else if (rtscts)
590 		rate->count = mr->retry_count_rtscts;
591 	else
592 		rate->count = mr->retry_count;
593 
594 	rate->flags = 0;
595 	if (rtscts)
596 		rate->flags |= IEEE80211_TX_RC_USE_RTS_CTS;
597 
598 	if (index / MCS_GROUP_RATES == MINSTREL_CCK_GROUP) {
599 		rate->idx = mp->cck_rates[index % ARRAY_SIZE(mp->cck_rates)];
600 		return;
601 	}
602 
603 	rate->flags |= IEEE80211_TX_RC_MCS | group->flags;
604 	rate->idx = index % MCS_GROUP_RATES + (group->streams - 1) * MCS_GROUP_RATES;
605 }
606 
607 static inline int
608 minstrel_get_duration(int index)
609 {
610 	const struct mcs_group *group = &minstrel_mcs_groups[index / MCS_GROUP_RATES];
611 	return group->duration[index % MCS_GROUP_RATES];
612 }
613 
614 static int
615 minstrel_get_sample_rate(struct minstrel_priv *mp, struct minstrel_ht_sta *mi)
616 {
617 	struct minstrel_rate_stats *mr;
618 	struct minstrel_mcs_group_data *mg;
619 	int sample_idx = 0;
620 
621 	if (mi->sample_wait > 0) {
622 		mi->sample_wait--;
623 		return -1;
624 	}
625 
626 	if (!mi->sample_tries)
627 		return -1;
628 
629 	mi->sample_tries--;
630 	mg = &mi->groups[mi->sample_group];
631 	sample_idx = sample_table[mg->column][mg->index];
632 	mr = &mg->rates[sample_idx];
633 	sample_idx += mi->sample_group * MCS_GROUP_RATES;
634 	minstrel_next_sample_idx(mi);
635 
636 	/*
637 	 * Sampling might add some overhead (RTS, no aggregation)
638 	 * to the frame. Hence, don't use sampling for the currently
639 	 * used max TP rate.
640 	 */
641 	if (sample_idx == mi->max_tp_rate)
642 		return -1;
643 	/*
644 	 * When not using MRR, do not sample if the probability is already
645 	 * higher than 95% to avoid wasting airtime
646 	 */
647 	if (!mp->has_mrr && (mr->probability > MINSTREL_FRAC(95, 100)))
648 		return -1;
649 
650 	/*
651 	 * Make sure that lower rates get sampled only occasionally,
652 	 * if the link is working perfectly.
653 	 */
654 	if (minstrel_get_duration(sample_idx) >
655 	    minstrel_get_duration(mi->max_tp_rate)) {
656 		if (mr->sample_skipped < 20)
657 			return -1;
658 
659 		if (mi->sample_slow++ > 2)
660 			return -1;
661 	}
662 
663 	return sample_idx;
664 }
665 
666 static void
667 minstrel_ht_check_cck_shortpreamble(struct minstrel_priv *mp,
668 				    struct minstrel_ht_sta *mi, bool val)
669 {
670 	u8 supported = mi->groups[MINSTREL_CCK_GROUP].supported;
671 
672 	if (!supported || !mi->cck_supported_short)
673 		return;
674 
675 	if (supported & (mi->cck_supported_short << (val * 4)))
676 		return;
677 
678 	supported ^= mi->cck_supported_short | (mi->cck_supported_short << 4);
679 	mi->groups[MINSTREL_CCK_GROUP].supported = supported;
680 }
681 
682 static void
683 minstrel_ht_get_rate(void *priv, struct ieee80211_sta *sta, void *priv_sta,
684                      struct ieee80211_tx_rate_control *txrc)
685 {
686 	struct ieee80211_tx_info *info = IEEE80211_SKB_CB(txrc->skb);
687 	struct ieee80211_tx_rate *ar = info->status.rates;
688 	struct minstrel_ht_sta_priv *msp = priv_sta;
689 	struct minstrel_ht_sta *mi = &msp->ht;
690 	struct minstrel_priv *mp = priv;
691 	int sample_idx;
692 	bool sample = false;
693 
694 	if (rate_control_send_low(sta, priv_sta, txrc))
695 		return;
696 
697 	if (!msp->is_ht)
698 		return mac80211_minstrel.get_rate(priv, sta, &msp->legacy, txrc);
699 
700 	info->flags |= mi->tx_flags;
701 	minstrel_ht_check_cck_shortpreamble(mp, mi, txrc->short_preamble);
702 
703 	/* Don't use EAPOL frames for sampling on non-mrr hw */
704 	if (mp->hw->max_rates == 1 &&
705 	    txrc->skb->protocol == cpu_to_be16(ETH_P_PAE))
706 		sample_idx = -1;
707 	else
708 		sample_idx = minstrel_get_sample_rate(mp, mi);
709 
710 #ifdef CONFIG_MAC80211_DEBUGFS
711 	/* use fixed index if set */
712 	if (mp->fixed_rate_idx != -1) {
713 		mi->max_tp_rate = mp->fixed_rate_idx;
714 		mi->max_tp_rate2 = mp->fixed_rate_idx;
715 		mi->max_prob_rate = mp->fixed_rate_idx;
716 		sample_idx = -1;
717 	}
718 #endif
719 
720 	if (sample_idx >= 0) {
721 		sample = true;
722 		minstrel_ht_set_rate(mp, mi, &ar[0], sample_idx,
723 			true, false);
724 		info->flags |= IEEE80211_TX_CTL_RATE_CTRL_PROBE;
725 	} else {
726 		minstrel_ht_set_rate(mp, mi, &ar[0], mi->max_tp_rate,
727 			false, false);
728 	}
729 
730 	if (mp->hw->max_rates >= 3) {
731 		/*
732 		 * At least 3 tx rates supported, use
733 		 * sample_rate -> max_tp_rate -> max_prob_rate for sampling and
734 		 * max_tp_rate -> max_tp_rate2 -> max_prob_rate by default.
735 		 */
736 		if (sample_idx >= 0)
737 			minstrel_ht_set_rate(mp, mi, &ar[1], mi->max_tp_rate,
738 				false, false);
739 		else
740 			minstrel_ht_set_rate(mp, mi, &ar[1], mi->max_tp_rate2,
741 				false, true);
742 
743 		minstrel_ht_set_rate(mp, mi, &ar[2], mi->max_prob_rate,
744 				     false, !sample);
745 
746 		ar[3].count = 0;
747 		ar[3].idx = -1;
748 	} else if (mp->hw->max_rates == 2) {
749 		/*
750 		 * Only 2 tx rates supported, use
751 		 * sample_rate -> max_prob_rate for sampling and
752 		 * max_tp_rate -> max_prob_rate by default.
753 		 */
754 		minstrel_ht_set_rate(mp, mi, &ar[1], mi->max_prob_rate,
755 				     false, !sample);
756 
757 		ar[2].count = 0;
758 		ar[2].idx = -1;
759 	} else {
760 		/* Not using MRR, only use the first rate */
761 		ar[1].count = 0;
762 		ar[1].idx = -1;
763 	}
764 
765 	mi->total_packets++;
766 
767 	/* wraparound */
768 	if (mi->total_packets == ~0) {
769 		mi->total_packets = 0;
770 		mi->sample_packets = 0;
771 	}
772 }
773 
774 static void
775 minstrel_ht_update_cck(struct minstrel_priv *mp, struct minstrel_ht_sta *mi,
776 		       struct ieee80211_supported_band *sband,
777 		       struct ieee80211_sta *sta)
778 {
779 	int i;
780 
781 	if (sband->band != IEEE80211_BAND_2GHZ)
782 		return;
783 
784 	mi->cck_supported = 0;
785 	mi->cck_supported_short = 0;
786 	for (i = 0; i < 4; i++) {
787 		if (!rate_supported(sta, sband->band, mp->cck_rates[i]))
788 			continue;
789 
790 		mi->cck_supported |= BIT(i);
791 		if (sband->bitrates[i].flags & IEEE80211_RATE_SHORT_PREAMBLE)
792 			mi->cck_supported_short |= BIT(i);
793 	}
794 
795 	mi->groups[MINSTREL_CCK_GROUP].supported = mi->cck_supported;
796 }
797 
798 static void
799 minstrel_ht_update_caps(void *priv, struct ieee80211_supported_band *sband,
800                         struct ieee80211_sta *sta, void *priv_sta)
801 {
802 	struct minstrel_priv *mp = priv;
803 	struct minstrel_ht_sta_priv *msp = priv_sta;
804 	struct minstrel_ht_sta *mi = &msp->ht;
805 	struct ieee80211_mcs_info *mcs = &sta->ht_cap.mcs;
806 	u16 sta_cap = sta->ht_cap.cap;
807 	int n_supported = 0;
808 	int ack_dur;
809 	int stbc;
810 	int i;
811 
812 	/* fall back to the old minstrel for legacy stations */
813 	if (!sta->ht_cap.ht_supported)
814 		goto use_legacy;
815 
816 	BUILD_BUG_ON(ARRAY_SIZE(minstrel_mcs_groups) !=
817 		MINSTREL_MAX_STREAMS * MINSTREL_STREAM_GROUPS + 1);
818 
819 	msp->is_ht = true;
820 	memset(mi, 0, sizeof(*mi));
821 	mi->stats_update = jiffies;
822 
823 	ack_dur = ieee80211_frame_duration(sband->band, 10, 60, 1, 1);
824 	mi->overhead = ieee80211_frame_duration(sband->band, 0, 60, 1, 1) + ack_dur;
825 	mi->overhead_rtscts = mi->overhead + 2 * ack_dur;
826 
827 	mi->avg_ampdu_len = MINSTREL_FRAC(1, 1);
828 
829 	/* When using MRR, sample more on the first attempt, without delay */
830 	if (mp->has_mrr) {
831 		mi->sample_count = 16;
832 		mi->sample_wait = 0;
833 	} else {
834 		mi->sample_count = 8;
835 		mi->sample_wait = 8;
836 	}
837 	mi->sample_tries = 4;
838 
839 	stbc = (sta_cap & IEEE80211_HT_CAP_RX_STBC) >>
840 		IEEE80211_HT_CAP_RX_STBC_SHIFT;
841 	mi->tx_flags |= stbc << IEEE80211_TX_CTL_STBC_SHIFT;
842 
843 	if (sta_cap & IEEE80211_HT_CAP_LDPC_CODING)
844 		mi->tx_flags |= IEEE80211_TX_CTL_LDPC;
845 
846 	for (i = 0; i < ARRAY_SIZE(mi->groups); i++) {
847 		mi->groups[i].supported = 0;
848 		if (i == MINSTREL_CCK_GROUP) {
849 			minstrel_ht_update_cck(mp, mi, sband, sta);
850 			continue;
851 		}
852 
853 		if (minstrel_mcs_groups[i].flags & IEEE80211_TX_RC_SHORT_GI) {
854 			if (minstrel_mcs_groups[i].flags & IEEE80211_TX_RC_40_MHZ_WIDTH) {
855 				if (!(sta_cap & IEEE80211_HT_CAP_SGI_40))
856 					continue;
857 			} else {
858 				if (!(sta_cap & IEEE80211_HT_CAP_SGI_20))
859 					continue;
860 			}
861 		}
862 
863 		if (minstrel_mcs_groups[i].flags & IEEE80211_TX_RC_40_MHZ_WIDTH &&
864 		    sta->bandwidth < IEEE80211_STA_RX_BW_40)
865 			continue;
866 
867 		/* Mark MCS > 7 as unsupported if STA is in static SMPS mode */
868 		if (sta->smps_mode == IEEE80211_SMPS_STATIC &&
869 		    minstrel_mcs_groups[i].streams > 1)
870 			continue;
871 
872 		mi->groups[i].supported =
873 			mcs->rx_mask[minstrel_mcs_groups[i].streams - 1];
874 
875 		if (mi->groups[i].supported)
876 			n_supported++;
877 	}
878 
879 	if (!n_supported)
880 		goto use_legacy;
881 
882 	return;
883 
884 use_legacy:
885 	msp->is_ht = false;
886 	memset(&msp->legacy, 0, sizeof(msp->legacy));
887 	msp->legacy.r = msp->ratelist;
888 	msp->legacy.sample_table = msp->sample_table;
889 	return mac80211_minstrel.rate_init(priv, sband, sta, &msp->legacy);
890 }
891 
892 static void
893 minstrel_ht_rate_init(void *priv, struct ieee80211_supported_band *sband,
894                       struct ieee80211_sta *sta, void *priv_sta)
895 {
896 	minstrel_ht_update_caps(priv, sband, sta, priv_sta);
897 }
898 
899 static void
900 minstrel_ht_rate_update(void *priv, struct ieee80211_supported_band *sband,
901                         struct ieee80211_sta *sta, void *priv_sta,
902                         u32 changed)
903 {
904 	minstrel_ht_update_caps(priv, sband, sta, priv_sta);
905 }
906 
907 static void *
908 minstrel_ht_alloc_sta(void *priv, struct ieee80211_sta *sta, gfp_t gfp)
909 {
910 	struct ieee80211_supported_band *sband;
911 	struct minstrel_ht_sta_priv *msp;
912 	struct minstrel_priv *mp = priv;
913 	struct ieee80211_hw *hw = mp->hw;
914 	int max_rates = 0;
915 	int i;
916 
917 	for (i = 0; i < IEEE80211_NUM_BANDS; i++) {
918 		sband = hw->wiphy->bands[i];
919 		if (sband && sband->n_bitrates > max_rates)
920 			max_rates = sband->n_bitrates;
921 	}
922 
923 	msp = kzalloc(sizeof(*msp), gfp);
924 	if (!msp)
925 		return NULL;
926 
927 	msp->ratelist = kzalloc(sizeof(struct minstrel_rate) * max_rates, gfp);
928 	if (!msp->ratelist)
929 		goto error;
930 
931 	msp->sample_table = kmalloc(SAMPLE_COLUMNS * max_rates, gfp);
932 	if (!msp->sample_table)
933 		goto error1;
934 
935 	return msp;
936 
937 error1:
938 	kfree(msp->ratelist);
939 error:
940 	kfree(msp);
941 	return NULL;
942 }
943 
944 static void
945 minstrel_ht_free_sta(void *priv, struct ieee80211_sta *sta, void *priv_sta)
946 {
947 	struct minstrel_ht_sta_priv *msp = priv_sta;
948 
949 	kfree(msp->sample_table);
950 	kfree(msp->ratelist);
951 	kfree(msp);
952 }
953 
954 static void *
955 minstrel_ht_alloc(struct ieee80211_hw *hw, struct dentry *debugfsdir)
956 {
957 	return mac80211_minstrel.alloc(hw, debugfsdir);
958 }
959 
960 static void
961 minstrel_ht_free(void *priv)
962 {
963 	mac80211_minstrel.free(priv);
964 }
965 
966 static struct rate_control_ops mac80211_minstrel_ht = {
967 	.name = "minstrel_ht",
968 	.tx_status = minstrel_ht_tx_status,
969 	.get_rate = minstrel_ht_get_rate,
970 	.rate_init = minstrel_ht_rate_init,
971 	.rate_update = minstrel_ht_rate_update,
972 	.alloc_sta = minstrel_ht_alloc_sta,
973 	.free_sta = minstrel_ht_free_sta,
974 	.alloc = minstrel_ht_alloc,
975 	.free = minstrel_ht_free,
976 #ifdef CONFIG_MAC80211_DEBUGFS
977 	.add_sta_debugfs = minstrel_ht_add_sta_debugfs,
978 	.remove_sta_debugfs = minstrel_ht_remove_sta_debugfs,
979 #endif
980 };
981 
982 
983 static void
984 init_sample_table(void)
985 {
986 	int col, i, new_idx;
987 	u8 rnd[MCS_GROUP_RATES];
988 
989 	memset(sample_table, 0xff, sizeof(sample_table));
990 	for (col = 0; col < SAMPLE_COLUMNS; col++) {
991 		for (i = 0; i < MCS_GROUP_RATES; i++) {
992 			get_random_bytes(rnd, sizeof(rnd));
993 			new_idx = (i + rnd[i]) % MCS_GROUP_RATES;
994 
995 			while (sample_table[col][new_idx] != 0xff)
996 				new_idx = (new_idx + 1) % MCS_GROUP_RATES;
997 
998 			sample_table[col][new_idx] = i;
999 		}
1000 	}
1001 }
1002 
1003 int __init
1004 rc80211_minstrel_ht_init(void)
1005 {
1006 	init_sample_table();
1007 	return ieee80211_rate_control_register(&mac80211_minstrel_ht);
1008 }
1009 
1010 void
1011 rc80211_minstrel_ht_exit(void)
1012 {
1013 	ieee80211_rate_control_unregister(&mac80211_minstrel_ht);
1014 }
1015