实现功能:
python绘制散点图,展现两个变量间的关系,当数据包含多组时,使用不同颜色和形状区分。
实现代码:
1 | import numpy as np |
2 | import pandas as pd |
3 | import matplotlib as mpl |
4 | import matplotlib.pyplot as plt |
5 | import seaborn as sns |
6 | import warnings |
7 | warnings.filterwarnings(action='once') |
8 | plt.style.use('seaborn-whitegrid') |
9 | sns.set_style("whitegrid") |
10 | print(mpl.__version__) |
11 | print(sns.__version__) |
12 | |
13 | def draw_scatter(file): |
14 | # Import dataset |
15 | midwest = pd.read_csv(file) |
16 | # Prepare Data |
17 | # Create as many colors as there are unique midwest['category'] |
18 | categories = np.unique(midwest['category']) |
19 | colors = [plt.cm.Set1(i / float(len(categories) - 1)) for i in range(len(categories))] |
20 | # Draw Plot for Each Category |
21 | plt.figure(figsize=(10, 6), dpi=100, facecolor='w', edgecolor='k') |
22 | |
23 | for i, category in enumerate(categories): |
24 | plt.scatter('area', 'poptotal', data=midwest.loc[midwest.category == category, :],s=20,c=colors[i],label=str(category)) |
25 | |
26 | # Decorations |
27 | plt.gca().set(xlim=(0.0, 0.1), ylim=(0, 90000),) |
28 | |
29 | plt.xticks(fontsize=10) |
30 | plt.yticks(fontsize=10) |
31 | plt.xlabel('Area', fontdict={'fontsize': 10}) |
32 | plt.ylabel('Population', fontdict={'fontsize': 10}) |
33 | plt.title("Scatterplot of Midwest Area vs Population", fontsize=12) |
34 | plt.legend(fontsize=10) |
35 | plt.show() |
36 | |
37 | draw_scatter("F:\数据杂坛\datasets\midwest_filter.csv") |
实现效果:
喜欢记得点赞,在看,收藏,
关注V订阅号:数据杂坛,获取数据集,完整代码和效果,将持续更新!
| 留言与评论(共有 0 条评论) “” |