数据集来源Geolife
加载数据
import numpy as np import matplotlib.pyplot as plt import pandas as pd import os from matplotlib.colors import rgb2hex from shapely.geometry import MultiPoint from geopy.distance import great_circle from sklearn.cluster import KMeans from sklearn.cluster import DBSCAN userdata = '../Lab-work/Geolife Trajectories 1.3/Data/001/Trajectory/' filelist = os.listdir(userdata) names = ['lat','lng','zero','alt','days','date','time'] df_list = [pd.read_csv(userdata + f,header=6,names=names,index_col=False) for f in filelist] df = pd.concat(df_list, ignore_index=True) print(df.head(10)) plt.plot(df.lat, df.lng)
lat lng zero alt days date time 0 39.984198 116.319322 0 492 39744.245208 2008-10-23 05:53:06 1 39.984224 116.319402 0 492 39744.245266 2008-10-23 05:53:11 2 39.984211 116.319389 0 492 39744.245324 2008-10-23 05:53:16 3 39.984217 116.319422 0 491 39744.245382 2008-10-23 05:53:21 4 39.984710 116.319865 0 320 39744.245405 2008-10-23 05:53:23 5 39.984674 116.319810 0 325 39744.245463 2008-10-23 05:53:28 6 39.984623 116.319773 0 326 39744.245521 2008-10-23 05:53:33 7 39.984606 116.319732 0 327 39744.245579 2008-10-23 05:53:38 8 39.984555 116.319728 0 324 39744.245637 2008-10-23 05:53:43 9 39.984579 116.319769 0 309 39744.245694 2008-10-23 05:53:48 [<matplotlib.lines.Line2D at 0x17efc43eac8>]
K-Means
coords = df[['lat','lng']].values n_clusters = 100 cls = KMeans(n_clusters).fit(coords) colors = tuple([(np.random.random(),np.random.random(), np.random.random()) for i in range(n_clusters)]) colors = [rgb2hex(x) for x in colors] for i, color in enumerate(colors): members = cls.labels_ == i plt.scatter(coords[members, 0], coords[members, 1], s=60, c=color, alpha=0.5) plt.show()
获取 K-Means 聚类结果
cluster_labels = cls.labels_ num_clusters = len(set(cluster_labels) - set([-1])) print('Clustered ' + str(len(df_min)) + ' points to ' + str(num_clusters) + ' clusters') clusters = pd.Series([coords[cluster_labels == n] for n in range(num_clusters)]) print(clusters)
Clustered 9045 points to 100 clusters 0 [[40.014459, 116.305603], [40.014363, 116.3056... 1 [[39.975246000000006, 116.358976], [39.975244,... 2 [[40.001312, 116.193358], [40.001351, 116.1932... 3 [[39.984559000000004, 116.326696], [39.984669,... 4 [[39.964969, 116.434923], [39.964886, 116.4350... ... 95 [[40.004549, 116.260581], [40.004515999999995,... 96 [[39.97964, 116.323856], [39.979701, 116.32396... 97 [[40.0009, 116.23948500000002], [40.000831, 11... 98 [[39.962336, 116.32817800000001], [39.96223300... 99 [[39.9663, 116.353677], [39.966291999999996, 1... Length: 100, dtype: object
获取每个群集的中心点
def get_centermost_point(cluster): centroid = (MultiPoint(cluster).centroid.x, MultiPoint(cluster).centroid.y) centermost_point = min(cluster, key=lambda point: great_circle(point, centroid).m) return tuple(centermost_point) centermost_points = clusters.map(get_centermost_point) lats, lons = zip(*centermost_points) rep_points = pd.DataFrame({'lon':lons, 'lat':lats}) print(rep_points)
lon lat 0 116.306558 40.013751 1 116.353295 39.975357 2 116.190167 40.004290 3 116.326944 39.986492 4 116.438241 39.961273 .. ... ... 95 116.256309 40.004774 96 116.326462 39.978752 97 116.232672 39.998630 98 116.328847 39.958271 99 116.358655 39.966451 [100 rows x 2 columns]
描绘中心点
fig, ax = plt.subplots(figsize=[10, 6]) rs_scatter = ax.scatter(rep_points['lon'][0], rep_points['lat'][0], c='#99cc99', edgecolor='None', alpha=0.7, s=450) ax.scatter(rep_points['lon'][1], rep_points['lat'][1], c='#99cc99', edgecolor='None', alpha=0.7, s=250) ax.scatter(rep_points['lon'][2], rep_points['lat'][2], c='#99cc99', edgecolor='None', alpha=0.7, s=250) ax.scatter(rep_points['lon'][3], rep_points['lat'][3], c='#99cc99', edgecolor='None', alpha=0.7, s=150) df_scatter = ax.scatter(df_min['lng'], df_min['lat'], c='k', alpha=0.9, s=3) ax.set_title('Full GPS trace vs. DBSCAN clusters') ax.set_xlabel('Longitude') ax.set_ylabel('Latitude') ax.legend([df_scatter, rs_scatter], ['GPS points', 'Cluster centers'], loc='upper right') labels = ['cluster{0}'.format(i) for i in range(1, num_clusters+1)] for label, x, y in zip(labels, rep_points['lon'], rep_points['lat']): plt.annotate( label, xy = (x, y), xytext = (-25, -30), textcoords = 'offset points', ha = 'right', va = 'bottom', bbox = dict(boxstyle = 'round,pad=0.5', fc = 'white', alpha = 0.5), arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0')) plt.show()