《重构:改善既有代码的设计》中提到过很多重构方法,关于简化条件表达式的方法有8种。本文介绍:
以多态取代条件表达式 replace conditional with polymorphism
- 名称:以多态取代条件表达式 replace conditional with polymorphism
- 概要:有个条件表达式,它根据对象类型的不同而选择不同的行为。将这个条件表达式的每个分支放进一个子类内的覆写函数中,然后将原始函数声明为抽象函数。
- 动机: 多态的好处:如果你需要根据对象的不同类型而采取不同的行为,多态使你不必编写明显的条件表达式
- 做法:
- 在使用此重构之前,必须有一个继承结构。可以采用replace type code with subclasses和replace type code with state/strategy.
- 如果要处理的条件表达式是一个更大函数的一部分,首先对条件表达式进行分析,然后使用extract method将它提炼到一个独立函数去。
- 如果有必要,使用move method将条件表达式放置到继承结构的顶端
- 任选一个子类,在其中建立一个函数,使之覆写超类中容纳条件表达式的那个函数。将与该子类相关的条件表达式分支复制到新建函数中,并对它进行适当调整。为了顺利进行这一步骤,你可能需要将超类中的某些private 字段声明为protected
- 编译,测试
- 在超类中删掉条件表达式内被复制了的分支
- 编译,测试
- 针对条件表达式的每个分支,重复上述过程,直到所有分支都被移到子类内的函数为止
- 将超类之中容纳条件表达式的函数声明为抽象函数
- 代码演示
修改之前的代码:
/////////////////////////.h
class EmployeeType;
class Employee
{
public:
Employee(EmployeeType* code);
double payAmount();
int getCode() const;
void setCode(int code);
private:
//int m_code;
EmployeeType* m_code;
double m_monthlySalary;
double m_commission;
double m_bonus;
};
class EmployeeType
{
public:
static const int ENGINEER;
static const int SALESMAN;
static const int MANAGER;
static EmployeeType* newType(int code);
virtual int getTypeCode() = 0;
};
class Engineer : public EmployeeType
{
public:
int getTypeCode();
};
class Salesman : public EmployeeType
{
public:
int getTypeCode();
};
class Manager : public EmployeeType
{
public:
int getTypeCode();
};
class Error : public EmployeeType
{
public:
int getTypeCode();
};
//////////////////////////.cpp
const int EmployeeType::ENGINEER = 0;
const int EmployeeType::SALESMAN = 1;
const int EmployeeType::MANAGER = 2;
Employee::Employee(EmployeeType* code):
m_monthlySalary(100),
m_commission(10),
m_bonus(200)
{
m_code = code;
}
double Employee::payAmount()
{
switch (getCode()) {
case EmployeeType::ENGINEER:
return m_monthlySalary;
case EmployeeType::SALESMAN:
return m_monthlySalary + m_commission;
case EmployeeType::MANAGER:
return m_monthlySalary + m_bonus;
default:
return 0;
}
}
int Employee::getCode() const
{
return m_code->getTypeCode();
}
void Employee::setCode(int code)
{
m_code = m_code->newType(code);
}
int Engineer::getTypeCode()
{
return ENGINEER;
}
int Salesman::getTypeCode()
{
return SALESMAN;
}
int Manager::getTypeCode()
{
return MANAGER;
}
int Error::getTypeCode()
{
return MANAGER;
}
EmployeeType* EmployeeType::newType(int code)
{
switch (code) {
case ENGINEER:
return new Engineer();
case SALESMAN:
return new Salesman();
case MANAGER:
return new Manager();
default:
return nullptr;
}
}
////////////////////////main.cpp
EmployeeType* engineer_type = new Engineer();
EmployeeType* manager_type = new Manager();
Employee *engineer = new Employee(engineer_type);
qDebug() << "employee code = " <<engineer->getCode();
qDebug() << "employee payamount = " <<engineer->payAmount();
engineer->setCode(Manager().getTypeCode());
qDebug() << "employee code = " <<engineer->getCode();
qDebug() << "employee payamount = " <<engineer->payAmount();
1)由于需要employee的数据,所以将employee对象作为参数传递给payamount()
2) 修改employee中的payAmount(),令它委托employeeType
3) 将switch直接返回具体对象的payAmount(),并将超类的payAmount()声明为抽象函数
修改之后的代码:
///////////////////////////.h
class EmployeeType;
class Employee
{
public:
Employee(EmployeeType* code);
double payAmount();
int getCode() const;
void setCode(int code);
double getMonthlySalary();
double getCommission();
double getBonus();
private:
//int m_code;
EmployeeType* m_code;
double m_monthlySalary;
double m_commission;
double m_bonus;
};
class EmployeeType
{
public:
static const int ENGINEER;
static const int SALESMAN;
static const int MANAGER;
static EmployeeType* newType(int code);
virtual int getTypeCode() = 0;
virtual int payAmount(Employee* emp) = 0;
};
class Engineer : public EmployeeType
{
public:
int getTypeCode();
int payAmount(Employee* emp);
};
class Salesman : public EmployeeType
{
public:
int getTypeCode();
int payAmount(Employee* emp);
};
class Manager : public EmployeeType
{
public:
int getTypeCode();
int payAmount(Employee* emp);
};
class Error : public EmployeeType
{
public:
int getTypeCode();
int payAmount(Employee* emp);
};
///////////////////////////.cpp
const int EmployeeType::ENGINEER = 0;
const int EmployeeType::SALESMAN = 1;
const int EmployeeType::MANAGER = 2;
Employee::Employee(EmployeeType* code):
m_monthlySalary(100),
m_commission(10),
m_bonus(200)
{
m_code = code;
}
double Employee::payAmount()
{
return m_code->payAmount(this);
}
int Employee::getCode() const
{
return m_code->getTypeCode();
}
void Employee::setCode(int code)
{
m_code = m_code->newType(code);
}
double Employee::getMonthlySalary()
{
return m_monthlySalary;
}
double Employee::getCommission()
{
return m_commission;
}
double Employee::getBonus()
{
return m_bonus;
}
int Engineer::getTypeCode()
{
return ENGINEER;
}
int Engineer::payAmount(Employee *emp)
{
return emp->getMonthlySalary();
}
int Salesman::getTypeCode()
{
return SALESMAN;
}
int Salesman::payAmount(Employee *emp)
{
return emp->getMonthlySalary() + emp->getCommission();
}
int Manager::getTypeCode()
{
return MANAGER;
}
int Manager::payAmount(Employee *emp)
{
return emp->getMonthlySalary() + emp->getBonus();
}
int Error::getTypeCode()
{
return MANAGER;
}
int Error::payAmount(Employee *emp)
{
return 0;
}
EmployeeType* EmployeeType::newType(int code)
{
switch (code) {
case ENGINEER:
return new Engineer();
case SALESMAN:
return new Salesman();
case MANAGER:
return new Manager();
default:
return nullptr;
}
}
//int EmployeeType::payAmount(Employee *emp)
//{
// switch (getTypeCode()) {
// case ENGINEER:
// return emp->getMonthlySalary();
// case SALESMAN:
// return emp->getMonthlySalary() + emp->getCommission();
// case MANAGER:
// return emp->getMonthlySalary() + emp->getBonus();
// }
//}
///////////////////////////////////main.cpp
EmployeeType* engineer_type = new Engineer();
EmployeeType* manager_type = new Manager();
Employee *engineer = new Employee(engineer_type);
qDebug() << "employee code = " <<engineer->getCode();
qDebug() << "employee payamount = " <<engineer->payAmount();
engineer->setCode(Manager().getTypeCode());
qDebug() << "employee code = " <<engineer->getCode();
qDebug() << "employee payamount = " <<engineer->payAmount();