
The main idea of k-means is the following - if you have a bunch of elements and a given number of clusters:
- Create the initial clusters randomly within the space spanned by the elements (typically (always?) you would pick randomly from your elements).
- Lob all elements into the cluster with the nearest center (using some Euclidean distance metric typically).
- Recenter each cluster on the average of its elements,
- If necessary move the elements to their now nearest clusters.
- Repeat the "re-centering" and moving of elements until we have a "stable" (enough) result.
First I implemented some simple data structures, Points:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Point | |
def initialize(coords) | |
@coords = coords | |
end | |
def distance_to(point) | |
Math.sqrt( @coords.keys.inject(0) { |sum, key| sum + (@coords[key] - point.send(key))**2 } ) | |
end | |
def to_s | |
"#{@coords.values.join(", ")}" | |
end | |
def dimensions | |
@coords.keys | |
end | |
def method_missing(m, *args) | |
if m.to_s.end_with? '=' | |
@coords[m.to_s.chop.to_sym]= args.first | |
else | |
@coords[m] | |
end | |
end | |
end |
Clusters are simply:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Cluster | |
attr_reader :center | |
def initialize(center) | |
@center = center | |
@points = [] | |
@moved = true | |
end | |
def add_point(point) | |
@points << point | |
end | |
def update_center(delta = 0.001) | |
@moved = false | |
averages = {} | |
@center.dimensions.each do |dimension| | |
averages[dimension] = | |
@points.inject(0.0) {|sum, point| sum + point.send(dimension)} / | |
@points.length unless @points.length == 0 | |
end | |
unless Point.new(averages).distance_to(@center) < delta | |
@center = Point.new(averages) | |
@moved = true | |
end | |
end | |
def clear_points | |
@points = [] | |
end | |
def collect(dimension) | |
@points.collect {|p| p.send(dimension) } | |
end | |
def distance_to(point) | |
@center.distance_to point | |
end | |
def number_of_points | |
@points.length | |
end | |
def to_s | |
"#{@center.to_s}: #{number_of_points} points, cost: #{ @points.inject(0) { |sum, point| sum + point.distance_to(@center) }}" | |
end | |
def moved? | |
@moved | |
end | |
end |
I also implemented a special case of points, 2D points:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Point2D < Point | |
def initialize(x, y = nil) | |
y.nil? ? super(x) : super({x: x, y: y}) | |
end | |
end |
Last helper method in here is simply to plot these clusters (in 2D), and I use gnuplot like so:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def cluster_plot_2D(clusters, seq = 1, title = "Random noise", get_x = 'x', get_y = 'y') | |
if defined? Gnuplot | |
Gnuplot.open do |gp| | |
Gnuplot::Plot.new(gp) do |plot| | |
plot.terminal "png" | |
plot.output File.expand_path("../outfiles/clusters-#{seq}.png", __FILE__) | |
plot.title title | |
# Plot each cluster's points | |
clusters.each do |cluster| | |
# Collect all x and y coords for this cluster | |
x = cluster.collect(get_x.to_sym) | |
y = cluster.collect(get_y.to_sym) | |
# Plot w/o a title (clutters things up) | |
plot.data << Gnuplot::DataSet.new([x,y]) do |ds| | |
ds.notitle | |
end | |
end | |
# Plot each cluster's centers | |
x = clusters.collect {|p| p.center.send(get_x.to_sym) } | |
y = clusters.collect {|p| p.center.send(get_y.to_sym) } | |
plot.data << Gnuplot::DataSet.new([x,y]) do |ds| | |
ds.notitle | |
end | |
end | |
end | |
end | |
end |
The real meat is in the k-means method itself of course:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def k_means (points, k = 5, delta = 0.001, plot_on = false) | |
k = points.length if points.length < k | |
clusters = [] | |
k.times { clusters << Cluster.new(points.sample) } | |
iterations = 0 | |
while (clusters.any?(&:moved?)) | |
clusters.each(&:clear_points) | |
points.each do |point| | |
shortest = Float::INFINITY # requires Ruby 1.9.2 or later | |
cluster_found = nil | |
clusters.each do |cluster| | |
distance = cluster.distance_to(point) | |
if distance < shortest | |
cluster_found = cluster | |
shortest = distance | |
end | |
end | |
cluster_found.add_point point unless cluster_found.nil? | |
end | |
clusters.delete_if { |cluster| cluster.number_of_points == 0 } | |
clusters.each { |cluster| cluster.update_center delta} | |
iterations += 1 | |
end | |
clusters | |
end |
This takes as parameters the points themselves (as an Enumerable of Points), the number of clusters (k), the delta (defines what is a "move"), and if we want to plot each iteration. (Especially interesting while debugging the algorithm... :-)) No real surprises here, I guess. I chose to delete clusters with no points. I'm not entirely sure if that is what's in the original algorithm. (I guess by now I should duckduckgo it. :-))
Then I created a very silly little script that generates 10000 random points in 2D and (tries to) put them into 7 clusters, gnuplot'ing each iteration, to test the algorithm.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# parameters | |
k = 7 | |
number_of_points = 10000 | |
delta = 0.001 | |
plot_on = false | |
require './k-means.rb' | |
points = [] | |
number_of_points.times do | |
points << Point2D.new(10*rand()-5.0, 10*rand()-5.0) | |
end | |
clusters = k_means points, k, delta, plot_on | |
clusters.each_with_index do |cluster, index| | |
puts "#{ cluster.center.to_s }\t#{ index }" | |
end | |
cluster_plot_2D clusters, "#{k}-final" if plot_on |
Voila, the results are in:
![]() |
montage -geometry "320"x"240" -tile 7x7 outfiles/clusters-7* clusters-all.png |
![]() |
convert outfiles/clusters-7* clusters-all.gif |
Epiolog
I guess a variant of this could serve as a solution to the partitioning problem if we just map the categorical variables to the continuous space (using the weights?) and sample one element from each resulting (population/number of groups + 1) cluster (where the cluster size would be locked - Lloyd's algorithm?) to form groups.
Moral
It is useful to (try to) implement an algorithm you think you know from scratch from time time to see if you really do know all the details in the implementation of it... Especially in new languages.
Update: Clone this github if you for any reason would want to play around with this yourself: https://github.com/mortenjohs/k-means...
Comments
Post a Comment