From 08d3236555660cbc7b18f4ba1ee89e53fa0bb89b Mon Sep 17 00:00:00 2001
From: Vicent Marti <tanoku@gmail.com>
Date: Fri, 22 Jan 2016 16:07:49 +0100
Subject: [PATCH] Add `default_sample_rate` config

---
 lib/statsd.rb | 61 +++++++++++++++++++++++++--------------------------
 1 file changed, 30 insertions(+), 31 deletions(-)

diff --git a/lib/statsd.rb b/lib/statsd.rb
index 0881fb2..1335020 100644
--- a/lib/statsd.rb
+++ b/lib/statsd.rb
@@ -71,6 +71,7 @@ def nonce
 
   # A namespace to prepend to all statsd calls.
   attr_reader :namespace
+  attr_accessor :default_sample_rate
 
   def namespace=(namespace)
     @namespace = namespace
@@ -88,9 +89,10 @@ def namespace=(namespace)
   GAUGE_TYPE = "g".freeze
   HISTOGRAM_TYPE = "h".freeze
 
-  def initialize(client_class = nil)
+  def initialize(client_class: nil, default_sample_rate: 1)
     @shards = []
     @client_class = client_class || UDPClient
+    @default_sample_rate = default_sample_rate
     self.namespace = nil
   end
 
@@ -127,21 +129,21 @@ def flush_all
   # @param stat (see #count)
   # @param sample_rate (see #count)
   # @see #count
-  def increment(stat, sample_rate=1); count stat, 1, sample_rate end
+  def increment(stat, sample_rate=nil); count stat, 1, sample_rate end
 
   # Sends a decrement (count = -1) for the given stat to the statsd server.
   #
   # @param stat (see #count)
   # @param sample_rate (see #count)
   # @see #count
-  def decrement(stat, sample_rate=1); count stat, -1, sample_rate end
+  def decrement(stat, sample_rate=nil); count stat, -1, sample_rate end
 
   # Sends an arbitrary count for the given stat to the statsd server.
   #
   # @param [String] stat stat name
   # @param [Integer] count count
   # @param [Integer] sample_rate sample rate, 1 for always
-  def count(stat, count, sample_rate=1); send stat, count, COUNTER_TYPE, sample_rate end
+  def count(stat, count, sample_rate=nil); send stat, count, COUNTER_TYPE, sample_rate end
 
   # Sends an arbitary gauge value for the given stat to the statsd server.
   #
@@ -161,7 +163,7 @@ def gauge(stat, value)
   # @param stat stat name
   # @param [Integer] ms timing in milliseconds
   # @param [Integer] sample_rate sample rate, 1 for always
-  def timing(stat, ms, sample_rate=1); send stat, ms, TIMING_TYPE, sample_rate end
+  def timing(stat, ms, sample_rate=nil); send stat, ms, TIMING_TYPE, sample_rate end
 
   # Reports execution time of the provided block using {#timing}.
   #
@@ -171,7 +173,7 @@ def timing(stat, ms, sample_rate=1); send stat, ms, TIMING_TYPE, sample_rate end
   # @see #timing
   # @example Report the time (in ms) taken to activate an account
   #   $statsd.time('account.activate') { @account.activate! }
-  def time(stat, sample_rate=1)
+  def time(stat, sample_rate=nil)
     start = Time.now
     result = yield
     timing(stat, ((Time.now - start) * 1000).round(5), sample_rate)
@@ -182,34 +184,31 @@ def time(stat, sample_rate=1)
   # sample_rate determines what percentage of the time this report is sent. The
   # statsd server then uses the sample_rate to correctly track the average
   # for the stat.
-  def histogram(stat, value, sample_rate=1); send stat, value, HISTOGRAM_TYPE, sample_rate end
+  def histogram(stat, value, sample_rate=nil); send stat, value, HISTOGRAM_TYPE, sample_rate end
 
   private
-  def sampled(sample_rate)
-    yield unless sample_rate < 1 and rand > sample_rate
-  end
-
-  def send(stat, delta, type, sample_rate=1)
-    sampled(sample_rate) do
-      stat = stat.to_s.dup
-      stat.gsub!(/::/, ".".freeze)
-      stat.gsub!(RESERVED_CHARS_REGEX, "_".freeze)
-
-      msg = String.new
-      msg << @prefix
-      msg << stat
-      msg << ":".freeze
-      msg << delta.to_s
-      msg << "|".freeze
-      msg << type
-      if sample_rate < 1
-        msg << "|@".freeze
-        msg << sample_rate.to_s
-      end
-
-      shard = select_shard(stat)
-      shard.send(msg)
+  def send(stat, delta, type, sample_rate=nil)
+    sample_rate ||= default_sample_rate
+    return if sample_rate < 1 and rand > sample_rate
+
+    stat = stat.to_s.dup
+    stat.gsub!(/::/, ".".freeze)
+    stat.gsub!(RESERVED_CHARS_REGEX, "_".freeze)
+
+    msg = String.new
+    msg << @prefix
+    msg << stat
+    msg << ":".freeze
+    msg << delta.to_s
+    msg << "|".freeze
+    msg << type
+    if sample_rate < 1
+      msg << "|@".freeze
+      msg << sample_rate.to_s
     end
+
+    shard = select_shard(stat)
+    shard.send(msg)
   end
 
   def select_shard(stat)