A generic way to handle your problem would be to index the data frame and filter the indices that are greater than 2.
As suggested in another answer, you may try adding an index with monotonically_increasing_id.
.filter('Index > 2)
Yet, that's only going to work if the first 3 rows are in the first partition. Moreover, as mentioned in the comments, this is the case today but this code may break completely with further versions or spark and that would be very hard to debug. Indeed, the contract in the API is just "The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive". It is therefore not very safe to assume that they will always start from zero. There might even be other cases in the current version in which that does not work (I'm not sure though).
To illustrate my first concern, have a look at this:
| id| Index|
| 0| 0|
| 1| 1|
We would only remove two rows.