/*
// Definition for a Node.
class Node {
public:
int val = NULL;
Node* prev = NULL;
Node* next = NULL;
Node* child = NULL;
Node() {}
Node(int _val, Node* _prev, Node* _next, Node* _child) {
val = _val;
prev = _prev;
next = _next;
child = _child;
}
};
*/
class Solution {
public:
Node* flatten(Node* head) {
if (head == NULL) return head;
flatten2(head);
return head;
}
Node* flatten2(Node* head) { // flatten and return tail
Node* ret = head;
while (head) {
ret = head;
if (head->child) {
Node* tail = flatten2(head->child);
tail->next = head->next;
if (tail->next)
tail->next->prev = tail;
head->next = head->child;
head->next->prev = head;
head->child = NULL;
head = tail->next;
ret = tail;
}
else {
head = head->next;
}
}
return ret;
}
};